import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Schedule 1: A sample manufacturing schedule
RAW_SCHEDULE_DATA_1 = """

J1      1       7               0.00            12.00
J2      1       5               0.00            34.00
J4      1       8               0.00            36.00
J5      1       1               0.00            53.00
J6      1       9               0.00            14.00
J7      1       2               0.00            18.00
J8      1       4               0.00            18.00
J1      2       10              12.00           56.00
J3      1       4               18.00           72.00
J6      2       2               18.00           28.00
J7      2       7               18.00           95.00
J2      2       2               34.00           61.00
J6      3       5               34.00           60.00
J4      2       9               36.00           54.00
J4      3       9               54.00           68.00
J1      3       8               56.00           96.00
J8      2       10              56.00           71.00
J6      4       1               60.00           107.00
J2      3       3               61.00           93.00
J4      4       6               68.00           90.00
J5      2       9               68.00           105.00
J8      3       2               71.00           82.00
J4      5       5               90.00           136.00
J8      4       6               90.00           112.00
J2      4       4               93.00           108.00
J3      2       7               95.00           106.00
J7      3       2               95.00           111.00
J1      4       3               96.00           136.00
J5      3       10              105.00          119.00
J3      3       9               106.00          117.00
J6      5       8               107.00          151.00
J2      5       4               108.00          158.00
J7      4       7               111.00          144.00
J5      4       2               119.00          142.00
J3      4       3               136.00          149.00
J1      5       10              139.00          175.00
J5      5       2               142.00          165.00
J7      5       5               144.00          175.00
J8      5       1               148.00          175.00
J3      5       6               149.00          162.00"""

# Schedule 2: A slightly modified schedule to show differences.
RAW_SCHEDULE_DATA_2 = """J1      1       6               0.00            27.00
J3      1       4               0.00            13.00
J4      1       5               0.00            21.00
J8      1       7               0.00            14.00
J9      1       10              0.00            23.00
J10     1       2               0.00            33.00
J5      1       3               3.00            27.00
J3      2       1               13.00           93.00
J8      2       4               14.00           66.00
J7      1       7               16.00           70.00
J2      1       5               21.00           46.00
J1      2       9               27.00           54.00
J5      2       8               27.00           51.00
J4      2       3               31.00           85.00
J6      1       2               33.00           55.00
J10     2       10              33.00           44.00
J9      2       10              44.00           59.00
J5      3       5               51.00           93.00
J1      3       6               54.00           68.00
J2      2       2               55.00           65.00
J10     3       10              59.00           80.00
J6      2       8               60.00           72.00
J9      3       2               66.00           78.00
J2      3       7               70.00           107.00
J7      2       9               70.00           84.00
J1      4       8               72.00           91.00
J8      3       4               73.00           109.00
J6      3       2               78.00           96.00
J10     4       6               80.00           109.00
J4      3       9               85.00           135.00
J7      3       8               91.00           123.00
J3      3       5               93.00           104.00
J6      4       3               96.00           111.00
J9      4       2               96.00           142.00
J3      4       5               104.00          149.00
J2      4       7               107.00          135.00
J10     5       4               109.00          123.00
J5      4       6               110.00          149.00
J6      5       3               111.00          154.00
J8      4       1               111.00          154.00
J7      4       4               123.00          149.00
J1      5       8               127.00          168.00
J4      4       7               135.00          150.00
J9      5       2               142.00          154.00
J3      5       6               149.00          168.00
J7      5       4               149.00          168.00
J4      5       5               150.00          168.00
J2      5       3               154.00          168.00
J5      5       1               154.00          165.00
J8      5       2               154.00          168.00"""


def parse_schedule(raw_data):
    """Parses raw schedule data into a list of task dictionaries."""
    tasks = []
    lines = raw_data.strip().split('\n')
    for line in lines:
        parts = line.split()
        if len(parts) == 5:
            tasks.append({
                'job_id': parts[0],
                'op_id': int(parts[1]),
                'machine_id': int(parts[2]),
                'start_time': float(parts[3]),
                'end_time': float(parts[4]),
                'duration': float(parts[4]) - float(parts[3])
            })
    return tasks

def find_identical_timestamps(schedule1, schedule2):
    """
    Finds timestamps where the set of tasks is identical between the two schedules.
    """
    identical_timestamps = set()

    grouped_schedules_1 = {}
    for task in schedule1:
        grouped_schedules_1.setdefault(task['start_time'], []).append(task)
    
    grouped_schedules_2 = {}
    for task in schedule2:
        grouped_schedules_2.setdefault(task['start_time'], []).append(task)
    
    all_timestamps = sorted(list(set(grouped_schedules_1.keys()) | set(grouped_schedules_2.keys())))
    
    for time in all_timestamps:
        tasks1 = grouped_schedules_1.get(time, [])
        tasks2 = grouped_schedules_2.get(time, [])
        
        # We must create a canonical representation for comparison.
        # A sorted tuple of tuples (job_id, op_id, machine_id) works perfectly.
        canonical_tasks1 = sorted([(t['job_id'], t['op_id'], t['machine_id']) for t in tasks1])
        canonical_tasks2 = sorted([(t['job_id'], t['op_id'], t['machine_id']) for t in tasks2])
        
        # This is the key change: we check for identical timestamps
        if canonical_tasks1 == canonical_tasks2:
            identical_timestamps.add(time)
    
    return identical_timestamps

def plot_gantt_with_filtered_tasks(tasks, full_tasks, title, ax, show_y_labels, legend_patches):
    """
    Plots a Gantt chart with a filtered set of tasks, but uses the makespan from the full schedule.
    """
    all_machines = sorted(list(set(t['machine_id'] for t in full_tasks)))
    y_map = {machine: i for i, machine in enumerate(all_machines)}
    
    makespan = max(t['end_time'] for t in full_tasks) if full_tasks else 0
    
    cmap = plt.get_cmap('Paired')
    job_colors = {f'J{i+1}': cmap(i % 10) for i in range(10)}

    for task in tasks:
        y_pos = y_map[task['machine_id']]
        start = task['start_time']
        duration = task['duration']
        color = job_colors[task['job_id']]
        
        ax.barh(y_pos, duration, left=start, height=0.8, color=color, edgecolor='black', alpha=0.7)
        ax.text(start + duration / 2, y_pos, f'{task["job_id"]}-{task["op_id"]}', ha='center', va='center', fontsize=8, color='black')
        
    ax.set_title(title)
    ax.set_xlabel('Time')
    ax.set_yticks(range(len(all_machines)))
    if show_y_labels:
        ax.set_yticklabels([f'Machine {m}' for m in all_machines])
    else:
        ax.set_yticklabels([])
    ax.set_xlim(0, makespan + 1)
    ax.set_ylim(-0.5, len(all_machines) - 0.5)
    ax.grid(True, axis='x', linestyle='--', alpha=0.6)
    
    if legend_patches:
      ax.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left')

def main():
    schedule1_tasks = parse_schedule(RAW_SCHEDULE_DATA_1)
    schedule2_tasks = parse_schedule(RAW_SCHEDULE_DATA_2)
    
    # Find the timestamps where the schedules are identical
    identical_timestamps = find_identical_timestamps(schedule1_tasks, schedule2_tasks)
    
    if not identical_timestamps:
        print("No identical timestamps were found between the two schedules.")
        return

    # Filter tasks to keep only those at identical timestamps
    filtered_tasks_1 = [task for task in schedule1_tasks if task['start_time'] in identical_timestamps]
    filtered_tasks_2 = [task for task in schedule2_tasks if task['start_time'] in identical_timestamps]

    # Get a list of all machines and makespan for proper chart scaling
    all_machines = sorted(list(set(t['machine_id'] for t in schedule1_tasks)))
    makespan = max(t['end_time'] for t in schedule1_tasks) if schedule1_tasks else 0

    fig, axes = plt.subplots(1, 2, figsize=(18, 8), sharey=True)

    # Create the legend patches
    cmap = plt.get_cmap('Paired')
    job_colors = {f'J{i+1}': cmap(i % 10) for i in range(10)}
    legend_patches = [mpatches.Patch(color=job_colors[job], label=job) for job in sorted(job_colors.keys())]
    
    # Plot the first chart
    plot_gantt_with_filtered_tasks(filtered_tasks_1, schedule1_tasks, 'Schedule 1 (Only Identical)', axes[0], show_y_labels=True, legend_patches=None)
    
    # Plot the second chart
    plot_gantt_with_filtered_tasks(filtered_tasks_2, schedule2_tasks, 'Schedule 2 (Only Identical)', axes[1], show_y_labels=True, legend_patches=legend_patches)
    
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()