예제 #1
0
    def forward(self, Ks, ks):
        """ LQR forward recursion

        :param Ks: solved in backward recursion
        :param ks: solved in forward recursion
        :return: x, u
        """
        assert len(Ks) == self.T, "Ks length error"
        new_x = [self.x_init]
        new_u = []
        for t in range(self.T):
            Kt = Ks[t]
            kt = ks[t]
            xt = new_x[t]
            assert list(xt.shape) == [self.n_batch, self.n_state], str(xt.shape) + \
                                                                   " xt dim mismatch: expected" + str(
                [self.n_batch, self.n_state])
            new_ut = xpbmv(Kt, xt) + kt
            assert list(new_ut.shape) == [self.n_batch, self.n_ctrl
                                          ], "actual:" + str(new_ut.shape)
            if self.u_zero_Index is not None:
                assert self.u_zero_Index[t].shape == new_ut.shape, str(
                    self.u_zero_Index[t].shape) + " : " + str(new_ut.shape)
                new_ut = self.xp.where(self.u_zero_Index[t],
                                       self.xp.zeros_like(new_ut), new_ut)
                assert list(new_ut.shape) == [self.n_batch, self.n_ctrl
                                              ], "actual:" + str(new_ut.shape)
            new_u.append(new_ut)
            if t < self.T - 1:
                assert list(xt.shape) == [self.n_batch, self.n_state
                                          ], "actual:" + str(xt.shape)
                assert list(new_u[t].shape) == [
                    self.n_batch, self.n_ctrl
                ], "actual:" + str(new_u[t].shape)
                xu_t = self.xp.concatenate((xt, new_u[t]), axis=1)
                x = xpbmv(self.F[t], xu_t)
                if self.f is not None:
                    x += self.f[t]
                assert list(x.shape) == [self.n_batch, self.n_state], str(x.shape) + \
                                                                      " x dim mismatch: expected" + str(
                    [self.n_batch, self.n_state])
                new_x.append(x)
        new_x = self.xp.stack(new_x, axis=0)
        new_u = self.xp.stack(new_u, axis=0)
        assert list(new_x.shape) == [
            self.T, self.n_batch, self.n_state
        ], str(new_x.shape) + " new x dim mismatch"
        assert list(new_u.shape) == [self.T, self.n_batch,
                                     self.n_ctrl], "new u dim mismatch"
        return new_x, new_u
예제 #2
0
    def forward(self, inputs):
        """ Link forward

        :param inputs:
        :return:
        """
        with chainer.no_backprop_mode():
            x_init, C_hat, c_hat, F_hat, f_hat = inputs
            self.retain_inputs((0, 1, 2, 3, 4))
            if self.no_op_forward:
                self.retain_outputs((0, 1))
                return self.current_states, self.controls
            x_init = to_xp(x_init)
            C_hat = to_xp(C_hat)
            c_hat = to_xp(c_hat)
            F_hat = to_xp(F_hat)
            f_hat = to_xp(f_hat)
            if self.need_expand is True:
                # Taylor expansion
                # grad(delta_x,delta_u) = hessian(0,0) @ (delta_x, delta_u) + grad(0,0)
                c_back = [
                ]  # eq(12) in [1], constant term in eq(5.12) can be removed
                for t in range(self.T):
                    xt = self.current_states[t]
                    ut = self.controls[t]
                    xut = self.xp.concatenate((xt, ut), axis=1)
                    assert xut.shape == (self.n_batch, self.n_sc), "expected " + str([self.n_batch, self.n_sc]) + \
                                                                   "acutal" + str(xut.shape)
                    c_back.append(xpbmv(C_hat[t], xut) + c_hat[t])
                c_hat = self.xp.stack(c_back)
                f_hat = None  # eq(13) in [1]
            Ks, ks, _backward = self.backward_rec(C_hat, c_hat, F_hat, f_hat)
            x, u, _forward = self.forward_rec(Ks, ks, self.true_cost,
                                              self.true_dynamics,
                                              self.ls_decay, self.max_ls_iter)
        assert list(x.shape) == [self.T, self.n_batch,
                                 self.n_state], "x dim mismatch"
        self.back_out = _backward
        self.for_out = _forward
        self.retain_outputs((0, 1))
        assert list(x.shape) == [self.T, self.n_batch, self.n_state]
        assert list(u.shape) == [self.T, self.n_batch, self.n_ctrl]
        assert not self.xp.isnan(u).any()
        return x, u
예제 #3
0
    def backward(self):
        """ LQR backward recursion
        Note: Ks ks is reversed version fo original
        :return: Ks, ks gain
        """
        Ks = []
        ks = []
        Vt = None
        vt = None
        # self.T-1 to 0 loop
        for t in range(self.T - 1, -1, -1):
            # initial case
            if t == self.T - 1:
                Qt = self.C[t]
                qt = self.c[t]
            else:
                Ft = self.F[t]
                Ft_T = self.xp.transpose(Ft, axes=(0, 2, 1))
                assert Ft.dtype.kind == 'f', "Ft dtype"
                assert Vt.dtype.kind == 'f', "Vt dtype"
                Qt = self.C[t] + Ft_T @ Vt @ Ft
                if self.f is None:
                    # NOTE f.nelement() == 0 condition ?
                    qt = self.c[t] + xpbmv(Ft_T, vt)
                else:
                    # f is not none
                    ft = self.f[t]
                    qt = self.c[t] + xpbmv(Ft_T @ Vt, ft) + xpbmv(Ft_T, vt)
            assert list(qt.shape) == [self.n_batch,
                                      self.n_sc], "qt dim mismatch"
            assert list(Qt.shape) == [self.n_batch, self.n_sc, self.n_sc
                                      ], str(Qt.shape) + " Qt dim mismatch"
            Qt_xx = Qt[:, :self.n_state, :self.n_state]
            Qt_xu = Qt[:, :self.n_state, self.n_state:]
            Qt_ux = Qt[:, self.n_state:, :self.n_state]
            Qt_uu = Qt[:, self.n_state:, self.n_state:]
            qt_x = qt[:, :self.n_state]
            qt_u = qt[:, self.n_state:]
            assert list(
                Qt_uu.shape) == [self.n_batch, self.n_ctrl,
                                 self.n_ctrl], "Qt_uu dim mismatch"
            assert list(
                Qt_ux.shape) == [self.n_batch, self.n_ctrl,
                                 self.n_state], "Qt_ux dim mismatch"
            assert list(
                Qt_xu.shape) == [self.n_batch, self.n_state,
                                 self.n_ctrl], "Qt_xu dim mismatch"
            assert list(qt_x.shape) == [self.n_batch,
                                        self.n_state], "qt_x dim mismatch"
            assert list(qt_u.shape) == [self.n_batch,
                                        self.n_ctrl], "qt_u dim mismatch"
            # Next calculate Kt and kt
            # TODO LU decomposition
            if True:
                # u_zero_index is not none
                index = self.u_zero_Index[t]
                qt_u_ = copy.deepcopy(qt_u)
                qt_u_[index] = 0.0
                Qt_uu_ = copy.deepcopy(Qt_uu)
                notI = 1.0 - index.astype('float')
                Qt_uu_I = 1 - xpbger(notI, notI)
                Qt_uu_I = Qt_uu_I.astype('bool')
                Qt_uu_[Qt_uu_I] = 0.0
                # Qt_uu_ = F.where(Qt_uu_I, self.xp.zeros_like(Qt_uu_.array), Qt_uu_)
                index_qt_uu = self.xp.array([
                    self.xp.diagflat(index[i]) for i in range(index.shape[0])
                ])
                Qt_uu_[index_qt_uu] += 1e-8
                # Qt_uu_ = F.where(F.cast(index_qt_uu, 'bool'), Qt_uu + 1e-8, Qt_uu)
                Qt_ux_ = copy.deepcopy(Qt_ux)
                index_qt_ux = self.xp.repeat(self.xp.expand_dims(index,
                                                                 axis=2),
                                             Qt_ux.shape[2],
                                             axis=2)
                Qt_ux_[index_qt_ux] = 0.0
                # Qt_ux_ = F.where(index_qt_ux, self.xp.zeros_like(Qt_ux_.array), Qt_ux)
                #  print("qt_u_", qt_u_)
                #  print("Qt_uu_", Qt_uu_)
                #  print("Qt_ux_", Qt_ux_)
                if self.n_ctrl == 1:
                    Kt = -(1. /
                           Qt_uu_) * Qt_ux_  # NOTE different from original
                    kt = -(1. / self.xp.squeeze(Qt_uu_, axis=2)) * qt_u_
                else:
                    Qt_uu_LU_ = xpbatch_lu_factor(Qt_uu_)
                    Kt = -xpbatch_lu_solve(Qt_uu_LU_, Qt_ux_)
                    kt = -xpbatch_lu_solve(Qt_uu_LU_, qt_u_)
            assert list(Kt.shape) == [self.n_batch, self.n_ctrl,
                                      self.n_state], "Kt dim mismatch"
            assert list(kt.shape) == [self.n_batch,
                                      self.n_ctrl], "kt dim mismatch"
            Kt_T = self.xp.transpose(Kt, axes=(0, 2, 1))
            Ks.append(Kt)
            ks.append(kt)
            Vt = Qt_xx + self.xp.matmul(Qt_xu, Kt) + self.xp.matmul(
                Kt_T, Qt_ux) + self.xp.matmul(self.xp.matmul(Kt_T, Qt_uu), Kt)
            vt = qt_x + xpbmv(Qt_xu, kt) + xpbmv(Kt_T, qt_u) + xpbmv(
                self.xp.matmul(Kt_T, Qt_uu), kt)

        assert len(Ks) == self.T, "Ks length error"

        Ks.reverse()
        ks.reverse()
        return Ks, ks
예제 #4
0
    def backward_rec(self, C_hat, c_hat, F_hat, f_hat):
        """ Back ward recursion over the linearized trajectory

        :param C_hat: approximated C
        :param c_hat: approximated c
        :param F_hat: approximated F
        :param f_hat: approximated f
        :return:
        """
        assert list(C_hat.shape) == [self.T, self.n_batch, self.n_sc, self.n_sc], \
            "C hat dim mismatch"
        assert list(c_hat.shape) == [self.T, self.n_batch, self.n_sc], \
            str(c_hat.shape) + " c hat dim mismatch: expected " + str([self.T, self.n_batch, self.n_sc])
        if list(F_hat.shape)[0] == self.T:
            F_hat = F_hat[:self.T - 1]
        else:
            assert (F_hat.shape[0]) == self.T - 1, "F_hat dimension"
        assert list(F_hat.shape) == [self.T - 1, self.n_batch, self.n_state, self.n_sc], \
            str(F_hat.shape) + " predicted:" + str(self.T - 1) + " " + \
            str(self.n_batch) + " " + str(self.n_state) + " " + str(self.n_sc) + "F_hat dim mismatch"
        if f_hat is not None:
            assert list(f_hat.shape) == [self.T - 1, self.n_batch, self.n_state] or \
                   list(f_hat.shape) == [self.T, self.n_batch, self.n_state], " f_hat dim mismatch"

        # Ks = []
        # ks = []
        Ks = self.xp.zeros((self.T, self.n_batch, self.n_ctrl, self.n_state))
        ks = self.xp.zeros((self.T, self.n_batch, self.n_ctrl))
        Vt = None
        vt = None
        prev_kt = None  # used for warm start up in Projected newton quadratic programmig
        n_total_qp_iter = 0
        # self.T-1 to 0 loop
        for t in range(self.T - 1, -1, -1):
            if t == self.T - 1:
                Qt = C_hat[t]
                qt = c_hat[t]
            else:
                Ft_hat = F_hat[t]
                Ft_hat_T = self.xp.transpose(Ft_hat, axes=(0, 2, 1))
                Qt = C_hat[t] + Ft_hat_T @ Vt @ Ft_hat
                if f_hat is None:
                    qt = c_hat[t] + xpbmv(Ft_hat_T, vt)
                else:
                    # f is not none
                    ft = f_hat[t]
                    qt = c_hat[t] + xpbmv(Ft_hat_T @ Vt, ft) + xpbmv(
                        Ft_hat_T, vt)
            assert list(qt.shape) == [self.n_batch,
                                      self.n_sc], "qt dim mismatch"
            assert list(Qt.shape) == [self.n_batch, self.n_sc, self.n_sc
                                      ], str(Qt.shape) + " Qt dim mismatch"
            Qt_xx = Qt[:, :self.n_state, :self.n_state]
            Qt_xu = Qt[:, :self.n_state, self.n_state:]
            Qt_ux = Qt[:, self.n_state:, :self.n_state]
            Qt_uu = Qt[:, self.n_state:, self.n_state:]
            qt_x = qt[:, :self.n_state]
            qt_u = qt[:, self.n_state:]
            assert list(
                Qt_uu.shape) == [self.n_batch, self.n_ctrl,
                                 self.n_ctrl], "Qt_uu dim mismatch"
            assert list(
                Qt_ux.shape) == [self.n_batch, self.n_ctrl,
                                 self.n_state], "Qt_ux dim mismatch"
            assert list(
                Qt_xu.shape) == [self.n_batch, self.n_state,
                                 self.n_ctrl], "Qt_xu dim mismatch"
            assert list(qt_x.shape) == [self.n_batch,
                                        self.n_state], "qt_x dim mismatch"
            assert list(qt_u.shape) == [self.n_batch,
                                        self.n_ctrl], "qt_u dim mismatch"
            # calculate K and k
            # different from LQR case starts from here
            # lower_bound of control - current control
            assert not self.xp.isnan(self.controls[t]).any(), str(
                self.controls[t])
            assert not self.xp.isnan(self.u_lower[t]).any()
            assert not self.xp.isnan(self.u_upper[t]).any()
            lower_bound = self.u_lower[t] - self.controls[t]
            # upper_bound of control - current control
            upper_bound = self.u_upper[t] - self.controls[t]
            assert (lower_bound <= upper_bound).all(), " lower is larger than upper" \
                                                       + " lower: " + str(lower_bound) + "upper: " + str(upper_bound)
            kt, Qt_uu_free_LU, Index_free, n_qp_iter = PNQP(Qt_uu,
                                                            qt_u,
                                                            lower_bound,
                                                            upper_bound,
                                                            x_init=prev_kt,
                                                            n_iter=20)
            if self.verbose is True:
                print('  + n_qp_iter in mpc step: ', n_qp_iter + 1)
            n_total_qp_iter += 1 + n_qp_iter
            prev_kt = kt
            Qt_ux_copy = copy.deepcopy(Qt_ux)
            Index_Qt_ux_free = self.xp.repeat(self.xp.expand_dims(
                (1.0 - Index_free), axis=2),
                                              self.n_state,
                                              axis=2)
            Index_Qt_ux_free = Index_Qt_ux_free.astype('bool')
            Qt_ux_copy[Index_Qt_ux_free] = 0.0
            # Qt_ux_copy = F.where(Index_Qt_ux_free, self.xp.zeros_like(Qt_ux.data), Qt_ux_copy)
            if self.n_ctrl == 1:
                # Bad naming, Qt_uu_free_LU is scalar
                Kt = -((1. / Qt_uu_free_LU) * Qt_ux_copy)
            else:
                # Qt_uu K_{t,f} = - Qt_ux
                Kt = -xpbatch_lu_solve(Qt_uu_free_LU, Qt_ux_copy)
            assert list(Kt.shape) == [self.n_batch, self.n_ctrl,
                                      self.n_state], "Kt dim mismatch"
            assert list(kt.shape) == [self.n_batch,
                                      self.n_ctrl], "kt dim mismatch"
            Kt_T = self.xp.transpose(Kt, axes=(0, 2, 1))
            assert not self.xp.isnan(kt).any()
            assert not self.xp.isnan(Kt).any()
            Ks[t] = Kt
            ks[t] = kt
            Vt = Qt_xx + Qt_xu @ Kt + Kt_T @ Qt_ux + Kt_T @ Qt_uu @ Kt
            vt = qt_x + xpbmv(Qt_xu, kt) + xpbmv(Kt_T, qt_u) + xpbmv(
                (Kt_T @ Qt_uu), kt)

        assert len(Ks) == self.T, "Ks length error"
        '''
        Ks.reverse()
        ks.reverse()
        '''
        return Ks, ks, LqrBackOut(n_total_qp_iter=n_total_qp_iter)
예제 #5
0
    def backward(self, target_input_indexes, grad_outputs):
        """

        :param target_input_indexes:
        :param grad_outputs:
        :return:
        """
        # print("backward is called")
        x_init, C_hat, c_hat, F_hat, f_hat = self.get_retained_inputs()
        dl_dx, dl_du = grad_outputs

        x_init = to_xp(x_init)
        C_hat = to_xp(C_hat)
        c_hat = to_xp(c_hat)
        F_hat = to_xp(F_hat)
        f_hat = to_xp(f_hat)
        dl_dx = to_xp(dl_dx)
        dl_du = to_xp(dl_du)
        if dl_dx is None:
            dl_dx = self.xp.zeros((self.T, self.n_batch, self.n_state))
        else:
            assert list(dl_dx.shape) == [self.T, self.n_batch, self.n_state]
        if dl_du is None:
            dl_du = self.xp.zeros((self.T, self.n_batch, self.n_ctrl))
        else:
            assert list(dl_du.shape) == [self.T, self.n_batch, self.n_ctrl]
        # just concatenating dl_dx, dl_du
        d_taus = self.xp.concatenate((dl_dx, dl_du), axis=2)
        # assert False, str(d_taus)
        # choose active control
        new_x, new_u = self.get_retained_outputs()
        new_x = to_xp(new_x)
        new_u = to_xp(new_u)
        active_index = (self.xp.absolute(new_u - self.u_lower) <= 1e-8) | \
                       (self.xp.absolute(new_u - self.u_upper) <= 1e-8)
        dx_init_zero = self.xp.zeros_like(x_init)
        # backward pass LINE (1)
        '''
        print("I", active_index)
        print("r",d_taus)
        print("F",F_hat)
        print("C",C_hat)
        assert  False
        '''
        lqr = LQR_active(dx_init_zero,
                         C_hat,
                         -d_taus,
                         F_hat,
                         None,
                         self.T,
                         self.n_state,
                         self.n_ctrl,
                         u_zero_Index=active_index)
        dx, du = lqr.solve_recursion()
        # print("dx", dx)
        # print("du", du)
        # assert  False
        dxu = self.xp.concatenate((dx, du), axis=2)

        xu = self.xp.concatenate((new_x, new_u), axis=2)
        dC = self.xp.zeros_like(C_hat)
        for t in range(self.T):
            xut = self.xp.concatenate((new_x[t], new_u[t]), axis=1)
            dxut = dxu[t]
            dCt = -0.5 * (xpbger(dxut, xut) + xpbger(xut, dxut))
            assert dC[t].shape == dCt.shape
            dC[t] = dCt
        dc = -dxu
        # Compute Lambda (Forward Pass line(2)) in Module(1)
        # lams = []
        lams = self.xp.zeros((self.T, self.n_batch, self.n_state))
        prev_lam = None
        for t in range(self.T - 1, -1, -1):
            Ct_xx = C_hat[t, :, :self.n_state, :self.n_state]
            Ct_xu = C_hat[t, :, :self.n_state, self.n_state:]
            ct_x = c_hat[t, :, :self.n_state]
            xt = new_x[t]
            ut = new_u[t]
            lamt = xpbmv(Ct_xx, xt) + xpbmv(Ct_xu, ut) + ct_x
            if prev_lam is not None:
                Fxt = self.xp.transpose(F_hat[t, :, :, :self.n_state],
                                        axes=(0, 2, 1))
                lamt += xpbmv(Fxt, prev_lam)
            lams[t] = lamt
            prev_lam = lamt
        # lams = list(reversed(lams))
        # Backward Pass Line(3)
        # Compute the derivatives
        # d_Lambda
        # dlams = []
        dlams = self.xp.zeros_like(lams)
        prev_dlam = None
        for t in range(self.T - 1, -1, -1):
            dCt_xx = C_hat[t, :, :self.n_state, :self.n_state]
            dCt_xu = C_hat[t, :, :self.n_state, self.n_state:]
            drt_x = -d_taus[t, :, :self.n_state]
            dxt = dx[t]
            dut = du[t]
            dlamt = xpbmv(dCt_xx, dxt) + xpbmv(dCt_xu, dut) + drt_x
            if prev_dlam is not None:
                Fxt = self.xp.transpose(F_hat[t, :, :, :self.n_state],
                                        axes=(0, 2, 1))
                dlamt += xpbmv(Fxt, prev_dlam)
            dlams[t] = dlamt
            prev_dlam = dlamt
        # dlams = self.xp.stack(list(reversed(dlams)))
        # d_F
        dF = self.xp.zeros_like(F_hat)
        for t in range(self.T - 1):
            xut = xu[t]
            lamt = lams[t + 1]
            dxut = dxu[t]
            dlamt = dlams[t + 1]
            append_to_dF = -(xpbger(dlamt, xut) + xpbger(lamt, dxut))
            assert dF[t].shape == append_to_dF.shape, str(
                dF[t].shape) + " : " + str(append_to_dF.shape)
            dF[t] = append_to_dF
        if f_hat is not None:
            _dlams = dlams[1:]
            assert _dlams.shape == f_hat.shape
            df = -_dlams
            df = chainer.Variable(df)
        else:
            # CHECK THIS
            df = chainer.Variable()

        dx_init = -dlams[0]

        # print(dx_init.shape)
        # print(dC.shape)
        # print(dc.shape)
        # print(dF.shape)
        # print(dF)
        # assert False
        dx_init = chainer.Variable(dx_init)
        dC = chainer.Variable(dC)
        dc = chainer.Variable(dc)
        dF = chainer.Variable(dF)
        # print("dC", dC)
        # print("dc", dc)
        return dx_init, dC, dc, dF, df
예제 #6
0
    def forward_rec(self, Ks, ks, true_cost, true_dynamics, ls_decay,
                    max_ls_iter):
        """ Forward recursion and line search


        :param Ks:
        :param ks:
        :param true_cost: true Cost function
        :param true_dynamics: true dynamics function
        :param states: states_{1:T} current state iterate
        :param ls_decay: line search decay ratio
        :param max_ls_iter: max line search iteration
        :return:
        """
        assert len(Ks) == self.T, "Ks length error"
        states = self.current_states
        alphas = self.xp.ones(self.n_batch, dtype=self.controls.dtype)
        OLD_COST = xpget_cost(self.T,
                              self.controls,
                              true_cost,
                              true_dynamics,
                              x=states)
        current_cost = None
        n_iter = 0  # number of line search iterations
        full_du_norm = None  # initial change of u
        # line search terminate condition, alpha for all batch is decreased until all batch meets terminal condition.
        while (n_iter < max_ls_iter and current_cost is None
               or (current_cost > OLD_COST).any()):
            assert type(alphas) == self.xp.ndarray, "alphas dtype error"
            new_x = [states[0]]
            new_u = []
            dx = [self.xp.zeros_like(states[0])]
            objs = []  # cost
            for t in range(self.T):
                Kt = Ks[t]
                kt = ks[t]
                new_xt = new_x[t]
                xt = states[t]
                ut = self.controls[t]
                dxt = dx[t]
                new_ut = xpbmv(Kt, dxt) + ut

                assert not self.xp.isnan(new_ut).any()
                assert not self.xp.isinf(new_ut).any()
                xp_alpha = self.xp.diagflat(alphas).astype(dtype=kt.dtype)
                assert not self.xp.isnan(xp_alpha).any()
                assert not self.xp.isinf(xp_alpha).any()
                assert not self.xp.isnan(kt).any()
                add_new_ut = xp_alpha @ kt
                assert not self.xp.isnan(add_new_ut).any(
                ), str(alphas) + " @" + str(kt) + "=" + str(add_new_ut)
                new_ut += add_new_ut
                assert not self.xp.isnan(new_ut).any(), str(xp_alpha) + str(kt)
                new_ut = xpclamp(new_ut, self.u_lower[t], self.u_upper[t])
                # delta_u is None
                assert not self.xp.isnan(new_ut).any()
                new_u.append(new_ut)
                new_xut = self.xp.concatenate((new_xt, new_ut), axis=1)
                if t < self.T - 1:
                    # Calculate next x_{t+1}
                    # Dynamics is linear
                    if isinstance(true_dynamics, LinDx):
                        large_f, f = true_dynamics.F, true_dynamics.f
                        large_f = to_xp(large_f)
                        f = to_xp(f)
                        # new_x_{t+1}
                        new_xtp1 = xpbmv(large_f[t], new_xut)
                        if f is not None:
                            new_xtp1 += f[t]
                    else:
                        # Dynamics is non linear
                        new_xtp1 = true_dynamics(new_xt, new_ut)
                        new_xtp1 = to_xp(new_xtp1)
                    assert not self.xp.isnan(new_xtp1).any()
                    new_x.append(new_xtp1)
                    dx.append(new_xtp1 - states[t + 1])
                # Calculate cost
                # If cost is quadratic
                if isinstance(true_cost, QuadCost):
                    C = true_cost.C
                    c = true_cost.c
                    C = to_xp(C)
                    c = to_xp(c)
                    obj = 0.5 * xpbquad(new_xut, C[t]) + xpbdot(new_xut, c[t])
                else:
                    obj = true_cost(new_xut)
                objs.append(obj)
            objs = self.xp.stack(objs, axis=0)
            current_cost = self.xp.sum(objs, axis=0)
            new_x = self.xp.stack(new_x, axis=0)
            new_u = self.xp.stack(new_u, axis=0)
            # only update once
            if full_du_norm is None:
                du = self.controls - new_u
                du = self.xp.transpose(du, axes=(0, 2, 1)).reshape(
                    self.n_batch, self.T * self.n_ctrl)
                full_du_norm = self.xp.sqrt(self.xp.sum(du**2, axis=1))
            assert list(new_x.shape) == [
                self.T, self.n_batch, self.n_state
            ], str(new_x.shape) + " new x dim mismatch"
            assert list(new_u.shape) == [self.T, self.n_batch,
                                         self.n_ctrl], "new u dim mismatch"
            index_decay = current_cost > OLD_COST
            assert not self.xp.isinf(alphas).any()
            alphas[index_decay] *= ls_decay
            assert not self.xp.isinf(alphas).any(), str(ls_decay)
            n_iter += 1
        # TODO Check this decay
        # If the iteration limit is hit, some alphas
        # are one step too small.
        alphas[current_cost > OLD_COST] /= ls_decay
        du = self.controls - new_u
        du = self.xp.transpose(du,
                               axes=(0, 2, 1)).reshape(self.n_batch,
                                                       self.T * self.n_ctrl)
        alpha_du_norm = self.xp.sqrt(self.xp.sum(du**2, axis=1))
        res = LqrForOut(objs, full_du_norm, alpha_du_norm,
                        self.xp.mean(alphas), current_cost)
        assert not self.xp.isnan(new_x).any()
        assert not self.xp.isnan(new_u).any()
        return new_x, new_u, res
예제 #7
0
def PNQP(H, q, lower, upper, x_init=None, n_iter=20):
    """ projected newton qp solver
    :param H:
    :param q:
    :param lower:
    :param upper:
    :param x_init:
    :param n_iter:
    :return:
    Algorithm[1] in [2]
    1) Get indices: eq(15)
    2) Get Newton step: eq(16)
    3) Convergence: If |g_f|< epsilon << 1 terminate
    4) Line search
    """
    xp = get_array_module(H)
    if type(H) == Variable:
        H = H.array
    if type(q) == Variable:
        q = q.array
    if type(lower) == Variable:
        lower = lower.array
    if type(upper) == Variable:
        upper = upper.array
    if x_init is not None and type(x_init) == Variable:
        x_init = x_init.array
    assert type(H) == xp.ndarray, str(H)
    assert (lower <= upper).all(), " lower is larger than upper" \
                                   + " lower: " + str(lower) + "upper: " + str(upper)
    n_batch = H.shape[0]
    n_dim = H.shape[1]
    assert list(H.shape) == [n_batch, n_dim, n_dim], "H dim mismatch"
    assert list(
        q.shape) == [n_batch,
                     n_dim], "q dim mismatch expected" + str([n_batch, n_dim])
    assert list(
        lower.shape) == [n_batch,
                         n_dim], "lower dim mismatch actual" + str(lower.shape)
    assert list(upper.shape) == [n_batch, n_dim], "upper dim mismatch"
    # small identity matrix
    I_pnqp = xpexpand_batch((1e-11 * xp.eye(n_dim)),
                            n_batch).reshape(n_batch, n_dim, n_dim)
    # print("I_pnqp: ", I_pnqp)
    if x_init is None:
        # make initial guess
        if n_dim == 1:
            x_init = -(1.0 / xp.squeeze(H, axis=2)) * q
        else:
            H_lu = xpbatch_lu_factor(H)
            # Clamped in the x assignment
            # Hx = -q (Don't to unpack H_lu)
            x_init = -xpbatch_lu_solve(H_lu, q)
    else:
        # Don't over-write the original x_init.
        x_init = copy.deepcopy(x_init)
        x_init = to_xp(x_init)
        assert type(x_init[0][0]) != Variable
    # Begin with feasible guess
    assert type(x_init[0][0]) != Variable
    assert type(lower) != Variable
    assert type(upper) != Variable
    x = xpclamp(x_init, lower, upper)
    assert list(x.shape) == [n_batch, n_dim], "x dim mismatch"
    for i in range(n_iter):
        # 1. Get indices
        # calculate gradient
        grad = xpbmv(H, x) + q
        assert type(H) == xp.ndarray
        assert type(grad) == xp.ndarray
        assert type(x) == xp.ndarray
        assert type(lower) == xp.ndarray
        assert type(upper) == xp.ndarray
        try:
            xp.greater(grad, 0.0)
        except:
            print(x)
            print(type(grad))
            print(type(H[0, 0]))
        Index_c = ((x == lower) & (xp.greater(grad, 0.0)) |
                   ((x == upper) & (xp.less(grad, 0.0))))
        Index_c = 1.0 * Index_c
        Index_f = 1.0 - Index_c
        # 2. Get Newton step
        # print(Index_f)
        Index_Hff = xpbger(Index_f, Index_f)
        Index_not_Hff = 1.0 - Index_Hff
        Index_fc = xpbger(Index_f, Index_c)
        Index_c = Index_c.astype('bool')
        g_f = copy.deepcopy(grad)
        # print("g_f original", g_f)
        g_f[Index_c] = 0.0
        # g_f = F.where(Index_c, chainer.Variable(xp.zeros_like(g_f, dtype=g_f.dtype)), g_f)
        # Bad implementation (when n_dim is large)
        H_f = copy.deepcopy(H)
        # Index_not_Hff = xp.cast(Index_not_Hff, 'bool')
        Index_not_Hff = Index_not_Hff.astype('bool')
        H_f[Index_not_Hff] = 0.0
        # H_f = F.where(Index_not_Hff, chainer.Variable(xp.zeros_like(H_f, dtype=H_f.dtype)), H_f)
        H_f += I_pnqp
        # print("H", H)
        # print("H_f", H_f)
        # calculate dx
        if n_dim == 1:
            dx = -(1.0 / xp.squeeze(H_f, axis=2)) * g_f
        else:
            H_lu_f = xpbatch_lu_factor(H_f)
            dx = -xpbatch_lu_solve(H_lu_f, g_f)
        # 3. Convergence
        norm = xp.sqrt(xp.sum(dx**2, axis=1))
        batch_large = norm >= 1e-4
        batch_large = batch_large.astype('float')
        num_large = xp.sum(batch_large)
        if num_large == 0:
            return x, H_f if n_dim == 1 else H_lu_f, Index_f, i
        # check convergence
        if num_large.data == 0:
            '''
            print("x:", x)
            if n_dim != 1:
                print("==============================")
                print("H", H_f)
                print(len(H_lu_f))
                print("H_lu_f[0][0]", H_lu_f[0][0])
                print("H_lu_f[0][1]", H_lu_f[0][1])
                print("H_lu_f[1]", H_lu_f[1])
                print("dx", dx)
                print("Index_f", Index_f)
                print("i", i)
            '''
            return x, H_f if n_dim == 1 else H_lu_f, Index_f, i
        # 4. Line search (Backtracking)
        alpha = xp.ones(n_batch, dtype=x.dtype)
        DECAY = 0.1  # making alpha smaller
        max_lhs = xp.array(GAMMA)
        batch_large = batch_large.astype('bool')
        '''
        print("Hf:", H_f)
        print("batch_large: ", batch_large)
        print("gradient:", grad)
        '''
        count = 0
        while max_lhs <= GAMMA and count < 10:
            x_hat = xpclamp(x + xp.diagflat(alpha) @ dx, lower, upper)
            lhs = (GAMMA + 1e-6) * (xp.ones(n_batch, dtype=x.dtype))
            lhs[batch_large] = (calc_obj(H, q, x) - calc_obj(H, q, x_hat))[batch_large] \
                               / xpbdot(grad, x - x_hat)[batch_large]
            '''
            print("x:", x)
            print("dx:", dx)
            print("x_hat: ", x_hat)
            print("lhs_cng", lhs_cng)
            print("x_hat:", x_hat)
            print("lhs:", lhs)
            '''
            I = lhs <= GAMMA
            alpha[I] *= DECAY  # making smaller
            max_lhs = xp.max(lhs)
            # Don't write cnt += cnt +1 HERE
            count += 1
        x = x_hat

    warnings.warn(
        "Projected Newton Quadratic Programming warning: Did not converge")
    '''
    x = Variable(x)
    H_f = Variable(H_f)
    a, b = H_lu_f
    a = Variable(a)
    H_lu_f = (a, b)
    Index_f =Variable(Index_f)
    '''
    return x, H_f if n_dim == 1 else H_lu_f, Index_f, i