import numpy
from ..constants import elements, covalent_radius, angstrom_to_bohr


class Bond(object):
    """Representation of a bond

    Attributes:
        i1 (int): Index of first atom
        i2 (int): Index of second atom
        e1 (str): Element label of first atom
        e2 (str): Element label of second atom
        dist (float): Bond distance when initialized
    """
    def __init__(self, i1, i2, e1, e2, dist):
        self.i1 = i1
        self.i2 = i2
        self.e1 = e1
        self.e2 = e2
        self.dist = dist

    def __repr__(self):
        out = self.bond_type()
        out += "(" + repr(self.i1) + ", "
        out += repr(self.i2) + "): " + repr(self.dist)
        return out

    def bond_type(self):
        """Return string representation of bond type"""
        return self.e1 + "-" + self.e2

    def includes(self, idx):
        """Return true if the bond includes the atom specified by idx"""
        if idx == self.i1 or idx == self.i2:
            return True
        else:
            return False

    def value(self):
        """Return initial value of the bond."""
        return self.dist

    def evaluate(self, newcoords):
        """Return the bond distance given a new set of coordinates."""
        c1 = numpy.asarray(newcoords[self.i1])
        c2 = numpy.asarray(newcoords[self.i2])
        return numpy.linalg.norm(c1 - c2)

    def grad(self, newcoords):
        """Return the gradient of the bond distance at the specified
        coordinates.
        """
        natom = len(newcoords)
        gv = numpy.zeros(3*natom)
        c1 = numpy.asarray(newcoords[self.i1])
        c2 = numpy.asarray(newcoords[self.i2])
        g1 = (c1 - c2)/numpy.linalg.norm(c1 - c2)
        o1 = 3*self.i1
        o2 = 3*self.i2
        gv[o1:o1 + 3] = g1
        gv[o2:o2 + 3] = -g1
        return gv


def get_bond_list(mol, const=1.4):
    """Return a list of bond objects based on covalent const*[covalent radius]
    """
    blist = []
    for i, (e1, c1) in enumerate(zip(mol.names[:-1], mol.coords[:-1])):
        for j, (e2, c2) in enumerate(zip(mol.names[i+1:], mol.coords[i+1:])):
            c1 = numpy.asarray(c1)
            c2 = numpy.asarray(c2)
            e1str = e1.decode("utf-8")
            e2str = e2.decode("utf-8")
            n1 = elements.index(e1str)
            n2 = elements.index(e2str)
            max_bond = const*(covalent_radius[n1] + covalent_radius[n2])  # pm
            max_bond *= angstrom_to_bohr/100.0  # to bohr
            dist = numpy.linalg.norm(c1 - c2)
            if dist < max_bond:
                blist.append(Bond(i, j+i+1, e1str, e2str, dist))
    return blist


def get_adjacency(blist, natom):
    """Return the adjacencies of all atoms as a list of sets."""
    adj = [set() for i in range(natom)]
    for b in blist:
        adj[b.i1].add(b.i2)
        adj[b.i2].add(b.i1)
    return adj
