import unittest
import numpy
from omega.system import Molecule, MolSystem
from omega.finite_difference import fd_hess_on_vec

has_pyscf = True
try:
    from omega.pyscf_interface import PyscfInterface, Options
except ImportError:
    has_pyscf = False


class FDHessVecTest(unittest.TestCase):
    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_H2(self):
        H2 = Molecule()
        # reference geometry
        H2.add((0.0, 0.0, 0.0), name='H')
        H2.add((0.0, 0.0, 0.74), name='H')

        op = Options()
        op.method = "b3lyp"
        op.basis = "ccpvdz"
        computer = PyscfInterface(op)

        sys = MolSystem(mol=H2)
        sys.compute_forces(computer, order=2)
        vec = numpy.random.rand(6)

        ref = numpy.einsum('ij,j->i', sys.F2, vec)
        out = fd_hess_on_vec(H2, computer, vec)

        diff = numpy.linalg.norm(ref - out)
        self.assertTrue(diff < 1e-6, "Error: {}".format(diff))

    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_H2_mweight(self):
        H2 = Molecule()
        # reference geometry
        H2.add((0.0, 0.0, 0.0), name='H')
        H2.add((0.0, 0.0, 0.74), name='H')

        op = Options()
        op.method = "b3lyp"
        op.basis = "ccpvdz"
        computer = PyscfInterface(op)

        sys = MolSystem(mol=H2)
        sys.compute_forces(computer, order=2)
        vec = numpy.random.rand(6)

        Mi2 = 1.0/numpy.sqrt(numpy.asarray(sys.M))
        F2m = numpy.einsum('ij,i,j->ij', sys.F2, Mi2, Mi2)
        ref = numpy.einsum('ij,j->i', F2m, vec)
        out = fd_hess_on_vec(H2, computer, vec, mweight=True)

        diff = numpy.linalg.norm(ref - out)
        self.assertTrue(diff < 1e-6, "Error: {}".format(diff))

    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_H2_forward(self):
        H2 = Molecule()
        # reference geometry
        H2.add((0.0, 0.0, 0.0), name='H')
        H2.add((0.0, 0.0, 0.74), name='H')

        op = Options()
        op.method = "b3lyp"
        op.basis = "ccpvdz"
        computer = PyscfInterface(op)

        sys = MolSystem(mol=H2)
        sys.compute_forces(computer, order=2)
        vec = numpy.random.rand(6)

        ref = numpy.einsum('ij,j->i', sys.F2, vec)
        out = fd_hess_on_vec(H2, computer, vec, method="forward", delta=0.0005)

        diff = numpy.linalg.norm(ref - out)
        self.assertTrue(diff < 5e-4, "Error: {}".format(diff))

    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_H2_backward(self):
        H2 = Molecule()
        # reference geometry
        H2.add((0.0, 0.0, 0.0), name='H')
        H2.add((0.0, 0.0, 0.74), name='H')

        op = Options()
        op.method = "b3lyp"
        op.basis = "ccpvdz"
        computer = PyscfInterface(op)

        sys = MolSystem(mol=H2)
        sys.compute_forces(computer, order=2)
        vec = numpy.random.rand(6)

        ref = numpy.einsum('ij,j->i', sys.F2, vec)
        out = fd_hess_on_vec(
            H2, computer, vec, method="backward", delta=0.0005)

        diff = numpy.linalg.norm(ref - out)
        self.assertTrue(diff < 5e-4, "Error: {}".format(diff))

    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_H2_fmweight(self):
        H2 = Molecule()
        # reference geometry
        H2.add((0.0, 0.0, 0.0), name='H')
        H2.add((0.0, 0.0, 0.74), name='H')

        op = Options()
        op.method = "b3lyp"
        op.basis = "ccpvdz"
        computer = PyscfInterface(op)

        sys = MolSystem(mol=H2)
        sys.compute_forces(computer, order=2)
        vec = numpy.random.rand(6)

        Mi2 = 1.0/numpy.sqrt(numpy.asarray(sys.M))
        F2m = numpy.einsum('ij,i,j->ij', sys.F2, Mi2, Mi2)
        ref = numpy.einsum('ij,j->i', F2m, vec)
        out = fd_hess_on_vec(
            H2, computer, vec, method="forward",
            mweight=True, gc=sys.F1.copy(), delta=0.0005)

        diff = numpy.linalg.norm(ref - out)
        self.assertTrue(diff < 5e-4, "Error: {}".format(diff))


if __name__ == '__main__':
    unittest.main()
