import logging
import numpy
from .optimizer import Optimizer
from .line_search import linesearch


class BFGS(Optimizer):
    def __init__(self, maxstep=None):
        self.maxstep = maxstep
        self.maxderiv = 1
        self._hess = None  # BFGS hessian
        self._step = None  # step from previous point
        self._gk = None  # gradient from previous iteration
        self._Ek = None  # energy from previous iteration

    def step(self, E, g, hess):
        assert(E is not None)
        assert(g is not None)
        n = g.shape[0]
        if self._hess is None:
            self._hess = numpy.eye(n)

        # approximate Newton step
        if self._step is None:
            assert(self._gk is None)
            assert(self._Ek is None)
            hinv = numpy.linalg.pinv(self._hess, rcond=1e-9)
            step = -numpy.matmul(hinv, g)
            self._step = step
            self._gk = g
            self._Ek = E
            return step
        # line-search and Hessian update
        else:
            E1 = self._Ek
            g1 = numpy.dot(self._gk, self._step)
            E2 = E
            g2 = numpy.dot(g, self._step)
            alpha = linesearch(E1, g1, E2, g2)
            if alpha is None:
                logging.info("   Line search failed, stepping back")
                step = -0.9*self._step
            else:
                step = (alpha - 1.0)*self._step

            sk = alpha*self._step
            yk = g - self._gk

            Hs = numpy.matmul(self._hess, sk)
            H2 = Hs[:, None]*Hs[None, :]/numpy.dot(sk, Hs)

            H1 = numpy.zeros((n, n))
            H1 += yk[None, :]*yk[:, None]/numpy.dot(yk, sk)
            self._hess += (H1 - H2)
            self._step = None
            self._gk = None
            self._Ek = None
            return step
