import logging
import numpy
from .random import get_random
from .random import get_importance
from .lanczos import Lanczos


def stochastic_trace(nout, n, Xv, nsample, method, names=None):
    avgs = [numpy.zeros((nsample)) for i in range(nout)]
    for i in range(nsample):
        v = get_random(n, method=method)
        # project
        #if P is not None:
        #    assert(P.shape[0] == P.shape[1] == n)
        #    v = numpy.einsum('ij,j->i', P, v)

        Xvs = Xv(v)
        for j, xv in enumerate(Xvs):
            avgs[j][i] = numpy.dot(v, xv)

        if names is not None:
            assert(len(names) == nout)
            d = i + 1e-4
            logging.info("Sample: {}".format(i))
            for nm, xavg in zip(names, avgs):
                logging.info("   {}: {:13.10f} {:9.4E}".format(
                    nm,
                    numpy.average(xavg[:i+1]),
                    numpy.sqrt(numpy.var(xavg[:i+1])/d)))
            logging.info("")

    return avgs


def stochastic_trace_imp(nout, n, Xv, nsample, vec, prob, P=None, names=None):
    avgs = [numpy.zeros((nsample)) for i in range(nout)]
    for i in range(nsample):
        idx, w = get_importance(prob)
        v = vec[:, idx]
        dim = v.shape[0]
        v = numpy.sqrt(dim)*v/numpy.linalg.norm(v)
        Ud = 1./float(dim)
        # project
        if P is not None:
            assert(P.shape[0] == P.shape[1] == n)
            v = numpy.einsum('ij,j->i', P, v)

        Xvs = Xv(v)
        for j, xv in enumerate(Xvs):
            avgs[j][i] = Ud*numpy.dot(v, xv)/w

        if names is not None:
            assert(len(names) == nout)
            d = i + 1e-4
            logging.info("Sample: {}".format(i))
            for nm, xavg in zip(names, avgs):
                logging.info("   {}: {:13.10f} {:9.4E}".format(
                    nm,
                    numpy.average(xavg[:i+1]),
                    numpy.sqrt(numpy.var(xavg[:i+1])/d)))
            logging.info("")

    out = [numpy.average(x) for x in avgs]
    return out


def stochastic_trace_pp(nout, n, Xv, nsample, kappa=1, names=None):
    kappa = int(kappa)
    avgs = numpy.array([numpy.zeros((kappa*nsample)) for i in range(nout)])
    vveca = numpy.zeros((n, nsample))
    vvecb = numpy.zeros((n, kappa*nsample))
    Xveca = [numpy.zeros((n, nsample)) for i in range(nout)]
    Xvecb = [numpy.zeros((n, kappa*nsample)) for i in range(nout)]
    for i in range(nsample):
        va = get_random(n, method="Rademacher")
        # project
        #if P is not None:
        #    assert(P.shape[0] == P.shape[1] == n)
        #    va = numpy.einsum('ij,j->i', P, va)
        #    vb = numpy.einsum('ij,j->i', P, vb)
        vveca[:, i] = va

        Xvs = Xv(va)
        for X, x in zip(Xveca, Xvs):
            X[:, i] = x

    for i in range(kappa*nsample):
        vb = get_random(n, method="Rademacher")
        vvecb[:, i] = vb

        Xvs = Xv(vb)
        for X, x in zip(Xvecb, Xvs):
            X[:, i] = x

    for j, Xa in enumerate(Xveca):
        Xb = Xvecb[j]
        Q, R = numpy.linalg.qr(Xa)
        AQ = numpy.zeros(Q.shape)
        for icol in range(Q.shape[1]):
            AQ[:, icol] = Xv(Q[:, icol])[j]

        # compute exact part
        #for i in range(nsample):
        #    avgs[j][i] = numpy.trace(
        #        numpy.einsum('qr,rs->qs', Q.transpose(), AQ))
        avgs[j, :] = numpy.trace(
                numpy.einsum('qr,rs->qs', Q.transpose(), AQ))

        # compute correction (AV - AQQ^TV)
        for i in range(kappa*nsample):
            Acor = Xb[:, i] - numpy.einsum(
                                'pq,qr,r->p', AQ, Q.transpose(), vvecb[:, i])
            avgs[j][i] += numpy.matmul(vvecb[:, i].transpose(), Acor)
            avgs[j][i] -= numpy.matmul(
                numpy.matmul(vvecb[:, i].transpose(), Q),
                numpy.matmul(Q.transpose(), Acor))

    for i in range(kappa*nsample):
        if names is not None:
            assert(len(names) == nout)
            d = i + 1e-4
            logging.info("Sample: {}".format(i))
            for nm, xavg in zip(names, avgs):
                logging.info("   {}: {:13.10f} {:9.4E}".format(
                    nm,
                    numpy.average(xavg[:i+1]),
                    numpy.sqrt(numpy.var(xavg[:i+1])/d)))
            logging.info("")

    return avgs


def stochastic_trace_diff(nout, n, Xv1, Xv2, nsample, method, P1=None, P2=None, names=None):
    logging.warning("This function is deprecated: Use stochastic trace instead!")
    avgs = [numpy.zeros((nsample)) for i in range(nout)]
    for i in range(nsample):
        v = get_random(n, method=method)
        # project
        if P1 is not None:
            assert(P1.shape[0] == P1.shape[1] == n)
            v1 = numpy.einsum('ij,j->i', P1, v)
        if P2 is not None:
            assert(P2.shape[0] == P2.shape[1] == n)
            v2 = numpy.einsum('ij,j->i', P2, v)

        Xv1s = Xv1(v1)
        Xv2s = Xv2(v2)
        for j, xv in enumerate(Xv1s):
            avgs[j][i] = (numpy.dot(v1, xv) - numpy.dot(v2, Xv2s[j]))

        if names is not None:
            assert(len(names) == nout)
            d = i + 1e-4
            logging.info("Sample: {}".format(i))
            for nm, xavg in zip(names, avgs):
                logging.info("   {}: {:13.10f} {:9.4E}".format(
                    nm,
                    numpy.average(xavg[:i+1]),
                    numpy.sqrt(numpy.var(xavg[:i+1])/d)))
            logging.info("")

    return avgs


def stochastic_trace_diff_imp(nout, n, Xv1, Xv2, nsample, vec, prob, P1=None, P2=None, names=None):
    avgs = [numpy.zeros((nsample)) for i in range(nout)]
    for i in range(nsample):
        idx, w = get_importance(prob)
        v = vec[:, idx]
        dim = v.shape[0]
        v1 = numpy.sqrt(dim)*v/numpy.linalg.norm(v)
        v2 = v1.copy()
        Ud = 1./float(dim)
        # project
        if P1 is not None:
            assert(P1.shape[0] == P1.shape[1] == n)
            v1 = numpy.einsum('ij,j->i', P1, v1)
        if P2 is not None:
            assert(P2.shape[0] == P2.shape[1] == n)
            v2 = numpy.einsum('ij,j->i', P2, v2)

        Xvs1 = Xv1(v1)
        Xvs2 = Xv2(v2)
        for j, xv1 in enumerate(Xvs1):
            avgs[j][i] = Ud*(numpy.dot(v1, xv1) - numpy.dot(v2, Xvs2[j]))/w

        if names is not None:
            assert(len(names) == nout)
            d = i + 1e-4
            logging.info("Sample: {}".format(i))
            for nm, xavg in zip(names, avgs):
                logging.info("   {}: {:13.10f} {:9.4E}".format(
                    nm,
                    numpy.average(xavg[:i+1]),
                    numpy.sqrt(numpy.var(xavg[:i+1])/d)))
            logging.info("")

    out = [numpy.average(x) for x in avgs]
    return out


def stochastic_trace_diff_pp(nout, n, Xv1, Xv2, nsample, P1=None, P2=None, names=None):
    logging.warning("This function is deprecated: Use stochastic trace instead!")
    avgs = [numpy.zeros((nsample)) for i in range(nout)]
    vveca1 = numpy.zeros((n, nsample))
    vvecb1 = numpy.zeros((n, nsample))
    vveca2 = numpy.zeros((n, nsample))
    vvecb2 = numpy.zeros((n, nsample))
    Xveca1 = [numpy.zeros((n, nsample)) for i in range(nout)]
    Xvecb1 = [numpy.zeros((n, nsample)) for i in range(nout)]
    Xveca2 = [numpy.zeros((n, nsample)) for i in range(nout)]
    Xvecb2 = [numpy.zeros((n, nsample)) for i in range(nout)]
    for i in range(nsample):
        va = get_random(n, method="Rademacher")
        vb = get_random(n, method="Rademacher")
        # project
        if P1 is not None:
            assert(P1.shape[0] == P1.shape[1] == n)
            va1 = numpy.einsum('ij,j->i', P1, va)
            vb1 = numpy.einsum('ij,j->i', P1, vb)
        if P2 is not None:
            assert(P2.shape[0] == P2.shape[1] == n)
            va2 = numpy.einsum('ij,j->i', P2, va)
            vb2 = numpy.einsum('ij,j->i', P2, vb)
        vveca1[:, i] = va1
        vveca2[:, i] = va2
        vvecb1[:, i] = vb1
        vvecb2[:, i] = vb2

        Xvs = Xv1(va1)
        for X, x in zip(Xveca1, Xvs):
            X[:, i] = x
        Xvs = Xv2(va2)
        for X, x in zip(Xveca2, Xvs):
            X[:, i] = x
        Xvs = Xv1(vb1)
        for X, x in zip(Xvecb1, Xvs):
            X[:, i] = x
        Xvs = Xv2(vb2)
        for X, x in zip(Xvecb2, Xvs):
            X[:, i] = x

    for j, Xa1 in enumerate(Xveca1):
        Xa2 = Xveca2[j]
        Xb1 = Xvecb1[j]
        Q1, R1 = numpy.linalg.qr(Xa1)
        AQ1 = numpy.zeros(Q1.shape)
        for icol in range(Q1.shape[1]):
            AQ1[:, icol] = Xv1(Q1[:, icol])[j]

        Xb2 = Xvecb2[j]
        Q2, R2 = numpy.linalg.qr(Xa2)
        AQ2 = numpy.zeros(Q2.shape)
        for icol in range(Q2.shape[1]):
            AQ2[:, icol] = Xv2(Q2[:, icol])[j]

        # compute exact part
        for i in range(nsample):
            avgs[j][i] = numpy.trace(
                numpy.einsum('qr,rs->qs', Q1.transpose(), AQ1))
            avgs[j][i] -= numpy.trace(
                numpy.einsum('qr,rs->qs', Q2.transpose(), AQ2))

        # compute correction (AV - AQQ^TV)
        for i in range(nsample):
            Acor1 = Xb1[:, i] - numpy.einsum(
                                'pq,qr,r->p', AQ1, Q1.transpose(), vvecb1[:, i])
            Acor2 = Xb2[:, i] - numpy.einsum(
                                'pq,qr,r->p', AQ2, Q2.transpose(), vvecb2[:, i])
            avgs[j][i] += numpy.matmul(vvecb1[:, i].transpose(), Acor1)
            avgs[j][i] -= numpy.matmul(
                numpy.matmul(vvecb1[:, i].transpose(), Q1),
                numpy.matmul(Q1.transpose(), Acor1))
            avgs[j][i] -= numpy.matmul(vvecb2[:, i].transpose(), Acor2)
            avgs[j][i] += numpy.matmul(
                numpy.matmul(vvecb2[:, i].transpose(), Q2),
                numpy.matmul(Q2.transpose(), Acor2))

    for i in range(nsample):
        if names is not None:
            assert(len(names) == nout)
            d = i + 1e-4
            logging.info("Sample: {}".format(i))
            for nm, xavg in zip(names, avgs):
                logging.info("   {}: {:13.10f} {:9.4E}".format(
                    nm,
                    numpy.average(xavg[:i+1]),
                    numpy.sqrt(numpy.var(xavg[:i+1])/d)))
            logging.info("")

    return avgs


def stochastic_trace_bind(nout, n, n1, n2, Xv, Xv1, Xv2, nsample, method, P=None, P1=None, P2=None, names=None):
    logging.warning("This function is deprecated: Use stochastic trace instead!")
    avgs = [numpy.zeros((nsample)) for i in range(nout)]
    assert(n == n1 + n2)
    if method.lower == "rayleigh":
        raise Exception("Rayleigh-random vectors will converge slower for a binding calculation")
    for i in range(nsample):
        v = get_random(n, method=method)
        v1 = v[:n1]
        v2 = v[n1:]
        assert(len(v1) == n1)
        assert(len(v2) == n2)
        # project
        if P is not None:
            assert(P.shape[0] == P.shape[1] == n)
            v = numpy.einsum('ij,j->i', P, v)
        if P1 is not None:
            assert(P1.shape[0] == P1.shape[1] == n1)
            v1 = numpy.einsum('ij,j->i', P1, v1)
        if P2 is not None:
            assert(P2.shape[0] == P2.shape[1] == n2)
            v2 = numpy.einsum('ij,j->i', P2, v2)

        Xvs = Xv(v)
        Xv2s = Xv2(v2)
        Xv1s = Xv1(v1)
        Xv2s = Xv2(v2)
        for j, xv in enumerate(Xvs):
            avgs[j][i] = (
                numpy.dot(v, xv) - numpy.dot(v1, Xv1s[j]) - numpy.dot(v2, Xv2s[j]))

        if names is not None:
            assert(len(names) == nout)
            d = i + 1e-4
            logging.info("Sample: {}".format(i))
            for nm, xavg in zip(names, avgs):
                logging.info("   {}: {:13.10f} {:9.4E}".format(
                    nm,
                    numpy.average(xavg[:i+1]),
                    numpy.sqrt(numpy.var(xavg[:i+1])/d)))
            logging.info("")

    return avgs


def exact_trace(nout, n, Xv):
    ntol = 1e-8
    outs = nout*[0.0]
    for i in range(n):
        v = numpy.zeros((n))
        v[i] = 1.0
        #if P is not None:
        #    assert(P.shape[0] == P.shape[1] == n)
        #    v = numpy.einsum('ij,j->i', P, v)
        if numpy.linalg.norm(v) < ntol:
            logging.warning("skipping vector {}".format(i))
            continue

        ws = Xv(v)
        assert(len(ws) == nout)
        for i, w in enumerate(ws):
            outs[i] += numpy.dot(v, w)

    return outs


def exact_trace_diff(nout, n, Xv1, Xv2, P1=None, P2=None):
    #ntol = 1e-8
    logging.warning("This function is deprecated: Use exact_trace instead!")
    outs = nout*[0.0]
    for i in range(n):
        v = numpy.zeros((n))
        v[i] = 1.0
        v1 = v.copy()
        v2 = v.copy()
        if P1 is not None:
            assert(P1.shape[0] == P1.shape[1] == n)
            v1 = numpy.matmul(P1, v1)
        if P2 is not None:
            assert(P2.shape[0] == P2.shape[1] == n)
            v2 = numpy.matmul(P2, v2)

        ws1 = Xv1(v1)
        ws2 = Xv2(v2)
        assert(len(ws1) == nout)
        assert(len(ws2) == nout)
        for i, w1 in enumerate(ws1):
            outs[i] += (numpy.dot(v1, w1) - numpy.dot(v2, ws2[i]))

    return outs


def exact_trace_bind(nout, n, n1, n2, Xv, Xv1, Xv2, P=None, P1=None, P2=None, names=None):
    #avgs = [numpy.zeros((nsample)) for i in range(nout)]
    logging.warning("This function is deprecated: Use exact_trace instead!")
    outs = nout*[0.0]
    assert(n == n1 + n2)
    for i in range(n):
        v = numpy.zeros((n))
        v[i] = 1.0
        v1 = v[:n1]
        v2 = v[n1:]
        assert(len(v1) == n1)
        assert(len(v2) == n2)
        # project
        if P is not None:
            assert(P.shape[0] == P.shape[1] == n)
            v = numpy.einsum('ij,j->i', P, v)
        if P1 is not None:
            assert(P1.shape[0] == P1.shape[1] == n1)
            v1 = numpy.einsum('ij,j->i', P1, v1)
        if P2 is not None:
            assert(P2.shape[0] == P2.shape[1] == n2)
            v2 = numpy.einsum('ij,j->i', P2, v2)

        Xvs = Xv(v)
        Xv2s = Xv2(v2)
        Xv1s = Xv1(v1)
        Xv2s = Xv2(v2)
        for j, xv in enumerate(Xvs):
            outs[j] += (
                numpy.dot(v, xv) - numpy.dot(v1, Xv1s[j]) - numpy.dot(v2, Xv2s[j]))

    return outs


def stochastic_lanczos(n, matvec, nsample, funcs, names, m, Ph=None):
    """Perform nsample iterations of m-Lanczos"""
    nout = len(funcs)
    vals = [numpy.zeros((nsample)) for i in range(nout)]
    if Ph is not None:
        assert(n == Ph.shape[1])
        n2 = Ph.shape[0]
    for i in range(nsample):
        if Ph is not None:
            v = get_random(n2, method="Rademacher")
            v = numpy.matmul(v, Ph)
        else:
            v = get_random(n, method="Rademacher")
        eng = Lanczos(matvec)
        eng.run(m + 1, v0=v)
        T = eng.Tmat()
        ei, vi = numpy.linalg.eigh(T)
        ti = vi[0, :]
        fis = [numpy.asarray([func(e) for e in ei]) for func in funcs]
        logging.info("Sample: {}".format(i))
        for nm, val, fi in zip(names, vals, fis):
            val[i] = v.shape[0]*numpy.dot(fi, ti*ti)
            logging.info("   {}: {:13.10f} {:9.4E}".format(
                nm,
                numpy.average(val[:i+1]),
                numpy.sqrt(numpy.var(val[:i+1])/(i+1))))
        logging.info("")

    return vals


def stochastic_lanczos_diff(n, matvec, matvec2, nsample, funcs, names, m, Ph1=None, Ph2=None):
    """Perform nsample iterations of m-Lanczos"""
    from .lanczos import Lanczos
    nout = len(funcs)
    vals = [numpy.zeros((nsample)) for i in range(nout)]
    if Ph1 is not None:
        assert(Ph2 is not None)
    if Ph2 is not None:
        assert(Ph1 is not None)
        n2 = Ph2.shape[0]
    for i in range(nsample):
        if Ph1 is None:
            v = get_random(n, method="Rademacher")
            v1 = v2 = v
        else:
            v = get_random(n2, method="Rademacher")
            v1 = numpy.matmul(v, Ph1)
            v2 = numpy.matmul(v, Ph2)
        eng = Lanczos(matvec)
        eng.run(m + 1, v0=v1)
        T = eng.Tmat()
        ei, vi = numpy.linalg.eigh(T)
        ti = vi[0, :]
        fis = [numpy.asarray([func(e) for e in ei]) for func in funcs]

        eng2 = Lanczos(matvec2)
        eng2.run(m + 1, v0=v2)
        T2 = eng2.Tmat()
        ei2, vi2 = numpy.linalg.eigh(T2)
        ti2 = vi2[0, :]
        fis2 = [numpy.asarray([func(e) for e in ei2]) for func in funcs]

        logging.info("Sample: {}".format(i))
        for nm, val, fi, fi2 in zip(names, vals, fis, fis2):
            first = v1.shape[0]*numpy.dot(fi, ti*ti)
            second = v2.shape[0]*numpy.dot(fi2, ti2*ti2)
            logging.debug("  {}:".format(nm))
            logging.debug("    First: {}".format(first))
            logging.debug("    Second: {}".format(second))
            val[i] = first - second
            logging.info("  {}: {:13.10f} {:9.4E}".format(
                nm,
                numpy.average(val[:i+1]),
                numpy.sqrt(numpy.var(val[:i+1])/(i+1))))
        logging.info("")

    return vals


def stochastic_lanczos_bind(n, n1, n2, matvec, matvec1, matvec2, nsample,
                            funcs, names, m, Ph=None, Ph1=None, Ph2=None):
    nout = len(funcs)
    assert(n == n1 + n2)
    assert(Ph is not None and Ph1 is not None and Ph2 is not None)
    vals = [numpy.zeros((nsample)) for i in range(nout)]

    for i in range(nsample):
        v = get_random(n, method="Rademacher")
        v1 = v[:n1]
        v2 = v[n1:]
        vx = numpy.matmul(v, Ph)
        v1x = numpy.matmul(v1, Ph1)
        v2x = numpy.matmul(v2, Ph2)

        eng = Lanczos(matvec)
        eng.run(m + 1, v0=vx)
        T = eng.Tmat()
        ei, vi = numpy.linalg.eigh(T)
        ti = vi[0, :]
        fis = [numpy.asarray([func(e) for e in ei]) for func in funcs]

        eng1 = Lanczos(matvec1)
        eng1.run(m + 1, v0=v1x)
        T1 = eng1.Tmat()
        ei1, vi1 = numpy.linalg.eigh(T1)
        ti1 = vi1[0, :]
        fis1 = [numpy.asarray([func(e) for e in ei1]) for func in funcs]

        eng2 = Lanczos(matvec2)
        eng2.run(m + 1, v0=v2x)
        T2 = eng2.Tmat()
        ei2, vi2 = numpy.linalg.eigh(T2)
        ti2 = vi2[0, :]
        fis2 = [numpy.asarray([func(e) for e in ei2]) for func in funcs]

        logging.info("Sample: {}".format(i))
        for nm, val, fi, fi1, fi2 in zip(names, vals, fis, fis1, fis2):
            val[i] = vx.shape[0]*numpy.dot(fi, ti*ti)
            val[i] -= v1x.shape[0]*numpy.dot(fi1, ti1*ti1)
            val[i] -= v2x.shape[0]*numpy.dot(fi2, ti2*ti2)
            logging.info("   {}: {:13.10f} {:9.4E}".format(
                nm,
                numpy.average(val[:i+1]),
                numpy.sqrt(numpy.var(val[:i+1])/(i+1))))
        logging.info("")

    return vals
