/* Copyright (C) 2008 Xavier Pujol.

This file is part of the fplll Library.

The fplll Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

The fplll Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with the fplll Library; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA. */

#include "enumerate.h"

typedef double real;
static const int DMAX = 100;
static real mut[DMAX][DMAX];
typedef real RVECT[DMAX];

static inline bool nextPosUp(RVECT x, RVECT dx, RVECT ddx,
  int& k, int& kMax, int kEnd) {

  k++;
  if (k < kMax) {
    ddx[k] = -ddx[k];
    dx[k] = ddx[k] - dx[k];
    x[k] += dx[k];
  }
  else {
    if (k >= kEnd) return false;
    kMax = k;
    x[k]++;
  }
  return true;
}

static bool enumerateLoop(RVECT rdiag,
  RVECT center, RVECT dist, RVECT centerPartSum, RVECT x, RVECT dx, RVECT ddx,
  real maxDist, int k, int kEnd, int kMax,
  real& newMaxDist, int& newKMax) {

  if (k >= kEnd) return false;
  while (true) {
    real y = center[k] - x[k];
    real newDist = dist[k] + y * y * rdiag[k];
    if (newDist <= maxDist) {
      k--;
      if (k < 0) {
        newMaxDist = newDist;
        newKMax = kMax;
        return true; // New solution found
      }
      real newCenter = centerPartSum[k];
      for (int j = kEnd - 1; j > k; j--)
        newCenter -= x[j] * mut[k][j];

      center[k] = newCenter;
      dist[k] = newDist;
      x[k] = rint(newCenter);
      dx[k] = real(0.0);
      ddx[k] = newCenter < x[k] ? real(1.0) : real(-1.0);
    }
    else if (!nextPosUp(x, dx, ddx, k, kMax, kEnd))
      return false; // End of the enumeration
  }
}

void enumerate(const FloatMatrix& mu, const FloatVect& frdiag, Float& fMaxDist,
  Evaluator& evaluator, const FloatVect& targetCoord,
  const FloatVect& subTree, int first, int last) {

  if (last == -1) last = mu.GetNumRows();
  int d = last - first;

  bool solveSVP;      // true->SVP, false->CVP
  bool svpBeginning;  // true->SVP and all coordinates in subTree are null
  int k;              // Current level in the enumeration
  int kEnd;           // The algorithm stops when k = kEnd
  int kMax = 0;       // Index of the last non-zero value of x (<= kEnd)
  RVECT rdiag, x, dx, ddx, dist, center, centerPartSum;
  real newX, newDist = real(0.0), maxDist, newMaxDist;
  FloatVect fX(d);

  solveSVP = targetCoord.empty();
  INTERNAL_CHECK(d <= DMAX, "(kfp) Dimension is too high");
  svpBeginning = solveSVP;
  kEnd = d - subTree.size();

  // Float->real Conversion and transposition of mu
  maxDist = fMaxDist.get_d();
  for (int i = 0; i < d; i++) {
    rdiag[i] = frdiag[i + first].get_d();
    if (solveSVP)
      centerPartSum[i] = 0.0;
    else
      centerPartSum[i] = targetCoord[i + first].get_d();
    for (int j = 0; j < d; j++)
      mut[i][j] = mu(j + first, i + first).get_d();
  }

  // Prepares the loop (goes to the first vector)
  for (k = d - 1; k >= 0 && newDist <= maxDist; k--) {
    real newCenter = centerPartSum[k];
    for (int j = k + 1; j < kEnd; j++)
      newCenter -= x[j] * mut[k][j];

    if (k >= kEnd) {
      newX = subTree[k - kEnd].get_d();
      if (newX != 0.0) svpBeginning = false;
      for (int j = 0; j < k; j++)
        centerPartSum[j] -= newX * mut[j][k];
    }
    else {
      newX = rint(newCenter);
      center[k] = newCenter;
      dist[k] = newDist;
      dx[k] = real(0.0);
      ddx[k] = newCenter < newX ? real(1.0) : real(-1.0);
    }
    x[k] = newX;
    real y = newCenter - newX;
    newDist += y * y * rdiag[k];
  }
  if (!svpBeginning)
    kMax = kEnd; // The last non-zero coordinate of x will not stay positive
  else
    x[0] = real(1.0); // Excludes (0,...,0) from the enumeration
  k++;
  // now, 0 <= k <= kEnd - 1

  while (enumerateLoop(rdiag,
    center, dist, centerPartSum, x, dx, ddx,
    maxDist, k, kEnd, kMax, newMaxDist, kMax)) {
    // We have found a solution
    for (int j = 0; j < d; j++)
      fX[j].set(x[j]);
    evaluator.evalSol(fX, newMaxDist, maxDist);
    fMaxDist.set(maxDist); // Exact
    k = -1;
    // Goes to the next step and continues the loop
    nextPosUp(x, dx, ddx, k, kMax, kEnd);
  }
}
