// Copyright (c) 2007, Yoshimasa Tsuruoka

//
// Jorge Nocedal, "Updating Quasi-Newton Matrices With Limited Storage",
// Mathematics of Computation, Vol. 35, No. 151, pp. 773-782, 1980.
//

#ifndef _LBFGS_HPP_
#define _LBFGS_HPP_

#include <vector>
#include <iostream>
#include <cmath>
#include "MathVec.hpp"

using namespace std;

template <typename FuncGrad, typename RealNumber = double>
class LBFGS {
public:
  static const unsigned int DEFAULT_NUM_VECTORS = 7;
  static const RealNumber LINE_SEARCH_ALPHA = 0.1;
  static const RealNumber LINE_SEARCH_BETA  = 0.5;
  static const RealNumber MIN_GRAD_NORM = 0.0001;
  static const RealNumber MIN_FUNC_DIFF = 0.000001;

  static const unsigned int MAX_LINE_SEARCH_ITERATIONS = 100;

protected:
  FuncGrad& func_grad;
  unsigned int num_vectors;

  size_t dim;  // dimension
  MathVec<RealNumber> x;       // parameter
  MathVec<RealNumber> grad, grad1;    // gradient
  MathVec<RealNumber> dx;      // difference
  RealNumber f;    // function value
  RealNumber grad_norm;  // norm of gradient
  std::vector<MathVec<RealNumber> > s, y;
  std::vector<RealNumber> z;    // rho
  bool is_beginning;
  bool is_converged;
  size_t iter;

public:
  LBFGS(FuncGrad& f, size_t d, unsigned int v = DEFAULT_NUM_VECTORS)
    : func_grad(f), num_vectors(v),
      dim(d), x(d), grad(d), grad1(d), dx(d), f(0.0), grad_norm(0.0), s(v), y(v), z(v),
      is_beginning(true), is_converged(false), iter(0)
  {}
  virtual ~LBFGS() {}

public:
  virtual RealNumber 
  backtracking_line_search(const MathVec<RealNumber> & grad0, const RealNumber f0, 
                           MathVec<RealNumber> & dx, MathVec<RealNumber> & grad1)
  {
    RealNumber t = 1.0 / this->LINE_SEARCH_BETA;
    RealNumber old_t = 0.0;
    RealNumber tolerance = this->LINE_SEARCH_ALPHA * dot_product(dx, grad0);

    RealNumber f = 0.0;
    unsigned int i = 0;
    do {
      if (++i > this->MAX_LINE_SEARCH_ITERATIONS) {
        cerr << "Warning: line search not terminated" << endl;
        is_converged = true;
        break;
      }
      t *= this->LINE_SEARCH_BETA;
      //if(t==0) cerr << " looping infinitely " << endl;
      //cerr << "line search: " << t << endl;
      //x = x0 + t * dx;
      x += (t-old_t)*dx;
      f = func_grad(x.STLVec(), grad1.STLVec());
      //        cout << "*";
      //cerr << "f=" << f << ", old=" << f0 << endl;
      //cerr << f0 + t * tolerance << endl;
      old_t = t;
    } while (f > f0 + t * tolerance);

    dx *= t;

    // check convergence
    if (fabs(1.0 - fabs(f / f0)) < this->MIN_FUNC_DIFF) {
      is_converged = true;
    }

    return f;
  }

  virtual MathVec<RealNumber> 
  approximate_Hg(const unsigned int iter, const MathVec<RealNumber> & grad)
  {
    int offset, bound;
    if (iter <= num_vectors) { offset = 0;        bound = iter; }
    else           { offset = iter - num_vectors; bound = num_vectors;    }

    MathVec<RealNumber> q = grad;
    RealNumber alpha[num_vectors], beta[num_vectors];
    for (int i = bound - 1; i >= 0; i--) {
      const int j = (i + offset) % num_vectors;
      alpha[i]    = z[j]   * dot_product(s[j], q);
      q          += -alpha[i] * y[j];
    }
    if (iter > 0) {
      const int j = (iter - 1) % num_vectors;
      const RealNumber gamma = ((1.0 / z[j]) / dot_product(y[j], y[j]));
      //    static RealNumber gamma;
      //    if (gamma == 0) gamma = ((1.0 / z[j]) / dot_product(y[j], y[j]));
      q *= gamma;
    }
    for (int i = 0; i <= bound - 1; i++) {
      const int j = (i + offset) % num_vectors;
      beta[i]     = z[j] * dot_product(y[j], q);
      q          += s[j] * (alpha[i] - beta[i]);
    }

    return q;
  }

public:
  virtual void initialize(size_t d = 0) {
    // resize dimension
    dim = d;
    x.resize(d);
    fill(x.STLVec().begin(), x.STLVec().end(), 0.0);
    grad.resize(d);
    grad1.resize(d);
    dx.resize(d);
    is_beginning = true;
    is_converged = false;
    iter = 0;
  }
  virtual void initialize(const std::vector<RealNumber>& weight) {
    // resize dimension
    dim = weight.size();
    x = weight;
    grad.resize(dim);
    grad1.resize(dim);
    dx.resize(dim);
    is_beginning = true;
    is_converged = false;
    iter = 0;
  }

  virtual void iteration() {
    if(is_beginning) {
      //cerr << "beginning" << endl;
      f = func_grad(x.STLVec(), grad.STLVec());
      is_beginning = false;
    }
    //cerr << "iteration " << iter << ": obj = " << f << endl;
    //cerr << "|grad|=" << sqrt(dot_product(grad, grad)) << endl;
    grad_norm = sqrt(dot_product(grad, grad));
    if (grad_norm < this->MIN_GRAD_NORM) {
        is_converged = true;
        return;
    }

    dx = -1.0 * approximate_Hg(iter, grad);

    f = backtracking_line_search(grad, f, dx, grad1);

    s[iter % num_vectors] = dx;
    y[iter % num_vectors] = grad1 - grad;
    z[iter % num_vectors] = 1.0 / dot_product(y[iter % num_vectors], s[iter % num_vectors]);
    grad = grad1;
    ++iter;
  }

  virtual bool isConverged() const{
    return is_converged;
  }
  virtual RealNumber funcVal() const {
    return f;
  }
  virtual RealNumber gradNorm() const {
    return grad_norm;
  }

//   vector<RealNumber> 
//   perform_LBFGS(const vector<RealNumber> & x0)
//   {
//     const size_t dim = x0.size();
//     MathVec x = x0;

//     MathVec grad(dim), dx(dim);
//     RealNumber f = func_grad(x.STLVec(), grad.STLVec());

//     MathVec s[num_vectors], y[num_vectors];
//     RealNumber z[num_vectors];  rho

//     for (int iter = 0; iter < max_iter; iter++) {

//       cerr << "iteration " << iter << ": obj = " << f << endl;

//       cerr << "|grad|=" << sqrt(dot_product(grad, grad)) << endl;
//       if (sqrt(dot_product(grad, grad)) < this->MIN_GRAD_NORM) break;

//       dx = -1 * approximate_Hg(iter, grad, s, y, z);

//       MathVec x1(dim), grad1(dim);
//       f = backtracking_line_search(x, grad, f, dx, x1, grad1);

//       s[iter % num_vectors] = x1 - x;
//       y[iter % num_vectors] = grad1 - grad;
//       z[iter % num_vectors] = 1.0 / dot_product(y[iter % num_vectors], s[iter % num_vectors]);
//       x = x1;
//       grad = grad1;
//     }

//     return x.STLVec();
//   }

};

//template<class FuncGrad>
//std::vector<RealNumber> 
//perform_LBFGS(FuncGrad func_grad, const std::vector<RealNumber> & x0);

// std::vector<RealNumber> 
// perform_LBFGS(RealNumber (*func_grad)(const std::vector<RealNumber> &, std::vector<RealNumber> &), 
// 	      const std::vector<RealNumber> & x0);


// std::vector<RealNumber> 
// perform_OWLQN(RealNumber (*func_grad)(const std::vector<RealNumber> &, std::vector<RealNumber> &), 
// 	      const std::vector<RealNumber> & x0,
// 	      const RealNumber C);

// const int    LBFGS_M = 7;

#endif
