{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.datasets import make_classification\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import roc_auc_score" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sksurv.ensemble import RandomSurvivalForest, GradientBoostingSurvivalAnalysis, ExtraSurvivalTrees" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from sklearn.experimental import enable_iterative_imputer # noqa\n", "from sklearn.impute import IterativeImputer" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import gurobipy as gp\n", "from gurobipy import GRB" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Read the CSV file into a DataFrame\n", "df_crlm_1 = pd.read_csv(\"crlm_chemo_data_imputed.csv\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "df_xray = df_crlm_1" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "df_xray['RFS_event_5_years'] = np.where((df_xray['RCT_RFS_event'] == 1) & (df_xray['RCT_TIME'] <= 60), 1, 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Risk buckets and X-ray model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "to_train = df_xray[df_xray['adjuvantchemo'] == 0]\n", "\n", "# Selecting variables\n", "X = to_train[[\"age\", \"gender\", \"T\", \"N\", \"rightleft\", \"neochemo\",\n", " \"cea\", \"DFI\", \"size\", \n", " \"bilobar\", \"number\"]]\n", "\n", "\n", "y = to_train[[\"RCT_RFS_event\", \"RCT_TIME\"]]\n", "\n", "y_values = to_train[[\"RCT_RFS_event\", \"RCT_TIME\"]].values\n", "\n", "# Create a structured array\n", "y_structured = np.zeros(y_values.shape[0], dtype=[('event', '?'), ('time', ' 1:\n", " score += 1\n", " if row['size'] > 5:\n", " score += 1\n", " if row['cea'] > 200:\n", " score += 1\n", " \n", " return score\n", "\n", "# Apply the function to each row to calculate the risk score\n", "df_xray['fong_score'] = df_xray.apply(calculate_fong_score, axis=1)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Function to calculate GAME score\n", "def calculate_game_score(row):\n", " score = 0\n", " # KRAS mutation\n", " score += row['KRAS']\n", " \n", " # Tumor Burden Score (TBS)\n", " tbs = (row['number']**2 + row['size']**2) ** 0.5\n", " if 3 <= tbs < 9:\n", " score += 1\n", " elif tbs >= 9:\n", " score += 2\n", " \n", " # CEA level\n", " if row['cea'] >= 20:\n", " score += 1\n", " \n", " # Primary CRC lymph node metastases\n", " score += row['N']\n", " \n", " return score\n", "\n", "# Apply the function to each row to calculate the risk score\n", "df_xray['game_score'] = df_xray.apply(calculate_game_score, axis=1)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Time to calculate survival probability\n", "time_to_calculate = 60\n", "X_chemo = df_xray.copy()\n", "X_chemo_pred = X_chemo[[\"age\", \"gender\", \"T\", \"N\", \"rightleft\", \"neochemo\",\n", " \"cea\", \"DFI\", \"size\", \n", " \"bilobar\", \"number\"]]\n", "\n", "surv_funcs = estimator.predict_survival_function(X_chemo_pred.iloc[:])\n", "surv_probs_duration = [surv_funcs[i](time_to_calculate) for i in range(len(X_chemo))]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "df_xray = X_chemo\n", "df_xray[\"rec_score\"] = [1 - surv_funcs[i](time_to_calculate) for i in range(len(X_chemo))]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.0, 0.17, 0.33, 0.5, 0.67, 0.83, 1.0]\n" ] } ], "source": [ "n_bins = 7\n", "bins = [round(i / (n_bins - 1), 2) for i in range(n_bins)]\n", "print(bins)\n", "rec_score_buckets = [f\"{bins[i]}-{bins[i+1]}\" for i in range(len(bins)-1)]\n", "df_xray['rec_score_category'] = pd.cut(df_xray['rec_score'], bins=bins, labels=rec_score_buckets, include_lowest=True)\n", "rec_score_buckets_2 = [i+1 for i in range(len(bins)-1)]\n", "df_xray['bucket'] = pd.cut(df_xray['rec_score'], bins=bins, labels=rec_score_buckets_2, include_lowest=True)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "grouped_df = df_xray.groupby(['bucket', 'adjuvantchemo']).size().unstack(fill_value=0)\n", "df_anycat = df_xray" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create binary variables based on the specified cutoffs\n", "df_anycat['size_more_than_5'] = (df_anycat['size'] >= 5).astype(int) # size >= 5 -> 1, otherwise 0\n", "df_anycat['number_more_than_4'] = (df_anycat['number'] >= 4).astype(int) # number >= 4 -> 1, otherwise 0\n", "df_anycat['cea_more_than_20'] = (df_anycat['cea'] >= 20).astype(int) # cea >= 20 -> 1, otherwise 0\n", "df_anycat['DFI_more_than_12'] = (df_anycat['DFI'] >= 12).astype(int) # DFI >= 12 -> 1, otherwise 0\n", "\n", "# Display the first few rows to verify the changes\n", "df_anycat.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Prognostic Matching" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import MinMaxScaler\n", "\n", "df_anycat_0_sampled = pd.DataFrame(columns=df_anycat.columns)\n", "df_anycat_1_sampled = pd.DataFrame(columns=df_anycat.columns)\n", "\n", "# Extracting arrays from each column\n", "array_list = []\n", "for column in grouped_df.columns:\n", " column_array = grouped_df[column].values\n", " array_list.append(column_array)\n", "\n", "samples_bucket_0 = array_list[0]\n", "samples_bucket_1 = array_list[1]\n", "print(samples_bucket_0)\n", "print(samples_bucket_1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for bucket in range(1, len(bins)):\n", " subset_0 = df_anycat[(df_anycat['adjuvantchemo'] == 0) & (df_anycat['bucket'] == bucket)].sample(n=samples_bucket_0[bucket-1], replace=False)\n", " subset_1 = df_anycat[(df_anycat['adjuvantchemo'] == 1) & (df_anycat['bucket'] == bucket)].sample(n=samples_bucket_1[bucket-1], replace=False)\n", " df_anycat_0_sampled = pd.concat([df_anycat_0_sampled, subset_0])\n", " df_anycat_1_sampled = pd.concat([df_anycat_1_sampled, subset_1])\n", "\n", "# Reset the index of the selected_rows DataFrame\n", "resulting_data_0_1 = pd.concat([df_anycat_0_sampled,df_anycat_1_sampled])\n", "resulting_data_0_1.reset_index(drop=True, inplace=True)\n", "\n", "# scaler = MinMaxScaler()\n", "# scaled_values = scaler.fit_transform(resulting_data_0_1[['T']])\n", "# scaled_df = pd.DataFrame(scaled_values, columns=['T'])\n", "# resulting_data_0_1_norm = pd.concat([resulting_data_0_1.drop(columns = ['T']), scaled_df], axis=1)\n", "resulting_data_0_1_norm = resulting_data_0_1[['size_more_than_5', 'N', 'rec_score', 'rightleft', \n", " 'number_more_than_4', 'cea_more_than_20', 'DFI_more_than_12']]\n", "resulting_data_0_1_norm['bucket'] = resulting_data_0_1['bucket']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# specify how many treated to keep in matching\n", "n_matched_bucket = [0, 2, 10, 28, 28, 5]\n", "total_pairs = sum(n_matched_bucket)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy.spatial.distance import cdist\n", "\n", "# Combine all buckets into one dataset for the optimization\n", "resulting_data_all = resulting_data_0_1.copy()\n", "resulting_data_all_norm = resulting_data_0_1_norm.copy()\n", "\n", "# Calculate distance matrix for all patients across buckets\n", "resulting_data_all_norm = resulting_data_all_norm.astype(float)\n", "features = resulting_data_all_norm.columns\n", "matrix = cdist(resulting_data_all_norm[features], resulting_data_all_norm[features], metric='euclidean')\n", "matrix = np.nan_to_num(matrix, nan=1e5)\n", "\n", "# Extract important variables\n", "n = len(resulting_data_all)\n", "treat = resulting_data_all['adjuvantchemo']\n", "bucket = resulting_data_all['bucket']\n", "\n", "# Record patients in each bucket \n", "bucket_patient = {}\n", "for b in range(1, n_bins):\n", " df = resulting_data_all_norm.loc[resulting_data_all_norm['bucket'] == b]\n", " bucket_patient[b] = df.index.tolist()\n", " \n", " \n", "# Extract covariates that are reported in the trial\n", "rightleft = resulting_data_all['rightleft']\n", "number_more_than_4 = resulting_data_all['number_more_than_4']\n", "size_more_than_5 = resulting_data_all['size_more_than_5']\n", "rec_score = resulting_data_all['RFS_event_5_years']\n", "n_score = resulting_data_all['N']\n", "\n", "# Set target trial values\n", "targets = {\n", " 'rightleft': 0.8,\n", " 'number_more_than_4': 0.09,\n", " 'size_more_than_5': 0.145,\n", " 'rec_score': 0.62,\n", " 'n_score': 0.4\n", "}\n", "\n", "# Initialize the optimization model\n", "m = gp.Model()\n", "\n", "# Create 2D array of binary decision variables\n", "x = m.addVars(n, n, vtype=gp.GRB.BINARY, name=\"x\")\n", "\n", "# Calculate the totals for untreated and treated groups\n", "totals_treated = {\n", " 'rightleft': gp.quicksum(x[i, j] * treat[i] * rightleft[i] for i in range(n) for j in range(n)),\n", " 'number_more_than_4': gp.quicksum(x[i, j] * treat[i] * number_more_than_4[i] for i in range(n) for j in range(n)),\n", " 'size_more_than_5': gp.quicksum(x[i, j] * treat[i] * size_more_than_5[i] for i in range(n) for j in range(n)),\n", " 'rec_score': gp.quicksum(x[i, j] * treat[i] * rec_score[i] for i in range(n) for j in range(n)),\n", " 'n_score': gp.quicksum(x[i, j] * treat[i] * n_score[i] for i in range(n) for j in range(n))\n", "}\n", "\n", "totals_untreated = {\n", " 'rightleft': gp.quicksum(x[i, j] * (1 - treat[i]) * rightleft[i] for i in range(n) for j in range(n)),\n", " 'number_more_than_4': gp.quicksum(x[i, j] * (1 - treat[i]) * number_more_than_4[i] for i in range(n) for j in range(n)),\n", " 'size_more_than_5': gp.quicksum(x[i, j] * (1 - treat[i]) * size_more_than_5[i] for i in range(n) for j in range(n)),\n", " 'rec_score': gp.quicksum(x[i, j] * (1 - treat[i]) * rec_score[i] for i in range(n) for j in range(n)),\n", " 'n_score': gp.quicksum(x[i, j] * (1 - treat[i]) * n_score[i] for i in range(n) for j in range(n))\n", "}\n", "\n", "# Calculate means for untreated and treated groups\n", "means_untreated = {key: totals_untreated[key] / total_pairs for key in totals_untreated}\n", "means_treated = {key: totals_treated[key] / total_pairs for key in totals_treated}\n", "\n", "# Add auxiliary variables for absolute deviations\n", "z_vars_untreated = {key: m.addVar(vtype=gp.GRB.CONTINUOUS, name=f\"z_{key}_untreated\") for key in targets}\n", "z_vars_treated = {key: m.addVar(vtype=gp.GRB.CONTINUOUS, name=f\"z_{key}_treated\") for key in targets}\n", "\n", "# Add constraints for absolute differences\n", "for key in targets:\n", " m.addConstr(z_vars_untreated[key] >= means_untreated[key] - targets[key], name=f\"abs_pos_{key}_untreated\")\n", " m.addConstr(z_vars_untreated[key] >= targets[key] - means_untreated[key], name=f\"abs_neg_{key}_untreated\")\n", " \n", " m.addConstr(z_vars_treated[key] >= means_treated[key] - targets[key], name=f\"abs_pos_{key}_treated\")\n", " m.addConstr(z_vars_treated[key] >= targets[key] - means_treated[key], name=f\"abs_neg_{key}_treated\")\n", "\n", "# Set the objective function: minimize the sum of deviations and penalize matching costs\n", "m.setObjective(\n", " gp.quicksum(z_vars_untreated[key] + z_vars_treated[key] for key in targets) +\n", " gp.quicksum(x[i, j] * matrix[i][j] for i in range(n) for j in range(n)),\n", " gp.GRB.MINIMIZE\n", ")\n", "\n", "# Ensure that each patient is matched with at most one other patient\n", "m.addConstrs((gp.quicksum(x[i,j] for j in range(n)) <= 1 for i in range(n)), name=\"only_one_j2\")\n", "m.addConstrs((gp.quicksum(x[i,j] for i in range(n)) <= 1 for j in range(n)), name=\"only_one_i2\")\n", "\n", "# Prevent matching between two patients who both received treatment\n", "m.addConstrs(((treat[i] + treat[j]) * x[i,j] <= x[i,j] for i in range(n) for j in range(n)), name=\"not_two\")\n", "\n", "# Enforce that patients can only be matched within the same bucket\n", "m.addConstrs((bucket[i] * x[i,j] == bucket[j] * x[i,j] for i in range(n) for j in range(n)), name=\"same_bucket\")\n", "\n", "# Ensure number of patients matched with in each cohort \n", "for b in range(1,n_bins):\n", " m.addConstr(gp.quicksum(x[i,j] for i in bucket_patient[b] if treat[i] == 0 for j in bucket_patient[b]) == n_matched_bucket[b-1], name=\"match_treat_0\")\n", " m.addConstr(gp.quicksum(x[i,j] for i in bucket_patient[b] if treat[i] == 1 for j in bucket_patient[b]) == n_matched_bucket[b-1], name=\"match_treat_1\")\n", "\n", "# Avoid diagonal matching (no self-matching)\n", "m.addConstrs((x[i,i] == 0 for i in range(n)), name=\"not_diag\")\n", "\n", "# Optimize the model\n", "m.optimize()\n", "\n", "# Get the optimal solution\n", "x_opt = m.getAttr('x', x)\n", "\n", "# Convert the solution to a 2D array for further analysis\n", "x_arr = np.zeros((n, n))\n", "for i in range(n):\n", " for j in range(n):\n", " x_arr[i][j] = x_opt[i, j]\n", "\n", "# Filter results to get only the matched patients\n", "row_sums = np.sum(x_arr, axis=1)\n", "resulting_data_4 = resulting_data_all[row_sums == 1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "resulting_data_4.to_csv('crlm_data_cohort_after_matching.csv', index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }