import unittest
import numpy
import os
from omega.system import Molecule
from omega.system import UnitCell
from omega.system import get_cell_from_hdf5
from omega.system import get_cell_from_ase
from omega.constants import angstrom_to_bohr

has_ase = True
try:
    import ase
except ImportError:
    has_ase = False

has_pyscf = True
try:
    import pyscf
except ImportError:
    has_pyscf = False


class UnitCellTest(unittest.TestCase):
    def test_input_diamond(self):
        ab = 3.370137329
        aa = ab/angstrom_to_bohr
        l1b = [0.0, ab, ab]
        l2b = [ab, 0.0, ab]
        l3b = [ab, ab, 0.0]
        l1a = [0.0, aa, aa]
        l2a = [aa, 0.0, aa]
        l3a = [aa, aa, 0.0]

        mola = Molecule()
        molb = Molecule()

        molb.add((0.0, 0.0, 0.0), name="C", unit="Bohr")
        molb.add((ab/2, ab/2, ab/2), name="C", unit="Bohr")
        mola.add((0.0, 0.0, 0.0), name="C", unit="Angstrom")
        mola.add((aa/2, aa/2, aa/2), name="C", unit="Angstrom")

        ucb = UnitCell(mol=molb, lattice=[l1b, l2b, l3b], unit="Bohr")
        uca = UnitCell(mol=mola, lattice=[l1a, l2a, l3a], unit="Angstrom")

        for lb, la in zip(ucb.lattice, uca.lattice):
            xla = numpy.asarray(la)
            xlb = numpy.asarray(lb)
            self.assertTrue(numpy.linalg.norm(xla - xlb) < 1e-14)

        for xb, xa in zip(ucb.mol.coords, uca.mol.coords):
            xxb = numpy.asarray(xb)
            xxa = numpy.asarray(xa)
            self.assertTrue(numpy.linalg.norm(xxa - xxb) < 1e-14)

    def test_h5py(self):
        a = 3.370137329
        l1 = [0.0, a, a]
        l2 = [a, 0.0, a]
        l3 = [a, a, 0.0]

        mol = Molecule()
        mol.add((0.0, 0.0, 0.0), name="C", unit="Bohr")
        mol.add((a/2, a/2, a/2), name="C", unit="Bohr")
        uc = UnitCell(mol=mol, lattice=[l1, l2, l3], unit="Bohr")
        filename = "_test.h5"
        uc.save(filename)
        uc2 = get_cell_from_hdf5(filename)
        os.remove(filename)
        for lb, la in zip(uc.lattice, uc2.lattice):
            xla = numpy.asarray(la)
            xlb = numpy.asarray(lb)
            self.assertTrue(numpy.linalg.norm(xla - xlb) < 1e-14)

        for xb, xa in zip(uc.mol.coords, uc2.mol.coords):
            xxb = numpy.asarray(xb)
            xxa = numpy.asarray(xa)
            self.assertTrue(numpy.linalg.norm(xxa - xxb) < 1e-14)

    @unittest.skipUnless(has_pyscf, "Requires PySCF")
    def test_supercell(self):
        from omega.pyscf_interface import PBCPyscfInterface, PBCOptions
        ab = 3.370137329
        l1b = [0.0, ab, ab]
        l2b = [ab, 0.0, ab]
        l3b = [ab, ab, 0.0]
        mol = Molecule()
        mol.add((0.0, 0.0, 0.0), name="C", unit="Bohr")
        mol.add((ab/2, ab/2, ab/2), name="C", unit="Bohr")
        uc = UnitCell(mol=mol, lattice=[l1b, l2b, l3b], unit="Bohr")
        options = PBCOptions()
        options.method = 'lda'
        options.basis = 'gth-dzvp'
        options.pseudo = 'gth-pade'
        options.kmesh = [2, 2, 2]
        options.precision = 1e-10
        options.scf_convergence = 1e-10
        computer = PBCPyscfInterface(options)
        Eref = computer.energy(uc)

        suc = uc.supercell([2, 2, 2])
        options.kmesh = [1, 1, 1]
        computer = PBCPyscfInterface(options)
        Eout = computer.energy(suc)
        diff = abs(Eout/8 - Eref)
        self.assertTrue(diff < 1e-5)

    @unittest.skipUnless(has_ase, "Requires ASE")
    def test_ase(self):
        from ase.io import read
        si_ase = read("xyz/Si.cif")
        cell = get_cell_from_ase(si_ase)
        aarray = numpy.asarray(cell.lattice)/angstrom_to_bohr
        self.assertTrue((numpy.linalg.norm(aarray[0]) - 2.7361) < 1e-4)
        self.assertTrue((numpy.linalg.norm(aarray[1]) - 2.7361) < 1e-4)
        self.assertTrue((numpy.linalg.norm(aarray[2]) - 2.7361) < 1e-4)


if __name__ == '__main__':
    unittest.main()
