from ast import arg
from hmac import new
import os,shutil
import argparse
import json
from sys import executable, stdout

import ray
import random
import copy
import sys

from jinja2 import FileSystemLoader, Environment
from autosat.utils import *
from autosat.llm_api.base_api import GPTCallAPI
from autosat.execution.execution_worker import ExecutionWorkerModSAT
from autosat.evaluation.evaluate import evaluate
RAY_DEDUP_LOGS=0

import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score 

from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F



def generate_code_embeddings(code_samples):
    model_name = "Salesforce/codet5p-110m-embedding"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    embeddings = []
    
    for code in code_samples:
        preprocessed_code = preprocess_code(code)
        
        inputs = tokenizer(preprocessed_code, return_tensors='pt', truncation=True, max_length=512)
        with torch.no_grad():
            outputs = model(**inputs)
        
        mean_embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
        embeddings.append(mean_embedding.numpy())
    
    return np.array(embeddings)

def preprocess_code(cpp_code):
    pattern = r'(\w+(?:<.*>)?\s+\**\s*(\w+)\s*\(([\s\S]*?)\)\s*([^{]*)'
    match = re.search(pattern, cpp_code)
    if not match:
        return cpp_code 

    return_type, func_name, params, suffix = match.groups()
    body = cpp_code[match.end():].strip()

    func_name = re.sub(r'_([a-z])', lambda m: m.group(1).upper(), func_name)
    func_name = func_name[0].lower() + func_name[1:]

    params = params.replace('\n', ' ').strip()
    param_list = [p.strip() for p in re.split(r',(?![^<>]*>)', params) if p.strip()]
    
    if len(param_list) > 0:
        if len(', '.join(param_list)) <= 60:  
            params_str = ', '.join(param_list)
        else:  
            params_str = ',\n'.join(f"    {p}" for p in param_list)
            params_str = f"\n{params_str}\n"
    else:
        params_str = ""

    signature = f"{return_type} {func_name}({params_str})"
    signature = re.sub(r'\s+', ' ', signature)
    
    body = body.lstrip()
    if body.startswith('{'):
        body = body[1:].strip()
    if body.endswith('}'):
        body = body[:-1].rstrip()
    

    lines = body.split('\n')
    indented_body = "\n".join(f"  {line.strip()}" if line.strip() else "" for line in lines)
    
    return f"{signature} {{\n{indented_body}\n}}"


def cluster_and_calculate_entropy(embeddings, k_clusters):
    kmeans = KMeans(
        n_clusters=k_clusters,
        init='k-means++',
        n_init=10,
        max_iter=300
    )
    
    kmeans.fit(embeddings)
    labels = kmeans.labels_
    silhouette = silhouette_score(embeddings, labels)

    cluster_counts = np.bincount(labels, minlength=k_clusters)
    
    total_samples = len(embeddings)
    probabilities = cluster_counts / total_samples
    
    entropy = 0
    for p in probabilities:
        if p > 0: 
            entropy -= p * np.log2(p)
    
    return entropy, probabilities, cluster_counts, silhouette


function_candidates = ['rephase_condition', 'rephase_function','reduce_condition', 'restart_condition', 'restart_function', 'varBumpActivity', 'claBumpActivity'] 
function_candidates_h = ['varBumpActivity', 'claBumpActivity'] 
function_candidates_c = ['rephase_condition', 'rephase_function', 'reduce_condition', 'restart_function', 'restart_condition'] 


@ray.remote
def synchronized_asked(prompt_file, args, func_names, coder_temperature):
    llm_api = GPTCallAPI(api_base=args.api_base,
                            api_key=args.api_key,
                            model_name=args.llm_model,
                            stream=False)

    f = open(prompt_file)
    exist_code = f.read()

    answer = llm_api.call_api(
        prompt_file=prompt_file, 
        temperature=coder_temperature
        )

    function_code = {}
    for func_name in func_names:
        function_code[func_name] = get_code(answer, seperator=[f'// start {func_name}', f'// end {func_name}'])
    
    return answer, function_code


@ray.remote
def synchronized_asked_optimizer(prompt_file, args, coder_temperature):
    llm_api = GPTCallAPI(api_base=args.api_base,
                            api_key=args.api_key,
                            model_name=args.llm_model,
                            stream=False)

    f = open(prompt_file)
    exist_code = f.read()

    answer = llm_api.call_api(
        prompt_file=prompt_file, 
        temperature=coder_temperature
        )

    optimized_prompt_part = get_code(answer, seperator=[f'// start', f'// end'])
    
    return answer, optimized_prompt_part


@ray.remote
def synchronized_asked_evaluator(prompt_file, args, func_names):
    llm_api = GPTCallAPI(api_base=args.api_base,
                        api_key=args.api_key,
                        model_name=args.llm_model,
                        stream=False)
                            
    f = open(prompt_file)
    exist_code = f.read()
    answer = llm_api.call_api(prompt_file=prompt_file, temperature=0.2)
    
    difference = [get_code(answer, seperator=[f'whether it has substantially improved:\n', f'\n'])]
        
    return answer, difference


def asynchronous_executed(args, output_dir, instances):
        active_tasks = set()
        completed_results = []
        iterator = iter(instances)
        
        ins_num = 0
        for _ in range(min(args.num_cpus, len(instances))):
            ins_num += 1
            try:
                ins = next(iterator)
                task = synchronized_executed.remote(args, output_dir, ins)
                active_tasks.add(task)
            except StopIteration:
                break

        while active_tasks:
            done_tasks, active_tasks = ray.wait(list(active_tasks), num_returns=1)
            active_tasks = set(active_tasks)

            completed_results.extend(ray.get(done_tasks))
            
            for _ in range(len(done_tasks)):
                try:
                    ins_num += 1
                    ins = next(iterator)
                    task = synchronized_executed.remote(args, output_dir, ins)
                    active_tasks.add(task)
                except StopIteration:
                    break
        this_result = completed_results  
        return this_result
 

@ray.remote
def synchronized_executed(arguments, output_dir, instance, *args, **kwargs):
    execution_worker = ExecutionWorkerModSAT()
    save_file=os.path.join(output_dir,arguments.source_file)
    executable_file=os.path.join(output_dir,arguments.executable_file)
    output_file = os.path.join(output_dir,'result',instance)

    success = execution_worker.execute(
        output_dir=output_dir,
        executable_file_path =executable_file, 
        instance_file = os.path.join(arguments.data_dir,instance), 
        output_file = output_file,
        timeout=arguments.sat_timeout
    )

    status, time = extract_log_info(output_file)
    if status == "INDETERMINATE":
        status = "TIMEOUT"
        time = arguments.sat_timeout
    result = [instance, status, time]
    return result

def code_generate(args, func_candi_index, prompt_code_template, result_dir):
    answers = {} 
    results = {}
    code_base = []
    best = 10000000
    func_candis = [function_candidates[func_candi_index]]

    project_dir = os.path.join(args.project_dir,args.project) 
    with open(os.path.join(project_dir, "original_function.json")) as f:
        best_functions = json.load(f)

    instances = [file for file in os.listdir(args.data_dir) if '.cnf' in file]
    ray.init(num_cpus=args.num_cpus)
    env = Environment(loader=FileSystemLoader(os.path.join(project_dir)))
    iteration = 0

    while True:
        iteration += 1
        restart_flag = 0
        print("#######################")
        print("iteration: ", iteration)
        
        with open(os.path.join(result_dir,'log.txt'), 'w') as f:
            f.write("")

        while True:
            if restart_flag > 3:
                break
            
            output_dir = os.path.join(result_dir,'{}-{}'.format(iteration, restart_flag))
            if os.path.exists(output_dir):
                shutil.rmtree(output_dir, ignore_errors=True)
                try:
                    os.makedirs(output_dir)
                except:
                    pass
            else:   
                os.makedirs(output_dir)

            print("function_candis:", func_candis)

            with open(os.path.join(result_dir, 'best_prompt_code.txt')) as f:
                best_code = f.read()

            render_prompt_context = {}
            for func in function_candidates:
                render_prompt_context[func] = best_functions[func]
            prompt_code = prompt_code_template.render(render_prompt_context)

            ##### Coder ######
            coder_temperature = args.coder_temperature

            prompt_template = env.get_template(os.path.join("prompt_template", 'coder_template.txt'))
            func_candi_name = func_candis[0]
            
            promp_output = prompt_template.render(
                replace_key_code=prompt_code,
                func_name=func_candi_name)
            
            prompt_file = os.path.join(output_dir,'prompt_{}'.format(func_candi_name))
            with open(prompt_file, 'w') as f:
                f.write(promp_output)

            task = synchronized_asked.remote(prompt_file,args,func_candis, coder_temperature)
            answer, answered_function_code = ray.get(task)

            with open(os.path.join(result_dir,'log.txt'), 'a') as f:
                f.write(f'answer: \n {answer} \n \
                        answer code: {answered_function_code} \n')

            ##### Evaluator ######            
            original_code = {}
            for func in func_candis:
                original_code[func] = best_functions[func]

            new_code = answered_function_code

            with open(os.path.join(result_dir,'log.txt'), 'a') as f:
                f.write(f'new_code: \n {new_code} \n \
                       original_code: {original_code} \n')


            print("successfully get_verified_function! ")
            evaluator_prompt_template = env.get_template(os.path.join("prompt_template", args.evaluator_prompt))
            evaluator_prompt_output = evaluator_prompt_template.render(func_name=func_candis, new_code=new_code, original_code=original_code)
            evaluator_prompt_file = os.path.join(output_dir,'evaluator_prompt')
            with open(evaluator_prompt_file, 'w') as f:
                f.write(evaluator_prompt_output)
            task = synchronized_asked_evaluator.remote(evaluator_prompt_file,args,func_candis)

            answer_evaluator, difference = ray.get(task, timeout=120)

            print("evaluator difference: ", difference)
            ######################
            
            shutil.copytree(os.path.join(project_dir,'modsat'),os.path.join(output_dir,'modsat'))
            h_template = env.get_template("./Solver_template.h")
            render_h_context = {}
            for func in function_candidates_h:
                if func in func_candis:
                    render_h_context[func] = answered_function_code[func]
                else:
                    render_h_context[func] = best_functions[func]
            h_file = h_template.render(render_h_context)
            with open(os.path.join(output_dir,'modsat', 'modsat','core', 'Solver.h'), 'w') as f:
                f.write(h_file)

            c_template = env.get_template("./Solver_template.cc")
            render_c_context = {}
            for func in function_candidates_c:
                if func in func_candis:
                    render_c_context[func] = answered_function_code[func]
                else:
                    render_c_context[func] = best_functions[func]
            c_file = c_template.render(render_c_context)

            with open(os.path.join(output_dir, 'modsat', 'modsat', 'core', 'Solver.cc'), 'w') as f:
                    f.write(c_file)

            execution_worker = ExecutionWorkerModSAT()
            success = execution_worker.compile(output_dir = output_dir)
            print("success: ", success)    
            
            if success:
                print("successfully compiled! ")
                break
            else:
                print("Executing Error!!")
                restart_flag += 1
            
        code_base.append(new_code)

        if iteration == 20:
            break 

    return code_base        

def main(args):
    project_dir = os.path.join(args.project_dir,args.project) 
    with open(os.path.join(project_dir, "prompt_template/original_coder.json")) as f:
            prompt_part = json.load(f)
    result_dir = precheck(args)

    max_entropy = 0
    for i in range(args.max_iterations):
        
        output_dir = os.path.join(result_dir,'{}'.format(i))
        content_candidate = ["role", "goal", "tips"]
        content_index = random.sample(content_candidate)
        prompt_needed_optimized = prompt_part[content_index]

        # initialize coder prompt
        env = Environment(loader=FileSystemLoader(os.path.join(project_dir)))
        whole_prompt_template = env.get_template("./prompt_template/coder_template.txt")
        whole_prompt = whole_prompt_template.render(prompt_part)

        optimized_prompt_template = env.get_template("./prompt_template/optimize_prompt.txt")
        optimized_prompt = optimized_prompt_template.render(
            whole_prompt=whole_prompt,
            original_prompt=prompt_needed_optimized
        )

        prompt_file = os.path.join(output_dir, 'prompt_optimized')
        task = synchronized_asked_optimizer.remote(prompt_file, args)
        answer, optimized_prompt_part = ray.get(task)

        new_prompt_part = copy(prompt_part)
        new_prompt_part[content_index] = optimized_prompt_part
        new_whole_prompt = whole_prompt_template.render(new_prompt_part)

        code_base = code_generate(
            args, 
            function_index=0, 
            prompt_code_template=new_whole_prompt,
            result_dir=result_dir
        )

        # get code
        synthesized_code_samples = code_base
        embeddings = generate_code_embeddings(synthesized_code_samples)
        
        max_silhouette_score = 0
        max_K = 0
        for K in range(2, 10):
            entropy, probs, counts, silhouette_score = cluster_and_calculate_entropy(embeddings, K)
            if silhouette_score > max_silhouette_score:
                max_silhouette_score = silhouette_score
                max_K = K
                curr_max_entropy = entropy

        if curr_max_entropy > max_entropy:
            prompt_part[content_index] = optimized_prompt_part
            max_entropy = curr_max_entropy


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--max_iterations', type=int, default=20)
    parser.add_argument('--solver_name', type=str, default="modsat")
    parser.add_argument('--num_cpus', type=int, default=100)
    parser.add_argument('--llm_model',
                        type=str,
                        default="deepseek-chat",
                        choices=["gpt-4o-2024-08-06", "deepseek-reasoner", "o1-2024-12-17", "deepseek-chat"])
    parser.add_argument('--sat_timeout', type=int, default=15)
    parser.add_argument('--data_dir', type=str, default="./data/Zamkeller_train")
    parser.add_argument('--project', type=str, default="module_function")
    parser.add_argument('--source_file', type=str, default="./modsat/modsat/core/Solver.cc")
    parser.add_argument('--executable_file', type=str, default="modsat/bin/modsat")
    parser.add_argument('--project_dir', type=str, default="./examples/ModSAT")
    parser.add_argument('--result_dir', type=str, default="./results/")

    parser.add_argument('--api_base', type=str, default='')
    parser.add_argument('--api_key', type=str, default='')
    args = parser.parse_args()

    main(args)
