import numpy
from . import constants


def switch(w, sthr):
    return 1.0/(1.0+(sthr/w)**4)


def HO_RR_analysis(sys, T=298.15, P=1.01325E5, rotor=None, cutoff=None):
    e, v = numpy.linalg.eigh(sys.I)
    rthresh = 1e-4
    bAB = True if abs(e[0] - e[1]) < rthresh else False
    bBC = True if abs(e[1] - e[2]) < rthresh else False
    bAC = True if abs(e[0] - e[2]) < rthresh else False
    if sys.linear:
        print("Molecule is a linear top.")
        assert(abs(e[0]) < rthresh)
        e = e[1:]
    elif bAB and bBC and bAC:
        print("Molecule is a spherical top")
    elif bAB:
        assert(not bAC and not bBC)
        print("Molecule is a oblate symmetric top")
    elif bBC:
        assert(not bAC and not bAB)
        print("Molecule is an prolate symmetric top")
    else:
        print("Molecule is an asymmetric top")
    print("Principal moments (amu*Bohr^2):")
    for i, m in enumerate(e):
        print("    {}: {:.6f}".format(i+1, m/constants.amu_to_el))
    srot = 1  # TODO: fix this
    #print("Rotational symmetry number: {}".format(srot))
    print("WARNING rotational symmetry number is assumed to be 1 (this may be wrong)")

    if sys.P2 is None:
        sys.get_normal_modes()

    omega2 = sys.P2.diagonal()
    imin = 0
    for i, w2 in enumerate(omega2):
        if w2 < 0:
            imin = i+1
            print("WARNING: Imaginary frequencies detected!")

    omega2 = numpy.abs(omega2)
    omega = numpy.sqrt(omega2)
    print("")
    print("Normal Modes (cm^{-1}):")
    for i, w in enumerate(omega):
        sign = -1 if i < imin else 1
        print("  {:4d}: {:.2f}".format(i+1, sign*w*constants.hartree_to_cm_1))
    print("")

    kBT = T*constants.kb / constants.hartree_to_ev
    beta = 1.0 / (kBT + 1e-14)
    zpe = omega.sum()/2.0
    ex = numpy.exp(-omega*beta)
    Hvibv = (omega*ex/(1.0 - ex))
    Avibv = (numpy.log(1.0 - ex))/beta
    kb_ha = constants.kb/constants.hartree_to_ev
    Svibv = (Hvibv - Avibv)*beta*kb_ha

    # apply rotor approximation for Entropy if requested
    if rotor is not None:
        avmom = numpy.trace(sys.I)/3.0
        print("Applying the rotor approximation to low-frequency modes...")
        sthr = rotor/constants.hartree_to_cm_1
        for i, w in enumerate(omega):
            mu = 0.5/(w + 1e-15)
            mu = mu*avmom/(mu + avmom)
            wt = switch(w, sthr)
            srr = 0.5 + numpy.log(numpy.sqrt(2.0*numpy.pi*mu/beta))
            Svibv[i] = wt*Svibv[i] + (1.0 - wt)*srr*kb_ha

    if cutoff is not None:
        print("Applying the cutoff approximation to low-frequency modes...")
        if rotor is not None:
            raise Exception("Only one approximation may be applied to low-frequency modes")
        cut = cutoff/constants.hartree_to_cm_1
        exc = numpy.exp(-cut*beta)
        Hcut = (cut*exc/(1.0 - exc))
        Acut = (numpy.log(1.0 - exc))/beta
        scut = (Hcut - Acut)*beta*kb_ha
        for i, w in enumerate(omega):
            Svibv[i] = Svibv[i] if w > cut else scut

    Hvib = Hvibv.sum()
    Avib = Avibv.sum()
    Svib = Svibv.sum()
    if sys.linear:
        qr = 2*kBT*e[0]/srot
        Hr = kBT
        Sr = kb_ha*(numpy.log(qr) + 1.0)  # Ha/K
    else:
        qr = numpy.sqrt(numpy.pi)*numpy.sqrt(kBT*kBT*kBT)*numpy.sqrt(8.0*e[0]*e[1]*e[2])/srot
        Sr = (numpy.log(qr) + 3.0/2.0)*constants.kb/constants.hartree_to_ev
        Hr = 1.5*kBT
    Gr = Hr - T*Sr
    Mtot = sum(sys.mol.masses)*constants.amu_to_el
    P_au = P/constants.au_to_pa
    twopi3_2 = pow(2.0*numpy.pi, 1.5)
    qt = numpy.sqrt(Mtot*Mtot*Mtot*kBT*kBT*kBT)*kBT/(P_au*twopi3_2)
    St = kb_ha*(numpy.log(qt) + 5.0/2.0)
    Ht = 3.0*kBT/2.0
    Gt = Ht - T*St

    print("Harmonic zero point energy: {:.4f} kcal/mol".format(zpe*constants.hartree_to_kcal))
    print("")
    print("Thermochemistry at T = {} K, P = {} pa".format(T, P))
    print("   Vib. Enthalpy (kcal/mol):      {: .4f}".format((Hvib + zpe)*constants.hartree_to_kcal))
    print("   Vib. Free Energy (kcal/mol):   {: .4f}".format((Avib + zpe)*constants.hartree_to_kcal))
    if rotor is not None or cutoff is not None:
        print("      Vib. Free Energy (approx):  {: .4f}".format((Hvib + zpe - T*Svib)*constants.hartree_to_kcal))
    print("   Vib. Entropy (cal/mol K):      {: .4f}".format(1000*(Hvib - Avib)*beta*kb_ha*constants.hartree_to_kcal))
    if rotor is not None or cutoff is not None:
        print("      Vib. Entropy (approx):      {: .4f}".format((1000*Svib)*constants.hartree_to_kcal))
    print("   Rot. Enthalpy (kcal/mol):      {: .4f}".format(Hr*constants.hartree_to_kcal))
    print("   Rot. Free Energy (kcal/mol):   {: .4f}".format(Gr*constants.hartree_to_kcal))
    print("   Rot. Entropy (cal/mol K):      {: .4f}".format(1000*Sr*constants.hartree_to_kcal))
    print("   Trans. Enthalpy (kcal/mol):    {: .4f}".format(Ht*constants.hartree_to_kcal))
    print("   Trans. Free Energy (kcal/mol): {: .4f}".format(Gt*constants.hartree_to_kcal))
    print("   Trans. Entropy (cal/mol K):    {: .4f}".format(1000*St*constants.hartree_to_kcal))
