import casadi as ca
import numpy as np

from puzzlebot_assembly.utils import *

class ControlParam:
    def __init__(self, vmax=0.5,
                    uvmax=0.1,
                    wmax=1.0,
                    uwmax=0.5,
                    gamma=0.1,
                    mpc_horizon=10,
                    constr_horizon=10,
                    eth=1e-3):
        self.vmax = vmax
        self.uvmax = uvmax
        self.wmax = wmax
        self.uwmax = uwmax
        self.gamma = gamma
        self.hmpc = mpc_horizon
        self.hcst = constr_horizon
        self.eth = eth
        self.cost_Q = {
            "cp_xy": 1e5, "cp_t":1e3,      # final cost of connection pair
            "prev_xy": 1e-2, "prev_t": 1e-2,# final cost of connected cp
            "s_cp_xy": 1e0, "s_cp_t": 1e-2,  # stage cost of connection pair
            "s_prev_xy": 1e-3, "s_prev_t": 1e-4,    # stage cost of conncted cp
            "rend_xy":1e0, "rend_t": 1e3,
            "smooth_v": 1e-1, "smooth_w":1e0,
            "Q_u": 1e-2 
            }

class CasadiInterface:
    def __init__(self, N, dt, state_len, M=0.1):
        self.N = N
        self.dt = dt
        self.state_len = state_len
        self.M = M

    def get_local_pt(self):
        xi = ca.SX.sym('xi', 3)
        cp = ca.SX.sym('cp', 2)
        theta = xi[2]
        cp_x = ca.cos(theta)*cp[0] - ca.sin(theta)*cp[1] + xi[0]
        cp_y = ca.sin(theta)*cp[0] + ca.cos(theta)*cp[1] + xi[1]

        return ca.Function("get_local_pt", [xi, cp], [cp_x, cp_y])

    def get_relative_pt(self):
        xi = ca.SX.sym('xi', 3) # ego robot
        xj = ca.SX.sym('xj', 3) # other robot
        xj_pt = ca.SX.sym('xj_pt', 2) # point of interest on other robot
        ti = xi[2]
        tj = xj[2]
        xi_pt_x = ca.cos(tj-ti)*(xj_pt[0]) - ca.sin(tj-ti)*(xj_pt[1]) + ca.cos(-ti)*(xj[0]-xi[0]) - ca.sin(-ti)*(xj[1]-xi[1])
        return ca.Function("get_relative_pt", [xi, xj, xj_pt], [xi_pt_x])

    def fk_opt_force(self, N, dt):
        x_sym = ca.SX.sym('x', 5*N)
        u_sym = ca.SX.sym('u', 2*N)

        theta = x_sym[2::5]
        vs = x_sym[3::5]
        x_dot = ca.SX.zeros(5*N)
        x_dot[0::5] = vs * ca.cos(theta)
        x_dot[1::5] = vs * ca.sin(theta)
        x_dot[2::5] = x_sym[4::5]
        x_dot[3::5] = u_sym[0::2]
        x_dot[4::5] = u_sym[1::2]

        return ca.Function("fk_opt", [x_sym, u_sym], [x_sym + (x_dot * dt)])

    def fk_opt(self, N, dt):
        x_sym = ca.SX.sym('x', 3*N)
        u_sym = ca.SX.sym('u', 2*N)

        theta = x_sym[2::3]
        x_dot = ca.SX.zeros(3*N)
        x_dot[0::3] = u_sym[0::2] * ca.cos(theta)
        x_dot[1::3] = u_sym[0::2] * ca.sin(theta)
        x_dot[2::3] = u_sym[1::2]
        return ca.Function("fk_opt", [x_sym, u_sym], [x_sym + (x_dot * dt)])

    def dd_fx_opt(self, theta, N):
        F = ca.SX.zeros(3*N, 2*N)
        F[0::3, 0::2] = ca.diag(ca.cos(theta))
        F[1::3, 0::2] = ca.diag(ca.sin(theta))
        F[2::3, 1::2] = ca.SX.eye(N)
        return F

    def fk_exact_opt(self, N, dt):
        x_sym = ca.SX.sym('x', 3*N)
        u_sym = ca.SX.sym('u', 2*N)
        dx = ca.SX.zeros(3*N)
        dx[0::3] = u_sym[0::2]/(u_sym[1::2] + 1e-6) * (
                                ca.sin(x_sym[2::3] + u_sym[1::2]*dt)
                                - ca.sin(x_sym[2::3]))
        dx[1::3] = - u_sym[0::2]/(u_sym[1::2] + 1e-6) * (
                                ca.cos(x_sym[2::3] + u_sym[1::2]*dt)
                                - ca.cos(x_sym[2::3]))
        dx[2::3] = u_sym[1::2] * dt
        return ca.Function("fk_exact_opt", [x_sym, u_sym], [x_sym + dx])

class Controller:
    def __init__(self, N, dt, control_param, logger=None):
        self.N = N
        self.dt = dt
        self.param = control_param
        self.state_len = 3
        self.ca_int = CasadiInterface(N, dt, self.state_len, M=0.1)
        self.fk_opt = self.ca_int.fk_opt(N, dt)
        self.get_local_pt = self.ca_int.get_local_pt()
        self.get_relative_pt = self.ca_int.get_relative_pt()
        self.ipopt_param = {"verbose": False, 
                            "ipopt.print_level": 0,
                            "print_time": 0,
                            'ipopt.sb': 'yes',
                            "ipopt.constr_viol_tol": 1e-6
                            }
        self.opt = None
        self.x = None # 3 states [x, y, theta]
        self.u = None # 2 controls [v, w]
        self.logger = logger

        # for debug
        self.prev_x = None
    
    def fit_prev_x2opt(self, prev_x):
        x_curr = np.zeros([self.N, self.state_len])
        x_curr[:, 0:3] = prev_x.reshape([self.N, 3])
        x_curr = x_curr.flatten()
        return x_curr

    def init_opt(self, prev_x, prev_u, prev_cp=[]):
        N = self.N
        param = self.param
        opt = ca.Opti()
        sl = self.state_len
        x = opt.variable(sl*N, param.hmpc + 1)
        u = opt.variable(2*N, param.hmpc)

        # for debug
        self.prev_x = prev_x

        # initial state constraints
        opt.subject_to(x[0::sl, 0] == prev_x[0::3])
        opt.subject_to(x[1::sl, 0] == prev_x[1::3])
        opt.subject_to(x[2::sl, 0] == prev_x[2::3])

        # uv, uw constraints
        opt.subject_to(opt.bounded(-param.vmax, 
                            ca.vec(u[0::2, :]), param.vmax))
        opt.subject_to(opt.bounded(-param.wmax, 
                            ca.vec(u[1::2, :]), param.wmax))

        #try warm start
        x_curr = self.fit_prev_x2opt(prev_x)
        for ti in range(param.hmpc + 1):
            opt.set_initial(x[:, ti], x_curr)
        # print("prev_u: ", prev_u)
        for ti in range(param.hmpc):
            opt.set_initial(u[:, ti], prev_u)

        self.opt = opt
        self.x = x
        self.u = u

    def add_dynamics_constr(self):
        opt = self.opt
        x, u = self.x, self.u
        # dynamics constraints
        for ti in range(self.param.hmpc):
            opt.subject_to(x[:, ti+1] == self.fk_opt(x[:, ti], u[:, ti]))

    def add_vwlim_constraint(self):
        opt = self.opt
        param = self.param
        sl = self.state_len
        x, u = self.x, self.u

        # do diamond shape constraints
        for ti in range(self.param.hmpc):
            opt.subject_to(1/param.vmax * u[0::2, ti] + 
                           1/param.wmax * u[1::2, ti] <=1)
            opt.subject_to(-1/param.vmax * u[0::2, ti] + 
                           1/param.wmax * u[1::2, ti] <=1)
            opt.subject_to(1/param.vmax * u[0::2, ti] - 
                           1/param.wmax * u[1::2, ti] <=1)
            opt.subject_to(-1/param.vmax * u[0::2, ti] - 
                           1/param.wmax * u[1::2, ti] <=1)

    def add_square_constr(self, x_pt, pts):
        opt = self.opt
        xl, xr, xls, xrs = pts
        # constraint points need to be counter-clockwise
        opt.subject_to((x_pt[1] - xls[1])*(xl[0] - xls[0]) >= (xl[1] - xls[1])*(x_pt[0] - xls[0])) 
        opt.subject_to((x_pt[1] - xrs[1])*(xls[0] - xrs[0]) >= (xls[1] - xrs[1])*(x_pt[0] - xrs[0]))
        opt.subject_to((x_pt[1] - xr[1])*(xrs[0] - xr[0]) >= (xrs[1] - xr[1])*(x_pt[0] - xr[0]))
        opt.subject_to((x_pt[1] - xl[1])*(xr[0] - xl[0]) >= (xr[1] - xl[1])*(x_pt[0] - xl[0]))

    def add_align_poly_constr(self, prev_cp, L, ex_type=[]):
        if len(prev_cp) == 0: return
        param = self.param
        eth = param.eth*2
        sl = self.state_len
        #  get_local_pt = self.ca_int.get_local_pt
        get_local_pt = self.get_local_pt
        x, u = self.x, self.u
        
        for cp_ids in prev_cp:
            [body_idx, anchor_idx], conn_type = get_anchor_body_index(prev_cp[cp_ids], L)
            body_id = cp_ids[body_idx]
            anchor_id = cp_ids[anchor_idx]
            
            cp_d = prev_cp[cp_ids][0:2, :]
            if conn_type in ex_type: continue
            if conn_type == 'anchor':
                anchor_tail_pt = cp_d[:, anchor_idx]
                anchor_body_pt = np.array([L, 0]) + anchor_tail_pt
                for ti in range(1, param.hcst):
                    x_pt = get_local_pt(x[sl*anchor_id:(sl*anchor_id+3), ti], 
                                            ca.MX(anchor_tail_pt))
                    xr = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([L/2+eth, -L/2]))
                    xl = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([L/2+eth, L/2]))
                    
                    # test experiment square constraints
                    # constraint points need to be counter-clockwise
                    x_lim, y_lim = [0.01, 0.01]
                    xl = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([L/2+eth, -y_lim]))
                    xr = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([L/2+eth, y_lim]))
                    xls = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([x_lim, -y_lim]))
                    xrs = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([x_lim, y_lim]))
                    
                    self.add_square_constr(x_pt, [xl, xr, xls, xrs])
            elif conn_type == 'knob':
                for ti in range(1, param.hcst):
                    x_pt = get_local_pt(x[sl*anchor_id:(sl*anchor_id+3), ti], 
                                            ca.MX(cp_d[:, anchor_idx]))

                    cx, cy = cp_d[:, body_idx]
                    yl = cy - (2*eth if cy < 0 else 5*eth)
                    yr = cy + (2*eth if cy > 0 else 5*eth)
                    # print("yl: ", yl, ", yr: ", yr)
                    # constraint points need to be counter-clockwise
                    xl = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([cx+3*eth, yl]))
                    xr = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([cx+3*eth, yr]))
                    xls = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([cx-2*eth, yl]))
                    xrs = get_local_pt(x[sl*body_id:(sl*body_id+3), ti],
                                            ca.MX([cx-2*eth, yr]))
                    self.add_square_constr(x_pt, [xl, xr, xls, xrs])
            else:
                raise ValueError("conn_type not recognized")

    
    def add_body_line_constr(self, prev_cp, L):
        x, u = self.x, self.u
        opt = self.opt
        param = self.param
        eth = param.eth*2
        sl = self.state_len
        get_relative_pt = self.get_relative_pt

        for cp_ids in prev_cp:
            cp_d = prev_cp[cp_ids][0:2, :]
            body_idx = np.where(cp_d[0, :] == L/2)[0] 
            assert(len(body_idx) > 0)
            body_idx = body_idx[0]
            body_id = cp_ids[body_idx]
            anchor_idx = 1 - body_idx
            anchor_id = cp_ids[anchor_idx]

            # check if constraints can be satisfied
            xa_bl_curr = get_relative_pt_num(self.prev_x[3*body_id:(3*body_id+3)], self.prev_x[3*anchor_id:(3*anchor_id+3)], np.array([-L/2,L/2]))
            xa_br_curr = get_relative_pt_num(self.prev_x[3*body_id:(3*body_id+3)], self.prev_x[3*anchor_id:(3*anchor_id+3)], np.array([-L/2,-L/2]))
            xb_fl_curr = get_relative_pt_num(self.prev_x[3*anchor_id:(3*anchor_id+3)], self.prev_x[3*body_id:(3*body_id+3)], np.array([L/2,L/2]))
            xb_fr_curr = get_relative_pt_num(self.prev_x[3*anchor_id:(3*anchor_id+3)], self.prev_x[3*body_id:(3*body_id+3)], np.array([L/2,-L/2]))
            constr_mask = np.array([xa_bl_curr >= L/2-eth, xa_br_curr >= L/2-eth, xb_fl_curr <= -L/2+eth, xb_fr_curr <= -L/2+eth])
            #  print("xa_bl_curr: ", xa_bl_curr, ", xa_br_curr: ", xa_br_curr, ", xb_fl_curr: ", xb_fl_curr, ", xb_fr_curr: ", xb_fr_curr)
            #  print("constr_mask: ", constr_mask)
            
            for ti in range(1, param.hcst):
                # for anchor robot, limit the relative position of the body robot to be smaller than back line
                if constr_mask[2] and constr_mask[3]:
                    xb_fl = get_relative_pt(x[sl*anchor_id:(sl*anchor_id+3), ti], x[sl*body_id:(sl*body_id+3), ti], ca.MX([L/2, L/2]))
                    opt.subject_to(xb_fl <= -L/2+eth)
                    xb_fr = get_relative_pt(x[sl*anchor_id:(sl*anchor_id+3), ti], x[sl*body_id:(sl*body_id+3), ti], ca.MX([L/2, -L/2]))
                    opt.subject_to(xb_fr <= -L/2+eth)

                # for body robot, limit the relative position of the anchor robot to be larger than front line
                elif constr_mask[0] and constr_mask[1]:
                    xa_bl = get_relative_pt(x[sl*body_id:(sl*body_id+3), ti], x[sl*anchor_id:(sl*anchor_id+3), ti], ca.MX([-L/2, L/2]))
                    opt.subject_to(xa_bl >= L/2-eth)
                    xa_br = get_relative_pt(x[sl*body_id:(sl*body_id+3), ti], x[sl*anchor_id:(sl*anchor_id+3), ti], ca.MX([-L/2, -L/2]))
                    opt.subject_to(xa_br >= L/2-eth)

    def add_cp_cost(self, cp, ti, xy_param, t_param):
        x, u = self.x, self.u
        sl = self.state_len
        cost = 0
        i0, i1 = next(iter(cp))
        d0 = cp[(i0, i1)][:, 0]
        d1 = cp[(i0, i1)][:, 1]
        t0 = x[i0*sl+2, ti]
        t1 = x[i1*sl+2, ti]
        cp_len = d0.shape[0]
        x_diff = x[i0*sl:(i0*sl+cp_len), ti] - x[i1*sl:(i1*sl+cp_len), ti]
        x_diff[0] += (ca.cos(t0)*d0[0] - ca.sin(t0)*d0[1] 
                    - (ca.cos(t1)*d1[0] - ca.sin(t1)*d1[1]))
        x_diff[1] += (ca.sin(t0)*d0[0] + ca.cos(t0)*d0[1] 
                    - (ca.sin(t1)*d1[0] + ca.cos(t1)*d1[1]))
        x_diff *= xy_param
        if cp_len > 2:
            # wrap angle diff in tan(0.5x)
            x_diff[2] = ca.tan(0.5*((t0 - t1) - (d0[2] - d1[2])))
            x_diff[2] *= t_param
        cost += ca.mtimes(x_diff.T, x_diff)
        return cost

    def align_cp_cost(self, cp, prev_cp, boost_t=[]):
        '''
        cp key: (i0, i1)
        cp item: 2-by-2: [[dx0, dy0], [dx1, dy1]].T
        '''
        param = self.param
        cost = 0
        for ti in range(param.hmpc+1):
            for key in cp:
                curr = {key: cp[key]}
                t_multiplier = 1
                if key in boost_t:
                    t_multiplier = 1e1
                if ti < param.hmpc:
                    cost += self.add_cp_cost(curr, ti, 
                                        param.cost_Q["s_cp_xy"], 
                                        param.cost_Q["s_cp_t"]*t_multiplier)
                else:
                    cost += self.add_cp_cost(curr, ti,
                                        param.cost_Q["cp_xy"],
                                        param.cost_Q["cp_t"]*t_multiplier)
            for key in prev_cp:
                curr = {key: prev_cp[key]}
                if ti < param.hmpc:
                    cost += self.add_cp_cost(curr, ti, 
                                        param.cost_Q["s_prev_xy"], 
                                        param.cost_Q["s_prev_t"])
                else:
                    cost += self.add_cp_cost(curr, ti,
                                        param.cost_Q["prev_xy"],
                                        param.cost_Q["prev_t"])
        return cost
    
    def init_cost(self, prev_x, zero_list=[]):
        param = self.param
        sl = self.state_len
        x = self.x
        x_curr = self.fit_prev_x2opt(prev_x)
            
        cost = 0
        return cost
    
    def segment_goal_cost(self, x, cp, segment_dict):
        '''
        make the non-busy robot to have a goal
        of the other robot in the same segment
        '''
        if not segment_dict: return 0
        param = self.param
        N = self.N
        x, u = self.x, self.u
        sl = self.state_len

        # can be potential bug if segment len > 2
        seg_leader_list = np.zeros(N, dtype=int)-1
        for cps in cp:
            i0, i1 = cps
            seg_leader_list[i0] = i0
            seg_leader_list[i1] = i1
        print("seg_leader_list: ", seg_leader_list)
        for i in segment_dict:
            this_seg = segment_dict[i]
            if seg_leader_list[i] == i: continue
            # there should only be one leader in this segment
            leader_idx = np.where(seg_leader_list[this_seg] >= 0)[0]
            if len(leader_idx) == 0: continue
            assert(len(leader_idx) == 1)
            if i == this_seg[leader_idx[0]]: continue
            seg_leader_list[this_seg] = this_seg[leader_idx[0]]

        cost = 0
        pxy = param.cost_Q["s_prev_xy"]
        pt = param.cost_Q["s_prev_t"]
        # test out the follower robot go to the leader robot
        for i in range(N):
            li = seg_leader_list[i]
            if li == i or li == -1: continue
            for ti in range(1, param.hmpc):
                xc = x[sl*i:(sl*i+sl), ti]
                xg = x[sl*li:(sl*li+sl), ti]
                diff_xy = xg[0:2] - xc[0:2]
                diff_t = ca.tan(0.5*(xg[2] - xc[2]))
                cost += pxy * ca.mtimes(diff_xy.T, diff_xy)
                cost += pt * ca.mtimes(diff_t.T, diff_t)
        
        # test out the follower robot have the same u as the leader robot
        for i in range(N):
            li = seg_leader_list[i]
            if li == i or li == -1: continue
            diff_u = u[2*i:2*(i+1), :] - u[2*li:2*(li+1), :]
            cost += ca.mtimes(diff_u[0, :], diff_u[0, :].T)*pxy
            cost += ca.mtimes(diff_u[1, :], diff_u[1, :].T)*pt
        return cost

    def goal_cost(self, goal):
        """
        goal: list of len 3 [x, y, theta]
        """
        x, u = self.x, self.u
        param = self.param
        assert(self.state_len == 5 or self.state_len == 3)
        cost = 0
        if self.state_len == 5:
            goal_vec = np.hstack([goal+[0, 0] for i in range(self.N)])
            x_diff = x[:, -1] - goal_vec
            cost += ca.mtimes(x_diff.T, x_diff) * param.cost_Q["cp_xy"]
        elif self.state_len == 3:
            goal_vec = np.hstack([goal for i in range(self.N)])
            x_diff = x[:, -1] - goal_vec
            cost += ca.mtimes(x_diff.T, x_diff) * param.cost_Q["cp_xy"]
        return cost

    def traj_cost(self, traj_func, x0, v_sc=0.5):
        """
        traj_func: lambda function with input x
        x0: 3*N vector, initial state pose-x
        """
        N = self.N
        x, u = self.x, self.u
        param = self.param
        assert(len(x0) == 3*self.N)
        cost = 0
        for ti in range(1, self.param.hmpc):
            traj_x = traj_func(x0[0::3])
            curr_x = x0[0::3] + np.max(np.hstack([np.cos(traj_x[2::3]), np.zeros(N)+0.5]), axis=0)*v_sc*param.vmax * ti * self.dt
            diff = traj_func(curr_x) - x[:, ti]
            diff[2::3] = np.tan(0.5 * diff[2::3])
            cost += ca.mtimes(diff[0::3].T, diff[0::3]) * param.cost_Q["cp_xy"]*1e0
            # 1e1 for line, 1e0 for wave
            cost += ca.mtimes(diff[1::3].T, diff[1::3]) * param.cost_Q["cp_xy"]*1e0
            cost += ca.mtimes(diff[2::3].T, diff[2::3]) * param.cost_Q["cp_t"]*1e1
        return cost

    def traj_cost_first(self, traj_func, x0, v_sc=0.5):
        """
        traj_func: lambda function with input x
        x0: 3*N vector, initial state pose-x
        """
        N = self.N
        x, u = self.x, self.u
        param = self.param
        cost = 0
        for ti in range(1, self.param.hmpc):
            for i in range(N):
                ix0 = x0[(3*i):3*(i+1)]
                traj_x = traj_func(ix0[0::3])
                curr_x = ix0[0::3] + np.max(np.hstack([np.cos(traj_x[2::3]), np.zeros(N)+0.5]), axis=0)*v_sc*param.vmax * ti * self.dt
                diff = traj_func(curr_x) - x[3*i:3*(i+1), ti]
                diff[2::3] = np.tan(0.5 * diff[2::3])
                cost += ca.mtimes(diff[0::3].T, diff[0::3]) * param.cost_Q["cp_xy"]*1e0/(i+1)
                # 1e1 for line, 1e0 for wave
                cost += ca.mtimes(diff[1::3].T, diff[1::3]) * param.cost_Q["cp_xy"]*1e1j/(i+1)
                cost += ca.mtimes(diff[2::3].T, diff[2::3]) * param.cost_Q["cp_t"]*1e1/(i+1)
        return cost

    def stage_cost(self):
        u = self.u
        param = self.param
        cost = 0
        for ti in range(1, param.hmpc):
            cost += ca.mtimes(u[:, ti].T, u[:, ti]) * param.cost_Q["Q_u"]
        return cost

    def smooth_cost(self, prev_u):
        u = self.u
        param = self.param
        diff_u = u[:, 0] - prev_u
        diff_u[0::2] *= param.cost_Q["smooth_v"]
        diff_u[1::2] *= param.cost_Q["smooth_w"]
        cost = ca.mtimes(diff_u.T, diff_u)
        for ti in range(1, param.hmpc):
            diff_u = u[:, ti] - u[:, ti-1]
            diff_u[0::2] *= param.cost_Q["smooth_v"]
            diff_u[1::2] *= param.cost_Q["smooth_w"]
            cost += ca.mtimes(diff_u.T, diff_u)
        return cost

    def rendz_cost(self):
        x = self.x
        N = self.N
        param = self.param
        sl = self.state_len

        cost = 0
        pxy = param.cost_Q["rend_xy"]
        pt = param.cost_Q["rend_t"]
        for ti in range(param.hmpc):
            for i in range(N):
                for j in range(i):
                    xc = x[sl*i:(sl*i+sl), ti]
                    xg = x[sl*j:(sl*j+sl), ti]
                    diff_xy = xg[0:2] - xc[0:2]
                    diff_t = ca.tan(0.5*(xg[2] - xc[2]))
                    cost += pxy * ca.mtimes(diff_xy.T, diff_xy)
                    cost += pt * ca.mtimes(diff_t.T, diff_t)
        return cost

    def gdu_cost(self, gdu):
        u = self.u
        cost = 0
        gdu = ca.MX(gdu.tolist())
        for ti in range(self.param.hmpc):
            diff = u[:, ti] - gdu
            cost += ca.mtimes(diff.T, diff) * 1e5
        return cost

    def optimize_cp(self, cost):
        opt = self.opt
        sl = self.state_len
        opt.minimize(cost)
        opt.solver("ipopt", self.ipopt_param)
        try:
            ans = opt.solve()
            return ans.value(self.u[:, 0]), ans.value(cost)
        except Exception as e:
            print(e)
            #  print("Solver value: ", opt.debug.value)
            # opt.debug.show_infeasibilities()
        return np.zeros(2*self.N), None

    def diff_drive_goal(self, xc, xg, use_max=False):
        '''
        xc: 3-by-1 ego robot pose, xg: 3-by-1 goal pose
        return: 2-by-1 control
        '''
        # print("xc: ", xc, ", xg: ", xg)
        x_diff = xg - xc
        x_diff[2] = wrap_pi(x_diff[2])
        du = np.array(
            [[np.cos(xc[2]), np.sin(xc[2]), 0],
             [0, 0, 1]]).dot(x_diff)
        du /= self.dt
        du[0] = np.clip(du[0], -self.param.vmax, self.param.vmax)
        du[1] = np.clip(du[1], -self.param.wmax, self.param.wmax)
        # print("du before: ", du)
        if use_max:
            du[0] = np.sign(du[0]) * self.param.vmax
        return du
