import numpy


def energy(omegas, state):
    assert(len(omegas) == len(state))
    E = 0.0
    for i, w in enumerate(omegas):
        E += 0.5*(2*i + 1)*w
    return E


def print_state(state):
    s = "|"
    for x in state:
        s += " "
        s += str(x)
    s += ">"
    print(s)


def vci_matrixel(bi, bj, omegas, P1, P2, P3, P4):
    nmode = len(omegas)
    assert(len(bi) == len(bj))
    assert(len(bi) == nmode)
    ds = []  # absolute difference in each mode
    ns = []  # max n in each mode
    ndiff = 0  # number of differing modes
    diff = []
    for m, w in enumerate(omegas):
        im = bi[m]
        jm = bj[m]
        igj = im > jm
        x = im - jm if igj else jm - im
        if x > 0:
            ndiff = ndiff + 1
            diff.append(m)
        ds.append(x)
        ns.append(im if igj else jm)

    H = 0.0
    # bra and ket equal: T + m^2, m^4, m^2n^2
    if ndiff == 0:
        for im, w in enumerate(omegas):
            m = ns[im]
            H += 0.5*(2*m + 1.0)*w
            H += P4[im, im, im, im]*(6.0*m*(m + 1.0) + 3.0)/(24.0*4.0*w*w)
        for im, wm in enumerate(omegas):
            m = ns[im]
            for inn in range((im + 1), nmode):
                wn = omegas[inn]
                n = ns[inn]
                H += P4[im, im, inn, inn]*(1/4.0)*(m + 0.5)/wm*(n + 0.5)/wn
    # difference in one mode
    elif ndiff == 1:
        assert(len(diff) == 1)
        im = diff[0]
        m = ns[im]
        w = omegas[im]
        d = ds[im]
        assert(d > 0)
        if d == 1:  # m + mn^2 + m^3
            #H += P1[im]*numpy.sqrt(0.5*m/w)
            H += P3[im, im, im]*(3.0/6.0)*numpy.sqrt(0.5*m/w)*0.5*m/w
            for inn, wn in enumerate(omegas):
                if inn == im:
                    continue
                n = ns[inn]
                H += P3[im, inn, inn]*(1/2.0)*numpy.sqrt(0.5*m/w)*(n + 0.5)/wn
        elif d == 2:  # (T + m^2) + m^2n^2 + m^4
            H += P4[im, im, im, im]*(1/24.0)*(m - 0.5)*numpy.sqrt(m*(m - 1.0))/(w*w)
            for inn, wn in enumerate(omegas):
                if inn == im:
                    continue
                n = ns[inn]
                H += P4[im, im, inn, inn]*(1/4.0)*numpy.sqrt(m*(m - 1.0))/(2.0*w)*(n + 0.5)/wn
        elif d == 3:  # m^3
            H += P3[im, im, im]*(1.0/6.0)*numpy.sqrt(m*(m - 1.0)*(m - 2.0)/(8.0*w*w*w))
        elif d == 4:  # m^4
            H += P4[im, im, im, im]*(1.0/24.0)*numpy.sqrt(m*(m - 1.0)*(m - 2.0)*(m - 3.0))/(4.0*w*w)
        else:
            pass

    # difference in two modes
    elif ndiff == 2:
        assert(len(diff) == 2)
        im = diff[0]
        inn = diff[1]
        m = ns[im]
        wm = omegas[im]
        dm = ds[im]
        n = ns[inn]
        wn = omegas[inn]
        dn = ds[inn]
        if dm > dn:
            im, inn = inn, im
            m, n = n, m
            wm, wn = wn, wm
            dm, dn = dn, dm
        if dm == 1 and dn == 1:  # mn + m^3n + mn^3 + nmo^2
            H += P2[im, inn]*numpy.sqrt(0.5*m/wm)*numpy.sqrt(0.5*n/wn)
            H += P4[im, im, im, inn]*(1/6.0)*3.0*numpy.sqrt(m*m*m/(8.0*wm*wm*wm))*numpy.sqrt(0.5*n/wn)
            H += P4[im, inn, inn, inn]*(1/6.0)*numpy.sqrt(0.5*m/wm)*3.0*numpy.sqrt(n*n*n/(8.0*wn*wn*wn))
            for io, wo in enumerate(omegas):
                o = ns[io]
                if io == inn or io == im:
                    continue
                H += P4[im, inn, io, io]*(1/2.0)*numpy.sqrt(0.5*m/wm)*numpy.sqrt(0.5*n/wn)*(o + 0.5)/wo
        elif dm == 1 and dn == 2:  # mn^2
            H += P3[im, inn, inn]*(1/2.0)*numpy.sqrt(0.5*m/wm)*numpy.sqrt(n*(n - 1.0))/(2*wn)
        elif dm == 1 and dn == 3:  # mn^3
            H += P4[im, inn, inn, inn]*(1/6.0)*numpy.sqrt(0.5*m/wm)*numpy.sqrt(n*(n - 1.0)*(n - 2.0)/(8.0*wn*wn*wn))
        elif dm == 2 and dn == 2:  # m^2n^2
            H += P4[im, im, inn, inn]*(1/4.0)*numpy.sqrt(m*(m - 1.0))/(2*wm)*numpy.sqrt(n*(n - 1.0))/(2*wn)
        else:
            pass
    # difference in three modes
    elif ndiff == 3:
        assert(len(diff) == 3)
        im = diff[0]
        inn = diff[1]
        io = diff[2]
        dm = ds[im]
        dn = ds[inn]
        do = ds[io]
        wm = omegas[im]
        wn = omegas[inn]
        wo = omegas[io]
        m = ns[im]
        n = ns[inn]
        o = ns[io]
        if dm == 1 and dn == 1 and do == 1:     # mno
            H += P3[im, inn, io]*numpy.sqrt(0.5*m/wm)*numpy.sqrt(0.5*n/wn)*numpy.sqrt(0.5*o/wo)
        elif dm == 1 and dn == 1 and do == 2:   # mno^2
            H += P4[im, inn, io, io]*(1/2.0)*numpy.sqrt(0.5*m/wm)*numpy.sqrt(0.5*n/wn)*numpy.sqrt(o*(o - 1.0))/(2*wo)
        elif dm == 1 and dn == 2 and do == 1:   # mn^2o
            H += P4[im, inn, inn, io]*(1/2.0)*numpy.sqrt(0.5*m/wm)*numpy.sqrt(n*(n - 1.0))/(2*wn)*numpy.sqrt(0.5*o/wo)
        elif dm == 2 and dn == 1 and do == 1:   # m^2no
            H += P4[im, im, inn, io]*(1/2.0)*numpy.sqrt(m*(m - 1.0))/(2*wm)*numpy.sqrt(0.5*n/wn)*numpy.sqrt(0.5*o/wo)
        else:
            pass
    # difference in four modes
    elif ndiff == 4:
        assert(len(diff) == 4)
        im = diff[0]
        inn = diff[1]
        io = diff[2]
        ip = diff[3]
        dm = ds[im]
        dn = ds[inn]
        do = ds[io]
        dp = ds[ip]
        m = ns[im]
        n = ns[inn]
        o = ns[io]
        p = ns[ip]
        wm = omegas[im]
        wn = omegas[inn]
        wo = omegas[io]
        wp = omegas[ip]
        if dm == 1 and dn == 1 and do == 1 and dp == 1:
            H += P4[im, inn, io, ip]*numpy.sqrt(0.5*m/wm)*numpy.sqrt(0.5*n/wn)*numpy.sqrt(0.5*o/wo)*numpy.sqrt(0.5*p/wp)
    return H
