/*=========================================================================

  Program:   Insight Segmentation & Registration Toolkit
  Module:    $RCSfile: itkWeightedMeanSampleFilterTest.cxx,v $
  Language:  C++
  Date:      $Date: 2009-05-12 14:30:19 $
  Version:   $Revision: 1.2 $

  Copyright (c) Insight Software Consortium. All rights reserved.
  See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.

     This software is distributed WITHOUT ANY WARRANTY; without even 
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#if defined(_MSC_VER)
#pragma warning ( disable : 4786 )
#endif

#include "itkWeightedMeanSampleFilter.h"
#include "itkListSample.h"
#include "itkFixedArray.h"
#include "itkFunctionBase.h"

const unsigned int                  MeasurementVectorSize = 2;

typedef itk::FixedArray< 
    float, MeasurementVectorSize >             MeasurementVectorType;

class WeightedMeanTestFunction :
  public itk::FunctionBase< MeasurementVectorType, double >
{
public:
  /** Standard class typedefs. */
  typedef WeightedMeanTestFunction Self;

  typedef itk::FunctionBase< MeasurementVectorType, double > Superclass;
  
  typedef itk::SmartPointer<Self> Pointer;
  
  typedef itk::SmartPointer<const Self> ConstPointer;
  
  /** Standard macros. */
  itkTypeMacro(WeightedMeanTestFunction, FunctionBase);
  itkNewMacro(Self);

  /** Input type */
  typedef MeasurementVectorType InputType;

  /** Output type */
  typedef double OutputType;

  /**Evaluate at the specified input position */
  OutputType Evaluate( const InputType& input ) const 
    {
    MeasurementVectorType measurements;
    // set the weight factor of the measurment 
    // vector with valuev[2, 2] to 0.5.
    measurements.Fill(2.0f);
    if ( input != measurements )
      {
      return 0.5;
      }
    else
      {
      return 1.0;
      }
    }

protected:
  WeightedMeanTestFunction() {}
  ~WeightedMeanTestFunction() {}
}; // end of class


int itkWeightedMeanSampleFilterTest(int, char* [] ) 
{
  std::cout << "WeightedMeanSampleFilter test \n \n";

  const unsigned int                  numberOfMeasurementVectors = 5;
  unsigned int                        counter;

  typedef itk::FixedArray< 
    float, MeasurementVectorSize >             MeasurementVectorType;

  typedef itk::Statistics::ListSample< 
    MeasurementVectorType >                    SampleType;

  SampleType::Pointer sample = SampleType::New();

  sample->SetMeasurementVectorSize( MeasurementVectorSize ); 

  MeasurementVectorType               measure;
  
  //reset counter
  counter = 0;

  while ( counter < numberOfMeasurementVectors ) 
    {
    for( unsigned int i=0; i<MeasurementVectorSize; i++)
      {
      measure[i] = counter;
      }
    sample->PushBack( measure );
    counter++;
    }

  typedef itk::Statistics::WeightedMeanSampleFilter< SampleType > 
    FilterType;

  FilterType::Pointer filter = FilterType::New();

  std::cout << filter->GetNameOfClass() << std::endl;
  filter->Print(std::cout);

  //Invoke update before adding an input. An exception should be 
  //thrown.
  try
    {
    filter->Update();
    std::cerr << "Exception should have been thrown since \
                    Update() is invoked without setting an input " << std::endl;
    return EXIT_FAILURE;
    }
  catch ( itk::ExceptionObject & excp )
    {
  std::cerr << "Exception caught: " << excp << std::endl;
  }

  if ( filter->GetInput() != NULL )
    {
    std::cerr << "GetInput() should return NULL if the input \
                     has not been set" << std::endl;
    return EXIT_FAILURE;
    }

  filter->ResetPipeline();
  filter->SetInput( sample );

  //run the filters without weighting coefficients
  try
    {
    filter->Update();
    }
  catch ( itk::ExceptionObject & excp )
    {
    std::cerr << "Exception caught: " << excp << std::endl;
    }
 
  const FilterType::MeasurementVectorDecoratedType * decorator = filter->GetOutput();
  FilterType::MeasurementVectorType    meanOutput  = decorator->Get();

  FilterType::MeasurementVectorType mean;

  mean[0] = 2.0;
  mean[1] = 2.0;

 FilterType::MeasurementVectorType::ValueType    epsilon = 1e-6; 

  if ( ( fabs( meanOutput[0] - mean[0]) > epsilon )  || 
       ( fabs( meanOutput[1] - mean[1]) > epsilon ))
    {
    std::cerr << "Wrong result " << std::endl;
    std::cerr << meanOutput[0] << " " << mean[0] << " "  
            << meanOutput[1] << " " << mean[1] << " " << std::endl;  
    std::cerr << "The result is not what is expected" << std::endl;
    return EXIT_FAILURE;
    }
 
  typedef FilterType::WeightArrayType  WeightArrayType;
  WeightArrayType weightArray(sample->Size());
  weightArray.Fill(1.0);

  filter->SetWeights( weightArray );

  try
    {
    filter->Update();
    }
  catch ( itk::ExceptionObject & excp )
    {
    std::cerr << "Exception caught: " << excp << std::endl;
    }
 
  decorator = filter->GetOutput();
  meanOutput  = decorator->Get();

  mean[0] = 2.0;
  mean[1] = 2.0;

  if ( ( fabs( meanOutput[0] - mean[0]) > epsilon )  || 
       ( fabs( meanOutput[1] - mean[1]) > epsilon ))
    {
    std::cerr << "Wrong result " << std::endl;
    std::cerr << meanOutput[0] << " " << mean[0] << " "  
            << meanOutput[1] << " " << mean[1] << " " << std::endl;  
    std::cerr << "The result is not what is expected" << std::endl;
    return EXIT_FAILURE;
    }

  //change the weight of the last element to 0.5 and recompute
  weightArray[numberOfMeasurementVectors - 1] = 0.5;
  filter->SetWeights( weightArray );

  try
    {
    filter->Update();
    }
  catch ( itk::ExceptionObject & excp )
    {
    std::cerr << "Exception caught: " << excp << std::endl;
    }
 
  decorator = filter->GetOutput();
  meanOutput  = decorator->Get();

  mean[0] = 1.7777778;
  mean[1] = 1.7777778;

  if ( ( fabs( meanOutput[0] - mean[0]) > epsilon )  || 
       ( fabs( meanOutput[1] - mean[1]) > epsilon ))
    {
    std::cerr << "Wrong result" << std::endl;
    std::cerr << meanOutput[0] << " " << mean[0] << " "  
            << meanOutput[1] << " " << mean[1] << " " << std::endl;  

    std::cerr << "The result is not what is expected" << std::endl;
    return EXIT_FAILURE;
    }
 
  //set the weight using a function
  WeightedMeanTestFunction::Pointer weightFunction = WeightedMeanTestFunction::New(); 
  filter->SetWeightingFunction( weightFunction.GetPointer() );

  try
    {
    filter->Update();
    }
  catch ( itk::ExceptionObject & excp )
    {
    std::cerr << "Exception caught: " << excp << std::endl;
    }
 
  decorator = filter->GetOutput();
  meanOutput  = decorator->Get();

  mean[0] = 2.0;
  mean[1] = 2.0;

  if ( ( fabs( meanOutput[0] - mean[0]) > epsilon )  || 
       ( fabs( meanOutput[1] - mean[1]) > epsilon ))
    {
    std::cerr << "Wrong result" << std::endl;
    std::cerr << meanOutput[0] << " " << mean[0] << " "  
            << meanOutput[1] << " " << mean[1] << " " << std::endl;  
    std::cerr << "The result is not what is expected" << std::endl;
    return EXIT_FAILURE;
    }
 
  std::cout << "Test passed." << std::endl;
  return EXIT_SUCCESS;
}
