Example #1
0
    def __init__(self, env, lqr_iter=500, mpc_T=20, slew_rate_penalty=None):
        self.env = env

        if self.env == 'pendulum':
            self.true_dx = pendulum.PendulumDx()
        elif self.env == 'cartpole':
            self.true_dx = cartpole.CartpoleDx()
        elif self.env == 'pendulum-complex':
            params = torch.tensor((10., 1., 1., 1.0, 0.1))
            self.true_dx = pendulum.PendulumDx(params, simple=False)
        else:
            assert False

        self.lqr_iter = lqr_iter
        self.mpc_T = mpc_T
        self.slew_rate_penalty = slew_rate_penalty

        self.grad_method = GradMethods.AUTO_DIFF

        self.train_data = None
        self.val_data = None
        self.test_data = None
Example #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='pendulum')
    args = parser.parse_args()

    n_batch = 1
    if args.env == 'pendulum':
        T = 20
        dx = pendulum.PendulumDx()
        xinit = torch.zeros(n_batch, dx.n_state)
        th = 1.0
        xinit[:, 0] = np.cos(th)
        xinit[:, 1] = np.sin(th)
        xinit[:, 2] = -0.5
    elif args.env == 'cartpole':
        T = 20
        dx = cartpole.CartpoleDx()
        xinit = torch.zeros(n_batch, dx.n_state)
        th = 0.5
        xinit[:, 2] = np.cos(th)
        xinit[:, 3] = np.sin(th)
    else:
        assert False

    q, p = dx.get_true_obj()

    u = None
    ep_length = 100
    for t in range(ep_length):
        x, u = solve_lqr(dx, xinit, q, p, T, dx.linesearch_decay,
                         dx.max_linesearch_iter, u)

        fig, ax = dx.get_frame(x[0])
        fig.savefig('{:03d}.png'.format(t))
        plt.close(fig)

        u = torch.cat((u[1:-1], u[-2:]), 0).contiguous()
        xinit = x[1]

    vid_file = 'ctrl_{}.mp4'.format(args.env)
    if os.path.exists(vid_file):
        os.remove(vid_file)
    cmd = ('/usr/bin/ffmpeg -loglevel quiet '
           '-r 32 -f image2 -i %03d.png -vcodec '
           'libx264 -crf 25 -pix_fmt yuv420p {}').format(vid_file)
    os.system(cmd)
    for t in range(ep_length):
        os.remove('{:03d}.png'.format(t))
Example #3
0
    def dataset_loss(self, loader, warmstart=None):
        true_q, true_p = self.env.true_dx.get_true_obj()
        true_q, true_p = true_q.to(self.device), true_p.to(self.device)

        losses = []
        for i, (xinits, xs, us, idxs) in enumerate(loader):
            n_batch = xinits.shape[0]

            if self.mode == 'nn':
                # pred_u = self.policy(xinits)
                # pred_u = pred_u.reshape(-1, self.env.mpc_T, self.n_ctrl)
                pred_u = self.lstm_forward(xinits)
            elif self.mode in ['empc', 'sysid']:
                if self.env_name == 'pendulum-complex':
                    if self.learn_dx:
                        dx = pendulum.PendulumDx(self.env_params, simple=True)

                        # TODO: Hacky to have this here.
                        # class CombDx(nn.Module):
                        #     def __init__(self):
                        #         super().__init__()

                        #     def forward(_self, x, u):
                        #         return simple_dx(x,u) + 0.1*self.extra_dx(x,u)

                        # dx = CombDx()
                    else:
                        dx = pendulum.PendulumDx(self.env_params, simple=False)

                    # TODO: Hacky to have this here.
                    # class CombDx(nn.Module):
                    #     def __init__(self):
                    #         super().__init__()

                    #     def forward(_self, x, u):
                    #         return simple_dx(x,u) + 0.1*self.extra_dx(x,u)

                    # dx = CombDx()
                else:
                    dx = self.env.true_dx.__class__(self.env_params)

                if self.learn_cost:
                    q = torch.sigmoid(self.learn_q_logit)
                    p = q.sqrt() * self.learn_p
                else:
                    q, p = true_q, true_p

                _, pred_u = self.env.mpc(
                    dx,
                    xinits,
                    q,
                    p,
                    u_init=warmstart[idxs].transpose(0, 1),
                    # lqr_iter_override=100,
                )
                pred_u = pred_u.transpose(0, 1)
                warmstart[idxs] = pred_u

            assert pred_u.shape == us.shape
            loss = (us.detach() - pred_u).pow(2).mean(dim=1)
            losses.append(loss)

        loss = torch.cat(losses).mean().item()
        return loss
Example #4
0
    def run(self):
        torch.manual_seed(self.seed)

        loss_names = ['epoch']
        loss_names.append('im_loss')
        if self.learn_dx:
            loss_names.append('sysid_loss')
        fname = os.path.join(self.save, 'train_losses.csv')
        train_loss_f = open(fname, 'w')
        train_loss_f.write('{}\n'.format(','.join(loss_names)))
        train_loss_f.flush()

        fname = os.path.join(self.save, 'val_test_losses.csv')
        vt_loss_f = open(fname, 'w')
        loss_names = ['epoch']
        loss_names += ['im_loss_val', 'im_loss_test']
        # if self.learn_dx:
        #     loss_names += ['sysid_loss_val', 'im_loss_val']
        vt_loss_f.write('{}\n'.format(','.join(loss_names)))
        vt_loss_f.flush()

        if self.learn_dx:
            fname = os.path.join(self.save, 'dx_hist.csv')
            dx_f = open(fname, 'w')
            dx_f.write(','.join(
                map(str,
                    self.env.true_dx.params.cpu().detach().numpy().tolist())))
            dx_f.write('\n')
            dx_f.flush()

        if self.learn_cost:
            fname = os.path.join(self.save, 'cost_hist.csv')
            cost_f = open(fname, 'w')
            cost_f.write(','.join(
                map(
                    str,
                    torch.cat((self.true_q,
                               self.true_p)).cpu().detach().numpy().tolist())))
            cost_f.write('\n')
            cost_f.flush()

        if self.mode == 'nn':
            opt = optim.Adam(
                list(self.state_emb.parameters()) +
                list(self.ctrl_emb.parameters()) +
                list(self.decode.parameters()) + list(self.cell.parameters()),
                1e-4)
        elif self.mode == 'empc':
            params1 = []

            if self.learn_cost:
                params1 += [self.learn_q_logit, self.learn_p]
            if self.learn_dx:
                params1.append(self.env_params)

            params = [{
                'params': params1,
                'lr': 1e-2,
                'alpha': 0.5,
            }]

            # if self.learn_dx and self.env_name == 'pendulum-complex':
            #     params.append({
            #         'params': self.extra_dx.parameters(),
            #         'lr': 1e-4,
            #     })

            opt = optim.RMSprop(params)
        elif self.mode == 'sysid':
            params = [{
                'params': self.env_params,
                'lr': 1e-2,
                'alpha': 0.5,
            }]

            # if self.env_name == 'pendulum-complex':
            # params.append({
            #     'params': self.extra_dx.parameters(),
            #     'lr': 1e-4,
            # })
            opt = optim.RMSprop(params)
        else:
            assert False

        T = self.env.mpc_T

        if self.mode in ['empc', 'sysid']:
            train_warmstart = torch.zeros(self.n_train, T,
                                          self.n_ctrl).to(self.device)
            val_warmstart = torch.zeros(self.env.val_data.shape[0], T,
                                        self.n_ctrl).to(self.device)
            test_warmstart = torch.zeros(self.env.test_data.shape[0], T,
                                         self.n_ctrl).to(self.device)
        else:
            train_warmstart = val_warmstart = test_warmstart = None

        train_data, train = self.make_data(self.env.train_data[:self.n_train],
                                           shuffle=True)
        val_data, val = self.make_data(self.env.val_data)
        test_data, test = self.make_data(self.env.test_data)

        best_val_loss = None

        true_q, true_p = self.env.true_dx.get_true_obj()
        true_q, true_p = true_q.to(self.device), true_p.to(self.device)

        n_train_batch = len(train)
        # nom_u = None # TODO

        learn_cost_round_robin_interval = 10
        cost_update_q = False

        for i in range(self.n_epoch):
            if i > 0 and i % learn_cost_round_robin_interval == 0:
                cost_update_q = not cost_update_q

            if self.mode in ['empc', 'sysid'] \
               and i % self.restart_warmstart_every == 0:
                train_warmstart.zero_()
                val_warmstart.zero_()
                test_warmstart.zero_()

            for j, (xinits, xs, us, idxs) in enumerate(train):
                if self.mode == 'nn':
                    # pred_u = self.policy(xinits)
                    # pred_u = pred_u.reshape(-1, self.env.mpc_T, self.n_ctrl)
                    pred_u = self.lstm_forward(xinits)
                    assert pred_u.shape == us.shape
                    im_loss = (us.detach() - pred_u).pow(2).mean()
                elif self.mode in ['empc', 'sysid']:
                    if self.learn_dx:
                        if self.env_name == 'pendulum-complex':
                            dx = pendulum.PendulumDx(self.env_params,
                                                     simple=True)

                            # TODO: Hacky to have this here.
                            # class CombDx(nn.Module):
                            #     def __init__(self):
                            #         super().__init__()

                            #     def forward(_self, x, u):
                            #         return simple_dx(x,u) + 0.1*self.extra_dx(x,u)

                            # dx = CombDx()
                        else:
                            dx = self.env.true_dx.__class__(self.env_params)
                    else:
                        dx = self.env.true_dx

                    if self.learn_cost:
                        q = torch.sigmoid(self.learn_q_logit)
                        p = q.sqrt() * self.learn_p
                    else:
                        q, p = true_q, true_p

                    nom_x, nom_u = self.env.mpc(
                        dx,
                        xinits,
                        q,
                        p,
                        u_init=train_warmstart[idxs].transpose(0, 1),
                        # u_init=nom_u,
                        # eps_override=0.1,
                        # lqr_iter_override=100,
                    )
                    nom_u = nom_u.transpose(0, 1)
                    train_warmstart[idxs] = nom_u
                    assert nom_u.shape == us.shape
                    im_loss = (us.detach() - nom_u).pow(2).mean()

                    if self.learn_dx:
                        xs_flat = xs[:, :-1].transpose(0, 2).contiguous().view(
                            self.n_state, -1).t()
                        us_flat = us[:, :-1].transpose(0, 2).contiguous().view(
                            self.n_ctrl, -1).t()
                        pred_next_x = dx(xs_flat, us_flat).t().view(
                            self.n_state, T - 1, -1).transpose(0, 2)
                        next_x = xs[:, 1:]
                        assert next_x.shape == pred_next_x.shape
                        sysid_loss = (next_x.detach() -
                                      pred_next_x).pow(2).mean()
                else:
                    assert False

                t = [i + j / n_train_batch, im_loss.item()]
                if self.learn_dx:
                    t.append(sysid_loss.item())
                t = ','.join(map(str, t))
                print(t)
                train_loss_f.write(t + '\n')
                train_loss_f.flush()
                opt.zero_grad()
                if self.mode == 'sysid':
                    sysid_loss.backward()
                else:
                    im_loss.backward()

                if self.learn_cost:
                    if cost_update_q:
                        print('only updating q')
                        self.learn_p.grad.zero_()
                    else:
                        print('only updating p')
                        self.learn_q_logit.grad.zero_()

                if self.learn_dx:
                    if self.env_name == 'pendulum-complex':
                        true_params = self.env.true_dx.params[:3]
                    else:
                        true_params = self.env.true_dx.params
                    print(
                        np.array_str(torch.stack(
                            (self.env_params,
                             true_params)).cpu().detach().numpy(),
                                     precision=2,
                                     suppress_small=True))
                    dx_f.write(','.join(
                        map(str,
                            self.env_params.cpu().detach().numpy().tolist())))
                    dx_f.write('\n')
                    dx_f.flush()
                if self.learn_cost:
                    print(
                        np.array_str(
                            torch.stack((
                                torch.cat((true_q, true_p)),
                                torch.cat((q, p)),
                                # torch.cat((q.grad, p.grad)),
                            )).cpu().detach().numpy(),
                            precision=2,
                            suppress_small=True))
                    cost_f.write(','.join(
                        map(str,
                            torch.cat(
                                (q, p)).cpu().detach().numpy().tolist())))
                    cost_f.write('\n')
                    cost_f.flush()

                opt.step()
                # import ipdb; ipdb.set_trace()

                # if self.learn_cost:
                #     I = self.learn_q.data < 1e-6
                #     self.learn_q.data[I] = 1e-6

            val_loss = self.dataset_loss(val, val_warmstart)
            test_loss = self.dataset_loss(test, test_warmstart)
            t = [i, val_loss, test_loss]
            t = ','.join(map(str, t))
            vt_loss_f.write(t + '\n')
            vt_loss_f.flush()

            self.last_epoch = i
            if best_val_loss is None or val_loss < best_val_loss:
                best_val_loss = val_loss
                fname = os.path.join(self.save, 'best.pkl')
                print('Saving best model to {}'.format(fname))
                with open(fname, 'wb') as f:
                    pkl.dump(self, f)
Example #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Gemini_flight_dynamics')
    args = parser.parse_args()

    torch.manual_seed(1)
    n_batch = 1
    if args.env == 'pendulum':
        T = 5
        params = torch.tensor((10., 1., 1.))  # Gravity, mass, length.
        START = time.time()
        dx = pendulum.PendulumDx(params, simple=True)
        END = time.time()
        print('initialize model time:', END - START)
        xinit = torch.zeros(n_batch, dx.n_state)
        th = 1.0
        xinit[:, 0] = np.cos(th)
        xinit[:, 1] = np.sin(th)
        xinit[:, 2] = -0.5
    elif args.env == 'cartpole':
        T = 20
        dx = cartpole.CartpoleDx()
        xinit = torch.zeros(n_batch, dx.n_state)
        th = 0.5
        xinit[:, 2] = np.cos(th)
        xinit[:, 3] = np.sin(th)

    elif args.env == 'Gemini_flight_dynamics':
        T = 5
        # START = time.time()
        dx_control = Gemini_flight_dynamics.flight_dynamics(
            T, n_batch, NUM_ENSEMBLE_CONTROL, PATH_CONTROL)
        dx_predict = Gemini_flight_dynamics.flight_dynamics(
            T, n_batch, NUM_ENSEMBLE_PREDICT, PATH_PREDICT)
        # END = time.time()
        # print('initialize model time:', END - START)
        xinit = np.expand_dims(np.array([
            3.0953e-03, 3.8064e-03, 4.1662e-03, -7.4164e-02, 2.0541e-02,
            3.1526e+00, -2.7789e-02, 4.7386e-02, -3.4559e-02, 5.2193e-04,
            3.8532e-03, 3.3118e-03, -7.3201e-02, 2.0192e-02, 3.1536e+00,
            -1.0387e-01, -1.9162e-02, -5.6059e-02, -2.5459e-03, 5.3939e-03,
            1.8183e-03, -7.1328e-02, 2.0130e-02, 3.1550e+00, -4.6090e-02,
            5.5544e-02, -7.8294e-02, 6.0124e-02, -2.6421e-02, 8.4291e-03,
            6.8993e-01, 7.2849e-02, -1.3616e-02, 1.0076e-02, 6.9233e-01
        ],
                                        dtype='single'),
                               axis=0).repeat(n_batch, axis=0)

    else:
        assert False

    q, p = dx_control.get_true_obj()

    u = np.tile(dx_control.goal_ctrl, (T, n_batch, 1))
    # u= None
    ep_length = 100
    x_plot = []
    u_plot = []
    MPC_time = 0.
    # error_int = torch.zeros(T, n_batch, dx_control.n_ctrl)
    # prev_error = torch.zeros(T, n_batch, dx_control.n_ctrl)
    for t in range(ep_length):
        start_ilqr = time.time()
        x, u = solve_lqr(dx_control, xinit, q, p, T,
                         dx_control.linesearch_decay,
                         dx_control.max_linesearch_iter, u)
        end_ilqr = time.time()
        print('one step MPC:', end_ilqr - start_ilqr)
        MPC_time += end_ilqr - start_ilqr
        # print('epoch:', t, '| u:', u[0])
        x_plot.append(x)
        u_plot.append(u)
        # error = torch.cat((dx_control.goal_state[6:9], dx_control.goal_state[2:3])).repeat(T, n_batch, 1) - torch.cat((x[:,:,6:9], x[:,:,2:3]), dim=2)
        # error_int = util.eclamp(error_int + I * error * dt, -torch.tensor([0.5, 0.5, 0.5, 0.5]).repeat(T, n_batch, 1), torch.tensor([0.5, 0.5, 0.5, 0.5]).repeat(T, n_batch, 1))
        # u = util.eclamp(P * error + error_int + D * (error - prev_error) / dt + dx_control.goal_ctrl.repeat(T, n_batch, 1), dx_control.lower, dx_control.upper)
        # prev_error = error
        # print('initial u:', u)
        xinit = dx_predict(x[0], u[0])
        # u = torch.cat((u[1:-1], u[-2:]), 0).contiguous()
        # u = dx_control.goal_ctrl.repeat(T, n_batch, 1)
        # u = u[0].repeat(T, n_batch, 1)
        # xinit = x[1]
    print('average mpc one step time:', MPC_time / ep_length)