import unittest
import numpy

from harmony.evaluation import eval_chebyshev, eval_horner, eval_mono
from harmony.evaluation import eval_horner_matvec, eval_chebyshev_matvec_batch
from harmony.evaluation import eval_chebyshev_batch, eval_chebyshev_matvec


class EvaluationTest(unittest.TestCase):
    def test_horner(self):
        x = numpy.random.rand(4, 4)
        coeff = numpy.random.rand(6)
        ref = eval_mono(x, coeff)
        out = eval_horner(x, coeff)
        diff = numpy.linalg.norm(out - ref) / 4
        self.assertTrue(diff < 1e-14, "Difference: {}".format(diff))

    def test_horner_matvec(self):
        vec = numpy.random.rand(4)
        x = numpy.random.rand(4, 4)
        coeff = numpy.random.rand(6)
        Px = eval_horner(x, coeff)
        ref = numpy.einsum('ij,j->i', Px, vec)
        def matvec(vec): return numpy.einsum('ij,j->i', x, vec)
        out = eval_horner_matvec(vec, matvec, coeff)
        diff = numpy.linalg.norm(out - ref)/numpy.linalg.norm(ref)
        self.assertTrue(diff < 1e-14, "Difference: {}".format(diff))

    def test_chebyshev(self):
        coeff = numpy.random.rand(6)
        C = numpy.polynomial.chebyshev.Chebyshev(coeff)
        x = numpy.random.rand(4, 4)
        ref = eval_horner(x, C.convert(kind=numpy.polynomial.Polynomial).coef)
        out = eval_chebyshev(x, coeff)
        diff = numpy.linalg.norm(out - ref)/numpy.linalg.norm(ref)
        self.assertTrue(diff < 1e-14, "Difference: {}".format(diff))

    def test_chebyshev_shifted(self):
        coeff = numpy.random.rand(6)
        C = numpy.polynomial.chebyshev.Chebyshev(coeff, domain=[0, 2])
        x = numpy.random.rand(4, 4)
        ref = eval_horner(x, C.convert(kind=numpy.polynomial.Polynomial).coef)
        out = eval_chebyshev(x, coeff, domain=[0, 2])
        diff = numpy.linalg.norm(out - ref)/numpy.linalg.norm(ref)
        self.assertTrue(diff < 1e-14, "Difference: {}".format(diff))

    def test_chebyshev_scaled(self):
        coeff = numpy.random.rand(6)
        C = numpy.polynomial.chebyshev.Chebyshev(coeff, domain=[-10, 10])
        x = numpy.random.rand(4, 4)
        ref = eval_horner(x, C.convert(kind=numpy.polynomial.Polynomial).coef)
        out = eval_chebyshev(x, coeff, domain=[-10, 10])
        diff = numpy.linalg.norm(out - ref)/numpy.linalg.norm(ref)
        self.assertTrue(diff < 1e-14, "Difference: {}".format(diff))

    def test_chebyshev_gen(self):
        coeff = numpy.random.rand(6)
        C = numpy.polynomial.chebyshev.Chebyshev(coeff, domain=[0, 10])
        x = numpy.random.rand(4, 4)
        ref = eval_horner(x, C.convert(kind=numpy.polynomial.Polynomial).coef)
        out = eval_chebyshev(x, coeff, domain=[0, 10])
        diff = numpy.linalg.norm(out - ref)/numpy.linalg.norm(ref)
        self.assertTrue(diff < 1e-14, "Difference: {}".format(diff))

    def test_chebyshev_gen_batch(self):
        coeff1 = numpy.random.rand(6)
        coeff2 = numpy.random.rand(6)
        x = numpy.random.rand(4, 4)
        ref1 = eval_chebyshev(x, coeff1, domain=[0, 10])
        ref2 = eval_chebyshev(x, coeff2, domain=[0, 10])
        out1, out2 = eval_chebyshev_batch(x, [coeff1, coeff2], domain=[0, 10])
        diff1 = numpy.linalg.norm(out1 - ref1)/numpy.linalg.norm(ref1)
        diff2 = numpy.linalg.norm(out2 - ref2)/numpy.linalg.norm(ref2)
        self.assertTrue(diff1 < 1e-14, "Difference: {}".format(diff1))
        self.assertTrue(diff2 < 1e-14, "Difference: {}".format(diff2))

    def test_chebyshev_matvec(self):
        coeff = numpy.random.rand(6)
        x = numpy.random.rand(4, 4)
        def matvec(vec): return numpy.einsum('ij,j->i', x, vec)
        vec = numpy.random.rand(4)
        Tx = eval_chebyshev(x, coeff)
        ref = numpy.einsum('ij,j->i', Tx, vec)
        out = eval_chebyshev_matvec(vec, matvec, coeff)
        diff = numpy.linalg.norm(out - ref)/numpy.linalg.norm(ref)
        self.assertTrue(diff < 1e-14, "Difference: {}".format(diff))

    def test_chebyshev_matvec_gen(self):
        coeff = numpy.random.rand(6)
        C = numpy.polynomial.chebyshev.Chebyshev(coeff, domain=[0, 10])
        x = numpy.random.rand(4, 4)
        def matvec(vec): return numpy.einsum('ij,j->i', x, vec)
        vec = numpy.random.rand(4)
        Tx = eval_chebyshev(x, coeff, domain=C.domain)
        ref = numpy.einsum('ij,j->i', Tx, vec)
        out = eval_chebyshev_matvec(vec, matvec, coeff, domain=C.domain)
        diff = numpy.linalg.norm(out - ref)/numpy.linalg.norm(ref)
        self.assertTrue(diff < 1e-14, "Difference: {}".format(diff))

        # test with projector
        P = numpy.identity(4)
        P[3, 3] = 0
        Pv = numpy.einsum('ij,j->i', P, vec)
        ref = numpy.einsum('ij,j->i', Tx, Pv)
        out = eval_chebyshev_matvec(
            vec, matvec, coeff, domain=C.domain, proj=P)
        diff = numpy.linalg.norm(out - ref)/numpy.linalg.norm(ref)
        self.assertTrue(diff < 1e-14, "Difference: {}".format(diff))

    def test_chebyshev_matvec_batch(self):
        coeff1 = numpy.random.rand(6)
        coeff2 = numpy.random.rand(6)
        coeff3 = numpy.random.rand(6)
        coeffs = [coeff1, coeff2, coeff3]
        dom = [0, 10]
        x = numpy.random.rand(4, 4)
        def matvec(vec): return numpy.einsum('ij,j->i', x, vec)
        vec = numpy.random.rand(4)
        Tx = [eval_chebyshev(x, coeff, domain=dom) for coeff in coeffs]
        ref = [numpy.einsum('ij,j->i', Txi, vec) for Txi in Tx]
        out = eval_chebyshev_matvec_batch(vec, matvec, coeffs, domain=dom)
        norm = numpy.linalg.norm
        diffs = [norm(o - r)/norm(r) for o, r in zip(out, ref)]
        for i, d in enumerate(diffs):
            self.assertTrue(d < 1e-14, "Difference ({}): {}".format(i, d))

        # test with projector
        P = numpy.identity(4)
        P[3, 3] = 0
        Pv = numpy.einsum('ij,j->i', P, vec)
        ref = [numpy.einsum('ij,j->i', T, Pv) for T in Tx]
        out = eval_chebyshev_matvec_batch(
            vec, matvec, coeffs, domain=dom, proj=P)
        diffs = [norm(o - r)/norm(r) for o, r in zip(out, ref)]
        for i, d in enumerate(diffs):
            self.assertTrue(d < 1e-14, "Difference ({}): {}".format(i, d))


if __name__ == '__main__':
    unittest.main()
