import pandas as pd
import numpy as np
import os
import geopandas as gpd
from pyproj import Geod
from tqdm import tqdm
from datetime import datetime
import pycountry
from shapely.geometry import Point, LineString, MultiLineString
from shapely.wkt import loads
import math
from math import radians, cos, sin, asin, sqrt
from global_land_mask import globe
import reverse_geocoder

def get_modal_dict():
    cur_path = os.path.abspath(os.getcwd())
    filepath = cur_path + '/Databases/EUROSTAT_transport_mode_split.xlsx'
    splits = pd.read_excel(filepath).set_index('Country')[['Rail','Roads','Inland waterways','Sea']]/100
    modal_dict = splits.T.to_dict()
    return modal_dict

def import_material_list(steel_and_aggregated_only=False, exclude_materials=[]):
    cur_path = os.path.abspath(os.getcwd())
    filepath = cur_path+ '//LCI_data//fossil_infrastructure_inventory.xlsx'
    materials = list(pd.read_excel(filepath, sheet_name='LCI_per_capacity').iloc[:,13:].columns)
    if steel_and_aggregated_only:
        materials = [mat for mat in materials if 'steel' in mat or mat[0].isupper()]
    materials = [mat for mat in materials if mat not in exclude_materials]
    return materials

def switch_lat_lon(coords):
    return [(lat, lon) for lon, lat in coords]

def is_land_or_water(lat, lon):
    is_land = globe.is_land(lat, lon)
    if is_land:
        return 'onshore'
    else:
        return 'offshore'

def return_stock_point_of_line(line_string):
    try:
        coords= list(line_string.coords)
        point = coords[int(len(coords)/2)]
    except:
        lines = line_string.geoms
        maxi  = 0
        for l, i in zip(lines, range(len(lines))):
            if len(l.coords)>maxi:
                maxi = len(l.coords)
                idx=i
        point = return_stock_point_of_line(lines[idx])
    return point

def calculate_stock_distribution(datasets, materials):
    total_distribution_df = pd.DataFrame()
    for material in tqdm(materials):
        infrastructure_totals = []
        for data in datasets:
            total_distribution = sum([unc_array * stock for unc_array, stock in zip(data['{} distribution [-]'.format(material)], data['{} [kg]'.format(material)])])
            data['{} median [kg]'.format(material)] = [np.median(unc_array * stock) for unc_array, stock in zip(data['{} distribution [-]'.format(material)], data['{} [kg]'.format(material)])]
            infrastructure_totals.append(total_distribution)
        total_distribution = np.array(infrastructure_totals).sum(axis=0)
        total_distribution_df[material+' [kg]'] = total_distribution
    return total_distribution_df

def generate_LCI_uncertainty_distributions(LCI_sigma_data_dict, size):    
    LCI_distribution_dict = {}
    for material in LCI_sigma_data_dict.keys():
        LCI_distribution_dict[material] = {}
        for infr in LCI_sigma_data_dict[material].keys():
            sigma = LCI_sigma_data_dict[material][infr]
            np.random.seed(0)
            distr = np.random.lognormal(mean = np.log(1.0), sigma = sigma, size = size)
            LCI_distribution_dict[material][infr] = distr
    return LCI_distribution_dict

def generate_linking_uncertainty_distributions(linking_uncertainty_dict, size):    
    linking_distribution_dict = {}
    for infr in linking_uncertainty_dict.keys():
        typ  = linking_uncertainty_dict[infr]['linking uncertainty type']
        low  = linking_uncertainty_dict[infr]['min']
        high = linking_uncertainty_dict[infr]['max']
        if typ=='uniform':
            np.random.seed(0)
            distr = np.random.uniform(low,high,size)
        else:
            raise Exception('Unknown uncertainty type', typ)
        linking_distribution_dict[infr] = distr
        
    return linking_distribution_dict

def add_stocks(df, materials, scaling_cols, kg_concrete_per_kg_cement):
    for material in materials:
        material_cols = [col for col in df.columns if (material+' ' in col and '/' in col and not '[kg]' in col)]  
        try:
            df[material+' [kg]'] = (df[material_cols].to_numpy() * df[scaling_cols].to_numpy()).sum(axis=1)
        except:
            raise Exception(material_cols, scaling_cols)
        
        if material=='Concrete':
            cement_cols = [col for col in df.columns if 'Cement [kg/' in col]
            try:
                df[material+' [kg]']+=  (df[cement_cols].to_numpy() *  df[scaling_cols].to_numpy()).sum(axis=1)
            except:
                raise Exception(material_cols, scaling_cols)

    return df[[col for col in df if '[kg/']]

def check_materials(materials):
    if 'Concrete' in materials and 'Cement' not in materials:
        raise Exception('List cement if concrete is assessed', materials)

def process_matrix(matrix):    
    result = []
    for row in matrix:
        try:
            if np.any(row < 1):
                result.append(np.min(row[row < 1]))
            elif np.any(row > 1):
                result.append(np.max(row[row > 1]))
            else:
                result.append(np.mean(row))
        except:
            raise Exception(row)
    return np.array(result)    

def add_normalized_uncertainty_distribution(data, LCI_distribution_dict, linking_distribution_dict, materials, infrastructure, size):    
    for material in tqdm(materials):
            
        if infrastructure in ['pipeline', 'power plant', 'coal mine']:
            if infrastructure=='power plant':
                type_key = 'class'
            else:
                type_key = 'type'

            uncertainty_arrays = [LCI_distribution_dict[material][(infrastructure, fuel, clas, loc)] * linking_distribution_dict[(infrastructure, fuel, clas, loc)]
                                          for fuel, clas, loc in zip(data.fuel, data[type_key], data.ecoinvent_location)]

            data['{} distribution [-]'.format(material)] = uncertainty_arrays
            
        if infrastructure=='wells':
            fuel_uncertainty_arrays = []
            for fuel in ['natural gas', 'oil']:
       
                uncertainty_arrays = [LCI_distribution_dict[material][(infrastructure, fuel, clas, loc)] * linking_distribution_dict[(infrastructure, fuel, clas, loc)]
                                          for clas, loc in zip(data['type'], data.ecoinvent_location)]
                
                fuel_uncertainty_arrays.append(uncertainty_arrays)

            data['{0} distribution [-]'.format(material)] = [process_matrix(np.array([np.sort(distr1), np.sort(distr2)]).T) for distr1, distr2 in zip(fuel_uncertainty_arrays[0], fuel_uncertainty_arrays[1])]
    return data

def get_LCI_data_including_uncertainties():
    cur_path = os.path.abspath(os.getcwd())
    
    LCI_data       = pd.read_excel(cur_path+'/LCI_data/fossil_infrastructure_inventory.xlsx',
                            sheet_name = 'LCI_per_capacity')
    LCI_sigma_data = pd.read_excel(cur_path+'/LCI_data/fossil_infrastructure_inventory.xlsx',
                            sheet_name = 'LCI_sigma', header=0)
    density_data   = pd.read_excel(cur_path+'/LCI_data/fossil_infrastructure_inventory.xlsx',
                            sheet_name = 'densities', header=0)
    
    LCI_dict            = LCI_data.set_index(['type', 'fuel', 'class','location']).iloc[:,9:].to_dict()
    LCI_sigma_data_dict = LCI_sigma_data.set_index(['type', 'fuel', 'class','location']).iloc[:,3:].to_dict()
    linking_uncertainty_dict = LCI_data.set_index(['type', 'fuel', 'class','location']).iloc[:,2:6].T.to_dict()
    density_dict         = density_data.set_index('material').T.to_dict()
    return LCI_dict, LCI_sigma_data_dict, linking_uncertainty_dict, density_dict

def add_material_intensity(data, LCI_dict, materials, infrastructure, density_dict):
    
    for material in materials:
        if material in density_dict.keys():
            density_factor = density_dict[material]['mode']
        else:
            density_factor = 1.0
        
        if infrastructure=='power plant':
            data['{} [kg/MW]'.format(material)] = [LCI_dict[material][(infrastructure, fuel, clas, loc)]*density_factor for fuel, clas, loc in zip(data.fuel, data['class'], data.ecoinvent_location)]
    
        if infrastructure=='pipeline':
            data['{} [kg/km]'.format(material)] = [LCI_dict[material][(infrastructure, fuel, clas, loc)]*density_factor for fuel, clas, loc in zip(data.fuel, data['type'], data.ecoinvent_location)]
    
        if infrastructure=='coal mine':
            data['{} [kg/(Mt/a)]'.format(material)] = [LCI_dict[material][(infrastructure, fuel, clas, loc)]*density_factor for fuel, clas, loc in zip(data.fuel, data['type'], data.ecoinvent_location)]
    
        if infrastructure=='wells':
            gas_well_intensities = np.array([LCI_dict[material][(infrastructure, 'natural gas', clas, loc)]*density_factor for clas, loc in zip(data['type'], data.ecoinvent_location)])
            oil_well_intensities = np.array([LCI_dict[material][(infrastructure, 'oil', clas, loc)]*density_factor for clas, loc in zip(data['type'], data.ecoinvent_location)])
            data['{} [kg/kg]'.format(material)] = oil_well_intensities
            data['{} [kg/m3]'.format(material)] = gas_well_intensities

    return data

def split_line_into_line_segments(line):
    # Extract the coordinates
    coords = list(line.coords)

    # Create line segments
    segments = [LineString([coords[i], coords[i + 1]]) for i in range(len(coords) - 1)]
    return segments

def segment_pipelines(pipelines):
    pipelines['Lines'] = [multiline_to_lines(data) for data in pipelines['WKTFormat']]
    exp_pipelines = pipelines.explode('Lines')
    exp_pipelines['Section Length [km]'] = [get_length_of_linestring_in_km(line) for line in exp_pipelines['Lines']]
    exp_pipelines['Line Segments'] = [split_line_into_line_segments(line) for line in exp_pipelines['Lines']]
    segmented_exp_pipelines = exp_pipelines.explode('Line Segments')
    segmented_exp_pipelines['Section Length [km]'] = [get_length_of_linestring_in_km(line) for line in segmented_exp_pipelines['Line Segments']]
    segmented_exp_pipelines['PointOnSurface'] = [(line.point_on_surface().x, line.point_on_surface().y) for line in segmented_exp_pipelines['Line Segments']]
    segmented_exp_pipelines['Segment Country'] = points_to_countries(list(segmented_exp_pipelines['PointOnSurface']))
    segmented_exp_pipelines['% onshore'] = [float(sum([globe.is_land(lon = p[0], lat= p[1])])) for p in segmented_exp_pipelines['PointOnSurface']]
    
    columns = ['ProjectID','PipelineName', 'fuel', 'ecoinvent_location', 'type',
           'Status', 'Section Length [km]', 'Line Segments', '% onshore',
           'PointOnSurface', 'Segment Country']
    
    segmented_pipelines = segmented_exp_pipelines[[col for col in columns if col in segmented_exp_pipelines.columns]]
    segmented_pipelines['Longitude'] = [p[0] for p in segmented_pipelines['PointOnSurface']]
    segmented_pipelines['Latitude']  = [p[1] for p in segmented_pipelines['PointOnSurface']]
    
    return segmented_pipelines

def multiline_to_lines(val):
    try:
        return list(val.geoms)
    except:
        return val

def points_to_countries(coords):
    countries = [alpha_2_to_name(response['cc']) for response in reverse_geocoder.search(switch_lat_lon(coords))]
    return countries

def alpha_2_to_name(alpha2):
    return pycountry.countries.get(alpha_2=alpha2).name

def get_production_start_year_dict(old_wells):
    prod_year_data = old_wells[old_wells['Status'].isin(['operating', 'mothballed', 'idle'])][['Unit ID', 'Production start year']]
    
    prod_year_data['Production start year'] = pd.to_numeric([val if (type(val)==float or type(val)==int) 
                                                 else val.replace(' (expected)','') 
                                             for val in prod_year_data['Production start year']])
    
    mean_production_year = int(pd.to_numeric(prod_year_data['Production start year']).mean())
    prod_year_data = prod_year_data.fillna(mean_production_year)
    production_start_year_dict = prod_year_data.set_index('Unit ID').to_dict()['Production start year']
    return production_start_year_dict

def get_EAF_from_GEM(): #Extracting locations of EAF steel plants from GEM
    cur_path = os.path.abspath(os.getcwd())
    steelplants_filepath = cur_path+"/Databases/Global-Steel-Plant-Tracker-April-2024-Standard-Copy-V1.xlsx"
    steel_dataset = pd.read_excel(steelplants_filepath, sheet_name='Steel Plants')
    steel_dataset = steel_dataset[['Plant ID', 'Country/Area', 'Region', 'Coordinates', 'Capacity operating status',
                                   'Nominal EAF steel capacity (ttpa)', 'Other/unspecified steel capacity (ttpa)', 
                                   'Main production equipment', 'Detailed production equipment']]

    steel_dataset['Other/unspecified steel capacity (ttpa)'] = steel_dataset['Other/unspecified steel capacity (ttpa)'].replace('>0', 0)

    steel_dataset = steel_dataset[(steel_dataset['Capacity operating status'].isin(['operating', 'mothballed', 'operating pre-retirement'])) 
        & ((steel_dataset['Main production equipment'].str.contains('EAF')) 
        | (steel_dataset['Detailed production equipment'].str.contains('IF')))].fillna(0)

    lat, lon = zip(*[tuple([float(ele) for ele in coord.split(',')]) for coord in steel_dataset['Coordinates'].to_list()])
    lat = list(lat)
    lon = list(lon)
    steel_dataset['Latitude'] =lat
    steel_dataset['Longitude'] =lon
    return steel_dataset.set_index('Plant ID')

def get_EAF_from_GEM_gdp(): #Extracting locations of EAF steel plants from GEM
    steel_dataset = get_EAF_from_GEM()
    geo_steel_dataset = gpd.GeoDataFrame(steel_dataset)
    geometry = gpd.points_from_xy(geo_steel_dataset.Longitude, geo_steel_dataset.Latitude, crs="EPSG:4326")
    geo_steel_dataset['geometry'] = geometry
    return geo_steel_dataset

def extracting_locations_from_power_plants(): # Extracting locations of power plants from GEM
    cur_path = os.path.abspath(os.getcwd())
    #Coal Power Plants
    
    fuel_type_dict = pd.read_csv(cur_path+'/LCI_data/dict_fuel_type.csv', index_col=0).to_dict()['fuels']
    is_gas = [key for key in fuel_type_dict.keys() if fuel_type_dict[key]=='gas']
    is_oil = [key for key in fuel_type_dict.keys() if fuel_type_dict[key]=='oil']
    
    filepath1 =cur_path+'/Databases/Global-Coal-Plant-Tracker-January-2024.xlsx'
    GCPT = pd.read_excel(filepath1, sheet_name='Units')
    GCPT = GCPT[GCPT['Status'].isin(['operating', 'mothballed', 'idle'])]
    GCPT = GCPT[['Country','Region', 'GEM unit/phase ID','Capacity (MW)' , 'Latitude', 'Longitude', 'Coal type']]    

    
    GCPT['fuel']     = [fuel_type_dict[fuel] for fuel in GCPT['Coal type']]
    GCPT['coordinates'] = list(zip(GCPT['Latitude'], GCPT['Longitude']))
    GCPT = GCPT[['Country','Region', 'GEM unit/phase ID', 'Capacity (MW)', 'Latitude', 'Longitude', 'fuel','coordinates']]
    GCPT = GCPT.rename(columns={'GEM unit/phase ID': 'ID'})
    #GCPT['Technology'] = ['unspecified']*len(GCPT)

    filepath2 =cur_path+'/Databases/Global-Oil-and-Gas-Plant-Tracker-GOGPT-February-2024-v4.xlsx'
    GOGPT = pd.read_excel(filepath2, sheet_name='Gas & Oil Units')
    GOGPT = GOGPT[GOGPT['Status'].isin(['operating', 'mothballed', 'idle'])]

    GOGPT = GOGPT[['Country','Region', 'GEM unit ID','Capacity (MW)' , 'Latitude', 'Longitude', 'Fuel', 'Technology']]
    GOGPT['is_gas'] = [len(set(is_gas).intersection(set(fuel.split('/'))))>0 for fuel in GOGPT['Fuel']]
    GOGPT['is_oil'] = [len(set(is_oil).intersection(set(fuel.split('/'))))>0 for fuel in GOGPT['Fuel']]
    GOGPT['fuel']   = ['natural gas' if (gas & ~oil) else 'oil' if  (~gas & oil) else 'oil & natural gas' 
                               for gas,oil in zip(GOGPT['is_gas'], GOGPT['is_oil'] )]

    GOGPT['coordinates'] = list(zip(GOGPT['Latitude'], GOGPT['Longitude']))
    GOGPT = GOGPT[['Country','Region', 'GEM unit ID', 'Capacity (MW)', 'Latitude', 'Longitude', 'fuel','coordinates', 'Technology']]
    GOGPT = GOGPT.rename(columns={'GEM unit ID': 'ID'})

    GPPT = pd.concat([GCPT.reset_index(drop=True), GOGPT.reset_index(drop=True)])

    GPPT['Capacity (MW)'] = [float(val) for val in GPPT['Capacity (MW)']]
    GPPT['class'] = ['combined cycle' if (fuel=='natural gas' and 'CC' in tech) else 'standard' if fuel=='natural gas' else
                                         '<300 MW' if (fuel=='hard coal' and cap<300) else '>300 MW' if fuel=='hard coal' else fuel for 
                                          fuel, cap, tech in zip(GPPT['fuel'], GPPT['Capacity (MW)'], GPPT['Technology'])]
    
    GPPT['ecoinvent_location'] = ['GLO' if fuel=='hard coal' else 'RoW' if fuel=='oil & natural gas' else 'RER' if region=='Europe' else 'RoW' for
                              fuel, clas, region in zip(GPPT['fuel'], GPPT['class'], GPPT['Region'])]
    
    return GPPT[['Country','Region', 'ID', 'Capacity (MW)', 'Latitude', 'Longitude', 'fuel','coordinates', 'class', 'ecoinvent_location']]

def extracting_locations_from_gas_pipelines(threshold_cap, threshold_dm): #Extracting locations from gas pipelines (GEM data)
    cur_path = os.path.abspath(os.getcwd())
    filepath = cur_path+'/Databases/GEM-GGIT-Gas-Pipelines-December-2022.xlsx'
    gas_pipelines = pd.read_excel(filepath)
    
    gas_pipelines = gas_pipelines[(gas_pipelines['Status'].isin(['Idle','Mothballed','Operating'])) &
                                 (gas_pipelines['WKTFormat']!='--')]
    
    gas_pipelines['type'] = get_capacity_of_pipelines(gas_pipelines, threshold_cap, threshold_dm)

    gas_pipelines=gas_pipelines[['ProjectID','StartCountry', 'EndCountry', 'Countries','PipelineName' , 'Status', 'StartRegion',
                                 'EndRegion', 'WKTFormat', 'type']]

    def WKT_format_to_tuples(data):
        if 'MULTI' in str(data):
            raw = [coord.strip().split(' ') for coord in data.replace('MULTILINESTRING ((', '').replace('(','').replace(')','').split(',')]
        else:
            raw = [coord.strip().split(' ') for coord in data.replace('LINESTRING (','').replace(')','').split(',')]
        tuples = [(float(ele[0]), float(ele[1])) for ele in raw]
        return tuples

    gas_pipelines['WKTFormat']  = [loads(s) if type(s)==str else s for s in gas_pipelines.WKTFormat]
    gas_pipelines['Length']     = [get_length_of_linestring_in_km(ls) for ls in gas_pipelines.WKTFormat]
    gas_pipelines['% onshore']  = [percent_land(return_points_between_linestring(line)) for line in gas_pipelines.WKTFormat]
    stock_lons, stock_lats      = zip(*[return_stock_point_of_line(line_string) for line_string in gas_pipelines.WKTFormat])
    gas_pipelines['stock_lat']  = stock_lats
    gas_pipelines['stock_lon']  = stock_lons
    gas_pipelines['Latitude']   = stock_lats
    gas_pipelines['Longitude']  = stock_lons
    gas_pipelines['fuel']       = ['natural gas'] * len(gas_pipelines)
    gas_pipelines['ecoinvent_location'] = ['GLO'] * len(gas_pipelines)
    
    return gas_pipelines

def extracting_locations_from_oil_pipelines(): #Extracting locations from oil pipelines in Europe (GEM data)
    cur_path = os.path.abspath(os.getcwd())
    filepath = cur_path+'/Databases/GOIT-Oil-NGL-Pipelines-June-2022-v2.xlsx'
    oil_pipelines = pd.read_excel(filepath)

    oil_pipelines=oil_pipelines[['ProjectID','StartCountry', 'EndCountry','PipelineName', 'Status', 'StartRegion', 'EndRegion', 'WKTFormat']]

    oil_pipelines = oil_pipelines[oil_pipelines['Status'].isin(['Idle','Mothballed','Operating'])]
    oil_pipelines = oil_pipelines[oil_pipelines['WKTFormat']!='--']
    
    oil_pipelines['WKTFormat']  = [loads(s) if type(s)==str else s for s in oil_pipelines.WKTFormat]
    oil_pipelines['Length']     = [get_length_of_linestring_in_km(ls) for ls in oil_pipelines.WKTFormat]
    oil_pipelines['% onshore']  = [percent_land(return_points_between_linestring(line)) for line in oil_pipelines.WKTFormat]
    
    stock_lons, stock_lats      = zip(*[return_stock_point_of_line(line_string) for line_string in oil_pipelines.WKTFormat])
    oil_pipelines['stock_lat']  = stock_lats
    oil_pipelines['stock_lon']  = stock_lons
    #oil_pipelines = oil_pipelines[~oil_pipelines['PipelineName'].isin(exclude_pipelines)] #clean artefacts
    oil_pipelines['fuel'] = ['oil']*len(oil_pipelines)
    oil_pipelines['type'] = ['offshore' if p_onshore==0.0 else 'onshore & offshore' for p_onshore in oil_pipelines['% onshore']]
    oil_pipelines['ecoinvent_location'] = ['GLO' if typ=='offshore' else 'RER' if region=='Europe' else 'RoW' for typ, region in zip(oil_pipelines['type'], oil_pipelines['StartRegion'])]
    return oil_pipelines

def extracting_locations_from_coal_mines(coal_mine_capacity_utililsation): #Exctracting locations from coal mines in Europe (GEM data)
    cur_path = os.path.abspath(os.getcwd())
    filepath  = cur_path+'/Databases/Global-Coal-Mine-Tracker-July-2022.xlsx'
    coalmines = pd.read_excel(filepath, sheet_name = 'Global Coal Mine Tracker')
    coalmines = coalmines[['Mine IDs', 'Country', 'Status','Latitude', 'Longitude',
                           'Mine Type', 'Coal Type','Region','Production or Capacity Data (Mtpa)',
                           'Coal Output (Annual, Mt)']]

    fuel_type_dict = pd.read_csv(cur_path+'/LCI_data/dict_fuel_type.csv', index_col=0).to_dict()['fuels']

    
    coalmines = coalmines[(coalmines['Status'].isin(['Idle', 'Mothballed','Operating']))]
    
    coalmines = coalmines[coalmines['Coal Output (Annual, Mt)']!='*']
    coalmines['Capacity [Mt/a]'] = [val if var=='Capacity' else val/coal_mine_capacity_utililsation
        for val,var in zip(coalmines['Coal Output (Annual, Mt)'], coalmines['Production or Capacity Data (Mtpa)'])]
    mean_capacity_per_country_dict = coalmines.groupby('Country')['Capacity [Mt/a]'].mean().reset_index().set_index('Country').to_dict()
    
    coalmines['fuel'] = [fuel_type_dict[coal.lower()] for coal in coalmines['Coal Type']]
    coalmines['coordinates'] = list(zip(coalmines['Latitude'], coalmines['Longitude']))

    coalmines['type'] = coalmines['Mine Type'].str.lower()
    coalmines['ecoinvent_location'] = ['GLO' if (typ=='surface' and fuel=='hard coal') else 
                                       'CN' if (country=='China' and fuel=='hard coal') else 
                                       'RER' if (region=='Europe' and fuel=='lignite') else 'RoW'
               for typ, country, region, fuel in zip(coalmines['type'], coalmines['Country'], coalmines['Region'], coalmines['fuel'])]
    
    no_data = coalmines[coalmines['Capacity [Mt/a]']==0]
    if len(no_data)>0:
        print('No capacity data found for {} coal mines. Filled data gap with mean capacity in country.'.format(len(no_data)))
        coalmines['Capacity [Mt/a]'] = [cap if cap>0 else mean_capacity_per_country_dict['Capacity [Mt/a]'][country] 
                                        for cap, country in zip(coalmines['Capacity [Mt/a]'], coalmines['Country'])]
    
    return coalmines[['Mine IDs','Country','Region', 'Status','Latitude', 'Longitude', 'coordinates',
                            'Capacity [Mt/a]', 'type', 'fuel','ecoinvent_location']]

def extracting_locations_from_coal_mines(coal_mine_capacity_utililsation): #Exctracting locations from coal mines in Europe (GEM data)
    cur_path = os.path.abspath(os.getcwd())
    filepath  = cur_path+'/Databases/Global-Coal-Mine-Tracker-July-2022.xlsx'
    coalmines = pd.read_excel(filepath, sheet_name = 'Global Coal Mine Tracker')
    coalmines = coalmines[['Mine IDs', 'Country', 'Status','Latitude', 'Longitude',
                           'Mine Type', 'Coal Type','Region','Production or Capacity Data (Mtpa)',
                           'Coal Output (Annual, Mt)']]
    
    fuel_type_dict = pd.read_csv(cur_path+'/LCI_data/dict_fuel_type.csv', index_col=0).to_dict()['fuels']

    coalmines = coalmines[(coalmines['Status'].isin(['Idle', 'Mothballed','Operating']))]
    
    coalmines = coalmines[coalmines['Coal Output (Annual, Mt)']!='*']
    coalmines['Capacity [Mt/a]'] = [val if var=='Capacity' else val/coal_mine_capacity_utililsation
        for val,var in zip(coalmines['Coal Output (Annual, Mt)'], coalmines['Production or Capacity Data (Mtpa)'])]
    mean_capacity_per_country_dict = coalmines.groupby('Country')['Capacity [Mt/a]'].mean().reset_index().set_index('Country').to_dict()
    
    coalmines['fuel'] = [fuel_type_dict[coal.lower()] for coal in coalmines['Coal Type']]
    coalmines['coordinates'] = list(zip(coalmines['Latitude'], coalmines['Longitude']))

    coalmines['type'] = coalmines['Mine Type'].str.lower()
    coalmines['ecoinvent_location'] = ['GLO' if (typ=='surface' and fuel=='hard coal') else 
                                       'CN' if (country=='China' and fuel=='hard coal') else 
                                       'RER' if (region=='Europe' and fuel=='lignite') else 'RoW'
               for typ, country, region, fuel in zip(coalmines['type'], coalmines['Country'], coalmines['Region'], coalmines['fuel'])]
    
    no_data = coalmines[coalmines['Capacity [Mt/a]']==0]
    if len(no_data)>0:
        print('No capacity data found for {} coal mines. Filled data gap with mean capacity in country.'.format(len(no_data)))
        coalmines['Capacity [Mt/a]'] = [cap if cap>0 else mean_capacity_per_country_dict['Capacity [Mt/a]'][country] 
                                        for cap, country in zip(coalmines['Capacity [Mt/a]'], coalmines['Country'])]
    
    return coalmines[['Mine IDs','Country','Region', 'Status','Latitude', 'Longitude', 'coordinates',
                            'Capacity [Mt/a]', 'type', 'fuel','ecoinvent_location']]

def extracting_locations_from_oil_and_gas_wells(): #Extracting locations from oil and gas extraction (GEM data)
    import os
    cur_path = os.path.abspath(os.getcwd())
    filepath = cur_path+'/Databases/Global-Oil-and-Gas-Extraction-Tracker-March-2024.xlsx'
    wells = pd.read_excel(filepath, sheet_name = 'Main data')
    
    production_start_year_dict=  get_production_start_year_dict(wells)
    
    metadata = pd.read_excel(filepath, sheet_name = 'Production & reserves')
    metadata = metadata[['Unit ID', 'Fuel description', 'Data year','Production/reserves','Quantity (converted)', 'Units (converted)']]
    
    cumprod_dict = get_cumulative_production_dict(metadata, production_start_year_dict)

    wells = wells[['Unit ID', 'Unit type','Country','Latitude',
                   'Longitude','Status', 'Fuel type', 'Production start year']]

    wells = wells[(wells['Unit type']=='field') &
                 (wells['Status'].isin(['mothballed', 'idle', 'operating'])) 
                  #&(wells['Country'].isin(europe))
                 ].dropna()
    
    wells['type'] =  [is_land_or_water(lat, lon) for lat, lon in zip(wells['Latitude'], wells['Longitude'])]
    
    wells['cum. gas prod.'] = [cumprod_dict[ID]['gas'] if ID in cumprod_dict.keys() else np.nan
                                       for ID in wells['Unit ID']]
    wells['cum. oil prod.'] = [cumprod_dict[ID]['oil'] if ID in cumprod_dict.keys() else np.nan
                                       for ID in wells['Unit ID']]

    wells = wells.reset_index(drop=True)
    mean_prod_dict = wells[['Country', 'cum. gas prod.', 'cum. oil prod.']].groupby('Country').mean().to_dict()
    for col in ['cum. gas prod.', 'cum. oil prod.']:
        col_nr = list(wells.columns).index(col)
        for idx, country, prod in list(zip(wells.index, wells['Country'], wells[col])):
            if np.isnan(prod):
                wells.iat[idx, col_nr] = mean_prod_dict[col][country]

    no_data = sum([ID not in cumprod_dict.keys() for ID in wells['Unit ID']])
    wells['approximated'] = [0 if ID in cumprod_dict.keys() else 1 for ID in wells['Unit ID']]
    print('No production data found for {} wells. Filled data gap with mean annual production for wells in country multiplied with operating years.'.format(no_data))
    #Filling data gaps
    wells['Production start year'] = [int(val) for val in wells['Production start year']]
    mean_annual_production_per_country = get_mean_annual_production_per_country()

    updated_wells = fill_cumulative_production_data_gaps(wells, 'cum. gas prod.', 'Production start year','Country','Unit ID',
                                           mean_annual_production_per_country, get_approx_cumulative_production_of_well)
    updated_wells = fill_cumulative_production_data_gaps(wells, 'cum. oil prod.', 'Production start year','Country','Unit ID',
                                           mean_annual_production_per_country, get_approx_cumulative_production_of_well)
    
    updated_wells['coordinates'] = list(zip(wells['Latitude'], wells['Longitude']))
    updated_wells['ecoinvent_location'] = ['GLO']*len(updated_wells)
    return updated_wells

def fill_cumulative_production_data_gaps(df, col_to_update, col_to_use1, col_to_use2,ID_col, dct, func):
    df[col_to_update] = df.apply(lambda row: func(col_to_update, row[col_to_use1], row[col_to_use2], row[ID_col], dct) 
                             if row[col_to_update] == 0 else row[col_to_update], axis=1)
    return df

def get_approx_cumulative_production_of_well(col, start_year,country,ID, mean_annual_production_per_country):
    if 'gas' in col:
        fuel='natural gas'
    elif 'oil' in col:
        fuel='oil'
    else:
        raise Exception('Unknown fuel.', col)
    current_year = datetime.now().year
    lifetime =int(current_year) - int(start_year)
    try:
        annual_production = mean_annual_production_per_country[country][fuel]
    except:
        print('No production data found for {1} well in {0}. ID: {2}'.format(country, fuel, ID))
        annual_production = 0
    cumulative_production = annual_production * lifetime
    return cumulative_production

def get_mean_annual_production_per_country():
    cur_path = os.path.abspath(os.getcwd())
    filepath = cur_path+'/Databases/Global-Oil-and-Gas-Extraction-Tracker-March-2024.xlsx'
    metadata = pd.read_excel(filepath, sheet_name = 'Production & reserves')
    mean_annual_production_per_country = metadata[metadata['Production/reserves']=='production'].groupby(['Units (converted)', 'Country'])['Quantity (converted)'].mean().reset_index().set_index('Country')
    mean_annual_production_per_country['fuel'] = ['oil' if unit=='million bbl/y' else 'natural gas' for unit in mean_annual_production_per_country['Units (converted)'] ]
    grouped_mean_annual_production = mean_annual_production_per_country.groupby(['fuel','Country'])['Quantity (converted)'].sum().reset_index()
    pivoted_mean_annual_production = grouped_mean_annual_production.pivot(index='Country',columns='fuel',values='Quantity (converted)').fillna(0)
    pivoted_mean_annual_production['natural gas']*=10**6 #million m3 to cubic meters
    pivoted_mean_annual_production['oil']        *=(10**6)*136.4 #million bbl to kilogram
    mean_annual_production_dict = pivoted_mean_annual_production.T.to_dict()
    return mean_annual_production_dict

def get_production_start_year_dict(old_wells):
    prod_year_data = old_wells[old_wells['Status'].isin(['operating', 'mothballed', 'idle'])][['Unit ID', 'Production start year']]
    
    prod_year_data['Production start year'] = pd.to_numeric([val if (type(val)==float or type(val)==int) 
                                                 else val.replace(' (expected)','') 
                                             for val in prod_year_data['Production start year']])
    
    mean_production_year = int(pd.to_numeric(prod_year_data['Production start year']).mean())
    prod_year_data = prod_year_data.fillna(mean_production_year)
    production_start_year_dict = prod_year_data.set_index('Unit ID').to_dict()['Production start year']
    return production_start_year_dict

def get_cumulative_production_dict(metadata, production_start_year_dict):
    d = {}
    for ID in set(metadata['Unit ID']):
        cumprod = metadata[(metadata['Unit ID']==ID)]
        try:
            gas = max(cumprod[(cumprod['Units (converted)']=='million m³') &
                      (cumprod['Production/reserves']=='cumulative production')]['Quantity (converted)'])*10**6 #cubic meters
        except:
            gas = 0
        try:
            oil = max(cumprod[(cumprod['Units (converted)']=='million bbl') &
                      (cumprod['Production/reserves']=='cumulative production')]['Quantity (converted)'])*(10**6)*136.4 #kg
        except:
            oil = 0
        d[ID] = {'gas':gas, 'oil':oil}
    return d

def get_capacity_of_pipelines(df, threshold_cap, threshold_dm):
    capacity_types = []
    for cap, dm, dm_unit in zip(df['CapacityBcm/y'], df['Diameter'], df['DiameterUnits']):
        if cap==np.nan or cap=='--':
            cap_mncm == np.nan
        else:
            cap_mncm =cap*1000/365/24
            
        if cap_mncm!=np.nan:
            if cap_mncm < threshold_cap:
                capacity_type = 'low capacity'
            else:
                capacity_type = 'high capacity'
        else:
            if dm_unit=='in':
                dm=float(dm)*2.54*10
            elif dm_unit == 'mm':
                dm = float(dm)
            else:
                dm = np.nan
                
            if dm!=np.nan:
                if dm < threshold_dm:
                    capacity_type = 'low capacity'
                else:
                    capacity_type = 'high capacity'
            else:
                capacity_type = np.nan
        capacity_types.append(capacity_type)
    return capacity_types

def get_length_of_linestring_in_km(line_string):
    geod = Geod(ellps="WGS84")
    try:
        total_length = geod.geometry_length(line_string)/1000
    except:
        total_length = sum([get_length_of_linestring_in_km(ls) for ls in line_string.geoms])/1000
    return total_length

def distance(lat1, lat2, lon1, lon2):
     
    # The math module contains a function named
    # radians which converts from degrees to radians.
    lon1 = radians(lon1)
    lon2 = radians(lon2)
    lat1 = radians(lat1)
    lat2 = radians(lat2)
      
    # Haversine formula
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
 
    c = 2 * asin(sqrt(a))
    
    # Radius of earth in kilometers. Use 3956 for miles
    r = 6371
      
    # calculate the result
    return(c * r)

def flatten(xss):
    return [x for xs in xss for x in xs ]

def percent_land(points):
    """
    returns percentage of coordinates, that are on land mask
    """
    try:
        p = sum([globe.is_land(point[0], point[1]) for point in points])/len(points)
        return p
    except:
        try:
            p = sum([globe.is_land(point[1], point[0]) for point in points])/len(points)
            return p
        except:
            print('Problem: points are weird: ', points)
        
def return_points_between_linestring(ls):
    points = []
    try:
        coords =list(ls.coords)
        for i in range(len(coords)-1):
            points.append(return_points_between(coords[i][0], coords[i+1][0], coords[i][1], coords[i+1][1]))
        points = [tuple(coord) for coord in flatten(points)]
        return points
    except:
        points = flatten([return_points_between_linestring(ls) for ls in ls.geoms])
        return points

def get_ID_stock_country_dict(material):
    cur_path = os.path.abspath(os.getcwd())
    filepath            = cur_path+'/Results/stacked_stock_data.xlsx'
    stock_data = pd.read_excel(filepath, index_col=0)
    ID_country_dict = stock_data[['Infr. ID', 'Country', '{} median [kg]'.format(material)]].drop_duplicates().set_index('Infr. ID').to_dict()
    return ID_country_dict

def get_distance_between_points(coord1, coord2):
    try:
        d = distance(coord1[0], coord2[0], coord1[1], coord2[1])
    except:
        coord2 = (float(coord2[0].replace('\xa0','')), float(coord2[1].replace('\xa0','')))
        d = distance(coord1[0], coord2[0], coord1[1], coord2[1])
    return d
        
def return_points_between(lat1, lat2, lon1, lon2):
    step = 10 #km
    d = distance(lat1, lat2, lon1, lon2)
    interpols = int(d/step)
    if interpols == 0:
        return [(lat1, lon1), (lat2, lon2)]
    else:
        points = np.linspace([lat1, lon1], [lat2, lon2], interpols)
        return points.tolist()

def get_EAFs_data_and_distances_to_stock(material):
    cur_path   = os.path.abspath(os.getcwd())
    EAFs_data  = get_EAF_from_GEM()
    coords_eaf = EAFs_data.Coordinates
    filepath   = cur_path +'/Results/stacked_stock_data.xlsx'
    stock_data = pd.read_excel(filepath, index_col=0)
    stock_country_dict = stock_data[['Infr. ID', 'Country']].drop_duplicates().set_index('Infr. ID').to_dict()['Country']
    stock_data = stock_data[stock_data['{} median [kg]'.format(material)]!=0]
    try:
        print('Try to import distances')
        distance_df = pd.read_excel()
    except:
        distances = [[get_distance_between_points(eval(coord1), eval(coord2)) for coord2 in stock_data.coords]
                         for coord1 in coords_eaf]
        distance_df = pd.DataFrame(np.array(distances).T, index=stock_data['Infr. ID'], columns=EAFs_data.index)
    EAFs = distance_df.columns
    return EAFs_data, EAFs, distance_df, stock_country_dict

def get_stocks_per_country(material):
    cur_path   = os.path.abspath(os.getcwd())
    filepath            = cur_path +'/Results/fossil_infrastructure_material_stocks_per_country.xlsx'
    stocks_per_country  = pd.read_excel(filepath, index_col=0)
    stocks_country_dict = stocks_per_country.groupby('Country')['{} median [kg]'.format(material)].sum().to_dict()
    return stocks_country_dict

def get_mean_distance_within_quantile(df, ID, q):
    d = df[ID][df[ID]<=df[ID].quantile(q)].mean()
    if type(d)!=np.float64:
       return list(df[ID][df[ID]<=df[ID].quantile(q)].mean())[0]
    else:
        return d

def get_mean_distance_of_closest_eaf_based(df, ID, n):
    d = df[[ID]].mean(axis=1).sort_values(ascending=True)[:n].mean()
    return d

def get_mean_distance_of_closest_infr_based(df, ID, n):
    d = df.T[[ID]].mean(axis=1).sort_values(ascending=True)[:n].mean()
    return d

def replace_nan(x, mean_cap_per_country_dict):
    country = x['Country']
    if any(x.isna()):
        return x.fillna(mean_cap_per_country_dict[country])
    else:
        return x

def get_mode_dict():
    cur_path = os.path.abspath(os.getcwd())
    filepath = cur_path+'/Results/EUROSTAT_inland_transport_mode_split.xlsx'
    splits   = pd.read_excel(filepath).set_index('Country')[['Rail', 'Roads','Inland waterways','Sea']]/100
    modal_dict = splits.T.to_dict()
    return modal_dict

def get_country_and_global_representative_transport_distances_stock_based(Ds_dict, route_factor, ID_and_stock_country_dict, material):
    ID_distances = pd.DataFrame({'repr. distance [km]': Ds_dict})
    ID_distances['Country']     = [ID_and_stock_country_dict['Country'][ID] for ID in ID_distances.index]
    ID_distances['Stock [kg]'] = [ID_and_stock_country_dict['{} median [kg]'.format(material)][ID] for ID in ID_distances.index]
    stock_country_dict          = ID_distances.groupby('Country')['Stock [kg]'].sum().to_dict()
    ID_distances['intra-country stock weight'] = [ID_stock/stock_country_dict[country] for ID_stock, country in zip(ID_distances['Stock [kg]'],ID_distances['Country'])]
    ID_distances['intra-country weighted repr. distance [km]'] = ID_distances['repr. distance [km]']*route_factor*ID_distances['intra-country stock weight']
    country_distances = ID_distances.groupby('Country')[['intra-country weighted repr. distance [km]']].sum()
    country_distances['country stocks [ton]'] = [stock_country_dict[country] for country in country_distances.index]
    country_distances['country stock weight'] = country_distances['country stocks [ton]']/country_distances['country stocks [ton]'].sum()
    glo_ds = (country_distances['intra-country weighted repr. distance [km]']*country_distances['country stock weight']).sum()
    return country_distances[['intra-country weighted repr. distance [km]']], glo_ds

def filter_scenario(data, scenario, method_contains, mode, not_mf_vals, product, slag_modes, process):
    
    data['indicator'] = [(met[0], met[2]) for met in data.method]
    data = data[data['method'].apply(lambda x: any([s not in str(x) for s in not_mf_vals]))]
    data = data[data['method'].apply(lambda x: '&' not in str(x))]

    if mode=='contribution':
       data = data[data['method'].apply(lambda x: 'total' not in str(x))] 
        
    filtered_df = data[data['scenario'].str.contains(scenario) &
    (data['steel']==product) &
    (data['slag'].isin(slag_modes)) & 
    (data['process']==process)].reset_index(drop=True)

    filtered_df = filtered_df[filtered_df['method'].isin(method_contains)]
    
    return filtered_df

def get_stock_dict(stocks):
    stock_dict = {}
    stock_dict['lower'] = stocks.quantile(0.025, axis=0).to_dict()
    stock_dict['median']= stocks.quantile(0.500, axis=0).to_dict()
    stock_dict['upper'] = stocks.quantile(0.975, axis=0).to_dict()
    return stock_dict

def add_relative_savings(saving_dict2):
    saving_dict2['steel sum'] = {}
    saving_dict2['steel sum']['primary'] = {}
    saving_dict2['steel sum']['secondary'] = {}
    saving_dict2['steel sum']['relative saving'] = {}
    
    for bound in ['lower', 'median', 'upper']:
        df1_prim = saving_dict2['reinforcing steel']['primary'][bound]
        df2_prim = saving_dict2['chromium steel']['primary'][bound]
        df3_prim = saving_dict2['low-alloyed steel']['primary'][bound]
        
        df1_sec = saving_dict2['reinforcing steel']['secondary'][bound]
        df2_sec = saving_dict2['chromium steel']['secondary'][bound]
        df3_sec = saving_dict2['low-alloyed steel']['secondary'][bound]

        saving_dict2['steel sum']['primary'][bound] = df1_prim + df2_prim + df3_prim
        saving_dict2['steel sum']['secondary'][bound] = df1_sec  + df2_sec  + df3_sec
        
        saving_dict2['steel sum']['relative saving'][bound] = saving_dict2['steel sum']['secondary'][bound] / saving_dict2['steel sum']['primary'][bound]
    return saving_dict2

def add_absolute_savings(saving_dict1, stock_dict):
    saving_dict1['steel sum'] = {}
    for bound in stock_dict.keys():
        df1 = saving_dict1['reinforcing steel'][bound]
        df2 = saving_dict1['chromium steel'][bound]
        df3 = saving_dict1['low-alloyed steel'][bound]
        saving_dict1['steel sum'][bound] = df1 + df2 + df3
    return saving_dict1

def add_to_other(df, percentage=0.05):
    # Calculate the mean of each column
    column_means = df.mean()
    # Calculate the mean sum of all columns
    mean_sum = abs(df).sum(axis=1).min()
    # Determine the threshold
    threshold = abs(percentage * mean_sum)
    # Identify columns where the mean is less than the threshold
    columns_to_sum = column_means[abs(column_means) < threshold].index
    # Sum these columns into a new 'Other' column
    df['Other'] = df[columns_to_sum].sum(axis=1)    
    # Drop the original columns that were summed
    df = df.drop(columns=columns_to_sum)    
    return df

def aggregate_absolute_savings(df, indicator_agg_dict):
    plot_data_raw = df.T
    #try:
    plot_data_raw['agg_indicator'] = [indicator_agg_dict[eval(idx)[1]] for idx in plot_data_raw.index]
    #except:
    #plot_data_raw['agg_indicator'] = [indicator_agg_dict[eval(idx)] for idx in plot_data_raw.index]
    plot_data     = plot_data_raw.groupby('agg_indicator').sum().T
    plot_data     = add_to_other(plot_data, 0.05)
    #plot_data     = add_to_other(plot_data_raw.T, 0.01)
    return plot_data

def get_errors(saving_dict):
    error_low = np.array(saving_dict['median'].sum(axis=1) - saving_dict['lower'].sum(axis=1))
    error_high= np.array(saving_dict['upper'].sum(axis=1)  - saving_dict['median'].sum(axis=1))
    errors    = np.array([error_low, error_high])
    return errors

def get_production_costs(unit_externalities, year=2025):
    cur_path = os.path.abspath(os.getcwd())
    
    filepath2   =cur_path+'/Databases/production_costs.xlsx'
    prod_costs  = pd.read_excel(filepath2, index_col=0)/1000 #to USD/kg
    
    prod_cost_T = prod_costs.T.reset_index()
    prod_cost_T.columns = ['prod. cost primary' if col=='BF-BOF' else 'prod. cost secondary' if col=='EAF' else 'steel' 
                               for col in prod_cost_T.columns]
    
    for process in ['primary', 'secondary']:
        prod_cost_T['ext. cost {}'.format(process)] = [unit_externalities[steel][process][[col for col in unit_externalities[steel][process].columns if 'resources' not in col]].sum(axis=1).loc[year] 
                                                           	for steel in prod_cost_T.steel]
    return prod_cost_T