def update( self, args, policy_storage, encoder=None, # variBAD encoder rlloss_through_encoder=False, # whether or not to backprop RL loss through encoder compute_vae_loss=None # function that can compute the VAE loss ): # -- get action values -- advantages = policy_storage.returns[: -1] - policy_storage.value_preds[: -1] if rlloss_through_encoder: # re-compute encoding (to build the computation graph from scratch) utl.recompute_embeddings(policy_storage, encoder, sample=False, update_idx=0) data_generator = policy_storage.feed_forward_generator(advantages, 1) for sample in data_generator: obs_batch, actions_batch, latent_sample_batch, latent_mean_batch, latent_logvar_batch, value_preds_batch, \ return_batch, old_action_log_probs_batch, adv_targ = sample if not rlloss_through_encoder: obs_batch = obs_batch.detach() if latent_sample_batch is not None: latent_sample_batch = latent_sample_batch.detach() latent_mean_batch = latent_mean_batch.detach() latent_logvar_batch = latent_logvar_batch.detach() obs_aug = utl.get_augmented_obs(args=args, obs=obs_batch, latent_sample=latent_sample_batch, latent_mean=latent_mean_batch, latent_logvar=latent_logvar_batch) values, action_log_probs, dist_entropy, action_mean, action_logstd = \ self.actor_critic.evaluate_actions(obs_aug, actions_batch, return_action_mean=True) # -- UPDATE -- # zero out the gradients self.optimizer.zero_grad() if rlloss_through_encoder: self.optimiser_vae.zero_grad() # compute policy loss and backprop value_loss = (return_batch - values).pow(2).mean() action_loss = -(adv_targ.detach() * action_log_probs).mean() # (loss = value loss + action loss + entropy loss, weighted) loss = value_loss * self.value_loss_coef + action_loss - dist_entropy * self.entropy_coef # compute vae loss and backprop if rlloss_through_encoder: loss += args.vae_loss_coeff * compute_vae_loss() # compute gradients (will attach to all networks involved in this computation) loss.backward() nn.utils.clip_grad_norm_(self.actor_critic.parameters(), args.policy_max_grad_norm) if encoder is not None and rlloss_through_encoder: nn.utils.clip_grad_norm_(encoder.parameters(), args.policy_max_grad_norm) # update self.optimizer.step() if rlloss_through_encoder: self.optimiser_vae.step() if (not rlloss_through_encoder) and (self.optimiser_vae is not None): for _ in range(args.num_vae_updates - 1): compute_vae_loss(update=True) return value_loss, action_loss, dist_entropy, loss
def update( self, args, policy_storage, encoder=None, # variBAD encoder rlloss_through_encoder=False, # whether or not to backprop RL loss through encoder compute_vae_loss=None # function that can compute the VAE loss ): # -- get action values -- advantages = policy_storage.returns[: -1] - policy_storage.value_preds[: -1] advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5) # if this is true, we will update the VAE at every PPO update # otherwise, we update it after we update the policy if rlloss_through_encoder: # recompute embeddings (to build computation graph) utl.recompute_embeddings(policy_storage, encoder, sample=False, update_idx=0) value_loss_epoch = 0 action_loss_epoch = 0 dist_entropy_epoch = 0 loss_epoch = 0 for e in range(self.ppo_epoch): data_generator = policy_storage.feed_forward_generator( advantages, self.num_mini_batch) for sample in data_generator: obs_batch, actions_batch, latent_sample_batch, latent_mean_batch, latent_logvar_batch, \ value_preds_batch, return_batch, old_action_log_probs_batch, \ adv_targ = sample if not rlloss_through_encoder: obs_batch = obs_batch.detach() if latent_sample_batch is not None: latent_sample_batch = latent_sample_batch.detach() latent_mean_batch = latent_mean_batch.detach() latent_logvar_batch = latent_logvar_batch.detach() obs_aug = utl.get_augmented_obs( args, obs_batch, latent_sample=latent_sample_batch, latent_mean=latent_mean_batch, latent_logvar=latent_logvar_batch, ) # Reshape to do in a single forward pass for all steps values, action_log_probs, dist_entropy, action_mean, action_logstd = \ self.actor_critic.evaluate_actions(obs_aug, actions_batch, return_action_mean=True) ratio = torch.exp(action_log_probs - old_action_log_probs_batch) surr1 = ratio * adv_targ surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ action_loss = -torch.min(surr1, surr2).mean() if self.use_huber_loss and self.use_clipped_value_loss: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp( -self.clip_param, self.clip_param) value_losses = F.smooth_l1_loss(values, return_batch, reduction='none') value_losses_clipped = F.smooth_l1_loss(value_pred_clipped, return_batch, reduction='none') value_loss = 0.5 * torch.max(value_losses, value_losses_clipped).mean() elif self.use_huber_loss: value_loss = F.smooth_l1_loss(values, return_batch) elif self.use_clipped_value_loss: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp( -self.clip_param, self.clip_param) value_losses = (values - return_batch).pow(2) value_losses_clipped = (value_pred_clipped - return_batch).pow(2) value_loss = 0.5 * torch.max(value_losses, value_losses_clipped).mean() else: value_loss = 0.5 * (return_batch - values).pow(2).mean() # zero out the gradients self.optimiser.zero_grad() if rlloss_through_encoder: self.optimiser_vae.zero_grad() # compute policy loss and backprop loss = value_loss * self.value_loss_coef + action_loss - dist_entropy * self.entropy_coef # compute vae loss and backprop if rlloss_through_encoder: loss += args.vae_loss_coeff * compute_vae_loss() # compute gradients (will attach to all networks involved in this computation) loss.backward() nn.utils.clip_grad_norm_(self.actor_critic.parameters(), args.policy_max_grad_norm) if (encoder is not None) and rlloss_through_encoder: nn.utils.clip_grad_norm_(encoder.parameters(), args.policy_max_grad_norm) # update self.optimiser.step() if rlloss_through_encoder: self.optimiser_vae.step() value_loss_epoch += value_loss.item() action_loss_epoch += action_loss.item() dist_entropy_epoch += dist_entropy.item() loss_epoch += loss.item() if rlloss_through_encoder: # recompute embeddings (to build computation graph) utl.recompute_embeddings(policy_storage, encoder, sample=False, update_idx=e + 1) if (not rlloss_through_encoder) and (self.optimiser_vae is not None): for _ in range(args.num_vae_updates): compute_vae_loss(update=True) num_updates = self.ppo_epoch * self.num_mini_batch value_loss_epoch /= num_updates action_loss_epoch /= num_updates dist_entropy_epoch /= num_updates loss_epoch /= num_updates return value_loss_epoch, action_loss_epoch, dist_entropy_epoch, loss_epoch
def visualise_behaviour( env, args, policy, iter_idx, encoder=None, image_folder=None, return_pos=False, **kwargs, ): num_episodes = args.max_rollouts_per_task unwrapped_env = env.venv.unwrapped.envs[0].unwrapped # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: episode_latent_samples = episode_latent_means = episode_latent_logvars = None # --- roll out policy --- # (re)set environment env.reset_task() (obs_raw, obs_normalised) = env.reset() obs_raw = obs_raw.float().reshape((1, -1)).to(device) obs_normalised = obs_normalised.float().reshape((1, -1)).to(device) start_obs_raw = obs_raw.clone() # initialise actions and rewards (used as initial input to policy if we have a recurrent policy) if hasattr(args, 'hidden_size'): hidden_state = torch.zeros((1, args.hidden_size)).to(device) else: hidden_state = None # keep track of what task we're in and the position of the cheetah task = env.get_task() pos = [[] for _ in range(args.max_rollouts_per_task)] start_pos = unwrapped_env.get_body_com("torso")[:2].copy() for episode_idx in range(num_episodes): curr_rollout_rew = [] pos[episode_idx].append(start_pos) if episode_idx == 0: if encoder is not None: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None if encoder is not None: episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) for step_idx in range(1, env._max_episode_steps + 1): if step_idx == 1: episode_prev_obs[episode_idx].append(start_obs_raw.clone()) else: episode_prev_obs[episode_idx].append(obs_raw.clone()) # act o_aug = utl.get_augmented_obs( args, obs_normalised if args.norm_obs_for_policy else obs_raw, curr_latent_sample, curr_latent_mean, curr_latent_logvar) _, action, _ = policy.act(o_aug, deterministic=True) (obs_raw, obs_normalised), (rew_raw, rew_normalised), done, info = env.step( action.cpu().detach()) obs_raw = obs_raw.float().reshape((1, -1)).to(device) obs_normalised = obs_normalised.float().reshape( (1, -1)).to(device) # keep track of position pos[episode_idx].append( unwrapped_env.get_body_com("torso")[:2].copy()) if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.float().to(device), obs_raw, rew_raw.reshape((1, 1)).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(obs_raw.clone()) episode_rewards[episode_idx].append(rew_raw.clone()) episode_actions[episode_idx].append(action.clone()) if info[0]['done_mdp'] and not done: start_obs_raw = info[0]['start_state'] start_obs_raw = torch.from_numpy( start_obs_raw).float().reshape((1, -1)).to(device) start_pos = unwrapped_env.get_body_com("torso")[:2].copy() break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [ torch.stack(e) for e in episode_latent_means ] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.cat(e) for e in episode_actions] episode_rewards = [torch.cat(e) for e in episode_rewards] # plot the movement of the ant # print(pos) plt.figure(figsize=(5, 4 * num_episodes)) min_dim = -3.5 max_dim = 3.5 span = max_dim - min_dim for i in range(num_episodes): plt.subplot(num_episodes, 1, i + 1) x = list(map(lambda p: p[0], pos[i])) y = list(map(lambda p: p[1], pos[i])) plt.plot(x[0], y[0], 'bo') plt.scatter(x, y, 1, 'g') plt.title('task: {}'.format(task), fontsize=15) if args.env_name == 'AntGoal-v0': plt.plot(task[0], task[1], 'rx') plt.ylabel('y-position (ep {})'.format(i), fontsize=15) if i == num_episodes - 1: plt.xlabel('x-position', fontsize=15) plt.ylabel('y-position (ep {})'.format(i), fontsize=15) plt.xlim(min_dim - 0.05 * span, max_dim + 0.05 * span) plt.ylim(min_dim - 0.05 * span, max_dim + 0.05 * span) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx)) plt.close() else: plt.show() if not return_pos: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns else: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns, pos
def visualise_behaviour(env, args, policy, iter_idx, encoder=None, image_folder=None, **kwargs, ): # TODO: are we going to use the decoders for anything? Some visualisations? num_episodes = args.max_rollouts_per_task unwrapped_env = env.venv.unwrapped.envs[0].unwrapped # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] sample_embeddings = args.sample_embeddings else: episode_latent_samples = episode_latent_means = episode_latent_logvars = None sample_embeddings = False # --- roll out policy --- # (re)set environment env.reset_task() (obs_raw, obs_normalised) = env.reset() obs_raw = obs_raw.float().reshape((1, -1)).to(device) obs_normalised = obs_normalised.float().reshape((1, -1)).to(device) start_obs_raw = obs_raw.clone() # initialise actions and rewards (used as initial input to policy if we have a recurrent policy) if hasattr(args, 'hidden_size'): hidden_state = torch.zeros((1, args.hidden_size)).to(device) else: hidden_state = None # keep track of what task we're in and the position of the cheetah task = env.get_task() pos = [[] for _ in range(args.max_rollouts_per_task)] pos[0] = [unwrapped_env.get_body_com("torso")[0]] for episode_idx in range(num_episodes): curr_rollout_rew = [] if episode_idx == 0: if encoder is not None: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None if encoder is not None: episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append(curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone()) # keep track of position pos[episode_idx].append(unwrapped_env.get_body_com("torso")[0].copy()) for step_idx in range(1, env._max_episode_steps + 1): if step_idx == 1: episode_prev_obs[episode_idx].append(start_obs_raw.clone()) else: episode_prev_obs[episode_idx].append(obs_raw.clone()) # act o_aug = utl.get_augmented_obs(args, obs_normalised if args.norm_obs_for_policy else obs_raw, curr_latent_sample, curr_latent_mean, curr_latent_logvar) _, action, _ = policy.act(o_aug, deterministic=True) (obs_raw, obs_normalised), (rew_raw, rew_normalised), done, info = env.step(action.cpu().detach()) obs_raw = obs_raw.float().reshape((1, -1)).to(device) obs_normalised = obs_normalised.float().reshape((1, -1)).to(device) # keep track of position pos[episode_idx].append(unwrapped_env.get_body_com("torso")[0].copy()) if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.float().to(device), obs_raw, torch.tensor(rew_raw).reshape((1, 1)).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append(curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(obs_raw.clone()) episode_rewards[episode_idx].append(rew_raw.clone()) episode_actions[episode_idx].append(action.clone()) if info[0]['done_mdp'] and not done: start_obs_raw = info[0]['start_state'] start_obs_raw = torch.from_numpy(start_obs_raw).float().reshape((1, -1)).to(device) break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [torch.stack(e) for e in episode_latent_means] episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.cat(e) for e in episode_actions] episode_rewards = [torch.cat(e) for e in episode_rewards] # plot the movement of the half-cheetah plt.figure(figsize=(7, 4 * num_episodes)) min_x = min([min(p) for p in pos]) max_x = max([max(p) for p in pos]) span = max_x - min_x for i in range(num_episodes): plt.subplot(num_episodes, 1, i + 1) plt.plot(pos[i], range(len(pos[i])), 'k') plt.title('task: '.format(task), fontsize=15) plt.ylabel('steps (ep {})'.format(i), fontsize=15) if i == num_episodes - 1: plt.xlabel('position', fontsize=15) else: plt.xticks([]) plt.xlim(min_x - 0.05 * span, max_x + 0.05 * span) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx)) plt.close() else: plt.show() return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns
def get_value(self, obs, latent_sample, latent_mean, latent_logvar): obs = utl.get_augmented_obs(self.args, obs, latent_sample, latent_mean, latent_logvar) return self.policy.actor_critic.get_value(obs).detach()
def get_value(self, obs): obs = utl.get_augmented_obs(args=self.args, obs=obs) return self.policy.actor_critic.get_value(obs).detach()