def visualise_behaviour( env, args, policy, iter_idx, encoder=None, image_folder=None, return_pos=False, **kwargs, ): num_episodes = args.max_rollouts_per_task unwrapped_env = env.venv.unwrapped.envs[0].unwrapped # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None episode_latent_samples = episode_latent_means = episode_latent_logvars = None # --- roll out policy --- # (re)set environment env.reset_task() state, belief, task = utl.reset_env(env, args) start_state = state.clone() # if hasattr(args, 'hidden_size'): # hidden_state = torch.zeros((1, args.hidden_size)).to(device) # else: # hidden_state = None # keep track of what task we're in and the position of the cheetah pos = [[] for _ in range(args.max_rollouts_per_task)] start_pos = unwrapped_env.get_body_com("torso")[0].copy() for episode_idx in range(num_episodes): curr_rollout_rew = [] pos[episode_idx].append(start_pos) if encoder is not None: if episode_idx == 0: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) for step_idx in range(1, env._max_episode_steps + 1): if step_idx == 1: episode_prev_obs[episode_idx].append(start_state.clone()) else: episode_prev_obs[episode_idx].append(state.clone()) # act latent = utl.get_latent_for_policy( args, latent_sample=curr_latent_sample, latent_mean=curr_latent_mean, latent_logvar=curr_latent_logvar) _, action = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True) (state, belief, task), (rew, rew_normalised), done, info = utl.env_step( env, action, args) state = state.reshape((1, -1)).float().to(device) # keep track of position pos[episode_idx].append( unwrapped_env.get_body_com("torso")[0].copy()) if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.reshape(1, -1).float().to(device), state, rew.reshape(1, -1).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(state.clone()) episode_rewards[episode_idx].append(rew.clone()) episode_actions[episode_idx].append( action.reshape(1, -1).clone()) if info[0]['done_mdp'] and not done: start_state = info[0]['start_state'] start_state = torch.from_numpy(start_state).reshape( (1, -1)).float().to(device) start_pos = unwrapped_env.get_body_com("torso")[0].copy() break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [ torch.stack(e) for e in episode_latent_means ] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.cat(e) for e in episode_actions] episode_rewards = [torch.cat(e) for e in episode_rewards] # plot the movement of the half-cheetah plt.figure(figsize=(7, 4 * num_episodes)) min_x = min([min(p) for p in pos]) max_x = max([max(p) for p in pos]) span = max_x - min_x for i in range(num_episodes): plt.subplot(num_episodes, 1, i + 1) # (not plotting the last step because this gives weird artefacts) plt.plot(pos[i][:-1], range(len(pos[i][:-1])), 'k') plt.title('task: {}'.format(task), fontsize=15) plt.ylabel('steps (ep {})'.format(i), fontsize=15) if i == num_episodes - 1: plt.xlabel('position', fontsize=15) # else: # plt.xticks([]) plt.xlim(min_x - 0.05 * span, max_x + 0.05 * span) plt.plot([0, 0], [200, 200], 'b--', alpha=0.2) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx)) plt.close() else: plt.show() if not return_pos: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns else: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns, pos
def visualise_behaviour( self, env, args, policy, iter_idx, encoder=None, image_folder=None, return_pos=False, **kwargs, ): num_episodes = args.max_rollouts_per_task # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: episode_latent_samples = episode_latent_means = episode_latent_logvars = None # --- roll out policy --- # (re)set environment env.reset_task() state, belief, task = utl.reset_env(env, args) start_obs_raw = state.clone() task = task.view(-1) if task is not None else None # initialise actions and rewards (used as initial input to policy if we have a recurrent policy) if hasattr(args, 'hidden_size'): hidden_state = torch.zeros((1, args.hidden_size)).to(device) else: hidden_state = None # keep track of what task we're in and the position of the cheetah pos = [[] for _ in range(args.max_rollouts_per_task)] start_pos = state for episode_idx in range(num_episodes): curr_rollout_rew = [] pos[episode_idx].append(start_pos[0]) if episode_idx == 0: if encoder is not None: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None if encoder is not None: episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) for step_idx in range(1, env._max_episode_steps + 1): if step_idx == 1: episode_prev_obs[episode_idx].append(start_obs_raw.clone()) else: episode_prev_obs[episode_idx].append(state.clone()) # act latent = utl.get_latent_for_policy( args, latent_sample=curr_latent_sample, latent_mean=curr_latent_mean, latent_logvar=curr_latent_logvar) _, action = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True) (state, belief, task), (rew, rew_normalised), done, info = utl.env_step( env, action, args) state = state.float().reshape((1, -1)).to(device) task = task.view(-1) if task is not None else None # keep track of position pos[episode_idx].append(state[0]) if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.reshape(1, -1).float().to(device), state, rew.reshape(1, -1).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(state.clone()) episode_rewards[episode_idx].append(rew.clone()) episode_actions[episode_idx].append(action.clone()) if info[0]['done_mdp'] and not done: start_obs_raw = info[0]['start_state'] start_obs_raw = torch.from_numpy( start_obs_raw).float().reshape((1, -1)).to(device) start_pos = start_obs_raw break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [ torch.stack(e) for e in episode_latent_means ] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.stack(e) for e in episode_actions] episode_rewards = [torch.cat(e) for e in episode_rewards] figsize = (5.5, 4) figure, axis = plt.subplots(1, 1, figsize=figsize) xlim = (-1.3, 1.3) if self.goal_sampler == semi_circle_goal_sampler: ylim = (-0.3, 1.3) else: ylim = (-1.3, 1.3) color_map = mpl.colors.ListedColormap( sns.color_palette("husl", num_episodes)) observations = torch.stack( [episode_prev_obs[i] for i in range(num_episodes)]).cpu().numpy() curr_task = env.get_task() # plot goal axis.scatter(*curr_task, marker='x', color='k', s=50) # radius where we get reward if hasattr(self, 'goal_radius'): circle1 = plt.Circle(curr_task, self.goal_radius, color='c', alpha=0.2, edgecolor='none') plt.gca().add_artist(circle1) for i in range(num_episodes): color = color_map(i) path = observations[i] # plot (semi-)circle r = 1.0 if self.goal_sampler == semi_circle_goal_sampler: angle = np.linspace(0, np.pi, 100) else: angle = np.linspace(0, 2 * np.pi, 100) goal_range = r * np.array((np.cos(angle), np.sin(angle))) plt.plot(goal_range[0], goal_range[1], 'k--', alpha=0.1) # plot trajectory axis.plot(path[:, 0], path[:, 1], '-', color=color, label=i) axis.scatter(*path[0, :2], marker='.', color=color, s=50) plt.xlim(xlim) plt.ylim(ylim) plt.xticks([]) plt.yticks([]) plt.legend() plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_behaviour.png'.format(image_folder, iter_idx), dpi=300, bbox_inches='tight') plt.close() else: plt.show() plt_rew = [ episode_rewards[i][:episode_lengths[i]] for i in range(len(episode_rewards)) ] plt.plot(torch.cat(plt_rew).view(-1).cpu().numpy()) plt.xlabel('env step') plt.ylabel('reward per step') plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_rewards.png'.format(image_folder, iter_idx), dpi=300, bbox_inches='tight') plt.close() else: plt.show() if not return_pos: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns else: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns, pos
def train(self): """ Main training loop """ start_time = time.time() # reset environments state, belief, task = utl.reset_env(self.envs, self.args) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_state[0].copy_(state) # log once before training with torch.no_grad(): self.log(None, None, start_time) for self.iter_idx in range(self.num_updates): # rollout policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = utl.select_action( args=self.args, policy=self.policy, state=state, belief=belief, task=task, deterministic=False) # observe reward and next obs [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action, self.args) # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) # reset environments that are done done_indices = np.argwhere(done.flatten()).flatten() if len(done_indices) > 0: state, belief, task = utl.reset_env(self.envs, self.args, indices=done_indices, state=state) # add experience to policy buffer self.policy_storage.insert( state=state, belief=belief, task=task, actions=action, action_log_probs=action_log_prob, rewards_raw=rew_raw, rewards_normalised=rew_normalised, value_preds=value, masks=masks_done, bad_masks=bad_masks, done=torch.from_numpy(np.array(done, dtype=float)).unsqueeze(1), ) self.frames += self.args.num_processes # --- UPDATE --- train_stats = self.update(state=state, belief=belief, task=task) # log run_stats = [action, action_log_prob, value] if train_stats is not None: with torch.no_grad(): self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update()
def get_test_rollout(args, env, policy, encoder=None): num_episodes = args.max_rollouts_per_task # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None episode_latent_means = episode_latent_logvars = None # --- roll out policy --- # (re)set environment env.reset_task() state, belief, task = utl.reset_env(env, args) state = state.reshape((1, -1)).to(device) task = task.view(-1) if task is not None else None for episode_idx in range(num_episodes): curr_rollout_rew = [] if encoder is not None: if episode_idx == 0: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) for step_idx in range(1, env._max_episode_steps + 1): episode_prev_obs[episode_idx].append(state.clone()) latent = utl.get_latent_for_policy( args, latent_sample=curr_latent_sample, latent_mean=curr_latent_mean, latent_logvar=curr_latent_logvar) _, action = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True) action = action.reshape((1, *action.shape)) # observe reward and next obs (state, belief, task), (rew_raw, rew_normalised), done, infos = utl.env_step( env, action, args) state = state.reshape((1, -1)).to(device) task = task.view(-1) if task is not None else None if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.float().to(device), state, rew_raw.reshape((1, 1)).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(state.clone()) episode_rewards[episode_idx].append(rew_raw.clone()) episode_actions[episode_idx].append(action.clone()) if infos[0]['done_mdp']: break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [torch.stack(e) for e in episode_latent_means] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.cat(e) for e in episode_actions] episode_rewards = [torch.cat(r) for r in episode_rewards] return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns
def evaluate(args, policy, ret_rms, iter_idx, tasks, encoder=None, num_episodes=None): env_name = args.env_name if hasattr(args, 'test_env_name'): env_name = args.test_env_name if num_episodes is None: num_episodes = args.max_rollouts_per_task num_processes = args.num_processes # --- set up the things we want to log --- # for each process, we log the returns during the first, second, ... episode # (such that we have a minimum of [num_episodes]; the last column is for # any overflow and will be discarded at the end, because we need to wait until # all processes have at least [num_episodes] many episodes) returns_per_episode = torch.zeros( (num_processes, num_episodes + 1)).to(device) # --- initialise environments and latents --- envs = make_vec_envs( env_name, seed=args.seed * 42 + iter_idx, num_processes=num_processes, gamma=args.policy_gamma, device=device, rank_offset=num_processes + 1, # to use diff tmp folders than main processes episodes_per_task=num_episodes, normalise_rew=args.norm_rew_for_policy, ret_rms=ret_rms, tasks=tasks, add_done_info=args.max_rollouts_per_task > 1, ) num_steps = envs._max_episode_steps # reset environments state, belief, task = utl.reset_env(envs, args) # this counts how often an agent has done the same task already task_count = torch.zeros(num_processes).long().to(device) if encoder is not None: # reset latent state to prior latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior( num_processes) else: latent_sample = latent_mean = latent_logvar = hidden_state = None for episode_idx in range(num_episodes): for step_idx in range(num_steps): with torch.no_grad(): _, action = utl.select_action(args=args, policy=policy, state=state, belief=belief, task=task, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, deterministic=True) # observe reward and next obs [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step( envs, action, args) done_mdp = [info['done_mdp'] for info in infos] if encoder is not None: # update the hidden state latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=encoder, next_obs=state, action=action, reward=rew_raw, done=None, hidden_state=hidden_state) # add rewards returns_per_episode[range(num_processes), task_count] += rew_raw.view(-1) for i in np.argwhere(done_mdp).flatten(): # count task up, but cap at num_episodes + 1 task_count[i] = min(task_count[i] + 1, num_episodes) # zero-indexed, so no +1 if np.sum(done) > 0: done_indices = np.argwhere(done.flatten()).flatten() state, belief, task = utl.reset_env(envs, args, indices=done_indices, state=state) envs.close() return returns_per_episode[:, :num_episodes]
def visualise_behaviour(env, args, policy, iter_idx, encoder=None, reward_decoder=None, image_folder=None, **kwargs): """ Visualises the behaviour of the policy, together with the latent state and belief. The environment passed to this method should be a SubProcVec or DummyVecEnv, not the raw env! """ num_episodes = args.max_rollouts_per_task unwrapped_env = env.venv.unwrapped.envs[0] # --- initialise things we want to keep track of --- episode_all_obs = [[] for _ in range(num_episodes)] episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] episode_goals = [] if args.pass_belief_to_policy and (encoder is None): episode_beliefs = [[] for _ in range(num_episodes)] else: episode_beliefs = None if encoder is not None: # keep track of latent spaces episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: episode_latent_samples = episode_latent_means = episode_latent_logvars = None curr_latent_sample = curr_latent_mean = curr_latent_logvar = None # --- roll out policy --- env.reset_task() [state, belief, task] = utl.reset_env(env, args) start_obs = state.clone() for episode_idx in range(args.max_rollouts_per_task): curr_goal = env.get_task() curr_rollout_rew = [] curr_rollout_goal = [] if encoder is not None: if episode_idx == 0: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_all_obs[episode_idx].append(start_obs.clone()) if args.pass_belief_to_policy and (encoder is None): episode_beliefs[episode_idx].append(belief) for step_idx in range(1, env._max_episode_steps + 1): if step_idx == 1: episode_prev_obs[episode_idx].append(start_obs.clone()) else: episode_prev_obs[episode_idx].append(state.clone()) # act _, action, _ = utl.select_action( args=args, policy=policy, state=state.view(-1), belief=belief, task=task, deterministic=True, latent_sample=curr_latent_sample.view(-1) if (curr_latent_sample is not None) else None, latent_mean=curr_latent_mean.view(-1) if (curr_latent_mean is not None) else None, latent_logvar=curr_latent_logvar.view(-1) if (curr_latent_logvar is not None) else None, ) # observe reward and next obs [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step( env, action, args) if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.float().to(device), state, rew_raw.reshape((1, 1)).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_all_obs[episode_idx].append(state.clone()) episode_next_obs[episode_idx].append(state.clone()) episode_rewards[episode_idx].append(rew_raw.clone()) episode_actions[episode_idx].append(action.clone()) curr_rollout_rew.append(rew_raw.clone()) curr_rollout_goal.append(env.get_task().copy()) if args.pass_belief_to_policy and (encoder is None): episode_beliefs[episode_idx].append(belief) if infos[0]['done_mdp'] and not done: start_obs = infos[0]['start_state'] start_obs = torch.from_numpy(start_obs).float().reshape( (1, -1)).to(device) break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) episode_goals.append(curr_goal) # clean up if encoder is not None: episode_latent_means = [ torch.stack(e) for e in episode_latent_means ] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.cat(e) for e in episode_actions] episode_rewards = [torch.cat(e) for e in episode_rewards] # plot behaviour & visualise belief in env rew_pred_means, rew_pred_vars = plot_bb(env, args, episode_all_obs, episode_goals, reward_decoder, episode_latent_means, episode_latent_logvars, image_folder, iter_idx, episode_beliefs) if reward_decoder: plot_rew_reconstruction(env, rew_pred_means, rew_pred_vars, image_folder, iter_idx) return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns
def visualise_behaviour( env, args, policy, iter_idx, encoder=None, image_folder=None, return_pos=False, **kwargs, ): num_episodes = args.max_rollouts_per_task unwrapped_env = env.venv.unwrapped.envs[0].unwrapped # --- initialise things we want to keep track of --- episode_prev_obs = [[] for _ in range(num_episodes)] episode_next_obs = [[] for _ in range(num_episodes)] episode_actions = [[] for _ in range(num_episodes)] episode_rewards = [[] for _ in range(num_episodes)] episode_returns = [] episode_lengths = [] if encoder is not None: episode_latent_samples = [[] for _ in range(num_episodes)] episode_latent_means = [[] for _ in range(num_episodes)] episode_latent_logvars = [[] for _ in range(num_episodes)] else: episode_latent_samples = episode_latent_means = episode_latent_logvars = None # --- roll out policy --- # (re)set environment env.reset_task() state, belief, task = utl.reset_env(env, args) start_obs_raw = state.clone() task = task.view(-1) if task is not None else None # initialise actions and rewards (used as initial input to policy if we have a recurrent policy) if hasattr(args, 'hidden_size'): hidden_state = torch.zeros((1, args.hidden_size)).to(device) else: hidden_state = None # keep track of what task we're in and the position of the cheetah pos = [[] for _ in range(args.max_rollouts_per_task)] start_pos = unwrapped_env.get_body_com("torso")[:2].copy() for episode_idx in range(num_episodes): curr_rollout_rew = [] pos[episode_idx].append(start_pos) if episode_idx == 0: if encoder is not None: # reset to prior curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior( 1) curr_latent_sample = curr_latent_sample[0].to(device) curr_latent_mean = curr_latent_mean[0].to(device) curr_latent_logvar = curr_latent_logvar[0].to(device) else: curr_latent_sample = curr_latent_mean = curr_latent_logvar = None if encoder is not None: episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) for step_idx in range(1, env._max_episode_steps + 1): if step_idx == 1: episode_prev_obs[episode_idx].append(start_obs_raw.clone()) else: episode_prev_obs[episode_idx].append(state.clone()) # act latent = utl.get_latent_for_policy( args, latent_sample=curr_latent_sample, latent_mean=curr_latent_mean, latent_logvar=curr_latent_logvar) _, action, _ = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True) (state, belief, task), (rew, rew_normalised), done, info = utl.env_step( env, action, args) state = state.float().reshape((1, -1)).to(device) task = task.view(-1) if task is not None else None # keep track of position pos[episode_idx].append( unwrapped_env.get_body_com("torso")[:2].copy()) if encoder is not None: # update task embedding curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( action.reshape(1, -1).float().to(device), state, rew.reshape(1, -1).float().to(device), hidden_state, return_prior=False) episode_latent_samples[episode_idx].append( curr_latent_sample[0].clone()) episode_latent_means[episode_idx].append( curr_latent_mean[0].clone()) episode_latent_logvars[episode_idx].append( curr_latent_logvar[0].clone()) episode_next_obs[episode_idx].append(state.clone()) episode_rewards[episode_idx].append(rew.clone()) episode_actions[episode_idx].append(action.clone()) if info[0]['done_mdp'] and not done: start_obs_raw = info[0]['start_state'] start_obs_raw = torch.from_numpy( start_obs_raw).float().reshape((1, -1)).to(device) start_pos = unwrapped_env.get_body_com("torso")[:2].copy() break episode_returns.append(sum(curr_rollout_rew)) episode_lengths.append(step_idx) # clean up if encoder is not None: episode_latent_means = [ torch.stack(e) for e in episode_latent_means ] episode_latent_logvars = [ torch.stack(e) for e in episode_latent_logvars ] episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] episode_next_obs = [torch.cat(e) for e in episode_next_obs] episode_actions = [torch.stack(e) for e in episode_actions] episode_rewards = [torch.cat(e) for e in episode_rewards] # plot the movement of the ant # print(pos) plt.figure(figsize=(5, 4 * num_episodes)) min_dim = -3.5 max_dim = 3.5 span = max_dim - min_dim for i in range(num_episodes): plt.subplot(num_episodes, 1, i + 1) x = list(map(lambda p: p[0], pos[i])) y = list(map(lambda p: p[1], pos[i])) plt.plot(x[0], y[0], 'bo') plt.scatter(x, y, 1, 'g') curr_task = env.get_task() plt.title('task: {}'.format(curr_task), fontsize=15) if 'Goal' in args.env_name: plt.plot(curr_task[0], curr_task[1], 'rx') plt.ylabel('y-position (ep {})'.format(i), fontsize=15) if i == num_episodes - 1: plt.xlabel('x-position', fontsize=15) plt.ylabel('y-position (ep {})'.format(i), fontsize=15) plt.xlim(min_dim - 0.05 * span, max_dim + 0.05 * span) plt.ylim(min_dim - 0.05 * span, max_dim + 0.05 * span) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx)) plt.close() else: plt.show() if not return_pos: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns else: return episode_latent_means, episode_latent_logvars, \ episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ episode_returns, pos
def train(self): """ Main Meta-Training loop """ start_time = time.time() # reset environments prev_state, belief, task = utl.reset_env(self.envs, self.args) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_state[0].copy_(prev_state) # log once before training with torch.no_grad(): self.log(None, None, start_time) for self.iter_idx in range(self.num_updates): # First, re-compute the hidden states given the current rollouts (since the VAE might've changed) with torch.no_grad(): latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory( ) # add this initial hidden state to the policy storage assert len(self.policy_storage.latent_mean ) == 0 # make sure we emptied buffers self.policy_storage.hidden_states[0].copy_(hidden_state) self.policy_storage.latent_samples.append(latent_sample.clone()) self.policy_storage.latent_mean.append(latent_mean.clone()) self.policy_storage.latent_logvar.append(latent_logvar.clone()) # rollout policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action = utl.select_action( args=self.args, policy=self.policy, state=prev_state, belief=belief, task=task, deterministic=False, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, ) # take step in the environment [next_state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action, self.args) done = torch.from_numpy(np.array( done, dtype=int)).to(device).float().view((-1, 1)) # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) with torch.no_grad(): # compute next embedding (for next loop and/or value prediction bootstrap) latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_state, action=action, reward=rew_raw, done=done, hidden_state=hidden_state) # before resetting, update the embedding and add to vae buffer # (last state might include useful task info) if not (self.args.disable_decoder and self.args.disable_kl_term): self.vae.rollout_storage.insert( prev_state.clone(), action.detach().clone(), next_state.clone(), rew_raw.clone(), done.clone(), task.clone() if task is not None else None) # add the obs before reset to the policy storage self.policy_storage.next_state[step] = next_state.clone() # reset environments that are done done_indices = np.argwhere(done.cpu().flatten()).flatten() if len(done_indices) > 0: next_state, belief, task = utl.reset_env( self.envs, self.args, indices=done_indices, state=next_state) # TODO: deal with resampling for posterior sampling algorithm # latent_sample = latent_sample # latent_sample[i] = latent_sample[i] # add experience to policy buffer self.policy_storage.insert( state=next_state, belief=belief, task=task, actions=action, rewards_raw=rew_raw, rewards_normalised=rew_normalised, value_preds=value, masks=masks_done, bad_masks=bad_masks, done=done, hidden_states=hidden_state.squeeze(0), latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, ) prev_state = next_state self.frames += self.args.num_processes # --- UPDATE --- if self.args.precollect_len <= self.frames: # check if we are pre-training the VAE if self.args.pretrain_len > self.iter_idx: for p in range(self.args.num_vae_updates_per_pretrain): self.vae.compute_vae_loss( update=True, pretrain_index=self.iter_idx * self.args.num_vae_updates_per_pretrain + p) # otherwise do the normal update (policy + vae) else: train_stats = self.update(state=prev_state, belief=belief, task=task, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) # log run_stats = [ action, self.policy_storage.action_log_probs, value ] with torch.no_grad(): self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update() self.envs.close()