import numpy
import logging

from omega import constants
from omega import transrot

from .harmonic import get_function
from .matvec import get_matvec


class HarmonicPotential(object):
    """Driver for stochastic harmonic potential"""
    def __init__(self, sys, ref, computer):
        self.sys = sys
        self.ref = ref
        self.computer = computer
        self.dvecs = transrot.transrot(
                sys.mol.natom, sys.mol.coords, sys.I, sys.R,
                linear=sys.linear, masses=sys.M, mweight=True)
        self.P = transrot.half_projector_against(self.dvecs)
        self.dvecs_ref = transrot.transrot(
                ref.mol.natom, ref.mol.coords, ref.I, ref.R,
                linear=ref.linear, masses=ref.M, mweight=True)
        self.Pref = transrot.half_projector_against(self.dvecs_ref)

    def run(self, params, quantities=["ZPE", "Hvib", "Avib", "Svib"],
            T=298.15, nsample=1, method="Rayleigh"):
        from .random import get_random
        # Some constants
        kBT = T*constants.kb / constants.hartree_to_ev
        beta = 1.0 / (kBT + 1e-14)
        n = len(self.sys.M)
        if self.ref.F2 is None:
            logging.error("Reference hessian not found!")
            raise Exception("Reference hessian not found!")
        Mi2 = 1.0/numpy.sqrt(numpy.asarray(self.ref.M))
        F2m = numpy.einsum('ij,i,j->ij', self.ref.F2, Mi2, Mi2)
        #F2mv = numpy.matmul(self.Pref, numpy.matmul(F2m, self.Pref))
        F2mv = numpy.matmul(self.P.transpose(), numpy.matmul(F2m, self.P))
        e, v = numpy.linalg.eigh(F2mv)
        for i, w2 in enumerate(e):
            if w2 < 0:
                logging.warning("Imaginary frequencies detected in reference.")

        F2mv2 = numpy.matmul(self.Pref.transpose(), numpy.matmul(F2m, self.Pref))
        eref, vref = numpy.linalg.eigh(F2mv2)

        # measure difference in dvecs?
        rotor = params["rotor"] if "rotor" in params else None
        cutoff = params["cutoff"] if "cutoff" in params else None
        if rotor is not None:
            rotor /= constants.hartree_to_cm_1
        if cutoff is not None:
            cutoff /= constants.hartree_to_cm_1

        funcs = [get_function(s, beta, self.sys.I, rotor, cutoff) for s in quantities]
        ests = [sum([func(w2) for w2 in e]) for func in funcs]
        logging.info("Initial estimates:")
        for nm, x in zip(quantities, ests):
            logging.info("   {}: {:13.10f}".format(nm, x))

        Hx = numpy.zeros((n, nsample))
        xs = numpy.zeros((n, nsample))
        matvec = get_matvec(self.sys.mol, self.computer)

        Hest = F2m.copy()

        for i in range(nsample):
            if method.lower() == "modes":
                x = numpy.matmul(self.Pref, vref[:, i])
            else:
                x = get_random(n, method)

            Hx[:, i] = matvec(x)
            xs[:, i] = x

            # get rotation matrix for orthogal vectors
            S = numpy.matmul(xs[:, :(i+1)].transpose(), xs[:, :(i+1)])
            if i == 0:
                ss = numpy.sqrt(numpy.dot(x, x))
                Ux = (x/ss).reshape((n, 1))
                V = numpy.asarray([1.0/ss]).reshape((1, 1))
                #Ha = Hx/ss
            else:
                se, sv = numpy.linalg.eigh(S)
                V = numpy.zeros(sv.shape)
                # add check for zero overlaps

                for j in range(V.shape[1]):
                    V[:, j] = sv[:, j]/numpy.sqrt(se[j])

                Ux = numpy.matmul(xs[:, :(i + 1)], V)

            # Hessian projected onto xs
            Ha = numpy.matmul(Hx[:, :(i + 1)], V)
            Ha = numpy.matmul(Ha[:, :(i + 1)], Ux.transpose())
            Ha = 0.5*(Ha + Ha.transpose())

            # Add to reference hessian projected against xs
            P = numpy.eye(n) - numpy.matmul(Ux, Ux.transpose())
            Hest = numpy.matmul(P, numpy.matmul(F2m, P)) + Ha
            Hestv = numpy.matmul(self.P.transpose(), numpy.matmul(Hest, self.P))
            e, _v = numpy.linalg.eigh(Hestv)
            nimag = 0
            for w2 in e:
                if w2 < 0:
                    nimag += 1
            if nimag > 0:
                logging.warning("{} imaginary frequencies detected in reference.".format(nimag))
            ests = [sum([func(w2) for w2 in e]) for func in funcs]
            logging.info("Estimates for iteration {}:".format(i + 1))
            for nm, x in zip(quantities, ests):
                logging.info("   {}: {:13.10f}".format(nm, x))
