import argparse
import json
from sys import stdout
import ray
import random
import numpy as np
from autosat.utils import *
from autosat.llm_api.base_api import GPTCallAPI, LocalCallAPI
from autosat.execution.execution_worker import ExecutionWorkerModSAT
RAY_DEDUP_LOGS=0


function_candidates = ['rephase_condition', 'rephase_function','reduce_condition', 'restart_condition', 'restart_function', 'varBumpActivity', 'claBumpActivity'] 
function_candidates_h = ['varBumpActivity', 'claBumpActivity'] 
function_candidates_c = ['rephase_condition', 'rephase_function', 'reduce_condition', 'restart_function', 'restart_condition'] 


# cryptography-ascon
# function_candidate = ['rephase_condition', 'rephase_function', 'reduce_condition', 'varBumpActivity']
# function_candidate_c = ['rephase_condition', 'rephase_function', 'reduce_condition']
# function_candidate_h = ['varBumpActivity']

# register-allocation
# function_candidate = ['rephase_function', 'reduce_condition', 'restart_function', 'varBumpActivity']
# function_candidate_c = ['rephase_function', 'reduce_condition', 'restart_function'] 
# function_candidate_h = ['varBumpActivity']

# social-golfer
# function_candidate = ['rephase_condition', 'restart_condition', 'restart_function', 'varBumpActivity']
# function_candidate_c = ['rephase_condition', 'restart_condition', 'restart_function', ]
# function_candidate_h = ['varBumpActivity']

# hashtable-safety
# function_candidate = ['rephase_function', 'restart_condition', 'restart_function', 'claBumpActivity']
# function_candidate_c = ['rephase_function', 'restart_condition', 'restart_function']
# function_candidate_h = ['claBumpActivity']

# argumentation 2023
# function_candidate = ['rephase_condition', 'rephase_function', 'reduce_condition', 'varBumpActivity']
# function_candidate_c = ['rephase_condition', 'rephase_function', 'reduce_condition']
# function_candidate_h = ['varBumpActivity']

# argumentation 2024
# function_candidate = ['rephase_condition', 'rephase_function', 'reduce_condition', 'restart_function']
# function_candidate_c = ['rephase_condition', 'rephase_function', 'reduce_condition', 'restart_function']
# function_candidate_h = []

# hamiltonian
# function_candidate = ['reduce_condition', 'restart_condition', 'restart_function', 'varBumpActivity']
# function_candidate_c = ['reduce_condition', 'restart_condition', 'restart_function']
# function_candidate_h = ['varBumpActivity']

# MineSweeper
# function_candidate = ['rephase_function', 'restart_condition', 'reduce_condition', 'claBumpActivity']
# function_candidate_c = ['rephase_function', 'restart_condition', 'reduce_condition']
# function_candidate_h = ['claBumpActivity']

# KnightTour
# function_candidate = ['rephase_condition', 'reduce_condition', 'restart_condition', 'claBumpActivity']
# function_candidate_c = ['rephase_condition', 'reduce_condition', 'restart_condition']
# function_candidate_h = ['claBumpActivity']

# Zamkeller
# function_candidate = ['rephase_condition', 'reduce_condition', 'restart_condition', 'varBumpActivity']
# function_candidate_c = ['rephase_condition', 'reduce_condition', 'restart_condition']
# function_candidate_h = ['varBumpActivity']

# EDA
# function_candidate = ['rephase_function', 'restart_function', 'varBumpActivity', 'claBumpActivity']
# function_candidate_c = ['rephase_function', 'restart_function', 'varBumpActivity']
# function_candidate_h = ['claBumpActivity']


@ray.remote
def synchronized_asked(prompt_file, args, func_names, coder_temperature):
    llm_api = GPTCallAPI(api_base=args.api_base,
                            api_key=args.api_key,
                            model_name=args.llm_model,
                            stream=False)

    f = open(prompt_file)
    exist_code = f.read()

    answer = llm_api.call_api(
        prompt_file=prompt_file, 
        temperature=coder_temperature
        )

    function_code = {}
    for func_name in func_names:
        function_code[func_name] = get_code(answer, seperator=[f'// start {func_name}', f'// end {func_name}'])
    
    return answer, function_code


@ray.remote
def synchronized_asked_evaluator(prompt_file, args, func_names):
    # if args.llm_model == "deepseek-reasoner" or args.llm_model == "deepseek-ai/DeepSeek-R1" or args.llm_model == "o1-2024-12-17":
    # else:
    llm_api = GPTCallAPI(
        api_base=args.api_base,
        api_key=args.api_key,
        model_name=args.llm_model,
        stream=False
    )
                            
    f = open(prompt_file)
    exist_code = f.read()
    answer = llm_api.call_api(prompt_file=prompt_file, temperature=0.2)
    
    difference = [get_code(answer, seperator=[f'whether it has substantially improved:\n', f'\n'])]
        
    return answer, difference

@ray.remote
def synchronized_asked_advisor(prompt_file, args, func_names):

    llm_api = GPTCallAPI(
        api_base=args.api_base,
        api_key=args.api_key,
        model_name=args.llm_model,
        stream=False
    )

    answer = llm_api.call_api(prompt_file=prompt_file, temperature=0.2)
    
    advice = get_advice(func_names, answer)
        
    return advice


@ray.remote
def synchronized_asked_repairer(prompt_file, args, func_names):
    llm_api = GPTCallAPI(api_base=args.api_base,
                        api_key=args.api_key,
                        model_name=args.llm_model,
                        stream=False)
                            
    f = open(prompt_file)
    answer = llm_api.call_api(prompt_file=prompt_file, temperature=0.2)
    function_code = {}
    for func_name in func_names:
        function_code[func_name] = get_code(answer, seperator=[f'// start {func_name}', f'// end {func_name}'])
    
    return answer, function_code


def asynchronous_executed(args, output_dir, instances):
        active_tasks = set()
        completed_results = []
        iterator = iter(instances)
        
        ins_num = 0
        for _ in range(min(args.num_cpus, len(instances))):
            ins_num += 1
            try:
                ins = next(iterator)
                task = synchronized_executed.remote(args, output_dir, ins)
                active_tasks.add(task)
            except StopIteration:
                break

        while active_tasks:
            done_tasks, active_tasks = ray.wait(list(active_tasks), num_returns=1)
            active_tasks = set(active_tasks)

            completed_results.extend(ray.get(done_tasks))
            
            for _ in range(len(done_tasks)):
                try:
                    ins_num += 1
                    ins = next(iterator)
                    task = synchronized_executed.remote(args, output_dir, ins)
                    active_tasks.add(task)
                except StopIteration:
                    break
        this_result = completed_results  
        return this_result
 

@ray.remote
def synchronized_executed(arguments, output_dir, instance, *args, **kwargs):
    execution_worker = ExecutionWorkerModSAT()
    save_file = os.path.join(output_dir, arguments.source_file)
    executable_file=os.path.join(output_dir, arguments.executable_file)
    output_file = os.path.join(output_dir, 'result', instance)

    success = execution_worker.execute(
        output_dir=output_dir,
        executable_file_path =executable_file, 
        instance_file = os.path.join(arguments.data_dir, instance),
        output_file = output_file,
        timeout=arguments.sat_timeout
    )

    '''
    if success:
        with open(output_file) as f:
            result = f.readline()
            model = f.readline()
        result = result.strip().split(',')
        with open(output_file + '-model','w') as wf:
            wf.write('s {}\n'.format(result[1]))
            wf.write(model)
        # print("/Users/fye/Codes/Git/AutoSAT_dev/check_sat {} {} > {}".format(os.path.join(arguments.data_dir,instance), output_file+'-model', output_file + '-model-check'))
        if result[1] != 'SATISFIABLE':
            return result
        else:
            exec_code = os.system(
                "./check-sat {} {} > {}".format(os.path.join(arguments.data_dir,instance), output_file+'-model', output_file + '-model-check'))
            if exec_code != 0:
                result = [arguments.data_dir+ '/' + instance, 'ERROR', sys.maxsize / 1000]
            else:
                with open(output_file + '-model-check') as wcf:
                    info = wcf.read()
                if 'WRONG' in info:
                    result = [arguments.data_dir+ '/' + instance, 'ERROR', sys.maxsize / 1000]
    else:
        result = [arguments.data_dir+ '/' + instance, 'ERROR', sys.maxsize / 1000]
    '''

    status, time = extract_log_info(output_file)
    if status == "INDETERMINATE":
        status = "TIMEOUT"
        time = arguments.sat_timeout
    result = [instance, status, time]

    return result


def evolution_search(args):
    results = {}
    best = 10000000
    no_improvement_num = 0
    substantial_improvement_num = 0
    parameter_tuning_num = 0

    result_dir = precheck(args)
    project_dir = os.path.join(args.project_dir,args.project)  # ./examples/EasySAT/module_function 
    if args.load_original:
        with open(os.path.join(project_dir, "best_function.json")) as f:
            best_functions = json.load(f)
    else:
        with open(os.path.join(project_dir, "original_function.json")) as f:
            best_functions = json.load(f)

    instances = [file for file in os.listdir(args.data_dir) if '.cnf' in file]
    ray.init(num_cpus=args.num_cpus)
    env = Environment(loader=FileSystemLoader(os.path.join(project_dir)))

    output_dir = os.path.join(result_dir,'original')
    os.makedirs(output_dir)  # original/
    
    ##
    shutil.copytree(os.path.join(project_dir,'modsat'),os.path.join(output_dir,'modsat'))
    h_template = env.get_template("./Solver_template.h")
    render_h_context = {}
    for func in function_candidates_h:
        render_h_context[func] = best_functions[func]
    h_file = h_template.render(render_h_context)
    with open(os.path.join(output_dir,'modsat', 'modsat','core', 'Solver.h'), 'w') as f:
        f.write(h_file)

    c_template = env.get_template("./Solver_template.cc")
    render_c_context = {}
    for func in function_candidates_c:
        render_c_context[func] = best_functions[func]
    c_file = c_template.render(render_c_context)

    with open(os.path.join(output_dir,'modsat', 'modsat','core', 'Solver.cc'), 'w') as f:
        f.write(c_file)

    ##

    execution_worker = ExecutionWorkerModSAT()
    success = execution_worker.compile(output_dir = output_dir)
    
    if success:
        os.mkdir(os.path.join(output_dir,'result'))
    
        this_result = asynchronous_executed(args, output_dir, instances)
        if len(this_result) != len(instances):
            print('something went wrong in {0}')
        results['0'], solved_num = par2_array(this_result,args)

        best = results['0']
        best_solved_num = solved_num
        print(f'best = {best} in {0}, solved_num = {solved_num}')
        stdout.flush()

    iteration = 0
    num_function_candidate = len(function_candidates)


    while True:
        iteration += 1
        find_better = False

        if args.algorithm == "ea":
            flip_size = np.random.binomial(num_function_candidate, 1.0/num_function_candidate)
            while flip_size == 0:
                flip_size = np.random.binomial(num_function_candidate, 1.0/num_function_candidate)
            func_candi_indexs = random.sample(list(range(num_function_candidate)),flip_size)
            func_candis = [function_candidates[func_candi_index] for func_candi_index in func_candi_indexs]
        elif args.algorithm == "greedy":
            func_candi_index = iteration % num_function_candidate
            func_candis = [function_candidates[func_candi_index]]

        restart_flag = 0
        previous_no_improvement = False
        print("#######################")
        print("iteration: ", iteration)
        
        if iteration == 0:
            with open(os.path.join(result_dir,'best_log.txt'),'w') as af:
                af.write(f'Dataset: {args.data_dir} \n')
        
        with open(os.path.join(result_dir,'log.txt'), 'w') as f:
            f.write("")
        while True:
            # Force change
            if restart_flag > 3:
                # iteration -= 1
                break
            
            # output dir:  ./results/module_function-num/1-0
            # result dir:  ./results/module_function-num
            output_dir = os.path.join(result_dir,'{}-{}'.format(iteration, restart_flag))
            if os.path.exists(output_dir):
                shutil.rmtree(output_dir, ignore_errors=True)
                try:
                    os.makedirs(output_dir)
                except:
                    pass
            else:   
                os.makedirs(output_dir)

            print("function_candis:", func_candis)

            with open(os.path.join(result_dir,'best_prompt_code.txt')) as f:
                best_code = f.read()

            # generate code and template

            prompt_code_template = env.get_template("./original_prompt_code.txt")

            render_prompt_context = {}
            for func in function_candidates:
                render_prompt_context[func] = best_functions[func]
            prompt_code = prompt_code_template.render(render_prompt_context)

            used_multi_coder = False
            used_repairer = False
            ##### Multi-Coder ######
            coder_temperature = args.coder_temperature
            if previous_no_improvement:
                coder_temperature += 0.3
            print("current coder temperature:: ", coder_temperature)
            if used_multi_coder:
                task = []

                for i, func_candi_name in enumerate(func_candis):
                    if random.random() < 1:
                        prompt_template = env.get_template(args.rewrite_func_prompt)
                    else:
                        prompt_template = env.get_template('prompt_rewrite_function_2.txt')
                    promp_output = prompt_template.render(replace_key_code=best_code,func_name=func_candi_name)
                    prompt_file = os.path.join(output_dir,'prompt_{}'.format(func_candi_name))  # prompt dir:  ./results/module_function-num/1-0/prompt

                    with open(prompt_file, 'w') as f:
                        f.write(promp_output)
                        if restart_flag > 0:
                            f.write(f'\n Note that you already provided me codes that could not be compiled {restart_flag} time(s). \nPlease check if you provided me codes requiring additional namespace or packages.')

                    task.append(synchronized_asked.remote(prompt_file, args, func_candi_name, coder_temperature)) 
                
                # collect coder result
                try:
                    task_results = ray.get(task, timeout=120)
                    
                    answer, answered_function_code = [], []
                    for result in task_results:
                        answer.append(result[0])
                        answered_function_code.append(result[1][0])
                except:
                    restart_flag += 1
                    print("LLM QUERY ERROR!!")
                    continue
                print("successfully asked! ")
            
            ##### Coder ######
            else: 
                task = []

                if random.random() < 0.5:
                    prompt_template = env.get_template(os.path.join("prompt_template", 'prompt_rewrite_function_2.txt'))
                else:
                    prompt_template = env.get_template(os.path.join("prompt_template", 'prompt_rewrite_function_2.txt'))
                
                if len(func_candis) == 1:
                    func_candi_name = func_candis[0]
                else:
                    func_candi_name = ''
                    for index in  range(len(func_candis) - 1):
                        func_candi_name += (func_candis[index] + ',')
                    func_candi_name += ('and ' + func_candis[-1])
                
                promp_output = prompt_template.render(
                    replace_key_code=prompt_code,
                    func_name=func_candi_name)
                
                prompt_file = os.path.join(output_dir,'prompt_{}'.format(func_candi_name))  # prompt dir:  ./results/module_function-num/1-0/prompt

                with open(prompt_file, 'w') as f:
                    f.write(promp_output)
                    if restart_flag > 0:
                        f.write(f'\n Note that you already provided me codes that could not be compiled {restart_flag} time(s). \nPlease check if you provided me codes requiring additional namespace or packages.')

                task = synchronized_asked.remote(prompt_file,args,func_candis, coder_temperature)
                answer, answered_function_code = ray.get(task)

                with open(os.path.join(result_dir,'log.txt'), 'a') as f:
                    f.write(f'answer: \n {answer} \n \
                            answer code: {answered_function_code} \n')
            

            ##### Evaluator ######            
            original_code = {}
            for func in func_candis:
                original_code[func] = best_functions[func]

            new_code = answered_function_code

            with open(os.path.join(result_dir,'log.txt'), 'a') as f:
                f.write(f'new_code: \n {new_code} \n \
                       original_code: {original_code} \n')
            # print("new_code", new_code)
            # print("original_code", original_code)

            print("successfully get_verified_function! ")
            evaluator_prompt_template = env.get_template(os.path.join("prompt_template", args.evaluator_prompt))
            evaluator_prompt_output = evaluator_prompt_template.render(func_name=func_candis, new_code=new_code, original_code=original_code)
            evaluator_prompt_file = os.path.join(output_dir,'evaluator_prompt')  # prompt dir:  ./results/module_function-num/1-0/prompt
            with open(evaluator_prompt_file, 'w') as f:
                f.write(evaluator_prompt_output)
            task = synchronized_asked_evaluator.remote(evaluator_prompt_file,args,func_candis)

            answer_evaluator, difference = ray.get(task, timeout=120)

            print("evaluator difference: ", difference)
            ######################

            with open(os.path.join(result_dir,'answer_log.txt'),'a') as af:
                af.write(f"################################## \n")
                af.write(f"iteration: {iteration} \n")
                af.write(f"restart flag: {restart_flag} \n")
                af.write(f"modified function: {func_candis} \n")
                af.write(f"total answer: {answer} \n")
                af.write(f"1111111111111111111111111111111111111111111111111111111111 \n")
                for func_name, func in answered_function_code.items():
                    af.write(f"answered_function_code: {func} \n")
                af.write(f'evaluator answer = {difference} \n')
                af.write(f'22222222222222222222222222222222222222222222222222222222222\n')
            if "Substantial Improvement" in difference[0]:
                substantial_improvement_num += 1
            elif "Parameter Tuning" in difference[0]:
                parameter_tuning_num += 1
            elif "No Improvement" in difference[0]:
                no_improvement_num += 1
                previous_no_improvement = True
                restart_flag += 1
                continue
            print("restart_flag: ", restart_flag)
            ## evaluator end
            
            shutil.copytree(os.path.join(project_dir,'modsat'),os.path.join(output_dir,'modsat'))
            h_template = env.get_template("./Solver_template.h")
            render_h_context = {}
            for func in function_candidates_h:
                if func in func_candis:
                    render_h_context[func] = answered_function_code[func]
                else:
                    render_h_context[func] = best_functions[func]
            h_file = h_template.render(render_h_context)
            with open(os.path.join(output_dir,'modsat', 'modsat','core', 'Solver.h'), 'w') as f:
                f.write(h_file)

            c_template = env.get_template("./Solver_template.cc")
            render_c_context = {}
            for func in function_candidates_c:
                if func in func_candis:
                    render_c_context[func] = answered_function_code[func]
                else:
                    render_c_context[func] = best_functions[func]
            c_file = c_template.render(render_c_context)

            with open(os.path.join(output_dir,'modsat', 'modsat','core', 'Solver.cc'), 'w') as f:
                    f.write(c_file)

            execution_worker = ExecutionWorkerModSAT()
            success = execution_worker.compile(output_dir = output_dir)
            print("success: ", success)    
            
            if success:
                print("successfully compiled! ")
                os.mkdir(os.path.join(output_dir,'result'))

                this_result = asynchronous_executed(args, output_dir, instances)

                # if os.path.exists(os.path.join(output_dir,'result')):
                #     shutil.rmtree(os.path.join(output_dir,'result'), ignore_errors=True)

                if len(this_result) != len(instances):
                    print('something went wrong in {}-{}'.format(iteration,restart_flag))

                results['{}-{}'.format(iteration, restart_flag)],solved_num = par2_array(this_result,args)
                

                if results['{}-{}'.format(iteration,restart_flag)] < 0.995*best:
                    best = results['{}-{}'.format(iteration,restart_flag)]
                    find_better = True
                    current_best_code = ""

                    for func in func_candis:
                        best_functions[func] = answered_function_code[func]

                    print(f'best = {best} in {iteration}-{restart_flag}, solved_num = {solved_num}')
                    stdout.flush()

                    with open(os.path.join(result_dir,'best_log.txt'),'a') as af:
                        af.write(f'best = {best} in {iteration}-{restart_flag}, solved_num = {solved_num} \n')

                break
            else:
                print("Executing Error!!")

                ##### Reparier #####
                if used_repairer:
                    repairer_prompt_template = env.get_template(os.path.join("prompt_template", "repairer_template.txt"))
                    repairer_prompt_output = repairer_prompt_template.render(func_name=func_candis, new_code=new_code, bug=bug_info)
                    repairer_prompt_file = os.path.join(output_dir,'repairer_prompt')  # prompt dir:  ./results/module_function-num/1-0/prompt
                    with open(repairer_prompt_file, 'w') as f:
                        f.write(repairer_prompt_output)
                    task = synchronized_asked_repairer.remote(repairer_prompt_file,args,func_candis)
                    answer_repairer, code = ray.get(task, timeout=120)
                    execution_worker = ExecutionWorkerModSAT()
                    success = execution_worker.compile(output_dir = output_dir)
                    if success:
                        os.mkdir(os.path.join(output_dir,'result'))
                        this_result = asynchronous_executed(args, output_dir, instances)

                        if len(this_result) != len(instances):
                            print('something went wrong in {}-{}'.format(iteration,restart_flag))

                        results['{}-{}'.format(iteration, restart_flag)],solved_num = par2_array(this_result,args)
                        
                        if results['{}-{}'.format(iteration,restart_flag)] < 0.995*best:
                            best = results['{}-{}'.format(iteration,restart_flag)]
                            find_better = True
                            current_best_code = ""

                            for func in func_candis:
                                best_functions[func] = answered_function_code[func]

                            print(f'best = {best} in {iteration}-{restart_flag}, solved_num = {solved_num}')
                            stdout.flush()

                            with open(os.path.join(result_dir,'best_log.txt'),'a') as af:
                                af.write(f'best = {best} in {iteration}-{restart_flag}, solved_num = {solved_num} \n')
                    

                ######################
                restart_flag += 1
            
        if find_better:
            with open(os.path.join(result_dir,'best_prompt_code.txt'), 'w') as f:
                f.write(current_best_code)
            with open(os.path.join(result_dir,'result.txt'), 'w') as f:
                f.write(str([best, best_solved_num]))

            with open(os.path.join(result_dir,'best_functions.json'), 'w') as f:
                json.dump(best_functions, f, ensure_ascii=False, indent=4)

        if iteration==args.max_iterations:
            with open(os.path.join(result_dir,'result.txt'), 'a') as f:
                f.write(f'Substantial Improvement num {substantial_improvement_num} \n \
                        Parameter Tuning num {parameter_tuning_num} \n \
                        No Improvement num {no_improvement_num}')
            break         


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--max_iterations', type=int, default=60)
    parser.add_argument('--algorithm', type=str, default="ea")
    parser.add_argument('--solver_name', type=str, default="modsat")
    parser.add_argument('--num_cpus', type=int, default=100)
    parser.add_argument('--llm_model',
                        type=str,
                        default="deepseek-chat",
                        choices=["gpt-4o-2024-08-06", "gpt-4-1106-preview", "gpt-3.5-turbo", "deepseek-reasoner", "o1-2024-12-17", "deepseek-ai/DeepSeek-R1", "deepseek-chat"])
    parser.add_argument('--sat_timeout', type=int, default=15)
    parser.add_argument('--data_dir', type=str, default="./data/Zamkeller_train")
    parser.add_argument('--project', type=str, default="module_function")
    parser.add_argument('--source_file', type=str, default="./modsat/modsat/core/Solver.cc")
    parser.add_argument('--executable_file', type=str, default="modsat/bin/modsat")
    parser.add_argument('--project_dir', type=str, default="./examples/ModSAT")
    parser.add_argument('--result_dir', type=str, default="./results/")
    parser.add_argument('--load_original', type=bool, default=False)

    parser.add_argument('--advisor_prompt', type=str, default='advisor_template.txt')
    parser.add_argument('--rewrite_func_prompt', type=str, default='prompt_rewrite_function.txt')
    parser.add_argument('--coder_prompt', type=str, default='coder_template.txt')
    parser.add_argument('--evaluator_prompt', type=str, default='evaluator_template.txt')
    parser.add_argument('--coder_temperature', type=float, default=0.2)

    parser.add_argument('--api_base', type=str, default='')
    parser.add_argument('--api_key', type=str, default='')
    args = parser.parse_args()

    evolution_search(args)


