import unittest
import numpy
import os
from omega.system import Molecule
from omega.system import get_mol_from_xyz
from omega.system import get_mol_from_ase
from omega.system import get_mol_from_hdf5
from omega.constants import elements
from omega.constants import atomic_masses
from omega.constants import charges
from omega.constants import angstrom_to_bohr

has_ase = True
try:
    import ase
except ImportError:
    has_ase = False


class MoleculeTest(unittest.TestCase):
    def setUp(self):
        self.thresh = 1e-14

    def test_He(self):
        mymol = Molecule()
        mymol.add((0.0, 0.0, 0.0), name="He")
        self.assertTrue(abs(mymol.masses[0] - atomic_masses[1]) < self.thresh)
        self.assertTrue(abs(mymol.charges[0] - 2.0) < self.thresh)

        mymol = Molecule()
        mymol.add((0.0, 0.0, 0.0), charge=2.0)
        self.assertTrue(abs(mymol.masses[0] - atomic_masses[1]) < self.thresh)
        self.assertTrue(mymol.names[0] == numpy.string_("He"))

    def test_atoms(self):
        for i, el in enumerate(elements):
            mymol = Molecule()
            mymol.add((0.0, 0.0, 0.0), name=el)

            mdiff = abs(mymol.masses[0] - atomic_masses[i])
            msg = "Instonstent masses: {} {}".format(
                mymol.masses[0], atomic_masses[i])
            self.assertTrue(mdiff < self.thresh, msg)
            cdiff = abs(mymol.charges[0] - charges[i])
            msg = "Instonstent charges: {} {}".format(
                mymol.charges[0], charges[i])
            self.assertTrue(cdiff < self.thresh, msg)

    def test_xyz(self):
        mol = get_mol_from_xyz("xyz/H2.xyz")
        exp = 0.36*angstrom_to_bohr
        act1 = mol.coords[0][2]
        act2 = -mol.coords[1][2]
        msg = "Error reading xyz--Expected: {} Actual: {}".format(exp, act1)
        diff = abs(exp - act1)
        self.assertTrue(diff < self.thresh, msg)
        msg = "Error reading xyz--Expected: {} Actual: {}".format(exp, act2)
        diff = abs(exp - act2)
        self.assertTrue(diff < self.thresh, msg)

    def test_h5py(self):
        # create H2 molecule
        mol = Molecule()
        mol.add((0.0, 0.0, 0.76), charge=1)
        mol.add((0.0, 0.0, 0.0), charge=1)

        # create H5py data file
        mol.save("_test.h5")

        # create new H2 molecule from data file
        nmol = get_mol_from_hdf5("_test.h5")
        os.remove("_test.h5")

        # compare
        self.assertTrue(mol.natom == nmol.natom)
        self.assertTrue(mol.names == nmol.names)
        self.assertTrue(mol.charges == nmol.charges)
        self.assertTrue(mol.masses == nmol.masses)
        for i, x in enumerate(mol.coords):
            self.assertTrue(x == nmol.coords[i])

    @unittest.skipUnless(has_ase, "Requires ase")
    def test_ase(self):
        from ase import Atom, Atoms
        tol = 1e-14
        d = 1.104  # N2 bondlength (Angstrom)
        a = Atoms([Atom('N', (0, 0, 0)), Atom('N', (0, 0, d))])
        mol = get_mol_from_ase(a)
        a2 = mol.get_ase()
        p1 = a.get_positions()
        p2 = a2.get_positions()
        m1 = a.get_masses()
        m2 = a2.get_masses()
        for mm1, mm2 in zip(m1, m2):
            self.assertTrue(abs(mm1 - mm2) < tol)
        for pp1, pp2 in zip(p1, p2):
            c1 = numpy.asarray(pp1)
            c2 = numpy.asarray(pp2)
            self.assertTrue(numpy.linalg.norm(c1 - c2) < tol)


if __name__ == '__main__':
    unittest.main()
