import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import scipy.stats as stats

# Set style for academic charts
# Use Times New Roman for publication quality
plt.rcParams['font.family'] = 'Times New Roman'
# Fallback for Chinese characters if needed, but for English publication Times New Roman is standard
# plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial', 'DejaVu Sans'] 
plt.rcParams['axes.unicode_minus'] = False

def reproduce_hawthorne():
    # Locate data file
    possible_paths = [
        "hawthorne_plot_data.xlsx",
        "db/hawthorne_plot_data.xlsx",
        os.path.join(os.path.dirname(__file__), "hawthorne_plot_data.xlsx"),
        os.path.join(os.path.dirname(__file__), "db/hawthorne_plot_data.xlsx")
    ]
    
    data_path = None
    for path in possible_paths:
        if os.path.exists(path):
            data_path = path
            break
            
    if not data_path:
        print("Error: Could not find 'hawthorne_plot_data.xlsx'. Please place it next to this script.")
        return

    print(f"Loading data from: {data_path}")
    df = pd.read_excel(data_path)
    
    # Verify Data
    n = len(df)
    k = max(1, int(n * 0.25)) # First/Last 25%
    
    # --- ADAPTATION FOR NEW RESEARCH DESIGN (Section 5.4.2) ---
    # The new design focuses on "Total Cognitive Depth Score" rather than just "Active Students".
    # If the column exists, use it. If not, derive it from 'Total Interaction Score' 
    # as a proxy for demonstration purposes (assuming Score ~ Depth).
    
    target_metric = 'Total Cognitive Depth Score'
    
    # Calculate Total Cognitive Depth Score based on Total Interaction Score
    # The user requested "Total Cognitive Depth Score" (深度认知总分数).
    # Since the score is already weighted (High-level +20, General +10), 
    # the Total Interaction Score itself serves as the Total Cognitive Depth Score.
    df['Total Cognitive Depth Score'] = df['Total Interaction Score']
    
    # 3. Longitudinal Trend Analysis (Linear Regression)
    # We want to see if the "Total Cognitive Depth Score" increases over time (Lesson Index)
    slope, intercept, r_value, p_value, std_err = stats.linregress(df['Lesson Index'], df['Total Cognitive Depth Score'])
    
    # --- FIX: SMOOTHING THE SPIKE (Lesson 18) ---
    # Apply a 3-lesson moving average to smooth out the spike at Lesson 18 and other fluctuations
    # This makes the trend clearer without hiding data (we can mention smoothing in caption)
    df['Smoothed Score'] = df['Total Cognitive Depth Score'].rolling(window=3, center=True).mean()
    # Fill NaN at edges with original values to keep length consistent
    df['Smoothed Score'] = df['Smoothed Score'].fillna(df['Total Cognitive Depth Score'])
    
    # Use smoothed score for plotting, but keep original for stats calculation to be honest?
    # Actually, user suggested smoothing for visual clarity. Let's use smoothed score for the line plot.
    
    # Recalculate averages for Step Change visualization based on RAW data (to be honest) or SMOOTHED?
    # User said: "In Initial Phase (Lesson 1-6) draw a horizontal line... In Final Phase (Lesson 20-25)..."
    # Let's use the raw data averages for the "Step Change" lines to be statistically accurate, 
    # but plot the smoothed line for visual trend.
    
    # Define phases
    phase1_end = 6
    phase2_start = 20
    
    # Calculate means for Step Change
    mean_phase1 = df[df['Lesson Index'] <= phase1_end]['Total Cognitive Depth Score'].mean()
    mean_phase2 = df[df['Lesson Index'] >= phase2_start]['Total Cognitive Depth Score'].mean()
    
    growth_rate = ((mean_phase2 - mean_phase1) / mean_phase1) * 100
    
    print(f"Verified Growth Rate ({target_metric}): {growth_rate:.2f}% (Lessons 1-{phase1_end} vs {phase2_start}-{len(df)})")
    
    # Create Output Dir
    output_dir = "reproduced_figures"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Plot (English Version)
    plot_chart(df, mean_phase1, mean_phase2, growth_rate, phase1_end, phase2_start, output_dir, lang='en')
    
    # Plot (Chinese Version)
    plot_chart(df, mean_phase1, mean_phase2, growth_rate, phase1_end, phase2_start, output_dir, lang='cn')
    
    print(f"Reproduced charts saved to: {os.path.abspath(output_dir)}")

def plot_chart(df, mean1, mean2, growth_rate, p1_end, p2_start, output_dir, lang='en'):
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Configuration
    if lang == 'en':
        ylabel = "Total Substantive Engagement Score"
        xlabel = "Lesson Sequence"
        title = "" 
        phase1_text = "Adaptation Phase"
        phase2_text = "Stabilisation Phase" # Changed from Intervention Effect Phase as requested
        # Phase 2 is Intervention Phase (7-19), Phase 3 is Stabilisation Phase (20-25)
        # But our code currently uses 2 phases for growth calculation.
        # User requested: Adaptation Phase (1-6) | Intervention Phase (7-19) | Stabilisation Phase (20-25)
        # We should update the visual phases.
        
        filename = "Fig3_Hawthorne_Trend.png"
        smoothed_label = "3-Lesson Moving Average"
        raw_label = "Raw Data"
    else:
        # For Chinese chart, we might need a Chinese font. 
        # Temporarily switch font for Chinese output
        plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial']
        plt.rcParams['font.family'] = 'sans-serif'
        ylabel = "深度认知总分数 (分)"
        xlabel = "课时序列"
        title = ""
        phase1_text = "适应期"
        phase2_text = "稳定期"
        filename = "Fig3_Hawthorne_Trend_cn.png"
        smoothed_label = "3课时移动平均"
        raw_label = "原始数据"

    # 1. Plot Raw Data (Lightly, to show honesty)
    sns.lineplot(data=df, x='Lesson Index', y='Total Cognitive Depth Score', ax=ax, 
                 color='lightgray', alpha=0.5, linewidth=1, label=raw_label, zorder=1)
    
    # 2. Plot Smoothed Data (Bold, to show trend)
    sns.lineplot(data=df, x='Lesson Index', y='Smoothed Score', ax=ax, 
                 color='#2c7bb6', linewidth=2.5, marker='o', markersize=6, label=smoothed_label, zorder=2)
    
    # 3. Step Change Visualization (The "Grandpa's Trick")
    # Horizontal line for Phase 1
    ax.hlines(y=mean1, xmin=1, xmax=p1_end, colors='#d7191c', linestyles='-', linewidth=3, zorder=3)
    # Horizontal line for Phase 3 (Stabilisation Phase, 20-25) - previously called Phase 2 in code
    ax.hlines(y=mean2, xmin=p2_start, xmax=df['Lesson Index'].max(), colors='#d7191c', linestyles='-', linewidth=3, zorder=3)
    
    # Add text for means
    ax.text(p1_end/2 + 0.5, mean1 + 2, f"Mean = {mean1:.1f}", color='#d7191c', ha='center', fontweight='bold', fontsize=12)
    ax.text((p2_start + df['Lesson Index'].max())/2, mean2 + 2, f"Mean = {mean2:.1f}", color='#d7191c', ha='center', fontweight='bold', fontsize=12)
    
    # Arrow and Growth Rate
    mid_x = (p1_end + p2_start) / 2
    
    # Position text box much higher, above the highest point in the graph or comfortably between the two mean lines
    # Option 1: Place it between the two red lines, centered.
    mid_y_between_lines = (mean1 + mean2) / 2
    
    # Option 2: Place it clearly above the lower line, pointing to the upper line?
    # User suggestion: "Use a dashed large arrow connecting front and back red lines"
    
    # Let's draw a dashed connection line from mean1 to mean2
    ax.plot([p1_end, p2_start], [mean1, mean2], color='#d7191c', linestyle=':', linewidth=2, zorder=3)
    
    # Place the text box in the middle of this connection line
    ax.text(mid_x, mid_y_between_lines + 50, f"+{growth_rate:.1f}%", 
            ha='center', va='bottom', fontsize=14, fontweight='bold', color='#d7191c',
            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="#d7191c", alpha=0.9))
            
    # Draw arrow from text to upper line
    # ax.annotate("", xy=(p2_start, mean2), xytext=(mid_x, mid_y_between_lines + 50),
    #            arrowprops=dict(arrowstyle="->", color='#d7191c', lw=1.5))
    
    # Styling
    if title:
        ax.set_title(title, fontsize=14, fontweight='bold')
    
    # Font size increase (User requested 1-2 sizes larger)
    ax.set_xlabel(xlabel, fontsize=14, fontweight='bold') # Was 12
    ax.set_ylabel(ylabel, fontsize=14, fontweight='bold') # Was 12
    
    # Legend customization (User requested transparent border)
    legend = ax.legend(loc='upper left', fontsize=12)
    legend.get_frame().set_alpha(0.0) # Make background transparent
    legend.get_frame().set_linewidth(0.0) # Remove border
    
    # Grid customization (User requested minimal grid)
    # Only horizontal lines, very light
    ax.yaxis.grid(True, linestyle='--', alpha=0.3) 
    ax.xaxis.grid(False) # Remove vertical grid lines
    
    # Phase Separators
    ax.axvline(x=p1_end + 0.5, color='gray', linestyle=':', alpha=0.4)
    ax.axvline(x=p2_start - 0.5, color='gray', linestyle=':', alpha=0.4)
    
    # Phase Labels (Black, Clear Font)
    # Get current y-axis limit to place text appropriately
    y_min, y_max = ax.get_ylim()
    # Ensure there is enough headroom for labels
    if y_max < 65: 
        ax.set_ylim(y_min, 65)
        y_max = 65
    
    text_y_pos = y_max * 0.92
    
    # Adaptation Phase (1-6)
    ax.text(p1_end/2 + 0.5, text_y_pos, phase1_text, ha='center', color='black', fontsize=12, fontweight='bold')
    
    # Intervention Phase (7-19) - Middle Section
    middle_phase_text = "Intervention Phase" if lang == 'en' else "干预实施期"
    middle_x = (p1_end + 0.5 + p2_start - 0.5) / 2
    ax.text(middle_x, text_y_pos, middle_phase_text, ha='center', color='black', fontsize=12, fontweight='bold')
    
    # Stabilisation Phase (20-25)
    ax.text((p2_start + df['Lesson Index'].max())/2, text_y_pos, phase2_text, ha='center', color='black', fontsize=12, fontweight='bold')

    # Remove spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False) 

    # Save
    output_path = os.path.join(output_dir, filename)
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.close()
    print(f"Generated {filename}")

if __name__ == "__main__":
    reproduce_hawthorne()
