/*
 * Copyright (c) 1997 Massachusetts Institute of Technology
 *
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files (the
 * "Software"), to use, copy, modify, and distribute the Software without
 * restriction, provided the Software, including any modified copies made
 * under this license, is not distributed for a fee, subject to
 * the following conditions:
 *
 * The above copyright notice and this permission notice shall be
 * included in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE MASSACHUSETTS INSTITUTE OF TECHNOLOGY BE LIABLE
 * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
 * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 *
 * Except as contained in this notice, the name of the Massachusetts
 * Institute of Technology shall not be used in advertising or otherwise
 * to promote the sale, use or other dealings in this Software without
 * prior written authorization from the Massachusetts Institute of
 * Technology.
 *
 */

#ifndef BENCH_UTILS_H
#define BENCH_UTILS_H

#include <bench-config.h>

#include <fftw.h>
#include <fftw-int.h>

#include <stdio.h>
#include <string.h>

extern char which_fft[80]; /* which FFT to benchmark, or "" for all */
extern int which_fft_index; /* index of FFT corresponding to which_fft,
			       or -1 if which_fft is "" */

extern short bench_echo_dat;
extern FILE *bench_log_file,*bench_dat_file;

typedef struct bench_fft_data_struct {
     char *name;
     int index;
     double cur_mflops;
     double norm_avg;
     int num_sizes;
     struct bench_fft_data_struct *next;
} bench_fft_data;

extern bench_fft_data *fft_data_top, *fft_data_cur;

void log_printf(const char *template, ...)
#ifdef __GNUC__
     __attribute__ ((format (printf, 1, 2)));
#else
;
#endif

void dat_printf(const char *template, ...)
#ifdef __GNUC__
     __attribute__ ((format (printf, 1, 2)));
#else
;
#endif

typedef enum {
     POWERS_OF_TWO_ONLY,
     NON_POWERS_OF_TWO_ONLY,
     ALL_FACTORS
} factor_type;

short bench_1d(short compute_accuracy, 
	       factor_type allowed_factors, int which_N, double max_MB);
short bench_3d(short compute_accuracy, 
	       factor_type allowed_factors, int which_N, double max_MB);
short rbench_1d(short compute_accuracy, 
		factor_type allowed_factors, int which_N, double max_MB);
short rbench_2d(short compute_accuracy, 
		factor_type allowed_factors, int which_N, double max_MB);

void bench_init_array(FFTW_COMPLEX *arr, int N);
void bench_init_array_for_check(FFTW_COMPLEX *arr, int N);
double bench_check_array(FFTW_COMPLEX *arr, int N, double scale);

void rbench_init_array(FFTW_REAL *arr, int N);
void rbench_init_array_for_check(FFTW_REAL *arr, int N);
double rbench_check_array(FFTW_REAL *arr, int N, double scale);

void bench_conjugate_array(FFTW_COMPLEX *arr, int n, short reim_alt);
void bench_copy_array(FFTW_COMPLEX *from_arr, FFTW_COMPLEX *to_arr, int n);

/*
 * convert C name -> FORTRAN name.  On some systems, 
 * append an underscore. On other systems, use all caps.
 *
 * x is the lower case name, X is the all caps name.
 */

#if defined(CRAY) || defined(_UNICOS) || defined(_CRAYMPP)
#define FORTRANIZE(x,X) X /* all upper-case on the Cray */

#elif defined(SOLARIS)
#define FORTRANIZE(x,X) x##_  /* append an underscore for Solaris */

#elif defined(IBM6000) || defined(_AIX)
#define FORTRANIZE(x,X) x  /* all lower-case on RS/6000 */

#else
#define FORTRANIZE(x,X) x##_  /* use all lower-case with 
				 appended underscore by default */

#endif

#ifdef HAVE_F77
#define WHEN_FORTRAN(foobar) foobar
#else
#define WHEN_FORTRAN(foobar) 0
#endif

#define IS_POWER_OF_TWO(n) ((((n)-1) & ~(n)) + 1 == (n))

#define MAX_BENCH_ERROR 1.0e-1


#ifndef FFTW_HAS_TIME_DIFF /* old version of fftw, no fftw_time_diff */
#define fftw_time_diff(t1,t2) ((t1) - (t2))
#endif

void set_fft_enabled(short enabled);
void set_fft_skip(short skip, const char *message);

extern short set_fft_name(const char *name, int for_computation);
extern void output_mean_error(double mean_error, int for_check_only);
extern void output_results(double t, int iters, int real_N,
			   double mflops_scale);
extern void destroy_fft_data(void);

extern void skip_benchmark(const char *why);

#define FFT_REQUIRE_POWER_OF_TWO do { \
     if (allowed_factors == NON_POWERS_OF_TWO_ONLY) \
	  set_fft_enabled(0); \
     else if (N > 0 && !is_power_of_two) \
	  set_fft_skip(1, "requires power of two size"); \
} while (0) \

#define FFT_NAME(name_string) set_fft_name(name_string, N != 0)

#define FFT_OK (fft_data_cur != (bench_fft_data *) 0)

void init_normalized_averages(void);
void compute_normalized_averages(void);
void output_normalized_averages(void);

void do_standard_fft(FFTW_COMPLEX *arr, int rank, int *size, int sign,
		     int reim_alt);

extern double min_bench_time;

/* Macro to do the ND benchmark:

   Parameters:
   rank = rank (dimensionality) of transform
   size[rank] = size of transform (array of integers); if product of
                sizes is zero, then print name of transform instead
   real_N = N to normalize time by (usually should be product of
            sizes, but may be different for CWP)
   in = input array
   out = output array (can equal in)
   fft = command to transform "in"
   scale_1 = scale factor to multiply by after fft to get unnormalized FFT
   sign = sign (+1/-1) in exponent of fft
   reim_alt = 1 if real & imag. parts are stored in alternating order,
              and 0 if they are stored separately (in consecutive, contiguous
	      chunks)...this is used to check the correctness of the FFT
	      using GO; set to -1 if you don't want to check at all
   ifft = command to perform inverse of fft, putting results in "in"
   scale = scale factor to multiply by after ifft to get true inverse
   acc = whether to compute timing (0) or acc (1 if have ifft, -1 otherwise)
*/
#define DO_BENCHMARK_ND_AUX(bench,nffts,mflops_scale,rank,size,real_N,in,out,bench_code,fft,scale_1,sign,reim_alt,ifft,scale,acc) do { \
     if (fft_data_cur && (real_N) > 0) { \
	  fftw_time start_t,end_t; \
	  int iter, iters = 1, N_tot = 1; \
          double t, mean_error; \
	  for (iter = 0; iter < (rank); ++iter) N_tot *= (size)[iter]; \
	  if (acc) { \
	       bench##_init_array_for_check(in, N_tot); \
	       fft; \
	       if (acc < 0) { \
		    bench_conjugate_array((FFTW_COMPLEX*)out,N_tot,reim_alt); \
		    bench_copy_array((FFTW_COMPLEX*)out, \
				     (FFTW_COMPLEX*)in,N_tot); \
		    fft; \
		    bench_conjugate_array((FFTW_COMPLEX*)out,N_tot,reim_alt); \
		    mean_error = bench##_check_array(out,N_tot,           \
						   (scale_1)*(scale_1)  \
						   * 1.0/N_tot);        \
	       } \
               else { \
                    ifft; \
                    mean_error = bench##_check_array(in,N_tot,scale); \
               } \
	       output_mean_error(mean_error, 0); \
	  } \
	  else { \
	  bench##_init_array(in,N_tot); \
	  bench_code; \
	  do { \
	       start_t = fftw_get_time(); \
	       for (iter = 0; iter < iters; ++iter) { bench_code; } \
	       end_t = fftw_get_time(); \
	       t = fftw_time_to_sec(fftw_time_diff(end_t,start_t)); \
	       iters *= 2; \
	  } while (t < min_bench_time); \
	  iters /= 2; \
	  output_results(t / (nffts), iters, real_N, mflops_scale); \
	  if (reim_alt != -1) { \
               bench##_init_array_for_check(in,N_tot); \
               fft; \
               do_standard_fft((FFTW_COMPLEX*)out,rank,size,-(sign),reim_alt);\
               mean_error = bench##_check_array(out,N_tot,(scale_1)*1.0/N_tot); \
	       output_mean_error(mean_error, 1); \
	  } \
	  } \
	  log_printf("\n"); \
     } \
} while (0)

#define DO_BENCHMARK_ND(rank,size,real_N,in,out,fft,scale_1,sign,reim_alt,ifft,scale,acc) \
   DO_BENCHMARK_ND_AUX(bench,1,1.0,rank,size,real_N,in,out,fft,fft,scale_1,sign,reim_alt,ifft,scale,acc) 

#define DO_RBENCHMARK_ND(rank,size,real_N,in,out,fft,ifft,scale,acc) \
   DO_BENCHMARK_ND_AUX(rbench,2,0.5,rank,size,real_N,in,out,{fft;ifft;},fft,1.0,-1.0,-1,ifft,scale,acc)

#endif
