//
// FFT radix2 Version
// based on "Matrix Computation 3rd.ed"
// Version 0.0 : 2008-07-14 Tomonori Kouya
// Version 0.1 : 2008-07-17 Tomonori Kouya
//

#include <iostream>
#include <complex>
#include <cmath>

// 名前空間はstdを使用
using namespace std;

#define MAX_FFT_SIZE 32768
#ifndef M_PI
	#define M_PI (3.14159265358979323846)
#endif

// FFT radix-2
// INPUT : complex<double> f[2^power2_n], long int start_index, plus_index
// OUTPUT: complex<double> x[2^power2_n]
int fft_radix2_recursive(complex<double> x[], complex<double> f[], long int start_index, long int plus_index, long int power2_n, int flag_inv)
{
	long int total_n, m, i, index_even, index_odd, index_half[2];
	complex<double> omega, tmp;
	complex<double> *z, *d, *tmp_x;

	total_n = (long int)pow(2.0, (double)power2_n);

	// power2_n == 0 (total_n == 1)
	if(total_n == 1)
	{
		x[start_index] = f[start_index];
		return 0;
	}

	m = total_n / 2;
	z = new complex<double>[m];
	d = new complex<double>[m];
	tmp_x = new complex<double>[m];

	// y_t: even
	fft_radix2_recursive(x, f, start_index             , plus_index * 2, power2_n - 1, flag_inv);

	// y_b: odd
	fft_radix2_recursive(x, f, start_index + plus_index, plus_index * 2, power2_n - 1, flag_inv);

	// d = [1, omega^1, omega^2, ..., omega^(m-1)]
	if(flag_inv >= 1)
		omega = exp(complex<double>(0, 2 * M_PI / total_n));
	else
		omega = exp(complex<double>(0, -2 * M_PI / total_n));

	for(i = 0; i < m; i++)
		d[i] = pow(omega, i);

	// z := d .* y_b
	for(i = 0; i < m; i++)
	{
		index_even = (start_index             ) + (plus_index * 2) * i;
		index_odd  = (start_index + plus_index) + (plus_index * 2) * i;
		z[i] = d[i] * x[index_odd]; // minimum times of multiplication
		tmp_x[i] = x[index_even]; // temporarily moved
	}

	// x := [y_t + z, y_t - z]^T
	for(i = 0; i < m; i++)
	{
		index_half[0] = start_index + plus_index * i;
		index_half[1] = index_half[0] + m * plus_index;
		x[index_half[0]]  = tmp_x[i] + z[i]; // x[i]   = x[index_even] + z
		x[index_half[1]]  = tmp_x[i] - z[i]; // x[i+m] = x[index_even] - z
	}

	// destructor
	delete[] z;
	delete[] d;
	delete[] tmp_x;

	return 0;
}

// FFT radix-2
// INPUT : complex<double> f[2^power2_n], power2_n
// OUTPUT: complex<double> x[2^power2_n]
int fft_radix2(complex<double> x[], complex<double> f[], long int power2_n)
{
	long int total_n, i;

	total_n = (long int)pow(2.0, (double)power2_n);

	fft_radix2_recursive(x, f, 0, 1, power2_n, 0);

	// 1/n * x
	for(i = 0; i < total_n; i++)
		x[i] /= (double)total_n;

	return 0;
}

// FFT radix-2
// INPUT : complex<double> f[2^power2_n], power2_n
// OUTPUT: complex<double> x[2^power2_n]
int inv_fft_radix2(complex<double> x[], complex<double> f[], long int power2_n)
{
	fft_radix2_recursive(x, f, 0, 1, power2_n, 1);

	return 0;
}

void swap(complex<double> *a, complex<double> *b)
{
	complex<double> tmp;

	tmp = *b;
	*b = *a;
	*a = tmp;
}

#define MAXDIM 1024

int main(void)
{
	complex<double> test_f[MAXDIM];
	complex<double> test_x[MAXDIM];
	long int i, dim, pow2_dim;

	// Input Dimension
	cout << "Input dimension of sample data as 2^n (n <= 10): ";
	cin >> pow2_dim;

	dim = (long int)pow(2.0, pow2_dim);

	if(dim > MAXDIM)
		return -1;

	// Initial Values
	for(i = 0; i < dim; i++)
	{
		test_f[i] = complex<double>(i + 1, i + 1);
	}

	// Fourier Transform
	fft_radix2(test_x, test_f, pow2_dim);

	for(i = 0; i < dim; i++)
	{
		cout << i << " " << test_f[i] << " "  << test_x[i] << endl;
		swap(test_x[i], test_f[i]);
	}

	// Inverse Fourier Transform
	inv_fft_radix2(test_x, test_f, pow2_dim);

	for(i = 0; i < dim; i++)
	{
		cout << i << " " << test_f[i] << " "  << test_x[i] << endl;
	}

	return 0;
}
