from copy import deepcopy
import numpy
import logging
from . import constants


def get_fd1(coords, delta):
    fcoords = []
    bcoords = []
    for i, x in enumerate(coords):
        for a in range(3):
            cf = deepcopy(coords)
            cb = deepcopy(coords)
            cf[i][a] += delta
            cb[i][a] -= delta
            fcoords.append(cf)
            bcoords.append(cb)
    return fcoords, bcoords


def get_fd2(coords, delta):
    ffc = []
    bbc = []
    fbc = []
    bfc = []
    for i in range(len(coords)):
        for a in range(3):
            for j in range(len(coords)):
                for b in range(3):
                    cff = deepcopy(coords)
                    cbb = deepcopy(coords)
                    cfb = deepcopy(coords)
                    cbf = deepcopy(coords)
                    cff[i][a] += delta
                    cff[j][b] += delta
                    cbb[i][a] -= delta
                    cbb[j][b] -= delta
                    cfb[i][a] += delta
                    cfb[j][b] -= delta
                    cbf[i][a] -= delta
                    cbf[j][b] += delta
                    ffc.append(cff)
                    bbc.append(cbb)
                    fbc.append(cfb)
                    bfc.append(cbf)
    return ffc, fbc, bfc, bbc


def fd_d1_0(obj, coords, computer, E0, diag2=True, delta=0.0012):
    temp = deepcopy(obj)
    fc, bc = get_fd1(coords, delta)
    nf = len(fc)
    nb = len(bc)
    N = 3*len(coords)
    assert(nf == nb)
    assert(nf == N)
    F1 = numpy.zeros(N)
    F2diag = numpy.zeros(N) if diag2 else None
    for i in range(nf):
        temp.set(fc[i], unit="Bohr")
        Ef = computer.energy(temp)
        temp.set(bc[i], unit="Bohr")
        Eb = computer.energy(temp)
        F1[i] = (Ef - Eb)/(2*delta)
        if diag2:
            F2diag[i] = (Ef + Eb - 2*E0)/(delta*delta)
    return F1, F2diag


def fd_d2_1(obj, coords, computer, F1, diag3=True, delta=0.0012):
    temp = deepcopy(obj)
    fc, bc = get_fd1(coords, delta)
    nf = len(fc)
    nb = len(bc)
    N = 3*len(coords)
    assert(nf == nb)
    assert(nf == N)
    F2 = numpy.zeros((N, N))
    F3diag = numpy.zeros((N, N)) if diag3 else None
    for i in range(nf):
        temp.set(fc[i], unit="Bohr")
        gradf = computer.gradient(temp)[1]
        temp.set(bc[i], unit="Bohr")
        gradb = computer.gradient(temp)[1]
        F2[i] = (gradf - gradb)/(2*delta)
        if diag3:
            F3diag[i] = (gradf + gradb - 2*F1)/(delta*delta)
        logging.info("Done FD {} of {}".format(i+1, nf))
    return F2, F3diag


def fd_d3_2(obj, coords, computer, F2, diag4=True, delta=0.0012):
    temp = deepcopy(obj)
    fc, bc = get_fd1(coords, delta)
    nf = len(fc)
    nb = len(bc)
    N = 3*len(coords)
    assert(nf == nb)
    assert(nf == N)
    F3 = numpy.zeros((N, N, N))
    F4diag = numpy.zeros((N, N, N)) if diag4 else None
    for i in range(nf):
        temp.set(fc[i], unit="Bohr")
        hessf = computer.hessian(temp)[2]
        temp.set(bc[i], unit="Bohr")
        hessb = computer.hessian(temp)[2]
        F3[i] = (hessf - hessb)/(2*delta)
        if diag4:
            F4diag[i] = (hessf + hessb - 2*F2)/(delta*delta)
    return F3, F4diag


def fd_d4_2(obj, coords, computer, delta=0.0012):
    N = 3*len(coords)
    F4 = numpy.zeros((N, N, N, N))
    temp = deepcopy(obj)
    ffc, fbc, bfc, bbc = get_fd2(coords, delta)
    assert(len(ffc) == N*N)
    assert(len(fbc) == N*N)
    assert(len(bfc) == N*N)
    assert(len(bbc) == N*N)
    tot = N*(N-1)//2
    count = 0
    for i in range(N):
        for j in range(i+1, N):
            idx = i*N + j
            if i == j:
                continue
            temp.set(ffc[idx], unit="Bohr")
            hessff = computer.hessian(temp)[2]
            temp.set(fbc[idx], unit="Bohr")
            hessfb = computer.hessian(temp)[2]
            temp.set(bfc[idx], unit="Bohr")
            hessbf = computer.hessian(temp)[2]
            temp.set(bbc[idx], unit="Bohr")
            hessbb = computer.hessian(temp)[2]
            F4[i, j] = (hessff + hessbb - hessfb - hessbf)/(4.0*delta*delta)
            F4[j, i] = F4[i, j]
            count = count + 1
            logging.info("Done FD {} of {}".format(count, tot))
    return F4


def fd_hess_on_vec(mol, computer, vec, mweight=False, method='sym', gc=None, delta=0.0012):
    coords = mol.coords
    temp = deepcopy(mol)
    cp = deepcopy(coords)
    cm = deepcopy(coords)
    Marray = [mol.masses[i//3]*constants.amu_to_el for i in range(3*mol.natom)]
    Mi2 = 1.0/numpy.sqrt(numpy.asarray(Marray))
    if mweight:
        vec = Mi2*vec
    count = 0
    for i in range(len(coords)):
        for a in range(3):
            cp[i][a] += delta*vec[count]
            cm[i][a] -= delta*vec[count]
            count += 1

    gp = gm = None
    # forward difference
    if method.lower() == 'sym' or method.lower() == 'forward':
        temp.set(cp, unit="Bohr")
        Ep, gp = computer.gradient(temp)

    # backward difference
    if method.lower() == 'sym' or method.lower() == 'backward':
        temp.set(cm, unit="Bohr")
        Em, gm = computer.gradient(temp)

    # central
    if method.lower() == 'backward' or method.lower() == 'forward':
        if gc is None:
            Ec, gc = computer.gradient(mol)

    if mweight:
        if gp is not None:
            gp = Mi2*gp
        if gm is not None:
            gm = Mi2*gm
        if gc is not None:
            gc = Mi2*gc

    if method.lower() == 'sym':
        return (gp - gm)/(2.0*delta)
    elif method.lower() == 'forward':
        return (gp - gc)/delta
    elif method.lower() == 'backward':
        return (gc - gm)/delta
    else:
        raise Exception("Unrecognized FD method: {}".format(method))
