# -*- coding: utf-8 -*-
"""
Created on Mon Mar 17 11:56:07 2025

@author: admin
"""
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-

import h5py
import numpy as np
import time
import torch
import torch.nn as nn
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from torch.utils.data import TensorDataset, DataLoader
from matplotlib.colors import LinearSegmentedColormap

# ================== 全局配置 ==================
plt.rcParams.update({
    'font.family': 'Times New Roman',
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'axes.linewidth': 1.5,
    'lines.linewidth': 2.5,
    'lines.markersize': 10,
    'figure.dpi': 600,
    'savefig.format': 'pdf',
    'savefig.bbox': 'tight',
    'legend.fontsize': 12,
    'mathtext.fontset': 'stix',
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.linestyle': ':',
    'grid.alpha': 0.4
})
sns.set_palette("pastel")

# ================== 数据预处理 ==================
def load_and_preprocess():
    with h5py.File(r'E:\流体\paper 1\Figure\python figure\Ma4.mat', 'r') as f:
        X = np.array(f['X'][:]).astype('float32')  # (3000, 50, 6)
        y = np.array(f['Y'][:]).astype('float32')   # (3000, 1)

    # 通道级标准化
    X_mean = X.mean(axis=(0, 1), keepdims=True)
    X_std = X.std(axis=(0, 1), keepdims=True) + 1e-8
    X = (X - X_mean) / X_std

    # 输出标准化
    y_scaler = StandardScaler()
    y = y_scaler.fit_transform(y)

    # 数据集分割
    X_train_3d, X_test_3d, y_train, y_test = train_test_split(
        X, y, test_size=0.2, shuffle=False
    )

    # 生成2D数据
    X_train_2d = X_train_3d.reshape(X_train_3d.shape[0], -1)
    X_test_2d = X_test_3d.reshape(X_test_3d.shape[0], -1)

    return X_train_3d, X_test_3d, X_train_2d, X_test_2d, y_train, y_test, y_scaler

# ================== 模型定义 ==================
class KAN(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.Linear(128, 1)
        )
        
    def forward(self, x):
        return self.net(x)

def build_transformer(input_shape):
    inputs = Input(shape=input_shape)
    
    # 位置编码
    positions = tf.range(start=0, limit=input_shape[0], delta=1)
    position_embedding = Embedding(input_dim=input_shape[0], output_dim=input_shape[1])(positions)
    x = inputs + position_embedding
    
    # Transformer层
    x = MultiHeadAttention(num_heads=4, key_dim=16)(x, x)
    x = LayerNormalization(epsilon=1e-6)(x)
    x = GlobalAveragePooling1D()(x)
    outputs = Dense(1)(x)
    
    model = Model(inputs, outputs)
    model.compile(optimizer=Adam(0.0001), loss='mse')
    return model

# ================== 模型训练 ==================
def train_all_models():
    X_train_3d, X_test_3d, X_train_2d, X_test_2d, y_train, y_test, y_scaler = load_and_preprocess()
    
    results = {}
    
    models = [
        ('DNN', 'dnn', X_train_2d),
        ('LSTM', 'lstm', X_train_3d),
        ('CNN', 'cnn', X_train_3d),
        ('Transformer', 'transformer', X_train_3d),
        ('RandomForest', 'rf', X_train_2d),
        ('XGBoost', 'xgb', X_train_2d),
        ('KAN', 'kan', X_train_2d)
    ]

    for name, model_type, X_train in models:
        try:
            print(f"Training {name}...")
            start_time = time.time()
            
            # 模型训练
            if model_type == 'dnn':
                model = Sequential([
                    Dense(512, activation='gelu', input_shape=(X_train.shape[1],)),
                    BatchNormalization(),
                    Dropout(0.3),
                    Dense(256, activation='gelu'),
                    Dense(1)
                ])
                model.compile(optimizer=Adam(0.0005), loss='mse')
                model.fit(X_train, y_train, epochs=300, batch_size=64, verbose=0)
                
            elif model_type == 'lstm':
                model = Sequential([
                    Bidirectional(LSTM(128, return_sequences=True), input_shape=(50, 4)),
                    LayerNormalization(),
                    Bidirectional(LSTM(64)),
                    Dense(64, activation='gelu'),
                    Dense(1)
                ])
                model.compile(optimizer=Adam(0.001), loss='mse')
                model.fit(X_train, y_train, epochs=200, batch_size=32, verbose=0)
                
            elif model_type == 'cnn':
                model = Sequential([
                    Conv1D(64, 5, activation='gelu', padding='same', input_shape=(50, 4)),
                    MaxPooling1D(2),
                    Conv1D(128, 3, activation='gelu', padding='same'),
                    GlobalAveragePooling1D(),
                    Dense(1)
                ])
                model.compile(optimizer=Adam(0.0005), loss='mse')
                model.fit(X_train, y_train, epochs=200, batch_size=64, verbose=0)
                
            elif model_type == 'transformer':
                model = build_transformer((50, 4))
                model.fit(X_train, y_train, epochs=200, batch_size=64, verbose=0)
                
            elif model_type == 'rf':
                model = RandomForestRegressor(n_estimators=300, max_depth=12, 
                                            min_samples_split=5, random_state=42)
                model.fit(X_train, y_train.ravel())
                
            elif model_type == 'xgb':
                model = XGBRegressor(n_estimators=300, learning_rate=0.1,
                                   max_depth=6, subsample=0.8, random_state=42)
                model.fit(X_train, y_train.ravel())
                
            elif model_type == 'kan':
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                model = KAN(X_train.shape[1]).to(device)
                optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
                
                X_tensor = torch.FloatTensor(X_train).to(device)
                y_tensor = torch.FloatTensor(y_train).view(-1,1).to(device)
                
                model.train()
                for epoch in range(300):
                    optimizer.zero_grad()
                    outputs = model(X_tensor)
                    loss = nn.HuberLoss()(outputs, y_tensor)
                    loss.backward()
                    optimizer.step()
                model = model.cpu()

            # 预测和评估
            model.eval() if model_type == 'kan' else None
            if model_type in ['lstm', 'cnn', 'transformer']:
                X_test = X_test_3d
            else:
                X_test = X_test_2d
                
            if model_type == 'kan':
                with torch.no_grad():
                    X_test_tensor = torch.FloatTensor(X_test.reshape(X_test.shape[0], -1))
                    y_pred = model(X_test_tensor).numpy().flatten()
            else:
                y_pred = model.predict(X_test).flatten()
            
            # 反标准化
            y_true = y_scaler.inverse_transform(y_test.reshape(-1, 1)).flatten()
            y_pred = y_scaler.inverse_transform(y_pred.reshape(-1, 1)).flatten()
            
            results[model_type] = {
                'model': model,
                'y_true': y_true,
                'y_pred': y_pred,
                'MAE': mean_absolute_error(y_true, y_pred),
                'RAE': np.sum(np.abs((y_pred - y_true) / y_true))/len(y_pred),
                'R2': r2_score(y_true, y_pred),
                'Time': time.time() - start_time
            }

        except Exception as e:
            print(f"\n{name} 训练失败: {str(e)}\n")
            continue
    return results, X_test_3d, X_test_2d, y_test, y_scaler

# ================== 主程序 ==================
if __name__ == "__main__":
    results, X_test_3d, X_test_2d, y_test, y_scaler = train_all_models()


