/* Copyright (C) 2005-2008 Damien Stehle.
Copyright (C) 2007 David Cade.
Copyright (C) 2008 Xavier Pujol.

This file is part of the fplll Library.

The fplll Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

The fplll Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with the fplll Library; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA. */

#include "main.h"

template<class ZT> void run_action(Options& o) {
  Lexer* lexer;
  ZZ_mat<ZT> m;
  Vect< Z_NR<ZT> > target;

  if (o.inputFile)
    lexer = new Lexer(o.inputFile);
  else
    lexer = new Lexer();

  *lexer >> m;
  if (o.action == ACVP) {
    *lexer >> target;
  }

  delete lexer;

  switch (o.action) {
    case ALLL: lll(o, m); break;
    case ASVP: svp(o, m); break;
    case ACVP: cvp(o, m, target); break;
    case AHKZ: hkz(o, m); break;
    default:   ABORT_MSG("Unimplemented action"); break;
  }
}

int main(int argc, char** argv) {
  Options o;
  read_options(argc, argv, o);
  switch (o.z) {
    case ZMPZ:    run_action<mpz_t>(o); break;
#ifndef FAST_BUILD
    case ZINT:    run_action<long int>(o); break;
    case ZDOUBLE: run_action<double>(o); break;
#endif
    default:      ABORT_MSG("Compiled without support for this integer type");
  }
  return 0;
}

/* LLL-reduction (all versions)
   ============================ */

template<class ZT> void lll_wrapper(Options& o, ZZ_mat<ZT>& m) {
  ABORT_MSG("mpz_t required for the wrapper");
}

template<> void lll_wrapper(Options& o, ZZ_mat<mpz_t>& m) {
  if (o.f != FUNSPECIFIED)
    cerr << "The -f option is not taken into account for the wrapper.\n";
  wrapper(&m, 0, o.eta, o.delta).LLL();
}

template<class ZT, class FT> void lll_zf(Options& o, ZZ_mat<ZT>& m) {
  switch (o.m) {
    case MWRAPPER:
      lll_wrapper(o, m);
      break;
    case MFAST:
      STD_CHECK(o.f == FDOUBLE, "Double required");
      fast<ZT, double>(&m, o.pr, o.eta, o.delta).LLL();
      break;
    case MPROVED:
      proved<ZT, FT>(&m, o.pr, o.eta, o.delta).LLL();
      break;
    case MHEURISTIC:
      heuristic<ZT, FT>(&m, o.pr, o.eta, o.delta, o.siegel).LLL();
      break;
#ifndef FAST_BUILD
    case MFASTEARLY:
      STD_CHECK(o.f == FDOUBLE, "Double required");
      fast_early_red<ZT, double>(&m, o.pr, o.eta, o.delta).LLL();
      break;
    case MHEUREARLY:
      heuristic_early_red<ZT, FT>(&m, o.pr, o.eta, o.delta).LLL();
      break;
#endif
    default:
      ABORT_MSG("Compiled without support for this method");
      break;
  }
}

template<class ZT> void lll(Options& o, ZZ_mat<ZT>& m, bool print) {
  switch (o.f) {
    case FDOUBLE: lll_zf<ZT, double>(o, m); break;
    case FDPE:    lll_zf<ZT, dpe_t>(o, m); break;
    case FMPFR:   lll_zf<ZT, mpfr_t>(o, m); break;
    case FUNSPECIFIED: lll_zf<ZT, dpe_t>(o,m); break;
  }
  if (print) m.print();
}

/* Other actions (z=mpz_t only)
   ============================ */

template<class ZT> void hkz(Options& o, ZZ_mat<ZT>& m, bool print) {
  ABORT_MSG("mpz required for HKZ");
}

/* HKZ-reduction
   Note: since we only force |mu_i,j| <= 0.51, the solution
   is not unique even for a generic matrix */

template<> void hkz(Options& o, ZZ_mat<mpz_t>& m, bool print) {
  int d = m.GetNumRows(), n = m.GetNumCols();
  FloatMatrix fMatrix(d, n), mu(d, d);
  FloatVect rdiag;
  IntVect solCoord(d);

  lll<mpz_t>(o, m, false);

  for (int i = 0; i < d - 1; i++) {
    // Computes the coordinates of a shortest vector
    for (int j = 0; j < d; j++)
      for (int k = 0; k < n; k++)
        fMatrix(j, k).set_z(m(j, k));
    gramSchmidt(fMatrix, mu, rdiag);
    Evaluator evaluator(mu, rdiag);
    enumerate(mu, rdiag, rdiag[i], evaluator, FloatVect(), FloatVect(), i);
    INTERNAL_CHECK(!evaluator.solCoord.empty(), "(hkz) KFP failure");
    for (int j = i; j < d; j++)
      solCoord[j].set_f(evaluator.solCoord[j - i]);

    // Adds this vector to the matrix (at row i)
    m.SetNumRows(d + 1);
    for (int k = 0; k < n; k++) {
      m(d, k) = 0;
      for (int j = i; j < d; j++)
        m(d, k).addmul(m(j, k), solCoord[j]);
    }
    for (int j = d; j > i; j--)
      swap(m.GetVec(j), m.GetVec(j - 1));

    // LLL should find a linear dependency
    lll<mpz_t>(o, m, false);
    for (int j = 0; j < n; j++)
      INTERNAL_CHECK(m(0, j) == 0, "(hkz) First vector is non-zero after LLL");

    // Moves the first vector to the last row
    for (int j = 0; j < d; j++)
      swap(m.GetVec(j), m.GetVec(j + 1));
    m.SetNumRows(d);
  }
  if (print) m.print();
}

template<class ZT> void svp(Options& o, ZZ_mat<ZT>& m) {
  ABORT_MSG("mpz required for SVP");
}

template<> void svp(Options& o, ZZ_mat<mpz_t>& m) {
  IntVect solCoord;
  Solver solver;
  solver.evaluatorType = evalSmart;
  solver.precision = AUTOMATIC_PRECISION;
  solver.verbose = 0;
  lll<mpz_t>(o, m, false);
  solver.solveSVP(m, solCoord);
  cout << solCoord << endl;
}

template<class ZT> void cvp(Options& o, ZZ_mat<ZT>& m, Vect< Z_NR<ZT> >& target) {
  ABORT_MSG("mpz required for CVP");
}

template<> void cvp(Options& o, ZZ_mat<mpz_t>& m, Vect< Z_NR<mpz_t> >& target) {
  IntVect solCoord;
  Solver solver;
  solver.evaluatorType = evalFast;
  solver.verbose = 0;
  lll<mpz_t>(o, m, false);
  solver.solveCVP(m, target, solCoord);
  cout << solCoord << endl;
}

/* Command line parsing
   ==================== */

void read_options(int argc, char** argv, Options& o) {
  o.inputFile = NULL;

  for (int ac = 1; ac < argc; ac++) {
    if (strcmp(argv[ac], "-a") == 0) {
      ++ac;
      if (strcmp(argv[ac], "lll") == 0)
        o.action = ALLL;
      else if (strcmp(argv[ac], "hkz") == 0)
        o.action = AHKZ;
      else if (strcmp(argv[ac], "bkz") == 0)
        o.action = ABKZ;
      else if (strcmp(argv[ac], "svp") == 0)
        o.action = ASVP;
      else if (strcmp(argv[ac], "cvp") == 0)
        o.action = ACVP;
      else
        ABORT_MSG("Parse error in -a switch : "
                  << "lll, hkz, bkz, svp or cvp expected");
    }
    else if (strcmp(argv[ac], "-c") == 0) {
      ++ac;
      STD_CHECK(ac < argc, "Missing value after -c switch");
      o.c=atoi(argv[ac]);
    }
    else if (strcmp(argv[ac], "-d") == 0) {
      ++ac;
      STD_CHECK(ac < argc, "Missing value after -d switch");
      o.delta=atof(argv[ac]);
    }
    else if (strcmp(argv[ac], "-e") == 0) {
      ++ac;
      STD_CHECK(ac < argc, "Missing value after -e switch");
      o.eta=atof(argv[ac]);
    }
    else if (strcmp(argv[ac], "-f") == 0) {
      ++ac;
      STD_CHECK(ac < argc, "Missing value after -f switch");
      if (strcmp("mpfr",argv[ac])==0)
        o.f=FMPFR;
      else if (strcmp("dpe",argv[ac])==0)
        o.f=FDPE;
      else if (strcmp("double",argv[ac])==0)
        o.f=FDOUBLE;
      else
        ABORT_MSG("Parse error in -f switch : mpfr, dpe or double expected");
    }
    else if (strcmp(argv[ac], "-l") == 0) {
      ++ac;
      STD_CHECK(ac < argc, "Missing value after -l switch");
      o.siegel = !atoi(argv[ac]);
    }
    else if (strcmp(argv[ac], "-m") == 0) {
      ++ac;
      STD_CHECK(ac < argc, "Missing value after -m switch");
      if (strcmp("proved",argv[ac])==0)
        o.m=MPROVED;
      else if (strcmp("heuristic",argv[ac])==0)
        o.m=MHEURISTIC;
      else if (strcmp("fast",argv[ac])==0)
        o.m=MFAST;
      else if (strcmp("fastearly",argv[ac])==0)
        o.m=MFASTEARLY;
      else if (strcmp("heuristicearly",argv[ac])==0)
        o.m=MHEUREARLY;
      else if (strcmp("wrapper",argv[ac])==0)
        o.m=MWRAPPER;
      else
        ABORT_MSG("Parse error in -m switch : proved, heuristic, fast, "
                  << "heuristicearly, fastearly, or wrapper expected");
    }
    else if (strcmp(argv[ac], "-p") == 0) {
      ++ac;
      STD_CHECK(ac < argc, "Missing value after -p switch");
      o.pr = atoi(argv[ac]);
    }
    else if (strcmp(argv[ac], "-r") == 0) {
      ++ac;
      STD_CHECK(ac < argc, "Missing value after -r switch");
      o.r = atoi(argv[ac]);
    }
    else if (strcmp(argv[ac], "-z") == 0) {
      ++ac;
      STD_CHECK(ac < argc, "Missing value after -z switch");
      if (strcmp("mpz",argv[ac])==0)
        o.z=ZMPZ;
      else if (strcmp("int",argv[ac])==0)
        o.z=ZINT;
      else if (strcmp("double",argv[ac])==0)
        o.z=ZDOUBLE;
      else
        ABORT_MSG("Parse error in -z switch : int, double or mpz expected");
    }
    else if (strcmp(argv[ac], "--help") == 0) {
      cout << "Usage: " << argv[0] << " [options] [file]\n"
           << "List of options:\n"
           << "  -a [lll|hkz|svp]\n"
           << "       lll = LLL-reduce the input matrix (default)\n"
           << "       hkz = HKZ-reduce the input matrix\n"
           << "       svp = solve SVP\n"
           << "       cvp = solve CVP\n"
           << "  -m [proved|heuristic|fast|heuristicearly|fastearly|wrapper]\n"
           << "       LLL version (default: wrapper)\n"
           << "  -z [int|mpz|double]\n"
           << "       Integer type in LLL (default: mpz)\n"
           << "  -f [mpfr|dpe|double]\n"
           << "       Floating-point type in LLL (default: dpe)\n"
           << "  -p <precision>\n"
           << "       Floating-point precision (only with -f mpfr)\n"
           << "  -d <delta> (default=0.99)\n"
           << "  -e <eta> (default=0.51)\n"
           << "  -l <lovasz>\n";
      exit(0);
    }
    else if (argv[ac][0] == '-') {
      ABORT_MSG("Invalid option '" << argv[ac] << "'. Try '" << argv[0]
        << " --help' for more information.");
    }
    else {
      if (o.inputFile) ABORT_MSG("Too many input files");
      o.inputFile = argv[ac];
    }
  }
}
