import numpy
from .bond import get_adjacency


def _eval_dihedral(c1, c2, c3, c4):
    v1 = c1 - c2
    v2 = c4 - c3
    vb = c3 - c2

    # project against bond vector and define arbitrary sign
    v1 = v1 - numpy.dot(v1, vb)*vb/(numpy.linalg.norm(vb)**2)
    v2 = v2 - numpy.dot(v2, vb)*vb/(numpy.linalg.norm(vb)**2)
    cos = numpy.dot(v1, v2)/(numpy.linalg.norm(v1)*numpy.linalg.norm(v2))
    vy = numpy.cross(v1, v2)
    sign = -1 if numpy.dot(vy, vb) > 0.0 else 1
    if cos > 1.0:
        cos = 1.0
    if cos < -1.0:
        cos = -1.0
    phi = numpy.arccos(cos)
    n1 = numpy.linalg.norm(v1)
    n2 = numpy.linalg.norm(v2)
    return sign*phi, n1, n2


def _eval_derivative(c1, c2, c3, c4):
    v1 = c1 - c2
    v2 = c4 - c3
    vb = c3 - c2

    nb = numpy.linalg.norm(vb)
    v1 = v1 - numpy.dot(v1, vb)*vb/(numpy.linalg.norm(vb)**2)
    v2 = v2 - numpy.dot(v2, vb)*vb/(numpy.linalg.norm(vb)**2)
    n1 = numpy.linalg.norm(v1)
    n2 = numpy.linalg.norm(v2)
    cos = numpy.dot(v1, v2)/(n1*n2)
    if cos > 1.0:
        cos = 1.0
    if cos < -1.0:
        cos = -1.0

    nearp1 = (abs(cos - 1.0) < 1e-7)
    nearm1 = (abs(cos + 1.0) < 1e-7)

    dv1dc1 = numpy.eye(3) - numpy.einsum('i,j->ij', vb, vb)/(numpy.linalg.norm(vb)**2)
    dv2dc4 = numpy.eye(3) - numpy.einsum('i,j->ij', vb, vb)/(numpy.linalg.norm(vb)**2)
    dn1dc1 = numpy.matmul(dv1dc1, v1)/n1
    dn2dc4 = numpy.matmul(dv2dc4, v2)/numpy.linalg.norm(v2)
    dv1db = -numpy.einsum('i,j->ij', vb, v1)/(nb*nb)
    dv1db -= numpy.diag(vb)*numpy.dot(vb, v1)/(nb*nb)
    dv1db += numpy.dot(vb, v1)/numpy.dot(vb, vb)*numpy.einsum('i,j->ij', vb, vb)
    dv2db = -numpy.einsum('i,j->ij', vb, v2)/(nb*nb)
    dv2db -= numpy.diag(vb)*numpy.dot(vb, v2)/(nb*nb)
    dv2db += numpy.dot(vb, v2)/numpy.dot(vb, vb)*numpy.einsum('i,j->ij', vb, vb)
    dn1db = numpy.matmul(dv1db, vb)/n1
    dn2db = numpy.matmul(dv2db, vb)/n2
    vy = numpy.cross(v1, v2)
    sign = -1 if numpy.dot(vy, vb) > 0.0 else 1
    if nearp1:
        x1 = numpy.cross(v1, vb)
        x1 = x1 / numpy.linalg.norm(x1)
        x2 = numpy.cross(v2, vb)
        x2 = x2 / numpy.linalg.norm(x2)
        g1 = -x1/n1
        g4 = x2/n1

        # We're doing finite differences here because I can't figure it out
        delta = 1e-4
        c2p1 = c2.copy()
        c2p2 = c2.copy()
        c2p3 = c2.copy()
        c2m1 = c2.copy()
        c2m2 = c2.copy()
        c2m3 = c2.copy()
        c2p1[0] += delta
        c2p2[1] += delta
        c2p3[2] += delta
        c2m1[0] -= delta
        c2m2[1] -= delta
        c2m3[2] -= delta
        g2fd = numpy.zeros(3)
        g2fd[0] = (_eval_dihedral(c1, c2p1, c3, c4)[0]
            - _eval_dihedral(c1, c2m1, c3, c4)[0])/(2*delta)
        g2fd[1] = (_eval_dihedral(c1, c2p2, c3, c4)[0]
            - _eval_dihedral(c1, c2m2, c3, c4)[0])/(2*delta)
        g2fd[2] = (_eval_dihedral(c1, c2p3, c3, c4)[0]
            - _eval_dihedral(c1, c2m3, c3, c4)[0])/(2*delta)

        c3p1 = c3.copy()
        c3p2 = c3.copy()
        c3p3 = c3.copy()
        c3m1 = c3.copy()
        c3m2 = c3.copy()
        c3m3 = c3.copy()
        c3p1[0] += delta
        c3p2[1] += delta
        c3p3[2] += delta
        c3m1[0] -= delta
        c3m2[1] -= delta
        c3m3[2] -= delta
        g3fd = numpy.zeros(3)
        g3fd[0] = (_eval_dihedral(c1, c2, c3p1, c4)[0] - _eval_dihedral(c1, c2, c3m1, c4)[0])/(2*delta)
        g3fd[1] = (_eval_dihedral(c1, c2, c3p2, c4)[0] - _eval_dihedral(c1, c2, c3m2, c4)[0])/(2*delta)
        g3fd[2] = (_eval_dihedral(c1, c2, c3p3, c4)[0] - _eval_dihedral(c1, c2, c3m3, c4)[0])/(2*delta)

        g2 = g2fd
        g3 = g3fd

    elif nearm1:
        x1 = numpy.cross(v1, vb)
        x1 = x1 / numpy.linalg.norm(x1)
        x2 = numpy.cross(v2, vb)
        x2 = x2 / numpy.linalg.norm(x2)
        g1 = -x1/n1
        g4 = x2/n2
        g2 = -g1*(1 - numpy.dot(c1 - c2, vb)) + g4*numpy.dot(c4 - c3, vb)
        g3 = g1*(numpy.dot(c1 - c2, vb)) - g4*(1 - numpy.dot(c4 - c3, vb))
    else:
        dcos = -1.0/numpy.sqrt(1 - cos*cos)
        dcosdb = (
            numpy.matmul(v2, dv1db) + numpy.matmul(v1, dv2db)
            - dn1db*numpy.dot(v1, v2)/n1 - dn2db*numpy.dot(v1, v2)/n2)/(n1*n2)
        dcosdc1 = numpy.matmul(dv1dc1, v2)/(n1*n2)
        dcosdc1 -= numpy.dot(v1, v2)/(n1*n1*n2)*dn1dc1
        dcosdc4 = numpy.matmul(dv2dc4, v1)/(n1*n2)
        dcosdc4 -= numpy.dot(v1, v2)/(n1*n1*n2)*dn2dc4
        dcosdc2 = -dcosdc1 - dcosdb
        dcosdc3 = -dcosdc4 + dcosdb
        g1 = dcos*dcosdc1
        g2 = dcos*dcosdc2
        g3 = dcos*dcosdc3
        g4 = dcos*dcosdc4

    return sign*g1, sign*g2, sign*g3, sign*g4


class Dihedral(object):
    """Representation of a dihedral angle

    The topology is 1-2--3-4.

    Attributes:
        i1 (int): Index of first atom
        i2 (int): Index of second (middle bond) atom
        i2 (int): Index of third (middle bond) atom
        i3 (int): Index of fourth atom
        e1 (str): Element label of first atom
        e2 (str): Element of second (middle bond) atom
        e2 (str): Element of third (middle bond) atom
        e3 (str): Element of fourth atom
        phi (float): Initial value of dihedral
    """
    def __init__(self, i1, i2, i3, i4, e1, e2, e3, e4, phi):
        self.i1 = i1
        self.i2 = i2
        self.i3 = i3
        self.i4 = i4
        self.e1 = e1
        self.e2 = e2
        self.e3 = e3
        self.e4 = e4
        self.phi = phi

    def __repr__(self):
        out = self.dihedral_type()
        out += "(" + repr(self.i1) + ", " + repr(self.i2)
        out += ", " + repr(self.i3) + ", " + repr(self.i4)
        out += "): " + repr(self.phi)
        return out

    def dihedral_type(self):
        """Return a string representing the type of dihedral angle"""
        return self.e1 + "-" + self.e2 + "--" + self.e3 + "-" + self.e4

    def equiv(self, i1, i2, i3, i4):
        """Return true if the given indices represent the same dihedral."""
        return {i2, i3} == {self.i2, self.i3} and {i1, i4} == {self.i1, self.i4}

    def value(self):
        """Return the initial value of the dihedral"""
        return self.phi

    def evaluate(self, newcoords):
        """Return the value of the dihedral at the given coordinates."""
        c1 = numpy.asarray(newcoords[self.i1])
        c2 = numpy.asarray(newcoords[self.i2])
        c3 = numpy.asarray(newcoords[self.i3])
        c4 = numpy.asarray(newcoords[self.i4])
        xphi = _eval_dihedral(c1, c2, c3,  c4)[0]
        bsign = (numpy.sign(xphi) != numpy.sign(self.phi))
        bangle = (abs(self.phi) > numpy.pi/2)
        if bsign and bangle:
            xphi += (2*numpy.pi if xphi < 0.0 else -2*numpy.pi)
        return xphi

    def grad(self, newcoords):
        """Return the gradient of the dihedral at the given coordinates."""
        c1 = numpy.asarray(newcoords[self.i1])
        c2 = numpy.asarray(newcoords[self.i2])
        c3 = numpy.asarray(newcoords[self.i3])
        c4 = numpy.asarray(newcoords[self.i4])
        g1, g2, g3, g4 = _eval_derivative(c1, c2, c3, c4)
        natom = len(newcoords)
        gv = numpy.zeros(3*natom)
        o1 = 3*self.i1
        o2 = 3*self.i2
        o3 = 3*self.i3
        o4 = 3*self.i4
        gv[o1:o1 + 3] = g1
        gv[o2:o2 + 3] = g2
        gv[o3:o3 + 3] = g3
        gv[o4:o4 + 3] = g4
        return gv


def get_dihedrals(blist, flist, mol, vtol=1e-3):
    """Return a list of dihedral angles given a list of bonds.

    If the norm of the vectors determining the value of the dihedral
    is less than vtol, the dihedral is considered to be ill-defined
    and is not included.
    """
    adj = get_adjacency(blist, mol.natom)
    dlist = []

    # loop over fragments
    for f in flist:
        # loop over atoms
        for idx in f:
            # loop over bonds
            for jdx in adj[idx]:
                # loop over other bonds
                for iidx in adj[idx]:
                    if iidx == jdx:
                        continue
                    for jjdx in adj[jdx]:
                        if jjdx == idx:
                            continue

                        new = True
                        # check that we haven't already found it
                        for d in dlist:
                            if d.equiv(iidx, idx, jdx, jjdx):
                                new = False
                                break
                        if not new:
                            continue
                        cii = numpy.asarray(mol.coords[iidx])
                        ci = numpy.asarray(mol.coords[idx])
                        cjj = numpy.asarray(mol.coords[jjdx])
                        cj = numpy.asarray(mol.coords[jdx])

                        phi, n1, n2 = _eval_dihedral(cii, ci, cj, cjj)

                        if n1 > vtol and n2 > vtol:
                            eii = mol.names[iidx].decode("utf-8")
                            ei = mol.names[idx].decode("utf-8")
                            ej = mol.names[jdx].decode("utf-8")
                            ejj = mol.names[jjdx].decode("utf-8")
                            dlist.append(Dihedral(
                                iidx, idx, jdx, jjdx, eii, ei, ej, ejj, phi))
    return dlist
