import numpy
from .vutils import print_state
from .vutils import vci_matrixel


class VCI(object):
    def __init__(self, sys, nmax, Ecut=None, T=298.15):
        self.omegas = numpy.sqrt(sys.P2.diagonal())
        self.nm = len(self.omegas)
        self.sys = sys

        old = []
        for m in range(nmax + 1):
            old.append([m])
        for i in range(self.nm - 1):
            new = []
            for x in old:
                for m in range(nmax + 1):
                    temp = list(x)
                    temp.append(m)
                    new.append(temp)
            old = new
        self.basis = []
        for x in old:
            if sum(x) <= nmax:
                self.basis.append(x)

    def print_basis(self):
        print("CI basis:")
        for s in self.basis:
            print_state(s)

    def solve(self):
        N = len(self.basis)
        H = numpy.zeros((N, N))
        for i in range(N):
            bi = self.basis[i]
            for j in range(N):
                bj = self.basis[j]
                H[i, j] = vci_matrixel(
                    bi, bj, self.omegas, self.sys.P1,
                    self.sys.P2, self.sys.P3, self.sys.P4)

        e, v = numpy.linalg.eigh(H)
        return e
