import copy
import logging
import numpy

from .constants import elements
from .constants import atomic_masses
from .constants import charges
from .constants import amu_to_el
from .constants import bohr_to_angstrom
from .constants import angstrom_to_bohr

from . import normal_modes
from . import phonons
from .compute import compute_forces
from .compute import compute_gradient
from .compute import compute_hessian


def is_linear(coords, tol=1e-5):
    if len(coords) < 3:
        return True

    A = coords[0]
    B = coords[1]
    for C in coords[2:]:
        CA = [C[0] - A[0], C[1] - A[1], C[2] - A[2]]
        CB = [C[0] - B[0], C[1] - B[1], C[2] - B[2]]
        L1 = numpy.asarray(CA)
        L2 = numpy.asarray(CB)
        N = numpy.cross(L1, L2)
        if numpy.linalg.norm(N) > tol:
            return False


class Molecule(object):
    """Molecule object

    Attributes:
        names (list): Names of atoms
        charges (list): Charges of atoms (e)
        masses (list): Masses of atoms (m_e)
        coords (list): Coordinates of atoms (Bohr) as a list of triples
        natom (int): Number of atoms
        charge (int): Overall charge of molecule
        spin (int): Spin projection (nalpha - nbeta)
    """
    def __init__(self, fh5=None):
        self.thresh = 1e-12
        if fh5 is None:
            self.names = []
            self.charges = []
            self.masses = []
            self.coords = []
            self.natom = 0
            self.charge = 0
            self.spin = 0
        else:
            self.natom = fh5.attrs["natom"]
            self.names = list(fh5["names"])
            self.charges = list(fh5["charges"])
            self.masses = list(fh5["masses"])
            self.coords = numpy.array(fh5["coords"]).tolist()
            if "charge" in fh5.attrs.keys():
                self.charge = fh5.attrs["charge"]
                self.spin = fh5.attrs["spin"]
            else:
                logging.warning("Old-style molecule file detected, consider replacing it")
                self.charge = 0
                self.spin = (0 if self._n_proton() % 2 == 0 else 1)

    def _n_proton(self):
        return sum(self.charges)

    def copy(self):
        """Create a deep copy of the molecule."""
        newmol = Molecule()
        newmol.names = copy.deepcopy(self.names)
        newmol.charges = copy.deepcopy(self.charges)
        newmol.masses = copy.deepcopy(self.masses)
        newmol.coords = copy.deepcopy(self.coords)
        newmol.natom = self.natom
        newmol.charge = self.charge
        newmol.spin = self.spin
        return newmol

    def _verify_lengths(self):
        l = self.natom
        assert(len(self.names) == l)
        assert(len(self.charges) == l)
        assert(len(self.masses) == l)
        assert(len(self.coords) == l)

    def add(self, xyz, name=None, charge=None, mass=None, unit="Angstrom"):
        """Add an atom to the molecule."""
        if name is None and charge is None:
            raise Exception("molecule::add(): name/charge not specified")
        if name is None:
            c = int(round(charge))
            assert(c > 0)
            if abs(c - charge) < self.thresh:
                name = elements[c - 1]
            else:
                name = "xx"
        if charge is None:
            charge = charges[elements.index(name)]
        if mass is None:
            c = int(round(charge))
            mass = atomic_masses[c - 1]

        # convert Angstrom to bohr if necessary
        if unit[0].lower() == "a":
            cc = angstrom_to_bohr
            xyz_new = [xyz[0]*cc, xyz[1]*cc, xyz[2]*cc]
        else:
            xyz_new = list(xyz)
        self.names.append(numpy.string_(name))
        self.charges.append(charge)
        self.masses.append(mass)
        self.coords.append(xyz_new)
        self.natom += 1
        self._verify_lengths()
        np = self._n_proton()
        if np % 2 == 1:
            self.spin = 1
        else:
            self.spin = 0

    def update(self, step, unit="Angstrom"):
        """Take a step in the coordinate space of the molecule."""
        if step.size != 3*self.natom:
            raise Exception("Update has wrong dimensions!")
        const = 1.0 if unit[0].lower() == 'b' else angstrom_to_bohr
        sss = const*step.reshape((self.natom, 3))
        for i in range(self.natom):
            for a in range(3):
                self.coords[i][a] += sss[i, a]

    def get_coords(self, unit="Angstrom"):
        if unit.lower()[0] == "b":
            return copy.deepcopy(self.coords)
        else:
            if unit.lower()[0] != 'a':
                raise Exception("Unrecognized unit: {}".format(unit))
            acoords = []
            b2a = bohr_to_angstrom
            for c in self.coords:
                acoords.append((b2a*c[0], b2a*c[1], b2a*c[2]))
            return acoords

    def set(self, coords, unit="Angstrom"):
        """Set the coordinates."""
        if len(coords) != len(self.coords):
            raise Exception("Molecule::set(): Provided coordinates are the wrong shape")
        self.coords = copy.deepcopy(coords)
        if unit.lower()[0] == "a":
            a2b = angstrom_to_bohr
            for i, c in enumerate(self.coords):
                self.coords[i] = [a2b*c[0], a2b*c[1], a2b*c[2]]

    def write_xyz_string(self, unit="angstrom"):
        """Return a string representing the molecule in 'xyz' format."""
        self._verify_lengths()
        A = (unit[0] == "A" or unit[0] == "a")
        output = str()
        for i, xyz in enumerate(self.coords):
            ss = bohr_to_angstrom if A else 1.0
            output = output + "{} {: .12f} {: .12f} {: .12f}\n".format(
                    self.names[i].decode("utf-8"), ss*xyz[0], ss*xyz[1], ss*xyz[2])
        return output

    def write_xyz(self, filename, unit="angstrom"):
        """Write the molecule to an 'xyz' file."""
        if "xx" in self.names:
            raise Exception(
                "Elements with non-integer charges are not compatible with the xyz format")

        self._verify_lengths()
        A = (unit[0] == "A" or unit[0] == "a")
        l = len(self.names)
        f = open(filename, 'w')
        f.write(str(l)+"\n")
        f.write("xyz file automatically generated\n")
        elem = [n.decode("utf-8") for n in self.names]
        for i, xyz in enumerate(self.coords):
            ss = bohr_to_angstrom if A else 1.0
            f.write("{} {: .8f} {: .8f} {: .8f}\n".format(elem[i], ss*xyz[0], ss*xyz[1], ss*xyz[2]))
        f.close()

    def save_to_hdf5(self, obj):
        """Save object data in hdf5 format."""
        obj.attrs["natom"] = self.natom
        obj.attrs["charge"] = self.charge
        obj.attrs["spin"] = self.spin
        obj.create_dataset("names", data=numpy.asarray(self.names))
        obj.create_dataset("charges", data=numpy.asarray(self.charges))
        obj.create_dataset("masses", data=numpy.asarray(self.masses))
        obj.create_dataset("coords", data=numpy.asarray(self.coords))

    def save(self, filename):
        """Save object to .h5 file."""
        import h5py
        f = h5py.File(filename, 'w')
        self.save_to_hdf5(f)
        f.close()

    def get_ase(self):
        from ase import Atom, Atoms
        acoords = self.get_coords(unit="Angstrom")
        alist = [Atom(n.decode("utf-8"), c) for n, c in zip(self.names, acoords)]
        return Atoms(alist)

    def std_print(self, unit="Angstrom"):
        """Print out molecule data with 'print'."""
        print("Name  Charge     Mass          x             y             z    ")
        self._verify_lengths()
        A = (unit[0] == "A" or unit[0] == "a")
        for i, name in enumerate(self.names):
            xyz = self.coords[i]
            ss = bohr_to_angstrom if A else 1.0
            print(" {:>2}   {:6.3f} {:11.6f} {:13.8f} {:13.8f} {:13.8f}".format(
                name.decode('utf-8'), self.charges[i], self.masses[i], ss*xyz[0], ss*xyz[1], ss*xyz[2]))

    def print_distances(self, prec=3, ncols=6, unit="Angstrom"):
        """Print a table of interatomic distances"""
        dmat = numpy.zeros((self.natom, self.natom))
        for i, c1 in enumerate(self.coords):
            for j, c2 in enumerate(self.coords[(i + 1):]):
                dist = numpy.linalg.norm(numpy.asarray(c1) - numpy.asarray(c2))
                dmat[i, j + i + 1] = dist
                dmat[j + i + 1, i] = dist
        if unit.lower()[0] == 'a':
            dmat *= bohr_to_angstrom

        ffmt = '{:' + str(prec + 3) + "." + str(prec) + 'f}'
        sfmt = '{:2s}'

        nslices = self.natom // ncols
        nremain = self.natom % ncols
        for i in range(nslices):
            # print header
            hstring = "      "
            spaces = " " * (prec + 2)
            for k in range(ncols):
                hstring += sfmt
                hstring += spaces
            alist = [x.decode('utf-8') for x in self.names[i*ncols:(i + 1)*ncols + 1]]
            print(hstring.format(*alist))
            dslice = dmat[:, i*ncols:(i + 1)*ncols + 1]
            for j, s in enumerate(dslice):
                pstring = " " + sfmt
                for k in range(ncols):
                    pstring += " " + ffmt
                args = [self.names[j].decode('utf-8')]
                for d in dslice[j]:
                    args.append(d)
                print(pstring.format(*args))
        if nremain > 0:
            print("")
            hstring = "      "
            spaces = " " * (prec + 2)
            for k in range(nremain):
                hstring += sfmt
                hstring += spaces
            alist = [x.decode('utf-8') for x in self.names[nslices*ncols:]]
            print(hstring.format(*alist))
            dslice = dmat[:, nslices*ncols:]
            for j, s in enumerate(dslice):
                pstring = " " + sfmt
                for k in range(nremain):
                    pstring += " " + ffmt
                args = [self.names[j].decode('utf-8')]
                for d in dslice[j]:
                    args.append(d)
                print(pstring.format(*args))


def get_mol_from_xyz(filename, unit="Angstrom"):
    f = open(filename)
    lines = f.readlines()
    f.close()
    natoms = int(lines[0].replace("\n", ""))
    if len(lines) - 2 != natoms:
        raise Exception("xyz file has inconsistent number of atoms")

    mol = Molecule()
    for i in range(natoms):
        items = lines[i + 2].replace("\n", "").split()
        assert(len(items) == 4)
        xyz = [float(items[1]), float(items[2]), float(items[3])]
        mol.add(xyz, name=items[0], unit=unit)
    return mol


def get_mol_from_hdf5(filename):
    import h5py
    f = h5py.File(filename, 'r')
    mol = Molecule(fh5=f)
    f.close()
    return mol


def get_mol_from_ase(atoms):
    pbc = atoms.get_pbc()
    if pbc.any():
        logging.warning("Creating a molecule from 1 unit cell of an ase object with PBC")

    charges = atoms.get_atomic_numbers()
    names = [numpy.string_(elements[c - 1]) for c in charges]
    masses = atoms.get_masses()
    pos = atoms.get_positions()
    natom = len(charges)

    coords = (angstrom_to_bohr*numpy.asarray(pos)).tolist()
    newmol = Molecule()
    newmol.natom = natom
    newmol.names = names
    newmol.charges = charges
    newmol.masses = masses
    newmol.coords = coords
    newmol.charge = 0
    newmol.spin = charges.sum() % 2
    return newmol


def buildI(mol):
    Rx = Ry = Rz = 0.0
    Mtot = 0.0
    for i, M in enumerate(mol.masses):
        x, y, z = mol.coords[i]
        Mtot += M
        Rx += M*x
        Ry += M*y
        Rz += M*z
    R = numpy.asarray([Rx, Ry, Rz])/Mtot
    Ixx = Iyy = Izz = 0.0
    Ixy = Ixz = Iyz = 0.0
    for i, M in enumerate(mol.masses):
        x, y, z = mol.coords[i] - R
        Ixx += M*(y*y + z*z)
        Iyy += M*(z*z + x*x)
        Izz += M*(x*x + y*y)
        Ixy -= M*x*y
        Ixz -= M*x*z
        Iyz -= M*y*z
    I = numpy.zeros((3, 3))
    I[0, 0] = Ixx
    I[1, 1] = Iyy
    I[2, 2] = Izz
    I[0, 1] = I[1, 0] = Ixy
    I[0, 2] = I[2, 0] = Ixz
    I[1, 2] = I[2, 1] = Iyz
    I *= amu_to_el
    return R, I


class MolSystem(object):
    """Molecule system container

    Attributes:
        mol (Molecule): Geometry of molecule (Bohr).
        linear (bool): Linear molecule?
        M (array): Atomic masses (m_e).
        E (float): Electronic energy at given geometry (E_h).
        F1 (array): Cartesian gradient.
        F2 (array): Cartesian Hessian.
        F3 (array): Cartesian 3rd derivatives.
        F4 (array): Cartesian 4th derivatives.
        L (array): Transformation matrix to normal modes.
        P1 (array): Normal mode gradient.
        P2 (array): Normal mode Hessian (not guaranteed to be diagonal).
        P3 (array): Normal mode 3rd order anharmonicities.
        P4 (array): Normal mode 4th order anharmonicities.
        I (array): Intertia tensor
    """
    def __init__(self, filename=None, mol=None):
        # arrays of force constants
        self.E = None
        self.F1 = None
        self.F2 = None
        self.F3 = None
        self.F4 = None
        self.L = None
        self.P1 = None
        self.P2 = None
        self.P3 = None
        self.P4 = None
        self.I = None

        if mol is not None:
            self.mol = mol
            assert(filename is None)
        elif filename is not None:
            assert(mol is None)
            s = filename.split(".")
            ext = s[-1]
            if ext == "xyz":
                self.mol = get_mol_from_xyz(filename)
            elif ext == "h5":
                import h5py
                f = h5py.File(filename, 'r')
                fm = f["molecule"]
                self.mol = Molecule(fh5=fm)
                if "E" in f.attrs.keys():
                    self.E = f.attrs["E"]
                if "F1" in f.keys():
                    self.F1 = numpy.array(f["F1"])
                if "F2" in f.keys():
                    self.F2 = numpy.array(f["F2"])
                if "F3" in f.keys():
                    self.F3 = numpy.array(f["F3"])
                if "F4" in f.keys():
                    self.F4 = numpy.array(f["F4"])
                if "L" in f.keys():
                    self.L = numpy.array(f["L"])
                if "P1" in f.keys():
                    self.P1 = numpy.array(f["P1"])
                if "P2" in f.keys():
                    self.P2 = numpy.array(f["P2"])
                if "P3" in f.keys():
                    self.P3 = numpy.array(f["P3"])
                if "P4" in f.keys():
                    self.P4 = numpy.array(f["P4"])
                f.close()
            else:
                raise Exception("Unrecognized file extension: " + ext)
        else:
            raise Exception("MolSystem::__init__(): No molecule provided!")

        self.linear = is_linear(self.mol.coords)
        self.M = [self.mol.masses[i//3]*amu_to_el for i in range(3*self.mol.natom)]
        self.R, self.I = buildI(self.mol)
        if self.linear:
            e, v = numpy.linalg.eigh(self.I)
            logging.info("Linear molecule detected (zero principal moment: {})".format(e[0]))

    def set_forces(self, E=None, F1=None, F2=None, F3=None, F4=None):
        """Set the Cartesian force constants."""
        self.E = E
        self.F1 = F1
        self.F2 = F2
        self.F3 = F3
        self.F4 = F4

    def compute_forces(self, computer, order=2, delta=0.0012):
        """Compute Cartesian force constants."""
        self.E, self.F1, self.F2, self.F3, self.F4 = compute_forces(self.mol, computer, order, delta=delta)

    def get_normal_modes(self, method="proj"):
        """Get normal modes from Cartesian force constants."""
        if self.F2 is None:
            raise Exception("Cannot compute normal modes withouth Hessian")

        natom = len(self.mol.coords)
        e, self.L = normal_modes.get_normal_modes(
            natom, self.mol.coords, self.M, self.F2,
            method, linear=self.linear, R=self.R, I=self.I)

        Mi2 = 1.0/numpy.sqrt(numpy.asarray(self.M))
        self.P2 = numpy.diag(e)
        self.P1 = numpy.einsum('i,ip->p', self.F1, self.L)
        if self.F3 is not None:
            F3m = numpy.einsum('ijk,i,j,k->ijk', self.F3, Mi2, Mi2, Mi2)
            self.P3 = numpy.einsum('ijk,ip,jq,kr->pqr', F3m, self.L, self.L, self.L)
        if self.F4 is not None:
            F4m = numpy.einsum('ijkl,i,j,k,l->ijkl', self.F4, Mi2, Mi2, Mi2, Mi2)
            self.P4 = numpy.einsum('ijkl,ip,jq,kr,ls->pqrs', F4m, self.L, self.L, self.L, self.L)

    def save_to_hdf5(self, obj):
        """Save object contents to the given hdf5 object."""
        gm = obj.create_group("molecule")
        self.mol.save_to_hdf5(gm)
        if self.E is not None:
            obj.attrs["E"] = self.E
        if self.F1 is not None:
            obj.create_dataset("F1", data=self.F1)
        if self.F2 is not None:
            obj.create_dataset("F2", data=self.F2)
        if self.F3 is not None:
            obj.create_dataset("F3", data=self.F3)
        if self.F4 is not None:
            obj.create_dataset("F4", data=self.F4)
        if self.L is not None:
            obj.create_dataset("L", data=self.L)
        if self.P1 is not None:
            obj.create_dataset("P1", data=self.P1)
        if self.P2 is not None:
            obj.create_dataset("P2", data=self.P2)
        if self.P3 is not None:
            obj.create_dataset("P3", data=self.P3)
        if self.P4 is not None:
            obj.create_dataset("P4", data=self.P4)

    def save(self, filename):
        """Save object contents with hdf5."""
        import h5py
        f = h5py.File(filename, 'w')
        self.save_to_hdf5(f)
        f.close()


class UnitCell(object):
    """Unit cell class

    Attributes:

    """
    def __init__(self, mol=None, lattice=[None, None, None], fh5=None, unit="Angstrom"):
        self.set_latt(lattice, unit=unit)

        if fh5 is not None:
            #raise Exception("Deserialization from hdf5 is not currently supported")
            if lattice[0] is not None or lattice[1] is not None or lattice[2] is not None:
                logging.warning("Non-trivial user-input lattice will be overwritten!")
            if mol is not None:
                logging.error("File ({}) specified along with other inputs".format(fh5))
                raise Exception("Other inputs cannot be specified along with file")
            self.lattice = numpy.array(fh5["latt"]).tolist()
            fm = fh5["molecule"]
            self.mol = Molecule(fh5=fm)

        if mol is not None:
            self.mol = mol

        if len(self.lattice) != 3:
            raise Exception("UnitCell::__init__: Invalid lattice!")

        self.dim = 0
        for l in self.lattice:
            if l is not None:
                self.dim += 1

    def copy(self):
        mol = self.mol.copy()
        latt = copy.deepcopy(self.lattice)
        newcell = UnitCell(mol=mol, lattice=latt, unit="Bohr")
        return newcell

    def add(self, xyz, name=None, charge=None, mass=None, unit="Angstrom"):
        self.mol.add(xyz, name=name, charge=charge, mass=mass, unit=unit)

    def update(self, step, unit="Angstrom"):
        self.mol.update(step, unit=unit)

    def set(self, coords, unit="Angstrom"):
        self.mol.set(coords, unit=unit)

    def set_latt(self, lattice, unit="Angstrom"):
        if len(lattice) == 3:
            self.lattice = copy.deepcopy(lattice)
        elif len(lattice) == 6:
            self.lattice[0] = numpy.asarray([lattice[0], 0.0, 0.0])
            assert(False)
        else:
            logging.error("Unrecognized lattice specification in UnitCell::set_latt")
            raise Exception("Unrecognized lattice specification!")
        a2b = angstrom_to_bohr
        if unit[0].lower() == 'a':
            for i, x in enumerate(self.lattice):
                if x is not None:
                    self.lattice[i] = [a2b*x[0], a2b*x[1], a2b*x[2]]

        self.dim = 0
        for l in self.lattice:
            if l is not None:
                self.dim += 1

    def dimension(self):
        return self.dim

    def supercell(self, nimgs):
        import itertools
        newlat = [numpy.asarray(l)*n for l, n in zip(self.lattice, nimgs)]
        ranges = [range(n) for n in nimgs]
        newmol = Molecule()
        npl = numpy.asarray(self.lattice)
        for n1, n2, n3 in itertools.product(*ranges):
            T = n1*npl[0] + n2*npl[1] + n3*npl[2]
            for c, n in zip(self.mol.coords, self.mol.names):
                newmol.add(numpy.asarray(c) + T, name=n.decode("utf-8"), unit="Bohr")
        return UnitCell(mol=newmol, lattice=newlat, unit="Bohr")

    def save(self, filename):
        import h5py
        f = h5py.File(filename, 'w')
        gm = f.create_group("molecule")
        self.mol.save_to_hdf5(gm)
        f.create_dataset("latt", data=numpy.asarray(self.lattice))
        f.close()


def get_cell_from_hdf5(filename):
    import h5py
    f = h5py.File(filename, 'r')
    uc = UnitCell(fh5=f)
    f.close()
    return uc


def get_cell_from_ase(atoms):
    charges = atoms.get_atomic_numbers()
    names = [numpy.string_(elements[c - 1]) for c in charges]
    masses = atoms.get_masses()
    pos = atoms.get_positions()
    natom = len(charges)

    coords = (angstrom_to_bohr*numpy.asarray(pos)).tolist()
    newmol = Molecule()
    newmol.natom = natom
    newmol.names = names
    newmol.charges = charges
    newmol.masses = masses
    newmol.coords = coords
    newmol.charge = 0
    newmol.spin = charges.sum() % 2

    ase_cell = atoms.get_cell()
    return UnitCell(mol=newmol, lattice=ase_cell)


def cluster(cell, mlat):
    newmol = Molecule()
    npl = numpy.asarray(cell.lattice)
    for n1 in range(mlat):
        for n2 in range(mlat - n1):
            for n3 in range(mlat - n1 - n2):
                T = n1*npl[0] + n2*npl[1] + n3*npl[2]
                for c, n in zip(cell.mol.coords, cell.mol.names):
                    newmol.add(numpy.asarray(c) + T, name=n.decode("utf-8"), unit="Bohr")
    return newmol


class PBCSystem(object):
    def __init__(self, filename=None, cell=None):
        # arrays of force constants
        self.E = None
        self.F1 = None
        self.F2s = None
        self.F2q = None
        #self.F3 = None
        #self.F4 = None
        #self.L = None
        #self.P1 = None
        self.P2 = None
        #self.P3 = None
        #self.P4 = None

        if cell is not None:
            self.cell = cell
            assert(filename is None)
        elif filename is not None:
            assert(cell is not None)
            raise Exception("PBCSystem from file is not yet supported")
        else:
            raise Exception("PBCSystem::__init__: No unit cell provided!")

        self.M = [self.cell.mol.masses[i//3]*amu_to_el for i in range(3*self.cell.mol.natom)]

    def set_forces(self, E=None, F1=None):
        """Set the Cartesian force constants."""
        self.E = E
        self.F1 = F1

    def compute_forces(self, computer, delta=0.0012):
        self.E, self.F1 = compute_gradient(self.cell, self.cell.mol.coords, computer, delta=delta)

    def compute_hessian(self, computer, supercell=None, delta=0.0012):
        if supercell is None:
            raise NotImplementedError
        else:
            suc = self.cell.supercell(supercell)
            Es, F1s, F2s = compute_hessian(suc, suc.mol.coords, computer, delta)
            self.F2s = [supercell, F2s]

    def get_phonons(self, q):
        """Compute the phonon frequencies for a given q-vector."""
        if self.F2q is not None:
            raise NotImplementedError
        else:
            assert(self.F2s is not None)
            supercell, F2s = self.F2s
            suc = self.cell.supercell(supercell)
            #Es, F1s, F2s = compute_hessian(suc, suc.mol.coords, computer, delta)
            w = phonons.get_phonons_supercell(
                    self.cell.mol.natom, self.M, self.cell.mol.coords,
                    supercell, suc.mol.coords, F2s, q)
            # TODO save this or something
            return w
