import numpy


# form mass-weighted vectors of translations
def translations(natom, masses=None, mweight=False):
    if mweight and (masses is None):
        raise Exception("Masses must be provided for mass weighting")
    d1 = numpy.zeros(3*natom)
    d2 = numpy.zeros(3*natom)
    d3 = numpy.zeros(3*natom)
    for i in range(natom):
        d1[3*i] = 1.0
        d2[3*i + 1] = 1.0
        d3[3*i + 2] = 1.0

    if mweight:
        for i in range(natom):
            m = masses[3*i]
            sm = numpy.sqrt(m)
            d1[3*i] *= sm
            d1[3*i + 1] *= sm
            d1[3*i + 2] *= sm
            d2[3*i] *= sm
            d2[3*i + 1] *= sm
            d2[3*i + 2] *= sm
            d3[3*i] *= sm
            d3[3*i + 1] *= sm
            d3[3*i + 2] *= sm
    return d1, d2, d3


# form mass-weighted vectors of rotations
def rotations(natom, coords, I, R, linear=False, masses=None, mweight=False):
    if mweight and (masses is None):
        raise Exception("Masses must be provided for mass weighting")
    nr = 3 if not linear else 2
    if not linear:
        d4 = numpy.zeros(3*natom)
        d5 = numpy.zeros(3*natom)
        d6 = numpy.zeros(3*natom)
        e, v = numpy.linalg.eigh(I)
        for i in range(natom):
            Ri = numpy.asarray(coords[i] - R)
            d4[3*i:3*(i + 1)] = numpy.cross(Ri, v[:, 0])
            d5[3*i:3*(i + 1)] = numpy.cross(Ri, v[:, 1])
            d6[3*i:3*(i + 1)] = numpy.cross(Ri, v[:, 2])
    else:
        d4 = numpy.zeros(3*natom)
        d5 = numpy.zeros(3*natom)
        e, v = numpy.linalg.eigh(I)
        for i in range(natom):
            Ri = coords[i] - R
            d4[3*i:3*(i + 1)] = numpy.cross(Ri, v[:, 1])
            d5[3*i:3*(i + 1)] = numpy.cross(Ri, v[:, 2])

    if mweight:
        for i in range(natom):
            m = masses[3*i]
            sm = numpy.sqrt(m)
            d4[3*i] *= sm
            d4[3*i + 1] *= sm
            d4[3*i + 2] *= sm
            d5[3*i] *= sm
            d5[3*i + 1] *= sm
            d5[3*i + 2] *= sm
            if nr == 3:
                d6[3*i] *= sm
                d6[3*i + 1] *= sm
                d6[3*i + 2] *= sm
    return (d4, d5) if linear else (d4, d5, d6)


def transrot(natom, coords, I, R, linear=False, masses=None, mweight=False):
    d1, d2, d3 = translations(natom, masses=masses, mweight=mweight)
    if linear:
        d4, d5 = rotations(natom, coords, I, R, linear=True, masses=masses, mweight=True)
        return d1, d2, d3, d4, d5
    else:
        d4, d5, d6 = rotations(natom, coords, I, R, linear=False, masses=masses, mweight=True)
        return d1, d2, d3, d4, d5, d6


# form projector onto translations/rotations
def half_projector(vecs):
    nv = len(vecs)
    dim = len(vecs[0])
    Uv = numpy.zeros((dim, nv))
    for i, v in enumerate(vecs):
        Uv[:, i] = v/numpy.linalg.norm(v)
    return Uv


# form projector onto translations/rotations
def projector(vecs):
    dim = len(vecs[0])
    Pv = numpy.zeros((dim, dim))
    for i, v in enumerate(vecs):
        vtemp = v/numpy.linalg.norm(v)
        Pv += numpy.einsum('i,j->ij', vtemp, vtemp)
    return Pv


# form (pseudo-unitary) projector^{1/2} against translations/rotations
def half_projector_against(vecs, method='svd', full=False):
    nv = len(vecs)
    dim = len(vecs[0])

    if method.lower() != "svd":
        raise Exception("Unrecognized projection method: {}".format(method))

    Mv = numpy.zeros((dim, nv))
    for i, v in enumerate(vecs):
        Mv[:, i] = v/numpy.linalg.norm(v)
    P = numpy.einsum('pi,qi->pq', Mv, Mv)
    U, s, Vt = numpy.linalg.svd(P)
    if full:
        Uout = numpy.zeros(U.shape)
        Uout[:, nv:] = U[:, nv:]
        return Uout
    else:
        return U[:, nv:]


def projector_against(vecs):
    Pv = projector(vecs)
    n = Pv.shape[0]
    return numpy.eye(n) - Pv
