Beispiel #1
0
    def forward(self, x_init, C, c, F, f=None):
        if self.no_op_forward:
            self.save_for_backward(x_init, C, c, F, f, self.current_x,
                                   self.current_u)
            return self.current_x, self.current_u

        if self.delta_space:
            # Taylor-expand the objective to do the backward pass in
            # the delta space.
            assert self.current_x is not None
            assert self.current_u is not None
            c_back = []
            for t in range(self.T):
                xt = self.current_x[t]
                ut = self.current_u[t]
                xut = torch.cat((xt, ut), 1)
                c_back.append(util.bmv(C[t], xut) + c[t])
            c_back = torch.stack(c_back)
            f_back = None
        else:
            assert False

        Ks, ks, self.back_out = self.lqr_backward(C, c_back, F, f_back)
        new_x, new_u, self.for_out = self.lqr_forward(x_init, C, c, F, f, Ks,
                                                      ks)
        self.save_for_backward(x_init, C, c, F, f, new_x, new_u)

        return new_x, new_u
Beispiel #2
0
def linearize_dynamics(x, u, dynamics):
    """linearize dynamics

    :param x:time batch n_state
    :param u:
    :param dynamics:
    :return:
    """
    assert x.shape[0] == u.shape[0]
    assert x.shape[1] == u.shape[1]
    n_state = x.shape[2]
    T = x.shape[0]
    x_init = x[0]
    x_ar = [x_init]
    # NOTE need to use newly calculate trajectory ???
    # TODO: CHECK THIS
    large_F, f = [], []
    for t in range(T):
        if t < T - 1:
            xt = x_ar[t]
            ut = u[t]
            #  print("x_ut.shape", xut.shape)
            new_x = dynamics(xt, ut)
            # Linear dynamics approximation.
            Rt, St = [], []
            for j in range(n_state):
                Rj, Sj = chainer.grad([F.sum(new_x[:, j])], [xt, ut])
                Rt.append(Rj)
                St.append(Sj)
                assert Sj is not None
            Rt = F.stack(Rt, axis=1)
            St = F.stack(St, axis=1)
            #  print("Rt shape", Rt.shape)
            #  print("St shape", St.shape)
            Ft = F.concat((Rt, St), axis=2)
            large_F.append(Ft)
            ft = new_x - bmv(Rt, xt) - bmv(St, ut)
            f.append(ft)
            x_ar.append(new_x)

    large_F = F.stack(large_F, 0)
    f = F.stack(f, 0)
    return large_F, f
Beispiel #3
0
    def forward(self, Ks, ks):
        """ LQR forward recursion

        :param Ks: solved in backward recursion
        :param ks: solved in forward recursion
        :return: x, u
        """
        assert len(Ks) == self.T, "Ks length error"
        new_x = [self.x_init]
        new_u = []
        for t in range(self.T):
            Kt = Ks[t]
            kt = ks[t]
            xt = new_x[t]
            assert list(xt.shape) == [self.n_batch, self.n_state], str(xt.shape) + \
                                                                   " xt dim mismatch: expected" + str(
                [self.n_batch, self.n_state])
            new_ut = bmv(Kt, xt) + kt
            assert list(new_ut.shape) == [self.n_batch, self.n_ctrl], "actual:" + str(new_ut.shape)
            if self.u_zero_Index is not None:
                assert self.u_zero_Index[t].shape == new_ut.shape, str(self.u_zero_Index[t].shape) + " : " + str(
                    new_ut.shape)
                new_ut = F.where(self.u_zero_Index[t], self.xp.zeros_like(new_ut.array), new_ut)
                assert list(new_ut.shape) == [self.n_batch, self.n_ctrl], "actual:" + str(new_ut.shape)
            new_u.append(new_ut)
            if t < self.T - 1:
                assert list(xt.shape) == [self.n_batch, self.n_state], "actual:" + str(xt.shape)
                assert list(new_u[t].shape) == [self.n_batch, self.n_ctrl], "actual:" + str(new_u[t].shape)
                xu_t = F.concat((xt, new_u[t]), axis=1)
                x = bmv(self.F[t], xu_t)
                if self.f is not None:
                    x += self.f[t]
                assert list(x.shape) == [self.n_batch, self.n_state], str(x.shape) + \
                                                                      " x dim mismatch: expected" + str(
                    [self.n_batch, self.n_state])
                new_x.append(x)
        new_x = F.stack(new_x, axis=0)
        new_u = F.stack(new_u, axis=0)
        assert list(new_x.shape) == [self.T, self.n_batch, self.n_state], str(new_x.shape) + " new x dim mismatch"
        assert list(new_u.shape) == [self.T, self.n_batch, self.n_ctrl], "new u dim mismatch"
        return new_x, new_u
Beispiel #4
0
    def approximate_cost(self, x, u, Cf, diff=True):
        with torch.enable_grad():
            tau = torch.cat((x, u), dim=2).data
            tau = Variable(tau, requires_grad=True)
            if self.slew_rate_penalty is not None:
                print("""
MPC Error: Using a non-convex cost with a slew rate penalty is not yet implemented.
The current implementation does not correctly do a line search.
More details: https://github.com/locuslab/mpc.pytorch/issues/12
""")
                sys.exit(-1)
                differences = tau[1:, :, -self.n_ctrl:] - tau[:-1, :,
                                                              -self.n_ctrl:]
                slew_penalty = (self.slew_rate_penalty *
                                differences.pow(2)).sum(-1)
            costs = list()
            hessians = list()
            grads = list()
            for t in range(self.T):
                tau_t = tau[t]
                if self.slew_rate_penalty is not None:
                    cost = Cf(tau_t) + (slew_penalty[t - 1] if t > 0 else 0)
                else:
                    cost = Cf(tau_t)

                grad = torch.autograd.grad(cost.sum(),
                                           tau_t,
                                           retain_graph=True)[0]
                hessian = list()
                for v_i in range(tau.shape[2]):
                    hessian.append(
                        torch.autograd.grad(grad[:, v_i].sum(),
                                            tau_t,
                                            retain_graph=True)[0])
                hessian = torch.stack(hessian, dim=-1)
                costs.append(cost)
                grads.append(grad - util.bmv(hessian, tau_t))
                hessians.append(hessian)
            costs = torch.stack(costs, dim=0)
            grads = torch.stack(grads, dim=0)
            hessians = torch.stack(hessians, dim=0)
            if not diff:
                return hessians.data, grads.data, costs.data
            return hessians, grads, costs
Beispiel #5
0
def test_dynamics():
    import numpy as np
    np.random.seed(0)
    batch = 2
    time = 3
    n_state = 1
    n_ctrl = 1
    n_sc = n_state + n_ctrl
    x = np.random.randn(time, batch, n_state)
    u = np.random.randn(time, batch, n_ctrl)
    x = chainer.Variable(x)
    u = chainer.Variable(u)
    A = np.random.randn(batch, n_state, n_sc)
    B = np.random.randn(batch, n_state)
    A = chainer.Variable(A)
    B = chainer.Variable(B)
    dynamics = lambda s, c: bmv(A, F.concat((s, c), axis=1)) + B
    large_F, f = linearize_dynamics(x, u, dynamics)
    print("A", A)
    print("large_F", large_F)
    print("B", B)
    print("f", f)
Beispiel #6
0
def approximate_cost(x, u, Cf):
    """ approximate cost function at point(x, u)

    :param x: time batch n_state
    :param u: time batch n_ctrl
    :param Cf:Cost Function need map vector to scalar
    :return: hessian, grads, costs
    """
    assert x.shape[0] == u.shape[0]
    assert x.shape[1] == u.shape[1]
    T = x.shape[0]
    tau = F.concat((x, u), axis=2)
    costs = []
    hessians = []
    grads = []
    # for time
    for t in range(T):
        tau_t = tau[t]
        cost = Cf(tau_t)  # value of cost function at tau
        assert list(cost.shape) == [x.shape[1]]
        # print("cost.shape", cost.shape)
        grad = chainer.grad([F.sum(cost)], [tau_t], enable_double_backprop=True)[0]  # need hessian
        hessian = []
        # for each dimension?
        for v_i in range(tau.shape[2]):
            # n_sc
            grad_line = F.sum(grad[:, v_i])
            hessian.append(chainer.grad([grad_line], [tau_t])[0])
        hessian = F.stack(hessian, axis=-1)
        costs.append(cost)
        # change to near 0?? Is this necessary ???
        grads.append(grad - bmv(hessian, tau_t))
        hessians.append(hessian)
    costs = F.stack(costs)
    grads = F.stack(grads)
    hessians = F.stack(hessians)
    return hessians, grads, costs
Beispiel #7
0
    def backward(self, target_input_indexes, grad_outputs):
        """ Backward Pass

        :param target_input_indexes:
        :param grad_outputs:
        :return:
        """
        # Forward 2 calculate dual variable with backward recursion, Equation (7) in [1]
        x_init, C, c, large_f = self.get_retained_inputs()
        C_Tx = C[self.T - 1, :, :self.n_state, :]
        c_Tx = c[self.T - 1, :, :self.n_state]
        assert list(C_Tx.shape) == [self.n_batch, self.n_state, self.n_sc]
        x, u = self.get_retained_outputs()
        taus = F.concat((x, u), axis=2).reshape(self.T, self.n_batch,
                                                self.n_sc)
        Lambda_T = bmv(C_Tx, taus[self.T - 1]) + c_Tx
        Lambdas = [Lambda_T]
        # backward recursion calculate dual variable
        for i in range(self.T - 2, -1, -1):
            Lambda_tp1 = Lambdas[self.T - 2 - i]
            tau_t = taus[i]
            F_t = large_f[i]
            F_tx_T = F.transpose(F_t[:, :self.n_state, :self.n_state],
                                 axes=(0, 2, 1))
            C_tx = C[i][:, :self.n_state, :]
            c_tx = c[i][:, :self.n_state]
            Lambda_t = bmv(F_tx_T, Lambda_tp1) + bmv(C_tx, tau_t) + c_tx
            Lambdas.append(Lambda_t)
        Lambdas.reverse()
        # Backward 1
        grad_x, grad_u = grad_outputs
        xp = chainer.backend.get_array_module(*grad_outputs)
        zero_init = chainer.Variable(xp.zeros_like(x_init.data))
        zero_f = chainer.Variable(
            xp.zeros((self.T - 1, self.n_batch, self.n_state)))
        drl = F.concat((grad_x, grad_u),
                       axis=2).reshape(self.T, self.n_batch, self.n_sc)
        lqr = LqrRecursion(zero_init, C, drl, large_f, zero_f, self.T,
                           self.n_state, self.n_ctrl)
        dx, du = lqr.solve_recursion()
        # Backward 2 calculate dual variable
        d_taus = F.concat((dx, du), axis=2).reshape(self.T, self.n_batch,
                                                    self.n_sc)
        d_lambda_T = bmv(
            C_Tx, d_taus[self.T - 1]) + drl[self.T - 1][:, :self.n_state]
        d_lambdas = [d_lambda_T]
        for i in range(self.T - 2, -1, -1):
            d_lambda_tp1 = d_lambdas[self.T - 2 - i]
            d_tau_t = d_taus[i]
            F_t = large_f[i]
            F_tx_T = F.transpose(F_t[:, :self.n_state, :self.n_state],
                                 axes=(0, 2, 1))
            C_tx = C[i][:, :self.n_state, :]
            d_rl = drl[i][:, :self.n_state]
            d_lambda_t = bmv(F_tx_T, d_lambda_tp1) + bmv(C_tx, d_tau_t) + d_rl
            d_lambdas.append(d_lambda_t)
        d_lambdas.reverse()
        # Backward line 3 compute derivatives : Equation 8
        dC = F.stack([
            0.5 * bger(d_taus[t], taus[t]) + bger(taus[t], d_taus[t])
            for t in range(self.T)
        ],
                     axis=0)
        dc = F.stack([d_taus[t] for t in range(self.T)], axis=0)
        dF = F.stack([
            bger(d_lambdas[t + 1], taus[t]).data +
            bger(Lambdas[t + 1], d_taus[t]).data for t in range(self.T - 1)
        ],
                     axis=0)
        df = F.stack(d_lambdas[:self.T - 1], axis=0)
        d_x_init = d_lambdas[0]
        '''
        print(d_x_init.shape)
        print(dC.shape)
        print(dc.shape)
        print(dF.shape)
        print(df.shape)
        '''
        return d_x_init, dC, dc, dF, df
Beispiel #8
0
    def lqr_forward(self, x_init, C, c, F, f, Ks, ks):
        x = self.current_x
        u = self.current_u
        n_batch = C.size(1)

        old_cost = util.get_cost(self.T,
                                 u,
                                 self.true_cost,
                                 self.true_dynamics,
                                 x=x)

        current_cost = None
        alphas = torch.ones(n_batch).type_as(C)
        full_du_norm = None

        i = 0
        while (current_cost is None or \
               (old_cost is not None and \
                torch.any((current_cost > old_cost)).cpu().item() == 1)) and \
                i < self.max_linesearch_iter:
            new_u = []
            new_x = [x_init]
            dx = [torch.zeros_like(x_init)]
            objs = []
            for t in range(self.T):
                t_rev = self.T - 1 - t
                Kt = Ks[t_rev]
                kt = ks[t_rev]
                new_xt = new_x[t]
                xt = x[t]
                ut = u[t]
                dxt = dx[t]
                new_ut = util.bmv(Kt, dxt) + ut + torch.diag(alphas).mm(kt)

                # Currently unimplemented:
                assert not ((self.delta_u is not None) and
                            (self.u_lower is None))

                if self.u_zero_I is not None:
                    new_ut[self.u_zero_I[t]] = 0.

                if self.u_lower is not None:
                    lb = self.get_bound('lower', t)
                    ub = self.get_bound('upper', t)

                    if self.delta_u is not None:
                        lb_limit, ub_limit = lb, ub
                        lb = u[t] - self.delta_u
                        ub = u[t] + self.delta_u
                        I = lb < lb_limit
                        lb[I] = lb_limit if isinstance(lb_limit,
                                                       float) else lb_limit[I]
                        I = ub > ub_limit
                        ub[I] = ub_limit if isinstance(lb_limit,
                                                       float) else ub_limit[I]
                    new_ut = util.eclamp(new_ut, lb, ub)
                new_u.append(new_ut)

                new_xut = torch.cat((new_xt, new_ut), dim=1)
                if t < self.T - 1:
                    if isinstance(self.true_dynamics, mpc.LinDx):
                        F, f = self.true_dynamics.F, self.true_dynamics.f
                        new_xtp1 = util.bmv(F[t], new_xut)
                        if f is not None and f.nelement() > 0:
                            new_xtp1 += f[t]
                    else:
                        new_xtp1 = self.true_dynamics(Variable(new_xt),
                                                      Variable(new_ut)).data

                    new_x.append(new_xtp1)
                    dx.append(new_xtp1 - x[t + 1])

                if isinstance(self.true_cost, mpc.QuadCost):
                    C, c = self.true_cost.C, self.true_cost.c
                    obj = 0.5 * util.bquad(new_xut, C[t]) + util.bdot(
                        new_xut, c[t])
                else:
                    obj = self.true_cost(new_xut)

                objs.append(obj)

            objs = torch.stack(objs)
            current_cost = torch.sum(objs, dim=0)

            new_u = torch.stack(new_u)
            new_x = torch.stack(new_x)
            if full_du_norm is None:
                full_du_norm = (u - new_u).transpose(1, 2).contiguous().view(
                    n_batch, -1).norm(2, 1)

            alphas[current_cost > old_cost] *= self.linesearch_decay
            i += 1

        # If the iteration limit is hit, some alphas
        # are one step too small.
        alphas[current_cost > old_cost] /= self.linesearch_decay
        alpha_du_norm = (u - new_u).transpose(1, 2).contiguous().view(
            n_batch, -1).norm(2, 1)

        return new_x, new_u, LqrForOut(objs, full_du_norm, alpha_du_norm,
                                       torch.mean(alphas), current_cost)
Beispiel #9
0
    def lqr_backward(self, C, c, F, f):
        n_batch = C.size(1)

        u = self.current_u
        Ks = []
        ks = []
        prev_kt = None
        n_total_qp_iter = 0
        Vtp1 = vtp1 = None
        for t in range(self.T - 1, -1, -1):
            if t == self.T - 1:
                Qt = C[t]
                qt = c[t]
            else:
                Ft = F[t]
                Ft_T = Ft.transpose(1, 2)
                Qt = C[t] + Ft_T.bmm(Vtp1).bmm(Ft)
                if f is None or f.nelement() == 0:
                    qt = c[t] + Ft_T.bmm(vtp1.unsqueeze(2)).squeeze(2)
                else:
                    ft = f[t]
                    qt = c[t] + Ft_T.bmm(Vtp1).bmm(ft.unsqueeze(2)).squeeze(2) + \
                         Ft_T.bmm(vtp1.unsqueeze(2)).squeeze(2)

            n_state = self.n_state
            Qt_xx = Qt[:, :n_state, :n_state]
            Qt_xu = Qt[:, :n_state, n_state:]
            Qt_ux = Qt[:, n_state:, :n_state]
            Qt_uu = Qt[:, n_state:, n_state:]
            qt_x = qt[:, :n_state]
            qt_u = qt[:, n_state:]

            if self.u_lower is None:
                if self.n_ctrl == 1 and self.u_zero_I is None:
                    Kt = -(1. / Qt_uu) * Qt_ux
                    kt = -(1. / Qt_uu.squeeze(2)) * qt_u
                else:
                    if self.u_zero_I is None:
                        Qt_uu_inv = [
                            torch.pinverse(Qt_uu[i])
                            for i in range(Qt_uu.shape[0])
                        ]
                        Qt_uu_inv = torch.stack(Qt_uu_inv)
                        Kt = -Qt_uu_inv.bmm(Qt_ux)
                        kt = util.bmv(-Qt_uu_inv, qt_u)

                        # Qt_uu_LU = Qt_uu.btrifact()
                        # Kt = -Qt_ux.btrisolve(*Qt_uu_LU)
                        # kt = -qt_u.btrisolve(*Qt_uu_LU)
                    else:
                        # Solve with zero constraints on the active controls.
                        I = self.u_zero_I[t]
                        notI = 1 - I

                        qt_u_ = qt_u.clone()
                        qt_u_[I] = 0

                        Qt_uu_ = Qt_uu.clone()

                        if I.is_cuda:
                            notI_ = notI.float()
                            Qt_uu_I = (1 - util.bger(notI_, notI_)).type_as(I)
                        else:
                            Qt_uu_I = 1 - util.bger(notI, notI)

                        Qt_uu_[Qt_uu_I] = 0.
                        Qt_uu_[util.bdiag(I)] += 1e-8

                        Qt_ux_ = Qt_ux.clone()
                        Qt_ux_[I.unsqueeze(2).repeat(1, 1, Qt_ux.size(2))] = 0.

                        if self.n_ctrl == 1:
                            Kt = -(1. / Qt_uu_) * Qt_ux_
                            kt = -(1. / Qt_uu.squeeze(2)) * qt_u_
                        else:
                            Qt_uu_LU_ = Qt_uu_.btrifact()
                            Kt = -Qt_ux_.btrisolve(*Qt_uu_LU_)
                            kt = -qt_u_.btrisolve(*Qt_uu_LU_)
            else:
                assert self.delta_space
                lb = self.get_bound('lower', t) - u[t]
                ub = self.get_bound('upper', t) - u[t]
                if self.delta_u is not None:
                    lb[lb < -self.delta_u] = -self.delta_u
                    ub[ub > self.delta_u] = self.delta_u
                kt, Qt_uu_free_LU, If, n_qp_iter = pnqp(Qt_uu,
                                                        qt_u,
                                                        lb,
                                                        ub,
                                                        x_init=prev_kt,
                                                        n_iter=20)
                if self.verbose > 1:
                    print('  + n_qp_iter: ', n_qp_iter + 1)
                n_total_qp_iter += 1 + n_qp_iter
                prev_kt = kt
                Qt_ux_ = Qt_ux.clone()
                Qt_ux_[(1 - If).unsqueeze(2).repeat(1, 1, Qt_ux.size(2))] = 0
                if self.n_ctrl == 1:
                    # Bad naming, Qt_uu_free_LU isn't the LU in this case.
                    Kt = -((1. / Qt_uu_free_LU) * Qt_ux_)
                else:
                    Kt = -Qt_ux_.btrisolve(*Qt_uu_free_LU)

            Kt_T = Kt.transpose(1, 2)

            Ks.append(Kt)
            ks.append(kt)

            Vtp1 = Qt_xx + Qt_xu.bmm(Kt) + Kt_T.bmm(Qt_ux) + Kt_T.bmm(
                Qt_uu).bmm(Kt)
            vtp1 = qt_x + Qt_xu.bmm(kt.unsqueeze(2)).squeeze(2) + \
                   Kt_T.bmm(qt_u.unsqueeze(2)).squeeze(2) + \
                   Kt_T.bmm(Qt_uu).bmm(kt.unsqueeze(2)).squeeze(2)

        return Ks, ks, LqrBackOut(n_total_qp_iter=n_total_qp_iter)
Beispiel #10
0
    def backward(self, dl_dx, dl_du):
        start = time.time()
        x_init, C, c, F, f, new_x, new_u = self.saved_tensors

        r = []
        for t in range(self.T):
            rt = torch.cat((dl_dx[t], dl_du[t]), 1)
            r.append(rt)
        r = torch.stack(r)

        if self.u_lower is None:
            I = None
        else:
            I = (torch.abs(new_u - self.u_lower) <= 1e-8) | \
                (torch.abs(new_u - self.u_upper) <= 1e-8)
        dx_init = Variable(torch.zeros_like(x_init))
        _mpc = mpc.MPC(
            self.n_state,
            self.n_ctrl,
            self.T,
            u_zero_I=I,
            u_init=None,
            lqr_iter=1,
            verbose=-1,
            n_batch=C.size(1),
            delta_u=None,
            # exit_unconverged=True, # It's really bad if this doesn't converge.
            exit_unconverged=False,  # It's really bad if this doesn't converge.
            eps=self.back_eps,
        )
        dx, du, _ = _mpc(dx_init, mpc.QuadCost(C, -r), mpc.LinDx(F, None))

        dx, du = dx.data, du.data
        dxu = torch.cat((dx, du), 2)
        xu = torch.cat((new_x, new_u), 2)

        dC = torch.zeros_like(C)
        for t in range(self.T):
            xut = torch.cat((new_x[t], new_u[t]), 1)
            dxut = dxu[t]
            dCt = -0.5 * (util.bger(dxut, xut) + util.bger(xut, dxut))
            dC[t] = dCt

        dc = -dxu

        lams = []
        prev_lam = None
        for t in range(self.T - 1, -1, -1):
            Ct_xx = C[t, :, :self.n_state, :self.n_state]
            Ct_xu = C[t, :, :self.n_state, self.n_state:]
            ct_x = c[t, :, :self.n_state]
            xt = new_x[t]
            ut = new_u[t]
            lamt = util.bmv(Ct_xx, xt) + util.bmv(Ct_xu, ut) + ct_x
            if prev_lam is not None:
                Fxt = F[t, :, :, :self.n_state].transpose(1, 2)
                lamt += util.bmv(Fxt, prev_lam)
            lams.append(lamt)
            prev_lam = lamt
        lams = list(reversed(lams))

        dlams = []
        prev_dlam = None
        for t in range(self.T - 1, -1, -1):
            dCt_xx = C[t, :, :self.n_state, :self.n_state]
            dCt_xu = C[t, :, :self.n_state, self.n_state:]
            drt_x = -r[t, :, :self.n_state]
            dxt = dx[t]
            dut = du[t]
            dlamt = util.bmv(dCt_xx, dxt) + util.bmv(dCt_xu, dut) + drt_x
            if prev_dlam is not None:
                Fxt = F[t, :, :, :self.n_state].transpose(1, 2)
                dlamt += util.bmv(Fxt, prev_dlam)
            dlams.append(dlamt)
            prev_dlam = dlamt
        dlams = torch.stack(list(reversed(dlams)))

        dF = torch.zeros_like(F)
        for t in range(self.T - 1):
            xut = xu[t]
            lamt = lams[t + 1]

            dxut = dxu[t]
            dlamt = dlams[t + 1]

            dF[t] = -(util.bger(dlamt, xut) + util.bger(lamt, dxut))

        if f.nelement() > 0:
            _dlams = dlams[1:]
            assert _dlams.shape == f.shape
            df = -_dlams
        else:
            df = torch.Tensor()

        dx_init = -dlams[0]

        self.backward_time = time.time() - start
        return dx_init, dC, dc, dF, df
Beispiel #11
0
def pnqp(H, q, lower, upper, x_init=None, n_iter=20):
    GAMMA = 0.1
    n_batch, n, _ = H.size()
    pnqp_I = 1e-11 * torch.eye(n).type_as(H).expand_as(H)

    def obj(x):
        return 0.5 * util.bquad(x, H) + util.bdot(q, x)

    if x_init is None:
        if n == 1:
            x_init = -(1. / H.squeeze(2)) * q
        else:
            # H_lu = H.btrifact()  # XXX deprecated!!!
            H_lu = torch.lu(H)
            x_init = -q.btrisolve(H_lu[0],
                                  H_lu[1])  # Clamped in the x assignment.
    else:
        x_init = x_init.clone()  # Don't over-write the original x_init.

    x = util.eclamp(x_init, lower, upper)

    # Active examples in the batch.
    J = torch.ones(n_batch).type_as(x).byte()

    for i in range(n_iter):
        g = util.bmv(H, x) + q
        Ic = ((x == lower) & (g > 0)) | ((x == upper) & (g < 0))
        If = 1 - Ic

        if If.is_cuda:
            Hff_I = util.bger(If.float(), If.float()).type_as(If)
            not_Hff_I = 1 - Hff_I
            Hfc_I = util.bger(If.float(), Ic.float()).type_as(If)
        else:
            Hff_I = util.bger(If, If)
            not_Hff_I = 1 - Hff_I
            Hfc_I = util.bger(If, Ic)

        g_ = g.clone()
        g_[Ic] = 0.
        H_ = H.clone()
        H_[not_Hff_I] = 0.0
        H_ += pnqp_I

        if n == 1:
            dx = -(1. / H_.squeeze(2)) * g_
        else:
            # H_lu_ = H_.btrifact()  # XXX deprecated!!!
            H_lu_ = torch.lu(H)
            dx = -g_.btrisolve(*H_lu_)

        J = torch.norm(dx, 2, 1) >= 1e-4
        m = J.sum().item()  # Number of active examples in the batch.
        if m == 0:
            return x, H_ if n == 1 else H_lu_, If, i

        alpha = torch.ones(n_batch).type_as(x)
        decay = 0.1
        max_armijo = GAMMA
        count = 0
        while max_armijo <= GAMMA and count < 10:
            # Crude way of making sure too much time isn't being spent
            # doing the line search.
            # assert count < 10

            maybe_x = util.eclamp(x + torch.diag(alpha).mm(dx), lower, upper)
            armijos = (GAMMA + 1e-6) * torch.ones(n_batch).type_as(x)
            armijos[J] = (obj(x) - obj(maybe_x))[J] / util.bdot(
                g, x - maybe_x)[J]
            I = armijos <= GAMMA
            alpha[I] *= decay
            max_armijo = torch.max(armijos)
            count += 1

        x = maybe_x

    # TODO: Maybe change this to a warning.
    print("[WARNING] pnqp warning: Did not converge")
    return x, H_ if n == 1 else H_lu_, If, i
Beispiel #12
0
    def linearize_dynamics(self, x, u, dynamics, diff):
        # TODO: Cleanup variable usage.

        n_batch = x[0].size(0)

        if self.grad_method == GradMethods.ANALYTIC:
            _u = Variable(u[:-1].view(-1, self.n_ctrl), requires_grad=True)
            _x = Variable(x[:-1].contiguous().view(-1, self.n_state),
                          requires_grad=True)

            # This inefficiently calls dynamics again, but is worth it because
            # we can efficiently compute grad_input for every time step at once.
            _new_x = dynamics(_x, _u)

            # This check is a little expensive and should only be done if
            # modifying this code.
            # assert torch.abs(_new_x.data - torch.cat(x[1:])).max() <= 1e-6

            if not diff:
                _new_x = _new_x.data
                _x = _x.data
                _u = _u.data

            R, S = dynamics.grad_input(_x, _u)

            f = _new_x - util.bmv(R, _x) - util.bmv(S, _u)
            f = f.view(self.T - 1, n_batch, self.n_state)

            R = R.contiguous().view(self.T - 1, n_batch, self.n_state,
                                    self.n_state)
            S = S.contiguous().view(self.T - 1, n_batch, self.n_state,
                                    self.n_ctrl)
            F = torch.cat((R, S), 3)

            if not diff:
                F, f = list(map(Variable, [F, f]))
            return F, f
        else:
            # TODO: This is inefficient and confusing.
            x_init = x[0]
            x = [x_init]
            F, f = [], []
            for t in range(self.T):
                if t < self.T - 1:
                    xt = Variable(x[t], requires_grad=True)
                    ut = Variable(u[t], requires_grad=True)
                    xut = torch.cat((xt, ut), 1)
                    new_x = dynamics(xt, ut)

                    # Linear dynamics approximation.
                    if self.grad_method in [
                            GradMethods.AUTO_DIFF, GradMethods.ANALYTIC_CHECK
                    ]:
                        Rt, St = [], []
                        for j in range(self.n_state):
                            Rj, Sj = torch.autograd.grad(new_x[:, j].sum(),
                                                         [xt, ut],
                                                         retain_graph=True)
                            if not diff:
                                Rj, Sj = Rj.data, Sj.data
                            Rt.append(Rj)
                            St.append(Sj)
                        Rt = torch.stack(Rt, dim=1)
                        St = torch.stack(St, dim=1)

                        if self.grad_method == GradMethods.ANALYTIC_CHECK:
                            assert False  # Not updated
                            Rt_autograd, St_autograd = Rt, St
                            Rt, St = dynamics.grad_input(xt, ut)
                            eps = 1e-8
                            if torch.max(torch.abs(Rt - Rt_autograd)).data[0] > eps or \
                                    torch.max(torch.abs(St - St_autograd)).data[0] > eps:
                                print('''
        nmpc.ANALYTIC_CHECK error: The analytic derivative of the dynamics function may be off.
                                ''')
                            else:
                                print('''
        nmpc.ANALYTIC_CHECK: The analytic derivative of the dynamics function seems correct.
        Re-run with GradMethods.ANALYTIC to continue.
                                ''')
                            sys.exit(0)
                    elif self.grad_method == GradMethods.FINITE_DIFF:
                        Rt, St = [], []
                        for i in range(n_batch):
                            Ri = util.jacobian(lambda s: dynamics(s, ut[i]),
                                               xt[i], 1e-4)
                            Si = util.jacobian(lambda a: dynamics(xt[i], a),
                                               ut[i], 1e-4)
                            if not diff:
                                Ri, Si = Ri.data, Si.data
                            Rt.append(Ri)
                            St.append(Si)
                        Rt = torch.stack(Rt)
                        St = torch.stack(St)
                    else:
                        assert False

                    Ft = torch.cat((Rt, St), 2)
                    F.append(Ft)

                    if not diff:
                        xt, ut, new_x = xt.data, ut.data, new_x.data
                    ft = new_x - util.bmv(Rt, xt) - util.bmv(St, ut)
                    f.append(ft)

                if t < self.T - 1:
                    x.append(util.detach_maybe(new_x))

            F = torch.stack(F, 0)
            f = torch.stack(f, 0)
            if not diff:
                F, f = list(map(Variable, [F, f]))
            return F, f
Beispiel #13
0
    def backward(self):
        """ LQR backward recursion
        Note: Ks ks is reversed version fo original
        :return: Ks, ks gain
        """
        Ks = []
        ks = []
        Vt = None
        vt = None
        # self.T-1 to 0 loop
        for t in range(self.T - 1, -1, -1):
            # initial case
            if t == self.T - 1:
                Qt = self.C[t]
                qt = self.c[t]
            else:
                Ft = self.F[t]
                Ft_T = F.transpose(Ft, axes=(0, 2, 1))
                assert Ft.dtype.kind == 'f', "Ft dtype"
                assert Vt.dtype.kind == 'f', "Vt dtype"
                Qt = self.C[t] + F.matmul(F.matmul(Ft_T, Vt), Ft)
                if self.f is None:
                    # NOTE f.nelement() == 0 condition ?
                    qt = self.c[t] + bmv(Ft_T, vt)
                else:
                    # f is not none
                    ft = self.f[t]
                    qt = self.c[t] + bmv(F.matmul(Ft_T, Vt), ft) + bmv(Ft_T, vt)
            assert list(qt.shape) == [self.n_batch, self.n_sc], "qt dim mismatch"
            assert list(Qt.shape) == [self.n_batch, self.n_sc, self.n_sc], str(Qt.shape) + " Qt dim mismatch"
            Qt_xx = Qt[:, :self.n_state, :self.n_state]
            Qt_xu = Qt[:, :self.n_state, self.n_state:]
            Qt_ux = Qt[:, self.n_state:, :self.n_state]
            Qt_uu = Qt[:, self.n_state:, self.n_state:]
            qt_x = qt[:, :self.n_state]
            qt_u = qt[:, self.n_state:]
            assert list(Qt_uu.shape) == [self.n_batch, self.n_ctrl, self.n_ctrl], "Qt_uu dim mismatch"
            assert list(Qt_ux.shape) == [self.n_batch, self.n_ctrl, self.n_state], "Qt_ux dim mismatch"
            assert list(Qt_xu.shape) == [self.n_batch, self.n_state, self.n_ctrl], "Qt_xu dim mismatch"
            assert list(qt_x.shape) == [self.n_batch, self.n_state], "qt_x dim mismatch"
            assert list(qt_u.shape) == [self.n_batch, self.n_ctrl], "qt_u dim mismatch"
            # Next calculate Kt and kt
            # TODO LU decomposition
            if self.n_ctrl == 1 and self.u_zero_Index is None:
                # scalar
                Kt = - (1. / Qt_uu) * Qt_ux
                kt = - (1. / F.squeeze(Qt_uu, axis=2)) * qt_u
            elif self.u_zero_Index is None:
                # matrix
                Qt_uu_inv = F.batch_inv(Qt_uu)
                Kt = - F.matmul(Qt_uu_inv, Qt_ux)
                kt = - bmv(Qt_uu_inv, qt_u)
            else:
                # u_zero_index is not none
                index = self.u_zero_Index[t]
                qt_u_ = copy.deepcopy(qt_u)
                qt_u_ = F.where(index, self.xp.zeros_like(qt_u_.array), qt_u_)
                Qt_uu_ = copy.deepcopy(Qt_uu)
                notI = 1.0 - F.cast(index, qt_u_.dtype)
                Qt_uu_I = 1 - bger(notI, notI)
                Qt_uu_I = F.cast(Qt_uu_I, 'bool')
                Qt_uu_ = F.where(Qt_uu_I, self.xp.zeros_like(Qt_uu_.array), Qt_uu_)
                index_qt_uu = self.xp.array([self.xp.diagflat(index[i]) for i in range(index.shape[0])])
                Qt_uu_ = F.where(F.cast(index_qt_uu, 'bool'), Qt_uu + 1e-8, Qt_uu)
                Qt_ux_ = copy.deepcopy(Qt_ux)
                index_qt_ux = F.repeat(F.expand_dims(index, axis=2), Qt_ux.shape[2], axis=2)
                Qt_ux_ = F.where(index_qt_ux, self.xp.zeros_like(Qt_ux_.array), Qt_ux)
                #  print("qt_u_", qt_u_)
                #  print("Qt_uu_", Qt_uu_)
                #  print("Qt_ux_", Qt_ux_)
                if self.n_ctrl == 1:
                    Kt = - (1. / Qt_uu_) * Qt_ux_  # NOTE different from original
                    kt = - (1. / F.squeeze(Qt_uu_, axis=2)) * qt_u_
                else:
                    Qt_uu_LU_ = batch_lu_factor(Qt_uu_)
                    Kt = - batch_lu_solve(Qt_uu_LU_, Qt_ux_)
                    kt = - batch_lu_solve(Qt_uu_LU_, qt_u_)
            assert list(Kt.shape) == [self.n_batch, self.n_ctrl, self.n_state], "Kt dim mismatch"
            assert list(kt.shape) == [self.n_batch, self.n_ctrl], "kt dim mismatch"
            Kt_T = F.transpose(Kt, axes=(0, 2, 1))
            Ks.append(Kt)
            ks.append(kt)
            Vt = Qt_xx + F.matmul(Qt_xu, Kt) + F.matmul(Kt_T, Qt_ux) + F.matmul(F.matmul(Kt_T, Qt_uu), Kt)
            vt = qt_x + bmv(Qt_xu, kt) + bmv(Kt_T, qt_u) + bmv(F.matmul(Kt_T, Qt_uu), kt)

        assert len(Ks) == self.T, "Ks length error"

        Ks.reverse()
        ks.reverse()
        return Ks, ks