import logging
import numpy
from . import constants
from .transrot import translations
from .transrot import rotations
from .transrot import half_projector
from .transrot import half_projector_against


def get_normal_modes(natom, coords, masses, F2, method,
                     linear=False, R=None, I=None):
    assert(len(coords) == natom)
    assert(len(masses) == 3*natom)
    Mi2 = 1.0/numpy.sqrt(numpy.asarray(masses))
    F2m = numpy.einsum('ij,i,j->ij', F2, Mi2, Mi2)

    if method == "diag":
        e, v = numpy.linalg.eigh(F2m)
        zeros = []
        nzeros = 5 if linear else 6
        for i, w2 in enumerate(e):
            if abs(w2) < 3.33e-8:
                zeros.append(i)
        if len(zeros) < nzeros:
            logging.error("Only {} of {} expected zero-energy modes were found".format(len(zeros), nzeros))
            raise Exception("Could not find all zero-energy modes")
        if len(zeros) > nzeros:
            logging.warning("{} low-energy modes are not included in vibrational analysis".format(len(zeros) - nzeros))
        off = len(zeros)
        L = v[:, off:]
        w2 = e[off:]
        return w2, L
    elif method == "proj":
        assert(I is not None)
        assert(R is not None)
        nr = 3 if not linear else 2
        d1, d2, d3 = translations(natom, masses=masses, mweight=True)

        # compute Cartesian rotation vectors
        if not linear:
            d4, d5, d6 = rotations(natom, coords, I, R, linear=False, masses=masses, mweight=True)
            dvecs = (d1, d2, d3, d4, d5, d6)
        else:
            d4, d5 = rotations(natom, coords, I, R, linear=True, masses=masses, mweight=True)
            dvecs = (d1, d2, d3, d4, d5)

        # get projection matrices from normalized d
        Uv = half_projector_against(dvecs)
        Ut = half_projector((d1, d2, d3))
        Ur = half_projector((d4, d5)) if nr == 2 else half_projector((d4, d5, d6))

        # compute normal modes
        #F2mv = numpy.einsum('pi,qj,pq->ij',Uv,Uv,F2m)
        F2mv = numpy.matmul(Uv.transpose(), numpy.matmul(F2m, Uv))
        e, v = numpy.linalg.eigh(F2mv)
        L = numpy.einsum('pi,iv->pv', Uv, v)

        # compute translation-rotation hessian eigenvalues as a check
        F2mt = numpy.einsum('pi,qj,pq->ij', Ut, Ut, F2m)
        et = numpy.linalg.eigh(F2mt)[0]
        F2mr = numpy.einsum('pi,qj,pq->ij', Ur, Ur, F2m)
        er = numpy.linalg.eigh(F2mr)[0]
        logging.info("Translational frequencies:")
        for i, ee in enumerate(et):
            se = numpy.sqrt(abs(ee))
            logging.info("    {}: {:.6f}".format(i+1, se*constants.hartree_to_cm_1))
        logging.info("Rotational frequencies:")
        for i, ee in enumerate(er):
            se = numpy.sqrt(abs(ee))
            logging.info("    {}: {:.6f}".format(i+1, se*constants.hartree_to_cm_1))
        return e, L
    else:
        raise Exception("Unrecognized method for normal mode computation: " + method)
