import numpy as np
from puzzlebot_assembly.canvas import Canvas
from puzzlebot_assembly.behavior_lib import BehaviorLib

class RobotParam:
    def __init__(self, L=5e-2,
            anchor_base_L=8e-3,
            anchor_L=1e-2):
        self.L = L
        self.anchor_base_L = anchor_base_L
        self.anchor_L = anchor_L

class Robots:
    def __init__(self, N, controller, robot_param=None, eth=1e-3, 
                pilot_ids=[], logger=None):
        self.N = N
        self.ctl = controller
        self.pilot_ids = pilot_ids
        self.x = np.zeros(3*N, dtype=np.float64)
        self.u = np.zeros(2*N, dtype=np.float64)
        self.prev_u = np.zeros(2*N)
        self.canvas = Canvas(N)
        self.time = {}
        self.robot_config = robot_param if robot_param else RobotParam()
        self.behavs = BehaviorLib(N, controller, eth=eth, 
                                robot_param=self.robot_config,
                                logger=logger)
        self.logger = logger
        self.stored_param = {}

    def setup(self, start, dt=0.1, tmax=15):
        assert(self.N == start.shape[1])
        
        self.time = {'t': 0, 'dt': dt, 'tmax': tmax}
        #  self.behavs.add_bhav(self.behavs.wiggle)
        self.behavs.add_bhav(self.behavs.align_anchor_pool)
        # self.behavs.add_bhav(self.behavs.traj_follow)
        #  self.behavs.add_bhav(self.behavs.go_to_goal)
        #  self.behavs.add_bhav(self.behavs.go_du)
        self.behavs.add_bhav(self.behavs.nothing)

        if len(self.pilot_ids) > 0:
            if self.pilot_ids[0] == 'first':
                pilot_id = np.argmax(start[0, :])
                self.pilot_ids = [pilot_id]
                print("updated pilot_id:", self.pilot_ids)

        self.x = start.T.flatten()

    def start(self):
        # this function is replaced in the simulation.py for Bullet
        print("System started.")
        while self.time['t'] < self.time['tmax']:
            is_done = self.step(self.x, self.prev_u, self.time['t'])
            x = self.ctl.fk(self.x, self.u)
            #  x = self.ctl.fk_rk4(self.x, self.u)
            #  x = self.ctl.fk_exact(self.x, self.u)
            self.log.append_xu(self.x, self.u)
            self.x = x
            self.time['t'] += self.time['dt']
            if is_done: 
                print("Simulation ended.")
                break

    def step(self, x, u, t):
        """
        x: 3N vector, 
        """
        N = self.N
        if self.logger is not None:
            if self.logger.is_end: 
                self.u = np.zeros(2*N, dtype=float)
                return True

        body_len = self.robot_config.L
        curr_bhav = self.behavs.current()
        if not curr_bhav: return True
        if curr_bhav == self.behavs.align_cp:
            cp = {(0, 1): np.array([[body_len/2, body_len/2, 0],
                    [-body_len/2, body_len/2, 0]]).T}
            u = curr_bhav(x, u, cp)
        elif curr_bhav == self.behavs.go_du:
            prev_cp = self.behavs.align_pool_var['conn_dict']
            #  prev_cp = self.behavs.anchor_param['align_cp']
            gdu = np.array([[0.06 - i*0.010, 0.0] for i in range(N)]).T
            #  gdu = np.array([[0.06, 0] for i in range(N)]).T
            #  gdu = np.array([[0.06,0], [0.03,0], [0.03,0]]).T
            du = np.zeros([2, N])
            sort_idx = np.argsort(x[0::3])
            du[:, sort_idx] = gdu[:, :]
            du = du.T.flatten()
            du[1::2] = -x[2::3]
            print("x:", x)
            print("robot du:", du)
            u = curr_bhav(x, u, du, prev_cp=prev_cp)
        elif curr_bhav == self.behavs.go_to_goal:
            prev_cp = self.behavs.align_pool_var['conn_dict']
            goal = [2.0, x[1], 0]
            u = curr_bhav(x, u, goal=goal, prev_cp=prev_cp)
        elif curr_bhav == self.behavs.traj_follow:
            try:
                prev_cp = self.behavs.align_anchor_pool_var['conn_dict']
            except:
                prev_cp = {}
            sort_idx = np.argsort(x[0::3])
            center_x, center_y = 0, 0
            if 'center_y' not in self.stored_param:
                center_y = x[3*sort_idx[-1]+1]
                self.stored_param['center_y'] = center_y
                center_x = x[3*sort_idx[-1]]
                self.stored_param['center_x'] = center_x
            else:
                center_x = self.stored_param['center_x']
                center_y = self.stored_param['center_y']
            #  print("center_y:", center_y)
            u = curr_bhav(x, u, traj="line", prev_cp=prev_cp, traj_param={'y': center_y, 'x': center_x})
            #  u = curr_bhav(x, u, traj="wave", prev_cp=prev_cp, traj_param={'y': center_y, 'x': center_x, 'A': 0.05, 'T': 1.0})
        elif curr_bhav == self.behavs.wiggle:
            u = curr_bhav(x, u, t=t, vbias=0.008)
        else:
            u = curr_bhav(x, u, pilot_ids=self.pilot_ids)
        if u is None:
            self.behavs.bhav_id += 1
            self.u[:] = 0
            return False
        self.prev_u[:] = self.u[:]
        self.u = u
        return False

    def generate_anim(self):
        self.canvas.animation(self.log)
