from genericpath import isfile
from jinja2 import FileSystemLoader, Environment
import os
import re
import glob
import shutil
import platform
import subprocess
from numpy import append
import pandas as pd

def precheck(args):
    project_dir = os.path.join(args.project_dir, args.project)
    if not os.path.exists(project_dir):
        exit("The project_dir {} does not exist".format(project_dir))
    
    # if not os.path.isfile(os.path.join(project_dir,args.source_file)):
    #     exit("The source file {} does not exist in the project_dir {}".format(args.source_file,project_dir))

    if not os.path.exists(args.data_dir):
        exit("The data/instance folder {} does not exist".format(args.data_dir))

    if len(os.listdir(args.data_dir)) == 0:
        exit("The data/instance folder {} is empty".format(args.data_dir))
    
    if args.result_dir == args.project_dir:
        exit('The project_dir and result_dir can not be identical.')

    if os.path.basename(args.data_dir) == "":
        path = os.path.dirname(args.data_dir)
        dataset_name = os.path.basename(path)
    else:
        dataset_name = os.path.basename(args.data_dir)
        
    result_dir = os.path.join(args.result_dir, args.solver_name +"_"+args.algorithm+"_"+ dataset_name)
    if os.path.exists(result_dir):
        appendix = 1
        new_result_dir = result_dir
        while os.path.exists(new_result_dir):
            new_result_dir = result_dir + '-' + str(appendix)
            appendix += 1
        
        result_dir = new_result_dir
        
    os.makedirs(result_dir)
    # shutil.copy(os.path.join(project_dir,args.original_code),os.path.join(result_dir,'best_promt_code.txt'))

    env = Environment(loader=FileSystemLoader(project_dir))
    source_template = env.get_template("original_prompt_code.txt")

    output = source_template.render(timeout=args.sat_timeout)
    with open(os.path.join(result_dir, 'best_prompt_code.txt'),'w') as f:
        f.write(output)

    if args.solver_name == "easysat":
        shutil.copy(os.path.join(project_dir,args.original_hpp_code),os.path.join(result_dir,'best_hpp_file.txt'))
        shutil.copy(os.path.join(project_dir,'heap.hpp'),os.path.join(result_dir,'heap.hpp'))
    elif args.solver_name == "minisat":
        shutil.copytree(os.path.join(project_dir,args.solver_name),os.path.join(result_dir, args.solver_name)) 
    else:
        raise ValueError("Wrong Solver Name !!!")
    return result_dir


def extract_log_info(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()
    
    # 提取<log>中的内容
    log_lines = []
    
    for line in lines:
        log_lines.append(line.strip())
    
    # 提取最后一行
    last_line = log_lines[-1] if log_lines else None
    
    # 查找CPU time后的数字
    cpu_time = None
    cpu_time_pattern = re.compile(r'CPU time\s*:\s*([\d.]+)\s*s')
    
    for line in log_lines:
        match = cpu_time_pattern.search(line)
        if match:
            cpu_time = match.group(1)
            break
    
    return last_line, float(cpu_time)


def get_code(answer, seperator):
    start = answer.find(seperator[0]) + len(seperator[0])
    end = answer.find(seperator[1], start) # - len(seperator[1])
    content = answer[start:end]
    content = content.replace("'''", "").replace("```", "")
    return content


def revise_file(file_name, save_dir, *args, **kwargs):
    env = Environment(loader=FileSystemLoader('.'))
    template = env.get_template(file_name)
    output = template.render(*args, **kwargs)

    with open(save_dir, 'w') as f:
        f.write(output)


def clean_files(folder_path, mode, *args, **kwargs):
    if mode == "all":
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            if os.path.isfile(file_path):
                os.remove(file_path)
    elif mode == "exe":
        for file_path in glob.glob(os.path.join(folder_path, "*.exe")):
            try:
                os.remove(file_path)
            except Exception as e:
                print(f"Error deleting {file_path}: {e}")
    elif mode == "folder":
        pass
    else:
        raise NotImplemented


def process_raw_results(folder_path, timeout, answers=None):
    result = {
        "time": {},
        "prompt": {},
        "PAR-2": {},
        "satisfiable": {},
        "unsatisfiable": {},
        "timeout": {},
    }
    record_all_data = [] if answers is None else None
    for filename in os.listdir(folder_path):
        match = re.match(r'(\d+)_(\d+).txt', filename)
        if match:
            id, num = match.groups()
            file_path = os.path.join(folder_path, filename)
            if os.path.isfile(file_path):
                tmp_total_time = 0
                tmp_situation = {"satisfiable": 0,
                                 "unsatisfiable": 0,
                                 "timeout": 0}
                tmp_par2 = 0
                with open(file_path, 'r') as file:
                    for line in file.readlines():
                        line = line.strip().strip('\n').strip()
                        if line.startswith('File name'):
                            continue
                        parts = line.split('\t')
                        duration = int(parts[1])
                        situation_single = parts[2].lower()
                        tmp_situation[situation_single] += 1
                        tmp_total_time += duration
                        tmp_par2 += duration if duration < timeout else 2*timeout
                        if record_all_data is not None:
                            cnf_file_name = parts[0]
                            record_all_data.append((cnf_file_name, duration, situation_single))
                # finish reading the file, load temp results
                if id in result["time"]:
                    result["time"][id] += tmp_total_time
                    result["PAR-2"][id] += tmp_par2
                    for situation_key in tmp_situation:
                        result[situation_key][id] += tmp_situation[situation_key]
                else:
                    result["time"][id] = tmp_total_time
                    result["PAR-2"][id] = tmp_par2
                    result["prompt"][id] = answers[int(id)] if answers else 'Evaluation Stage.'
                    for situation_key in tmp_situation:
                        result[situation_key][id] = tmp_situation[situation_key]
    if answers is not None: # train
        return result
    else: # eval
        result['total time'] = result.pop('time')
        result.pop('prompt')
        result_dict = {k: v['1'] for k, v in result.items()}
        result_dict['#question'] = result_dict['satisfiable'] + result_dict['unsatisfiable'] + result_dict['timeout']
        result_dict['PAR-2'] = round(result_dict['PAR-2'] / result_dict['#question'] , 2)
        return result_dict, record_all_data


def collect_results(answers, repetition_dict, results, args):
    repetition_result = {
        "time": {},
        "prompt": {},
        "PAR-2": {},
        "satisfiable": {},
        "unsatisfiable": {},
        "timeout": {},
    }
    folder_path = './temp/results/'
    result = process_raw_results(folder_path=folder_path, timeout=args.timeout, answers=answers)
    if args.devoid_duplication:
        for value in list(repetition_dict.values()):
            key = find_key_for_value(results["prompt"], value)
            if key == None:
                break
            repetition_result["time"][key] = results["time"][key]
            repetition_result["prompt"][key] = results["prompt"][key]
        # repetition_result = {key: results["time"][key] for key in repetition_list if key in results["time"]}

        result["time"].update(repetition_result["time"])
        result["prompt"].update(repetition_result["prompt"])

    best_key = min(result["time"], key=result["time"].get)
    return result, {best_key: [result["time"][best_key], result["prompt"][best_key], result["PAR-2"][best_key]]}


def collect_results_eval(raw_path, final_path, args):
    folder_path = raw_path
    result_dict, record_all_data = process_raw_results(folder_path=folder_path, timeout=args.eval_timeout, answers=None)  # eval mode
    with open(final_path, 'a+', encoding='utf-8') as f:
        f.write("cnf File \t Duration \t Situation \n")
        for cnf_name, duration, situation in record_all_data:
            f.write(f"{cnf_name}\t{duration}\t{situation}\n")
        f.write(str(result_dict) + '\n')
    return result_dict


def fill_core_codes(origin_file, target_file, answer_code,**kwargs):
    revise_file(file_name=origin_file,
                save_dir=target_file,
                timeout='{{ timeout }}',
                data_dir='{{ data_dir }}',
                replace_code=answer_code,
                **kwargs)
    return


def delete_InfiniteLoopInst(candidates, result_dict, results_folder='./temp/results/'):
    failed_id_list = []
    for file_name in candidates:
        if not os.path.isfile(os.path.join(results_folder, file_name)):  # failed
            id_str = file_name.replace('finished', '').split('_')[0]
            for key in result_dict:
                if id_str in result_dict[key]:
                    result_dict[key].pop(id_str)
                    failed_id_list.append(id_str)
    # kill the procession. Maybe dangerous.
    if platform.system() == 'Windows':
        try:
            result = subprocess.run(['taskkill', '/F', '/IM', 'EasySAT'], check=True, text=True)  # TODO check
        except:
            pass
        pass
    elif platform.system() == 'Linux':
        try:
            result = subprocess.run(['pkill', '-f', 'EasySAT'], check=True, stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE, text=True)
        except Exception as e:
            print(f"wrong when killing procession: {e}")
    else:
        raise NotImplementedError('sorry, we only support Wins Or Linux.')

    return


def copy_folder(src_folder, num, mode='train', target_folder = None):
    if mode == 'train':
        for i in range(num):
            new_folder_path = src_folder[:-1] + "_{}/".format(i)
            if os.path.exists(new_folder_path):
                shutil.rmtree(new_folder_path)
            shutil.copytree(src_folder, new_folder_path)
    elif mode == 'eval':
        if target_folder is None:
            raise ValueError('please set target folder to save source files.')
        if os.path.exists(target_folder):
            shutil.rmtree(target_folder)
        shutil.copytree(src_folder, target_folder)
    else:
        raise NotImplementedError('please choose `mode` between `train` or `eval`.')


def find_key_for_value(results, value_to_find):
    for key, value in results.items():
        if value == value_to_find:
            return key
    return None

def check_reIteration(round, best_result_dict, baseline):
    # True: restart the prompt to avoid terrible functions; False: no need to restart
    if round != 1: return False
    best_results = next(iter(best_result_dict.values()))
    if best_results[0] < baseline['time'] or best_results[2] < baseline['PAR-2']:
        return False
    return True

def get_verified_function(func_candis, answered_function_code, best_code):
    new_code = {}
    for func_c in func_candis:
        a_code = ''
        for answer_code in answered_function_code:
            if f'Solver::{func_c}(' in answer_code:
                a_code = answer_code
                break
        new_code[f'replace_{func_c}_code'] = a_code
    
    curr_code = {}

    for func_c in func_candis:
        start_pattern = r"// start " + re.escape(func_c)
        end_pattern = r"// end " + re.escape(func_c) 
        pattern = f"{start_pattern}(.*?){end_pattern}"
        match = re.search(pattern, best_code, re.DOTALL)

        if match:
            content_between = match.group(1).strip()  # 使用strip()移除前后的空白字符
            curr_code[f'curr_{func_c}_code'] = content_between
        else:
            print("没有找到匹配的内容")
    return new_code, curr_code

def get_advice(func_candis, advice_answer):
    advice = {}

    for func_c in func_candis:
        start_pattern = r"// start advice for " + re.escape(func_c)
        end_pattern = r"// end advice for " + re.escape(func_c) 
        pattern = f"{start_pattern}(.*?){end_pattern}"
        match = re.search(pattern, advice_answer, re.DOTALL)

        if match:
            content_between = match.group(1).strip()  # 使用strip()移除前后的空白字符
            advice[f'{func_c}'] = content_between
        else:
            advice = False
    return advice

def rewrite_source_file(answered_function_code,rewrite_funcs,project_dir,result_dir,output_dir,best_code,args):
    best_code_lines = best_code.split('\n')
    

    # best_code_lines = best_code.split('\n')
    # answer_code_lines = [function_code[fi].split('\n') for fi in range(len(function_code))]
    with open(os.path.join(output_dir,'tmp_best_code'),'w') as tmpf:
        write_flag = True
        current_func_c = ''
        print("start writing tmp_best_code !")

        for best_line in best_code_lines:
            for func_c in rewrite_funcs:

                # find position of rewrote functions
                if f'// start {func_c}' == best_line:
                    tmpf.write(f'// start {func_c}\n')
                    tmpf.write('{{{{ replace_{}_code }}}}\n'.format(func_c))
                    tmpf.write(f'// end {func_c}\n')
                    current_func_c = func_c
                    write_flag=False
                    break
            if not write_flag:
                if f'// end {current_func_c}' == best_line:
                    write_flag = True
                    continue
            if best_line == 'int main(int argc, char **argv) {':
                break
            if write_flag:
                tmpf.write(best_line + '\n')
    print("successfully writing tmp_best_code !")
    env = Environment(loader=FileSystemLoader(output_dir))
    source_template = env.get_template('tmp_best_code')
    new_code = {}
    for func_c in rewrite_funcs:
        a_code = ''
        for answer_code in answered_function_code:
            if f'Solver::{func_c}(' in answer_code:
                a_code = answer_code
                break
        new_code[f'replace_{func_c}_code'] = a_code
    # print("new_code: ", new_code)
    # print("answered_function_code: ", answered_function_code)
    output_code = source_template.render(new_code)

    env = Environment(loader=FileSystemLoader(project_dir))
    source_template = env.get_template(args.source_file)
    output = source_template.render(replace_key_code = output_code)
    source_file = os.path.join(output_dir,args.source_file)
    with open(source_file,'w') as tf:   # ./results/module_function-76/0-0/EasySAT.cpp
        tf.write(output)
    

    env = Environment(loader=FileSystemLoader(project_dir))
    hpp_template = env.get_template(args.hpp_file)
    header_output = ''
    with open(os.path.join(result_dir,'best_hpp_file.txt')) as hf:
        header_output = hf.read()
    output = hpp_template.render(new_func_definition=header_output)
    hpp_file = os.path.join(output_dir,'EasySAT.hpp')
    with open(hpp_file,'w') as tf:
        tf.write(output)

    return output_code, source_file


def rewrite_source_file_newfunction(added_func,added_func_code,position_func,position_code,project_dir,result_dir,output_dir,best_code,args):
    
    # find position
    # position_code = position_code.split('\n')
    # if '\n' in position_code:
    #     position_code.remove('\n')
    # if '' in position_code:
    #     position_code.remove('')

    # pre_line = ''
    # app_line = ''
    # for i in range(len(position_code)):
    #     if f'start calling {added_func}' in position_code[i]:
    #         if i > 0:
    #             pre_line = position_code[i-1]
    #             pre_line = pre_line.replace(' ','')
    #     if f'end calling {added_func}' in position_code[i]:
    #         if i+i < len(position_code):
    #             app_line = position_code[i+1]
    #             app_line = app_line.replace(' ','')
    #             break
    
    #     prelineF = True
    #     if 'remainsunchanged' in pre_line:
    #         prelineF = False
    #     applineF = True
    #     if 'remainsunchanged' in app_line:
    #         applineF = False

        
    best_code_lines = best_code.split('\n')
    
    with open(os.path.join(output_dir,'tmp_best_code'),'w') as tmpf:
        check_flag = False
        current_func_c = ''
        for best_line in best_code_lines:
            if f'// start {position_func}' == best_line:
                tmpf.write(f'// start {added_func}\n')
                tmpf.write('{{ replace_new_func_code }}\n')
                tmpf.write(f'// end {added_func}\n\n')
                tmpf.write(best_line + '\n')
                tmpf.write('{{ replace_revoke_func_code }}\n')
                check_flag = True
                continue

            
            if f'// end {position_func}' == best_line:
                tmpf.write(best_line + '\n')
                check_flag = False
                continue             

            if best_line == 'int main(int argc, char **argv) {':
                break
            if not check_flag:
                tmpf.write(best_line + '\n')


    env = Environment(loader=FileSystemLoader(output_dir))
    source_template = env.get_template('tmp_best_code')
    output_code = source_template.render(replace_new_func_code=added_func_code,replace_revoke_func_code=position_code)

    env = Environment(loader=FileSystemLoader(project_dir))
    source_template = env.get_template(args.source_file)
    output = source_template.render(replace_key_code = output_code)
    source_file = os.path.join(output_dir,args.source_file)
    with open(source_file,'w') as tf:
        tf.write(output)
    

    lines = added_func_code.split('\n')
    def_code = ''
    for line in lines:
        if 'Solver::' in line:
            # if '//' in line:
            #     line = line.split('//')
            # line  = [l for l in line if 'Solver::' in l ][0]
            def_code = line.replace('Solver::','')
            def_code = def_code.replace('{','')
            def_code += ';\n'
            break



    env = Environment(loader=FileSystemLoader(project_dir))
    hpp_template = env.get_template(args.hpp_file)
    header_output = ''
    with open(os.path.join(result_dir,'best_hpp_file.txt')) as hf:
        header_output = hf.read()
    header_output += def_code
    output = hpp_template.render(new_func_definition=header_output)
    hpp_file = os.path.join(output_dir,'EasySAT.hpp')
    with open(hpp_file,'w') as tf:
        tf.write(output)
    return output_code, source_file, header_output



def par2(output_file,args):
    result = pd.read_csv(output_file)
    r = 0
    for i in range(len(result.time)):
        if result['result'][i] == 'TIMEOUT':
            r += (args.sat_timeout * 2)
        else:
            r += float(result['time'][i])
    return r / len(result.time)

def par2_array(results,args):
    r = 0
    solver_num = 0

    for i in range(len(results)):
        if results[i][1] == 'TIMEOUT':
            r += (args.sat_timeout * 2)
        else:
            try:
                r += float(results[i][2])
            except:
                print("ERROR !!!!!!!")
                r += (args.sat_timeout * 2)
            solver_num += 1
    print("par2_array:", r / len(results), solver_num )
    return r / len(results), solver_num 
# if __name__ == "__main__":
#     a = {'1': ['950\n', 'else if (conflicts % 1000 == 0 && fast_lbd_sum / lbd_queue_size > 5) restart();']}
#     value = list(a.values())[0][1]
#     print(value)