import numpy
from .bond import get_adjacency


def _eval_angle(c1, c2, c3):
    ij = c1 - c2
    kj = c3 - c2
    cos = numpy.dot(ij, kj)/(numpy.linalg.norm(ij)*numpy.linalg.norm(kj))
    return numpy.arccos(cos)


def _eval_deriv(c1, c2, c3):
    ij = c1 - c2
    kj = c3 - c2
    nij = numpy.linalg.norm(ij)
    nkj = numpy.linalg.norm(kj)
    cos = numpy.dot(ij, kj)/(nij*nkj)
    dcos = -1.0/numpy.sqrt(1 - cos*cos)
    dA1 = (kj*nij*nkj - numpy.dot(ij, kj)*nkj*ij/nij)/(nij*nij*nkj*nkj)
    dA3 = (ij*nij*nkj - numpy.dot(ij, kj)*nij*kj/nkj)/(nij*nij*nkj*nkj)
    dA2 = -((ij + kj)*nij*nkj - numpy.dot(ij, kj)*(nkj*ij/nij + nij*kj/nkj))
    dA2 /= (nij*nij*nkj*nkj)
    return dcos*dA1, dcos*dA2, dcos*dA3


class Angle(object):
    """A representation of a bond angle

    Attributes:
        i1 (int): Index of first atom
        i2 (int): Index of second (middle) atom
        i3 (int): Index of third atom
        e1 (str): Element label of first atom
        e2 (str): Element label of second (middle) atom
        e3 (str): Element label of third atom
        theta (float): Initial value of bond angle
    """
    def __init__(self, i1, i2, i3, e1, e2, e3, theta):
        self.i1 = i1
        self.i2 = i2
        self.i3 = i3
        self.e1 = e1
        self.e2 = e2
        self.e3 = e3
        self.theta = theta

    def __repr__(self):
        out = self.angle_type()
        out += "(" + repr(self.i1) + ", " + repr(self.i2)
        out += ", " + repr(self.i3) + "): " + repr(self.theta)
        return out

    def angle_type(self):
        """Return string representation of angle type"""
        return self.e1 + "-" + self.e2 + "-" + self.e3

    def includes(self, idx):
        """Return True if the angle includes the atom specified by idx."""
        if idx == self.i1 or idx == self.i2 or idx == self.i3:
            return True
        else:
            return False

    def equiv(self, i1, i2, i3):
        """Return True if the given indices reference the same angle."""
        if i2 != self.i2:
            return False
        else:
            if i1 == self.i1 and i3 == self.i3:
                return True
            elif i1 == self.i3 and i3 == self.i1:
                return True
            else:
                return False

    def value(self):
        """Return the value of the angle."""
        return self.theta

    def evaluate(self, newcoords):
        """Return the value of the angle evaluated at the given coordinates."""
        ic = numpy.asarray(newcoords[self.i1])
        jc = numpy.asarray(newcoords[self.i2])
        kc = numpy.asarray(newcoords[self.i3])
        return _eval_angle(ic, jc, kc)

    def grad(self, newcoords):
        """Return the gradient of the angle evaluated at the given coordinates.
        """
        ic = numpy.asarray(newcoords[self.i1])
        jc = numpy.asarray(newcoords[self.i2])
        kc = numpy.asarray(newcoords[self.i3])
        g1, g2, g3 = _eval_deriv(ic, jc, kc)
        natom = len(newcoords)
        gv = numpy.zeros(3*natom)
        o1 = 3*self.i1
        o2 = 3*self.i2
        o3 = 3*self.i3
        gv[o1:o1 + 3] = g1
        gv[o2:o2 + 3] = g2
        gv[o3:o3 + 3] = g3
        return gv


def get_bond_angles(blist, flist, mol):
    """Return a list of bond angles given a list of bonds."""
    adj = get_adjacency(blist, mol.natom)
    alist = []

    # loop over fragments
    for f in flist:
        # loop over atoms
        for idx in f:
            # loop over adjacent triples
            for jdx in adj[idx]:
                for kdx in adj[jdx]:
                    if kdx == idx:
                        continue

                    # check that this angle is new
                    new = True
                    for a in alist:
                        if a.equiv(idx, jdx, kdx):
                            new = False
                            break

                    # add new angle
                    if new:
                        ic = numpy.asarray(mol.coords[idx])
                        jc = numpy.asarray(mol.coords[jdx])
                        kc = numpy.asarray(mol.coords[kdx])
                        rad = _eval_angle(ic, jc, kc)
                        if rad < numpy.pi*0.95 and rad > 0.2*numpy.pi:
                            ei = mol.names[idx].decode("utf-8")
                            ej = mol.names[jdx].decode("utf-8")
                            ek = mol.names[kdx].decode("utf-8")
                            alist.append(Angle(idx, jdx, kdx, ei, ej, ek, rad))
    return alist
