def get_value(self, state, belief, task, latent_sample, latent_mean, latent_logvar): latent = utl.get_latent_for_policy(self.args, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) return self.policy.actor_critic.get_value(state=state, belief=belief, task=task, latent=latent).detach()
def before_update(self, policy): latent = utl.get_latent_for_policy( self.args, latent_sample=torch.stack(self.latent_samples[:-1]) if self.latent_samples is not None else None, latent_mean=torch.stack(self.latent_mean[:-1]) if self.latent_mean is not None else None, latent_logvar=torch.stack(self.latent_logvar[:-1]) if self.latent_mean is not None else None) _, action_log_probs, _ = policy.evaluate_actions( self.prev_state[:-1], latent, self.beliefs[:-1] if self.beliefs is not None else None, self.tasks[:-1] if self.tasks is not None else None, self.actions) self.action_log_probs = action_log_probs.detach()
def update_rms(self, args, policy_storage): """ Update normalisation parameters for inputs with current data """ if self.pass_state_to_policy and self.norm_state: state = policy_storage.prev_state[:-1] self.state_rms.update(state) if self.pass_latent_to_policy and self.norm_latent: latent = utl.get_latent_for_policy( args, torch.cat(policy_storage.latent_samples[:-1]), torch.cat(policy_storage.latent_mean[:-1]), torch.cat(policy_storage.latent_logvar[:-1])) self.latent_rms.update(latent) if self.pass_belief_to_policy and self.norm_belief: self.belief_rms.update(policy_storage.beliefs[:-1]) if self.pass_task_to_policy and self.norm_task: self.task_rms.update(policy_storage.tasks[:-1])
def update( self, policy_storage, encoder=None, # variBAD encoder rlloss_through_encoder=False, # whether or not to backprop RL loss through encoder compute_vae_loss=None # function that can compute the VAE loss ): # -- get action values -- advantages = policy_storage.returns[: -1] - policy_storage.value_preds[: -1] advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5) # if this is true, we will update the VAE at every PPO update # otherwise, we update it after we update the policy if rlloss_through_encoder: # recompute embeddings (to build computation graph) utl.recompute_embeddings( policy_storage, encoder, sample=False, update_idx=0, detach_every=self.args.tbptt_stepsize if hasattr( self.args, 'tbptt_stepsize') else None) # update the normalisation parameters of policy inputs before updating self.actor_critic.update_rms(args=self.args, policy_storage=policy_storage) # call this to make sure that the action_log_probs are computed # (needs to be done right here because of some caching thing when normalising actions) policy_storage.before_update(self.actor_critic) value_loss_epoch = 0 action_loss_epoch = 0 dist_entropy_epoch = 0 loss_epoch = 0 for e in range(self.ppo_epoch): data_generator = policy_storage.feed_forward_generator( advantages, self.num_mini_batch) for sample in data_generator: state_batch, belief_batch, task_batch, \ actions_batch, latent_sample_batch, latent_mean_batch, latent_logvar_batch, value_preds_batch, \ return_batch, old_action_log_probs_batch, adv_targ = sample if not rlloss_through_encoder: state_batch = state_batch.detach() if latent_sample_batch is not None: latent_sample_batch = latent_sample_batch.detach() latent_mean_batch = latent_mean_batch.detach() latent_logvar_batch = latent_logvar_batch.detach() latent_batch = utl.get_latent_for_policy( args=self.args, latent_sample=latent_sample_batch, latent_mean=latent_mean_batch, latent_logvar=latent_logvar_batch) # Reshape to do in a single forward pass for all steps values, action_log_probs, dist_entropy = \ self.actor_critic.evaluate_actions(state=state_batch, latent=latent_batch, belief=belief_batch, task=task_batch, action=actions_batch) ratio = torch.exp(action_log_probs - old_action_log_probs_batch) surr1 = ratio * adv_targ surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ action_loss = -torch.min(surr1, surr2).mean() if self.use_huber_loss and self.use_clipped_value_loss: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp( -self.clip_param, self.clip_param) value_losses = F.smooth_l1_loss(values, return_batch, reduction='none') value_losses_clipped = F.smooth_l1_loss(value_pred_clipped, return_batch, reduction='none') value_loss = 0.5 * torch.max(value_losses, value_losses_clipped).mean() elif self.use_huber_loss: value_loss = F.smooth_l1_loss(values, return_batch) elif self.use_clipped_value_loss: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp( -self.clip_param, self.clip_param) value_losses = (values - return_batch).pow(2) value_losses_clipped = (value_pred_clipped - return_batch).pow(2) value_loss = 0.5 * torch.max(value_losses, value_losses_clipped).mean() else: value_loss = 0.5 * (return_batch - values).pow(2).mean() # zero out the gradients self.optimiser.zero_grad() if rlloss_through_encoder: self.optimiser_vae.zero_grad() # compute policy loss and backprop loss = value_loss * self.value_loss_coef + action_loss - dist_entropy * self.entropy_coef # compute vae loss and backprop if rlloss_through_encoder: loss += self.args.vae_loss_coeff * compute_vae_loss() # compute gradients (will attach to all networks involved in this computation) loss.backward() # clip gradients nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.args.policy_max_grad_norm) if rlloss_through_encoder: if self.args.encoder_max_grad_norm is not None: nn.utils.clip_grad_norm_( encoder.parameters(), self.args.encoder_max_grad_norm) # update self.optimiser.step() if rlloss_through_encoder: self.optimiser_vae.step() value_loss_epoch += value_loss.item() action_loss_epoch += action_loss.item() dist_entropy_epoch += dist_entropy.item() loss_epoch += loss.item() if rlloss_through_encoder: # recompute embeddings (to build computation graph) utl.recompute_embeddings( policy_storage, encoder, sample=False, update_idx=e + 1, detach_every=self.args.tbptt_stepsize if hasattr( self.args, 'tbptt_stepsize') else None) if (not rlloss_through_encoder) and (self.optimiser_vae is not None): for _ in range(self.args.num_vae_updates): compute_vae_loss(update=True) if self.lr_scheduler_policy is not None: self.lr_scheduler_policy.step() if self.lr_scheduler_encoder is not None: self.lr_scheduler_encoder.step() num_updates = self.ppo_epoch * self.num_mini_batch value_loss_epoch /= num_updates action_loss_epoch /= num_updates dist_entropy_epoch /= num_updates loss_epoch /= num_updates return value_loss_epoch, action_loss_epoch, dist_entropy_epoch, loss_epoch
def visualise_behaviour( env, args, policy, iter_idx, encoder=None, image_folder=None, return_pos=False, **kwargs, ): num_episodes = args.max_rollouts_per_task unwrapped_env = env.venv.unwrapped.envs[0].unwrapped # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None episode_latent_samples = episode_latent_means = episode_latent_logvars = None # --- roll out policy --- # (re)set environment env.reset_task() state, belief, task = utl.reset_env(env, args) start_state = state.clone() # if hasattr(args, 'hidden_size'): # hidden_state = torch.zeros((1, args.hidden_size)).to(device) # else: # hidden_state = None # keep track of what task we're in and the position of the cheetah pos = [[] for _ in range(args.max_rollouts_per_task)] start_pos = unwrapped_env.get_body_com("torso")[0].copy() for episode_idx in range(num_episodes): curr_rollout_rew = [] pos[episode_idx].append(start_pos) if encoder is not None: if episode_idx == 0: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) for step_idx in range(1, env._max_episode_steps + 1): if step_idx == 1: episode_prev_obs[episode_idx].append(start_state.clone()) else: episode_prev_obs[episode_idx].append(state.clone()) # act latent = utl.get_latent_for_policy( args, latent_sample=curr_latent_sample, latent_mean=curr_latent_mean, latent_logvar=curr_latent_logvar) _, action = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True) (state, belief, task), (rew, rew_normalised), done, info = utl.env_step( env, action, args) state = state.reshape((1, -1)).float().to(device) # keep track of position pos[episode_idx].append( unwrapped_env.get_body_com("torso")[0].copy()) if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.reshape(1, -1).float().to(device), state, rew.reshape(1, -1).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(state.clone()) episode_rewards[episode_idx].append(rew.clone()) episode_actions[episode_idx].append( action.reshape(1, -1).clone()) if info[0]['done_mdp'] and not done: start_state = info[0]['start_state'] start_state = torch.from_numpy(start_state).reshape( (1, -1)).float().to(device) start_pos = unwrapped_env.get_body_com("torso")[0].copy() break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [ torch.stack(e) for e in episode_latent_means ] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.cat(e) for e in episode_actions] episode_rewards = [torch.cat(e) for e in episode_rewards] # plot the movement of the half-cheetah plt.figure(figsize=(7, 4 * num_episodes)) min_x = min([min(p) for p in pos]) max_x = max([max(p) for p in pos]) span = max_x - min_x for i in range(num_episodes): plt.subplot(num_episodes, 1, i + 1) # (not plotting the last step because this gives weird artefacts) plt.plot(pos[i][:-1], range(len(pos[i][:-1])), 'k') plt.title('task: {}'.format(task), fontsize=15) plt.ylabel('steps (ep {})'.format(i), fontsize=15) if i == num_episodes - 1: plt.xlabel('position', fontsize=15) # else: # plt.xticks([]) plt.xlim(min_x - 0.05 * span, max_x + 0.05 * span) plt.plot([0, 0], [200, 200], 'b--', alpha=0.2) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx)) plt.close() else: plt.show() if not return_pos: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns else: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns, pos
def update( self, policy_storage, encoder=None, # variBAD encoder rlloss_through_encoder=False, # whether or not to backprop RL loss through encoder compute_vae_loss=None # function that can compute the VAE loss ): # get action values advantages = policy_storage.returns[: -1] - policy_storage.value_preds[: -1] if rlloss_through_encoder: # re-compute encoding (to build the computation graph from scratch) utl.recompute_embeddings( policy_storage, encoder, sample=False, update_idx=0, detach_every=self.args.tbptt_stepsize if hasattr( self.args, 'tbptt_stepsize') else None) # update the normalisation parameters of policy inputs before updating self.actor_critic.update_rms(args=self.args, policy_storage=policy_storage) data_generator = policy_storage.feed_forward_generator(advantages, 1) for sample in data_generator: state_batch, belief_batch, task_batch, \ actions_batch, latent_sample_batch, latent_mean_batch, latent_logvar_batch, value_preds_batch, \ return_batch, old_action_log_probs_batch, adv_targ = sample if not rlloss_through_encoder: state_batch = state_batch.detach() if latent_sample_batch is not None: latent_sample_batch = latent_sample_batch.detach() latent_mean_batch = latent_mean_batch.detach() latent_logvar_batch = latent_logvar_batch.detach() latent_batch = utl.get_latent_for_policy( args=self.args, latent_sample=latent_sample_batch, latent_mean=latent_mean_batch, latent_logvar=latent_logvar_batch) values, action_log_probs, dist_entropy, action_mean, action_logstd = \ self.actor_critic.evaluate_actions(state=state_batch, latent=latent_batch, belief=belief_batch, task=task_batch, action=actions_batch, return_action_mean=True) # -- UPDATE -- # zero out the gradients self.optimiser.zero_grad() if rlloss_through_encoder: self.optimiser_vae.zero_grad() # compute policy loss and backprop value_loss = (return_batch - values).pow(2).mean() action_loss = -(adv_targ.detach() * action_log_probs).mean() # (loss = value loss + action loss + entropy loss, weighted) loss = value_loss * self.value_loss_coef + action_loss - dist_entropy * self.entropy_coef # compute vae loss and backprop if rlloss_through_encoder: loss += self.args.vae_loss_coeff * compute_vae_loss() # compute gradients (will attach to all networks involved in this computation) loss.backward() nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.args.policy_max_grad_norm) if encoder is not None and rlloss_through_encoder: nn.utils.clip_grad_norm_(encoder.parameters(), self.args.policy_max_grad_norm) # update self.optimiser.step() if rlloss_through_encoder: self.optimiser_vae.step() if (not rlloss_through_encoder) and (self.optimiser_vae is not None): for _ in range(self.args.num_vae_updates): compute_vae_loss(update=True) if self.lr_scheduler_policy is not None: self.lr_scheduler_policy.step() if self.lr_scheduler_encoder is not None: self.lr_scheduler_encoder.step() return value_loss, action_loss, dist_entropy, loss
def visualise_behaviour( self, env, args, policy, iter_idx, encoder=None, image_folder=None, return_pos=False, **kwargs, ): num_episodes = args.max_rollouts_per_task # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: episode_latent_samples = episode_latent_means = episode_latent_logvars = None # --- roll out policy --- # (re)set environment env.reset_task() state, belief, task = utl.reset_env(env, args) start_obs_raw = state.clone() task = task.view(-1) if task is not None else None # initialise actions and rewards (used as initial input to policy if we have a recurrent policy) if hasattr(args, 'hidden_size'): hidden_state = torch.zeros((1, args.hidden_size)).to(device) else: hidden_state = None # keep track of what task we're in and the position of the cheetah pos = [[] for _ in range(args.max_rollouts_per_task)] start_pos = state for episode_idx in range(num_episodes): curr_rollout_rew = [] pos[episode_idx].append(start_pos[0]) if episode_idx == 0: if encoder is not None: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None if encoder is not None: episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) for step_idx in range(1, env._max_episode_steps + 1): if step_idx == 1: episode_prev_obs[episode_idx].append(start_obs_raw.clone()) else: episode_prev_obs[episode_idx].append(state.clone()) # act latent = utl.get_latent_for_policy( args, latent_sample=curr_latent_sample, latent_mean=curr_latent_mean, latent_logvar=curr_latent_logvar) _, action = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True) (state, belief, task), (rew, rew_normalised), done, info = utl.env_step( env, action, args) state = state.float().reshape((1, -1)).to(device) task = task.view(-1) if task is not None else None # keep track of position pos[episode_idx].append(state[0]) if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.reshape(1, -1).float().to(device), state, rew.reshape(1, -1).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(state.clone()) episode_rewards[episode_idx].append(rew.clone()) episode_actions[episode_idx].append(action.clone()) if info[0]['done_mdp'] and not done: start_obs_raw = info[0]['start_state'] start_obs_raw = torch.from_numpy( start_obs_raw).float().reshape((1, -1)).to(device) start_pos = start_obs_raw break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [ torch.stack(e) for e in episode_latent_means ] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.stack(e) for e in episode_actions] episode_rewards = [torch.cat(e) for e in episode_rewards] figsize = (5.5, 4) figure, axis = plt.subplots(1, 1, figsize=figsize) xlim = (-1.3, 1.3) if self.goal_sampler == semi_circle_goal_sampler: ylim = (-0.3, 1.3) else: ylim = (-1.3, 1.3) color_map = mpl.colors.ListedColormap( sns.color_palette("husl", num_episodes)) observations = torch.stack( [episode_prev_obs[i] for i in range(num_episodes)]).cpu().numpy() curr_task = env.get_task() # plot goal axis.scatter(*curr_task, marker='x', color='k', s=50) # radius where we get reward if hasattr(self, 'goal_radius'): circle1 = plt.Circle(curr_task, self.goal_radius, color='c', alpha=0.2, edgecolor='none') plt.gca().add_artist(circle1) for i in range(num_episodes): color = color_map(i) path = observations[i] # plot (semi-)circle r = 1.0 if self.goal_sampler == semi_circle_goal_sampler: angle = np.linspace(0, np.pi, 100) else: angle = np.linspace(0, 2 * np.pi, 100) goal_range = r * np.array((np.cos(angle), np.sin(angle))) plt.plot(goal_range[0], goal_range[1], 'k--', alpha=0.1) # plot trajectory axis.plot(path[:, 0], path[:, 1], '-', color=color, label=i) axis.scatter(*path[0, :2], marker='.', color=color, s=50) plt.xlim(xlim) plt.ylim(ylim) plt.xticks([]) plt.yticks([]) plt.legend() plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_behaviour.png'.format(image_folder, iter_idx), dpi=300, bbox_inches='tight') plt.close() else: plt.show() plt_rew = [ episode_rewards[i][:episode_lengths[i]] for i in range(len(episode_rewards)) ] plt.plot(torch.cat(plt_rew).view(-1).cpu().numpy()) plt.xlabel('env step') plt.ylabel('reward per step') plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_rewards.png'.format(image_folder, iter_idx), dpi=300, bbox_inches='tight') plt.close() else: plt.show() if not return_pos: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns else: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns, pos
def get_test_rollout(args, env, policy, encoder=None): num_episodes = args.max_rollouts_per_task # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None episode_latent_means = episode_latent_logvars = None # --- roll out policy --- # (re)set environment env.reset_task() state, belief, task = utl.reset_env(env, args) state = state.reshape((1, -1)).to(device) task = task.view(-1) if task is not None else None for episode_idx in range(num_episodes): curr_rollout_rew = [] if encoder is not None: if episode_idx == 0: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) for step_idx in range(1, env._max_episode_steps + 1): episode_prev_obs[episode_idx].append(state.clone()) latent = utl.get_latent_for_policy( args, latent_sample=curr_latent_sample, latent_mean=curr_latent_mean, latent_logvar=curr_latent_logvar) _, action = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True) action = action.reshape((1, *action.shape)) # observe reward and next obs (state, belief, task), (rew_raw, rew_normalised), done, infos = utl.env_step( env, action, args) state = state.reshape((1, -1)).to(device) task = task.view(-1) if task is not None else None if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.float().to(device), state, rew_raw.reshape((1, 1)).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(state.clone()) episode_rewards[episode_idx].append(rew_raw.clone()) episode_actions[episode_idx].append(action.clone()) if infos[0]['done_mdp']: break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [torch.stack(e) for e in episode_latent_means] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.cat(e) for e in episode_actions] episode_rewards = [torch.cat(r) for r in episode_rewards] return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns
def visualise_behaviour( env, args, policy, iter_idx, encoder=None, image_folder=None, return_pos=False, **kwargs, ): num_episodes = args.max_rollouts_per_task unwrapped_env = env.venv.unwrapped.envs[0].unwrapped # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: episode_latent_samples = episode_latent_means = episode_latent_logvars = None # --- roll out policy --- # (re)set environment env.reset_task() state, belief, task = utl.reset_env(env, args) start_obs_raw = state.clone() task = task.view(-1) if task is not None else None # initialise actions and rewards (used as initial input to policy if we have a recurrent policy) if hasattr(args, 'hidden_size'): hidden_state = torch.zeros((1, args.hidden_size)).to(device) else: hidden_state = None # keep track of what task we're in and the position of the cheetah pos = [[] for _ in range(args.max_rollouts_per_task)] start_pos = unwrapped_env.get_body_com("torso")[:2].copy() for episode_idx in range(num_episodes): curr_rollout_rew = [] pos[episode_idx].append(start_pos) if episode_idx == 0: if encoder is not None: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None if encoder is not None: episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) for step_idx in range(1, env._max_episode_steps + 1): if step_idx == 1: episode_prev_obs[episode_idx].append(start_obs_raw.clone()) else: episode_prev_obs[episode_idx].append(state.clone()) # act latent = utl.get_latent_for_policy( args, latent_sample=curr_latent_sample, latent_mean=curr_latent_mean, latent_logvar=curr_latent_logvar) _, action, _ = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True) (state, belief, task), (rew, rew_normalised), done, info = utl.env_step( env, action, args) state = state.float().reshape((1, -1)).to(device) task = task.view(-1) if task is not None else None # keep track of position pos[episode_idx].append( unwrapped_env.get_body_com("torso")[:2].copy()) if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.reshape(1, -1).float().to(device), state, rew.reshape(1, -1).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(state.clone()) episode_rewards[episode_idx].append(rew.clone()) episode_actions[episode_idx].append(action.clone()) if info[0]['done_mdp'] and not done: start_obs_raw = info[0]['start_state'] start_obs_raw = torch.from_numpy( start_obs_raw).float().reshape((1, -1)).to(device) start_pos = unwrapped_env.get_body_com("torso")[:2].copy() break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [ torch.stack(e) for e in episode_latent_means ] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.stack(e) for e in episode_actions] episode_rewards = [torch.cat(e) for e in episode_rewards] # plot the movement of the ant # print(pos) plt.figure(figsize=(5, 4 * num_episodes)) min_dim = -3.5 max_dim = 3.5 span = max_dim - min_dim for i in range(num_episodes): plt.subplot(num_episodes, 1, i + 1) x = list(map(lambda p: p[0], pos[i])) y = list(map(lambda p: p[1], pos[i])) plt.plot(x[0], y[0], 'bo') plt.scatter(x, y, 1, 'g') curr_task = env.get_task() plt.title('task: {}'.format(curr_task), fontsize=15) if 'Goal' in args.env_name: plt.plot(curr_task[0], curr_task[1], 'rx') plt.ylabel('y-position (ep {})'.format(i), fontsize=15) if i == num_episodes - 1: plt.xlabel('x-position', fontsize=15) plt.ylabel('y-position (ep {})'.format(i), fontsize=15) plt.xlim(min_dim - 0.05 * span, max_dim + 0.05 * span) plt.ylim(min_dim - 0.05 * span, max_dim + 0.05 * span) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx)) plt.close() else: plt.show() if not return_pos: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns else: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns, pos