import unittest
import numpy

from omega.system import Molecule, get_mol_from_xyz
from omega.coord.ric import RIC


class RICTest(unittest.TestCase):
    def test_H2(self):
        mymol = Molecule()
        mymol.add((0.0, 0.0, 0.0), name="H")
        mymol.add((0.0, 0.0, 0.8), name="H")

        myic = RIC(mymol)
        Bref = myic.bmatrix_fd(mymol.coords)
        Bout = myic.bmatrix(mymol.coords)
        diff = numpy.linalg.norm(Bout - Bref)
        self.assertTrue(diff < 1e-5)

    def test_H2O(self):
        mymol = get_mol_from_xyz("xyz/h2o_b3lyp_sto3g.xyz")
        myic = RIC(mymol)
        Bref = myic.bmatrix_fd(mymol.coords)
        Bout = myic.bmatrix(mymol.coords)
        diff = numpy.linalg.norm(Bout - Bref)
        self.assertTrue(diff < 1e-5)

    def test_ethane(self):
        mymol = Molecule()
        mymol.add([+0.000000000,  0.000000000,  0.000000000], name='C', unit="Bohr")
        mymol.add([+0.000000000,  0.000000000,  3.200000000], name='C', unit="Bohr")
        mymol.add([+0.000000000,  2.050000000,  0.000000000], name='H', unit="Bohr")
        mymol.add([-1.775352078, -1.025000000,  0.000000000], name='H', unit="Bohr")
        mymol.add([+1.775352078, -1.025000000,  0.000000000], name='H', unit="Bohr")
        mymol.add([+1.775352078,  1.025000000,  3.200000000], name='H', unit="Bohr")
        mymol.add([-1.775352078,  1.025000000,  3.200000000], name='H', unit="Bohr")
        mymol.add([+0.000000000, -2.050000000,  3.200000000], name='H', unit="Bohr")

        myic = RIC(mymol)
        Bref = myic.bmatrix_fd(mymol.coords)
        Bout = myic.bmatrix(mymol.coords)
        diff = numpy.linalg.norm(Bout - Bref)
        self.assertTrue(diff < 1e-5)

    def test_benzene(self):
        mymol = get_mol_from_xyz("xyz/Benzene_b3lyp_sto3g.xyz")
        myic = RIC(mymol)
        Bref = myic.bmatrix_fd(mymol.coords, delta=5e-4)
        Bout = myic.bmatrix(mymol.coords)
        diff = numpy.linalg.norm(Bout - Bref)/numpy.sqrt(Bout.size)
        self.assertTrue(diff < 1e-5)


if __name__ == '__main__':
    unittest.main()
