import numpy

from .interface import ComputerInterface, PBCComputerInterface

from pyscf import gto, dft, scf
from pyscf.hessian import rks as rks_hess
from pyscf.hessian import rhf as rhf_hess
from pyscf.pbc.gto import Cell
from pyscf.pbc import scf as pbc_scf


def build_mole(molecule, options):
    mol = gto.Mole()
    mol.atom = molecule.write_xyz_string()
    mol.basis = options.basis
    mol.charge = int(molecule.charge)
    mol.spin = int(molecule.spin)
    mol.verbose = options.verbose
    mol.build()
    return mol


def is_dft(method):
    if method.lower() == "hf":
        return False
    return True


def is_posthf(method):
    if is_dft(method) or method.lower() == "hf":
        return False
    else:
        return True


def get_mf(mol, options):
    if is_dft(options.method):
        mf = dft.RKS(mol) if mol.spin == 0 else scf.UKS(mol)
        mf.grids.level = 5
        mf.grids.prune = False
        mf.xc = options.method
        mf.conv_tol = 1e-12
        mf.conv_tol_grad = 1e-8 if options.scf_conv is None else options.scf_conv
    else:
        mf = scf.RHF(mol) if mol.spin == 0 else scf.UHF(mol)
        mf.conv_tol = 1e-12
        mf.conv_tol_grad = 1e-8 if options.scf_conv is None else options.scf_conv
    return mf


class Options:
    def __init__(self):
        self.method = None
        self.basis = None
        self.scf_conv = None
        self.verbose = 0


class PyscfInterface(ComputerInterface):
    def __init__(self, options, save=False):
        self.options = options
        self.maxderiv = 1
        if is_dft(self.options.method) or self.options.method.lower() == 'hf':
            self.maxderiv = 2
        self._dm0 = None
        self.save = save

    def max_deriv(self):
        return self.maxderiv

    def energy(self, mol):
        pyscf_mol = build_mole(mol, self.options)
        mf = get_mf(pyscf_mol, self.options)
        Emf = mf.kernel(dm0=self._dm0)
        if self.save:
            self._dm0 = mf.make_rdm1()

        if is_posthf(self.options.method):
            raise Exception("Post-HF Wavefunction methods not yet implemented")
        else:
            return Emf

    def gradient(self, mol):
        pyscf_mol = build_mole(mol, self.options)
        mf = get_mf(pyscf_mol, self.options)
        Emf = mf.kernel(dm0=self._dm0)
        if self.save:
            self._dm0 = mf.make_rdm1()

        if is_posthf(self.options.method):
            raise Exception("Post-HF Wavefunction methods not yet implemented")
        else:
            natom = pyscf_mol.natm
            grad = mf.nuc_grad_method()
            grad.grid_response = True
            g = grad.kernel()
            F1 = g.reshape(natom*3)
            return Emf, F1

    def hessian(self, mol):
        pyscf_mol = build_mole(mol, self.options)
        mf = get_mf(pyscf_mol, self.options)
        Emf = mf.kernel(dm0=self._dm0)
        if self.save:
            self._dm0 = mf.make_rdm1()

        if is_posthf(self.options.method):
            raise Exception("Post-HF Wavefunction methods not yet implemented")
        else:
            natom = pyscf_mol.natm
            grad = mf.nuc_grad_method()
            grad.grid_response = True
            g = grad.kernel()
            F1 = g.reshape(natom*3)
            hess = mf.Hessian().kernel()
            F2 = hess.transpose((0, 2, 1, 3)).reshape(3*natom, 3*natom)
            return Emf, F1, F2


def build_cell(uc, options):
    cell = Cell()
    cell.unit = "B"  # use Bohr here
    cell.a = numpy.asarray(uc.lattice)
    cell.atom = uc.mol.write_xyz_string(unit="Bohr")
    cell.basis = options.basis
    cell.pseudo = options.pseudo
    cell.verbose = options.verbose
    if options.precision is not None:
        cell.precision = options.precision
    cell.build()
    return cell


def get_pbc_mf(cell, options):
    kpts = cell.make_kpts(options.kmesh, scaled_center=options.kpt)
    if is_dft(options.method):
        mf = pbc_scf.KRKS(cell, kpts) if cell.spin == 0 else pbc_scf.KUKS(cell, kpts)
        mf.xc = options.method
        mf.conv_tol = 1e-12
        mf.conv_tol_grad = 1e-8 if options.scf_conv is None else options.scf_conv
    else:
        mf = pbc_scf.KRHF(cell, kpts) if cell.spin == 0 else pbc_scf.KUHF(cell, kpts)
        mf.conv_tol = 1e-12
        mf.conv_tol_grad = 1e-8 if options.scf_conv is None else options.scf_conv
    return mf


class PBCOptions:
    def __init__(self):
        self.method = None
        self.basis = None
        self.pseudo = None
        self.ke_cutoff = None
        self.kmesh = None
        self.kpt = None
        self.scf_conv = None
        self.verbose = 0
        self.precision = None


class PBCPyscfInterface(PBCComputerInterface):
    def __init__(self, options, save=False):
        self.options = options
        self.maxderiv = 1
        #if is_dft(self.options.method) or self.options.method.lower() == 'hf':
        #    self.maxderiv = 2
        self._dm0 = None
        self.save = save

    def max_deriv(self):
        return self.maxderiv

    def energy(self, cell):
        pyscf_cell = build_cell(cell, self.options)
        mf = get_pbc_mf(pyscf_cell, self.options)
        Emf = mf.kernel(dm0=self._dm0)
        if self.save:
            self._dm0 = mf.make_rdm1()

        if is_posthf(self.options.method):
            raise Exception("Post-HF Wavefunction methods not yet implemented")
        else:
            return Emf

    def gradient(self, cell):
        pyscf_cell = build_cell(cell, self.options)
        mf = get_pbc_mf(pyscf_cell, self.options)
        Emf = mf.kernel(dm0=self._dm0)
        if self.save:
            self._dm0 = mf.make_rdm1()

        if is_posthf(self.options.method):
            raise Exception("Post-HF Wavefunction methods not yet implemented")
        else:
            natom = pyscf_cell.natm
            grad = mf.nuc_grad_method()
            #grad.grid_response = True
            g = grad.kernel()
            F1 = g.reshape(natom*3)
            return Emf, F1
