import unittest
import numpy

import omega.finite_difference as fd
from omega.system import Molecule

has_pyscf = True
try:
    from omega.pyscf_interface import PyscfInterface, Options
except ImportError:
    has_pyscf = False


class FDHessianTest(unittest.TestCase):
    def setUp(self):
        self.delta = 0.000256
        self.thresh = 1e-6

    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_H2(self):
        H2 = Molecule()

        # reference geometry
        H2.add((0.0, 0.0, 0.0), name='H')
        H2.add((0.0, 0.0, 0.74), name='H')

        op = Options()
        op.method = "HF"
        op.basis = "sto-3g"
        computer = PyscfInterface(op)

        # compute analytic Hessian at reference geometry
        E, F1, hess = computer.hessian(H2)

        # compute FD Hessian at reference geometry
        fd_hess, _ = fd.fd_d2_1(
            H2, H2.coords, computer, None, diag3=False, delta=0.0012)

        diff = numpy.linalg.norm(fd_hess - hess)/numpy.sqrt(fd_hess.size)
        self.assertTrue(diff < self.thresh, "Error in FD Hessian")

    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_CH4(self):
        CH4 = Molecule()

        # reference geometry
        CH4.add((0.5288, 0.1610, 0.9359), name='H')
        CH4.add((0.2051, 0.8240, -0.6786), name='H')
        CH4.add((0.3345, -0.9314, -0.4496), name='H')
        CH4.add((-1.0685, -0.0537, 0.1921), name='H')
        CH4.add((0.0000, 0.0000, 0.0000), name='C')

        op = Options()
        op.method = "HF"
        op.basis = "sto-3g"
        computer = PyscfInterface(op)

        # compute analytic Hessian at reference geometry
        E, F1, hess = computer.hessian(CH4)

        # compute FD Hessian at reference geometry
        fd_hess, _ = fd.fd_d2_1(
            CH4, CH4.coords, computer, None, diag3=False, delta=0.0012)

        diff = numpy.linalg.norm(fd_hess - hess)
        diff /= numpy.sqrt(float(fd_hess.size))
        self.assertTrue(diff < self.thresh, "Error in FD Hessian")


if __name__ == '__main__':
    unittest.main()
