import numpy
import logging
from .compute import compute_hessian
from .compute import compute_gradient
from .optimizers.steepest_descent import SteepestDescent
from .optimizers.newton import Newton
from .optimizers.bfgs import BFGS
from .coord.curve import get_step
from .coord.ric import RIC


def get_convstr(ec, gc, dc):
    cchar = 'Y' if ec else 'N'
    gchar = 'Y' if gc else 'N'
    dchar = 'Y' if dc else 'N'
    return cchar + gchar + dchar


def get_coords(method, newmol):
    if method.lower() == "redundant" or method.lower() == "ric":
        return RIC(newmol)
    else:
        raise Exception("Unrecognized method: {}".format(method))


def _opt(mol, computer, options, method):
    if options is None:
        options = {}
    max_iter = options["max_iter"] if "max_iter" in options else 500
    max_step = options["max_step"] if "max_step" in options else 0.4
    conv_grad = options["conv_grad"] if "conv_grad" in options else 1e-3
    conv_energy = options["conv_energy"] if "conv_energy" in options else 1e-4
    conv_disp = options["conv_disp"] if "conv_disp" in options else 1e-3
    hessian = options["hessian"] if "hessian" in options else None
    delta = options["delta"] if "delta" in options else None
    line_search = options["linesearch"] if "linesearch" in options else False

    newmol = mol.copy()
    if hessian is not None:
        if hessian.lower() == "full":
            theopt = Newton(maxstep=max_step)
        elif hessian.lower() == "bfgs":
            theopt = BFGS(maxstep=max_step)
        else:
            raise Exception("Unrecognized Hessian option")
    else:
        theopt = SteepestDescent(maxstep=max_step, linesearch=line_search)

    if theopt.maxderiv > computer.max_deriv():
        logging.warning("Finite differences will be used")
        if delta is None:
            delta = 0.001

    logging.info("Using " + method + " coordinates for optimization...")
    logging.info("Optimization parameters:")
    logging.info("    max. iter:    {}".format(max_iter))
    logging.info("    max. step:    {}".format(max_step))
    logging.info("    conv. energy: {:.1E}".format(conv_energy))
    logging.info("    conv. grad:   {:.1E}".format(conv_grad))
    logging.info("    conv. disp:   {:.1E}".format(conv_disp))
    if delta is not None:
        logging.info("    delta:        {:.1E}".format(delta))
    logging.info("Note: intensive estimates of dE, |g|, |d| are used for convergence")
    logging.info("--------------------------------------------------------------------------")
    logging.info(" iter      energy        grad.       disp.      |dE|     |g|     |d|   EGD")
    logging.info("--------------------------------------------------------------------------")

    Eold = None
    for i in range(max_iter + 1):
        # compute derivatives
        if theopt.maxderiv == 2:
            E, grad, hess = compute_hessian(newmol, computer, delta)
        elif theopt.maxderiv == 1:
            E, grad = compute_gradient(newmol, computer, delta)
            hess = None
        else:
            E = computer.energy(newmol)
            grad = hess = None

        # get step
        if method.lower() == "cartesian":
            step = theopt.step(E, grad, hess)
        else:
            csys = get_coords(method, newmol)
            B = csys.bmatrix_gen(newmol.coords)
            Binv = numpy.linalg.pinv(B.transpose((1, 0)), rcond=1e-6)

            # transform gradient
            cgrad = numpy.matmul(Binv, grad) if grad is not None else None

            if hess is not None:
                raise NotImplementedError
            # get step
            cstep = theopt.step(E, cgrad, None)

            # backtransform step to cartesians
            step = get_step(newmol, csys, cstep)

        # evalute convergence
        gg = numpy.linalg.norm(grad)
        eg = gg / numpy.sqrt(grad.shape[0])
        ee = abs(E - Eold)/grad.shape[0] if Eold is not None else 0.0
        dd = numpy.linalg.norm(step)
        ed = dd / numpy.sqrt(step.shape[0])

        econv = ee < conv_energy
        gconv = eg < conv_grad
        dconv = ed < conv_disp
        constr = get_convstr(econv, gconv, dconv)

        # update molecule and print iteration
        newmol.update(step, unit="Bohr")
        logging.info("{:4d}:  {:.8f}   {:.3E}   {:.3E}   {:.1E} {:.1E} {:.1E} {}".format(i, E, gg, dd, ee, eg, ed, constr))
        if econv and gconv and dconv:
            logging.info("Geometry optimization converged")
            break
        Eold = E
        if i == max_iter:
            logging.warning("Optimization did not converge!")
    return newmol


def optimize(mol, computer, method="Cartesian", options=None):
    return _opt(mol, computer, options, method)
