import unittest
import numpy
from omega import constants
from omega.system import Molecule
from omega.system import MolSystem
from omega.vci import VCI


def get_N2H2_L():
    L = numpy.zeros((12, 6))
    L[:, 0] = numpy.array([
        +4.868929840184689e-17,  3.385075527403519e-17,
        +1.593384987245489e-02, -1.088843536945972e-17,
        -4.628114281774521e-18, -1.146786252304376e-03,
        +3.686300295670469e-18, -8.804280867943412e-20,
        -1.146786252304376e-03,  5.137968899340944e-17,
        +3.167718326133343e-17,  1.593384987245490e-02])
    L[:, 1] = numpy.array([
        -1.339593966627884e-02, -8.627651604889293e-03,
        +6.191544449762012e-17,  9.641285420006624e-04,
        +6.209467472931893e-04, -4.456159755711691e-18,
        +9.641285420006676e-04,  6.209467472931852e-04,
        -4.456159755711694e-18, -1.339593966627881e-02,
        -8.627651604889271e-03,  6.191544449762015e-17])
    L[:, 2] = numpy.array([
        +2.486225704821227e-03,  2.396920945123838e-03,
        -3.139997542379306e-17, -3.958892610038608e-03,
        -1.748323209326894e-03,  2.259909590396644e-18,
        +3.958892610038609e-03,  1.748323209326896e-03,
        +2.259909590396645e-18, -2.486225704821222e-03,
        -2.396920945123838e-03, -3.139997542379308e-17])
    L[:, 3] = numpy.array([
        -1.310827794325998e-02, -6.192490065361955e-03,
        -1.061407619959202e-17, -1.513124807616606e-03,
        +1.473675821973409e-03,  7.639131009791472e-19,
        +1.513124807616603e-03, -1.473675821973413e-03,
        +7.639131009791476e-19,  1.310827794326002e-02,
        +6.192490065361976e-03, -1.061407619959203e-17])
    L[:, 4] = numpy.array([
        +8.754143647062822e-03, -1.333879390485947e-02,
        -1.326759524949003e-18, -6.047741248395949e-04,
        +9.492536245048907e-04,  9.548913762239340e-20,
        +6.047741248396238e-04, -9.492536245049479e-04,
        +9.548913762239345e-20, -8.754143647063322e-03,
        +1.333879390486026e-02, -1.326759524949003e-18])
    L[:, 5] = numpy.array([
        -8.627651604889534e-03,  1.339593966627920e-02,
        -1.769012699932004e-18,  6.209467472931965e-04,
        -9.641285420006930e-04,  1.273188501631912e-19,
        +6.209467472931678e-04, -9.641285420006383e-04,
        +1.273188501631913e-19, -8.627651604889024e-03,
        +1.339593966627844e-02, -1.769012699932004e-18])


def read2(fname, nmode, P2):
    f = open(fname)
    lines = f.readlines()
    f.close()
    for line in lines:
        s = line.split()
        i = int(s[0]) - 1
        for j in range(nmode - i):
            P2[i, j+i] = float(s[j+1])
            P2[j+i, i] = P2[i, j+i]


def read3(fname, nmode, P3):
    f = open(fname)
    lines = f.readlines()
    f.close()
    for line in lines:
        s = line.split()
        i = int(s[0]) - 1
        j = int(s[1]) - 1
        for kk in range(nmode - j):
            f = float(s[kk+2])
            k = kk + j
            P3[i, j, k] = f
            P3[i, k, j] = f
            P3[j, i, k] = f
            P3[j, k, i] = f
            P3[k, i, j] = f
            P3[k, j, i] = f


def read4(fname, nmode, P4):
    f = open(fname)
    lines = f.readlines()
    f.close()
    for line in lines:
        s = line.split()
        i = int(s[0]) - 1
        j = int(s[1]) - 1
        k = int(s[2]) - 1
        for ll in range(nmode - k):
            f = float(s[ll+3])
            l = ll + k
            P4[i, j, k, l] = f
            P4[i, k, j, l] = f
            P4[j, i, k, l] = f
            P4[j, k, i, l] = f
            P4[k, i, j, l] = f
            P4[k, j, i, l] = f

            P4[i, j, l, k] = f
            P4[i, k, l, j] = f
            P4[j, i, l, k] = f
            P4[j, k, l, i] = f
            P4[k, i, l, j] = f
            P4[k, j, l, i] = f

            P4[i, l, j, k] = f
            P4[i, l, k, j] = f
            P4[j, l, i, k] = f
            P4[j, l, k, i] = f
            P4[k, l, i, j] = f
            P4[k, l, j, i] = f

            P4[l, i, j, k] = f
            P4[l, i, k, j] = f
            P4[l, j, i, k] = f
            P4[l, j, k, i] = f
            P4[l, k, i, j] = f
            P4[l, k, j, i] = f


class VCITest(unittest.TestCase):
    def setUp(self):
        self.thresh = 1e-4

    def _test_N2H2(self, N):
        N2H2 = Molecule()
        N2H2.add(
            [2.18628828338372e+00, -1.29218485926721e+00, 0.00000000000000e+00],
            name="H", unit="Bohr")
        N2H2.add(
            [1.13263930531833e+00, 3.46521429709182e-01, 0.00000000000000e+00],
            name="N", unit="Bohr")
        N2H2.add(
            [-1.13263930531833e+00, -3.46521429709182e-01, 0.00000000000000e+00],
            name="N", unit="Bohr")
        N2H2.add(
            [-2.18628828338372e+00, 1.29218485926721e+00, 0.00000000000000e+00],
            name="H", unit="Bohr")

        L = get_N2H2_L()
        P2 = numpy.zeros((6, 6))
        P3 = numpy.zeros((6, 6, 6))
        P4 = numpy.zeros((6, 6, 6, 6))
        read2("forces/N2H2_2.dat", 6, P2)
        read3("forces/N2H2_3.dat", 6, P3)
        read4("forces/N2H2_4.dat", 6, P4)
        sys = MolSystem(mol=N2H2)
        sys.L = L
        sys.P2 = P2
        sys.P3 = P3
        sys.P4 = P4
        ci = VCI(sys, N)
        etest = ci.solve()
        return etest

    def test_N2H2_1(self):
        eout = self._test_N2H2(1)
        eref = numpy.zeros(7)
        eref[0] = 0.0279851978576
        eref[1] = 0.0340138124507
        eref[2] = 0.0342115763996
        eref[3] = 0.0353438491542
        eref[4] = 0.035627841379
        eref[5] = 0.043574472661
        eref[6] = 0.0436990869888
        htcm = constants.hartree_to_cm_1
        thresh = self.thresh/htcm
        for i, e in enumerate(eref):
            diff = abs(e - eout[i])
            msg = "Mode {}--Expected: {} Actual: {}".format(
                i, e*htcm, eout[i]*htcm)
            self.assertTrue(diff < thresh, msg)

    def test_N2H2_2(self):
        eout = self._test_N2H2(2)
        eref = numpy.zeros(17)
        eref[0] = 0.0279730614363
        eref[1] = 0.0338185967207
        eref[2] = 0.0339880687088
        eref[3] = 0.034951918873
        eref[4] = 0.0351981208782
        eref[5] = 0.0397485382901
        eref[6] = 0.03998719162
        eref[7] = 0.0401773235893
        eref[8] = 0.0410911710624
        eref[9] = 0.041204331806
        eref[10] = 0.0413668443712
        eref[11] = 0.0415910435056
        eref[12] = 0.0421409143886
        eref[13] = 0.0424532563903
        eref[14] = 0.0425728701453
        eref[15] = 0.0428162509098
        eref[16] = 0.0431699503389
        htcm = constants.hartree_to_cm_1
        thresh = self.thresh/htcm
        for i, e in enumerate(eref):
            diff = abs(e - eout[i])
            msg = "Mode {}--Expected: {} Actual: {}".format(
                i, e*htcm, eout[i]*htcm)
            self.assertTrue(diff < thresh, msg)

    def test_N2H2_3(self):
        eout = self._test_N2H2(3)
        eref = numpy.zeros(17)
        eref[0] = 0.0278432685046
        eref[1] = 0.0337986955971
        eref[2] = 0.0339705839134
        eref[3] = 0.0349254591665
        eref[4] = 0.0351739646554
        eref[5] = 0.0396180714152
        eref[6] = 0.0398169634163
        eref[7] = 0.0399679472654
        eref[8] = 0.0407464786428
        eref[9] = 0.040866432672
        eref[10] = 0.0410115908437
        eref[11] = 0.0411129950859
        eref[12] = 0.0418502559586
        eref[13] = 0.0420261900262
        eref[14] = 0.0423036662633
        eref[15] = 0.0423267630333
        eref[16] = 0.0424332604406
        htcm = constants.hartree_to_cm_1
        thresh = self.thresh/htcm
        for i, e in enumerate(eref):
            diff = abs(e - eout[i])
            msg = "Mode {}--Expected: {} Actual: {}".format(
                i, e*htcm, eout[i]*htcm)
            self.assertTrue(diff < thresh, msg)

    def test_N2H2_4(self):
        eout = self._test_N2H2(4)
        eref = numpy.zeros(17)
        eref[0] = 0.0278321851223
        eref[1] = 0.0336394896025
        eref[2] = 0.0337987523779
        eref[3] = 0.0347696260281
        eref[4] = 0.0350217132093
        eref[5] = 0.0395761378834
        eref[6] = 0.0397861649933
        eref[7] = 0.0399289694248
        eref[8] = 0.0407088969972
        eref[9] = 0.0408068257134
        eref[10] = 0.0409759296111
        eref[11] = 0.0410583464969
        eref[12] = 0.0417480860537
        eref[13] = 0.0418803613178
        eref[14] = 0.042111436707
        eref[15] = 0.0421957549186
        eref[16] = 0.042370526132
        htcm = constants.hartree_to_cm_1
        thresh = self.thresh/htcm
        for i, e in enumerate(eref):
            diff = abs(e - eout[i])
            msg = "Mode {}--Expected: {} Actual: {}".format(
                i, e*htcm, eout[i]*htcm)
            self.assertTrue(diff < thresh, msg)


if __name__ == '__main__':
    unittest.main()
