/*-------------------------------------------------------------------------------
This file is part of Ranger.
    
Ranger is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Ranger is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Ranger. If not, see <http://www.gnu.org/licenses/>.

Written by: 

Marvin N. Wright
Institut für Medizinische Biometrie und Statistik
Universität zu Lübeck
Ratzeburger Allee 160
23562 Lübeck 

http://www.imbs-luebeck.de
wright@imbs.uni-luebeck.de
#-------------------------------------------------------------------------------*/

#ifndef TREESURVIVAL_H_
#define TREESURVIVAL_H_

#include "globals.h"
#include "Tree.h"

class TreeSurvival: public Tree {
public:
  TreeSurvival(std::vector<double>* unique_timepoints, size_t status_varID);

  // Create from loaded forest
  TreeSurvival(std::vector<std::vector<size_t>>& child_nodeIDs, std::vector<size_t>& split_varIDs,
      std::vector<double>& split_values, std::vector<std::vector<double>> chf, std::vector<double>* unique_timepoints);

  virtual ~TreeSurvival();

  void addPrediction(size_t nodeID, size_t sampleID);
  void appendToFileInternal(std::ofstream& file);
  void computePermutationImportanceInternal(std::vector<std::vector<size_t>>* permutations);

  const std::vector<std::vector<double> >& getChf() const {
    return chf;
  }

private:
  bool splitNodeInternal(size_t nodeID, std::unordered_set<size_t>& possible_split_varIDs);
  void createEmptyNodeInternal();

  double computePredictionAccuracyInternal();

  // TODO: Remove
  //double predictOobPermuted(size_t permuted_varID, std::vector<std::vector<size_t>>* permutations);

  // Called by splitNodeInternal(). Sets split_varIDs and split_values.
  bool findBestSplitLogRank(size_t nodeID, std::unordered_set<size_t>& possible_split_varIDs);
  bool findBestSplitAUC(size_t nodeID, std::unordered_set<size_t>& possible_split_varIDs);

  void computeDeathCounts(size_t* num_deaths, size_t* num_samples_at_risk, size_t& num_unique_death_times,
      size_t nodeID);
  double computeLogRankTest(size_t nodeID, size_t varID, double split_value, size_t* num_deaths,
      size_t* num_samples_at_risk, size_t* num_deaths_left_child, size_t* num_samples_at_risk_left_child,
      size_t num_unique_death_times);
  void computeChildDeathCounts(size_t nodeID, size_t varID, double split_value, size_t* num_deaths_left_child,
      size_t* num_samples_at_risk_left_child, size_t* num_samples_left_child);

  double computeAucSplit(size_t nodeID, size_t varID, double split_value);

  void reservePredictionMemory(size_t num_predictions) {
    predictions.resize(num_predictions, std::vector<double>());
    for (auto& sample_vector : predictions) {
      sample_vector.resize(num_timepoints, 0);
    }
  }

  size_t status_varID;

  // Unique time points for all individuals (not only this bootstrap), sorted
  std::vector<double>* unique_timepoints;
  size_t num_timepoints;

  // For all terminal nodes CHF for all unique timepoints. For other nodes empty vector.
  std::vector<std::vector<double>> chf;

  DISALLOW_COPY_AND_ASSIGN(TreeSurvival);
};

#endif /* TREESURVIVAL_H_ */
