/*
  This file is part of CDO. CDO is a collection of Operators to manipulate and analyse Climate model Data.

  Author: Uwe Schulzweida

*/

/*
   This module contains the following operators:

        Timstat3        varquot2test
        Timstat3        meandiff2test
*/

#include <cdi.h>

#include "process_int.h"
#include "cdo_vlist.h"
#include "param_conversion.h"
#include "statistic.h"

#define NIN 2
#define NOUT 1
#define NFWORK 4
#define NIWORK 2

static void
varquot2test(double rconst, double risk, size_t gridsize, double missval, const Varray2D<double> &fwork,
             const Varray2D<size_t> iwork, Varray<double> &out)
{
  auto missval1 = missval;
  auto missval2 = missval;

  for (size_t i = 0; i < gridsize; ++i)
    {
      double fnvals0 = iwork[0][i];
      double fnvals1 = iwork[1][i];

      auto temp0 = DIVMN(MULMN(fwork[0][i], fwork[0][i]), fnvals0);
      auto temp1 = DIVMN(MULMN(fwork[2][i], fwork[2][i]), fnvals1);
      auto temp2 = SUBMN(fwork[1][i], temp0);
      auto temp3 = SUBMN(fwork[3][i], temp1);
      auto statistic = DIVMN(temp2, ADDMN(temp2, MULMN(rconst, temp3)));

      auto fractil_1 = missval1, fractil_2 = missval1;
      if (fnvals0 > 1 && fnvals1 > 1)
        cdo::beta_distr_constants((fnvals0 - 1) / 2, (fnvals1 - 1) / 2, 1 - risk, &fractil_1, &fractil_2);

      out[i] = dbl_is_equal(statistic, missval1) ? missval1 : (statistic <= fractil_1 || statistic >= fractil_2);
    }
}

static void
meandiff2test(double rconst, double risk, size_t gridsize, double missval, const Varray2D<double> &fwork,
              const Varray2D<size_t> iwork, Varray<double> &out)
{
  auto missval1 = missval;
  auto missval2 = missval;

  double meanFactor[] = { 1.0, -1.0 };
  double varFactor[] = { 1.0, 1.0 };

  for (size_t i = 0; i < gridsize; ++i)
    {
      double temp0 = 0.0;
      double degOfFreedom = -NIN;
      for (int j = 0; j < NIN; ++j)
        {
          double fnvals = iwork[j][i];
          auto tmp = DIVMN(MULMN(fwork[2 * j][i], fwork[2 * j][i]), fnvals);
          temp0 = ADDMN(temp0, DIVMN(SUBMN(fwork[2 * j + 1][i], tmp), varFactor[j]));
          degOfFreedom = ADDMN(degOfFreedom, fnvals);
        }

      if (!dbl_is_equal(temp0, missval1) && temp0 < 0) temp0 = 0;  // This is possible because of rounding errors

      auto stddev_estimator = SQRTMN(DIVMN(temp0, degOfFreedom));
      auto mean_estimator = -rconst;
      for (int j = 0; j < NIN; ++j)
        {
          double fnvals = iwork[j][i];
          mean_estimator = ADDMN(mean_estimator, MULMN(meanFactor[j], DIVMN(fwork[2 * j][i], fnvals)));
        }

      double temp1 = 0.0;
      for (int j = 0; j < NIN; ++j)
        {
          double fnvals = iwork[j][i];
          temp1 = ADDMN(temp1, DIVMN(MUL(MUL(meanFactor[j], meanFactor[j]), varFactor[j]), fnvals));
        }

      auto norm = SQRTMN(temp1);

      auto temp2 = DIVMN(DIVMN(mean_estimator, norm), stddev_estimator);
      auto fractil = (degOfFreedom < 1) ? missval1 : cdo::student_t_inv(degOfFreedom, 1 - risk / 2);

      out[i] = (dbl_is_equal(temp2, missval1) || dbl_is_equal(fractil, missval1)) ? missval1 : (std::fabs(temp2) >= fractil);
    }
}

class ModuleTimstat3
{
  CdiDateTime vDateTime{};
  int vlistID[NIN], vlistID2 = -1;
  Varray4D<double> fwork;
  Varray4D<size_t> iwork;
  int reached_eof[NIN];

  CdoStreamID streamID[NIN];
  int taxisID1;

  CdoStreamID streamID3;
  int taxisID3;

  double rconst;
  double risk;

  int maxrecs;
  size_t gridsizemax;

  int operatorID;

  Field in[NIN], out[NOUT];

  std::vector<RecordInfo> recList;
  VarList varList0;

  int VARQUOT2TEST, MEANDIFF2TEST;

public:
  void
  init(void *process)
  {
    cdo_initialize(process);

    // clang-format off
    VARQUOT2TEST  = cdo_operator_add("varquot2test",  0, 0, nullptr);
    MEANDIFF2TEST = cdo_operator_add("meandiff2test", 0, 0, nullptr);
    // clang-format on

    operatorID = cdo_operator_id();

    operator_input_arg("constant and risk (e.g. 0.05)");
    operator_check_argc(2);
    rconst = parameter_to_double(cdo_operator_argv(0));
    risk = parameter_to_double(cdo_operator_argv(1));

    if (rconst <= 0) cdo_abort("Constant must be positive!");
    if (risk <= 0 || risk >= 1) cdo_abort("Risk must be greater than 0 and lower than 1!");

    for (int is = 0; is < NIN; ++is)
      {
        streamID[is] = cdo_open_read(is);
        vlistID[is] = cdo_stream_inq_vlist(streamID[is]);
        if (is > 0)
          {
            vlistID2 = cdo_stream_inq_vlist(streamID[is]);
            vlist_compare(vlistID[0], vlistID2, CmpVlist::All);
          }
      }

    auto vlistID3 = vlistDuplicate(vlistID[0]);

    gridsizemax = vlistGridsizeMax(vlistID[0]);
    auto nvars = vlistNvars(vlistID[0]);

    varList_init(varList0, vlistID[0]);

    maxrecs = vlistNrecs(vlistID[0]);
    recList = std::vector<RecordInfo>(maxrecs);

    taxisID1 = vlistInqTaxis(vlistID[0]);
    taxisID3 = taxisDuplicate(taxisID1);

    vlistDefTaxis(vlistID3, taxisID3);
    streamID3 = cdo_open_write(2);
    cdo_def_vlist(streamID3, vlistID3);

    for (int i = 0; i < NIN; ++i) reached_eof[i] = 0;

    for (int i = 0; i < NIN; ++i) in[i].resize(gridsizemax);
    for (int i = 0; i < NOUT; ++i) out[i].resize(gridsizemax);

    for (int iw = 0; iw < NFWORK; ++iw) fwork.resize(nvars);
    for (int iw = 0; iw < NIWORK; ++iw) iwork.resize(nvars);

    for (int varID = 0; varID < nvars; ++varID)
      {
        const auto &var = varList0[varID];
        auto gridsize = var.gridsize;
        auto nlevels = var.nlevels;

        for (int iw = 0; iw < NFWORK; ++iw) fwork[varID].resize(nlevels);
        for (int iw = 0; iw < NIWORK; ++iw) iwork[varID].resize(nlevels);

        for (int levelID = 0; levelID < nlevels; ++levelID)
          {
            fwork[varID][levelID].resize(NFWORK);
            iwork[varID][levelID].resize(NIWORK);
            for (int iw = 0; iw < NFWORK; ++iw) fwork[varID][levelID][iw].resize(gridsize);
            for (int iw = 0; iw < NIWORK; ++iw) iwork[varID][levelID][iw].resize(gridsize, 0);
          }
      }
  }

  void
  run()
  {
    int tsID = 0;
    while (true)
      {
        int is;
        for (is = 0; is < NIN; ++is)
          {
            if (reached_eof[is]) continue;

            auto nrecs = cdo_stream_inq_timestep(streamID[is], tsID);
            if (nrecs == 0)
              {
                reached_eof[is] = 1;
                continue;
              }

            vDateTime = taxisInqVdatetime(taxisID1);

            for (int recID = 0; recID < nrecs; ++recID)
              {
                int varID, levelID;
                cdo_inq_record(streamID[is], &varID, &levelID);

                auto gridsize = gridInqSize(vlistInqVarGrid(vlistID[is], varID));

                in[is].missval = vlistInqVarMissval(vlistID[is], varID);

                if (tsID == 0 && is == 0) recList[recID].set(varID, levelID);

                cdo_read_record(streamID[is], in[is].vec_d.data(), &in[is].nmiss);

                auto &rfwork = fwork[varID][levelID];
                auto &riwork = iwork[varID][levelID];
                for (size_t i = 0; i < gridsize; ++i)
                  {
                    // if ( ( ! dbl_is_equal(array1[i], missval1) ) && ( ! dbl_is_equal(array2[i], missval2) ) )
                    {
                      rfwork[NIN * is + 0][i] += in[is].vec_d[i];
                      rfwork[NIN * is + 1][i] += in[is].vec_d[i] * in[is].vec_d[i];
                      riwork[is][i]++;
                    }
                  }
              }
          }

        for (is = 0; is < NIN; ++is)
          if (!reached_eof[is]) break;

        if (is == NIN) break;

        tsID++;
      }

    taxisDefVdatetime(taxisID3, vDateTime);
    cdo_def_timestep(streamID3, 0);

    for (int recID = 0; recID < maxrecs; ++recID)
      {
        auto [varID, levelID] = recList[recID].get();

        const auto &var = varList0[varID];

        auto &rfwork = fwork[varID][levelID];
        auto &riwork = iwork[varID][levelID];

        if (operatorID == VARQUOT2TEST) { varquot2test(rconst, risk, var.gridsize, var.missval, rfwork, riwork, out[0].vec_d); }
        else if (operatorID == MEANDIFF2TEST)
          {
            meandiff2test(rconst, risk, var.gridsize, var.missval, rfwork, riwork, out[0].vec_d);
          }

        out[0].missval = var.missval;
        cdo_def_record(streamID3, varID, levelID);
        cdo_write_record(streamID3, out[0].vec_d.data(), field_num_miss(out[0]));
      }
  }

  void
  close()
  {
    cdo_stream_close(streamID3);
    for (int is = 0; is < NIN; ++is) cdo_stream_close(streamID[is]);

    cdo_finish();
  }
};

void *
Timstat3(void *process)
{
  ModuleTimstat3 timstat3;
  timstat3.init(process);
  timstat3.run();
  timstat3.close();

  return nullptr;
}
