from copy import deepcopy
from itertools import product
import numpy
from .vutils import print_state


class VPT2(object):
    def __init__(self, sys, T=298.15):
        self.omegas = numpy.sqrt(sys.P2.diagonal())
        self.nm = len(self.omegas)
        self.sys = sys

    def pt2_correction(self, state):
        assert(len(state) == self.nm)

        # construct list of connected states
        slist = []
        # differences in one mode
        for i in range(self.nm):
            ni = state[i]
            start = max(0, ni - 4)
            end = ni + 5
            for t in range(start, end):
                if t == ni:
                    continue
                temp = deepcopy(state)
                temp[i] = t
                slist.append(temp)

        # differences in two modes
        for i in range(self.nm):
            for j in range(i+1, self.nm):
                ni = state[i]
                nj = state[j]
                si = max(0, ni - 3)
                ei = ni + 5
                sj = max(0, nj - 3)
                ej = nj + 4
                for ti, tj in product(range(si, ei), range(sj, ej)):
                    di = abs(ti - ni)
                    dj = abs(tj - nj)
                    if dj == 0 or di == 0 or dj + di > 4:
                        continue
                    temp = deepcopy(state)
                    temp[i] = ti
                    temp[j] = tj
                    slist.append(temp)

        # differences in three modes
        for i in range(self.nm):
            for j in range(i+1, self.nm):
                for k in range(j+1, self.nm):
                    ni = state[i]
                    nj = state[j]
                    nk = state[k]
                    si = max(0, ni - 2)
                    ei = ni + 3
                    sj = max(0, nj - 2)
                    ej = nj + 3
                    sk = max(0, nk - 2)
                    ek = nk + 3
                    for ti, tj, tk in product(
                            range(si, ei),
                            range(sj, ej),
                            range(sk, ek)):
                        di = abs(ti - ni)
                        dj = abs(tj - nj)
                        dk = abs(tk - nk)
                        dtot = di + dj + dk
                        if di == 0 or dj == 0 or dtot > 4:
                            continue
                        temp = deepcopy(state)
                        temp[i] = ti
                        temp[j] = tj
                        temp[k] = tk
                        slist.append(temp)

        # differences in four modes
        for i in range(self.nm):
            for j in range(i+1, self.nm):
                for k in range(j+1, self.nm):
                    for l in range(k+1, self.nm):
                        print("WTF")

        for state in slist:
            print_state(state)

        # loop over list and compute matrix elements
        E2 = 0.0

        return E2
