import unittest
import numpy

from omega.system import MolSystem
from omega.transrot import transrot, projector_against, half_projector_against

has_pyscf = True
try:
    from omega.pyscf_interface import PyscfInterface, Options
except ImportError:
    has_pyscf = False


class ProjectorTest(unittest.TestCase):
    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_H2(self):
        sys = MolSystem(filename="xyz/H2.xyz")

        # compute analytic Hessian at reference geometry
        op = Options()
        op.method = "HF"
        op.charge = 0
        op.spin = 0
        op.basis = "ccpvtz"
        computer = PyscfInterface(op)
        sys.compute_forces(computer, order=2)
        sys.get_normal_modes()
        w2 = sys.P2.diagonal()[0]

        # massweight Hessian
        Mi2 = 1.0/numpy.sqrt(numpy.asarray(sys.M))
        F2m = numpy.einsum('ij,i,j->ij', sys.F2, Mi2, Mi2)

        # get translations/rotations
        dvecs = transrot(
            sys.mol.natom, sys.mol.coords, sys.I, sys.R,
            linear=True, masses=sys.M, mweight=True)

        # get half-projector
        Uv = half_projector_against(dvecs)
        F2mv = numpy.einsum('pi,qj,pq->ij', Uv, Uv, F2m)
        diff = abs(F2mv[0] - w2)
        self.assertTrue(
            diff < 1e-14, "Difference in half-projctor: {}".format(diff))

        # get full half-projector
        Uv = half_projector_against(dvecs, full=True)
        F2mv = numpy.einsum('pi,qj,pq->ij', Uv, Uv, F2m)
        diff = abs(w2 - numpy.trace(F2mv))
        self.assertTrue(
            diff < 1e-14, "Difference in half-projctor: {}".format(diff))

        # full projector
        Pv = projector_against(dvecs)
        F2mv = numpy.einsum('pi,qj,pq->ij', Pv, Pv, F2m)
        diff = abs(w2 - numpy.trace(F2mv))
        self.assertTrue(
            diff < 1e-14, "Difference in half-projctor: {}".format(diff))

    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_H2O(self):
        sys = MolSystem(filename="xyz/h2o_b3lyp_sto3g.xyz")

        # compute analytic Hessian at reference geometry
        op = Options()
        op.method = "b3lyp"
        op.charge = 0
        op.spin = 0
        op.basis = "sto-3g"
        computer = PyscfInterface(op)
        sys.compute_forces(computer, order=2)
        sys.get_normal_modes()

        # massweight Hessian
        Mi2 = 1.0/numpy.sqrt(numpy.asarray(sys.M))
        F2m = numpy.einsum('ij,i,j->ij', sys.F2, Mi2, Mi2)

        # get translations/rotations
        dvecs = transrot(
            sys.mol.natom, sys.mol.coords, sys.I, sys.R,
            linear=False, masses=sys.M, mweight=True)

        # get half-projector
        Uv = half_projector_against(dvecs)
        F2mv = numpy.einsum('pi,qj,pq->ij', Uv, Uv, F2m)
        e, v = numpy.linalg.eigh(F2mv)
        diff = abs(numpy.trace(F2mv) - numpy.trace(sys.P2))
        self.assertTrue(
            diff < 1e-14, "Difference in half-projctor: {}".format(diff))

        # get full half-projector
        Uv = half_projector_against(dvecs, full=True)
        F2mv = numpy.einsum('pi,qj,pq->ij', Uv, Uv, F2m)
        diff = abs(numpy.trace(F2mv) - numpy.trace(sys.P2))
        self.assertTrue(
            diff < 1e-14, "Difference in full half-projctor: {}".format(diff))

        # get projector
        Pv = projector_against(dvecs)
        F2mv = numpy.einsum('pi,qj,pq->ij', Pv, Pv, F2m)
        diff = abs(numpy.trace(F2mv) - numpy.trace(sys.P2))
        self.assertTrue(
            diff < 1e-14, "Difference in half-projctor: {}".format(diff))


if __name__ == '__main__':
    unittest.main()
