/**
 * Copyright (C) 2007-2013 Lawrence Murray
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the Free
 * Software Foundation; either version 2 of the License, or (at your option)
 * any later version.
 *
 * @author Lawrence Murray <lawrence@indii.org>
 * $Rev$
 * $Date$
 */
#include "EdgeDistanceMap.hpp"

//#include <sys/time.h>
#ifdef _OPENMP
#include <omp.h>
#else
int omp_get_max_threads() {
  return 1;
}

int omp_get_num_threads() {
  return 1;
}

int omp_get_thread_num() {
  return 0;
}
#endif

using namespace indii;

#ifdef _MSC_VER
#define INDII_ALIGN(num) __declspec(align(num))
#else
#define INDII_ALIGN(num) __attribute((aligned (16)))
#endif

EdgeDistanceMap::EdgeDistanceMap() {
  useSSE = true;
  try {
    /* test some SSE operations, invalid instruction may be thrown */
    INDII_ALIGN(16) float f1[] = { 0.0f, 1.0f, 2.0f, 3.0f };
    INDII_ALIGN(16) float f2[] = { 4.0f, 5.0f, 6.0f, 7.0f };
    INDII_ALIGN(16) float f3[] = { 0.0, 0.0, 0.0, 0.0 };
    const __m128 abs_mask = _mm_set1_ps(-0.0f);
    __m128 m1, m2, m3;

    m1 = _mm_load_ps(f1);
    m2 = _mm_load_ps(f2);
    m3 = _mm_shuffle_ps(m1, m2, _MM_SHUFFLE(0,0,3,2));
    m3 = _mm_andnot_ps(abs_mask, _mm_sub_ps(m1, m2));

    _mm_store_ps(f3, m3);
  } catch (int e) {
    useSSE = false;
  }
}

void EdgeDistanceMap::map(const int x0, const int y0, nested_matrix& L,
    nested_matrix& D, const float k) {
  /* pre-condition */
  assert (L.size1() == D.size1() && L.size2() == D.size2());

  D.inf(); // D boundary is inf, L boundary is zero
  D(y0, x0) = 0.0;

  //timeval start, end;
  //gettimeofday(&start, NULL);

  if (useSSE) {
    /* omp implementation */
    int* progress = new int[omp_get_max_threads()];

    #pragma omp parallel shared(progress)
    {
      int tid = omp_get_thread_num();
      int num_threads = omp_get_num_threads();
      int x1, y1, cols1;
      int x_len1, x_start1, x_end1;
      int cols2;
      int x_len2, x_start2, x_end2;

      x1 = x0 - (x0 % 4); // will ensure 16-byte aligned
      y1 = y0;
      cols1 = L.size2() - x1;
      cols2 = L.size2();

      x_len1 = cols1/num_threads;
      x_len1 = x_len1 - (x_len1 % 4); // ensure multiple of 16 bytes
      x_start1 = x1 + tid*x_len1;
      if (tid == num_threads - 1) {
        x_end1 = (L.size2() - 1 + 3)/4*4;
      } else {
        x_end1 = x_start1 + x_len1;
      }

      x_len2 = cols2/num_threads;
      x_len2 = x_len2 - (x_len2 % 4); // ensure multiple of 16 bytes
      x_start2 = tid*x_len2;
      if (tid == num_threads - 1) {
        x_end2 = (L.size2() - 1 + 3)/4*4;
      } else {
        x_end2 = x_start2 + x_len2;
      }
      
      forwardSSE(x_start1, y1, x_end1, L.size1() - 1, k, L, D, progress);
      backwardSSE(x_end2 - 4, L.size1() - 1, x_start2, 0, k, L, D, progress);
      forwardSSE(x_start2, 0, x_end2, L.size1() - 1, k, L, D, progress);
      backwardSSE(x_end2 - 4, L.size1() - 1, x_start2, 0, k, L, D, progress);
      forwardSSE(x_start2, 0, x_end2, L.size1() - 1, k, L, D, progress);
    }
  } else {
    forward(x0, y0, L.size2() - 1, L.size1() - 1, k, L, D);
    backward(L.size2() - 2, L.size1() - 1, 0, 0, k, L, D);
    forward(0, 0, L.size2() - 1, L.size1() - 1, k, L, D);
    backward(L.size2() - 2, L.size1() - 1, 0, 0, k, L, D);
    forward(0, 0, L.size2() - 1, L.size1() - 1, k, L, D);
  }

  //gettimeofday(&end, NULL);
  //long startusec = start.tv_sec*1e6 + start.tv_usec;
  //long endusec = end.tv_sec*1e6 + end.tv_usec;
  //std::cerr << "distance: " << endusec - startusec << std::endl;
}

void EdgeDistanceMap::forward(const int x0, const int y0,
    const int x1, const int y1, const float k, nested_matrix& L, nested_matrix& D) {
  /* pre-condition */
  assert (x0 >= 0 && x0 < L.size2());
  assert (x1 >= 0 && x1 < L.size2());
  assert (y0 >= 0 && y0 < L.size1());
  assert (y1 >= 0 && y1 < L.size1());

  int x, y;
  float e, ew, enw, en, ene;
  float l, lw, lnw, ln, lne;

  for (y = y0; y <= y1; ++y) {
    x = x0;

    ew = D(y, x - 1);
    enw = D(y - 1, x - 1);
    en = D(y - 1, x);

    lw = L(y, x - 1);
    lnw = L(y - 1, x - 1);
    ln = L(y - 1, x);

    for (; x < D.size2(); ++x) {
      e = D(y, x);
      ene = D(y - 1, x + 1);

      l = L(y, x);
      lne = L(y - 1, x + 1);

      e = std::min(e, ew + threshold(fabsf(lw - l), k));
      e = std::min(e, enw + threshold(fabsf(lnw - l), k));
      e = std::min(e, en + threshold(fabsf(ln - l), k));
      e = std::min(e, ene + threshold(fabsf(lne - l), k));

      D(y,x) = e;

      ew = e;
      enw = en;
      en = ene;

      lw = l;
      lnw = ln;
      ln = lne;
    }
  }
}

void EdgeDistanceMap::backward(const int x0, const int y0,
    const int x1, const int y1, const float k, nested_matrix& L, nested_matrix& D) {
  /* pre-condition */
  assert (x0 >= 0 && x0 < L.size2());
  assert (x1 >= 0 && x1 < L.size2());
  assert (y0 >= 0 && y0 < L.size1());
  assert (y1 >= 0 && y1 < L.size1());

  int x, y;
  float e, ee, esw, es, ese;
  float l, le, lsw, ls, lse;

  for (y = y0; y >= y1; --y) {
    x = x0;

    ee = D(y, x + 1);
    es = D(y + 1, x);
    ese = D(y + 1, x + 1);

    le = L(y, x + 1);
    ls = L(y + 1, x);
    lse = L(y + 1, x + 1);

    for (; x >= 0; --x) {
      e = D(y, x);
      esw = D(y + 1, x - 1);

      l = L(y, x);
      lsw = L(y + 1, x - 1);

      e = std::min(e, ee + threshold(fabsf(le - l), k));
      e = std::min(e, esw + threshold(fabsf(lsw - l), k));
      e = std::min(e, es + threshold(fabsf(ls - l), k));
      e = std::min(e, ese + threshold(fabsf(lse - l), k));

      D(y,x) = e;

      ee = e;
      ese = es;
      es = esw;

      le = l;
      lse = ls;
      ls = lsw;
    }
  }
}

void EdgeDistanceMap::forwardSSE(const int x0, const int y0,
    const int x1, const int y1, const float k, nested_matrix& L, nested_matrix& D, int* progress) {
  const __m128 k1 = _mm_set1_ps(k);
  const __m128 abs_mask = _mm_set1_ps(-0.0f);
  __m128 e, enw, en, ene, ene1;
  __m128 l, lnw, ln, lne, lne1;
  float fe, few, fl, flw;
  int x, y, i;

  int tid = omp_get_thread_num();
  progress[tid] = y0;
  #pragma omp flush

  for (y = y0; y <= y1; ++y) {
    /* make sure we're not ahead of adjacent thread */
    while (tid != 0 && progress[tid - 1] <= y) {
      #pragma omp flush
    }

    x = x0;

    ene = _mm_loadu_ps(&D(y - 1, x - 3));
    lne = _mm_loadu_ps(&L(y - 1, x - 3));
    few = D(y, x - 1);
    flw = L(y, x - 1);

    for (; x < x1; x += 4) {
      /* loads */
      e = _mm_load_ps(&D(y, x));
      ene1 = _mm_loadu_ps(&D(y - 1, x + 1));

      //_mm_prefetch(reinterpret_cast<char*>(&D(y, x) + 48), _MM_HINT_T0);
      //_mm_prefetch(reinterpret_cast<char*>(&D(y - 1, x) + 49), _MM_HINT_T0);

      enw = _mm_shuffle_ps(ene, ene1, _MM_SHUFFLE(1,0,3,2));
      en = _mm_shuffle_ps(enw, ene1, _MM_SHUFFLE(2,1,2,1));
      ene = ene1;

      l = _mm_load_ps(&L(y, x));
      lne1 = _mm_loadu_ps(&L(y - 1, x + 1));

      //_mm_prefetch(reinterpret_cast<char*>(&L(y , x) + 48), _MM_HINT_T0);
      //_mm_prefetch(reinterpret_cast<char*>(&L(y - 1, x) + 49), _MM_HINT_T0);

      lnw = _mm_shuffle_ps(lne, lne1, _MM_SHUFFLE(1,0,3,2));
      ln = _mm_shuffle_ps(lnw, lne1, _MM_SHUFFLE(2,1,2,1));
      lne = lne1;

      /* distances to north-west, north and north-east */
      e = _mm_min_ps(e, _mm_add_ps(enw, threshold(_mm_andnot_ps(abs_mask, _mm_sub_ps(lnw, l)), k1)));
      e = _mm_min_ps(e, _mm_add_ps(en, threshold(_mm_andnot_ps(abs_mask, _mm_sub_ps(ln, l)), k1)));
      e = _mm_min_ps(e, _mm_add_ps(ene, threshold(_mm_andnot_ps(abs_mask, _mm_sub_ps(lne, l)), k1)));
      _mm_store_ps(&D(y, x), e);

      /* distances to west */
      for (i = 0; i < 4; ++i) {
        fe = D(y, x + i);
        fl = L(y, x + i);
        fe = std::min(fe, few + threshold(fabsf(flw - fl), k));
        D(y, x + i) = fe;
        few = fe;
        flw = fl;
      }
    }

    ++progress[tid];
    #pragma omp flush
  }

  #pragma omp barrier
}

void EdgeDistanceMap::backwardSSE(const int x0, const int y0,
    const int x1, const int y1, const float k, nested_matrix& L, nested_matrix& D,
    int* progress) {
  const __m128 abs_mask = _mm_set1_ps(-0.0f);
  const __m128 k1 = _mm_set1_ps(k);
  __m128 e, esw, es, ese, esw1;
  __m128 l, lsw, ls, lse, lsw1;
  float fe, fee, fl, fle;
  int x, y, i;

  int tid = omp_get_thread_num();
  int num_threads = omp_get_num_threads();
  progress[tid] = y0;
  #pragma omp flush

  for (y = y0; y >= y1; --y) {
    /* make sure we're not ahead of adjacent thread */
    while (tid != num_threads - 1 && progress[tid + 1] >= y) {
      #pragma omp flush
    }

    x = x0;

    esw = _mm_loadu_ps(&D(y + 1, x + 3));
    lsw = _mm_loadu_ps(&L(y + 1, x + 3));
    fee = D(y, x + 4);
    fle = L(y, x + 4);

    for (; x >= x1; x -= 4) {
      /* loads */
      e = _mm_load_ps(&D(y, x));
      esw1 = _mm_loadu_ps(&D(y + 1, x - 1));

      //_mm_prefetch(reinterpret_cast<char*>(&D(y, x) - 48), _MM_HINT_T0);
      //_mm_prefetch(reinterpret_cast<char*>(&D(y - 1, x) - 49), _MM_HINT_T0);

      ese = _mm_shuffle_ps(esw1, esw, _MM_SHUFFLE(1,0,3,2));
      es = _mm_shuffle_ps(esw1, ese, _MM_SHUFFLE(2,1,2,1));
      esw = esw1;

      l = _mm_load_ps(&L(y, x));
      lsw1 = _mm_loadu_ps(&L(y + 1, x - 1));

      //_mm_prefetch(reinterpret_cast<char*>(&L(y , x) - 48), _MM_HINT_T0);
      //_mm_prefetch(reinterpret_cast<char*>(&L(y - 1, x) - 49), _MM_HINT_T0);

      lse = _mm_shuffle_ps(lsw1, lsw, _MM_SHUFFLE(1,0,3,2));
      ls = _mm_shuffle_ps(lsw1, lse, _MM_SHUFFLE(2,1,2,1));
      lsw = lsw1;

      /* distances to south-east, south and south-west */
      e = _mm_min_ps(e, _mm_add_ps(ese, threshold(_mm_andnot_ps(abs_mask, _mm_sub_ps(lse, l)), k1)));
      e = _mm_min_ps(e, _mm_add_ps(es, threshold(_mm_andnot_ps(abs_mask, _mm_sub_ps(ls, l)), k1)));
      e = _mm_min_ps(e, _mm_add_ps(esw, threshold(_mm_andnot_ps(abs_mask, _mm_sub_ps(lsw, l)), k1)));
      _mm_store_ps(&D(y,x), e);

      /* distances to east */
      for (i = 3; i >= 0; --i) {
        fe = D(y, x + i);
        fl = L(y, x + i);
        fe = std::min(fe, fee + threshold(fabsf(fle - fl), k));
        D(y, x + i) = fe;
        fee = fe;
        fle = fl;
      }
    }

    --progress[tid];
    #pragma omp flush
  }

  #pragma omp barrier
}
