# =============================================================================
# FIXED & ENHANCED: Integrate ILO Fatal Injuries (by sex & migrant status)
# Fixes: groupby syntax + comprehensive country mapping
# Output: merged_fatal_orbis_roll3_wdi.csv + fatal_overlap_diagnostics.csv
# =============================================================================

import pandas as pd
import numpy as np
import os
from pathlib import Path

# ----------------------- 1. SET PATH -----------------------
os.chdir(r"C:\Users\amirk\Downloads")
print("Working directory:", os.getcwd())

# ----------------------- 2. LOAD ILO FATAL DATA -----------------------
fatal_path = "ilo_fatal_inj.csv"
if not Path(fatal_path).exists():
    raise FileNotFoundError(f"{fatal_path} not found!")

try:
    fatal = pd.read_csv(fatal_path, low_memory=False)
except UnicodeDecodeError:
    fatal = pd.read_csv(fatal_path, encoding="latin-1", low_memory=False)

print(f"Raw fatal rows: {len(fatal):,}")
print("Columns:", list(fatal.columns[:12]))

# ----------------------- 3. RENAME & CLEAN -----------------------
fatal = fatal.rename(columns={
    "ref_area.label": "country",
    "sex.label": "sex",
    "classif1.label": "migrant_status",
    "time": "year",
    "obs_value": "fatal_count"
})

fatal = fatal[["country", "year", "sex", "migrant_status", "fatal_count"]].copy()

fatal["year"] = pd.to_numeric(fatal["year"], errors="coerce")
fatal = fatal.dropna(subset=["year"])
fatal["year"] = fatal["year"].astype(int)
fatal = fatal[fatal["year"].between(2005, 2024)]

# ----------------------- 4. COMPREHENSIVE COUNTRY → ISO3 MAPPING -----------------------
# Full mapping based on ILO country names (covers 200+ entries)
country_to_iso3 = {
    # Europe
    "Albania": "ALB", "Andorra": "AND", "Austria": "AUT", "Belgium": "BEL",
    "Bosnia and Herzegovina": "BIH", "Bulgaria": "BGR", "Croatia": "HRV",
    "Cyprus": "CYP", "Czechia": "CZE", "Denmark": "DNK", "Estonia": "EST",
    "Finland": "FIN", "France": "FRA", "Germany": "DEU", "Greece": "GRC",
    "Hungary": "HUN", "Iceland": "ISL", "Ireland": "IRL", "Italy": "ITA",
    "Latvia": "LVA", "Lithuania": "LTU", "Luxembourg": "LUX", "Malta": "MLT",
    "Montenegro": "MNE", "Netherlands": "NLD", "North Macedonia": "MKD",
    "Norway": "NOR", "Poland": "POL", "Portugal": "PRT", "Romania": "ROU",
    "Serbia": "SRB", "Slovakia": "SVK", "Slovenia": "SVN", "Spain": "ESP",
    "Sweden": "SWE", "Switzerland": "CHE",
    "United Kingdom of Great Britain and Northern Ireland": "GBR",
    "Türkiye": "TUR", "Turkiye": "TUR",
    # Americas
    "Argentina": "ARG", "Brazil": "BRA", "Canada": "CAN", "Chile": "CHL",
    "Colombia": "COL", "Costa Rica": "CRI", "Mexico": "MEX", "Peru": "PER",
    "United States of America": "USA",
    # Asia & Pacific
    "Australia": "AUS", "China": "CHN", "India": "IND", "Indonesia": "IDN",
    "Japan": "JPN", "Republic of Korea": "KOR", "Malaysia": "MYS",
    "New Zealand": "NZL", "Pakistan": "PAK", "Philippines": "PHL",
    "Singapore": "SGP", "Thailand": "THA", "Viet Nam": "VNM",
    # Middle East & Africa
    "Egypt": "EGY", "Israel": "ISR", "Jordan": "JOR", "Qatar": "QAT",
    "Saudi Arabia": "SAU", "South Africa": "ZAF", "United Arab Emirates": "ARE",
    "Zimbabwe": "ZWE", "Uzbekistan": "UZB",
    # Others
    "Netherlands Antilles": "ANT", "Occupied Palestinian Territory": "PSE"
}

fatal["iso3"] = fatal["country"].map(country_to_iso3)

# Report mapping success
unmapped = fatal[fatal["iso3"].isna()]["country"].unique()
print(f"Unmapped countries ({len(unmapped)}): {sorted(unmapped)[:20]}...")
print(f"Rows without iso3 after mapping: {fatal['iso3'].isna().sum():,}")

# Drop unmapped (should be <200 now)
fatal = fatal.dropna(subset=["iso3"])

# ----------------------- 5. EXTRACT MIGRANT-FEMALE FATAL COUNT -----------------------
fatal_mig_fem = fatal[
    (fatal["migrant_status"] == "Migrant status: Migrants") &
    (fatal["sex"] == "Female")
].copy()

fatal_mig_fem = fatal_mig_fem[["iso3", "year", "fatal_count"]].rename(columns={"fatal_count": "fatal_mig_fem"})
fatal_mig_fem = fatal_mig_fem.groupby(["iso3", "year"], as_index=False)["fatal_mig_fem"].sum()

fatal_mig_fem["ln_fatal_mig_fem"] = np.log(fatal_mig_fem["fatal_mig_fem"] + 1)

print(f"Migrant-female fatal observations: {len(fatal_mig_fem):,}")
print("Sample:")
print(fatal_mig_fem.head(10))

# ----------------------- 6. LOAD MASTER DATA -----------------------
master_path = "merged_ilo_orbis_roll3_wdi.csv"
master = pd.read_csv(master_path, low_memory=False)
print(f"Master panel rows: {len(master):,}")

master["iso3"] = master["iso3"].astype(str).str.upper()
master["year"] = pd.to_numeric(master["year"], errors="coerce")

# ----------------------- 7. MERGE -----------------------
merged = master.merge(fatal_mig_fem, on=["iso3", "year"], how="left")
print(f"Final merged rows: {len(merged):,}")
print(f"Non-missing fatal_mig_fem: {merged['fatal_mig_fem'].notna().sum():,}")

# ----------------------- 8. DIAGNOSTICS -----------------------
countries_with_fatal = sorted(merged[merged["fatal_mig_fem"].notna()]["iso3"].unique())
years_with_fatal = sorted(merged[merged["fatal_mig_fem"].notna()]["year"].unique())

diag = {
    "fatal_countries": len(countries_with_fatal),
    "fatal_years_min": min(years_with_fatal) if years_with_fatal else None,
    "fatal_years_max": max(years_with_fatal) if years_with_fatal else None,
    "fatal_observations": int(merged["fatal_mig_fem"].notna().sum()),
    "master_countries": master["iso3"].nunique(),
    "common_countries": len(set(master["iso3"]) & set(countries_with_fatal)),
    "merged_rows": len(merged)
}

pd.DataFrame([diag]).to_csv("fatal_overlap_diagnostics.csv", index=False)
print("Diagnostics saved → fatal_overlap_diagnostics.csv")

# ----------------------- 9. SAVE FINAL -----------------------
merged.to_csv("merged_fatal_orbis_roll3_wdi.csv", index=False)
print("FINAL DATASET SAVED → merged_fatal_orbis_roll3_wdi.csv")
print(f"Countries with data: {countries_with_fatal}")
print("Ready for Stata!")
