class BCAgent(POLOAgent): """ An agent extending upon POLO that uses behavior cloning on the planner predicted actions as a prior to MPC. """ def __init__(self, params): super(BCAgent, self).__init__(params) # Initialize policy network pol_params = self.params['p-bc']['pol_params'] pol_params['input_size'] = self.N pol_params['output_size'] = self.M if 'final_activation' not in pol_params: pol_params['final_activation'] = torch.tanh self.pol = MLP(pol_params) # Create policy optimizer ppar = self.params['p-bc']['pol_optim'] self.pol_optim = torch.optim.Adam(self.pol.parameters(), lr=ppar['lr'], weight_decay=ppar['reg']) # Use a replay buffer that will save planner actions self.pol_buf = ReplayBuffer(self.N, self.M, self.params['p-bc']['buf_size']) # Logging (store cum_rew, cum_emp_rew) self.hist['pols'] = np.zeros((self.T, 2)) self.has_pol = True self.pol_cache = () def get_action(self): """ BCAgent generates a planned trajectory using the behavior-cloned policy and then optimizes it via MPC. """ self.pol.eval() # Run a rollout using the policy starting from the current state infos = self.get_traj_info() self.hist['pols'][self.time] = infos[3:5] self.pol_cache = (infos[0], infos[2]) self.prior_actions = infos[1] # Generate trajectory via MPC with the prior actions as a prior action = super(BCAgent, self).get_action(prior=self.prior_actions) # Add final planning trajectory to BC buffer fin_states, fin_rews = self.cache[2], self.cache[3] fin_states = np.concatenate(([self.prev_obs], fin_states[1:])) pb_pct = self.params['p-bc']['pb_pct'] pb_len = int(pb_pct * fin_states.shape[0]) for t in range(pb_len): self.pol_buf.update(fin_states[t], fin_states[t + 1], fin_rews[t], self.planned_actions[t], False) return action def do_updates(self): """ Learn from the saved buffer of planned actions. """ super(BCAgent, self).do_updates() if self.time % self.params['p-bc']['update_freq'] == 0: self.update_pol() def update_pol(self): """ Update the policy via BC on the planner actions. """ self.pol.train() params = self.params['p-bc'] # Generate batches for training size = min(self.pol_buf.size, self.pol_buf.total_in) num_inds = params['batch_size'] * params['grad_steps'] inds = np.random.randint(0, size, size=num_inds) states = self.pol_buf.buffer['s'][inds] acts = self.pol_buf.buffer['a'][inds] states = torch.tensor(states, dtype=self.dtype) actions = torch.tensor(acts, dtype=self.dtype) for i in range(params['grad_steps']): bi, ei = i * params['batch_size'], (i + 1) * params['batch_size'] # Train based on L2 distance between actions and predictions preds = self.pol.forward(states[bi:ei]) preds = torch.squeeze(preds, dim=-1) targets = torch.squeeze(actions[bi:ei], dim=-1) loss = torch.nn.functional.mse_loss(preds, targets) self.pol_optim.zero_grad() loss.backward() self.pol_optim.step() def get_traj_info(self): """ Run the policy for a full trajectory and return details about the trajectory. """ env_state = self.env.sim.get_state() if self.mujoco else None infos = traj.eval_traj(copy.deepcopy(self.env), env_state, self.prev_obs, mujoco=self.mujoco, perturb=self.perturb, H=self.H, gamma=self.gamma, act_mode='deter', pt=(self.pol, 0), terminal=self.val_ens, tvel=self.tvel) return infos def print_logs(self): """ BC-specific logging information. """ bi, ei = super(BCAgent, self).print_logs() self.print('BC metrics', mode='head') self.print('policy traj rew', self.hist['pols'][self.time - 1][0]) self.print('policy traj emp rew', self.hist['pols'][self.time - 1][1]) return bi, ei def test_policy(self): """ Run the BC action selection mechanism. """ env = copy.deepcopy(self.env) obs = env.reset() if self.tvel is not None: env.set_target_vel(self.tvel) obs = env._get_obs() env_state = env.sim.get_state() if self.mujoco else None infos = traj.eval_traj(env, env_state, obs, mujoco=self.mujoco, perturb=self.perturb, H=self.eval_len, gamma=1, act_mode='deter', pt=(self.pol, 0), tvel=self.tvel) self.hist['pol_test'][self.time] = infos[3]
class POLOAgent(MPCAgent): """ MPC-based agent that uses the Plan Online, Learn Offline (POLO) framework (Lowrey et. al. 2018) for trajectory optimization. """ def __init__(self, params): super(POLOAgent, self).__init__(params) self.H_backup = self.params['polo']['H_backup'] # Create ensemble of value functions model_params = params['polo']['ens_params']['model_params'] model_params['input_size'] = self.N model_params['output_size'] = 1 params['polo']['ens_params']['dtype'] = self.dtype params['polo']['ens_params']['device'] = self.device self.val_ens = Ensemble(self.params['polo']['ens_params']) # Learn from replay buffer self.polo_buf = ReplayBuffer(self.N, self.M, self.params['polo']['buf_size']) # Value (from forward), value mean, value std self.hist['vals'] = np.zeros((self.T, 3)) def get_action(self, prior=None): """ POLO selects action based on MPC optimization with an optimistic terminal value function. """ self.val_ens.eval() # Get value of current state s = torch.tensor(self.prev_obs, dtype=self.dtype) s = s.to(device=self.device) current_val = self.val_ens.forward(s)[0] current_val = torch.squeeze(current_val, -1) current_val = current_val.detach().cpu().numpy() # Get prediction of every function in ensemble preds = self.val_ens.get_preds_np(self.prev_obs) # Log information from value function self.hist['vals'][self.time] = \ np.array([current_val, np.mean(preds), np.std(preds)]) # Run MPC to get action act = super(POLOAgent, self).get_action(terminal=self.val_ens, prior=prior) return act def action_taken(self, prev_obs, obs, rew, done, ifo): """ Update buffer for value function learning. """ self.polo_buf.update(prev_obs, obs, rew, done) def do_updates(self): """ POLO learns a value function from its past true history of interactions with the environment. """ super(POLOAgent, self).do_updates() if self.time % self.params['polo']['update_freq'] == 0: self.val_ens.update_from_buf(self.polo_buf, self.params['polo']['grad_steps'], self.params['polo']['batch_size'], self.params['polo']['H_backup'], self.gamma) def print_logs(self): """ POLO-specific logging information. """ bi, ei = super(POLOAgent, self).print_logs() self.print('POLO metrics', mode='head') self.print('current state val', self.hist['vals'][self.time - 1][0]) self.print('current state std', self.hist['vals'][self.time - 1][2]) return bi, ei