import numpy
from .finite_difference import fd_d1_0, fd_d2_1, fd_d3_2, fd_d4_2


def compute_gradient(obj, coords, computer, delta):
    md = computer.max_deriv()
    if md > 0:
        E, F1 = computer.gradient(obj)
    else:
        E = computer.energy(obj)
        F1, _ = fd_d1_0(obj, coords, computer, E, diag2=False, delta=delta)
    return E, F1


def compute_hessian(obj, coords, computer, delta):
    md = computer.max_deriv()
    if md > 1:
        E, F1, F2 = computer.hessian(obj)
    elif md == 1:
        E, F1 = computer.gradient(obj)
        F2, _ = fd_d2_1(obj, coords, computer, F1, delta=delta)
    else:
        raise Exception("Gradient required for FD Hessian")
    return E, F1, F2


def compute_d3(obj, coords, computer, delta):
    md = computer.max_deriv()
    if md > 2:
        raise Exception("Analytic 3rd derivatives?")
    elif md == 2:
        E, F1, F2 = computer.hessian(obj)
        F3, _ = fd_d3_2(obj, coords, computer, F2, delta=delta)
    elif md == 1:
        raise Exception("FD 3rd derivatives requires Hessian")
    else:
        raise Exception("FD 3rd derivatives requires gradient")
    return E, F1, F2, F3


def compute_d4(obj, coords, computer, delta):
    md = computer.max_deriv()
    natom = obj.natom
    N = 3*natom
    if md > 2:
        raise Exception("Analytic 3rd/4th derivatives?")
    elif md == 2:
        E, F1, F2 = computer.hessian(obj)
        F3, F4diag = fd_d3_2(obj, coords, computer, F2, delta=delta)
        F4 = numpy.zeros((N, N, N, N))
        for i in range(N):
            F4[i, i] = F4diag[i]
        F4o = fd_d4_2(obj, coords, computer, delta=delta)
        F4 += F4o
    elif md == 1:
        raise Exception("FD 3rd derivatives requires Hessian")
    else:
        raise Exception("FD 3rd derivatives requires gradient")
    return E, F1, F2, F3, F4


def compute_forces(mol, computer, order, delta=0.0012):
    F1 = None
    F2 = None
    F3 = None
    F4 = None

    if order == 0:
        E = computer.energy(mol)
    elif order == 1:
        E, F1 = compute_gradient(mol, mol.coords, computer, delta)
    elif order == 2:
        E, F1, F2 = compute_hessian(mol, mol.coords, computer, delta)
    elif order == 3:
        E, F1, F2, F3 = compute_d3(mol, mol.coords, computer, delta)
    elif order == 4:
        E, F1, F2, F3, F4 = compute_d4(mol, mol.coords, computer, delta)
    else:
        raise Exception("Unsuppored force constant order: {}".format(order))
    return E, F1, F2, F3, F4
