def forward(self, Ks, ks): """ LQR forward recursion :param Ks: solved in backward recursion :param ks: solved in forward recursion :return: x, u """ assert len(Ks) == self.T, "Ks length error" new_x = [self.x_init] new_u = [] for t in range(self.T): Kt = Ks[t] kt = ks[t] xt = new_x[t] assert list(xt.shape) == [self.n_batch, self.n_state], str(xt.shape) + \ " xt dim mismatch: expected" + str( [self.n_batch, self.n_state]) new_ut = xpbmv(Kt, xt) + kt assert list(new_ut.shape) == [self.n_batch, self.n_ctrl ], "actual:" + str(new_ut.shape) if self.u_zero_Index is not None: assert self.u_zero_Index[t].shape == new_ut.shape, str( self.u_zero_Index[t].shape) + " : " + str(new_ut.shape) new_ut = self.xp.where(self.u_zero_Index[t], self.xp.zeros_like(new_ut), new_ut) assert list(new_ut.shape) == [self.n_batch, self.n_ctrl ], "actual:" + str(new_ut.shape) new_u.append(new_ut) if t < self.T - 1: assert list(xt.shape) == [self.n_batch, self.n_state ], "actual:" + str(xt.shape) assert list(new_u[t].shape) == [ self.n_batch, self.n_ctrl ], "actual:" + str(new_u[t].shape) xu_t = self.xp.concatenate((xt, new_u[t]), axis=1) x = xpbmv(self.F[t], xu_t) if self.f is not None: x += self.f[t] assert list(x.shape) == [self.n_batch, self.n_state], str(x.shape) + \ " x dim mismatch: expected" + str( [self.n_batch, self.n_state]) new_x.append(x) new_x = self.xp.stack(new_x, axis=0) new_u = self.xp.stack(new_u, axis=0) assert list(new_x.shape) == [ self.T, self.n_batch, self.n_state ], str(new_x.shape) + " new x dim mismatch" assert list(new_u.shape) == [self.T, self.n_batch, self.n_ctrl], "new u dim mismatch" return new_x, new_u
def forward(self, inputs): """ Link forward :param inputs: :return: """ with chainer.no_backprop_mode(): x_init, C_hat, c_hat, F_hat, f_hat = inputs self.retain_inputs((0, 1, 2, 3, 4)) if self.no_op_forward: self.retain_outputs((0, 1)) return self.current_states, self.controls x_init = to_xp(x_init) C_hat = to_xp(C_hat) c_hat = to_xp(c_hat) F_hat = to_xp(F_hat) f_hat = to_xp(f_hat) if self.need_expand is True: # Taylor expansion # grad(delta_x,delta_u) = hessian(0,0) @ (delta_x, delta_u) + grad(0,0) c_back = [ ] # eq(12) in [1], constant term in eq(5.12) can be removed for t in range(self.T): xt = self.current_states[t] ut = self.controls[t] xut = self.xp.concatenate((xt, ut), axis=1) assert xut.shape == (self.n_batch, self.n_sc), "expected " + str([self.n_batch, self.n_sc]) + \ "acutal" + str(xut.shape) c_back.append(xpbmv(C_hat[t], xut) + c_hat[t]) c_hat = self.xp.stack(c_back) f_hat = None # eq(13) in [1] Ks, ks, _backward = self.backward_rec(C_hat, c_hat, F_hat, f_hat) x, u, _forward = self.forward_rec(Ks, ks, self.true_cost, self.true_dynamics, self.ls_decay, self.max_ls_iter) assert list(x.shape) == [self.T, self.n_batch, self.n_state], "x dim mismatch" self.back_out = _backward self.for_out = _forward self.retain_outputs((0, 1)) assert list(x.shape) == [self.T, self.n_batch, self.n_state] assert list(u.shape) == [self.T, self.n_batch, self.n_ctrl] assert not self.xp.isnan(u).any() return x, u
def backward(self): """ LQR backward recursion Note: Ks ks is reversed version fo original :return: Ks, ks gain """ Ks = [] ks = [] Vt = None vt = None # self.T-1 to 0 loop for t in range(self.T - 1, -1, -1): # initial case if t == self.T - 1: Qt = self.C[t] qt = self.c[t] else: Ft = self.F[t] Ft_T = self.xp.transpose(Ft, axes=(0, 2, 1)) assert Ft.dtype.kind == 'f', "Ft dtype" assert Vt.dtype.kind == 'f', "Vt dtype" Qt = self.C[t] + Ft_T @ Vt @ Ft if self.f is None: # NOTE f.nelement() == 0 condition ? qt = self.c[t] + xpbmv(Ft_T, vt) else: # f is not none ft = self.f[t] qt = self.c[t] + xpbmv(Ft_T @ Vt, ft) + xpbmv(Ft_T, vt) assert list(qt.shape) == [self.n_batch, self.n_sc], "qt dim mismatch" assert list(Qt.shape) == [self.n_batch, self.n_sc, self.n_sc ], str(Qt.shape) + " Qt dim mismatch" Qt_xx = Qt[:, :self.n_state, :self.n_state] Qt_xu = Qt[:, :self.n_state, self.n_state:] Qt_ux = Qt[:, self.n_state:, :self.n_state] Qt_uu = Qt[:, self.n_state:, self.n_state:] qt_x = qt[:, :self.n_state] qt_u = qt[:, self.n_state:] assert list( Qt_uu.shape) == [self.n_batch, self.n_ctrl, self.n_ctrl], "Qt_uu dim mismatch" assert list( Qt_ux.shape) == [self.n_batch, self.n_ctrl, self.n_state], "Qt_ux dim mismatch" assert list( Qt_xu.shape) == [self.n_batch, self.n_state, self.n_ctrl], "Qt_xu dim mismatch" assert list(qt_x.shape) == [self.n_batch, self.n_state], "qt_x dim mismatch" assert list(qt_u.shape) == [self.n_batch, self.n_ctrl], "qt_u dim mismatch" # Next calculate Kt and kt # TODO LU decomposition if True: # u_zero_index is not none index = self.u_zero_Index[t] qt_u_ = copy.deepcopy(qt_u) qt_u_[index] = 0.0 Qt_uu_ = copy.deepcopy(Qt_uu) notI = 1.0 - index.astype('float') Qt_uu_I = 1 - xpbger(notI, notI) Qt_uu_I = Qt_uu_I.astype('bool') Qt_uu_[Qt_uu_I] = 0.0 # Qt_uu_ = F.where(Qt_uu_I, self.xp.zeros_like(Qt_uu_.array), Qt_uu_) index_qt_uu = self.xp.array([ self.xp.diagflat(index[i]) for i in range(index.shape[0]) ]) Qt_uu_[index_qt_uu] += 1e-8 # Qt_uu_ = F.where(F.cast(index_qt_uu, 'bool'), Qt_uu + 1e-8, Qt_uu) Qt_ux_ = copy.deepcopy(Qt_ux) index_qt_ux = self.xp.repeat(self.xp.expand_dims(index, axis=2), Qt_ux.shape[2], axis=2) Qt_ux_[index_qt_ux] = 0.0 # Qt_ux_ = F.where(index_qt_ux, self.xp.zeros_like(Qt_ux_.array), Qt_ux) # print("qt_u_", qt_u_) # print("Qt_uu_", Qt_uu_) # print("Qt_ux_", Qt_ux_) if self.n_ctrl == 1: Kt = -(1. / Qt_uu_) * Qt_ux_ # NOTE different from original kt = -(1. / self.xp.squeeze(Qt_uu_, axis=2)) * qt_u_ else: Qt_uu_LU_ = xpbatch_lu_factor(Qt_uu_) Kt = -xpbatch_lu_solve(Qt_uu_LU_, Qt_ux_) kt = -xpbatch_lu_solve(Qt_uu_LU_, qt_u_) assert list(Kt.shape) == [self.n_batch, self.n_ctrl, self.n_state], "Kt dim mismatch" assert list(kt.shape) == [self.n_batch, self.n_ctrl], "kt dim mismatch" Kt_T = self.xp.transpose(Kt, axes=(0, 2, 1)) Ks.append(Kt) ks.append(kt) Vt = Qt_xx + self.xp.matmul(Qt_xu, Kt) + self.xp.matmul( Kt_T, Qt_ux) + self.xp.matmul(self.xp.matmul(Kt_T, Qt_uu), Kt) vt = qt_x + xpbmv(Qt_xu, kt) + xpbmv(Kt_T, qt_u) + xpbmv( self.xp.matmul(Kt_T, Qt_uu), kt) assert len(Ks) == self.T, "Ks length error" Ks.reverse() ks.reverse() return Ks, ks
def backward_rec(self, C_hat, c_hat, F_hat, f_hat): """ Back ward recursion over the linearized trajectory :param C_hat: approximated C :param c_hat: approximated c :param F_hat: approximated F :param f_hat: approximated f :return: """ assert list(C_hat.shape) == [self.T, self.n_batch, self.n_sc, self.n_sc], \ "C hat dim mismatch" assert list(c_hat.shape) == [self.T, self.n_batch, self.n_sc], \ str(c_hat.shape) + " c hat dim mismatch: expected " + str([self.T, self.n_batch, self.n_sc]) if list(F_hat.shape)[0] == self.T: F_hat = F_hat[:self.T - 1] else: assert (F_hat.shape[0]) == self.T - 1, "F_hat dimension" assert list(F_hat.shape) == [self.T - 1, self.n_batch, self.n_state, self.n_sc], \ str(F_hat.shape) + " predicted:" + str(self.T - 1) + " " + \ str(self.n_batch) + " " + str(self.n_state) + " " + str(self.n_sc) + "F_hat dim mismatch" if f_hat is not None: assert list(f_hat.shape) == [self.T - 1, self.n_batch, self.n_state] or \ list(f_hat.shape) == [self.T, self.n_batch, self.n_state], " f_hat dim mismatch" # Ks = [] # ks = [] Ks = self.xp.zeros((self.T, self.n_batch, self.n_ctrl, self.n_state)) ks = self.xp.zeros((self.T, self.n_batch, self.n_ctrl)) Vt = None vt = None prev_kt = None # used for warm start up in Projected newton quadratic programmig n_total_qp_iter = 0 # self.T-1 to 0 loop for t in range(self.T - 1, -1, -1): if t == self.T - 1: Qt = C_hat[t] qt = c_hat[t] else: Ft_hat = F_hat[t] Ft_hat_T = self.xp.transpose(Ft_hat, axes=(0, 2, 1)) Qt = C_hat[t] + Ft_hat_T @ Vt @ Ft_hat if f_hat is None: qt = c_hat[t] + xpbmv(Ft_hat_T, vt) else: # f is not none ft = f_hat[t] qt = c_hat[t] + xpbmv(Ft_hat_T @ Vt, ft) + xpbmv( Ft_hat_T, vt) assert list(qt.shape) == [self.n_batch, self.n_sc], "qt dim mismatch" assert list(Qt.shape) == [self.n_batch, self.n_sc, self.n_sc ], str(Qt.shape) + " Qt dim mismatch" Qt_xx = Qt[:, :self.n_state, :self.n_state] Qt_xu = Qt[:, :self.n_state, self.n_state:] Qt_ux = Qt[:, self.n_state:, :self.n_state] Qt_uu = Qt[:, self.n_state:, self.n_state:] qt_x = qt[:, :self.n_state] qt_u = qt[:, self.n_state:] assert list( Qt_uu.shape) == [self.n_batch, self.n_ctrl, self.n_ctrl], "Qt_uu dim mismatch" assert list( Qt_ux.shape) == [self.n_batch, self.n_ctrl, self.n_state], "Qt_ux dim mismatch" assert list( Qt_xu.shape) == [self.n_batch, self.n_state, self.n_ctrl], "Qt_xu dim mismatch" assert list(qt_x.shape) == [self.n_batch, self.n_state], "qt_x dim mismatch" assert list(qt_u.shape) == [self.n_batch, self.n_ctrl], "qt_u dim mismatch" # calculate K and k # different from LQR case starts from here # lower_bound of control - current control assert not self.xp.isnan(self.controls[t]).any(), str( self.controls[t]) assert not self.xp.isnan(self.u_lower[t]).any() assert not self.xp.isnan(self.u_upper[t]).any() lower_bound = self.u_lower[t] - self.controls[t] # upper_bound of control - current control upper_bound = self.u_upper[t] - self.controls[t] assert (lower_bound <= upper_bound).all(), " lower is larger than upper" \ + " lower: " + str(lower_bound) + "upper: " + str(upper_bound) kt, Qt_uu_free_LU, Index_free, n_qp_iter = PNQP(Qt_uu, qt_u, lower_bound, upper_bound, x_init=prev_kt, n_iter=20) if self.verbose is True: print(' + n_qp_iter in mpc step: ', n_qp_iter + 1) n_total_qp_iter += 1 + n_qp_iter prev_kt = kt Qt_ux_copy = copy.deepcopy(Qt_ux) Index_Qt_ux_free = self.xp.repeat(self.xp.expand_dims( (1.0 - Index_free), axis=2), self.n_state, axis=2) Index_Qt_ux_free = Index_Qt_ux_free.astype('bool') Qt_ux_copy[Index_Qt_ux_free] = 0.0 # Qt_ux_copy = F.where(Index_Qt_ux_free, self.xp.zeros_like(Qt_ux.data), Qt_ux_copy) if self.n_ctrl == 1: # Bad naming, Qt_uu_free_LU is scalar Kt = -((1. / Qt_uu_free_LU) * Qt_ux_copy) else: # Qt_uu K_{t,f} = - Qt_ux Kt = -xpbatch_lu_solve(Qt_uu_free_LU, Qt_ux_copy) assert list(Kt.shape) == [self.n_batch, self.n_ctrl, self.n_state], "Kt dim mismatch" assert list(kt.shape) == [self.n_batch, self.n_ctrl], "kt dim mismatch" Kt_T = self.xp.transpose(Kt, axes=(0, 2, 1)) assert not self.xp.isnan(kt).any() assert not self.xp.isnan(Kt).any() Ks[t] = Kt ks[t] = kt Vt = Qt_xx + Qt_xu @ Kt + Kt_T @ Qt_ux + Kt_T @ Qt_uu @ Kt vt = qt_x + xpbmv(Qt_xu, kt) + xpbmv(Kt_T, qt_u) + xpbmv( (Kt_T @ Qt_uu), kt) assert len(Ks) == self.T, "Ks length error" ''' Ks.reverse() ks.reverse() ''' return Ks, ks, LqrBackOut(n_total_qp_iter=n_total_qp_iter)
def backward(self, target_input_indexes, grad_outputs): """ :param target_input_indexes: :param grad_outputs: :return: """ # print("backward is called") x_init, C_hat, c_hat, F_hat, f_hat = self.get_retained_inputs() dl_dx, dl_du = grad_outputs x_init = to_xp(x_init) C_hat = to_xp(C_hat) c_hat = to_xp(c_hat) F_hat = to_xp(F_hat) f_hat = to_xp(f_hat) dl_dx = to_xp(dl_dx) dl_du = to_xp(dl_du) if dl_dx is None: dl_dx = self.xp.zeros((self.T, self.n_batch, self.n_state)) else: assert list(dl_dx.shape) == [self.T, self.n_batch, self.n_state] if dl_du is None: dl_du = self.xp.zeros((self.T, self.n_batch, self.n_ctrl)) else: assert list(dl_du.shape) == [self.T, self.n_batch, self.n_ctrl] # just concatenating dl_dx, dl_du d_taus = self.xp.concatenate((dl_dx, dl_du), axis=2) # assert False, str(d_taus) # choose active control new_x, new_u = self.get_retained_outputs() new_x = to_xp(new_x) new_u = to_xp(new_u) active_index = (self.xp.absolute(new_u - self.u_lower) <= 1e-8) | \ (self.xp.absolute(new_u - self.u_upper) <= 1e-8) dx_init_zero = self.xp.zeros_like(x_init) # backward pass LINE (1) ''' print("I", active_index) print("r",d_taus) print("F",F_hat) print("C",C_hat) assert False ''' lqr = LQR_active(dx_init_zero, C_hat, -d_taus, F_hat, None, self.T, self.n_state, self.n_ctrl, u_zero_Index=active_index) dx, du = lqr.solve_recursion() # print("dx", dx) # print("du", du) # assert False dxu = self.xp.concatenate((dx, du), axis=2) xu = self.xp.concatenate((new_x, new_u), axis=2) dC = self.xp.zeros_like(C_hat) for t in range(self.T): xut = self.xp.concatenate((new_x[t], new_u[t]), axis=1) dxut = dxu[t] dCt = -0.5 * (xpbger(dxut, xut) + xpbger(xut, dxut)) assert dC[t].shape == dCt.shape dC[t] = dCt dc = -dxu # Compute Lambda (Forward Pass line(2)) in Module(1) # lams = [] lams = self.xp.zeros((self.T, self.n_batch, self.n_state)) prev_lam = None for t in range(self.T - 1, -1, -1): Ct_xx = C_hat[t, :, :self.n_state, :self.n_state] Ct_xu = C_hat[t, :, :self.n_state, self.n_state:] ct_x = c_hat[t, :, :self.n_state] xt = new_x[t] ut = new_u[t] lamt = xpbmv(Ct_xx, xt) + xpbmv(Ct_xu, ut) + ct_x if prev_lam is not None: Fxt = self.xp.transpose(F_hat[t, :, :, :self.n_state], axes=(0, 2, 1)) lamt += xpbmv(Fxt, prev_lam) lams[t] = lamt prev_lam = lamt # lams = list(reversed(lams)) # Backward Pass Line(3) # Compute the derivatives # d_Lambda # dlams = [] dlams = self.xp.zeros_like(lams) prev_dlam = None for t in range(self.T - 1, -1, -1): dCt_xx = C_hat[t, :, :self.n_state, :self.n_state] dCt_xu = C_hat[t, :, :self.n_state, self.n_state:] drt_x = -d_taus[t, :, :self.n_state] dxt = dx[t] dut = du[t] dlamt = xpbmv(dCt_xx, dxt) + xpbmv(dCt_xu, dut) + drt_x if prev_dlam is not None: Fxt = self.xp.transpose(F_hat[t, :, :, :self.n_state], axes=(0, 2, 1)) dlamt += xpbmv(Fxt, prev_dlam) dlams[t] = dlamt prev_dlam = dlamt # dlams = self.xp.stack(list(reversed(dlams))) # d_F dF = self.xp.zeros_like(F_hat) for t in range(self.T - 1): xut = xu[t] lamt = lams[t + 1] dxut = dxu[t] dlamt = dlams[t + 1] append_to_dF = -(xpbger(dlamt, xut) + xpbger(lamt, dxut)) assert dF[t].shape == append_to_dF.shape, str( dF[t].shape) + " : " + str(append_to_dF.shape) dF[t] = append_to_dF if f_hat is not None: _dlams = dlams[1:] assert _dlams.shape == f_hat.shape df = -_dlams df = chainer.Variable(df) else: # CHECK THIS df = chainer.Variable() dx_init = -dlams[0] # print(dx_init.shape) # print(dC.shape) # print(dc.shape) # print(dF.shape) # print(dF) # assert False dx_init = chainer.Variable(dx_init) dC = chainer.Variable(dC) dc = chainer.Variable(dc) dF = chainer.Variable(dF) # print("dC", dC) # print("dc", dc) return dx_init, dC, dc, dF, df
def forward_rec(self, Ks, ks, true_cost, true_dynamics, ls_decay, max_ls_iter): """ Forward recursion and line search :param Ks: :param ks: :param true_cost: true Cost function :param true_dynamics: true dynamics function :param states: states_{1:T} current state iterate :param ls_decay: line search decay ratio :param max_ls_iter: max line search iteration :return: """ assert len(Ks) == self.T, "Ks length error" states = self.current_states alphas = self.xp.ones(self.n_batch, dtype=self.controls.dtype) OLD_COST = xpget_cost(self.T, self.controls, true_cost, true_dynamics, x=states) current_cost = None n_iter = 0 # number of line search iterations full_du_norm = None # initial change of u # line search terminate condition, alpha for all batch is decreased until all batch meets terminal condition. while (n_iter < max_ls_iter and current_cost is None or (current_cost > OLD_COST).any()): assert type(alphas) == self.xp.ndarray, "alphas dtype error" new_x = [states[0]] new_u = [] dx = [self.xp.zeros_like(states[0])] objs = [] # cost for t in range(self.T): Kt = Ks[t] kt = ks[t] new_xt = new_x[t] xt = states[t] ut = self.controls[t] dxt = dx[t] new_ut = xpbmv(Kt, dxt) + ut assert not self.xp.isnan(new_ut).any() assert not self.xp.isinf(new_ut).any() xp_alpha = self.xp.diagflat(alphas).astype(dtype=kt.dtype) assert not self.xp.isnan(xp_alpha).any() assert not self.xp.isinf(xp_alpha).any() assert not self.xp.isnan(kt).any() add_new_ut = xp_alpha @ kt assert not self.xp.isnan(add_new_ut).any( ), str(alphas) + " @" + str(kt) + "=" + str(add_new_ut) new_ut += add_new_ut assert not self.xp.isnan(new_ut).any(), str(xp_alpha) + str(kt) new_ut = xpclamp(new_ut, self.u_lower[t], self.u_upper[t]) # delta_u is None assert not self.xp.isnan(new_ut).any() new_u.append(new_ut) new_xut = self.xp.concatenate((new_xt, new_ut), axis=1) if t < self.T - 1: # Calculate next x_{t+1} # Dynamics is linear if isinstance(true_dynamics, LinDx): large_f, f = true_dynamics.F, true_dynamics.f large_f = to_xp(large_f) f = to_xp(f) # new_x_{t+1} new_xtp1 = xpbmv(large_f[t], new_xut) if f is not None: new_xtp1 += f[t] else: # Dynamics is non linear new_xtp1 = true_dynamics(new_xt, new_ut) new_xtp1 = to_xp(new_xtp1) assert not self.xp.isnan(new_xtp1).any() new_x.append(new_xtp1) dx.append(new_xtp1 - states[t + 1]) # Calculate cost # If cost is quadratic if isinstance(true_cost, QuadCost): C = true_cost.C c = true_cost.c C = to_xp(C) c = to_xp(c) obj = 0.5 * xpbquad(new_xut, C[t]) + xpbdot(new_xut, c[t]) else: obj = true_cost(new_xut) objs.append(obj) objs = self.xp.stack(objs, axis=0) current_cost = self.xp.sum(objs, axis=0) new_x = self.xp.stack(new_x, axis=0) new_u = self.xp.stack(new_u, axis=0) # only update once if full_du_norm is None: du = self.controls - new_u du = self.xp.transpose(du, axes=(0, 2, 1)).reshape( self.n_batch, self.T * self.n_ctrl) full_du_norm = self.xp.sqrt(self.xp.sum(du**2, axis=1)) assert list(new_x.shape) == [ self.T, self.n_batch, self.n_state ], str(new_x.shape) + " new x dim mismatch" assert list(new_u.shape) == [self.T, self.n_batch, self.n_ctrl], "new u dim mismatch" index_decay = current_cost > OLD_COST assert not self.xp.isinf(alphas).any() alphas[index_decay] *= ls_decay assert not self.xp.isinf(alphas).any(), str(ls_decay) n_iter += 1 # TODO Check this decay # If the iteration limit is hit, some alphas # are one step too small. alphas[current_cost > OLD_COST] /= ls_decay du = self.controls - new_u du = self.xp.transpose(du, axes=(0, 2, 1)).reshape(self.n_batch, self.T * self.n_ctrl) alpha_du_norm = self.xp.sqrt(self.xp.sum(du**2, axis=1)) res = LqrForOut(objs, full_du_norm, alpha_du_norm, self.xp.mean(alphas), current_cost) assert not self.xp.isnan(new_x).any() assert not self.xp.isnan(new_u).any() return new_x, new_u, res
def PNQP(H, q, lower, upper, x_init=None, n_iter=20): """ projected newton qp solver :param H: :param q: :param lower: :param upper: :param x_init: :param n_iter: :return: Algorithm[1] in [2] 1) Get indices: eq(15) 2) Get Newton step: eq(16) 3) Convergence: If |g_f|< epsilon << 1 terminate 4) Line search """ xp = get_array_module(H) if type(H) == Variable: H = H.array if type(q) == Variable: q = q.array if type(lower) == Variable: lower = lower.array if type(upper) == Variable: upper = upper.array if x_init is not None and type(x_init) == Variable: x_init = x_init.array assert type(H) == xp.ndarray, str(H) assert (lower <= upper).all(), " lower is larger than upper" \ + " lower: " + str(lower) + "upper: " + str(upper) n_batch = H.shape[0] n_dim = H.shape[1] assert list(H.shape) == [n_batch, n_dim, n_dim], "H dim mismatch" assert list( q.shape) == [n_batch, n_dim], "q dim mismatch expected" + str([n_batch, n_dim]) assert list( lower.shape) == [n_batch, n_dim], "lower dim mismatch actual" + str(lower.shape) assert list(upper.shape) == [n_batch, n_dim], "upper dim mismatch" # small identity matrix I_pnqp = xpexpand_batch((1e-11 * xp.eye(n_dim)), n_batch).reshape(n_batch, n_dim, n_dim) # print("I_pnqp: ", I_pnqp) if x_init is None: # make initial guess if n_dim == 1: x_init = -(1.0 / xp.squeeze(H, axis=2)) * q else: H_lu = xpbatch_lu_factor(H) # Clamped in the x assignment # Hx = -q (Don't to unpack H_lu) x_init = -xpbatch_lu_solve(H_lu, q) else: # Don't over-write the original x_init. x_init = copy.deepcopy(x_init) x_init = to_xp(x_init) assert type(x_init[0][0]) != Variable # Begin with feasible guess assert type(x_init[0][0]) != Variable assert type(lower) != Variable assert type(upper) != Variable x = xpclamp(x_init, lower, upper) assert list(x.shape) == [n_batch, n_dim], "x dim mismatch" for i in range(n_iter): # 1. Get indices # calculate gradient grad = xpbmv(H, x) + q assert type(H) == xp.ndarray assert type(grad) == xp.ndarray assert type(x) == xp.ndarray assert type(lower) == xp.ndarray assert type(upper) == xp.ndarray try: xp.greater(grad, 0.0) except: print(x) print(type(grad)) print(type(H[0, 0])) Index_c = ((x == lower) & (xp.greater(grad, 0.0)) | ((x == upper) & (xp.less(grad, 0.0)))) Index_c = 1.0 * Index_c Index_f = 1.0 - Index_c # 2. Get Newton step # print(Index_f) Index_Hff = xpbger(Index_f, Index_f) Index_not_Hff = 1.0 - Index_Hff Index_fc = xpbger(Index_f, Index_c) Index_c = Index_c.astype('bool') g_f = copy.deepcopy(grad) # print("g_f original", g_f) g_f[Index_c] = 0.0 # g_f = F.where(Index_c, chainer.Variable(xp.zeros_like(g_f, dtype=g_f.dtype)), g_f) # Bad implementation (when n_dim is large) H_f = copy.deepcopy(H) # Index_not_Hff = xp.cast(Index_not_Hff, 'bool') Index_not_Hff = Index_not_Hff.astype('bool') H_f[Index_not_Hff] = 0.0 # H_f = F.where(Index_not_Hff, chainer.Variable(xp.zeros_like(H_f, dtype=H_f.dtype)), H_f) H_f += I_pnqp # print("H", H) # print("H_f", H_f) # calculate dx if n_dim == 1: dx = -(1.0 / xp.squeeze(H_f, axis=2)) * g_f else: H_lu_f = xpbatch_lu_factor(H_f) dx = -xpbatch_lu_solve(H_lu_f, g_f) # 3. Convergence norm = xp.sqrt(xp.sum(dx**2, axis=1)) batch_large = norm >= 1e-4 batch_large = batch_large.astype('float') num_large = xp.sum(batch_large) if num_large == 0: return x, H_f if n_dim == 1 else H_lu_f, Index_f, i # check convergence if num_large.data == 0: ''' print("x:", x) if n_dim != 1: print("==============================") print("H", H_f) print(len(H_lu_f)) print("H_lu_f[0][0]", H_lu_f[0][0]) print("H_lu_f[0][1]", H_lu_f[0][1]) print("H_lu_f[1]", H_lu_f[1]) print("dx", dx) print("Index_f", Index_f) print("i", i) ''' return x, H_f if n_dim == 1 else H_lu_f, Index_f, i # 4. Line search (Backtracking) alpha = xp.ones(n_batch, dtype=x.dtype) DECAY = 0.1 # making alpha smaller max_lhs = xp.array(GAMMA) batch_large = batch_large.astype('bool') ''' print("Hf:", H_f) print("batch_large: ", batch_large) print("gradient:", grad) ''' count = 0 while max_lhs <= GAMMA and count < 10: x_hat = xpclamp(x + xp.diagflat(alpha) @ dx, lower, upper) lhs = (GAMMA + 1e-6) * (xp.ones(n_batch, dtype=x.dtype)) lhs[batch_large] = (calc_obj(H, q, x) - calc_obj(H, q, x_hat))[batch_large] \ / xpbdot(grad, x - x_hat)[batch_large] ''' print("x:", x) print("dx:", dx) print("x_hat: ", x_hat) print("lhs_cng", lhs_cng) print("x_hat:", x_hat) print("lhs:", lhs) ''' I = lhs <= GAMMA alpha[I] *= DECAY # making smaller max_lhs = xp.max(lhs) # Don't write cnt += cnt +1 HERE count += 1 x = x_hat warnings.warn( "Projected Newton Quadratic Programming warning: Did not converge") ''' x = Variable(x) H_f = Variable(H_f) a, b = H_lu_f a = Variable(a) H_lu_f = (a, b) Index_f =Variable(Index_f) ''' return x, H_f if n_dim == 1 else H_lu_f, Index_f, i