def transform_mdp_to_bamdp_rollouts(vae, args, obs, actions, rewards, next_obs, terminals): ''' :param vae: :param args: :param obs: shape (trajectory_len, n_rollouts, dim) :param actions: :param rewards: :param next_obs: :param terminals: :return: ''' # augmented_obs = ptu.zeros((obs.shape[0], obs.shape[1] + 2 * args.task_embedding_size)) augmented_obs = ptu.zeros((obs.shape[0], obs.shape[1], obs.shape[2] + 2 * args.task_embedding_size)) # augmented_next_obs = ptu.zeros((obs.shape[0], obs.shape[1] + 2 * args.task_embedding_size)) augmented_next_obs = ptu.zeros( (obs.shape[0], obs.shape[1], obs.shape[2] + 2 * args.task_embedding_size)) if args.belief_rewards: belief_rewards = ptu.zeros_like(rewards) else: belief_rewards = None with torch.no_grad(): # _, mean, logvar, hidden_state = vae.encoder.prior(batch_size=1) _, mean, logvar, hidden_state = vae.encoder.prior( batch_size=obs.shape[1]) augmented_obs[0, :, :] = torch.cat((obs[0], mean[0], logvar[0]), dim=-1) for step in range(args.trajectory_len): # update encoding _, mean, logvar, hidden_state = utl.update_encoding( encoder=vae.encoder, obs=next_obs[step].unsqueeze(0), action=actions[step].unsqueeze(0), reward=rewards[step].unsqueeze(0), done=terminals[step].unsqueeze(0), hidden_state=hidden_state) # augment data augmented_next_obs[step, :, :] = torch.cat( (next_obs[step], mean, logvar), dim=-1) if args.belief_rewards: with torch.no_grad(): belief_rewards[step, :, :] = vae.compute_belief_reward( mean.unsqueeze(dim=0), logvar.unsqueeze(dim=0), obs[step].unsqueeze(dim=0), next_obs[step].unsqueeze(dim=0), actions[step].unsqueeze(dim=0)) augmented_obs[1:, :, :] = augmented_next_obs[:-1, :, :].clone() return augmented_obs, belief_rewards, augmented_next_obs
def get_augmented_obs(args, obs, posterior_sample=None, task_mu=None, task_std=None): obs_augmented = obs.clone() if posterior_sample is None: sample_embeddings = False else: sample_embeddings = args.sample_embeddings if not args.condition_policy_on_state: # obs_augmented = torchkit.zeros(0,).to(device) obs_augmented = ptu.zeros(0, ) if sample_embeddings and (posterior_sample is not None): obs_augmented = torch.cat((obs_augmented, posterior_sample), dim=1) elif (task_mu is not None) and (task_std is not None): task_mu = task_mu.reshape((-1, task_mu.shape[-1])) task_std = task_std.reshape((-1, task_std.shape[-1])) obs_augmented = torch.cat((obs_augmented, task_mu, task_std), dim=-1) return obs_augmented
def rsample(self, return_pretanh_value=False): z = (self.normal_mean + self.normal_std * Variable( Normal(ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size())).sample())) # z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
def predict_rewards(learner, means, logvars): reward_preds = ptu.zeros([means.shape[0], learner.env.num_states]) for t in range(reward_preds.shape[0]): task_samples = learner.vae.encoder._sample_gaussian( ptu.FloatTensor(means[t]), ptu.FloatTensor(logvars[t]), num=50) reward_preds[t, :] = learner.vae.reward_decoder( ptu.FloatTensor(task_samples), None).mean(dim=0).detach() return ptu.get_numpy(reward_preds)
def update_step(vae, obs, actions, rewards, next_obs, args): episode_len, num_episodes, _ = obs.shape # get time-steps for ELBO computation if args.vae_batch_num_elbo_terms is not None: elbo_timesteps = np.stack([ np.random.choice(range(0, args.trajectory_len + 1), args.vae_batch_num_elbo_terms, replace=False) for _ in range(num_episodes) ]) else: elbo_timesteps = np.repeat(np.arange(0, args.trajectory_len + 1).reshape( 1, -1), num_episodes, axis=0) # pass through encoder (outputs will be: (max_traj_len+1) x number of rollouts x latent_dim -- includes the prior!) _, latent_mean, latent_logvar, _ = vae.encoder(actions=actions, states=next_obs, rewards=rewards, hidden_state=None, return_prior=True) rew_recon_losses, state_recon_losses, task_recon_losses, kl_terms = [], [], [], [] # for each task we have in our batch for episode_idx in range(num_episodes): # get the embedding values (size: traj_length+1 * latent_dim; the +1 is for the prior) curr_means = latent_mean[:episode_len + 1, episode_idx, :] curr_logvars = latent_logvar[:episode_len + 1, episode_idx, :] # take one sample for each ELBO term curr_samples = vae.encoder._sample_gaussian(curr_means, curr_logvars) # select data from current rollout (result is traj_length * obs_dim) curr_obs = obs[:, episode_idx, :] curr_next_obs = next_obs[:, episode_idx, :] curr_actions = actions[:, episode_idx, :] curr_rewards = rewards[:, episode_idx, :] num_latents = curr_samples.shape[0] # includes the prior num_decodes = curr_obs.shape[0] # expand the latent to match the (x, y) pairs of the decoder dec_embedding = curr_samples.unsqueeze(0).expand( (num_decodes, *curr_samples.shape)).transpose(1, 0) # expand the (x, y) pair of the encoder dec_obs = curr_obs.unsqueeze(0).expand((num_latents, *curr_obs.shape)) dec_next_obs = curr_next_obs.unsqueeze(0).expand( (num_latents, *curr_next_obs.shape)) dec_actions = curr_actions.unsqueeze(0).expand( (num_latents, *curr_actions.shape)) dec_rewards = curr_rewards.unsqueeze(0).expand( (num_latents, *curr_rewards.shape)) if args.decode_reward: # compute reconstruction loss for this trajectory # (for each timestep that was encoded, decode everything and sum it up) rrl = vae.compute_rew_reconstruction_loss(dec_embedding, dec_obs, dec_next_obs, dec_actions, dec_rewards) # sum along the trajectory which we decoded (sum in ELBO_t) if args.decode_only_past: curr_idx = 0 past_reconstr_sum = [] for i, idx_timestep in enumerate(elbo_timesteps[episode_idx]): dec_until = idx_timestep if dec_until != 0: past_reconstr_sum.append(rrl[curr_idx:curr_idx + dec_until].sum()) curr_idx += dec_until rrl = torch.stack(past_reconstr_sum) else: rrl = rrl.sum(dim=1) rew_recon_losses.append(rrl) if args.decode_state: srl = vae.compute_state_reconstruction_loss( dec_embedding, dec_obs, dec_next_obs, dec_actions) srl = srl.sum(dim=1) state_recon_losses.append(srl) if not args.disable_stochasticity_in_latent: # compute the KL term for each ELBO term of the current trajectory kl = vae.compute_kl_loss(curr_means, curr_logvars, elbo_timesteps[episode_idx]) kl_terms.append(kl) # sum the ELBO_t terms per task if args.decode_reward: rew_recon_losses = torch.stack(rew_recon_losses) rew_recon_losses = rew_recon_losses.sum(dim=1) else: rew_recon_losses = ptu.zeros(1) # 0 -- but with option of .mean() if args.decode_state: state_recon_losses = torch.stack(state_recon_losses) state_recon_losses = state_recon_losses.sum(dim=1) else: state_recon_losses = ptu.zeros(1) if not args.disable_stochasticity_in_latent: kl_terms = torch.stack(kl_terms) kl_terms = kl_terms.sum(dim=1) else: kl_terms = ptu.zeros(1) # make sure we can compute gradients if not args.disable_stochasticity_in_latent: assert kl_terms.requires_grad if args.decode_reward: assert rew_recon_losses.requires_grad if args.decode_state: assert state_recon_losses.requires_grad return rew_recon_losses.mean(), state_recon_losses.mean(), kl_terms.mean()
def forward(self, inputs): if self.output_size != 0: return self.activation_function(self.fc(inputs)) else: # return torchkit.zeros(0, ).to(device) return ptu.zeros(0, )