import logging
import numpy
from .optimizer import Optimizer
from .line_search import linesearch


class SteepestDescent(Optimizer):
    def __init__(self, maxstep=None, linesearch=False):
        self.maxstep = maxstep
        self.maxderiv = 1
        self.linesearch = linesearch
        self._step = None
        self._gk = None
        self._Ek = None

    def step(self, E, g, hess):
        assert(E is not None)
        assert(g is not None)

        # steepest descent step
        if self._step is None:
            step = -1.0*g
            sd = numpy.linalg.norm(step)/numpy.sqrt(step.shape[0])
            if self.maxstep is not None and sd > self.maxstep:
                scale = self.maxstep/sd
                logging.info("   Stepsize exceeds maximum, scaling by {}".format(scale))
                step *= scale
            if self.linesearch:
                self._step = step
                self._gk = g
                self._Ek = E
        # line-search step
        else:
            E1 = self._Ek
            g1 = numpy.dot(self._gk, self._step)
            E2 = E
            g2 = numpy.dot(g, self._step)
            alpha = linesearch(E1, g1, E2, g2)
            if alpha is None:
                logging.info("   Line search failed, stepping back")
                step = -0.9*self._step
            else:
                logging.debug("line-search step: {:.4f}".format(alpha))
                step = (alpha - 1.0)*self._step
            self._step = None
            self._gk = None
            self._Ek = None
        return step
