def min_aggregate(self, goal_latents, goal_latents_recon, goal_image, pred_latents, pred_latents_recon, pred_image): n_goal_latents = goal_latents_recon.shape[0] # Compare against each goal latent costs = [] # (n_goal_latents, n_actions) latent_idxs = [] # (n_goal_latents, n_actions), [a,b] is an index corresponding to a latent for i in range(n_goal_latents): # Going through all n_goal_latents goal latents single_costs = self.get_single_costs(goal_latents[i], goal_latents_recon[i], pred_latents, pred_latents_recon) #(K, n_actions) min_costs, latent_idx = single_costs.min(-1) # take min among K, size is (n_actions) costs.append(min_costs) latent_idxs.append(latent_idx) costs = torch.stack(costs) # (n_goal_latents, n_actions) latent_idxs = torch.stack(latent_idxs) # (n_goal_latents, n_actions) #Sort by sum cost #Image contains the following: Pred_images, goal_latent_reconstructions, and # corresponding pred_latent_reconstructions #For every latent in goal latents, find corresponding predicted one (this is in latent_idxs) # Should have something that is (K, num_actions) -> x[a,b] is index for pred_latents_recon min_costs, min_goal_latent_idx = costs.min(0) #(num_actions) sorted_costs, best_action_idxs = min_costs.sort() #(num_actions) if self.plot_actions: sorted_pred_images = pred_image[best_action_idxs] corresponding_pred_latent_recons = [] for i in range(n_goal_latents): tmp = pred_latents_recon[best_action_idxs, latent_idxs[i, best_action_idxs]] #(n_actions, 3, 64, 64) corresponding_pred_latent_recons.append(tmp) corresponding_pred_latent_recons = torch.stack(corresponding_pred_latent_recons) #(n_goal_latents, n_actions, 3, 64, 64) corresponding_costs = costs[:, best_action_idxs] # (n_goal_latents, n_actions) # pdb.set_trace() min_corresponding_latent_recon = pred_latents_recon[best_action_idxs, latent_idxs[min_goal_latent_idx[best_action_idxs], best_action_idxs]] #(n_actions, 3, 64, 64) # pdb.set_trace() full_plot = torch.cat([sorted_pred_images.unsqueeze(0), # (1, n_actions, 3, 64, 64) corresponding_pred_latent_recons, # (n_goal_latents=K, n_actions, 3, 64, 64) min_corresponding_latent_recon.unsqueeze(0) #(1, n_actions, 3, 64, 64) ], 0) # (n_goal_latents+2, n_actions, 3, 64, 64) plot_size = self.plot_actions full_plot = full_plot[:, :plot_size] # (n_goal_latents+2, plot_size, 3, 64, 64) # Add goal latents tmp = torch.cat([goal_image, goal_latents_recon, goal_image], dim=0).unsqueeze(1) # (n_goal_latents+2, 1, 3, 64, 64) full_plot = torch.cat([tmp, full_plot], dim=1) # (n_goal_latents+2, plot_size+1, 3, 64, 64) #Add captions caption = np.zeros(full_plot.shape[:2]) caption[0, 1:] = ptu.get_numpy(sorted_costs[:plot_size]) caption[1:1 + n_goal_latents, 1:] = ptu.get_numpy(corresponding_costs[:, :plot_size]) plot_multi_image(ptu.get_numpy(full_plot), '{}/mpc_pred_{}.png'.format(self.logging_directory, self.image_suffix), caption=caption) return ptu.get_numpy(sorted_costs), ptu.get_numpy(best_action_idxs), ptu.get_numpy(min_goal_latent_idx)
def plot_action_errors(self, env, actions, pred_recons, file_name): errors = env.get_action_error(ptu.get_numpy(actions)) # (B) np full_plot = pred_recons.view( [5, -1] + list(pred_recons.shape[1:])) # (5,B//5,3,D,D) caption = np.reshape(errors, (5, -1)) # (5,B//5) np plot_multi_image(ptu.get_numpy(full_plot), '{}/{}.png'.format(self.logging_dir, file_name), caption=caption)