import os
import glob
import shutil
import subprocess
import numpy as np
import pandas as pd
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
# from ase.io import read, write
import spglib

def print_red(message):
    # ANSI escape code for red text
    red = "\033[91m"
    reset = "\033[0m"
    print(f"{red}{message}{reset}")

def print_green(message):
    # ANSI escape code for red text
    red = "\033[92m"
    reset = "\033[0m"
    print(f"{red}{message}{reset}")

def print_blue(message):
    # ANSI escape code for red text
    red = "\033[93m"
    reset = "\033[0m"
    print(f"{red}{message}{reset}")


def descriptor_generator(cwd, base_dir, zeo_path, name, structure_path, wiggle_room):
    
    # Preliminary work
    print_blue(f"\nProcessing MOF {i+1}/{len(cif_paths)}: {MOF_name}")
    #current_MOF_folder = f'{base_dir}/feature_folders/{name}'
    current_MOF_folder = f'{cwd}/feature_folders/{name}'
    cif_folder = f'{current_MOF_folder}/cifs'
    RACs_folder = f'{current_MOF_folder}/RACs'
    zeo_folder = f'{current_MOF_folder}/zeo++'
    merged_descriptors_folder = f'{current_MOF_folder}/merged_descriptors'
    delete_and_remake_folders([current_MOF_folder, cif_folder, RACs_folder, zeo_folder, merged_descriptors_folder])


    # Getting the primitive structure
    primitive_saved = get_primitive(structure_path, f'{cif_folder}/{name}_primitive.cif')
    if not primitive_saved:
        print(f"  Skipping MOF {name} due to get_primitive error.")
        return None
    structure_path = f'{cif_folder}/{name}_primitive.cif'
    
    
    # Running Zeo++ (cmd1,2,3) for geometrical features and MolSimplify for RACS (cmd4)
    cmd1 = f'{zeo_path}/network -ha -res {zeo_folder}/{name}_pd.txt {structure_path} > {zeo_folder}/{name}_pd_error.log 2>&1'
    cmd2 = f'{zeo_path}/network -sa 1.4 1.4 10000 {zeo_folder}/{name}_sa.txt {structure_path} > {zeo_folder}/{name}_sa_error.log 2>&1'
    cmd3 = f'{zeo_path}/network -volpo 1.4 1.4 10000 {zeo_folder}/{name}_pov.txt {structure_path} > {zeo_folder}/{name}_pov_error.log 2>&1'
    cmd4 = f'python3 RAC_getter.py {structure_path} {name} {RACs_folder} {wiggle_room}'
    try:
        subprocess.Popen(cmd1, shell=True).communicate()
        subprocess.Popen(cmd2, shell=True).communicate()
        subprocess.Popen(cmd3, shell=True).communicate()
    except Exception as e:
        print_red(f"  Failed to run Zeo++ {name}. Error: {e}")
        return None
    try:
        subprocess.Popen(cmd4, shell=True).communicate()
    except Exception as e:
        print_red(f"  Failed to run RAC commands for {name}. Error: {e}")
        return None
    
    # Zeo++ check for failure
    if not (os.path.exists(f'{zeo_folder}/{name}_pd.txt') and os.path.exists(f'{zeo_folder}/{name}_sa.txt') and os.path.exists(f'{zeo_folder}/{name}_pov.txt')):
        print_red(f"  Zeo++ files missing for {name}. Skipping this MOF.")
        return None
    # Extract geometry
    geo_dict = extract_zeopp_features(name, zeo_folder)
    if geo_dict is None:
        print_red(f"  Skipping MOF {name} due to missing Zeo++ data.")
        return None
    
    # MolSimplify check for failure
    rac_log = f'{RACs_folder}/RAC_getter_log.txt'
    if os.path.exists(rac_log):
        with open(rac_log, 'r') as f:
            if 'FAILED' in f.read():
                print_red(f"  RAC generation failed for {name}. Skipping MOF.")
                return None
    #Extract RACS
    rac_features = extract_rac_features(RACs_folder)
    if rac_features is None:
        print_red(f"  -Skipping MOF {name} due to missing RAC data.")
        return None

    # Merge the features
    merged_features = {**geo_dict, **rac_features}

    # Temporary output for each MOF
    pd.DataFrame([merged_features]).to_csv(f'{current_MOF_folder}/{name}_features.csv', index=False)
    print_green(f"  Temporary features file created for {name}.")

    return merged_features


def delete_and_remake_folders(folder_names):
    for folder_name in folder_names:
        if os.path.isdir(folder_name):
            shutil.rmtree(folder_name)
        os.makedirs(folder_name, exist_ok=True)
        print(f"  -Directory created: {folder_name}")

def get_primitive(structure_path, output_path):
    try:
        print(f"  -Reading CIF file: {structure_path}")
        # Read CIF file using pymatgen
        structure = Structure.from_file(structure_path)
        
        print("  -Attempting to find primitive cell using SpacegroupAnalyzer...")
        # Find the primitive cell
        sga = SpacegroupAnalyzer(structure, symprec=1e-3)
        primitive_structure = sga.find_primitive()
        
        if primitive_structure:
            print_green("  -Primitive cell found, saving structure.")
            # Write the primitive structure to CIF
            primitive_structure.to(filename=output_path)
        else:
            print_red("  -Warning: Could not find a primitive cell. Using original structure.")
            structure.to(filename=output_path)
        
        print_green(f"  -Primitive CIF saved to: {output_path}")
        return True

    except Exception as e:
        print_red(f"  -Error in get_primitive for {structure_path}: {e}")
        return False

def extract_zeopp_features(name, zeo_folder):
    dict_list = []
    geo_dict = {}
    cif_file = name + '.cif'
    largest_included_sphere, largest_free_sphere, largest_included_sphere_along_free_sphere_path  = np.nan, np.nan, np.nan
    unit_cell_volume, crystal_density, VSA, GSA  = np.nan, np.nan, np.nan, np.nan
    VPOV, GPOV = np.nan, np.nan
    POAV, PONAV, GPOAV, GPONAV, POAV_volume_fraction, PONAV_volume_fraction = np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
    
    if os.path.exists(f'{zeo_folder}/{name}_pd.txt'):
        with open(f'{zeo_folder}/{name}_pd.txt') as f:
            pore_diameter_data = f.readlines()
            for row in pore_diameter_data:
                largest_included_sphere = float(row.split()[1]) # largest included sphere
                largest_free_sphere = float(row.split()[2]) # largest free sphere
                largest_included_sphere_along_free_sphere_path = float(row.split()[3]) # largest included sphere along free sphere path
    else:
        print_red(f'  PD file does not exist for {name}!', ' pd: ', os.path.exists(f'{zeo_folder}/{name}_pd.txt'))
    if os.path.exists(f'{zeo_folder}/{name}_sa.txt'):
        with open(f'{zeo_folder}/{name}_sa.txt') as f:
            surface_area_data = f.readlines()
            for i, row in enumerate(surface_area_data):
                if i == 0:
                    unit_cell_volume = float(row.split('Unitcell_volume:')[1].split()[0]) # unit cell volume
                    crystal_density = float(row.split('Density:')[1].split()[0]) # crystal density
                    VSA = float(row.split('ASA_m^2/cm^3:')[1].split()[0]) # volumetric surface area
                    GSA = float(row.split('ASA_m^2/g:')[1].split()[0]) # gravimetric surface area
    else:
        print_red(f'  SA file does not exist for {name}!', ' sa: ', os.path.exists(f'{zeo_folder}/{name}_sa.txt'))
    if os.path.exists(f'{zeo_folder}/{name}_pov.txt'):
        with open(f'{zeo_folder}/{name}_pov.txt') as f:
            pore_volume_data = f.readlines()
            for i, row in enumerate(pore_volume_data):
                if i == 0:
                    density = float(row.split('Density:')[1].split()[0])
                    POAV = float(row.split('POAV_A^3:')[1].split()[0]) # probe-occupiable accessible volume
                    PONAV = float(row.split('PONAV_A^3:')[1].split()[0]) # probe-occupiable non-accessible volume
                    GPOAV = float(row.split('POAV_cm^3/g:')[1].split()[0])
                    GPONAV = float(row.split('PONAV_cm^3/g:')[1].split()[0])
                    POAV_volume_fraction = float(row.split('POAV_Volume_fraction:')[1].split()[0])
                    PONAV_volume_fraction = float(row.split('PONAV_Volume_fraction:')[1].split()[0])
                    VPOV = POAV_volume_fraction+PONAV_volume_fraction
                    GPOV = VPOV/density
    else:
        print_red(f'  POV file does not exist for {name}!', ' pov: ', os.path.exists(f'{zeo_folder}/{name}_sa.txt'))
    
    geo_dict = {'name': name, 'cif_file': cif_file, 'Di': largest_included_sphere, 'Df': largest_free_sphere, 'Dif': largest_included_sphere_along_free_sphere_path, 
                'unit_cell_volume': unit_cell_volume,'rho': crystal_density, 'VSA': VSA, 'GSA': GSA, 'VPOV': VPOV, 'GPOV': GPOV, 'POAV_vol_frac': POAV_volume_fraction, 
                'PONAV_vol_frac': PONAV_volume_fraction, 'GPOAV': GPOAV, 'GPONAV': GPONAV, 'POAV': POAV, 'PONAV': PONAV}
    dict_list.append(geo_dict)
    geo_df = pd.DataFrame(dict_list)
    geo_df.to_csv(f'{zeo_folder}/geometric_parameters.csv', index=False)
    return geo_dict

def extract_rac_features(RACs_folder):
    try:
        features = pd.read_csv(f'{RACs_folder}/featurization_frame.csv')
        rac_features = features.to_dict(orient='records')[0]
        return rac_features
    except Exception as e:
        print_red(f"  -Error reading RAC descriptors: {e}")
        return None
    


### Main Execution ###

cwd = os.getcwd()
base_dir = "/scratch/ml/coremof/"
zeo_path = "/scratch/ml/zeo/"

if not os.path.exists(f'{cwd}/feature_folders'):
    os.makedirs(f'{cwd}/feature_folders', exist_ok=True)

cif_paths = glob.glob(f'{base_dir}/*.cif')
cif_paths.sort()
print_blue(f"\n\nFound {len(cif_paths)} CIF files to process.")

final_df_content = []
skipped_mofs = []

for i, cp in enumerate(cif_paths):
    MOF_name = os.path.basename(cp).replace('.cif', '')
    
    wiggle_room = 1
    mof_features = descriptor_generator(cwd, base_dir, zeo_path, MOF_name, cp, wiggle_room)

    if mof_features:
        final_df_content.append(mof_features)
    else:
        skipped_mofs.append(MOF_name)

final_df = pd.DataFrame(final_df_content)
if not final_df.empty:
    
    final_output_path = f'{cwd}/test-MOFs.csv'
    final_df = final_df.sort_values(by=['name'])
    final_df.to_csv(final_output_path, index=False)
    print_green(f"\n\nSUCCESS! Processing completed. Output saved to {final_output_path}")
    
    if skipped_mofs:
        print("\nNote: The following MOFs were skipped due to errors or missing data:")
        for mof in skipped_mofs:
            print(f"- {mof}")
else:
    print_red("\n\n TOTAL FAILURE! No MOFs were successfully processed. No output file generated.")


