def forward(self, x_init, C, c, F, f=None): if self.no_op_forward: self.save_for_backward(x_init, C, c, F, f, self.current_x, self.current_u) return self.current_x, self.current_u if self.delta_space: # Taylor-expand the objective to do the backward pass in # the delta space. assert self.current_x is not None assert self.current_u is not None c_back = [] for t in range(self.T): xt = self.current_x[t] ut = self.current_u[t] xut = torch.cat((xt, ut), 1) c_back.append(util.bmv(C[t], xut) + c[t]) c_back = torch.stack(c_back) f_back = None else: assert False Ks, ks, self.back_out = self.lqr_backward(C, c_back, F, f_back) new_x, new_u, self.for_out = self.lqr_forward(x_init, C, c, F, f, Ks, ks) self.save_for_backward(x_init, C, c, F, f, new_x, new_u) return new_x, new_u
def linearize_dynamics(x, u, dynamics): """linearize dynamics :param x:time batch n_state :param u: :param dynamics: :return: """ assert x.shape[0] == u.shape[0] assert x.shape[1] == u.shape[1] n_state = x.shape[2] T = x.shape[0] x_init = x[0] x_ar = [x_init] # NOTE need to use newly calculate trajectory ??? # TODO: CHECK THIS large_F, f = [], [] for t in range(T): if t < T - 1: xt = x_ar[t] ut = u[t] # print("x_ut.shape", xut.shape) new_x = dynamics(xt, ut) # Linear dynamics approximation. Rt, St = [], [] for j in range(n_state): Rj, Sj = chainer.grad([F.sum(new_x[:, j])], [xt, ut]) Rt.append(Rj) St.append(Sj) assert Sj is not None Rt = F.stack(Rt, axis=1) St = F.stack(St, axis=1) # print("Rt shape", Rt.shape) # print("St shape", St.shape) Ft = F.concat((Rt, St), axis=2) large_F.append(Ft) ft = new_x - bmv(Rt, xt) - bmv(St, ut) f.append(ft) x_ar.append(new_x) large_F = F.stack(large_F, 0) f = F.stack(f, 0) return large_F, f
def forward(self, Ks, ks): """ LQR forward recursion :param Ks: solved in backward recursion :param ks: solved in forward recursion :return: x, u """ assert len(Ks) == self.T, "Ks length error" new_x = [self.x_init] new_u = [] for t in range(self.T): Kt = Ks[t] kt = ks[t] xt = new_x[t] assert list(xt.shape) == [self.n_batch, self.n_state], str(xt.shape) + \ " xt dim mismatch: expected" + str( [self.n_batch, self.n_state]) new_ut = bmv(Kt, xt) + kt assert list(new_ut.shape) == [self.n_batch, self.n_ctrl], "actual:" + str(new_ut.shape) if self.u_zero_Index is not None: assert self.u_zero_Index[t].shape == new_ut.shape, str(self.u_zero_Index[t].shape) + " : " + str( new_ut.shape) new_ut = F.where(self.u_zero_Index[t], self.xp.zeros_like(new_ut.array), new_ut) assert list(new_ut.shape) == [self.n_batch, self.n_ctrl], "actual:" + str(new_ut.shape) new_u.append(new_ut) if t < self.T - 1: assert list(xt.shape) == [self.n_batch, self.n_state], "actual:" + str(xt.shape) assert list(new_u[t].shape) == [self.n_batch, self.n_ctrl], "actual:" + str(new_u[t].shape) xu_t = F.concat((xt, new_u[t]), axis=1) x = bmv(self.F[t], xu_t) if self.f is not None: x += self.f[t] assert list(x.shape) == [self.n_batch, self.n_state], str(x.shape) + \ " x dim mismatch: expected" + str( [self.n_batch, self.n_state]) new_x.append(x) new_x = F.stack(new_x, axis=0) new_u = F.stack(new_u, axis=0) assert list(new_x.shape) == [self.T, self.n_batch, self.n_state], str(new_x.shape) + " new x dim mismatch" assert list(new_u.shape) == [self.T, self.n_batch, self.n_ctrl], "new u dim mismatch" return new_x, new_u
def approximate_cost(self, x, u, Cf, diff=True): with torch.enable_grad(): tau = torch.cat((x, u), dim=2).data tau = Variable(tau, requires_grad=True) if self.slew_rate_penalty is not None: print(""" MPC Error: Using a non-convex cost with a slew rate penalty is not yet implemented. The current implementation does not correctly do a line search. More details: https://github.com/locuslab/mpc.pytorch/issues/12 """) sys.exit(-1) differences = tau[1:, :, -self.n_ctrl:] - tau[:-1, :, -self.n_ctrl:] slew_penalty = (self.slew_rate_penalty * differences.pow(2)).sum(-1) costs = list() hessians = list() grads = list() for t in range(self.T): tau_t = tau[t] if self.slew_rate_penalty is not None: cost = Cf(tau_t) + (slew_penalty[t - 1] if t > 0 else 0) else: cost = Cf(tau_t) grad = torch.autograd.grad(cost.sum(), tau_t, retain_graph=True)[0] hessian = list() for v_i in range(tau.shape[2]): hessian.append( torch.autograd.grad(grad[:, v_i].sum(), tau_t, retain_graph=True)[0]) hessian = torch.stack(hessian, dim=-1) costs.append(cost) grads.append(grad - util.bmv(hessian, tau_t)) hessians.append(hessian) costs = torch.stack(costs, dim=0) grads = torch.stack(grads, dim=0) hessians = torch.stack(hessians, dim=0) if not diff: return hessians.data, grads.data, costs.data return hessians, grads, costs
def test_dynamics(): import numpy as np np.random.seed(0) batch = 2 time = 3 n_state = 1 n_ctrl = 1 n_sc = n_state + n_ctrl x = np.random.randn(time, batch, n_state) u = np.random.randn(time, batch, n_ctrl) x = chainer.Variable(x) u = chainer.Variable(u) A = np.random.randn(batch, n_state, n_sc) B = np.random.randn(batch, n_state) A = chainer.Variable(A) B = chainer.Variable(B) dynamics = lambda s, c: bmv(A, F.concat((s, c), axis=1)) + B large_F, f = linearize_dynamics(x, u, dynamics) print("A", A) print("large_F", large_F) print("B", B) print("f", f)
def approximate_cost(x, u, Cf): """ approximate cost function at point(x, u) :param x: time batch n_state :param u: time batch n_ctrl :param Cf:Cost Function need map vector to scalar :return: hessian, grads, costs """ assert x.shape[0] == u.shape[0] assert x.shape[1] == u.shape[1] T = x.shape[0] tau = F.concat((x, u), axis=2) costs = [] hessians = [] grads = [] # for time for t in range(T): tau_t = tau[t] cost = Cf(tau_t) # value of cost function at tau assert list(cost.shape) == [x.shape[1]] # print("cost.shape", cost.shape) grad = chainer.grad([F.sum(cost)], [tau_t], enable_double_backprop=True)[0] # need hessian hessian = [] # for each dimension? for v_i in range(tau.shape[2]): # n_sc grad_line = F.sum(grad[:, v_i]) hessian.append(chainer.grad([grad_line], [tau_t])[0]) hessian = F.stack(hessian, axis=-1) costs.append(cost) # change to near 0?? Is this necessary ??? grads.append(grad - bmv(hessian, tau_t)) hessians.append(hessian) costs = F.stack(costs) grads = F.stack(grads) hessians = F.stack(hessians) return hessians, grads, costs
def backward(self, target_input_indexes, grad_outputs): """ Backward Pass :param target_input_indexes: :param grad_outputs: :return: """ # Forward 2 calculate dual variable with backward recursion, Equation (7) in [1] x_init, C, c, large_f = self.get_retained_inputs() C_Tx = C[self.T - 1, :, :self.n_state, :] c_Tx = c[self.T - 1, :, :self.n_state] assert list(C_Tx.shape) == [self.n_batch, self.n_state, self.n_sc] x, u = self.get_retained_outputs() taus = F.concat((x, u), axis=2).reshape(self.T, self.n_batch, self.n_sc) Lambda_T = bmv(C_Tx, taus[self.T - 1]) + c_Tx Lambdas = [Lambda_T] # backward recursion calculate dual variable for i in range(self.T - 2, -1, -1): Lambda_tp1 = Lambdas[self.T - 2 - i] tau_t = taus[i] F_t = large_f[i] F_tx_T = F.transpose(F_t[:, :self.n_state, :self.n_state], axes=(0, 2, 1)) C_tx = C[i][:, :self.n_state, :] c_tx = c[i][:, :self.n_state] Lambda_t = bmv(F_tx_T, Lambda_tp1) + bmv(C_tx, tau_t) + c_tx Lambdas.append(Lambda_t) Lambdas.reverse() # Backward 1 grad_x, grad_u = grad_outputs xp = chainer.backend.get_array_module(*grad_outputs) zero_init = chainer.Variable(xp.zeros_like(x_init.data)) zero_f = chainer.Variable( xp.zeros((self.T - 1, self.n_batch, self.n_state))) drl = F.concat((grad_x, grad_u), axis=2).reshape(self.T, self.n_batch, self.n_sc) lqr = LqrRecursion(zero_init, C, drl, large_f, zero_f, self.T, self.n_state, self.n_ctrl) dx, du = lqr.solve_recursion() # Backward 2 calculate dual variable d_taus = F.concat((dx, du), axis=2).reshape(self.T, self.n_batch, self.n_sc) d_lambda_T = bmv( C_Tx, d_taus[self.T - 1]) + drl[self.T - 1][:, :self.n_state] d_lambdas = [d_lambda_T] for i in range(self.T - 2, -1, -1): d_lambda_tp1 = d_lambdas[self.T - 2 - i] d_tau_t = d_taus[i] F_t = large_f[i] F_tx_T = F.transpose(F_t[:, :self.n_state, :self.n_state], axes=(0, 2, 1)) C_tx = C[i][:, :self.n_state, :] d_rl = drl[i][:, :self.n_state] d_lambda_t = bmv(F_tx_T, d_lambda_tp1) + bmv(C_tx, d_tau_t) + d_rl d_lambdas.append(d_lambda_t) d_lambdas.reverse() # Backward line 3 compute derivatives : Equation 8 dC = F.stack([ 0.5 * bger(d_taus[t], taus[t]) + bger(taus[t], d_taus[t]) for t in range(self.T) ], axis=0) dc = F.stack([d_taus[t] for t in range(self.T)], axis=0) dF = F.stack([ bger(d_lambdas[t + 1], taus[t]).data + bger(Lambdas[t + 1], d_taus[t]).data for t in range(self.T - 1) ], axis=0) df = F.stack(d_lambdas[:self.T - 1], axis=0) d_x_init = d_lambdas[0] ''' print(d_x_init.shape) print(dC.shape) print(dc.shape) print(dF.shape) print(df.shape) ''' return d_x_init, dC, dc, dF, df
def lqr_forward(self, x_init, C, c, F, f, Ks, ks): x = self.current_x u = self.current_u n_batch = C.size(1) old_cost = util.get_cost(self.T, u, self.true_cost, self.true_dynamics, x=x) current_cost = None alphas = torch.ones(n_batch).type_as(C) full_du_norm = None i = 0 while (current_cost is None or \ (old_cost is not None and \ torch.any((current_cost > old_cost)).cpu().item() == 1)) and \ i < self.max_linesearch_iter: new_u = [] new_x = [x_init] dx = [torch.zeros_like(x_init)] objs = [] for t in range(self.T): t_rev = self.T - 1 - t Kt = Ks[t_rev] kt = ks[t_rev] new_xt = new_x[t] xt = x[t] ut = u[t] dxt = dx[t] new_ut = util.bmv(Kt, dxt) + ut + torch.diag(alphas).mm(kt) # Currently unimplemented: assert not ((self.delta_u is not None) and (self.u_lower is None)) if self.u_zero_I is not None: new_ut[self.u_zero_I[t]] = 0. if self.u_lower is not None: lb = self.get_bound('lower', t) ub = self.get_bound('upper', t) if self.delta_u is not None: lb_limit, ub_limit = lb, ub lb = u[t] - self.delta_u ub = u[t] + self.delta_u I = lb < lb_limit lb[I] = lb_limit if isinstance(lb_limit, float) else lb_limit[I] I = ub > ub_limit ub[I] = ub_limit if isinstance(lb_limit, float) else ub_limit[I] new_ut = util.eclamp(new_ut, lb, ub) new_u.append(new_ut) new_xut = torch.cat((new_xt, new_ut), dim=1) if t < self.T - 1: if isinstance(self.true_dynamics, mpc.LinDx): F, f = self.true_dynamics.F, self.true_dynamics.f new_xtp1 = util.bmv(F[t], new_xut) if f is not None and f.nelement() > 0: new_xtp1 += f[t] else: new_xtp1 = self.true_dynamics(Variable(new_xt), Variable(new_ut)).data new_x.append(new_xtp1) dx.append(new_xtp1 - x[t + 1]) if isinstance(self.true_cost, mpc.QuadCost): C, c = self.true_cost.C, self.true_cost.c obj = 0.5 * util.bquad(new_xut, C[t]) + util.bdot( new_xut, c[t]) else: obj = self.true_cost(new_xut) objs.append(obj) objs = torch.stack(objs) current_cost = torch.sum(objs, dim=0) new_u = torch.stack(new_u) new_x = torch.stack(new_x) if full_du_norm is None: full_du_norm = (u - new_u).transpose(1, 2).contiguous().view( n_batch, -1).norm(2, 1) alphas[current_cost > old_cost] *= self.linesearch_decay i += 1 # If the iteration limit is hit, some alphas # are one step too small. alphas[current_cost > old_cost] /= self.linesearch_decay alpha_du_norm = (u - new_u).transpose(1, 2).contiguous().view( n_batch, -1).norm(2, 1) return new_x, new_u, LqrForOut(objs, full_du_norm, alpha_du_norm, torch.mean(alphas), current_cost)
def lqr_backward(self, C, c, F, f): n_batch = C.size(1) u = self.current_u Ks = [] ks = [] prev_kt = None n_total_qp_iter = 0 Vtp1 = vtp1 = None for t in range(self.T - 1, -1, -1): if t == self.T - 1: Qt = C[t] qt = c[t] else: Ft = F[t] Ft_T = Ft.transpose(1, 2) Qt = C[t] + Ft_T.bmm(Vtp1).bmm(Ft) if f is None or f.nelement() == 0: qt = c[t] + Ft_T.bmm(vtp1.unsqueeze(2)).squeeze(2) else: ft = f[t] qt = c[t] + Ft_T.bmm(Vtp1).bmm(ft.unsqueeze(2)).squeeze(2) + \ Ft_T.bmm(vtp1.unsqueeze(2)).squeeze(2) n_state = self.n_state Qt_xx = Qt[:, :n_state, :n_state] Qt_xu = Qt[:, :n_state, n_state:] Qt_ux = Qt[:, n_state:, :n_state] Qt_uu = Qt[:, n_state:, n_state:] qt_x = qt[:, :n_state] qt_u = qt[:, n_state:] if self.u_lower is None: if self.n_ctrl == 1 and self.u_zero_I is None: Kt = -(1. / Qt_uu) * Qt_ux kt = -(1. / Qt_uu.squeeze(2)) * qt_u else: if self.u_zero_I is None: Qt_uu_inv = [ torch.pinverse(Qt_uu[i]) for i in range(Qt_uu.shape[0]) ] Qt_uu_inv = torch.stack(Qt_uu_inv) Kt = -Qt_uu_inv.bmm(Qt_ux) kt = util.bmv(-Qt_uu_inv, qt_u) # Qt_uu_LU = Qt_uu.btrifact() # Kt = -Qt_ux.btrisolve(*Qt_uu_LU) # kt = -qt_u.btrisolve(*Qt_uu_LU) else: # Solve with zero constraints on the active controls. I = self.u_zero_I[t] notI = 1 - I qt_u_ = qt_u.clone() qt_u_[I] = 0 Qt_uu_ = Qt_uu.clone() if I.is_cuda: notI_ = notI.float() Qt_uu_I = (1 - util.bger(notI_, notI_)).type_as(I) else: Qt_uu_I = 1 - util.bger(notI, notI) Qt_uu_[Qt_uu_I] = 0. Qt_uu_[util.bdiag(I)] += 1e-8 Qt_ux_ = Qt_ux.clone() Qt_ux_[I.unsqueeze(2).repeat(1, 1, Qt_ux.size(2))] = 0. if self.n_ctrl == 1: Kt = -(1. / Qt_uu_) * Qt_ux_ kt = -(1. / Qt_uu.squeeze(2)) * qt_u_ else: Qt_uu_LU_ = Qt_uu_.btrifact() Kt = -Qt_ux_.btrisolve(*Qt_uu_LU_) kt = -qt_u_.btrisolve(*Qt_uu_LU_) else: assert self.delta_space lb = self.get_bound('lower', t) - u[t] ub = self.get_bound('upper', t) - u[t] if self.delta_u is not None: lb[lb < -self.delta_u] = -self.delta_u ub[ub > self.delta_u] = self.delta_u kt, Qt_uu_free_LU, If, n_qp_iter = pnqp(Qt_uu, qt_u, lb, ub, x_init=prev_kt, n_iter=20) if self.verbose > 1: print(' + n_qp_iter: ', n_qp_iter + 1) n_total_qp_iter += 1 + n_qp_iter prev_kt = kt Qt_ux_ = Qt_ux.clone() Qt_ux_[(1 - If).unsqueeze(2).repeat(1, 1, Qt_ux.size(2))] = 0 if self.n_ctrl == 1: # Bad naming, Qt_uu_free_LU isn't the LU in this case. Kt = -((1. / Qt_uu_free_LU) * Qt_ux_) else: Kt = -Qt_ux_.btrisolve(*Qt_uu_free_LU) Kt_T = Kt.transpose(1, 2) Ks.append(Kt) ks.append(kt) Vtp1 = Qt_xx + Qt_xu.bmm(Kt) + Kt_T.bmm(Qt_ux) + Kt_T.bmm( Qt_uu).bmm(Kt) vtp1 = qt_x + Qt_xu.bmm(kt.unsqueeze(2)).squeeze(2) + \ Kt_T.bmm(qt_u.unsqueeze(2)).squeeze(2) + \ Kt_T.bmm(Qt_uu).bmm(kt.unsqueeze(2)).squeeze(2) return Ks, ks, LqrBackOut(n_total_qp_iter=n_total_qp_iter)
def backward(self, dl_dx, dl_du): start = time.time() x_init, C, c, F, f, new_x, new_u = self.saved_tensors r = [] for t in range(self.T): rt = torch.cat((dl_dx[t], dl_du[t]), 1) r.append(rt) r = torch.stack(r) if self.u_lower is None: I = None else: I = (torch.abs(new_u - self.u_lower) <= 1e-8) | \ (torch.abs(new_u - self.u_upper) <= 1e-8) dx_init = Variable(torch.zeros_like(x_init)) _mpc = mpc.MPC( self.n_state, self.n_ctrl, self.T, u_zero_I=I, u_init=None, lqr_iter=1, verbose=-1, n_batch=C.size(1), delta_u=None, # exit_unconverged=True, # It's really bad if this doesn't converge. exit_unconverged=False, # It's really bad if this doesn't converge. eps=self.back_eps, ) dx, du, _ = _mpc(dx_init, mpc.QuadCost(C, -r), mpc.LinDx(F, None)) dx, du = dx.data, du.data dxu = torch.cat((dx, du), 2) xu = torch.cat((new_x, new_u), 2) dC = torch.zeros_like(C) for t in range(self.T): xut = torch.cat((new_x[t], new_u[t]), 1) dxut = dxu[t] dCt = -0.5 * (util.bger(dxut, xut) + util.bger(xut, dxut)) dC[t] = dCt dc = -dxu lams = [] prev_lam = None for t in range(self.T - 1, -1, -1): Ct_xx = C[t, :, :self.n_state, :self.n_state] Ct_xu = C[t, :, :self.n_state, self.n_state:] ct_x = c[t, :, :self.n_state] xt = new_x[t] ut = new_u[t] lamt = util.bmv(Ct_xx, xt) + util.bmv(Ct_xu, ut) + ct_x if prev_lam is not None: Fxt = F[t, :, :, :self.n_state].transpose(1, 2) lamt += util.bmv(Fxt, prev_lam) lams.append(lamt) prev_lam = lamt lams = list(reversed(lams)) dlams = [] prev_dlam = None for t in range(self.T - 1, -1, -1): dCt_xx = C[t, :, :self.n_state, :self.n_state] dCt_xu = C[t, :, :self.n_state, self.n_state:] drt_x = -r[t, :, :self.n_state] dxt = dx[t] dut = du[t] dlamt = util.bmv(dCt_xx, dxt) + util.bmv(dCt_xu, dut) + drt_x if prev_dlam is not None: Fxt = F[t, :, :, :self.n_state].transpose(1, 2) dlamt += util.bmv(Fxt, prev_dlam) dlams.append(dlamt) prev_dlam = dlamt dlams = torch.stack(list(reversed(dlams))) dF = torch.zeros_like(F) for t in range(self.T - 1): xut = xu[t] lamt = lams[t + 1] dxut = dxu[t] dlamt = dlams[t + 1] dF[t] = -(util.bger(dlamt, xut) + util.bger(lamt, dxut)) if f.nelement() > 0: _dlams = dlams[1:] assert _dlams.shape == f.shape df = -_dlams else: df = torch.Tensor() dx_init = -dlams[0] self.backward_time = time.time() - start return dx_init, dC, dc, dF, df
def pnqp(H, q, lower, upper, x_init=None, n_iter=20): GAMMA = 0.1 n_batch, n, _ = H.size() pnqp_I = 1e-11 * torch.eye(n).type_as(H).expand_as(H) def obj(x): return 0.5 * util.bquad(x, H) + util.bdot(q, x) if x_init is None: if n == 1: x_init = -(1. / H.squeeze(2)) * q else: # H_lu = H.btrifact() # XXX deprecated!!! H_lu = torch.lu(H) x_init = -q.btrisolve(H_lu[0], H_lu[1]) # Clamped in the x assignment. else: x_init = x_init.clone() # Don't over-write the original x_init. x = util.eclamp(x_init, lower, upper) # Active examples in the batch. J = torch.ones(n_batch).type_as(x).byte() for i in range(n_iter): g = util.bmv(H, x) + q Ic = ((x == lower) & (g > 0)) | ((x == upper) & (g < 0)) If = 1 - Ic if If.is_cuda: Hff_I = util.bger(If.float(), If.float()).type_as(If) not_Hff_I = 1 - Hff_I Hfc_I = util.bger(If.float(), Ic.float()).type_as(If) else: Hff_I = util.bger(If, If) not_Hff_I = 1 - Hff_I Hfc_I = util.bger(If, Ic) g_ = g.clone() g_[Ic] = 0. H_ = H.clone() H_[not_Hff_I] = 0.0 H_ += pnqp_I if n == 1: dx = -(1. / H_.squeeze(2)) * g_ else: # H_lu_ = H_.btrifact() # XXX deprecated!!! H_lu_ = torch.lu(H) dx = -g_.btrisolve(*H_lu_) J = torch.norm(dx, 2, 1) >= 1e-4 m = J.sum().item() # Number of active examples in the batch. if m == 0: return x, H_ if n == 1 else H_lu_, If, i alpha = torch.ones(n_batch).type_as(x) decay = 0.1 max_armijo = GAMMA count = 0 while max_armijo <= GAMMA and count < 10: # Crude way of making sure too much time isn't being spent # doing the line search. # assert count < 10 maybe_x = util.eclamp(x + torch.diag(alpha).mm(dx), lower, upper) armijos = (GAMMA + 1e-6) * torch.ones(n_batch).type_as(x) armijos[J] = (obj(x) - obj(maybe_x))[J] / util.bdot( g, x - maybe_x)[J] I = armijos <= GAMMA alpha[I] *= decay max_armijo = torch.max(armijos) count += 1 x = maybe_x # TODO: Maybe change this to a warning. print("[WARNING] pnqp warning: Did not converge") return x, H_ if n == 1 else H_lu_, If, i
def linearize_dynamics(self, x, u, dynamics, diff): # TODO: Cleanup variable usage. n_batch = x[0].size(0) if self.grad_method == GradMethods.ANALYTIC: _u = Variable(u[:-1].view(-1, self.n_ctrl), requires_grad=True) _x = Variable(x[:-1].contiguous().view(-1, self.n_state), requires_grad=True) # This inefficiently calls dynamics again, but is worth it because # we can efficiently compute grad_input for every time step at once. _new_x = dynamics(_x, _u) # This check is a little expensive and should only be done if # modifying this code. # assert torch.abs(_new_x.data - torch.cat(x[1:])).max() <= 1e-6 if not diff: _new_x = _new_x.data _x = _x.data _u = _u.data R, S = dynamics.grad_input(_x, _u) f = _new_x - util.bmv(R, _x) - util.bmv(S, _u) f = f.view(self.T - 1, n_batch, self.n_state) R = R.contiguous().view(self.T - 1, n_batch, self.n_state, self.n_state) S = S.contiguous().view(self.T - 1, n_batch, self.n_state, self.n_ctrl) F = torch.cat((R, S), 3) if not diff: F, f = list(map(Variable, [F, f])) return F, f else: # TODO: This is inefficient and confusing. x_init = x[0] x = [x_init] F, f = [], [] for t in range(self.T): if t < self.T - 1: xt = Variable(x[t], requires_grad=True) ut = Variable(u[t], requires_grad=True) xut = torch.cat((xt, ut), 1) new_x = dynamics(xt, ut) # Linear dynamics approximation. if self.grad_method in [ GradMethods.AUTO_DIFF, GradMethods.ANALYTIC_CHECK ]: Rt, St = [], [] for j in range(self.n_state): Rj, Sj = torch.autograd.grad(new_x[:, j].sum(), [xt, ut], retain_graph=True) if not diff: Rj, Sj = Rj.data, Sj.data Rt.append(Rj) St.append(Sj) Rt = torch.stack(Rt, dim=1) St = torch.stack(St, dim=1) if self.grad_method == GradMethods.ANALYTIC_CHECK: assert False # Not updated Rt_autograd, St_autograd = Rt, St Rt, St = dynamics.grad_input(xt, ut) eps = 1e-8 if torch.max(torch.abs(Rt - Rt_autograd)).data[0] > eps or \ torch.max(torch.abs(St - St_autograd)).data[0] > eps: print(''' nmpc.ANALYTIC_CHECK error: The analytic derivative of the dynamics function may be off. ''') else: print(''' nmpc.ANALYTIC_CHECK: The analytic derivative of the dynamics function seems correct. Re-run with GradMethods.ANALYTIC to continue. ''') sys.exit(0) elif self.grad_method == GradMethods.FINITE_DIFF: Rt, St = [], [] for i in range(n_batch): Ri = util.jacobian(lambda s: dynamics(s, ut[i]), xt[i], 1e-4) Si = util.jacobian(lambda a: dynamics(xt[i], a), ut[i], 1e-4) if not diff: Ri, Si = Ri.data, Si.data Rt.append(Ri) St.append(Si) Rt = torch.stack(Rt) St = torch.stack(St) else: assert False Ft = torch.cat((Rt, St), 2) F.append(Ft) if not diff: xt, ut, new_x = xt.data, ut.data, new_x.data ft = new_x - util.bmv(Rt, xt) - util.bmv(St, ut) f.append(ft) if t < self.T - 1: x.append(util.detach_maybe(new_x)) F = torch.stack(F, 0) f = torch.stack(f, 0) if not diff: F, f = list(map(Variable, [F, f])) return F, f
def backward(self): """ LQR backward recursion Note: Ks ks is reversed version fo original :return: Ks, ks gain """ Ks = [] ks = [] Vt = None vt = None # self.T-1 to 0 loop for t in range(self.T - 1, -1, -1): # initial case if t == self.T - 1: Qt = self.C[t] qt = self.c[t] else: Ft = self.F[t] Ft_T = F.transpose(Ft, axes=(0, 2, 1)) assert Ft.dtype.kind == 'f', "Ft dtype" assert Vt.dtype.kind == 'f', "Vt dtype" Qt = self.C[t] + F.matmul(F.matmul(Ft_T, Vt), Ft) if self.f is None: # NOTE f.nelement() == 0 condition ? qt = self.c[t] + bmv(Ft_T, vt) else: # f is not none ft = self.f[t] qt = self.c[t] + bmv(F.matmul(Ft_T, Vt), ft) + bmv(Ft_T, vt) assert list(qt.shape) == [self.n_batch, self.n_sc], "qt dim mismatch" assert list(Qt.shape) == [self.n_batch, self.n_sc, self.n_sc], str(Qt.shape) + " Qt dim mismatch" Qt_xx = Qt[:, :self.n_state, :self.n_state] Qt_xu = Qt[:, :self.n_state, self.n_state:] Qt_ux = Qt[:, self.n_state:, :self.n_state] Qt_uu = Qt[:, self.n_state:, self.n_state:] qt_x = qt[:, :self.n_state] qt_u = qt[:, self.n_state:] assert list(Qt_uu.shape) == [self.n_batch, self.n_ctrl, self.n_ctrl], "Qt_uu dim mismatch" assert list(Qt_ux.shape) == [self.n_batch, self.n_ctrl, self.n_state], "Qt_ux dim mismatch" assert list(Qt_xu.shape) == [self.n_batch, self.n_state, self.n_ctrl], "Qt_xu dim mismatch" assert list(qt_x.shape) == [self.n_batch, self.n_state], "qt_x dim mismatch" assert list(qt_u.shape) == [self.n_batch, self.n_ctrl], "qt_u dim mismatch" # Next calculate Kt and kt # TODO LU decomposition if self.n_ctrl == 1 and self.u_zero_Index is None: # scalar Kt = - (1. / Qt_uu) * Qt_ux kt = - (1. / F.squeeze(Qt_uu, axis=2)) * qt_u elif self.u_zero_Index is None: # matrix Qt_uu_inv = F.batch_inv(Qt_uu) Kt = - F.matmul(Qt_uu_inv, Qt_ux) kt = - bmv(Qt_uu_inv, qt_u) else: # u_zero_index is not none index = self.u_zero_Index[t] qt_u_ = copy.deepcopy(qt_u) qt_u_ = F.where(index, self.xp.zeros_like(qt_u_.array), qt_u_) Qt_uu_ = copy.deepcopy(Qt_uu) notI = 1.0 - F.cast(index, qt_u_.dtype) Qt_uu_I = 1 - bger(notI, notI) Qt_uu_I = F.cast(Qt_uu_I, 'bool') Qt_uu_ = F.where(Qt_uu_I, self.xp.zeros_like(Qt_uu_.array), Qt_uu_) index_qt_uu = self.xp.array([self.xp.diagflat(index[i]) for i in range(index.shape[0])]) Qt_uu_ = F.where(F.cast(index_qt_uu, 'bool'), Qt_uu + 1e-8, Qt_uu) Qt_ux_ = copy.deepcopy(Qt_ux) index_qt_ux = F.repeat(F.expand_dims(index, axis=2), Qt_ux.shape[2], axis=2) Qt_ux_ = F.where(index_qt_ux, self.xp.zeros_like(Qt_ux_.array), Qt_ux) # print("qt_u_", qt_u_) # print("Qt_uu_", Qt_uu_) # print("Qt_ux_", Qt_ux_) if self.n_ctrl == 1: Kt = - (1. / Qt_uu_) * Qt_ux_ # NOTE different from original kt = - (1. / F.squeeze(Qt_uu_, axis=2)) * qt_u_ else: Qt_uu_LU_ = batch_lu_factor(Qt_uu_) Kt = - batch_lu_solve(Qt_uu_LU_, Qt_ux_) kt = - batch_lu_solve(Qt_uu_LU_, qt_u_) assert list(Kt.shape) == [self.n_batch, self.n_ctrl, self.n_state], "Kt dim mismatch" assert list(kt.shape) == [self.n_batch, self.n_ctrl], "kt dim mismatch" Kt_T = F.transpose(Kt, axes=(0, 2, 1)) Ks.append(Kt) ks.append(kt) Vt = Qt_xx + F.matmul(Qt_xu, Kt) + F.matmul(Kt_T, Qt_ux) + F.matmul(F.matmul(Kt_T, Qt_uu), Kt) vt = qt_x + bmv(Qt_xu, kt) + bmv(Kt_T, qt_u) + bmv(F.matmul(Kt_T, Qt_uu), kt) assert len(Ks) == self.T, "Ks length error" Ks.reverse() ks.reverse() return Ks, ks