from pathlib import Path
import pickle
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.ML.Descriptors import MoleculeDescriptors
from functools import lru_cache

app_dir = Path(__file__).parent

# Functions to calculate chemical descriptors based on SMILES
def myMolFromSmiles(smiles):
    """ Function to create mol object from SMILES performing partial sanitization when necessary

    Inputs
    ----------
    smiles : str, mandatory
        SMILES string

    Outputs
    ----------
    mol: object
        RDKit mol object

    """
    
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:  # try partial sanitization
        try:
            mol = Chem.MolFromSmiles(smiles, sanitize=False)
            mol.UpdatePropertyCache(strict=False)
            Chem.SanitizeMol(mol,
                             Chem.SanitizeFlags.SANITIZE_FINDRADICALS | Chem.SanitizeFlags.SANITIZE_KEKULIZE |
                             Chem.SanitizeFlags.SANITIZE_SETAROMATICITY | Chem.SanitizeFlags.SANITIZE_SETCONJUGATION |
                             Chem.SanitizeFlags.SANITIZE_SETHYBRIDIZATION | Chem.SanitizeFlags.SANITIZE_SYMMRINGS,
                             catchErrors=True)
            print('Partial sanitization: ' + smiles)
        except:
            print('Partial sanitization failed - return none: ' + smiles)

    return mol


def calculate_descriptors_rdkit(smiles, rdkit_desc='all'):
    """ Wrapper function that calculates all RDKit molecular descriptors listed in rdkit_desc for a single SMILES

    Inputs
    ----------
    smiles : str, mandatory
        The SMILES string
    rdkit_desc: 'all' or list, default: 'all'
        The specified list of RDKit descriptor names. If 'all', all available RDKit descriptors are calculated

    Outputs
    ----------
    array of calculated RDKit descriptors
    """
    

    if rdkit_desc == 'all':
        # list all molecular descriptors in RDKit
        rdkit_desc = [x[0] for x in Chem.Descriptors._descList]

    if smiles=='invalid SMILES':
        return tuple(np.full(len(rdkit_desc), np.nan))
    else:
        mol = myMolFromSmiles(smiles)
        calculator = MoleculeDescriptors.MolecularDescriptorCalculator(rdkit_desc)
        return calculator.CalcDescriptors(mol)


def apply_descriptors_rdkit(series, rdkit_desc='all'):
    """ Wrapper function that calculates all RDKit molecular descriptors listed in rdkit_desc for a series of SMILES

    Inputs
    ----------
    series : pandas series, mandatory
        The series containing SMILES
    rdkit_desc: 'all' or list, default: 'all'
        The specified list of RDKit descriptor names. If 'all', all available RDKit descriptors are calculated

    Outputs
    ----------
    dataframe of calculated RDKit descriptors
    """

    if rdkit_desc == 'all':
        # list all molecular descriptors in RDKit
        rdkit_desc = [x[0] for x in Chem.Descriptors._descList]

    d = series.apply(calculate_descriptors_rdkit, rdkit_desc=rdkit_desc)

    return pd.DataFrame.from_records(d, columns=rdkit_desc)


# Functions to load and run pre-processing and prediction models
@lru_cache(maxsize=None)
def load_models(endpoint='nc', standardized=True):

    if standardized:
        pipe_name = 'final_model_pipe_rdkit_{endpoint}.pkl'.format(endpoint=endpoint)
        desc_name = 'final_set_desc_rdkit_{endpoint}.csv'.format(endpoint=endpoint)
        model_name = 'final_model_poduam_CI95_rdkit_{endpoint}.pkl'.format(endpoint=endpoint)
    else:
        pipe_name = 'diverse_model_pipe_rdkit_{endpoint}.pkl'.format(endpoint=endpoint)
        desc_name = 'diverse_set_desc_rdkit_{endpoint}.csv'.format(endpoint=endpoint)
        model_name = 'diverse_model_poduam_CI95_rdkit_{endpoint}.pkl'.format(endpoint=endpoint)

    models=dict()
    models['pipe'] = pickle.load(open(app_dir / 'final_models' / pipe_name, 'rb'))
    models['desc_list'] = pd.read_csv(app_dir/ 'final_models' / desc_name).iloc[:, 0].tolist()
    models['main'] = pickle.load(open(app_dir / 'final_models'/ model_name, 'rb'))

    return models
    
# Function to make predictions
def make_prediction(models, descriptors, mw):
    # Unpack models
    pipe = models['pipe']
    model = models['main']

    # Use the loaded model to make predictions
    X = pd.DataFrame(pipe.transform(descriptors))
    X.columns = descriptors.columns
    X = X[models['desc_list']]

    # - predict chemicals in batches
    model.predict(X)
    y_pred = np.log10(10 ** pd.Series(model.test_y_median_base) * mw['MolWt'] * 1e3) 
    y_lower = np.log10(10 ** pd.Series(model.test_y_lower_uacqrs) * mw['MolWt'] * 1e3) 
    y_upper = np.log10(10 ** pd.Series(model.test_y_upper_uacqrs) * mw['MolWt'] * 1e3)

    df = pd.concat([y_pred, y_lower, y_upper], axis=1)

    return df


# Functions to canonicalize non-standardized SMILES
def mol_with_atom_index(mol):
    """ Support function numbering each atom in a mol object

    Inputs
    ----------
    mol : object, mandatory
        RDKit mol object created from SMILES or InChI

    Outputs
    ----------
    mol: object
        RDKit mol object with numbered atoms

    """

    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(atom.GetIdx() + 1)
    return mol


def mol_without_atom_index(mol):
    """ Support function removing atom numbering in a mol object

    Inputs
    ----------
    mol : object, mandatory
        RDKit mol object with numbered atoms

    Outputs
    ----------
    mol: object
        RDKit mol object without numbering

    """
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    return mol


def remove_chiral_centers(smiles):
    """ Support function to remove chiral information from SMILES

    Inputs
    ----------
    smiles : str, mandatory
        SMILES string

    Outputs
    ----------
    res_smiles: str
        SMILES string without chiral information

    """
    m = myMolFromSmiles(smiles)
    if m is not None:
        all_chiral = Chem.FindMolChiralCenters(m)
        all_chiral_centers = [sublist[0] for sublist in all_chiral]
        if len(all_chiral_centers) > 0:
            for each in all_chiral_centers:
                m.GetAtomWithIdx(each).SetChiralTag(Chem.ChiralType.CHI_UNSPECIFIED)
        res_smiles = Chem.MolToSmiles(m)
    else:
        res_smiles = smiles

    return res_smiles


def remove_cis_trans(smiles):
    """ Support function to remove cis/trans information from SMILES

    Inputs
    ----------
    smiles : str, mandatory
        SMILES string

    Outputs
    ----------
    res_smiles: str
        SMILES string without cis/trans information

    """
    m = myMolFromSmiles(smiles)
    if m is not None:
        for b in m.GetBonds():
            if b.GetStereo() in {Chem.rdchem.BondStereo.STEREOE, Chem.rdchem.BondStereo.STEREOZ,
                                 Chem.rdchem.BondStereo.STEREOCIS, Chem.rdchem.BondStereo.STEREOTRANS,
                                 Chem.rdchem.BondStereo.STEREOANY}:
                b.SetStereo(Chem.rdchem.BondStereo.STEREONONE)

        res_smiles = Chem.MolToSmiles(m)
    else:
        res_smiles = smiles

    return res_smiles


def create_canonical_ordered_smiles(smiles, remove_numbers=True):
    """ Function creates canonicalized ordered SMILES with optional atom numbering
    by creating mol files from ordered SMILES strings, applying tautomerization and 
    (optionally) adding atom numbering before converting back to SMILES

    Inputs
    ----------
    smiles : str, mandatory
        SMILES string
    remove_numbers: bool, default: False
        if True returns SMILES without explicit atom numbering

    Outputs
    ----------
    canonical_order_smiles : str
        canonicalized ordered SMILES string

    """

    # remove stereochemistry
    smiles = remove_chiral_centers(remove_cis_trans(smiles))

    # reorder SMILES
    if myMolFromSmiles(smiles) is None:
        new_mol = None
    else:
        mod_smi = Chem.MolToSmiles(myMolFromSmiles(smiles))
        new_mol = myMolFromSmiles(mod_smi)

    if new_mol is None:  # return empty string if still no mol
        print('No mol: ' + smiles)
        return 'invalid SMILES'
    else:
        # Tautomerize
        try:
            enumerator = rdMolStandardize.TautomerEnumerator()
            new_mol = enumerator.Canonicalize(new_mol)
        except:
            print('No tautomerization:' + smiles)

        # Add numbering
        new_mol_numbered = mol_with_atom_index(new_mol)

        if remove_numbers:
            new_mol_numbered = mol_without_atom_index(new_mol_numbered)

        canonical_order_smiles = Chem.MolToSmiles(new_mol_numbered)
        return canonical_order_smiles


def apply_canonical_ordered_smiles(series, remove_numbers=True):
    """ Function to create canonicalized ordered SMILES for all SMILES in a pandas series

    Inputs
    ----------
    series : pandas series, mandatory
        Series containing a columns with SMILES
    remove_numbers: bool, default: False
        if True returns SMILES without explicit atom numbering

    Outputs
    ----------
    Pandas series containing canonicalized ordered SMILES

    """

    return series.apply(create_canonical_ordered_smiles, remove_numbers=remove_numbers)


# Load data for predicted marketed chemicals
# endpoint='rd'
# standardized=False

# df_name = f'out_final_model_market_{endpoint}.csv' if standardized else f'out_diverse_model_market_{endpoint}.csv'
# df = pd.read_csv(app_dir / 'data' / df_name)
# df.to_parquet(app_dir / 'data' / df_name.replace('.csv', '.parquet'))

@lru_cache(maxsize=None)
def load_data(endpoint='nc', standardized=True):
    df_name = f'out_final_model_market_{endpoint}.parquet' if standardized else f'out_diverse_model_market_{endpoint}.parquet'
    
    extra_list = ['Canonical_QSARr', 'Salt_Solvent'] if standardized else ['canonical order SMILES']
    helper_list = ['yhat_lo_mg', 'yhat_up_mg', 'uhat']
    usecols = ['INCHIKEY', 'CASRN', 'PREFERRED_NAME', 'yhat_mg', 'dJ', 'TSNE1', 'TSNE2', 
               'Kingdom', 'Superclass', 'Class', 'Subclass', 'SMILES']

    df = pd.read_parquet(app_dir / 'data' / df_name, columns=usecols+extra_list+helper_list)

    # Create bins for hue coloring 
    bins_uhat = pd.IntervalIndex.from_tuples([(0, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 10)])
    bins_yhat = pd.IntervalIndex.from_tuples([(-10, 0), (0, 0.5), (0.5, 1), (1, 1.5), (1.5, 2), (2, 10)])

    df['Toxicity [log10 mg/kg-d]'] = pd.cut(df['yhat_mg'], bins_yhat).astype(str)
    df['Uncertainty [95% CI width]'] = pd.cut(df['uhat'], bins_uhat).astype(str)

    # Parameter processing and re-naming
    df['CASRN'] = df['CASRN'].str.strip("[]'").str.strip("[]").str.replace(r"'(.*?)'", r"\1", regex=True)
    df['PREFERRED_NAME'] = df['PREFERRED_NAME'].str.strip("[]'") 
    df['yhat_mg'] = df[['yhat_mg', 'yhat_lo_mg', 'yhat_up_mg']].apply(lambda row: f"{row[0]:.2f} [{row[1]:.2f}, {row[2]:.2f}]", axis=1)
    df['dJ'] = df['dJ'].round(2)
    df['TSNE1'] = df['TSNE1'].round(2)
    df['TSNE2'] = df['TSNE2'].round(2)


    usenames = ['InChIKey', 'CAS RN', 'Chemical name', 'Predicted POD [95% CI]', 
                'Jaccard distance', 'TSNE1', 'TSNE2', 
                'Kingdom', 'Superclass', 'Class', 'Subclass', 'SMILES (original)']
    extranames = ['SMILES (QSAR-ready)', 'stripped'] if standardized else ['SMILES (canonical)']

    df.rename(columns=dict(zip(usecols+extra_list, usenames + extranames)), inplace=True)

    return df
