from narwhals import col
from shiny import reactive
from shiny.express import input, render, ui

from functools import partial
from shiny.ui import page_navbar

import shared
import numpy as np
import pandas as pd
import time
import io

from shinywidgets import render_widget
import plotly.express as px
import matplotlib.colors as mcolors
import seaborn as sns
import plotly.graph_objects as go
import matplotlib.pyplot as plt

ui.page_opts(title="PODUAM", fillable=True, page_fn=partial(page_navbar, id="page"))

with ui.nav_panel("Predict chemical PODs"):
    with ui.card(full_screen=True):
        
        ui.card_header("Provide SMILES to make predictions")

        with ui.layout_columns(col_widths=[4, 6, 2]):
            ui.input_radio_buttons("model_type", "Choose SMILES type", ['standardized (OPERA QSAR ready)', 'non-standardized (preliminary results)'], selected=None, inline=True, width='100%')
            
            ui.input_text_area("smiles", "Enter SMILES (one per line)", width='100%')

            with ui.div(style="margin-top: 35px;"):
                ui.input_action_button("action_button", "Predict", 
                style="width: 80%; height: 50px; font-size: 18px; text-align: center; display: flex; justify-content: center; align-items: center;")        
                        
    with ui.card(full_screen=True):

        ui.card_header("Results")
        
        calculation_done = reactive.value(False)
        elapsed_time = reactive.value("")
        download_df = reactive.value(None)
        
        # Run prediction model
        @render.data_frame
        @reactive.event(input.action_button)  # Run when the button is clicked
        def prediction_result():
            with ui.Progress(min=0, max=5) as p:  # Define the progress bar
                calculation_done.set(False)
                start_time = time.time()
                p.set(message="Initializing...", value=1)

                orig_smiles = pd.Series(input.smiles().splitlines(), name='SMILES')
                model_type = input.model_type()
        
                # Conditional SMILES pre-processing and model loading
                p.set(message="Loading models...", value=2)

                if input.smiles()=='':
                    return "Please enter a valid SMILES string."

                if model_type == 'standardized (QSAR ready)':
                    smiles = orig_smiles
                    models_rd = shared.load_models(endpoint='rd', standardized=True)
                    models_nc = shared.load_models(endpoint='nc', standardized=True)
                else:
                    smiles = shared.apply_canonical_ordered_smiles(orig_smiles)
                    if smiles[0]=='':
                        return "Please enter a valid SMILES string."
                    
                    models_rd = shared.load_models(endpoint='rd', standardized=False)
                    models_nc = shared.load_models(endpoint='nc', standardized=False)


                # Calculate descriptors
                p.set(message="Calculating descriptors...", value=3)

                descriptors = shared.apply_descriptors_rdkit(smiles)
                mw = shared.apply_descriptors_rdkit(smiles, rdkit_desc=['MolWt'])
        
                # Make predictions
                p.set(message="Making predictions...this may take a while...", value=4)

                prediction_rd = shared.make_prediction(models_rd, descriptors, mw)
                prediction_nc = shared.make_prediction(models_nc, descriptors, mw)

                # Finalizing results
                p.set(message="Almost there...", value=5)

  
                # Calculate ED10 with adapted uncertainty intervals
                ed10_rd = np.log10((10**prediction_rd)/3.49)
                ed10_nc = np.log10((10**prediction_nc)/3.49)
                
                GSD2_ed10 = 2.67
                
                GSD2_pred_rd_lo = (10**prediction_rd[0])/(10**prediction_rd[1])
                GSD2_pred_rd_up = (10**prediction_rd[2])/(10**prediction_rd[0])
                GSD2_pred_nc_lo = (10**prediction_nc[0])/(10**prediction_nc[1])
                GSD2_pred_nc_up = (10**prediction_nc[2])/(10**prediction_nc[0])

                GSD2_rd_lo = 10**((np.log10(GSD2_pred_rd_lo)**2 + np.log10(GSD2_ed10)**2)**0.5)
                GSD2_rd_up = 10**((np.log10(GSD2_pred_rd_up)**2 + np.log10(GSD2_ed10)**2)**0.5)
                GSD2_nc_lo = 10**((np.log10(GSD2_pred_nc_lo)**2 + np.log10(GSD2_ed10)**2)**0.5)
                GSD2_nc_up = 10**((np.log10(GSD2_pred_nc_up)**2 + np.log10(GSD2_ed10)**2)**0.5)

                ed10_rd[1] = np.log10(10**ed10_rd[0]/GSD2_rd_lo)
                ed10_rd[2] = np.log10(10**ed10_rd[0]*GSD2_rd_up)
                ed10_nc[1] = np.log10(10**ed10_nc[0]/GSD2_nc_lo)
                ed10_nc[2] = np.log10(10**ed10_nc[0]*GSD2_nc_up)

                # - prepare df for download
                df_csv = pd.concat([orig_smiles, smiles, pd.DataFrame(prediction_rd), pd.DataFrame(prediction_nc),
                                    pd.DataFrame(ed10_rd), pd.DataFrame(ed10_nc)], axis=1)
                df_csv.columns = ['Original SMILES', 'Processed SMILES', 
                              'Predicted PODrd (log10 mg/kg-d)', 'PODrd (2.5%)', 'PODrd (97.5%)', 
                              'Predicted PODnc (log10 mg/kg-d)', 'PODnc (2.5%)', 'PODnc (97.5%)',
                              'Derived ED10rd (log10 mg/kg-d)', 'ED10rd (2.5%)', 'ED10rd (97.5%)',
                              'Derived ED10nc (log10 mg/kg-d)', 'ED10nc (2.5%)', 'ED10nc (97.5%)']

                download_df.set(df_csv)

                # - prepare df for display
                formatted_rd = prediction_rd.apply(lambda row: f"{row[0]:.2f} [{row[1]:.2f}, {row[2]:.2f}]", axis=1)
                formatted_nc = prediction_nc.apply(lambda row: f"{row[0]:.2f} [{row[1]:.2f}, {row[2]:.2f}]", axis=1)

                formatted_ed10_rd = ed10_rd.apply(lambda row: f"{row[0]:.2f} [{row[1]:.2f}, {row[2]:.2f}]", axis=1)
                formatted_ed10_nc = ed10_nc.apply(lambda row: f"{row[0]:.2f} [{row[1]:.2f}, {row[2]:.2f}]", axis=1)

                df = pd.concat([orig_smiles, smiles, formatted_rd, formatted_nc, formatted_ed10_rd, formatted_ed10_nc], axis=1)
                df.columns = ['Original SMILES', 'Processed SMILES', 'Predicted PODrd (log10 mg/kg-d)', 'Predicted PODnc (log10 mg/kg-d)',
                              'Derived ED10rd (log10 mg/kg-d)', 'Derived ED10nc (log10 mg/kg-d)']

                col_styles = [
                    {   "col": [0, 1],
                        "style": {
                            "min-width": "300px",
                            "min-height": "50px",
                            "max-width": "350px",
                            "max-height": "100px",
                            "white-space": "nowrap",
                            "overflow": "auto",
                        },

                    }
                ]

                end_time = time.time() 
                elapsed_time.set(f"⏱️ Completed in {end_time - start_time:.2f} seconds.")
                calculation_done.set(True)

                return render.DataGrid(df, styles=col_styles)

        # Display a legend
        @render.ui
        def show_legend():
            if not calculation_done():
                return ''  # Don't display anything if calculation is not done

            return ui.div(
                ui.HTML(
                    "<p style='font-style: italic; font-size: 11px; line-height: 1.2;'>"
                    "<span style='font-weight: bold;'>Legend:</span> "
                    "POD: point of departure, rd: reproductive/developmental toxicity, nc: general non-cancer toxicity, ED10: eﬀect dose at which 10% of the population shows an eﬀect M, also called HD<sub>M</sub><sup>10%</sup>"
                    " \n <span style='font-weight:bold;'> Values in square brackets represent the 95% confidence interval</span> "
                    "</p>"),
                style="margin-top: 1px; text-align: center;")
        
        # Display time elapsed to run the models
        @render.text
        def show_elapsed_time():
            return elapsed_time()
        
        # Download button to save results as CSV 
        @render.download(label="Download prediction results as CSV", filename="PODUAM_results.csv")
        def download_csv():
            df_csv = download_df()

            if df_csv is None:
                return None  

            csv_io = io.StringIO()
            df_csv.to_csv(csv_io, index=False)
            csv_io.seek(0)
 
            return csv_io
        
    # About
    with ui.card():
        ui.card_header("About this app", class_="bg-light")
        
        ui.markdown(
            """ This app allows generating predictions using the PODUAM models by von Borries et al. (2025), 
            trained on the extensive probabilistic POD dataset from [Aurisano et al. (2023)](https://doi.org/10.1289/EHP11524). 
            The models estimate points of departure (PODs) for reproductive/developmental and general non-cancer human toxicity, 
            including 95% confidence intervals. POD predictions are additionally used to derive ED10 values, representing the dose 
            at which 10% of the population is affected. This is achieved by applying an extrapolation factor of 3.49 and adjusting 
            confidence intervals using an uncertainty factor (P97.5/P50 = 2.67) as outlined in the [WHO/IPCS framework](https://iris.who.int/handle/10665/259858). 
            Page 2 allows exploring PODUAM predictions for >130,000 globally marketed chemicals leveraging the chemical space map 
            by [von Borries et al. (2023)](https://doi.org/10.1021/acs.est.3c05300).
            """
        )

with ui.nav_panel("Explore global chemical market"):  
        
        with ui.layout_sidebar():
            with ui.sidebar(title="Select data options", width=300):
                ui.input_radio_buttons('endpoint', 'Toxicity endpoint', ['reproductive/developmental', 'general non-cancer'], selected=None, inline=False, width=None)
                ui.input_radio_buttons('chem_set', 'Chemical set', ['standardized (QSAR ready)', 'non-standardized (preliminary result)'], selected=None, inline=False, width=None)

                ui.input_select("hue_column", "Select a column to color by", 
                                choices=['Toxicity [log10 mg/kg-d]', 'Uncertainty [95% CI width]',
                                         'Superclass', 'Class', 'Subclass'])
                
            
            with ui.card(full_screen=True): #style="height: 800px;"
                ui.card_header("Chemical space map")
                
                # @output(id='chemical_space_plot')
                @render_widget
                def chemical_space_plot():
                    df = loaded_data()  # Use cached data
                    if df.empty:
                        return px.scatter(title="No data available")

                    # Define custom order (equivalent to seaborn's hue_order)
                    hue_column = input.hue_column()

                    # Set palette
                    if hue_column in ['Toxicity [log10 mg/kg-d]', 'Uncertainty [95% CI width]']:
                        hue_order = sorted(df[hue_column].unique())
                        if hue_column == 'Toxicity [log10 mg/kg-d]':
                            hue_order = hue_order[::-1]
                        colorscale = sns.diverging_palette(250, 10, s=100, l=50, sep=1, n=len(hue_order), center="light")
                        colors = [mcolors.to_hex(c) for c in colorscale]
                        color_map = dict(zip(hue_order, colors)) 

                    else:
                        hue_order = df[hue_column].value_counts().index.tolist() # sort by frequency
                        df[hue_column] = pd.Categorical(df[hue_column], categories=hue_order, ordered=True)
                
                        colors = [mcolors.rgb2hex(c) for c in plt.cm.tab20.colors]
                        # - repeat colors if more than 20 categories
                        repeat_colors = (colors * (len(hue_order) // len(colors) + 1))[:len(hue_order)]
                        color_map = dict(zip(hue_order, repeat_colors))

                    
                    # Scatterplot
                    df.fillna('not assigned', inplace=True)
                    fig = px.scatter(df, x="TSNE1", y="TSNE2", color=hue_column,
                                    color_discrete_map=color_map,
                                    category_orders = {hue_column: hue_order},
                                    hover_name = 'Chemical name',
                                    hover_data=['CAS RN', 'Predicted POD [95% CI]', 'Superclass', 'Class', 'Subclass'],
                                    render_mode="webgl",
                                    height=700, width=1200
                                    )
                    fig.update_traces(showlegend=False)

                    fig.update_traces(marker=dict(size=3, opacity=0.3, line=dict(width=0))) # plot markers
                    fig.update_traces(marker=dict(size=10), opacity=1, selector=dict(legendgroup=True)) # legend markers

                    # Add custom legend traces to control the opacity of the legend markers
                    categories = hue_order

                    for category in categories:
                        fig.add_trace(go.Scatter(
                            x=[None], y=[None],  # Invisible points
                            mode='markers',
                            marker=dict(color=color_map[category], size=12, opacity=1),  # Full opacity for the legend
                            legendgroup=category,
                            showlegend=True,
                            name=category,
                        ))

                    fig.update_layout(
                        dragmode="zoom",  # - enables zooming
                        uirevision=True,  # Track the zoom and pan state
                        xaxis=dict(title="TSNE1",visible=False, showgrid=False, zeroline=False, 
                                fixedrange=False),
                        yaxis=dict(title="TSNE2",visible=False, showgrid=False, zeroline=False, 
                                fixedrange=False),
                        modebar=dict(add=["zoom", "pan", "resetScale"]),
                        plot_bgcolor="rgba(0,0,0,0)",
                        paper_bgcolor="rgba(0,0,0,0)",
                        legend_title_text=hue_column,
                        legend=dict(itemsizing='constant',  # Control legend item sizing
                                    itemwidth=30,  # Adjust this value to control the item width
                                    tracegroupgap=0,  # Gap between legend items
                                    x=1.0,  # Position the legend towards the right
                                    y=0.5,  # Center vertically
                                    xanchor='left',
                                    yanchor='middle',
                                    orientation='v',  # Arrange legend vertically
                                    font=dict(family="Arial", size=10)))
                              
                    return fig


            with ui.card(full_screen=True):  # bottom: dataframe
                ui.card_header("Browse data table")

                # Render dataframe with filtering enabled**
                @render.data_frame
                def display_data():
                    df = loaded_data()

                    if df.empty:
                        return "No data available. Please select options."

                    styles = [
                        {"style": {
                        "min-width": "50px",
                        "min-height": "50px",
                        "max-width": "200px",
                        "max-height": "100px",
                        "white-space": "nowrap",
                        "overflow": "auto",
                        }}]
                    
                    return render.DataGrid(df.iloc[:, :-2], styles=styles, height='300px', filters=True)

                
# REACTIVE CODE
# - Reactive Cached Dataset (Loads Only When User Inputs Change)
@reactive.calc
def loaded_data():
    endpoint = "rd" if input.endpoint() == "reproductive/developmental" else "nc"
    standardized = input.chem_set() == "standardized (QSAR ready)"
    
    return shared.load_data(endpoint=endpoint, standardized=standardized)

