//
// Solve Algebraic Equations with Complex Coefficients
// 2008-06-20 Tomonori Kouya
//
#include <iostream>
#include <cmath>
#include <complex>

#ifndef M_PI
	#define M_PI (3.1415926535897932385)
#endif

using namespace std;

// get all values of z^(1/n)
// ret[i] = |z|^(1/n) arg(z/n) * omega_n^i (i = 0, 1, ..., n-1)
int get_all_roots(complex<double> ret[], complex<double> z, long int n)
{
	long int i;
	double abs_zn, arg_zn, inc_arg, theta;

	if(n <= 0)
		return -1;

	// abs_zn = |z|^(1/n)
	abs_zn = pow(abs(z), 1.0 / (double)n);

	//arg_zn = arg(z) / n
	arg_zn = arg(z) / (double)n;

	// set omega_n
	inc_arg = 2.0 * M_PI / (double)n;

	// set ret[0], ..., reg[n - 1]
	for(i = 0; i < n; i++)
	{
		theta = arg_zn + inc_arg * (double)i;
		ret[i] = abs_zn * complex<double>((double)cos(theta), (double)sin(theta));
	}

	return n;
}

// get the value of p(z)
complex<double> eval_poly(complex<double> coef[], complex<double> z, long int deg)
{
	complex<double> ret;

	// Horner method
	ret = coef[deg];
	for(long int i = deg - 1; i >= 0; i--)
		ret = ret * z + coef[i];

	return ret;
}

// multiply polynomial
// p = p[p_deg] x^p_deg + ... + p[1] x + p[0]
// q = q[q_deg] x^q_deg + ... + q[1] x + q[0]
// -> r = p * q = r[p_deg + q_deg] x^{p_deg + q_deg} + ...+ r[1] x + r[0]
int mul_poly(complex<double> ret_coef[], complex<double> pcoef[], long int p_deg, complex<double> qcoef[], long int q_deg)
{
	long int ret_deg;
	complex<double> *tmp_coef;

	if((p_deg < 0) || (q_deg < 0))
		return -1;

	ret_deg = p_deg + q_deg;

	tmp_coef = new complex<double>[ret_deg + 1];

	for(long int i = ret_deg; i >= 0; i--)
	{
		tmp_coef[i] = 0.0;
		for(long int j = p_deg; j >= 0; j--)
		{
			for(long int k = q_deg; k >= 0; k--)
			{
				if((j + k) == i) // bad implementation ... 
					tmp_coef[i] += pcoef[j] * qcoef[k];
			}
		}
	}

	// substitution
	for(long int i = 0; i <= ret_deg; i++)
		ret_coef[i] = tmp_coef[i];

	return ret_deg;
}

// get test sets
//   coef[deg] x^deg + ... + coef[1] x + coef[0] 
// = cost_mul (x - ans[0]) * ... * (x - ans[deg])
void test_set_complex_poly(complex<double> ret_coef[], complex<double> ans[], complex<double> const_mul, long int deg)
{
	complex<double> tmp_coef[2];
	long int ret_deg;

	// poly *= (x - ans[i])
	ret_coef[1] = 1.0;
	ret_coef[0] = -ans[0];
	ret_deg = 1;
	for(long int i = 1; i < deg; i++)
	{
		tmp_coef[1] = 1.0;
		tmp_coef[0] = -ans[i];
		ret_deg = mul_poly(ret_coef, ret_coef, ret_deg, tmp_coef, 1);
	}

	// poly * const_mul
	for(long int i = 0; i <= deg; i++)
		ret_coef[i] *= const_mul;

	// check
//	cout << endl;
//	for(long int i = deg; i >= 0; i--)
//		cout << i << ": " << ret_coef[i] << endl;

	// for mupad
	cout << endl;
	cout << "(" << ret_coef[deg].real() << "+(" << ret_coef[deg].imag() << ")*I)*x^" << deg ;
	for(long int i = deg - 1; i > 0; i--)
		cout << "+(" << ret_coef[i].real() << "+(" << ret_coef[i].imag() << ")*I)*x^" << i ;
	cout << "+(" << ret_coef[0].real() << "+(" << ret_coef[0].imag() << ")*I)=0" << endl;
	cout << endl;

}

// solve complex coef algebraic equation
// coef[2] * z^2 + coef[1] * z + coef[0] = 0
int solve_quadratic_eq_complex(complex<double> ret[2], complex<double> coef[3])
{
	// linear equation
	if(coef[2] == complex<double>(0.0, 0.0))
	{
		if(coef[1] == complex<double>(0.0, 0.0))
		{
			cerr << "Illigal coefficients!(solve_quadratic_eq_comlex)" << endl;
			return -1;
		}

		ret[0] = -coef[0] / coef[1];
		return 1;
	}

	complex<double> sqrt_d[2], tmp;

	// sqrt(coef[1]^2 - 4 * coef[2] * coef[0])
	get_all_roots(sqrt_d, coef[1] * coef[1] - 4.0 * coef[2] * coef[0], 2);

	tmp = 2.0 * coef[2];
	for(int i = 0; i < 2; i++)
		ret[i] = (-coef[1] + sqrt_d[i]) / tmp;

	return 2;
}

// solve complex coef algebraic equation
// coef[3] * z^3 + coef[2] * z^2 + coef[1] * z + coef[0] = 0
int solve_cubic_eq_complex(complex<double> ret[3], complex<double> coef[4])
{
	complex<double> tmp_coef[3];

	// under quadratic equation
	if(coef[3] == complex<double>(0.0, 0.0))
	{
		tmp_coef[2] = coef[2];
		tmp_coef[1] = coef[1];
		tmp_coef[0] = coef[0];
		return solve_quadratic_eq_complex(tmp_coef, ret);
	}

	complex<double> omega, tmp, p, q, u, v, tmp_sol[2];

	// solve z^2 + qz - p^3 = 0
	p = -pow(coef[2], 2) / ((double)9.0 * pow(coef[3], 2)) + coef[1] / ((double)3.0 * coef[3]);
	q = (double)(2.0 / 27.0) * pow(coef[2] / coef[3], 3) - coef[1] * coef[2] / ((double)3.0 * pow(coef[3], 2)) + coef[0] / coef[3];

	tmp_coef[2] = 1.0;
	tmp_coef[1] = q;
	tmp_coef[0] = -p * p * p;
	solve_quadratic_eq_complex(tmp_sol, tmp_coef);
//	for(int i = 0; i < 2; i++)
//		cout << i << ": " << tmp_sol[i] << "->" << eval_poly(tmp_coef, tmp_sol[i], 2) << endl;

	// u, v
	get_all_roots(ret, tmp_sol[0], 3);
	u = ret[0];
	if(abs(u) != 0.0)
		v = -p / u;
	else
		v = 0.0;

	// omega_3
	omega = exp(complex<double>(0, 2 * M_PI / 3.0));

	// roots
	tmp = coef[2] / (3.0 * coef[3]);

	ret[0] = u + v - tmp;
	ret[1] = omega * u + omega * omega * v - tmp;
	ret[2] = omega * omega * u + omega * v - tmp;

	return 3;
}


// solve complex coef algebraic equation
// coef[4] * z^4 + coef[3] * z^3 + coef[2] * z^2 + coef[1] * z + coef[0] = 0
int solve_quartic_eq_complex(complex<double> ret[4], complex<double> coef[5])
{
	complex<double> tmp_coef[4];

	// under quadratic equation
	if(coef[4] == complex<double>(0.0, 0.0))
	{
		tmp_coef[3] = coef[3];
		tmp_coef[2] = coef[2];
		tmp_coef[1] = coef[1];
		tmp_coef[0] = coef[0]; 
		return solve_cubic_eq_complex(tmp_coef, ret);
	}

	complex<double> tmp, p, q, r, z, tmp_sol[3], tmp_coef2[3], tmp_sol2[2];

	// solve z^4 + pz^2 + qz + r = 0
	p = -3.0 * pow(coef[3], 2) / ((double)8.0 * pow(coef[4], 2)) + coef[2] / coef[4];
	q = pow(coef[3] / coef[4], 3) / (double)8.0 - coef[2] * coef[3] / ((double)2.0 * pow(coef[4], 2)) + coef[1] / coef[4];
	r = -((double)3.0 / 256.0) * pow(coef[3] / coef[4], 4) + coef[2] * pow(coef[3], 2) / ((double)16.0 * pow(coef[4], 3)) - coef[1] * coef[3] / ((double)4.0 * pow(coef[4], 2)) + coef[0] / coef[4];

	// q == 0
	if(q == complex<double>(0.0, 0.0))
	{
		// solve y^4 + p y^2 + r = 0;
		tmp_coef2[2] = 1.0;
		tmp_coef2[1] = p;
		tmp_coef2[0] = r;
		solve_quadratic_eq_complex(tmp_sol2, tmp_coef2);
		tmp = tmp_sol2[1];

		// solve z^2 - y = 0;
		tmp_coef2[2] = 1.0;
		tmp_coef2[1] = 0.0;
		tmp_coef2[0] = -tmp_sol2[0];
		solve_quadratic_eq_complex(tmp_sol2, tmp_coef2);
		ret[0] = tmp_sol2[0];
		ret[1] = tmp_sol2[1];

		// solve z^2 - y = 0;
		tmp_coef2[2] = 1.0;
		tmp_coef2[1] = 0.0;
		tmp_coef2[0] = -tmp;
		solve_quadratic_eq_complex(tmp_sol2, tmp_coef2);
		ret[2] = tmp_sol2[0];
		ret[3] = tmp_sol2[1];
	}
	// q != 0
	else
	{
		// solve z^3 - pz^2 - 4rz + 4pr - q^2 = 0
		tmp_coef[3] = 1.0;
		tmp_coef[2] = -p;
		tmp_coef[1] = -4.0 * r;
		tmp_coef[0] = 4.0 * p * r - q * q;
		solve_cubic_eq_complex(tmp_sol, tmp_coef);
//		for(int i = 0; i < 3; i++)
//			cout << i << ": " << tmp_sol[i] << "->" << eval_poly(tmp_coef, tmp_sol[i], 3) << endl;

		// solve y^2 - sqrt(z-p) y + (z/2 + q / (2 sqrt(z-p))) = 0
		z = tmp_sol[0];
		tmp_coef2[2] = 1.0;
		tmp_coef2[1] = -sqrt(z - p);
		tmp_coef2[0] = z / 2.0 + q / (2.0 * sqrt(z - p));
		solve_quadratic_eq_complex(tmp_sol2, tmp_coef2);
		ret[0] = tmp_sol2[0];
		ret[1] = tmp_sol2[1];

		// solve y^2 + sqrt(z-p) y + (z/2 - q / (2 sqrt(z-p))) = 0
		tmp_coef2[2] = 1.0;
		tmp_coef2[1] = sqrt(z - p);
		tmp_coef2[0] = z / 2.0 - q / (2.0 * sqrt(z - p));
		solve_quadratic_eq_complex(tmp_sol2, tmp_coef2);
		ret[2] = tmp_sol2[0];
		ret[3] = tmp_sol2[1];
	}

	// roots
	tmp = coef[3] / (4.0 * coef[4]);
	for(int i = 0; i < 4; i++)
		ret[i] -= tmp;

	return 4;
}

#define N 4

int main()
{
	long int deg, i;
	complex<double> coef[N+1], sol[N], true_sol[N];

/*
	// Produce the example with user-defined solutions
	true_sol[0] = complex<double>(1.0, 1.0);
	true_sol[1] = complex<double>(1.0, 2.0);
	true_sol[2] = complex<double>(1.0, 3.0);
	true_sol[3] = complex<double>(1.0, 4.0);

	test_set_complex_poly(coef, true_sol, 2, N);
*/

	cout << "Input the degree of polynomial( <= 4): ";
	cin >> deg;

	if((deg <= 0) || (deg > 5))
	{
		cout << "Illigal degree!" << endl;
		return -1;
	}

	// Input coefficients as complex number
	for(i = 0; i <= deg; i++)
	{
		cout << "coef[" << i << "] = ";
		cin >> coef[i];
	}

	switch(deg)
	{
		case 1:
			coef[2] = 0.0;
		case 2:
			solve_quadratic_eq_complex(sol, coef);
			break;
		case 3:
			solve_cubic_eq_complex(sol, coef);
			break;
		case 4:
			solve_quartic_eq_complex(sol, coef);
			break;
	}

	// Validate sol[i]
	for(i = 0; i < deg; i++)
		cout << i << ": " << sol[i] << "->" << eval_poly(coef, sol[i], N) << endl;

	return 0;
}
