# eco_merge_pipeline_all_in_one_rdfix_v2.py
# Robust end-to-end: Orbis -> firm-year -> country-year; ILO; WDI
# R&D fix: derive rd_ratio from absolute R&D expenses when present; dynamic aggregation so missing blocks don't error

import os, re, warnings
from typing import List, Tuple
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)

# ====== USER PATHS (EDIT THESE) =========================================================
ORBISEXCEL = r"C:\Users\amirk\Downloads\Export3110-2.xlsx"
ILOFILE    = r"C:\Users\amirk\Downloads\ilo_gender_agg.csv"          # csv/xlsx
WDIFILE    = r"C:\Users\amirk\Downloads\worldbankindicators.xlsx"    # csv/xlsx
OUTDIR     = r"C:\Users\amirk\Downloads"
# ========================================================================================

# ====== FLAGS ===========================================================================
JOIN_MODE = 'inner'   # 'inner' or 'left'
IMPUTE_ILO = False    # fill small ILO gaps (limit=2) + 1-year carry
IMPUTE_RD  = True     # within-firm interpolation for rd_ratio (limit=2), after expense backfill
YEARS_MIN, YEARS_MAX = 2005, 2025
MERGE_YMIN, MERGE_YMAX = 2005, 2024
# ========================================================================================

ISO2_TO_ISO3 = {
    'US':'USA','CA':'CAN','MX':'MEX','GB':'GBR','UK':'GBR','IE':'IRL','FR':'FRA','DE':'DEU','IT':'ITA','ES':'ESP',
    'PT':'PRT','NL':'NLD','BE':'BEL','LU':'LUX','CH':'CHE','AT':'AUT','SE':'SWE','NO':'NOR','DK':'DNK','FI':'FIN',
    'IS':'ISL','EE':'EST','LV':'LVA','LT':'LTU','PL':'POL','CZ':'CZE','SK':'SVK','HU':'HUN','SI':'SVN','RO':'ROU',
    'BG':'BGR','HR':'HRV','GR':'GRC','CY':'CYP','MT':'MLT','RS':'SRB','BA':'BIH','MK':'MKD','AL':'ALB','UA':'UKR',
    'TR':'TUR','RU':'RUS','KZ':'KAZ','AM':'ARM','GE':'GEO','AZ':'AZE',
    'AU':'AUS','NZ':'NZL','JP':'JPN','KR':'KOR','CN':'CHN','HK':'HKG','MO':'MAC','SG':'SGP','MY':'MYS','ID':'IDN',
    'PH':'PHL','TH':'THA','VN':'VNM','KH':'KHM','LA':'LAO','MM':'MMR','PG':'PNG','FJ':'FJI',
    'IN':'IND','PK':'PAK','BD':'BGD','LK':'LKA','NP':'NPL',
    'AE':'ARE','SA':'SAU','QA':'QAT','KW':'KWT','OM':'OMN','BH':'BHR','JO':'JOR','LB':'LBN','EG':'EGY','MA':'MAR',
    'TN':'TUN','DZ':'DZA','IL':'ISR',
    'ZA':'ZAF','NG':'NGA','GH':'GHA','KE':'KEN','ET':'ETH','TZ':'TZA','UG':'UGA','SN':'SEN','CI':'CIV','CM':'CMR',
    'ZM':'ZMB','AR':'ARG','BR':'BRA','CL':'CHL','CO':'COL','PE':'PER','UY':'URY','PY':'PRY','EC':'ECU','BO':'BOL',
    'DO':'DOM','GT':'GTM','PA':'PAN','HN':'HND','CR':'CRI','SV':'SLV','VE':'VEN','CU':'CUB','JM':'JAM','BM':'BMU'
}
def normalize_iso3(x: str) -> str:
    if not isinstance(x, str):
        x = "" if x is None else str(x)
    s = x.strip().upper()
    if len(s) == 3 and re.fullmatch(r'[A-Z]{3}', s): return s
    if len(s) == 2 and re.fullmatch(r'[A-Z]{2}', s): return ISO2_TO_ISO3.get(s, s)
    return s

NAME_TO_ISO3 = {
    'Algeria':'DZA','Andorra':'AND','Argentina':'ARG','Armenia':'ARM','Australia':'AUS','Austria':'AUT',
    'Bahrain':'BHR','Belarus':'BLR','Belgium':'BEL','Belize':'BLZ','Bulgaria':'BGR','Chile':'CHL','Croatia':'HRV',
    'Cyprus':'CYP','Czechia':'CZE','Denmark':'DNK','Estonia':'EST','Finland':'FIN','France':'FRA','Georgia':'GEO',
    'Germany':'DEU','Greece':'GRC','Guyana':'GUY','Hungary':'HUN','Iceland':'ISL','Ireland':'IRL','Israel':'ISR',
    'Italy':'ITA','Kazakhstan':'KAZ','Kyrgyzstan':'KGZ','Latvia':'LVA','Lesotho':'LSO','Lithuania':'LTU',
    'Luxembourg':'LUX','Macao, China':'MAC','Malaysia':'MYS','Malta':'MLT','Mongolia':'MNG','Morocco':'MAR',
    'Netherlands':'NLD','New Zealand':'NZL','North Macedonia':'MKD','Norway':'NOR','Oman':'OMN','Pakistan':'PAK',
    'Poland':'POL','Portugal':'PRT','Qatar':'QAT','Republic of Korea':'KOR','Republic of Moldova':'MDA','Romania':'ROU',
    'Russian Federation':'RUS','Saudi Arabia':'SAU','Serbia':'SRB','Singapore':'SGP','Slovakia':'SVK','Slovenia':'SVN',
    'South Africa':'ZAF','Spain':'ESP','Sri Lanka':'LKA','Sweden':'SWE','Switzerland':'CHE','Thailand':'THA',
    'Timor-Leste':'TLS','Trinidad and Tobago':'TTO','Turkiye':'TUR','Türkiye':'TUR',
    'United Kingdom of Great Britain and Northern Ireland':'GBR','United States of America':'USA','Uzbekistan':'UZB',
    'Viet Nam':'VNM',"Lao People's Democratic Republic":'LAO','Dominican Republic':'DOM','Egypt':'EGY',
    'French Guiana':'GUF','Guadeloupe':'GLP','Montenegro':'MNE','Namibia':'NAM','Nicaragua':'NIC',
    'Occupied Palestinian Territory':'PSE','Panama':'PAN','Reunion':'REU','Bermuda':'BMU',
    'United Arab Emirates':'ARE','Saint Vincent and the Grenadines':'VCT','Senegal':'SEN'
}

REGION_MAP = {
    'AUT':'Europe','BEL':'Europe','BGR':'Europe','HRV':'Europe','CYP':'Europe','CZE':'Europe','DNK':'Europe','EST':'Europe',
    'FIN':'Europe','FRA':'Europe','DEU':'Europe','GRC':'Europe','HUN':'Europe','IRL':'Europe','ITA':'Europe','LVA':'Europe',
    'LTU':'Europe','LUX':'Europe','MLT':'Europe','NLD':'Europe','POL':'Europe','PRT':'Europe','ROU':'Europe','SVK':'Europe',
    'SVN':'Europe','ESP':'Europe','SWE':'Europe','GBR':'Europe','NOR':'Europe','ISL':'Europe','CHE':'Europe','UKR':'Europe',
    'SRB':'Europe','BIH':'Europe','MKD':'Europe','ALB':'Europe','TUR':'Europe','RUS':'Europe','KAZ':'Europe','ARM':'Europe',
    'GEO':'Europe','AZE':'Europe',
    'USA':'North America','CAN':'North America','MEX':'North America','BMU':'North America',
    'ARG':'Latin America & Caribbean','BRA':'Latin America & Caribbean','CHL':'Latin America & Caribbean',
    'COL':'Latin America & Caribbean','PER':'Latin America & Caribbean','URY':'Latin America & Caribbean',
    'PRY':'Latin America & Caribbean','ECU':'Latin America & Caribbean','BOL':'Latin America & Caribbean',
    'DOM':'Latin America & Caribbean','GTM':'Latin America & Caribbean','PAN':'Latin America & Caribbean',
    'HND':'Latin America & Caribbean','CRI':'Latin America & Caribbean','SLV':'Latin America & Caribbean',
    'VEN':'Latin America & Caribbean','CUB':'Latin America & Caribbean','JAM':'Latin America & Caribbean',
    'AUS':'East Asia & Pacific','NZL':'East Asia & Pacific','CHN':'East Asia & Pacific','HKG':'East Asia & Pacific',
    'MAC':'East Asia & Pacific','KOR':'East Asia & Pacific','JPN':'East Asia & Pacific','SGP':'East Asia & Pacific',
    'MYS':'East Asia & Pacific','IDN':'East Asia & Pacific','PHL':'East Asia & Pacific','THA':'East Asia & Pacific',
    'VNM':'East Asia & Pacific','KHM':'East Asia & Pacific','LAO':'East Asia & Pacific','MMR':'East Asia & Pacific',
    'PNG':'East Asia & Pacific','FJI':'East Asia & Pacific',
    'IND':'South Asia','PAK':'South Asia','BGD':'South Asia','LKA':'South Asia','NPL':'South Asia',
    'ARE':'Middle East & North Africa','SAU':'Middle East & North Africa','QAT':'Middle East & North Africa',
    'KWT':'Middle East & North Africa','OMN':'Middle East & North Africa','BHR':'Middle East & North Africa',
    'JOR':'Middle East & North Africa','LBN':'Middle East & North Africa','EGY':'Middle East & North Africa',
    'MAR':'Middle East & North Africa','TUN':'Middle East & North Africa','DZA':'Middle East & North Africa','ISR':'Middle East & North Africa',
    'ZAF':'Sub-Saharan Africa','NGA':'Sub-Saharan Africa','GHA':'Sub-Saharan Africa','KEN':'Sub-Saharan Africa',
    'ETH':'Sub-Saharan Africa','TZA':'Sub-Saharan Africa','UGA':'Sub-Saharan Africa','SEN':'Sub-Saharan Africa',
    'CIV':'Sub-Saharan Africa','CMR':'Sub-Saharan Africa','ZMB':'Sub-Saharan Africa',
}

def add_region(df: pd.DataFrame, iso_col='iso3'):
    df[iso_col] = df[iso_col].map(normalize_iso3)
    reg = df[iso_col].map(REGION_MAP)
    df['region'] = reg.fillna('Other/Unknown')
    return df

def cols_lower_strip(df: pd.DataFrame):
    df = df.copy()
    df.columns = [re.sub(r'\s+',' ', str(c)).strip() for c in df.columns]
    df.columns = [c.lower() for c in df.columns]
    return df

def read_any_tabular(path: str) -> pd.DataFrame:
    if not os.path.isfile(path):
        raise FileNotFoundError(f"File not found: {path}")
    ext = os.path.splitext(path)[1].lower()
    if ext in ['.csv', '.txt']:
        try:
            return pd.read_csv(path, low_memory=False)
        except Exception:
            return pd.read_csv(path, sep=';', low_memory=False)
    elif ext in ['.xlsx', '.xls']:
        return pd.read_excel(path)
    else:
        raise ValueError(f"Unsupported file type: {ext}")

def year_in_bounds(y):
    try:
        yi = int(y); return YEARS_MIN <= yi <= YEARS_MAX
    except Exception:
        return False

def dedup_block(df: pd.DataFrame, valcol: str) -> pd.DataFrame:
    if df is None or len(df) == 0:
        return pd.DataFrame(columns=['company_name','iso3','year', valcol])
    df = df[['company_name','iso3','year', valcol]].copy()
    df[valcol] = pd.to_numeric(df[valcol], errors='coerce')
    g = (df.groupby(['company_name','iso3','year'], as_index=False)[valcol]
           .agg(lambda s: s.dropna().iloc[0] if s.dropna().size else np.nan))
    return g

# ---------- ORBIS ----------
def build_company_year_from_orbis(path: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    wide = read_any_tabular(path)
    wide = cols_lower_strip(wide)

    iso_col = next((c for c in wide.columns if 'country iso code' in c or c in ['iso3','country_iso3','country code']), None)
    if iso_col is None:
        raise RuntimeError("Could not find 'Country ISO code' column in Orbis file.")
    wide['iso3'] = wide[iso_col].astype(str).map(normalize_iso3)

    name_col = next((c for c in wide.columns if c.startswith('company name')), None)
    if name_col is None:
        name_col = 'company_name'
        if name_col not in wide.columns: wide[name_col] = np.arange(len(wide))
    wide['company_name'] = wide[name_col].astype(str).str.strip()

    ypat = r'(20\d{2}|201[0-9]|200[5-9])'

    rev_m_cols = [c for c in wide.columns if c.startswith('operating revenue') and 'm usd' in c]
    rev_t_cols = [c for c in wide.columns if c.startswith('operating revenue') and 'th usd' in c]
    emp_cols   = [c for c in wide.columns if c.startswith('number of employees') and re.search(ypat, c)]
    pbt_cols   = [c for c in wide.columns if c.startswith('profit (loss) before tax')]
    ce_cols    = [c for c in wide.columns if c.startswith('costs of employees')]

    # R&D ratio column(s)
    rd_ratio_cols = [c for c in wide.columns if c.startswith('r&d expenses / operating revenue')]

    # R&D absolute (broad detection: any R&D/research+develop + expenses/costs, excluding ratio)
    def is_rd_abs(col: str) -> bool:
        s = col
        has_rd_term = ('r&d' in s) or ('research' in s and 'develop' in s)
        has_cost_term = ('expenses' in s) or ('costs' in s)
        is_ratio = '/ operating revenue' in s
        return has_rd_term and has_cost_term and (not is_ratio)

    rd_abs_m_cols  = [c for c in wide.columns if is_rd_abs(c) and 'm usd'  in c]
    rd_abs_th_cols = [c for c in wide.columns if is_rd_abs(c) and 'th usd' in c]

    def to_long(block_cols: List[str], valname: str):
        pairs = []
        for c in block_cols:
            m = re.search(ypat, c)
            if m: pairs.append((c, int(m.group(1))))
        if not pairs:
            return pd.DataFrame(columns=['company_name','iso3','year',valname])
        sub = wide[['company_name','iso3'] + [p[0] for p in pairs]].copy()
        ren = {col: re.search(ypat, col).group(1) for col in [p[0] for p in pairs]}
        sub = sub.rename(columns=ren)
        long_ = sub.melt(id_vars=['company_name','iso3'], var_name='year', value_name=valname)
        long_['year'] = pd.to_numeric(long_['year'], errors='coerce').dropna().astype(int)
        long_[valname] = pd.to_numeric(long_[valname], errors='coerce')
        return long_

    # Build long blocks
    rev_m_long   = dedup_block(to_long(rev_m_cols,   'revenue_musd'), 'revenue_musd')
    rev_t_long   = dedup_block(to_long(rev_t_cols,   'revenue_thusd'),'revenue_thusd')
    emp_long     = dedup_block(to_long(emp_cols,     'employees'),    'employees')
    pbt_long     = dedup_block(to_long(pbt_cols,     'pbt_thusd'),    'pbt_thusd')
    ce_long      = dedup_block(to_long(ce_cols,      'ce_thusd'),     'ce_thusd')
    rd_ratio_long= dedup_block(to_long(rd_ratio_cols,'rd_ratio'),     'rd_ratio')
    rd_m_long    = dedup_block(to_long(rd_abs_m_cols,'rd_exp_musd'),  'rd_exp_musd')
    rd_th_long   = dedup_block(to_long(rd_abs_th_cols,'rd_exp_thusd'),'rd_exp_thusd')

    # Revenue unified
    rev_candidates = []
    if len(rev_m_long):
        r = rev_m_long.copy(); r['revenue_usd'] = r['revenue_musd'] * 1e6; r['src_rev'] = 0
        rev_candidates.append(r[['company_name','iso3','year','revenue_usd','src_rev']])
    if len(rev_t_long):
        r = rev_t_long.copy(); r['revenue_usd'] = r['revenue_thusd'] * 1e3; r['src_rev'] = 1
        rev_candidates.append(r[['company_name','iso3','year','revenue_usd','src_rev']])
    rev = (pd.concat(rev_candidates, ignore_index=True)
           .sort_values(['company_name','iso3','year','src_rev'])
           .drop_duplicates(['company_name','iso3','year'], keep='first').drop(columns=['src_rev'])
           ) if rev_candidates else pd.DataFrame(columns=['company_name','iso3','year','revenue_usd'])

    # R&D abs unified
    rd_candidates = []
    if len(rd_m_long):
        r = rd_m_long.copy(); r['rd_exp_usd'] = r['rd_exp_musd'] * 1e6; r['src_rd'] = 0
        rd_candidates.append(r[['company_name','iso3','year','rd_exp_usd','src_rd']])
    if len(rd_th_long):
        r = rd_th_long.copy(); r['rd_exp_usd'] = r['rd_exp_thusd'] * 1e3; r['src_rd'] = 1
        rd_candidates.append(r[['company_name','iso3','year','rd_exp_usd','src_rd']])
    rd_abs = (pd.concat(rd_candidates, ignore_index=True)
              .sort_values(['company_name','iso3','year','src_rd'])
              .drop_duplicates(['company_name','iso3','year'], keep='first').drop(columns=['src_rd'])
              ) if rd_candidates else pd.DataFrame(columns=['company_name','iso3','year','rd_exp_usd'])

    # Merge all company-year pieces
    pieces = [rev, emp_long, pbt_long, ce_long, rd_ratio_long, rd_abs]
    pieces = [d for d in pieces if len(d)]
    long = pieces[0] if pieces else pd.DataFrame(columns=['company_name','iso3','year'])
    for d in pieces[1:]:
        if d.duplicated(['company_name','iso3','year']).any():
            d = d.groupby(['company_name','iso3','year'], as_index=False).agg('first')
        long = pd.merge(long, d, on=['company_name','iso3','year'], how='outer')

    # Keep analysis window
    long = long[long['year'].apply(year_in_bounds)]

    # Backfill rd_ratio from rd_exp_usd / revenue_usd when possible
    if {'revenue_usd','rd_exp_usd'}.issubset(long.columns):
        cond = (long.get('rd_ratio', pd.Series(index=long.index)).isna()) & (long['revenue_usd']>0) & (long['rd_exp_usd']>0)
        if 'rd_ratio' not in long.columns:
            long['rd_ratio'] = np.nan
        long.loc[cond, 'rd_ratio'] = (long.loc[cond, 'rd_exp_usd'] / long.loc[cond, 'revenue_usd']).clip(upper=5.0)

    # Optional within-firm interpolation of rd_ratio (short gaps)
    if IMPUTE_RD and 'rd_ratio' in long.columns:
        long = long.sort_values(['iso3','company_name','year'])
        filled = []
        for (iso, firm), g in long.groupby(['iso3','company_name'], as_index=False):
            gg = g.copy()
            gg['rd_ratio'] = gg['rd_ratio'].interpolate(method='linear', limit=2, limit_direction='both')
            gg['rd_ratio'] = gg['rd_ratio'].ffill(limit=1).bfill(limit=1)
            filled.append(gg)
        long = pd.concat(filled, ignore_index=True)

    # Dynamic aggregation — include only columns that exist
    agg_specs = {
        'revenue_usd':'sum','employees':'sum','pbt_thusd':'sum','ce_thusd':'sum',
        'rd_ratio':'mean','rd_exp_usd':'sum'
    }
    agg = {k:v for k,v in agg_specs.items() if k in long.columns}

    comp_year = long.groupby(['iso3','company_name','year'], as_index=False).agg(agg)

    # Country-year aggregates
    def nz_mean(s):
        s = pd.to_numeric(s, errors='coerce'); return s.replace(0, np.nan).mean()

    cy_agg_specs = {'revenue_usd':'sum','employees':'sum','pbt_thusd':'sum','ce_thusd':'sum','rd_ratio': nz_mean,'rd_exp_usd':'sum'}
    cy_agg = {k:v for k,v in cy_agg_specs.items() if k in comp_year.columns}
    country_year = comp_year.groupby(['iso3','year'], as_index=False).agg(cy_agg)

    # Diagnostics for RD coverage
    rd_direct_cov = rd_ratio_long['rd_ratio'].notna().mean() if len(rd_ratio_long) else 0.0
    rd_abs_cov    = rd_abs['rd_exp_usd'].notna().mean() if len(rd_abs) else 0.0
    rd_final_cov  = comp_year['rd_ratio'].notna().mean() if 'rd_ratio' in comp_year.columns else 0.0

    print(f"Company-year built: n={len(comp_year):,}, firms={comp_year['company_name'].nunique():,}, years={int(comp_year['year'].min())}–{int(comp_year['year'].max())}")
    print(f"R&D ratio coverage → direct: {rd_direct_cov:.3f}, abs-exp: {rd_abs_cov:.3f}, final (after backfill/impute): {rd_final_cov:.3f}")

    return comp_year, country_year

# ---------- ILO ----------
def load_ilo_days(path: str) -> pd.DataFrame:
    df = read_any_tabular(path); raw_cols = df.columns
    df_l = cols_lower_strip(df)

    cand_iso = [c for c in df_l.columns if c in ['iso3','iso','country_iso3','country code','country_code','country iso3','country iso code','ref_area','code']]
    cand_year = [c for c in df_l.columns if c in ['year','yr','time','time_period']]
    cand_f = [c for c in df_l.columns if ('female' in c and 'days' in c) or c in ['days_lost_female','lost_days_female','dayslost_female','female_days_lost']]
    cand_m = [c for c in df_l.columns if ('male' in c and 'days' in c) or c in ['days_lost_male','lost_days_male','dayslost_male','male_days_lost']]

    if cand_iso and cand_year and cand_f and cand_m:
        out = df_l[[cand_iso[0], cand_year[0], cand_f[0], cand_m[0]]].copy()
        out.columns = ['iso3','year','days_lost_female','days_lost_male']
        out['iso3'] = out['iso3'].astype(str).map(normalize_iso3)
    elif 'country' in df_l.columns and {'year','days_lost_female','days_lost_male'}.issubset(df_l.columns):
        out = df_l[['country','year','days_lost_female','days_lost_male']].copy()
        out['iso3'] = out['country'].map(NAME_TO_ISO3)
        if out['iso3'].isna().any():
            miss = sorted(out.loc[out['iso3'].isna(),'country'].unique().tolist())
            raise RuntimeError(f"[ILO] Unmapped country names: {miss[:10]} (total {len(miss)})")
        out = out[['iso3','year','days_lost_female','days_lost_male']].copy()
    else:
        sex_col = next((c for c in df_l.columns if c in ['sex','sex_code','sex code']), None)
        val_col = next((c for c in df_l.columns if c in ['value','obs_value','obs value','obs_value_num','obs_value_numeric']), None)
        iso_col  = cand_iso[0] if cand_iso else None
        year_col = cand_year[0] if cand_year else None
        if iso_col is None or year_col is None or sex_col is None or val_col is None:
            print(f"[ILO] Could not auto-detect columns. Raw columns: {list(raw_cols)[:12]} ...")
            return pd.DataFrame(columns=['iso3','year','days_lost_female','days_lost_male'])
        tmp = df_l[[iso_col, year_col, sex_col, val_col]].copy()
        tmp.columns = ['iso3','year','sex','val']
        tmp['iso3'] = tmp['iso3'].astype(str).map(normalize_iso3)
        tmp['year'] = pd.to_numeric(tmp['year'], errors='coerce').astype('Int64')
        tmp = tmp.dropna(subset=['iso3','year']).copy(); tmp['year'] = tmp['year'].astype(int)
        def sx(s):
            s = str(s).strip().upper()
            return 'female' if s in ['F','FEMALE','2'] else ('male' if s in ['M','MALE','1'] else None)
        tmp['sx'] = tmp['sex'].map(sx); tmp = tmp[~tmp['sx'].isna()].copy()
        tmp = tmp[(tmp['year']>=1990) & (tmp['year']<=2024)]
        tmp['val'] = pd.to_numeric(tmp['val'], errors='coerce')
        out = tmp.pivot_table(index=['iso3','year'], columns='sx', values='val', aggfunc='last').reset_index()
        if 'female' not in out.columns: out['female'] = np.nan
        if 'male' not in out.columns: out['male'] = np.nan
        out = out.rename(columns={'female':'days_lost_female','male':'days_lost_male'})

    out['year'] = pd.to_numeric(out['year'], errors='coerce').astype('Int64')
    out = out.dropna(subset=['year']).copy(); out['year'] = out['year'].astype(int)
    out = out[(out['year']>=1990) & (out['year']<=2024)]
    out = out[out['iso3'].astype(str).str.fullmatch(r'[A-Z]{3}') == True].copy()

    if IMPUTE_ILO:
        out = out.sort_values(['iso3','year']).reset_index(drop=True)
        cols = ['days_lost_female','days_lost_male']
        filled = []
        for iso, g in out.groupby('iso3', as_index=False):
            gg = g.copy()
            for c in cols:
                gg[c] = gg[c].interpolate(method='linear', limit=2, limit_direction='both')
                gg[c] = gg[c].ffill().bfill(limit=1)
            filled.append(gg)
        out = pd.concat(filled, ignore_index=True)

    def zero_placeholder_fix(g):
        both_zero = (g['days_lost_female'].fillna(0)==0) & (g['days_lost_male'].fillna(0)==0)
        any_pos = ((g['days_lost_female']>0) | (g['days_lost_male']>0)).any()
        if any_pos: g.loc[both_zero, ['days_lost_female','days_lost_male']] = np.nan
        return g
    out = out.groupby('iso3', group_keys=False).apply(zero_placeholder_fix)

    if len(out):
        y0, y1 = int(out['year'].min()), int(out['year'].max())
        print(f"ILO loaded: n={len(out):,}, countries={out['iso3'].nunique():,}, years={y0}-{y1}")
    else:
        print("ILO loaded: n=0 (no valid iso3/year rows after cleaning).")
    return out

# ---------- WDI ----------
def load_wdi(path: str) -> pd.DataFrame:
    import re
    if not os.path.isfile(path):
        print("[WDI] File not found; returning empty.")
        return pd.DataFrame()
    try:
        ext = os.path.splitext(path)[1].lower()
        if ext in ('.xlsx','.xls'):
            xl = pd.ExcelFile(path)
            if 'Data' in xl.sheet_names:
                df = pd.read_excel(xl, 'Data')
                if {'Series Code','Country Code'}.issubset(df.columns):
                    year_cols = [c for c in df.columns if re.fullmatch(r'\d{4} \[YR\d{4}\]', str(c))]
                    if year_cols:
                        long = df.melt(id_vars=['Country Code','Series Code'], value_vars=year_cols,
                                       var_name='year_label', value_name='value')
                        long['year'] = long['year_label'].str.extract(r'^(\d{4})').astype(int)
                        long = long.rename(columns={'Country Code':'iso3','Series Code':'indicator'})
                        long['iso3'] = long['iso3'].astype(str).str.strip().str.upper()
                        long['value'] = pd.to_numeric(long['value'], errors='coerce')
                        long = long[(long['year']>=1960) & (long['year']<=2025)]
                        w = (long.pivot_table(index=['iso3','year'], columns='indicator', values='value', aggfunc='last')
                             .reset_index())
                        print(f"[WDI] Loaded Data sheet: rows={len(w):,}, indicators={w.shape[1]-2}")
                        return w
            for sh in xl.sheet_names:
                df = cols_lower_strip(pd.read_excel(xl, sh))
                if {'iso3','year','indicator','value'}.issubset(df.columns):
                    df['iso3'] = df['iso3'].astype(str).str.strip().str.upper()
                    df['year'] = pd.to_numeric(df['year'], errors='coerce').astype('Int64')
                    df = df.dropna(subset=['iso3','year']).copy(); df['year'] = df['year'].astype(int)
                    df = df[(df['year']>=1960)&(df['year']<=2025)]
                    w = df.pivot_table(index=['iso3','year'], columns='indicator', values='value', aggfunc='last').reset_index()
                    print(f"[WDI] Loaded long sheet '{sh}': rows={len(w):,}, indicators={w.shape[1]-2}")
                    return w
            for sh in xl.sheet_names:
                df = cols_lower_strip(pd.read_excel(xl, sh))
                if {'iso3','year'}.issubset(df.columns):
                    id_cols = ['iso3','year']; val_cols = [c for c in df.columns if c not in id_cols]
                    for c in val_cols: df[c] = pd.to_numeric(df[c], errors='coerce')
                    df['iso3'] = df['iso3'].astype(str).str.strip().str.upper()
                    df['year'] = pd.to_numeric(df['year'], errors='coerce')
                    df = df.dropna(subset=['iso3','year']).copy(); df['year'] = df['year'].astype(int)
                    print(f"[WDI] Loaded wide sheet '{sh}': rows={len(df):,}, indicators={len(val_cols)}")
                    return df
        else:
            df = pd.read_csv(path, low_memory=False)
            if {'Series Code','Country Code'}.issubset(df.columns):
                year_cols = [c for c in df.columns if re.fullmatch(r'\d{4} \[YR\d{4}\]', str(c))]
                if year_cols:
                    long = df.melt(id_vars=['Country Code','Series Code'], value_vars=year_cols,
                                   var_name='year_label', value_name='value')
                    long['year'] = long['year_label'].str.extract(r'^(\d{4})').astype(int)
                    long = long.rename(columns={'Country Code':'iso3','Series Code':'indicator'})
                    long['iso3'] = long['iso3'].astype(str).str.strip().str.upper()
                    long['value'] = pd.to_numeric(long['value'], errors='coerce')
                    long = long[(long['year']>=1960) & (long['year']<=2025)]
                    w = (long.pivot_table(index=['iso3','year'], columns='indicator', values='value', aggfunc='last')
                         .reset_index())
                    print(f"[WDI] Loaded CSV: rows={len(w):,}, indicators={w.shape[1]-2}")
                    return w
            df = cols_lower_strip(df)
            if {'iso3','year','indicator','value'}.issubset(df.columns):
                df['iso3'] = df['iso3'].astype(str).str.strip().str.upper()
                df['year'] = pd.to_numeric(df['year'], errors='coerce').astype('Int64')
                df = df.dropna(subset=['iso3','year']).copy(); df['year'] = df['year'].astype(int)
                w = df.pivot_table(index=['iso3','year'], columns='indicator', values='value', aggfunc='last').reset_index()
                print(f"[WDI] Loaded CSV long: rows={len(w):,}, indicators={w.shape[1]-2}")
                return w
    except Exception as e:
        print(f"[WDI] Error reading: {e}")
    print("[WDI] Could not infer format; returning empty.")
    return pd.DataFrame()

# ---------- rolling ----------
def roll3_centered(df: pd.DataFrame, id_cols=('iso3','year')):
    df = df.copy()
    num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    num_cols = [c for c in num_cols if c not in id_cols]
    out = []
    for iso, g in df.groupby('iso3', as_index=False):
        g = g.sort_values('year')
        rolled = g[['year']+num_cols].rolling(window=3, center=True, min_periods=2).mean()
        gr = pd.concat([g[list(id_cols)], rolled[num_cols]], axis=1)
        out.append(gr)
    return pd.concat(out, axis=0, ignore_index=True)

def main():
    os.makedirs(OUTDIR, exist_ok=True)

    # 1) Orbis
    comp_year, country_year = build_company_year_from_orbis(ORBISEXCEL)
    print(f"Country-year built: n={len(country_year):,}, countries={country_year['iso3'].nunique():,}, years={int(country_year['year'].min())}–{int(country_year['year'].max())}")

    # 2) ILO
    ilo = load_ilo_days(ILOFILE)

    # 3) WDI
    wdi = load_wdi(WDIFILE)

    # 4) Overlap diag
    print("\n=== OVERLAP REPORT ===")
    orb_c = set(country_year['iso3'].unique())
    ilo_c = set(ilo['iso3'].unique()) if len(ilo) else set()
    print(f"Countries in Orbis country-year: {len(orb_c)}")
    print(f"Countries in ILO:              {len(ilo_c)}")
    common = sorted(orb_c & ilo_c)
    print(f"Common countries:              {len(common)}")
    print(f"Common list (first 20): {common[:20]}")
    if len(ilo):
        print("Year overlap:", f"{country_year['year'].min()}–{country_year['year'].max()} vs ILO {ilo['year'].min()}–{ilo['year'].max()} -> common {MERGE_YMIN}–{MERGE_YMAX}")
    else:
        print("Year overlap: (ILO empty)")

    # 5) Orbis roll3 window
    cy = country_year[(country_year['year']>=MERGE_YMIN) & (country_year['year']<=MERGE_YMAX)].copy()
    cy['iso3'] = cy['iso3'].astype(str).map(normalize_iso3)
    cy_r3 = roll3_centered(cy, id_cols=('iso3','year'))

    # 6) Merge with ILO per JOIN_MODE
    if len(ilo):
        ilo2 = ilo[(ilo['year']>=MERGE_YMIN) & (ilo['year']<=MERGE_YMAX)].copy()
        merged = (pd.merge(cy_r3, ilo2, on=['iso3','year'], how='left')
                  if JOIN_MODE.lower()=='left' else
                  pd.merge(ilo2, cy_r3, on=['iso3','year'], how='inner'))
    else:
        merged = cy_r3.copy()
        merged['days_lost_female'] = np.nan
        merged['days_lost_male'] = np.nan

    # 7) Merge WDI (left)
    if len(wdi):
        wdi2 = wdi.copy()
        wdi2['iso3'] = wdi2['iso3'].astype(str).map(normalize_iso3)
        if 'year' in wdi2.columns:
            wdi2['year'] = pd.to_numeric(wdi2['year'], errors='coerce')
            wdi2 = wdi2.dropna(subset=['year']).copy(); wdi2['year'] = wdi2['year'].astype(int)
            wdi2 = wdi2[(wdi2['year']>=MERGE_YMIN)&(wdi2['year']<=MERGE_YMAX)]
        merged = pd.merge(merged, wdi2, on=['iso3','year'], how='left')
    else:
        print("[WDI] merged: skipped (empty)")

    # 8) Exclude ISR
    merged = merged[merged['iso3']!='ISR'].copy()

    # 9) Regions + aggregates
    merged = add_region(merged, 'iso3')
    reg_agg = merged.groupby(['region','year'], as_index=False).median(numeric_only=True)

    # 10) Placeholder-zero cleanup
    zero_block = [c for c in ['revenue_usd','employees','pbt_thusd','ce_thusd'] if c in merged.columns]
    def zero_block_fix(g):
        if not zero_block: return g
        block = g[zero_block].fillna(0)
        all_zero = (block.sum(axis=1)==0)
        g.loc[all_zero, zero_block] = np.nan
        return g
    merged = merged.groupby(['iso3'], group_keys=False).apply(zero_block_fix)

    # convenience total
    if {'days_lost_female','days_lost_male'}.issubset(merged.columns):
        merged['ilo_days_total'] = merged[['days_lost_female','days_lost_male']].sum(axis=1, min_count=1)

    # 11) Write outputs
    out1 = os.path.join(OUTDIR, "orbis_company_year.csv")
    out2 = os.path.join(OUTDIR, "orbis_country_year.csv")
    out3 = os.path.join(OUTDIR, "orbis_country_year_roll3.csv")
    out4 = os.path.join(OUTDIR, "merged_ilo_orbis_roll3_wdi.csv")
    out5 = os.path.join(OUTDIR, "regional_aggregates_roll3.csv")
    out6 = os.path.join(OUTDIR, "overlap_diagnostics.csv")

    comp_year.to_csv(out1, index=False)
    country_year.to_csv(out2, index=False)
    cy_r3.to_csv(out3, index=False)
    merged.to_csv(out4, index=False)
    reg_agg.to_csv(out5, index=False)

    diag_rows = [{
        'orbis_countries': country_year['iso3'].nunique(),
        'ilo_countries': ilo['iso3'].nunique() if len(ilo) else 0,
        'common_countries': len(set(country_year['iso3']) & set(ilo['iso3'])) if len(ilo) else 0,
        'orbis_year_min': int(country_year['year'].min()),
        'orbis_year_max': int(country_year['year'].max()),
        'ilo_year_min': int(ilo['year'].min()) if len(ilo) else np.nan,
        'ilo_year_max': int(ilo['year'].max()) if len(ilo) else np.nan,
        'merged_rows': len(merged),
        'rd_ratio_company_nonmiss': float(comp_year['rd_ratio'].notna().mean()) if 'rd_ratio' in comp_year.columns else 0.0
    }]
    pd.DataFrame(diag_rows).to_csv(out6, index=False)

    # 12) Console summary
    print("\n=== DIAGNOSTICS ===")
    print(f"Company-year:                 n={len(comp_year):,}, firms={comp_year['company_name'].nunique():,}, years={int(comp_year['year'].min())}–{int(comp_year['year'].max())}")
    print(f"Country-year:                 n={len(country_year):,}, countries={country_year['iso3'].nunique():,}, years={int(country_year['year'].min())}–{int(country_year['year'].max())}")
    print(f"Country-year roll-3:          n={len(cy_r3):,}, countries={cy_r3['iso3'].nunique():,}, years={int(cy_r3['year'].min())}–{int(cy_r3['year'].max())}")
    print(f"Merged Orbis×ILO×WDI:         n={len(merged):,}, countries={merged['iso3'].nunique():,}, years={int(merged['year'].min())}–{int(merged['year'].max())}")
    if 'region' in merged.columns:
        print(f"Regions in merged:            {sorted(merged['region'].unique())}")

    print("\n=== DONE ===")
    print(f" - {out1}\n - {out2}\n - {out3}\n - {out4}\n - {out5}\n - {out6}")

if __name__ == "__main__":
    main()
