import numpy as np
import warnings
from time import time
from puzzlebot_assembly.utils import *
from puzzlebot_assembly.planner import Planner


class BehaviorLib:
    def __init__(self, N, controller, pool, eth, bhav_list=[], 
                 robot_param={}, logger=None):
        self.N = N
        self.ctl = controller
        self.bhav_list = bhav_list
        self.bhav_id = 0
        self.eth = eth
        self.robot_param = robot_param
        self.align_pool_var = {}
        self.anchor_cps = {}
        self.ctl_param = controller.param
        self.fail_count = 0
        self.time_list = []
        self.status_list = []
        self.pool = pool
        self.logger = logger
        self.is_init = False
        self.curr_status = np.zeros([2, N]) # time, status

    def add_bhav(self, bhav):
        self.bhav_list.append(bhav)

    def wiggle(self, x, u, t=0, vbias=0, rid=[]):
        N = self.N
        L = self.robot_param.L
        u = np.zeros(2*N)
        rid = np.argsort(-x[0::3])[0]

        x_diff = np.diff(x.reshape([N, 3]).T[0:2, :], axis=1)
        #  print("x_diff:", x_diff)
        if np.all(np.linalg.norm(x_diff, axis=0) > 1.3*L):
            u[rid*2] = self.ctl_param.vmax
            return u
        u[rid*2] = vbias
        u[rid*2+1] = 2*self.ctl_param.wmax * np.sign(np.sin(t*5))
        return u

    def init_anchor_param(self, cp_d):
        '''
        cp_d: 3-by-2 np array for the coupling pair position in local frame
        '''
        L = self.robot_param.L
        bl = self.robot_param.anchor_base_L
        al = self.robot_param.anchor_L
        eth = self.eth

        [body_idx, _], conn_type = get_anchor_body_index(cp_d, L)

        # anchor status:
        # decoupled, head_aligned, head_insert
        anchor_param = {}
        anchor_param['status'] = "decoupled"
        # execute status: wait, now, done
        anchor_param['execute'] = "wait"
        anchor_param['anchor_index'] = 1 - body_idx
        anchor_param['type'] = conn_type

        if conn_type == "anchor":
            anchor_param['align_cp'] = np.array([[L/2, 0, 0], 
                                        [-L/2 - al -eth, 0, 0]]).T
            anchor_param['insert_cp'] = np.array([[L/2, 0, 0], [-L/2, 0, 0]]).T
            anchor_param['maintain_cp'] = np.array([[L/2, 0, 0], 
                                        [-L/2 - bl -eth, 0, 0]]).T
        elif conn_type == "knob":
            anchor_param['align_cp'] = np.array([[-0.02, -L/2-0.02, -0.3],
                                        [L/2, L/2, 0.3]]).T
            anchor_param['insert_cp'] = np.array([[-0.001, -L/2, 0],
                                        [L/2, L/2, 0]]).T
            anchor_param['maintain_cp'] = np.array([
                                    [0.001, -L/2, 0],
                                    [L/2, L/2, 0.3]]).T
        else:
            raise ValueError("Connection Type Error!")

        if body_idx == 1:
            anchor_param['align_cp'] = np.fliplr(anchor_param['align_cp'])
            anchor_param['insert_cp'] = np.fliplr(anchor_param['insert_cp'])
            anchor_param['maintain_cp'] = np.fliplr(anchor_param['maintain_cp'])
        return anchor_param

    def update_cp_anchor(self, cp_list):
        anchor_cps = {}

        for cp_ids in cp_list:
            cp_d = cp_list[cp_ids]
            anchor_param = self.init_anchor_param(cp_d)
            anchor_cps[cp_ids] = anchor_param
        return anchor_cps

    def get_current_dicts(self, x, anchor_cps, curr_dict, conn_dict):
        N, eth = self.N, self.eth
        robot_param = self.robot_param
        
        for ids in anchor_cps:
            anchor_param = anchor_cps[ids]
            status = anchor_param['status']
            if anchor_param['execute'] == "wait": continue
            # print("ids:", ids)
            # print(anchor_param['status'])

            # get index of robot aligning with anchor (and without)
            anchor_idx = anchor_param['anchor_index']
            anchor_id = ids[anchor_idx]
            body_idx = 1 - anchor_idx
            body_id = ids[body_idx]
            body_x = x[3*ids[body_idx]:3*(ids[body_idx]+1)]

            if status == "decoupled":
                curr_dict[ids] = anchor_param['align_cp']

                # check if the anchor head is already aligned
                anchor_pt = body2world(x[3*anchor_id:3*(anchor_id+1)],
                            anchor_param['align_cp'][:, anchor_idx, np.newaxis])
                if (anchor_param['type'] == "anchor" and 
                    is_inside_robot(anchor_pt, body_x, robot_param.L,
                                    margin=0)):
                    anchor_param['status'] = "head_aligned"
                    continue
                if anchor_param['type'] == "knob":
                    body_pt = body2world(x[3*body_id:3*(body_id+1)],
                            anchor_param['align_cp'][:, body_idx, np.newaxis])
                    if np.linalg.norm(body_pt - anchor_pt) < 2*eth:
                        anchor_param['status'] = "head_aligned"
                        continue
            elif status == "head_aligned":
                curr_dict[ids] = anchor_param['insert_cp']

                # make sure the anchor head are indeed aligned
                anchor_pt = body2world(x[3*anchor_id:3*(anchor_id+1)],
                            anchor_param['align_cp'][:, anchor_idx, np.newaxis])
                #  print("anchor_pt:", anchor_pt)
                if anchor_param['type'] == "anchor":
                    if not is_inside_robot(anchor_pt, 
                                           body_x,
                                           robot_param.L,
                                           margin=eth):
                        anchor_param['status'] = "decoupled"
                        continue
                elif anchor_param['type'] == "knob":
                    anchor_pt = body2world(x[3*anchor_id:3*(anchor_id+1)],
                            anchor_param['insert_cp'][:, anchor_idx, np.newaxis])
                    body_pt = body2world(x[3*body_id:3*(body_id+1)],
                            anchor_param['insert_cp'][:, body_idx, np.newaxis])
                    if np.linalg.norm(body_pt - anchor_pt) > 10*eth:
                        anchor_param['status'] = "decoupled"
                        continue

                # check if the anchor is inserted
                if anchor_param['type'] == "anchor":
                    anchor_pt = body2world(x[3*anchor_id:3*(anchor_id+1)],
                            anchor_param['insert_cp'][:, anchor_idx, np.newaxis])
                    body_pt = body2world(x[3*body_id:3*(body_id+1)],
                            anchor_param['insert_cp'][:, body_idx, np.newaxis])
                    if ((np.linalg.norm(body_pt - anchor_pt) < eth) or is_inside_robot(anchor_pt, body_x, robot_param.L,
                                        margin=eth)):
                        anchor_param['status'] = "head_insert"
                        continue
                elif anchor_param['type'] == "knob":
                    anchor_pt = body2world(x[3*anchor_id:3*(anchor_id+1)],
                            anchor_param['insert_cp'][:, anchor_idx, np.newaxis])
                    body_pt = body2world(x[3*body_id:3*(body_id+1)],
                            anchor_param['insert_cp'][:, body_idx, np.newaxis])
                    if (np.linalg.norm(body_pt - anchor_pt) < eth):
                        anchor_param['status'] = "head_insert"
                        continue
                    # is_angle_align = get_heading_err(x, {ids: anchor_param['insert_cp']})
                    # if is_angle_align and (np.linalg.norm(body_pt - anchor_pt) < 0.5*eth):
                    #     anchor_param['status'] = "head_insert"
                    #     continue

            elif status == "head_insert":
                if ids not in conn_dict:
                    conn_dict[ids] = anchor_param['maintain_cp']
                    curr_dict.pop(ids)
                    #  print("normal head insert")
                    continue
                
                # the head is already inserted, check if still inserted
                #  anchor_pt = body2world(x[3*anchor_id:3*(anchor_id+1)],
                            #  anchor_param['insert_cp'][:, anchor_idx, np.newaxis])
                #  if anchor_param['type'] == "anchor":
                    #  body_pt = body2world(x[3*body_id:3*(body_id+1)],
                            #  anchor_param['insert_cp'][:, body_idx, np.newaxis])
                    #  if ((np.linalg.norm(body_pt - anchor_pt) < eth) or is_inside_robot(anchor_pt, body_x, robot_param.L,
                                        #  margin=0.3*eth)):
                        #  # the head is still inserted, do nothing
                        #  print("head insert check pass")
                        #  continue
                    #  # head is not inserted, go back to insert
                    #  warnings.warn('Anchor disconnected')
                    #  anchor_param['status'] = "head_insert"
                    #  conn_dict.pop(ids)
                    #  curr_dict[ids] = anchor_param['insert_cp']
                #  elif anchor_param['type'] == "knob":
                    #  raise NotImplementedError("Knob type in head_insert")

        return curr_dict, conn_dict

    def get_zero_ids(self, conn_dict, busy, pilot_ids=[]):
        mask = np.ones(self.N, dtype=bool)
        conn_ids = list(conn_dict.keys())
        mask_conn_ids = []
        for cids in conn_ids:
            if busy[cids[0]] or busy[cids[1]]:
                mask_conn_ids += list(cids)
        mask[mask_conn_ids] = False
        mask[busy] = False
        mask[pilot_ids] = True
        return np.nonzero(mask)[0]

    def init_anchor_pool(self, x, pilot_ids=[], pattern=[]):
        N = self.N
        if len(pattern) > 0:
            self.pattern = pattern
        else:
            # temperary define the pattern here
            #  self.pattern = [1,2]
            self.pattern = [N]
        self.align_pool_var['planner'] = Planner(self.N)
        self.align_pool_var['busy'] = np.zeros(N, dtype=bool)
        p = self.align_pool_var['planner']
        self.align_pool_var['pair_dict'] = p.generate_pair_pool(
                                            self.pattern, 
                                            x.reshape([N, 3]).T[0:2, :],
                                            pilot_ids=pilot_ids)
        self.anchor_cps = self.update_cp_anchor(self.align_pool_var['pair_dict'])
        self.align_pool_var['curr_dict'] = {}
        self.align_pool_var['conn_dict'] = {}
        self.align_pool_var['remain_dict'] = dict.fromkeys(self.align_pool_var['pair_dict'])
        self.align_pool_var['seg_dict'] = {i:[i] for i in range(N)}

    def recompute_remain_dict(self, x, pilot_ids=[]):
        N = self.N
        p = self.align_pool_var['planner']
        assert(len(self.pattern) > 0)
        new_dict = p.generate_pair_pool(self.pattern, 
                                        x.reshape([N, 3]).T[0:2, :],
                                        pilot_ids=pilot_ids)
        new_anchor_cps = self.update_cp_anchor(new_dict)
        old_anchor_cps = self.anchor_cps
        remain_dict = self.align_pool_var['remain_dict']

        # check if the new dict overlaps with the old one
        new_pairs, old_pairs = [], []
        for ids in new_anchor_cps:
            if ids in old_anchor_cps:
                # check if the new cps are the same as the old ones
                if not (is_cp_same(ids, 
                                new_anchor_cps[ids]['align_cp'],
                                ids,
                                old_anchor_cps[ids]['align_cp'])):
                    return 
                continue

            # a new connection pair is found
            new_pairs.append(ids)
        for ids in old_anchor_cps:
            if ids in new_anchor_cps:
                continue
            old_pairs.append(ids) 
        assert(len(new_pairs) == len(old_pairs))
        for ids in old_pairs:
            remain_dict.pop(ids)
            if ids in self.align_pool_var['curr_dict']:
                self.align_pool_var['curr_dict'].pop(ids)
        for ids in new_pairs:
            remain_dict[ids] = None
            self.align_pool_var['curr_dict'][ids] = new_anchor_cps[ids]['align_cp']
        self.anchor_cps = new_anchor_cps

    def align_anchor_pool(self, x, prev_u, pilot_ids=[]):
        N = self.N
        if len(self.align_pool_var) == 0: 
            self.init_anchor_pool(x, pilot_ids=pilot_ids)
        # planner = self.align_pool_var['planner']
        robot_busy = self.align_pool_var['busy']
        if np.all(robot_busy == 0) and len(self.pattern) < 2:
            self.recompute_remain_dict(x, pilot_ids=pilot_ids)

        curr_dict = self.align_pool_var['curr_dict']
        conn_dict = self.align_pool_var['conn_dict']
        remain_dict = self.align_pool_var['remain_dict']
        seg_dict = self.align_pool_var['seg_dict']
        anchor_cps = self.anchor_cps

        print("curr_dict:", curr_dict)

        # decide which pairs to execute
        for k in remain_dict:
            if np.any(robot_busy[list(k)]):
                continue
            if k in conn_dict:
                continue
            robot_busy[list(k)] = True
            robot_busy[seg_dict[k[0]] + seg_dict[k[1]]] = True
            anchor_cps[k]['execute'] = "now"

        # find the thetas that need boosting
        boost_t_list = []
        for ids in curr_dict:
            if anchor_cps[ids]['type'] == "knob":
                boost_t_list.append(ids)

        # update contact pair if needed
        curr_dict, conn_dict = self.get_current_dicts(x, anchor_cps,
                                    curr_dict, conn_dict)
        zero_list = self.get_zero_ids(conn_dict, robot_busy, 
                                pilot_ids=pilot_ids)
        #  print("zero_list:", zero_list)
        u, time_elapse = self.align_cp(x, prev_u, curr_dict, 
                                       prev_cp=conn_dict,
                                        zero_list=zero_list,
                                        boost_t_list=boost_t_list,
                                        segment_dict=seg_dict)
        # u = self.align_cp(x, prev_u, curr_dict, prev_cp=conn_dict,
        #                 zero_list=zero_list)
        u[2*zero_list] = 0
        u[2*zero_list+1] = 0

        if self.logger:
            self.curr_status[0, :] = time_elapse

        if len(curr_dict) > 0:
            curr_key = next(iter(curr_dict))
            if anchor_cps[curr_key]['status'] == "head_aligned" and anchor_cps[curr_key]['type'] == "anchor":
                ic, ig = np.array(curr_key)
                if self.logger: self.curr_status[1, [ic, ig]] = 1

                # set use_max to true on hardware, false in sim
                u[2*ic:2*(ic+1)] = self.ctl.diff_drive_goal(
                                            x[3*ic:3*(ic+1)],
                                            x[3*ig:3*(ig+1)],
                                            use_max=True)
                u[2*ig:2*(ig+1)] = self.ctl.diff_drive_goal(
                                            x[3*ig:3*(ig+1)],
                                            x[3*ic:3*(ic+1)],
                                            use_max=True)
                # u[curr_key_ls*2] = np.sign(u[curr_key_ls*2])*self.ctl_param.vmax
                # print("in head aligned, u:", u)
            if len(pilot_ids) > 0:
                if(pilot_ids[0] in curr_key and anchor_cps[curr_key]['status'] == "head_aligned"):
                    #  print("i hate hard code")
                    u[0::2] = self.ctl_param.vmax
                    u[1::2] = np.random.rand()*0.05 - 0.1
                    u[2*pilot_ids] = self.ctl_param.vmax/3

        # current pairs are done executing
        for idx, cp in conn_dict.items():
            if anchor_cps[idx]['execute'] != "now":
                continue
            remain_dict.pop(idx)
            anchor_cps[idx]['execute'] = "done"
            robot_busy[list(idx)] = 0
            robot_busy[seg_dict[idx[0]]] = 0
            robot_busy[seg_dict[idx[1]]] = 0
            seg_dict[idx[0]].append(idx[1])
            seg_dict[idx[1]].append(idx[0])

            if self.logger: self.curr_status[1, list(idx)] = 2
        print("seg_dict:", seg_dict)

        if self.logger:
            self.logger.write(x, u, self.curr_status)
        
        #  print(self.align_pool_var)
        if len(remain_dict) == 0:
            print("All pairs aligned.")
            return None
        print("u:", u)
        return u
            
    def init_align_pool(self, x):
        N = self.N
        self.align_pool_var['planner'] = Planner(self.N)
        self.align_pool_var['busy'] = np.zeros(N, dtype=bool)
        p = self.align_pool_var['planner']
        self.align_pool_var['pair_dict'] = p.generate_pair_pool(
                                            [N], x.reshape([N, 3]).T[0:2, :])
        self.align_pool_var['curr_dict'] = {}
        self.align_pool_var['conn_dict'] = {}
        self.align_pool_var['remain_dict'] = self.align_pool_var['pair_dict'].copy()
        self.align_pool_var['seg_dict'] = {i:[i] for i in range(N)}

    def align_cp(self, x, u, cp, prev_cp=[], zero_list=[],
                 boost_t_list=[], segment_dict={}):
        print("cp_dis:", get_cp_dis(x, cp))
        if not cp: 
            return np.zeros(2 * self.N), -1
        start = time()

        # check if the current constraints are satisfied
        L = self.robot_param.L
        eth = self.eth
        aligned_cps, unaligned_cps = check_cps(x, cp, prev_cp, L, eth)

        # u_vel, obj_value = self.ctl.final(x, u, self.robot_param.L, 
        #                                   cp, prev_cp, self.pool)
        u_vel, obj_value = self.ctl.final(x, u, L, unaligned_cps, 
                                          aligned_cps, segment_dict, 
                                          self.pool)

        end = time()
        time_elapsed = end - start
        print(time_elapsed)
        if (time_elapsed) < 0.1: self.is_init = True

        if not self.is_init:
            return np.zeros(2 * self.N), -1
        self.time_list.append(end-start)

        if obj_value is None:
            self.fail_count += 1
            return np.zeros(2 * self.N), -1
        else:
            self.fail_count = 0
        if obj_value is None and self.fail_count > 3:
            # print("Recompute the optimization.")

            u_vel, obj_value = self.ctl.final(x, u, L, cp, 
                                              prev_cp, self.pool)
            self.fail_count = 0

            if obj_value is None:
                u_vel = np.zeros(2 * self.N)
            return u_vel, -1

        return u_vel, time_elapsed

    def traj_follow(self, x, u, traj="line", prev_cp=[], traj_param={}):
        print("x", x)
        get_goal = lambda x0: x0
        assert('y' in traj_param)
        center_x = traj_param['x']
        center_y = traj_param['y']
        v_sc = 0.5
        if traj == "line":
            get_goal = lambda x0: np.array([x0, np.zeros(len(x0))+center_y, np.zeros(len(x0))]).reshape([3, len(x0)]).T.flatten()
            v_sc = 0.9
        elif traj == "wave":
            # check if traj_param has param 'A' and 'T'
            assert('A' in traj_param and 'T' in traj_param)
            wA = traj_param['A']
            wT = 2*np.pi / traj_param['T']
            get_goal = lambda x0: np.array([x0, wA*np.sin((x0 - center_x)*wT)+center_y, np.arctan(wT*wA*np.cos((x0-center_x)*wT))]).reshape([3, len(x0)]).T.flatten()
            v_sc = 0.8
        else:
            raise ValueError("Unknown trajectory type!")
        print("diff: ", np.linalg.norm((x - get_goal(x[0::3])).reshape([len(x)//3, 3])[:, 0:2], axis=1))

        self.ctl.init_opt(x, u)
        self.ctl.add_dynamics_constr()
        self.ctl.add_vwlim_constraint()
        self.ctl.add_align_poly_constr(prev_cp, self.robot_param.L)
        self.ctl.add_body_line_constr(prev_cp, self.robot_param.L)
        cost = 0
        cost += self.ctl.align_cp_cost({}, prev_cp)
        #  cost += self.ctl.traj_cost(get_goal, x, v_sc=v_sc)
        cost += self.ctl.traj_cost_first(get_goal, x, v_sc=v_sc)
        cost += self.ctl.stage_cost()
        cost += self.ctl.smooth_cost(u)
        u_vel, obj_value = self.ctl.optimize_cp(cost)

        if self.logger:
            self.logger.write(x, u_vel, get_goal(x[0::3]))

        print("u:", u_vel)
        return u_vel

    def go_du(self, x, u, gdu, prev_cp=[], end_x=-0.077, slow_x=-0.088):
        #  if np.min(x[0::3]) > end_x:
            #  return np.zeros(2 * self.N)
        #  if np.max(x[-1::3]) < slow_x:
            #  return gdu*0.5
        #  self.ctl.init_opt(x, u, prev_cp)
        #  self.ctl.add_dynamics_constr()
        #  self.ctl.add_vwlim_constraint()
        #  self.ctl.add_align_poly_constr(prev_cp, self.robot_param.L)
        #  cost = self.ctl.gdu_cost(gdu)
        #  u_opt, obj_value = self.ctl.optimize_cp(cost)
        #  print("u_opt:", u_opt)
        return gdu
        #  u_opt = np.zeros(2 * self.N)
        #  u_opt[0::2] = 0.15
        return u_opt

    def current(self):
        if self.bhav_id >= len(self.bhav_list):
            return None
        return self.bhav_list[self.bhav_id]
    
    def go_to_goal(self, x, u, goal=[0.0, 0.0, 0], prev_cp=[]):#mpc, note that connection pairs is empty
        goal_vec = np.hstack([goal for i in range(self.N)])    #one copy of goal per robot
        if np.linalg.norm((x - goal_vec)[0:2]) < self.eth*50:   #if robots are close enough to their goals
            return np.zeros(2*self.N)
            #return None
        self.ctl.init_opt(x, u)                                #setting initial constraints
        self.ctl.add_dynamics_constr()                         #adding dynamics constraints
        self.ctl.add_vwlim_constraint()                        #more constraints
        #self.ctl.add_align_poly_constr(prev_cp, self.robot_param.L)   #i think this has to do with anchor position being within triangle of body
        cost = 0
        cost += self.ctl.goal_cost(goal)                       #calculates cost
        cost += self.ctl.stage_cost()
        # print("cost: " + str(cost))

        u_vel, obj_value = self.ctl.optimize_cp(cost)          #get u
        # print("obj value: " + str(obj_value))
        # print("u_vel: " + str(u_vel))
        return u_vel

    def nothing(self, x, u=[], param=[], pilot_ids=[]):
        return np.zeros(2 * self.N)
