/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

        Timstat2        timcor      correlates two data files on the same grid
*/

#include <cdi.h>

#include "functs.h"
#include "process_int.h"
#include "cdo_vlist.h"

// correlation in time
static void
correlation_init(size_t gridsize, const double *array1, const double *array2, double missval1, double missval2, size_t *nofvals,
                 double *work0, double *work1, double *work2, double *work3, double *work4)
{
  for (size_t i = 0; i < gridsize; ++i)
    {
      if ((!DBL_IS_EQUAL(array1[i], missval1)) && (!DBL_IS_EQUAL(array2[i], missval2)))
        {
          work0[i] += array1[i];
          work1[i] += array2[i];
          work2[i] += array1[i] * array1[i];
          work3[i] += array2[i] * array2[i];
          work4[i] += array1[i] * array2[i];
          nofvals[i]++;
        }
    }
}

static size_t
correlation(size_t gridsize, double missval1, double missval2, size_t *nofvals, double *work0, double *work1, double *work2,
            double *work3, double *work4)
{
  size_t nmiss = 0;

  for (size_t i = 0; i < gridsize; ++i)
    {
      double cor;
      const auto nvals = nofvals[i];
      if (nvals > 0)
        {
          const auto temp0 = MULMN(work0[i], work1[i]);
          const auto temp1 = SUBMN(work4[i], DIVMN(temp0, nvals));
          const auto temp2 = MULMN(work0[i], work0[i]);
          const auto temp3 = MULMN(work1[i], work1[i]);
          const auto temp4 = SUBMN(work2[i], DIVMN(temp2, nvals));
          const auto temp5 = SUBMN(work3[i], DIVMN(temp3, nvals));
          const auto temp6 = MULMN(temp4, temp5);

          cor = DIVMN(temp1, SQRTMN(temp6));
          cor = std::min(std::max(cor, -1.0), 1.0);

          if (DBL_IS_EQUAL(cor, missval1)) nmiss++;
        }
      else
        {
          nmiss++;
          cor = missval1;
        }

      work0[i] = cor;
    }

  return nmiss;
}

// covariance in time
static void
covariance_init(size_t gridsize, const double *array1, const double *array2, double missval1, double missval2, size_t *nofvals,
                double *work0, double *work1, double *work2)
{
  for (size_t i = 0; i < gridsize; ++i)
    {
      if ((!DBL_IS_EQUAL(array1[i], missval1)) && (!DBL_IS_EQUAL(array2[i], missval2)))
        {
          work0[i] += array1[i];
          work1[i] += array2[i];
          work2[i] += array1[i] * array2[i];
          nofvals[i]++;
        }
    }
}

static size_t
covariance(size_t gridsize, double missval1, double missval2, size_t *nofvals, double *work0, double *work1, double *work2)
{
  size_t nmiss = 0;

  for (size_t i = 0; i < gridsize; ++i)
    {
      double covar;
      const auto nvals = nofvals[i];
      if (nvals > 0)
        {
          double dnvals = nvals;
          const auto temp = DIVMN(MULMN(work0[i], work1[i]), dnvals * dnvals);
          covar = SUBMN(DIVMN(work2[i], dnvals), temp);
          if (DBL_IS_EQUAL(covar, missval1)) nmiss++;
        }
      else
        {
          nmiss++;
          covar = missval1;
        }

      work0[i] = covar;
    }

  return nmiss;
}


// rms in time
static void
rmsd_init(size_t gridsize, const double *x, const double *y, double xmissval, double ymissval, size_t *nofvals,
         double *rmsd)
{
  for (size_t i = 0; i < gridsize; ++i)
    {
      if ((!DBL_IS_EQUAL(x[i], xmissval)) && (!DBL_IS_EQUAL(y[i], ymissval)))
        {
          rmsd[i] += ((x[i] - y[i]) * (x[i] - y[i]));
          nofvals[i]++;
        }
    }
}

static size_t
rmsd_compute(size_t gridsize, double missval, size_t *nofvals, double *rmsd)
{
  size_t nmiss = 0;

  for (size_t i = 0; i < gridsize; ++i)
    {
      if (nofvals[i] > 0)
        {
          rmsd[i] = std::sqrt(rmsd[i] / (double)nofvals[i]);
        }
      else
        {
          nmiss++;
          rmsd[i] = missval;
        }
    }

  return nmiss;
}

void *
Timstat2(void *process)
{
  int64_t vdate = 0;
  int vtime = 0;
  int varID, levelID;
  size_t nmiss = 0;

  cdoInitialize(process);

  // clang-format off
  cdoOperatorAdd("timcor",   func_cor,   5, nullptr);
  cdoOperatorAdd("timcovar", func_covar, 3, nullptr);
  cdoOperatorAdd("timrmsd",  func_rmsd,  1, nullptr);
  // clang-format on

  const auto operatorID = cdoOperatorID();
  const auto operfunc = cdoOperatorF1(operatorID);
  const auto nwork = cdoOperatorF2(operatorID);
  const auto timeIsConst = (operfunc == func_rmsd);

  operatorCheckArgc(0);

  const auto streamID1 = cdoOpenRead(0);
  const auto streamID2 = cdoOpenRead(1);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = cdoStreamInqVlist(streamID2);
  const auto vlistID3 = vlistDuplicate(vlistID1);

  vlistCompare(vlistID1, vlistID2, CMP_ALL);

  VarList varList1, varList2;
  varListInit(varList1, vlistID1);
  varListInit(varList2, vlistID2);

  const auto nvars = vlistNvars(vlistID1);
  auto nrecs = vlistNrecs(vlistID1);
  const int nrecs3 = nrecs;
  std::vector<int> recVarID(nrecs), recLevelID(nrecs);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  // const auto taxisID2 = vlistInqTaxis(vlistID2);
  const auto taxisID3 = taxisDuplicate(taxisID1);

  if (timeIsConst)
    for (varID = 0; varID < nvars; ++varID)
      vlistDefVarTimetype(vlistID3, varID, TIME_CONSTANT);

  vlistDefTaxis(vlistID3, taxisID3);
  const auto streamID3 = cdoOpenWrite(2);
  cdoDefVlist(streamID3, vlistID3);

  const auto gridsizemax = vlistGridsizeMax(vlistID1);
  Varray<double> array1(gridsizemax), array2(gridsizemax);

  Varray4D<double> work(nvars);
  Varray3D<size_t> nofvals(nvars);

  for (varID = 0; varID < nvars; varID++)
    {
      const auto gridsize = varList1[varID].gridsize;
      const auto nlevs = varList1[varID].nlevels;

      work[varID].resize(nlevs);
      nofvals[varID].resize(nlevs);

      for (levelID = 0; levelID < nlevs; levelID++)
        {
          nofvals[varID][levelID].resize(gridsize, 0);
          work[varID][levelID].resize(nwork);
          for (int i = 0; i < nwork; i++) work[varID][levelID][i].resize(gridsize, 0.0);
        }
    }

  int tsID = 0;
  while ((nrecs = cdoStreamInqTimestep(streamID1, tsID)))
    {
      vdate = taxisInqVdate(taxisID1);
      vtime = taxisInqVtime(taxisID1);

      auto nrecs2 = cdoStreamInqTimestep(streamID2, tsID);
      if (nrecs != nrecs2) cdoWarning("Input streams have different number of records!");

      for (int recID = 0; recID < nrecs; recID++)
        {
          cdoInqRecord(streamID1, &varID, &levelID);
          cdoInqRecord(streamID2, &varID, &levelID);

          if (tsID == 0)
            {
              recVarID[recID] = varID;
              recLevelID[recID] = levelID;
            }

          const auto gridsize = varList1[varID].gridsize;
          const auto missval1 = varList1[varID].missval;
          const auto missval2 = varList2[varID].missval;

          cdoReadRecord(streamID1, &array1[0], &nmiss);
          cdoReadRecord(streamID2, &array2[0], &nmiss);

          auto &rwork = work[varID][levelID];
          auto &rnofvals = nofvals[varID][levelID];

          if (operfunc == func_cor)
            {
              correlation_init(gridsize, array1.data(), array2.data(), missval1, missval2, rnofvals.data(),
                               rwork[0].data(), rwork[1].data(), rwork[2].data(), rwork[3].data(), rwork[4].data());
            }
          else if (operfunc == func_covar)
            {
              covariance_init(gridsize, array1.data(), array2.data(), missval1, missval2, rnofvals.data(),
                              rwork[0].data(), rwork[1].data(), rwork[2].data());
            }
          else if (operfunc == func_rmsd)
            {
              rmsd_init(gridsize, array1.data(), array2.data(), missval1, missval2, rnofvals.data(), rwork[0].data());
            }
        }

      tsID++;
    }

  tsID = 0;
  taxisDefVdate(taxisID3, vdate);
  taxisDefVtime(taxisID3, vtime);
  cdoDefTimestep(streamID3, tsID);

  for (int recID = 0; recID < nrecs3; recID++)
    {
      varID = recVarID[recID];
      levelID = recLevelID[recID];

      const auto gridsize = varList1[varID].gridsize;
      const auto missval1 = varList1[varID].missval;
      const auto missval2 = varList2[varID].missval;

      auto &rwork = work[varID][levelID];
      auto &rnofvals = nofvals[varID][levelID];

      if (operfunc == func_cor)
        {
          nmiss = correlation(gridsize, missval1, missval2, rnofvals.data(), rwork[0].data(),
                              rwork[1].data(), rwork[2].data(), rwork[3].data(), rwork[4].data());
        }
      else if (operfunc == func_covar)
        {
          nmiss = covariance(gridsize, missval1, missval2, rnofvals.data(), rwork[0].data(),
                             rwork[1].data(), rwork[2].data());
        }
      else if (operfunc == func_rmsd)
        {
          nmiss = rmsd_compute(gridsize, missval1, rnofvals.data(), rwork[0].data());
        }

      cdoDefRecord(streamID3, varID, levelID);
      cdoWriteRecord(streamID3, rwork[0].data(), nmiss);
    }

  cdoStreamClose(streamID3);
  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  cdoFinish();

  return nullptr;
}
