def __init__(self, env, lqr_iter=500, mpc_T=20, slew_rate_penalty=None): self.env = env if self.env == 'pendulum': self.true_dx = pendulum.PendulumDx() elif self.env == 'cartpole': self.true_dx = cartpole.CartpoleDx() elif self.env == 'pendulum-complex': params = torch.tensor((10., 1., 1., 1.0, 0.1)) self.true_dx = pendulum.PendulumDx(params, simple=False) else: assert False self.lqr_iter = lqr_iter self.mpc_T = mpc_T self.slew_rate_penalty = slew_rate_penalty self.grad_method = GradMethods.AUTO_DIFF self.train_data = None self.val_data = None self.test_data = None
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env', type=str, default='pendulum') args = parser.parse_args() n_batch = 1 if args.env == 'pendulum': T = 20 dx = pendulum.PendulumDx() xinit = torch.zeros(n_batch, dx.n_state) th = 1.0 xinit[:, 0] = np.cos(th) xinit[:, 1] = np.sin(th) xinit[:, 2] = -0.5 elif args.env == 'cartpole': T = 20 dx = cartpole.CartpoleDx() xinit = torch.zeros(n_batch, dx.n_state) th = 0.5 xinit[:, 2] = np.cos(th) xinit[:, 3] = np.sin(th) else: assert False q, p = dx.get_true_obj() u = None ep_length = 100 for t in range(ep_length): x, u = solve_lqr(dx, xinit, q, p, T, dx.linesearch_decay, dx.max_linesearch_iter, u) fig, ax = dx.get_frame(x[0]) fig.savefig('{:03d}.png'.format(t)) plt.close(fig) u = torch.cat((u[1:-1], u[-2:]), 0).contiguous() xinit = x[1] vid_file = 'ctrl_{}.mp4'.format(args.env) if os.path.exists(vid_file): os.remove(vid_file) cmd = ('/usr/bin/ffmpeg -loglevel quiet ' '-r 32 -f image2 -i %03d.png -vcodec ' 'libx264 -crf 25 -pix_fmt yuv420p {}').format(vid_file) os.system(cmd) for t in range(ep_length): os.remove('{:03d}.png'.format(t))
def dataset_loss(self, loader, warmstart=None): true_q, true_p = self.env.true_dx.get_true_obj() true_q, true_p = true_q.to(self.device), true_p.to(self.device) losses = [] for i, (xinits, xs, us, idxs) in enumerate(loader): n_batch = xinits.shape[0] if self.mode == 'nn': # pred_u = self.policy(xinits) # pred_u = pred_u.reshape(-1, self.env.mpc_T, self.n_ctrl) pred_u = self.lstm_forward(xinits) elif self.mode in ['empc', 'sysid']: if self.env_name == 'pendulum-complex': if self.learn_dx: dx = pendulum.PendulumDx(self.env_params, simple=True) # TODO: Hacky to have this here. # class CombDx(nn.Module): # def __init__(self): # super().__init__() # def forward(_self, x, u): # return simple_dx(x,u) + 0.1*self.extra_dx(x,u) # dx = CombDx() else: dx = pendulum.PendulumDx(self.env_params, simple=False) # TODO: Hacky to have this here. # class CombDx(nn.Module): # def __init__(self): # super().__init__() # def forward(_self, x, u): # return simple_dx(x,u) + 0.1*self.extra_dx(x,u) # dx = CombDx() else: dx = self.env.true_dx.__class__(self.env_params) if self.learn_cost: q = torch.sigmoid(self.learn_q_logit) p = q.sqrt() * self.learn_p else: q, p = true_q, true_p _, pred_u = self.env.mpc( dx, xinits, q, p, u_init=warmstart[idxs].transpose(0, 1), # lqr_iter_override=100, ) pred_u = pred_u.transpose(0, 1) warmstart[idxs] = pred_u assert pred_u.shape == us.shape loss = (us.detach() - pred_u).pow(2).mean(dim=1) losses.append(loss) loss = torch.cat(losses).mean().item() return loss
def run(self): torch.manual_seed(self.seed) loss_names = ['epoch'] loss_names.append('im_loss') if self.learn_dx: loss_names.append('sysid_loss') fname = os.path.join(self.save, 'train_losses.csv') train_loss_f = open(fname, 'w') train_loss_f.write('{}\n'.format(','.join(loss_names))) train_loss_f.flush() fname = os.path.join(self.save, 'val_test_losses.csv') vt_loss_f = open(fname, 'w') loss_names = ['epoch'] loss_names += ['im_loss_val', 'im_loss_test'] # if self.learn_dx: # loss_names += ['sysid_loss_val', 'im_loss_val'] vt_loss_f.write('{}\n'.format(','.join(loss_names))) vt_loss_f.flush() if self.learn_dx: fname = os.path.join(self.save, 'dx_hist.csv') dx_f = open(fname, 'w') dx_f.write(','.join( map(str, self.env.true_dx.params.cpu().detach().numpy().tolist()))) dx_f.write('\n') dx_f.flush() if self.learn_cost: fname = os.path.join(self.save, 'cost_hist.csv') cost_f = open(fname, 'w') cost_f.write(','.join( map( str, torch.cat((self.true_q, self.true_p)).cpu().detach().numpy().tolist()))) cost_f.write('\n') cost_f.flush() if self.mode == 'nn': opt = optim.Adam( list(self.state_emb.parameters()) + list(self.ctrl_emb.parameters()) + list(self.decode.parameters()) + list(self.cell.parameters()), 1e-4) elif self.mode == 'empc': params1 = [] if self.learn_cost: params1 += [self.learn_q_logit, self.learn_p] if self.learn_dx: params1.append(self.env_params) params = [{ 'params': params1, 'lr': 1e-2, 'alpha': 0.5, }] # if self.learn_dx and self.env_name == 'pendulum-complex': # params.append({ # 'params': self.extra_dx.parameters(), # 'lr': 1e-4, # }) opt = optim.RMSprop(params) elif self.mode == 'sysid': params = [{ 'params': self.env_params, 'lr': 1e-2, 'alpha': 0.5, }] # if self.env_name == 'pendulum-complex': # params.append({ # 'params': self.extra_dx.parameters(), # 'lr': 1e-4, # }) opt = optim.RMSprop(params) else: assert False T = self.env.mpc_T if self.mode in ['empc', 'sysid']: train_warmstart = torch.zeros(self.n_train, T, self.n_ctrl).to(self.device) val_warmstart = torch.zeros(self.env.val_data.shape[0], T, self.n_ctrl).to(self.device) test_warmstart = torch.zeros(self.env.test_data.shape[0], T, self.n_ctrl).to(self.device) else: train_warmstart = val_warmstart = test_warmstart = None train_data, train = self.make_data(self.env.train_data[:self.n_train], shuffle=True) val_data, val = self.make_data(self.env.val_data) test_data, test = self.make_data(self.env.test_data) best_val_loss = None true_q, true_p = self.env.true_dx.get_true_obj() true_q, true_p = true_q.to(self.device), true_p.to(self.device) n_train_batch = len(train) # nom_u = None # TODO learn_cost_round_robin_interval = 10 cost_update_q = False for i in range(self.n_epoch): if i > 0 and i % learn_cost_round_robin_interval == 0: cost_update_q = not cost_update_q if self.mode in ['empc', 'sysid'] \ and i % self.restart_warmstart_every == 0: train_warmstart.zero_() val_warmstart.zero_() test_warmstart.zero_() for j, (xinits, xs, us, idxs) in enumerate(train): if self.mode == 'nn': # pred_u = self.policy(xinits) # pred_u = pred_u.reshape(-1, self.env.mpc_T, self.n_ctrl) pred_u = self.lstm_forward(xinits) assert pred_u.shape == us.shape im_loss = (us.detach() - pred_u).pow(2).mean() elif self.mode in ['empc', 'sysid']: if self.learn_dx: if self.env_name == 'pendulum-complex': dx = pendulum.PendulumDx(self.env_params, simple=True) # TODO: Hacky to have this here. # class CombDx(nn.Module): # def __init__(self): # super().__init__() # def forward(_self, x, u): # return simple_dx(x,u) + 0.1*self.extra_dx(x,u) # dx = CombDx() else: dx = self.env.true_dx.__class__(self.env_params) else: dx = self.env.true_dx if self.learn_cost: q = torch.sigmoid(self.learn_q_logit) p = q.sqrt() * self.learn_p else: q, p = true_q, true_p nom_x, nom_u = self.env.mpc( dx, xinits, q, p, u_init=train_warmstart[idxs].transpose(0, 1), # u_init=nom_u, # eps_override=0.1, # lqr_iter_override=100, ) nom_u = nom_u.transpose(0, 1) train_warmstart[idxs] = nom_u assert nom_u.shape == us.shape im_loss = (us.detach() - nom_u).pow(2).mean() if self.learn_dx: xs_flat = xs[:, :-1].transpose(0, 2).contiguous().view( self.n_state, -1).t() us_flat = us[:, :-1].transpose(0, 2).contiguous().view( self.n_ctrl, -1).t() pred_next_x = dx(xs_flat, us_flat).t().view( self.n_state, T - 1, -1).transpose(0, 2) next_x = xs[:, 1:] assert next_x.shape == pred_next_x.shape sysid_loss = (next_x.detach() - pred_next_x).pow(2).mean() else: assert False t = [i + j / n_train_batch, im_loss.item()] if self.learn_dx: t.append(sysid_loss.item()) t = ','.join(map(str, t)) print(t) train_loss_f.write(t + '\n') train_loss_f.flush() opt.zero_grad() if self.mode == 'sysid': sysid_loss.backward() else: im_loss.backward() if self.learn_cost: if cost_update_q: print('only updating q') self.learn_p.grad.zero_() else: print('only updating p') self.learn_q_logit.grad.zero_() if self.learn_dx: if self.env_name == 'pendulum-complex': true_params = self.env.true_dx.params[:3] else: true_params = self.env.true_dx.params print( np.array_str(torch.stack( (self.env_params, true_params)).cpu().detach().numpy(), precision=2, suppress_small=True)) dx_f.write(','.join( map(str, self.env_params.cpu().detach().numpy().tolist()))) dx_f.write('\n') dx_f.flush() if self.learn_cost: print( np.array_str( torch.stack(( torch.cat((true_q, true_p)), torch.cat((q, p)), # torch.cat((q.grad, p.grad)), )).cpu().detach().numpy(), precision=2, suppress_small=True)) cost_f.write(','.join( map(str, torch.cat( (q, p)).cpu().detach().numpy().tolist()))) cost_f.write('\n') cost_f.flush() opt.step() # import ipdb; ipdb.set_trace() # if self.learn_cost: # I = self.learn_q.data < 1e-6 # self.learn_q.data[I] = 1e-6 val_loss = self.dataset_loss(val, val_warmstart) test_loss = self.dataset_loss(test, test_warmstart) t = [i, val_loss, test_loss] t = ','.join(map(str, t)) vt_loss_f.write(t + '\n') vt_loss_f.flush() self.last_epoch = i if best_val_loss is None or val_loss < best_val_loss: best_val_loss = val_loss fname = os.path.join(self.save, 'best.pkl') print('Saving best model to {}'.format(fname)) with open(fname, 'wb') as f: pkl.dump(self, f)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env', type=str, default='Gemini_flight_dynamics') args = parser.parse_args() torch.manual_seed(1) n_batch = 1 if args.env == 'pendulum': T = 5 params = torch.tensor((10., 1., 1.)) # Gravity, mass, length. START = time.time() dx = pendulum.PendulumDx(params, simple=True) END = time.time() print('initialize model time:', END - START) xinit = torch.zeros(n_batch, dx.n_state) th = 1.0 xinit[:, 0] = np.cos(th) xinit[:, 1] = np.sin(th) xinit[:, 2] = -0.5 elif args.env == 'cartpole': T = 20 dx = cartpole.CartpoleDx() xinit = torch.zeros(n_batch, dx.n_state) th = 0.5 xinit[:, 2] = np.cos(th) xinit[:, 3] = np.sin(th) elif args.env == 'Gemini_flight_dynamics': T = 5 # START = time.time() dx_control = Gemini_flight_dynamics.flight_dynamics( T, n_batch, NUM_ENSEMBLE_CONTROL, PATH_CONTROL) dx_predict = Gemini_flight_dynamics.flight_dynamics( T, n_batch, NUM_ENSEMBLE_PREDICT, PATH_PREDICT) # END = time.time() # print('initialize model time:', END - START) xinit = np.expand_dims(np.array([ 3.0953e-03, 3.8064e-03, 4.1662e-03, -7.4164e-02, 2.0541e-02, 3.1526e+00, -2.7789e-02, 4.7386e-02, -3.4559e-02, 5.2193e-04, 3.8532e-03, 3.3118e-03, -7.3201e-02, 2.0192e-02, 3.1536e+00, -1.0387e-01, -1.9162e-02, -5.6059e-02, -2.5459e-03, 5.3939e-03, 1.8183e-03, -7.1328e-02, 2.0130e-02, 3.1550e+00, -4.6090e-02, 5.5544e-02, -7.8294e-02, 6.0124e-02, -2.6421e-02, 8.4291e-03, 6.8993e-01, 7.2849e-02, -1.3616e-02, 1.0076e-02, 6.9233e-01 ], dtype='single'), axis=0).repeat(n_batch, axis=0) else: assert False q, p = dx_control.get_true_obj() u = np.tile(dx_control.goal_ctrl, (T, n_batch, 1)) # u= None ep_length = 100 x_plot = [] u_plot = [] MPC_time = 0. # error_int = torch.zeros(T, n_batch, dx_control.n_ctrl) # prev_error = torch.zeros(T, n_batch, dx_control.n_ctrl) for t in range(ep_length): start_ilqr = time.time() x, u = solve_lqr(dx_control, xinit, q, p, T, dx_control.linesearch_decay, dx_control.max_linesearch_iter, u) end_ilqr = time.time() print('one step MPC:', end_ilqr - start_ilqr) MPC_time += end_ilqr - start_ilqr # print('epoch:', t, '| u:', u[0]) x_plot.append(x) u_plot.append(u) # error = torch.cat((dx_control.goal_state[6:9], dx_control.goal_state[2:3])).repeat(T, n_batch, 1) - torch.cat((x[:,:,6:9], x[:,:,2:3]), dim=2) # error_int = util.eclamp(error_int + I * error * dt, -torch.tensor([0.5, 0.5, 0.5, 0.5]).repeat(T, n_batch, 1), torch.tensor([0.5, 0.5, 0.5, 0.5]).repeat(T, n_batch, 1)) # u = util.eclamp(P * error + error_int + D * (error - prev_error) / dt + dx_control.goal_ctrl.repeat(T, n_batch, 1), dx_control.lower, dx_control.upper) # prev_error = error # print('initial u:', u) xinit = dx_predict(x[0], u[0]) # u = torch.cat((u[1:-1], u[-2:]), 0).contiguous() # u = dx_control.goal_ctrl.repeat(T, n_batch, 1) # u = u[0].repeat(T, n_batch, 1) # xinit = x[1] print('average mpc one step time:', MPC_time / ep_length)