import glob
import json
import os
import shutil
import warnings
import logging
import copy
import numpy as np
import pandas as pd
import scanpy as sc
from scipy.sparse import load_npz
from sklearn.preprocessing import MinMaxScaler

from thor.pp import WholeSlideImage, Spatial
from thor.markov_graph_diffusion import (
    estimate_expression_markov_graph_diffusion,
    markov_graph_diffusion_initialize
)
from thor.plotting.graph import plot_cell_graph
from thor.utils import (
    generate_cell_adata, get_adata_layer_array, get_spot_heterogeneity_cv,
    var_cos
)
from thor.VAE import VAE, IdentityGenerator, train_vae

logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore")
sc.settings.verbosity = "error"


MANY_GENES = 200

default_run_params = dict(
    initialize=True,
    burn_in_steps=5,
    layer=None,
    is_rawCount=False,
    regulate_expression_mean=False,
    stochastic_expression_neighbors_level="spot",
    n_iter=20,
    conn_key="snn",
    write_freq=10,
    out_prefix="y",
    sample_predicted_expression_fluctuation_scale=1,
    smooth_predicted_expression_steps=0,
    save_chain=False,
    n_jobs=1,
)

default_graph_params = dict(
    n_neighbors=5,
    conn_key="snn",
    obs_keys=None,
    reduced_dimension_transcriptome_obsm_key="X_pca",
    reduced_dimension_transcriptome_obsm_dims=2,
    geom_morph_ratio=1,
    geom_constraint=0,
    adjust_cell_network_by_transcriptome_scale=0,
    snn_threshold=0.1,
    smoothing_scale=0.8,
    conn_csr_matrix=None,
    inflation_percentage=None,
    node_features_obs_list=["spot_heterogeneity"],
    preferential_flow=True,
    weigh_cells=True,
    balance_cell_quality=False,
    bcq_IQR=(0.15, 0.85),
)


class fineST:
    """Class for in silico cell gene expression inference

    Parameters
    ----------
    image_path : str
        Path to the HE staining image or an image of other types which is aligned to the spatial transcriptome spots (full resolution).
    name : str
        Name of the sample.
    spot_adata_path : str, optional
        Path to the processed spot adata (e.g., from the Visium sequencing data). 
        The counts/expression array (.X) and spots coordinates are required (.obsm["spatial"]). Expecting that adata.X is lognormalized.
        Either `spot_adata_path` or `st_dir` are needed. If `spot_adata_path` is provided, `st_dir` will be neglected.
    st_dir : str, optional
        Directory to the SpaceRanger output directory, where the count matrix and spatial directory can be found.
    cell_features_csv_path : str, optional
        Path to the CSV file that stored the cell features. First two columns are expected (exactly) to be the nuclei positions "x" and "y".
    cell_features_list : list or None, optional
        List of features to be used for generating the cell-cell graph. First two are expected (exactly) to be the nuclei positions "x" and "y". Features will be read from `cell_features_csv_path` csv file and the list will be used for selection. By default, if no external features are provided, those features 
        ["x", "y", "mean_gray", "std_gray", "entropy_img", "mean_r", "mean_g", "mean_b", "std_r", "std_g", "std_b"] are used.
    genes_path : str, optional
        Path to the file that contains a headless one column of the genes (same format as used in the `adata.var_names`) to be included for sure.
    save_dir : str or None, optional
        Path to the directory of saving fineST prediction results.
    recipe : str, optional
        Specifies the mode for predicting the gene expression. Valid choices are: "gene", "reduced", "mix".
            - "gene": use the user-set genes for prediction with *gene mode*. This includes the genes from the the `used_for_prediction` key in adata.var.
            - "reduced": use the reduced genes from the VAE model with *reduced mode*. Ignoring the genes from the `used_for_prediction` key in adata.var.
            - "mix": use the reduced genes from the VAE model using *reduced mode* and the rest of the user-set genes for prediction with *gene mode*.
    **kwargs : dict, optional
        Keyword arguments for any additional attributes to be set for the class. This allows future loading of the saved json file to create a new instance of the class.
    """

    def __init__(
        self,
        image_path,
        name,
        spot_adata_path=None,
        st_dir=None,
        cell_features_list=None,
        cell_features_csv_path=None,
        genes_path=None,
        save_dir=None,
        recipe="gene",
        **kwargs
    ):
        self.name = name
        if save_dir is None:
            save_dir = os.path.join(os.getcwd(), f"fineST_{name}")

        save_dir = os.path.abspath(save_dir)
        image_path = os.path.abspath(image_path)

        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir
        self.image_path = image_path

        assert (spot_adata_path is not None) or (
            st_dir is not None
        ), "Spot-level transcriptome data are required!"

        if spot_adata_path is not None:
            self.spot_adata_path = os.path.abspath(spot_adata_path)
        else:
            self.st_dir = os.path.abspath(st_dir)
            #self.process_transcriptome()
            st = Spatial(self.name, self.st_dir, image_path=self.image_path, save_dir=self.save_dir)
            st.process_transcriptome()
            logger.error("Please provide `spot_adata_path`")
            return

        self.set_cell_features_csv_path(cell_features_csv_path)
        self.set_cell_features_list(cell_features_list)

        self.genes = []
        if genes_path is not None:
            self.load_genes(genes_path)

        self.recipe = recipe
        self.graph_params = default_graph_params
        self.run_params = default_run_params
        self.__dict__.update(kwargs)

    def prepare_input(self, mapping_margin=10, spot_identifier="spot_barcodes"):
        """ Prepare the input for the fineST estimation.

            First, generate the cell-wise adata from the cell features and spot adata. In this step, the segmented cells will be read from the
            `self.cell_features_csv_path` and the outliers from the segmentation will be removed according to the distance between a cell and
            its nearest neighbor. Second, the spot gene expression is mapped to aligned nearest cells. Lastly, the spot heterogeneity will be
            computed using the image features for future construction of the cell-cell graph and the transition matrix.

        Parameters
        ----------
        mapping_margin : numeric, optional
            Margin for mapping the spot gene expression to the cells. Default is 10, which will attempt to map cells which are within 10- spot radius of any spot (so almost all identified cells are mapped to nearest spots). 
            Decrease this number if you would like to eliminate isolated cells.

        """

        if not hasattr(self, "adata"):
            adata_sc_nearest_spot = generate_cell_adata(
                self.cell_features_csv_path,
                self.spot_adata_path,
                obs_features=self.cell_features_list,
                mapping_margin=mapping_margin
            )

            self.adata = adata_sc_nearest_spot

        # compute spot heterogeneity, excluding the first two columns which are the 2D positions of the cells
        obs_edited = get_spot_heterogeneity_cv(
            self.adata.obs,
            self.cell_features_list[2:],
            spot_identifier
            )

        self.adata.obs.loc[:, "spot_heterogeneity"] = obs_edited
        #self.write_adata(f"{self.name}_adata_cell_pre-run.h5ad", self.adata)
        #self.data_pre_path = os.path.join(self.save_dir, f"{self.name}_adata_cell_pre-run.h5ad")

    def vae_training(
            self, vae_genes_set=None, min_mean_expression=0.1, **kwargs
    ):
        """ Train a VAE model for the spot-level transcriptome data. 

        Parameters
        ----------
        vae_genes_set : set, optional
            Set of genes to be used for VAE training. If None, all the genes (`adata.var.used_for_prediction`, which are specified in the `prepare_input` function) with mean expression > `min_mean_expression` will be used.
        min_mean_expression : float, optional
            Minimum mean expression for the genes to be used for VAE training.
        kwargs : dict
            Keyword arguments for the :py:func:`VAE.train_vae` function. 

        Returns
        -------
        None

        Notes
        -----
        This function internally calls the :py:func:`VAE.train_vae` function for training the VAE using the preprocessed transcriptomic data.

        """

        if vae_genes_set is None:
            vae_genes_set = set(
                self.adata.var_names[np.logical_and(
                    self.adata.var.used_for_prediction, self.adata.var.means
                    > min_mean_expression
                )]
            )

        assert vae_genes_set.issubset(
            set(self.adata.var_names[self.adata.var.used_for_prediction])
        )

        self.adata.var["used_for_vae"] = self.adata.var_names.isin(
            vae_genes_set
        )
        ad = self.adata[:, self.adata.var.used_for_vae]
        X = ad.X.toarray()
        scaler = MinMaxScaler()
        X_norm = scaler.fit_transform(X)

        model_path = os.path.join(self.save_dir, "VAE_models")
        model_save_prefix = os.path.join(model_path, self.name)
        kwargs.update({"save_prefix": model_save_prefix})

        train_vae(X_norm, **kwargs)
        self.model_path = model_path

        with open(
            os.path.join(model_path, f"vae_genes_{self.name}.csv"), "w"
        ) as tfile:
            tfile.write("\n".join(vae_genes_set))

    def load_genes(self, genes_file_path):
        """ Load the user-input genes to be used for prediction. 

        Parameters
        ----------
        genes_file_path: :py:class:`str`
            Path to the csv file that contains the genes to be used for prediction. The genes should be in the first column of the csv file.
            Genes should match the var_names used in adata.

        Returns
        -------
        None
            Update the `genes` attribute of the class.
        """

        genes_list = pd.read_csv(genes_file_path,
                                 header=None).values[:, 0].tolist()

        self.genes = genes_list

    def set_cell_features_csv_path(self, cell_features_csv_path=None):
        """ Set the path to the cell features csv file. 
        If `cell_features_csv_path` is None, the cell features csv file will be generated from the HE staining image. Otherwise, the `cell_features_csv_path` will be used for the cell features.
        """
        if cell_features_csv_path is not None:
            self.cell_features_csv_path = os.path.abspath(cell_features_csv_path)
            return None

        wsi = WholeSlideImage(self.image_path, name=self.name, save_dir=self.save_dir)
        wsi.process()
        self.cell_features_csv_path = wsi.cell_features_csv_path

    def set_cell_features_list(self, cell_features_list=None):
        """ Set the cell features to be used for the cell-cell graph construction. 
        If `cell_features_list` is None, all the columns in the `self.cell_features_csv_path` will be used. Otherwise, the `cell_features_list` will be used for selecting the columns.
        """

        assert self.cell_features_csv_path is not None, "Please provide the cell features csv file path."
        
        cell_features_df = pd.read_csv(self.cell_features_csv_path, index_col=0)
        feature_names = cell_features_df.columns

        if cell_features_list is not None:
            self.cell_features_list = list(
                feature_names[feature_names.isin(cell_features_list)]
            )
        else:
            self.cell_features_list = list(feature_names)

    def set_params(self, **kwargs):
        """ Set the parameters for the fineST estimation. 

        The keyword parameters specified here will overwrite existing settings. The complete list of the graph_params and run_params can be found in the :py:attr:`graph_params` and :py:attr:`run_params` attributes.
       
        The `graph_params` are the parameters for the cell-cell graph construction and the transition matrix estimation. Behaviors of some important parameters are listed below.

        Graph params:
            - n_neighbors: int
                Number of neighbors for the cell-cell graph construction. Default is 5. Increasing this number will increase the connectivity of the cell-cell graph.
            - geom_morph_ratio: float
                The ratio of the geometric distance and the morphological distance for the cell-cell graph construction. Default is 10. Increasing this number will lead to more local connections.
            - adjust_cell_network_by_transcriptome_scale: int or float
                The scale of the transcriptome heterogeneity to be used for adjusting the cell-cell graph in relative to the morphological distance. Default is 0. Increasing this number will increase the contribution of the transcriptome to build the cell-cell graph.
            - snn_threshold: float
                The threshold of proportion of shared neighbors for connection. Default is 0.1. Increasing this number will increase the criteria and lead to more sparse cell-cell graph.

        User-supplied cell-cell graph can also be used. The `conn_csr_matrix` parameter can be used to provide the cell-cell graph in the form of a scipy csr matrix. 
        Some other parameters are also important for the transition matrix estimation.

        Transition matrix params:
            - preferential_flow: bool
                Whether to use the preferential flow for the transition matrix estimation. So the information is controlled to make sure the
                information is flowing from the high quality cells to the low quality cells, where the quality is measured by the heterogeneity of
                the cells in the morphological space. Default is True. Setting this to False will lead to a symmetric transition.
            - weigh_cells: bool
                Whether to weigh the cells by the transcriptome heterogeneity for the transition matrix estimation. Default is True.
            - smoothing_scale: float
                The scale of the self-weight in the Lambda matrix. Default is 0.5. Increasing this number will preserve more original gene expression.
            - inflation_percentage: float
                How much to inflate the gene expression space after the transition matrix estimation. Default is None. Reasonable values range from
                [0, 10]. If this is None or 0, the gene expression space will not be inflated. Increasing this number will inflate the gene
                expression space for preserving the features space during the Markov graph diffusion. Read more about it in this paper `Taubin smoothing <https://graphics.stanford.edu/courses/cs468-01-fall/Papers/taubin-smoothing.pdf>`_.

        The `run_params` are the parameters for the Markov graph diffusion. Behaviors of some important parameters are listed below,

        Markov diffusion params:
            - n_iter: int
                Number of iterations for the Markov graph diffusion. Default is 10. Usually the estimation converges within 20 iterations.
            - initialize: bool
                Whether to initialize the cell-cell graph and the transition matrix. Default is True. Setting this to False will use the supplied/precomputed transition matrix (in `adata.obsp["transition_matrix"]`) for the Markov graph diffusion.
            - conn_key: str
                The key in `adata.obsp` for the cell-cell graph. Default is "snn". 
            - reduced_dimension_transcriptome_obsm_key: str 
                The key in `adata.obsm` for the cell features. Default is "X_pca". If `reduced_dimension_transcriptome_obsm_key` is not in `adata.obsm`, the cell features will be used for the cell-cell graph construction.
            - layer: str
                The layer of the gene expression to be used for the Markov graph diffusion. Default is None. If this is None, the `.X` will be used.
            - is_rawCount: bool
                Whether the gene expression is raw count. Default is True. If this is True, the output gene expression will be in raw counts as well. Otherwise, the output gene expression will be in log-normalized counts.
            - stochastic_expression_neighbors_level: str
                The level of the neighbors to be used for the stochastic expression. Default is "spot". Valid values are "spot" and "cell". "spot" means the cells enclosed by neighboring spots will be used for the stochastic expression. "cell" means the neighbors of the cells will be used for the stochastic expression.
        """
        for param in self.run_params:
            if param in kwargs:
                self.run_params.update({param: kwargs[param]})

        for param in self.graph_params:
            if param in kwargs:
                self.graph_params.update({param: kwargs[param]})

    def sanity_check(self):
        """ Whether the required attributes are set before running the prediction.

        Returns
        -------
        Boolean
            True if all the required attributes are set, False otherwise.
        """

        if self.recipe in ["gene"]:
            required_attrs_for_prediction = ["adata"]
        else:
            required_attrs_for_prediction = ["adata", "model_path", "generate"]

        for attr in required_attrs_for_prediction:
            if not hasattr(self, attr):
                logger.error(
                    f"Need to set attribute {attr} before running the prediction"
                )
                return False

        return True

    def predict_gene_expression(self, **kwargs):
        """ Predict the gene expression for the cells. Internally calls the :py:func:`estimate_expression_markov_graph_diffusion` function for the
        finest estimation. The keyword parameters specified here will overwrite existing settings.

        Parameters
        ----------
        kwargs : dict
            Parameters for the self.run_params and self.graph_params.

        Returns
        -------
        adata : :py:class:`anndata.AnnData`
            Anndata object with the predicted gene expression for the cells.

        """

        self.set_params(**kwargs)
        
        self.prepare_recipe()

        #print(self.run_params)
        #print(self.graph_params)

        temp_dir = os.path.join(self.save_dir, "TEMP")
        os.makedirs(temp_dir, exist_ok=True)

        # sanity_check
        if not self.sanity_check():
            return self.adata

        # burn-in if you would like to include the effect of the input transcriptome in the cell-cell graph construction
        if self.graph_params["adjust_cell_network_by_transcriptome_scale"] > 0:
            self._burn_in(n_iter=self.run_params["burn_in_steps"], n_pcs=self.graph_params["reduced_dimension_transcriptome_obsm_dims"])

        # initialization: constructing cell graph and the transition matrix
        if self.run_params["initialize"]:
            markov_graph_diffusion_initialize(self.adata, **self.graph_params)

        self.data_pre_path = os.path.join(self.save_dir,
                                          f"{self.name}_adata_cell_input.h5ad")
        self.run_params["regulate_expression_mean"] = self.run_params["regulate_expression_mean"] and self.recipe == "reduced"
        self.write_adata(f"{self.name}_adata_cell_input.h5ad", self.adata)
        self.write_params()
        self.save()

        # run the finest estimation
        estimate_expression_markov_graph_diffusion(
            self.adata,
            conn_key=self.run_params["conn_key"],
            n_iter=self.run_params["n_iter"],
            input_layer=self.run_params["layer"],
            is_rawCount=self.run_params["is_rawCount"],
            stochastic_expression_neighbors_level=self.
            run_params["stochastic_expression_neighbors_level"],
            regulate_expression_mean=self.run_params["regulate_expression_mean"],
            smooth_predicted_expression_steps=self.
            run_params["smooth_predicted_expression_steps"],
            sample_predicted_expression_fluctuation_scale=self.
            run_params["sample_predicted_expression_fluctuation_scale"],
            n_jobs=self.run_params["n_jobs"],
            out_prefix=self.run_params["out_prefix"],
            write_freq=self.run_params["write_freq"],
            temp_dir=temp_dir,
            save_dir=self.save_dir,
            gen_module=self.generate,
        )

        # Clean up
        if not self.run_params["save_chain"]:
            try:
                shutil.rmtree(temp_dir)
            except OSError as e:
                logger.error(e.strerror)

    def _burn_in(self, n_iter=5, genes_included="highly_variable", n_pcs=2):
        """ Burn-in the Markov graph diffusion. To include the effect of the input transcriptome in the cell-cell graph construction, we run :py:func:`estimate_expression_markov_graph_diffusion` function for the
        finest estimation using only histology features with vanilla parameters for `n_iter` steps (default: 5). In the burnin stage, `adjust_cell_network_by_transcriptome_scale` is set to 0. 

        Parameters
        ----------
        n_iter : int, optional
            Number of iterations for the Markov graph diffusion in burn-in stage. Default is 5. 
        genes_included : str or None, optional
            Genes to be used for PCA of the transcriptome. If None, use the genes for prediction in the ::class:`thor.fineST` object; if "highly_variable", the highly variable genes will be used; if "all", all the genes will be used. Default is None.
        n_pcs : int, optional
            Number of PCs for PCA of the transcriptome. Default is 2.

        Returns
        -------
        None
            Update the `self.adata.obsm["X_pca"]` attribute of the class.

        """

        if n_iter < 1:
            return

        # disable the logging for the burn-in stage at the moment for cleaner output
        logging.disable(logging.INFO)
        logger.info("Burn-in the Markov graph diffusion first.")

        os.environ["TQDM_DISABLE"] = '1'

        # Create a copy of the original object
        burnin = self.copy()
        burnin.adata.X = get_adata_layer_array(burnin.adata, layer_key=burnin.run_params["layer"])

        temp_dir = os.path.join(burnin.save_dir, "BURNIN_TEMP")
        os.makedirs(temp_dir, exist_ok=True)

        # initialization: constructing cell graph and the transition matrix
        burnin.graph_params["adjust_cell_network_by_transcriptome_scale"] = 0
        if burnin.run_params["initialize"]:
            markov_graph_diffusion_initialize(burnin.adata, **burnin.graph_params)

        # use highly variable genes for PCA of the transcriptome if not specified. 
        # set genes for prediction
        if genes_included is None:
            genes_included = list(self.adata.var_names[self.adata.var.used_for_prediction])
        elif genes_included == "highly_variable":
            sc.tl.pca(burnin.adata)
            sc.pp.highly_variable_genes(self.adata)
            genes_included = list(self.adata.var_names[self.adata.var.highly_variable])
        elif genes_included == "all":
            genes_included = list(self.adata.var_names)
        burnin.genes = genes_included
        burnin.set_genes_for_prediction(genes_selection_key=None)
        burnin.recipe = "gene"
        burnin.prepare_recipe()
        
        # run the markov graph diffusion
        estimate_expression_markov_graph_diffusion(
            burnin.adata,
            conn_key=burnin.run_params["conn_key"],
            n_iter=n_iter,
            input_layer=burnin.run_params["layer"],
            is_rawCount=burnin.run_params["is_rawCount"],
            stochastic_expression_neighbors_level=None,
            regulate_expression_mean=False,
            smooth_predicted_expression_steps=0,
            n_jobs=burnin.run_params["n_jobs"],
            out_prefix="burnin",
            write_freq=n_iter,
            temp_dir=temp_dir,
            save_dir=burnin.save_dir,
            gen_module=burnin.generate,
        )

        ad_burnin = burnin.load_result(f"burnin_{n_iter}.npz")
        sc.tl.pca(ad_burnin, n_comps=n_pcs)
        self.adata.obsm["X_pca"] = ad_burnin.obsm["X_pca"]

        # Clean up
        del ad_burnin
        del burnin
        try:
            shutil.rmtree(temp_dir)
        except OSError as e:
            logger.error(e.strerror)

        os.environ["TQDM_DISABLE"] = '0'
        logging.disable(logging.NOTSET)


    def write_adata(self, file_name, ad):
        """Write the gene expression into adata.
        
        Parameters
        ----------
        file_name : str
            File name of gene expression. This file will be saved in the save_dir. Taking the relative path to the save_dir.
        """

        cell_adata_out_path = os.path.join(self.save_dir, file_name)
        ad.write_h5ad(cell_adata_out_path)

    def write_params(self, exclude=["conn_csr_matrix"]):
        """Write the parameters to json files. This includes the run_params and graph_params.
        """
        with open(
            os.path.join(self.save_dir, f"{self.name}_run_params.json"), "w"
        ) as fp:
            json.dump(self.run_params, fp, indent=4)

        graph_params = self.graph_params.copy()
        if graph_params.get("conn_csr_matrix") is not None:
            del graph_params["conn_csr_matrix"]
        with open(
            os.path.join(self.save_dir, f"{self.name}_graph_params.json"), "w"
        ) as fp:
            json.dump(graph_params, fp, indent=4)

    def load_params(self, json_path):
        """Load the parameters from the json file. 
        The parameters will be matched to update the `self.run_params` and `self.graph_parmas`, which are used for the prediction.

        Parameters
        ----------
        json_path : str
            Path to the json file that contains the parameters.
        """
        with open(json_path, "r") as fp:
            params = json.load(fp)
        self.set_params(**params)

    def load_result(self, file_name, layer_name=None):
        """Load the predicted gene expression into adata.

        Parameters
        ----------
        file_name : str
            File name of the predicted gene expression. This file should be in the save_dir. Taking the relative path to the save_dir.
        layer_name : str, optional
            Layer name of the predicted gene expression. If None, the `.X` will be used.

        Returns
        -------
        adata : :py:class:`anndata.AnnData`
            Anndata object with the predicted gene expression for the cells.
        """
        if not hasattr(self, "adata"):
            self.adata = sc.read_h5ad(self.data_pre_path)
        layer_file = os.path.join(self.save_dir, file_name)
        layer_array = load_npz(layer_file)
        ad = self.adata[:, self.adata.var.used_for_prediction].copy()
        del ad.layers

        if layer_name is None:
            ad.X = layer_array
        else:
            ad.layers[layer_name] = layer_array
        return ad

    def prepare_recipe(self):
        """Prepare the running modes for the fineST estimation of gene expression. 

        Supported recipe: "gene", "reduced", "mix".
            - "gene": use the user-set genes for prediction with *gene mode*. This includes the genes from the the `used_for_prediction` key in adata.var.
            - "reduced": use the reduced genes from the VAE model with *reduced mode*. Ignoring the genes from the `used_for_prediction` key in adata.var.
            - "mix": use the reduced genes from the VAE model using *reduced mode* and the rest of the user-set genes for prediction with *gene mode*.
        """

        assert self.recipe in ("gene", "reduced", "mix"), "Please specify one of the implemented recipes: `mix`, `gene`, or `reduced`"
        
        logger.info(f"Using mode {self.recipe}")
        if self.recipe == "mix":
            try:
                self.get_reduced_genes()
            except:
                self.adata.var["used_for_reduced"] = False

            if len(self.adata.var[self.adata.var.used_for_reduced]) == 0:
                logger.warning(
                        "Failed to get reduced genes. Using all genes for prediction."
                        )
                self.adata.var["used_for_prediction"] = True
                self.recipe = "gene"
        
        if self.recipe == "reduced":
            #self.get_reduced_genes()

            self.adata.var["used_for_reduced"] = self.adata.var["used_for_vae"]
            self.adata.var["used_for_prediction"] = self.adata.var[
                "used_for_vae"]
        
        if self.recipe == "gene":
            self.adata.var["used_for_reduced"] = False
            self.adata.var["used_for_vae"] = False
            self.generate = IdentityGenerator()


    def visualize_cell_network(self, **kwargs):
        """Visualize the cell graph. 
            Internally calls the :py:func:`plot_cell_graph` function.
        """

        return plot_cell_graph(self.adata, **kwargs)

    def set_genes_for_prediction(self, genes_selection_key="highly_variable"):
        """ Set the genes to be used for prediction. 
        This will update the `used_for_prediction` column in adata.var.

        Parameters
        ----------
        genes_selection_key : str, optional
            var key in adata for selection of the genes. By default, "highly_variable" will be used. We also recommend using spatially variable
            genes (with SPARK-X). If the key is not present in `adata.var`, you are responsible to either do so before running this function.
            None or "all" is also supported. None will only use the user-supplied genes and "all" will use all the genes (this is almost never
            recommended).

        Returns
        -------
        None
            Update the `used_for_prediction` column in adata.var.                
        """

        ad = self.adata
        if genes_selection_key == "all":
            logger.warning(
                "Using all the genes. This may not be optimal and can be slow."
            )
            ad.var["used_for_prediction"] = True
        elif genes_selection_key is None:
            ad.var["used_for_prediction"] = ad.var.index.isin(self.genes)
        elif genes_selection_key not in ad.var:
            logger.error(
                f"Rerun it by providing a valid genes_selection_key. {genes_selection_key} is not a valid key in adata.var."
            )
            raise KeyError
        else:
            selected_list = list(ad.var_names[ad.var[genes_selection_key]])
            combined_list = list(set(selected_list + self.genes))
            ad.var["used_for_prediction"] = ad.var.index.isin(combined_list)

    def get_reduced_genes(self, keep=0.9, min_mean_expression=0.5):
        """Obtained the genes whose expression will be estimated through diffusion of latent variables collectively. 
        All the genes trained using VAE usually are not reconstructed faithfully. Therefore, we will use the genes with high reconstruction
        quality (measured by cosine similarity with the input gene expression).

        Parameters
        ----------
        keep : float, optional
            The proportion of the genes to be kept for the reduced mode in the VAE model. The genes are ranked according to the VAE reconstruction
            quality. Default is 0.9.
        min_mean_expression : float, optional
            Threshold of the mean expression for the genes to be used in reduced mode. Default is 0.5.
        """

        ad_spot = sc.read_h5ad(self.spot_adata_path)

        assert "used_for_vae" in self.adata.var, "Please set up `used_for_vae` column in adata.var by running either `vae_training` or `load_vae_model`"

        genes = self.adata.var_names[self.adata.var.used_for_vae]

        assert len(
            genes
        ) > 1, "No gene set for VAE training. There is something wrong."

        X = get_adata_layer_array(ad_spot[:, genes])
        z = self.generate.encode(X)
        decoded = self.generate.decode(z)
        cos_genes = 1 - var_cos(decoded, X)
        mean_exp_genes = X.mean(axis=0)
        index_low_to_high = np.array(np.argsort(cos_genes))
        rank = np.argsort(index_low_to_high) + 1
        ad = self.adata[:, genes]
        ad.var["reconstr_rank"] = rank.tolist()
        ad.var["used_for_reduced"] = np.logical_and(
            ad.var["reconstr_rank"] < keep * ad.shape[1], mean_exp_genes
            >= min_mean_expression
        )
        reduced_genes = list(ad.var_names[ad.var.used_for_reduced])

        if "used_for_prediction" in self.adata.var:
            self.adata.var["used_for_reduced"] = np.logical_and(
                self.adata.var_names.isin(reduced_genes),
                self.adata.var.used_for_prediction
            )
        else:
            self.adata.var["used_for_reduced"] = self.adata.var_names.isin(
                reduced_genes
            )

    def load_vae_model(self, model_path=None):
        """Load the trained VAE model.

        Parameters
        ----------
        model_path : tuple
            Folder where the encoder and decoder models (.h5 files) locate. The filenames should be {self.name}_VAE_encoder.h5 and {self.name}_VAE_decoder.h5.

        Returns
        -------
        None
            Update the `generate` attribute of the class.
        """

        if model_path is None:
            logger.error("Please provide a correct model path.")
            return

        encoder_model_path = os.path.join(model_path, "encoder")
        decoder_model_path = os.path.join(model_path, "decoder")

        vae = VAE()
        vae.load_models(
            encoder_model_path=encoder_model_path,
            decoder_model_path=decoder_model_path,
        )
        self.generate = vae
        self.model_path = model_path

        # load the associated genes used for vae training.
        # single column headless genes
        vae_genes_file_path = os.path.join(
            self.model_path, f"vae_genes_{self.name}.csv"
        )
        df = pd.read_csv(vae_genes_file_path, header=None, index_col=0)
        vae_genes_set = set(df.index)
        self.adata.var["used_for_vae"] = self.adata.var.index.isin(
            vae_genes_set
        )
#        self.adata.obsm["X_vae"] = self.generate.encode(self.adata[:, self.adata.var["used_for_vae"]].X.toarray())

    def save(self, exclude=["generate", "adata", "conn_csr_matrix"]):
        """ 
        Save the attributes of the class to json file. The saved json file can be directly loaded to create a new instance of the class.

        Parameters
        ----------
        exclude : list, optional
            List of attributes to be excluded from saving. By default, ["generate", "adata"] are excluded.
        """
        
        attrs = self.__dict__.copy()
        for attr in exclude:
            if attr in attrs:
                del attrs[attr]
        with open(os.path.join(self.save_dir, f"{self.name}_fineST.json"), "w") as fp:
            json.dump(attrs, fp, indent=4)


    def __deepcopy__(self, memo):
        # Create a new instance of the class
        new_instance = self.__class__.__new__(self.__class__)
        memo[id(self)] = new_instance

        # Copy all attributes except 'generate'
        for k, v in self.__dict__.items():
            if k != 'generate':
                setattr(new_instance, k, copy.deepcopy(v, memo))
            else:
                new_instance.load_vae_model(self.model_path)

        return new_instance

    def copy(self):
        return copy.deepcopy(self)

    
    @classmethod
    def load(cls, json_path):
        """ Load the saved json file to create a new instance of the class.

        Parameters
        ----------
        json_path : str
            Path to the saved json file.
        """
        with open(json_path, "r") as fp:
            obj = json.load(fp)

        obj = cls(**obj)

        if hasattr(obj, "data_pre_path"):
            obj.adata = sc.read_h5ad(obj.data_pre_path)

        if hasattr(obj, "model_path"):
            obj.load_vae_model(obj.model_path)

        return obj

