/*
  This file is part of CDO. CDO is a collection of Operators to manipulate and analyse Climate model Data.

  Author: Uwe Schulzweida

*/

/*
   This module contains the following operators:

      Setmiss    setmissval      Set a new missing value
      Setmiss    setctomiss      Set constant to missing value
      Setmiss    setmisstoc      Set missing value to constant
      Setmiss    setrtomiss      Set range to missing value
      Setmiss    setvrange       Set range of valid value
*/

#include <cmath>
#include <cdi.h>

#include "process_int.h"
#include "param_conversion.h"

template <typename T>
static size_t
set_missval(size_t gridsize, Varray<T> &array, T missval, T new_missval)
{
  size_t numMissVals = 0;
  for (size_t i = 0; i < gridsize; ++i)
    if (DBL_IS_EQUAL(array[i], missval) || DBL_IS_EQUAL(array[i], (float) missval) || DBL_IS_EQUAL(array[i], new_missval)
        || DBL_IS_EQUAL(array[i], (float) new_missval))
      {
        array[i] = new_missval;
        numMissVals++;
      }

  return numMissVals;
}

static void
set_missval(Field &field, double new_missval)
{
  if (field.memType == MemType::Float)
    field.numMissVals = set_missval(field.size, field.vec_f, (float) field.missval, (float) new_missval);
  else
    field.numMissVals = set_missval(field.size, field.vec_d, field.missval, new_missval);
}

template <typename T>
static size_t
set_const_to_miss(size_t gridsize, Varray<T> &array, T missval, T rconst)
{
  size_t numMissVals = 0;
  if (std::isnan(rconst))
    {
      for (size_t i = 0; i < gridsize; ++i)
        if (std::isnan(array[i]))
          {
            array[i] = missval;
            numMissVals++;
          }
    }
  else
    {
      for (size_t i = 0; i < gridsize; ++i)
        if (DBL_IS_EQUAL(array[i], rconst) || DBL_IS_EQUAL(array[i], (float) rconst))
          {
            array[i] = missval;
            numMissVals++;
          }
    }

  return numMissVals;
}

static void
set_const_to_miss(Field &field, double rconst)
{
  if (field.memType == MemType::Float)
    field.numMissVals += set_const_to_miss(field.size, field.vec_f, (float) field.missval, (float) rconst);
  else
    field.numMissVals += set_const_to_miss(field.size, field.vec_d, field.missval, rconst);
}

template <typename T>
static size_t
set_miss_to_const(size_t gridsize, Varray<T> &array, T missval, T rconst)
{
  for (size_t i = 0; i < gridsize; ++i)
    if (DBL_IS_EQUAL(array[i], missval) || DBL_IS_EQUAL(array[i], (float) missval)) { array[i] = rconst; }

  return 0;
}

static void
set_miss_to_const(Field &field, double rconst)
{
  if (field.memType == MemType::Float)
    field.numMissVals = set_miss_to_const(field.size, field.vec_f, (float) field.missval, (float) rconst);
  else
    field.numMissVals = set_miss_to_const(field.size, field.vec_d, field.missval, rconst);
}

template <typename T>
static size_t
set_range_to_miss(size_t gridsize, Varray<T> &array, T missval, T rmin, T rmax)
{
  size_t numMissVals = 0;
  for (size_t i = 0; i < gridsize; ++i)
    if (array[i] >= rmin && array[i] <= rmax)
      {
        array[i] = missval;
        numMissVals++;
      }

  return numMissVals;
}

static void
set_range_to_miss(Field &field, double rmin, double rmax)
{
  if (field.memType == MemType::Float)
    field.numMissVals += set_range_to_miss(field.size, field.vec_f, (float) field.missval, (float) rmin, (float) rmax);
  else
    field.numMissVals += set_range_to_miss(field.size, field.vec_d, field.missval, rmin, rmax);
}

template <typename T>
static size_t
set_valid_range(size_t gridsize, Varray<T> &array, T missval, T rmin, T rmax)
{
  for (size_t i = 0; i < gridsize; ++i)
    if (array[i] < rmin || array[i] > rmax) array[i] = missval;

  const auto numMissVals = varray_num_mv(gridsize, array, missval);

  return numMissVals;
}

static void
set_valid_range(Field &field, double rmin, double rmax)
{
  if (field.memType == MemType::Float)
    field.numMissVals = set_valid_range(field.size, field.vec_f, (float) field.missval, (float) rmin, (float) rmax);
  else
    field.numMissVals = set_valid_range(field.size, field.vec_d, field.missval, rmin, rmax);
}

class Setmiss : public Process
{
public:
  using Process::Process;
  inline static CdoModule module = {
    .name = "Setmiss",
    .operators = { { "setmissval", 0, 0, "missing value", SetmissHelp },
                   { "setctomiss", 0, 0, "constant", SetmissHelp },
                   { "setmisstoc", 0, 0, "constant", SetmissHelp },
                   { "setrtomiss", 0, 0, "range(min,max)", SetmissHelp },
                   { "setvrange", 0, 0, "range(min,max)", SetmissHelp } },
    .aliases = {},
    .mode = EXPOSED,     // Module mode: 0:intern 1:extern
    .number = CDI_REAL,  // Allowed number type
    .constraints = { 1, 1, NoRestriction },
  };
  inline static RegisterEntry<Setmiss> registration = RegisterEntry<Setmiss>(module);
  int SETMISSVAL, SETCTOMISS, SETMISSTOC, SETRTOMISS, SETVRANGE;

  CdoStreamID streamID1;
  CdoStreamID streamID2;
  int taxisID1;
  int taxisID2;

  VarList varList;
  Field field;
  int operatorID;

  double rconst = 0.0;
  double rmin = 0.0;
  double rmax = 0.0;
  double new_missval = 0.0;

public:
  void
  init()
  {

    // clang-format off
SETMISSVAL = module.get_id("setmissval");
SETCTOMISS = module.get_id("setctomiss");
SETMISSTOC = module.get_id("setmisstoc");
SETRTOMISS = module.get_id("setrtomiss");
SETVRANGE = module.get_id("setvrange");
    // clang-format on

    operatorID = cdo_operator_id();

    if (operatorID == SETMISSVAL)
      {
        operator_check_argc(1);
        new_missval = parameter_to_double(cdo_operator_argv(0));
      }
    else if (operatorID == SETCTOMISS || operatorID == SETMISSTOC)
      {
        operator_check_argc(1);
        rconst = parameter_to_double(cdo_operator_argv(0));
      }
    else
      {
        operator_check_argc(2);
        rmin = parameter_to_double(cdo_operator_argv(0));
        rmax = parameter_to_double(cdo_operator_argv(1));
      }

    streamID1 = cdo_open_read(0);

    const auto vlistID1 = cdo_stream_inq_vlist(streamID1);
    const auto vlistID2 = vlistDuplicate(vlistID1);

    taxisID1 = vlistInqTaxis(vlistID1);
    taxisID2 = taxisDuplicate(taxisID1);
    vlistDefTaxis(vlistID2, taxisID2);

    varList_init(varList, vlistID1);

    if (operatorID == SETMISSVAL)
      {
        const auto nvars = vlistNvars(vlistID2);
        for (int varID = 0; varID < nvars; ++varID) vlistDefVarMissval(vlistID2, varID, new_missval);
      }
    else if (operatorID == SETMISSTOC)
      {
        auto nvars = vlistNvars(vlistID2);
        for (int varID = 0; varID < nvars; ++varID)
          {
            const auto &var = varList[varID];
            if (DBL_IS_EQUAL(rconst, var.missval))
              {
                cdo_warning("Missing value and constant have the same value!");
                break;
              }
          }
      }

    /*
    if (operatorID == SETVRANGE)
      {
        double range[2] = {rmin, rmax};
        nvars = vlistNvars(vlistID2);
        for (varID = 0; varID < nvars; ++varID)
          cdiDefAttFlt(vlistID2, varID, "valid_range", CDI_DATATYPE_FLT64, 2, range);
      }
    */
    streamID2 = cdo_open_write(1);
    cdo_def_vlist(streamID2, vlistID2);
  }
  void
  run()
  {
    int tsID = 0;
    while (true)
      {
        const auto nrecs = cdo_stream_inq_timestep(streamID1, tsID);
        if (nrecs == 0) break;

        cdo_taxis_copy_timestep(taxisID2, taxisID1);
        cdo_def_timestep(streamID2, tsID);

        for (int recID = 0; recID < nrecs; ++recID)
          {
            int varID, levelID;
            cdo_inq_record(streamID1, &varID, &levelID);
            const auto &var = varList[varID];
            field.init(var);
            cdo_read_record(streamID1, field);

            // clang-format off
          if      (operatorID == SETMISSVAL) set_missval(field, new_missval);
          else if (operatorID == SETCTOMISS) set_const_to_miss(field, rconst);
          else if (operatorID == SETMISSTOC) set_miss_to_const(field, rconst);
          else if (operatorID == SETRTOMISS) set_range_to_miss(field, rmin, rmax);
          else if (operatorID == SETVRANGE)  set_valid_range(field, rmin, rmax);
            // clang-format on

            cdo_def_record(streamID2, varID, levelID);
            cdo_write_record(streamID2, field);
          }

        tsID++;
      }
  }
  void
  close()
  {
    cdo_stream_close(streamID2);
    cdo_stream_close(streamID1);
  }
};
