/**
 * Copyright (C) 2007-2013 Lawrence Murray
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the Free
 * Software Foundation; either version 2 of the License, or (at your option)
 * any later version.
 *
 * @author Lawrence Murray <lawrence@indii.org>
 * $Rev$
 * $Date$
 */
#ifndef INDII_KRIG_SIMPLEKRIGER_HPP
#define INDII_KRIG_SIMPLEKRIGER_HPP

#include "EdgeDistanceMap.hpp"
#include "nested_matrix.hpp"
#include "../image/ImageResource.hpp"
#include "../image/ColourSpace.hpp"

#include "boost/numeric/ublas/vector.hpp"
#include "boost/numeric/ublas/matrix.hpp"
#include "boost/numeric/ublas/matrix_proxy.hpp"
#include "boost/numeric/ublas/operation.hpp"
#include "boost/numeric/ublas/lu.hpp"

#include <list>

namespace indii {
/**
 * Simple Kriging.
 */
class SimpleKriger {
public:
  /**
   * Constructor.
   *
   * @param res Image resource.
   */
  SimpleKriger(ImageResource* res);

  /**
   * Get number of control points.
   *
   * @return Number of control points.
   */
  int getNumControls() const;

  /**
   * Add control point.
   *
   * @param x X-coordinate of control point.
   * @param y Y-coordinate of control point.
   * @param z Output of control point.
   */
  void addControl(const int x, const int y, const float z = 0.0f);

  /**
   * Remove control point.
   *
   * @param i Index of the control point.
   */
  void removeControl(const int i);

  /**
   * Get control point.
   *
   * @param i Index of control point.
   * @param[out] x X-coordinate.
   * @param[out] y Y-coordinate.
   */
  void getControl(const int i, int* x, int* y);

  /**
   * Set control point.
   *
   * @param i Index of control point.
   * @param x X-coordinate.
   * @param y Y-coordinate.
   */
  void setControl(const int i, const int x, const int y);

  /**
   * Set the output at a control point.
   *
   * @param i Index of control point.
   *
   * @return Output.
   */
  float getOutput(const int i);

  /**
   * Set the output at a control point.
   *
   * @param i Index of control point.
   * @param z Output.
   */
  void setOutput(const int i, const float z);

  /**
   * Get length scale.
   *
   * @return Length scale.
   */
  float getLength() const;

  /**
   * Set length scale.
   *
   * @param len Length scale.
   */
  void setLength(const float len);

  /**
   * Get scaled length.
   *
   * @return Length scale for current distance maps.
   */
  float getScaledLength() const;

  /**
   * Get output standard deviation.
   */
  float getOutputStd() const;

  /**
   * Set output standard deviation.
   */
  void setOutputStd(const float outputStd);

  /**
   * Get edge sensitivity.
   */
  float getEdge() const;

  /**
   * Set edge sensitivity.
   */
  void setEdge(const float edge);

  /**
   * Query output.
   */
  template<class V1>
  void query(const wxRect& rect, const float scale, V1& l) const;

  /**
   * Query output.
   */
  float query(const int x, const int y);

private:
  /**
   * Update lightness channel of source image.
   */
  void updateLightness();

  /**
   * Update distance map for given control point.
   *
   * @param i Control point index.
   */
  void updateDistance(const int i);

  /**
   * Update distance maps.
   */
  void updateDistance();

  /**
   * Update spread function.
   */
  void updateSpread();

  /**
   * Update covariance map for given control point.
   *
   * @param i Control point index.
   */
  void updateCovariance(const int i);

  /**
   * Update covariance maps.
   */
  void updateCovariance();

  /**
   * Update weight matrix.
   */
  void updateWeights();

  /**
   * Image resource.
   */
  ImageResource* res;

  /**
   * Length scale.
   */
  float len;

  /**
   * Output standard deviation.
   */
  float outputStd;

  /**
   * Signal covariance.
   */
  float s2;

  /**
   * Edge sensitivity.
   */
  float edge;

  /**
   * Original lightness channel of image.
   */
  nested_matrix L;

  /**
   * Control point distance maps.
   */
  std::list<nested_matrix> D;

  /**
   * Sorted distances for scaled spread.
   */
  boost::numeric::ublas::vector<float> d;

  /**
   * Control points.
   */
  boost::numeric::ublas::matrix<int> v;

  /**
   * Weight vector.
   */
  boost::numeric::ublas::vector<float> w;

  /**
   * Control point covariance matrix.
   */
  boost::numeric::ublas::matrix<float> C;

  /**
   * LU decomposition of covariance matrix. Note Cholesky factorisation would
   * be better here, but Boost.uBLAS provides no such function. As the number
   * of control points should be few anyway, the LU decomposition is
   * preferred to the additional LAPACK and Boost bindings dependencies that
   * the Cholesky would introduce.
   */
  boost::numeric::ublas::matrix<float> LU;

  /**
   * Permutation matrix for LU decomposition.
   */
  boost::numeric::ublas::permutation_matrix<std::size_t> pm;

  /**
   * Control point outputs.
   */
  boost::numeric::ublas::vector<float> z;

  /**
   * Colour space model.
   */
  ColourSpace cs;

  /**
   * Distance mapper.
   */
  EdgeDistanceMap distance;
};
}

inline int indii::SimpleKriger::getNumControls() const {
  return v.size1();
}

inline void indii::SimpleKriger::getControl(const int i, int* x, int* y) {
  /* pre-conditions */
  assert(i >= 0 && i < getNumControls());
  assert(x != NULL && y != NULL);

  *x = v(i, 0);
  *y = v(i, 1);
}

inline void indii::SimpleKriger::setControl(const int i, const int x,
    const int y) {
  /* pre-conditions */
  assert(i < getNumControls());

  v(i, 0) = x;
  v(i, 1) = y;

  updateDistance(i);
  updateSpread();
  updateCovariance(i);
  updateWeights();
}

inline float indii::SimpleKriger::getOutput(const int i) {
  /* pre-condition */
  assert(i >= 0 && i < getNumControls());

  return this->z(i);
}

inline void indii::SimpleKriger::setOutput(const int i, const float z) {
  /* pre-condition */
  assert(i >= 0 && i < getNumControls());

  this->z(i) = z;
  updateWeights();
}

inline float indii::SimpleKriger::getLength() const {
  return len;
}

inline void indii::SimpleKriger::setLength(const float len) {
  /* pre-condition */
  assert(len > 0.0f && len <= 1.0f);

  this->len = len;
  updateCovariance();
  updateWeights();
}

inline float indii::SimpleKriger::getScaledLength() const {
  float scaledLen = len;
  if (d.size() > 0) {
    float quantile = len;
    int q = static_cast<int>(quantile * d.size());
    scaledLen = d[q];
  }
  return scaledLen;
}

inline float indii::SimpleKriger::getOutputStd() const {
  return outputStd;
}

inline void indii::SimpleKriger::setOutputStd(const float outputStd) {
  this->outputStd = outputStd;
  updateCovariance();
  updateWeights();
}

inline float indii::SimpleKriger::getEdge() const {
  return edge;
}

inline void indii::SimpleKriger::setEdge(const float edge) {
  /* pre-condition */
  assert(edge >= 0.0f && edge <= 1.0f);

  this->edge = edge;
  updateDistance();
  updateSpread();
  updateCovariance();
  updateWeights();
}

template<class V1>
inline void indii::SimpleKriger::query(const wxRect& rect, const float scale, V1& l) const {
  /* pre-condition */
  assert((int )l.size() == rect.height * rect.width);

  const float scaledLen = getScaledLength();
  const float z = -0.5f / (scaledLen * scaledLen);

  boost::numeric::ublas::matrix<float> S(getNumControls(),
      rect.height * rect.width);
  int x, y, x1, y1, i;
  float d;

  /* compute covariances */
  std::list<nested_matrix>::const_iterator iter = this->D.begin();
  for (i = 0; i < getNumControls(); ++i, ++iter) {
    const nested_matrix& D = *iter;
    for (y = 0; y < rect.height; ++y) {
      for (x = 0; x < rect.width; ++x) {
        x1 = scale * (rect.x + x);
        y1 = scale * (rect.y + y);

        d = D(y1, x1);
        S(i, y * rect.width + x) = s2 * expf(z * d * d);
      }
    }
  }

  /* compute outputs */
  l = prod(trans(w), S);
}

inline float indii::SimpleKriger::query(const int x, const int y) {
  const float scaledLen = getScaledLength();
  const float z = -0.5f / (scaledLen * scaledLen);

  boost::numeric::ublas::vector<float> S(getNumControls());
  float d;
  int i;
  std::list<nested_matrix>::const_iterator iter = this->D.begin();
  for (i = 0; i < getNumControls(); ++i, ++iter) {
    const nested_matrix& D = *iter;
    d = D(y, x);
    S(i) = s2 * expf(z * d * d);
  }

  return inner_prod(w, S);
}

#endif
