def compute_beliefs(env, args, reward_decoder, latent_mean, latent_logvar, goal): num_cells = env.observation_space.high[0] + 1 unwrapped_env = env.venv.unwrapped.envs[0] if not args.disable_stochasticity_in_latent: # take several samples fromt he latent distribution samples = utl.sample_gaussian(latent_mean.view(-1), latent_logvar.view(-1), 100) else: samples = torch.cat((latent_mean.view(-1), latent_logvar.view(-1))).unsqueeze(0) # compute reward predictions for those if reward_decoder.multi_head: rew_pred_means = torch.mean(reward_decoder(samples, None), dim=0) # .reshape((1, -1)) rew_pred_vars = torch.var(reward_decoder(samples, None), dim=0) # .reshape((1, -1)) else: tsm = [] tsv = [] for st in range(num_cells ** 2): task_id = unwrapped_env.id_to_task(torch.tensor([st])) curr_state = unwrapped_env.goal_to_onehot_id(task_id).expand((samples.shape[0], 2)) if unwrapped_env.oracle: if isinstance(goal, np.ndarray): goal = torch.from_numpy(goal) curr_state = torch.cat((curr_state, goal.repeat(curr_state.shape[0], 1).float()), dim=1) tsm.append(torch.mean(reward_decoder(samples, curr_state))) tsv.append(torch.var(reward_decoder(samples, curr_state))) rew_pred_means = torch.stack(tsm).reshape((1, -1)) rew_pred_vars = torch.stack(tsv).reshape((1, -1)) # rew_pred_means = rew_pred_means[-1][0] return rew_pred_means, rew_pred_vars
def plot_vae_loss(args, latent_means, latent_logvars, prev_obs, next_obs, actions, rewards, task, image_folder, iter_idx, reward_decoder, state_decoder, task_decoder, compute_task_reconstruction_loss, compute_rew_reconstruction_loss, compute_state_reconstruction_loss, compute_kl_loss): num_rollouts = len(latent_means) num_episode_steps = len(latent_means[0]) if not args.disable_stochasticity_in_latent: num_samples = 10 # how many samples to use to get an average/std ELBO loss else: num_samples = 1 latent_means = torch.cat(latent_means) latent_logvars = torch.cat(latent_logvars) prev_obs = torch.cat(prev_obs).to(device) next_obs = torch.cat(next_obs).to(device) actions = torch.cat(actions).to(device) rewards = torch.cat(rewards).to(device) # - we will try to make predictions for each tuple in trajectory, hence we need to expand the targets prev_obs = prev_obs.unsqueeze(0).expand(num_samples, *prev_obs.shape).to(device) next_obs = next_obs.unsqueeze(0).expand(num_samples, *next_obs.shape).to(device) actions = actions.unsqueeze(0).expand(num_samples, *actions.shape).to(device) rewards = rewards.unsqueeze(0).expand(num_samples, *rewards.shape).to(device) rew_reconstr_mean = [] rew_reconstr_std = [] rew_pred_std = [] state_reconstr_mean = [] state_reconstr_std = [] state_pred_std = [] task_reconstr_mean = [] task_reconstr_std = [] task_pred_std = [] # compute the sum of ELBO_t's by looping through (trajectory length + prior) for i in range(len(latent_means)): curr_latent_mean = latent_means[i] curr_latent_logvar = latent_logvars[i] # compute the reconstruction loss if not args.disable_stochasticity_in_latent: # take several samples from the latent distribution latent_samples = utl.sample_gaussian(curr_latent_mean.view(-1), curr_latent_logvar.view(-1), num_samples) else: latent_samples = torch.cat( (curr_latent_mean.view(-1), curr_latent_logvar.view(-1))).unsqueeze(0) # expand: each latent sample will be used to make predictions for the entire trajectory len_traj = prev_obs.shape[1] # compute reconstruction losses if task_decoder is not None: loss_task, task_pred = compute_task_reconstruction_loss( latent_samples, task, return_predictions=True) # average/std across the different samples task_reconstr_mean.append(loss_task.mean()) task_reconstr_std.append(loss_task.std()) task_pred_std.append(task_pred.std()) latent_samples = latent_samples.unsqueeze(1).expand( num_samples, len_traj, latent_samples.shape[-1]) if reward_decoder is not None: loss_rew, rew_pred = compute_rew_reconstruction_loss( latent_samples, prev_obs, next_obs, actions, rewards, return_predictions=True) # sum along length of trajectory loss_rew = loss_rew.sum(dim=1) rew_pred = rew_pred.sum(dim=1) # average/std across the different samples rew_reconstr_mean.append(loss_rew.mean()) rew_reconstr_std.append(loss_rew.std()) rew_pred_std.append(rew_pred.std()) if state_decoder is not None: loss_state, state_pred = compute_state_reconstruction_loss( latent_samples, prev_obs, next_obs, actions, return_predictions=True) # sum along length of trajectory loss_state = loss_state.sum(dim=1) state_pred = state_pred.sum(dim=1) # average/std across the different samples state_reconstr_mean.append(loss_state.mean()) state_reconstr_std.append(loss_state.std()) state_pred_std.append(state_pred.std()) # kl term vae_kl_term = compute_kl_loss(latent_means, latent_logvars, None) # --- plot KL term --- x = range(len(vae_kl_term)) plt.plot(x, vae_kl_term.cpu().detach().numpy(), 'b-') vae_kl_term = vae_kl_term.cpu() for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]): span = vae_kl_term.max() - vae_kl_term.min() plt.plot( [tj + 0.5, tj + 0.5], [vae_kl_term.min() - span * 0.05, vae_kl_term.max() + span * 0.05], 'k--', alpha=0.5) plt.xlabel('env steps', fontsize=15) plt.ylabel('KL term', fontsize=15) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_kl'.format(image_folder, iter_idx)) plt.close() else: plt.show() # --- plot rew reconstruction --- if reward_decoder is not None: rew_reconstr_mean = torch.stack( rew_reconstr_mean).detach().cpu().numpy() rew_reconstr_std = torch.stack(rew_reconstr_std).detach().cpu().numpy() rew_pred_std = torch.stack(rew_pred_std).detach().cpu().numpy() plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) p = plt.plot(x, rew_reconstr_mean, 'b-') plt.gca().fill_between(x, rew_reconstr_mean - rew_reconstr_std, rew_reconstr_mean + rew_reconstr_std, facecolor=p[0].get_color(), alpha=0.1) for tj in np.cumsum( [0, *[num_episode_steps for _ in range(num_rollouts)]]): min_y = (rew_reconstr_mean - rew_reconstr_std).min() max_y = (rew_reconstr_mean + rew_reconstr_std).max() span = max_y - min_y plt.plot([tj + 0.5, tj + 0.5], [min_y - span * 0.05, max_y + span * 0.05], 'k--', alpha=0.5) plt.xlabel('env steps', fontsize=15) plt.ylabel('reward reconstruction error', fontsize=15) plt.subplot(1, 2, 2) plt.plot(x, rew_pred_std, 'b-') for tj in np.cumsum( [0, *[num_episode_steps for _ in range(num_rollouts)]]): span = rew_pred_std.max() - rew_pred_std.min() plt.plot([tj + 0.5, tj + 0.5], [ rew_pred_std.min() - span * 0.05, rew_pred_std.max() + span * 0.05 ], 'k--', alpha=0.5) plt.xlabel('env steps', fontsize=15) plt.ylabel('std of rew reconstruction', fontsize=15) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_rew_reconstruction'.format( image_folder, iter_idx)) plt.close() else: plt.show() # --- plot state reconstruction --- if state_decoder is not None: plt.figure(figsize=(12, 5)) state_reconstr_mean = torch.stack( state_reconstr_mean).detach().cpu().numpy() state_reconstr_std = torch.stack( state_reconstr_std).detach().cpu().numpy() state_pred_std = torch.stack(state_pred_std).detach().cpu().numpy() plt.subplot(1, 2, 1) p = plt.plot(x, state_reconstr_mean, 'b-') plt.gca().fill_between(x, state_reconstr_mean - state_reconstr_std, state_reconstr_mean + state_reconstr_std, facecolor=p[0].get_color(), alpha=0.1) for tj in np.cumsum( [0, *[num_episode_steps for _ in range(num_rollouts)]]): min_y = (state_reconstr_mean - state_reconstr_std).min() max_y = (state_reconstr_mean + state_reconstr_std).max() span = max_y - min_y plt.plot([tj + 0.5, tj + 0.5], [min_y - span * 0.05, max_y + span * 0.05], 'k--', alpha=0.5) plt.xlabel('env steps', fontsize=15) plt.ylabel('state reconstruction error', fontsize=15) plt.subplot(1, 2, 2) plt.plot(x, state_pred_std, 'b-') for tj in np.cumsum( [0, *[num_episode_steps for _ in range(num_rollouts)]]): span = state_pred_std.max() - state_pred_std.min() plt.plot([tj + 0.5, tj + 0.5], [ state_pred_std.min() - span * 0.05, state_pred_std.max() + span * 0.05 ], 'k--', alpha=0.5) plt.xlabel('env steps', fontsize=15) plt.ylabel('std of state reconstruction', fontsize=15) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_state_reconstruction'.format( image_folder, iter_idx)) plt.close() else: plt.show() # --- plot task reconstruction --- if task_decoder is not None: plt.figure(figsize=(12, 5)) task_reconstr_mean = torch.stack( task_reconstr_mean).detach().cpu().numpy() task_reconstr_std = torch.stack( task_reconstr_std).detach().cpu().numpy() task_pred_std = torch.stack(task_pred_std).detach().cpu().numpy() plt.subplot(1, 2, 1) p = plt.plot(x, task_reconstr_mean, 'b-') plt.gca().fill_between(x, task_reconstr_mean - task_reconstr_std, task_reconstr_mean + task_reconstr_std, facecolor=p[0].get_color(), alpha=0.1) for tj in np.cumsum( [0, *[num_episode_steps for _ in range(num_rollouts)]]): min_y = (task_reconstr_mean - task_reconstr_std).min() max_y = (task_reconstr_mean + task_reconstr_std).max() span = max_y - min_y plt.plot([tj + 0.5, tj + 0.5], [min_y - span * 0.05, max_y + span * 0.05], 'k--', alpha=0.5) plt.xlabel('env steps', fontsize=15) plt.ylabel('task reconstruction error', fontsize=15) plt.subplot(1, 2, 2) plt.plot(x, task_pred_std, 'b-') for tj in np.cumsum( [0, *[num_episode_steps for _ in range(num_rollouts)]]): span = task_pred_std.max() - task_pred_std.min() plt.plot([tj + 0.5, tj + 0.5], [ task_pred_std.min() - span * 0.05, task_pred_std.max() + span * 0.05 ], 'k--', alpha=0.5) plt.xlabel('env steps', fontsize=15) plt.ylabel('std of task reconstruction', fontsize=15) plt.tight_layout() if image_folder is not None: plt.savefig('{}/{}_task_reconstruction'.format( image_folder, iter_idx)) plt.close() else: plt.show()