Example #1
0
def transform_mdp_to_bamdp_rollouts(vae, args, obs, actions, rewards, next_obs,
                                    terminals):
    '''

    :param vae:
    :param args:
    :param obs: shape (trajectory_len, n_rollouts, dim)
    :param actions:
    :param rewards:
    :param next_obs:
    :param terminals:
    :return:
    '''

    # augmented_obs = ptu.zeros((obs.shape[0], obs.shape[1] + 2 * args.task_embedding_size))
    augmented_obs = ptu.zeros((obs.shape[0], obs.shape[1],
                               obs.shape[2] + 2 * args.task_embedding_size))
    # augmented_next_obs = ptu.zeros((obs.shape[0], obs.shape[1] + 2 * args.task_embedding_size))
    augmented_next_obs = ptu.zeros(
        (obs.shape[0], obs.shape[1],
         obs.shape[2] + 2 * args.task_embedding_size))
    if args.belief_rewards:
        belief_rewards = ptu.zeros_like(rewards)
    else:
        belief_rewards = None

    with torch.no_grad():
        # _, mean, logvar, hidden_state = vae.encoder.prior(batch_size=1)
        _, mean, logvar, hidden_state = vae.encoder.prior(
            batch_size=obs.shape[1])
        augmented_obs[0, :, :] = torch.cat((obs[0], mean[0], logvar[0]),
                                           dim=-1)
    for step in range(args.trajectory_len):
        # update encoding
        _, mean, logvar, hidden_state = utl.update_encoding(
            encoder=vae.encoder,
            obs=next_obs[step].unsqueeze(0),
            action=actions[step].unsqueeze(0),
            reward=rewards[step].unsqueeze(0),
            done=terminals[step].unsqueeze(0),
            hidden_state=hidden_state)

        # augment data
        augmented_next_obs[step, :, :] = torch.cat(
            (next_obs[step], mean, logvar), dim=-1)
        if args.belief_rewards:
            with torch.no_grad():
                belief_rewards[step, :, :] = vae.compute_belief_reward(
                    mean.unsqueeze(dim=0), logvar.unsqueeze(dim=0),
                    obs[step].unsqueeze(dim=0),
                    next_obs[step].unsqueeze(dim=0),
                    actions[step].unsqueeze(dim=0))

    augmented_obs[1:, :, :] = augmented_next_obs[:-1, :, :].clone()

    return augmented_obs, belief_rewards, augmented_next_obs
Example #2
0
def get_augmented_obs(args,
                      obs,
                      posterior_sample=None,
                      task_mu=None,
                      task_std=None):

    obs_augmented = obs.clone()

    if posterior_sample is None:
        sample_embeddings = False
    else:
        sample_embeddings = args.sample_embeddings

    if not args.condition_policy_on_state:
        # obs_augmented = torchkit.zeros(0,).to(device)
        obs_augmented = ptu.zeros(0, )

    if sample_embeddings and (posterior_sample is not None):
        obs_augmented = torch.cat((obs_augmented, posterior_sample), dim=1)
    elif (task_mu is not None) and (task_std is not None):
        task_mu = task_mu.reshape((-1, task_mu.shape[-1]))
        task_std = task_std.reshape((-1, task_std.shape[-1]))
        obs_augmented = torch.cat((obs_augmented, task_mu, task_std), dim=-1)

    return obs_augmented
Example #3
0
 def rsample(self, return_pretanh_value=False):
     z = (self.normal_mean + self.normal_std * Variable(
         Normal(ptu.zeros(self.normal_mean.size()),
                ptu.ones(self.normal_std.size())).sample()))
     # z.requires_grad_()
     if return_pretanh_value:
         return torch.tanh(z), z
     else:
         return torch.tanh(z)
Example #4
0
def predict_rewards(learner, means, logvars):
    reward_preds = ptu.zeros([means.shape[0], learner.env.num_states])
    for t in range(reward_preds.shape[0]):
        task_samples = learner.vae.encoder._sample_gaussian(
            ptu.FloatTensor(means[t]), ptu.FloatTensor(logvars[t]), num=50)
        reward_preds[t, :] = learner.vae.reward_decoder(
            ptu.FloatTensor(task_samples), None).mean(dim=0).detach()

    return ptu.get_numpy(reward_preds)
Example #5
0
def update_step(vae, obs, actions, rewards, next_obs, args):
    episode_len, num_episodes, _ = obs.shape

    # get time-steps for ELBO computation
    if args.vae_batch_num_elbo_terms is not None:
        elbo_timesteps = np.stack([
            np.random.choice(range(0, args.trajectory_len + 1),
                             args.vae_batch_num_elbo_terms,
                             replace=False) for _ in range(num_episodes)
        ])
    else:
        elbo_timesteps = np.repeat(np.arange(0,
                                             args.trajectory_len + 1).reshape(
                                                 1, -1),
                                   num_episodes,
                                   axis=0)

    # pass through encoder (outputs will be: (max_traj_len+1) x number of rollouts x latent_dim -- includes the prior!)
    _, latent_mean, latent_logvar, _ = vae.encoder(actions=actions,
                                                   states=next_obs,
                                                   rewards=rewards,
                                                   hidden_state=None,
                                                   return_prior=True)

    rew_recon_losses, state_recon_losses, task_recon_losses, kl_terms = [], [], [], []

    # for each task we have in our batch
    for episode_idx in range(num_episodes):

        # get the embedding values (size: traj_length+1 * latent_dim; the +1 is for the prior)
        curr_means = latent_mean[:episode_len + 1, episode_idx, :]
        curr_logvars = latent_logvar[:episode_len + 1, episode_idx, :]
        # take one sample for each ELBO term
        curr_samples = vae.encoder._sample_gaussian(curr_means, curr_logvars)

        # select data from current rollout (result is traj_length * obs_dim)
        curr_obs = obs[:, episode_idx, :]
        curr_next_obs = next_obs[:, episode_idx, :]
        curr_actions = actions[:, episode_idx, :]
        curr_rewards = rewards[:, episode_idx, :]

        num_latents = curr_samples.shape[0]  # includes the prior
        num_decodes = curr_obs.shape[0]

        # expand the latent to match the (x, y) pairs of the decoder
        dec_embedding = curr_samples.unsqueeze(0).expand(
            (num_decodes, *curr_samples.shape)).transpose(1, 0)

        # expand the (x, y) pair of the encoder
        dec_obs = curr_obs.unsqueeze(0).expand((num_latents, *curr_obs.shape))
        dec_next_obs = curr_next_obs.unsqueeze(0).expand(
            (num_latents, *curr_next_obs.shape))
        dec_actions = curr_actions.unsqueeze(0).expand(
            (num_latents, *curr_actions.shape))
        dec_rewards = curr_rewards.unsqueeze(0).expand(
            (num_latents, *curr_rewards.shape))

        if args.decode_reward:
            # compute reconstruction loss for this trajectory
            # (for each timestep that was encoded, decode everything and sum it up)
            rrl = vae.compute_rew_reconstruction_loss(dec_embedding, dec_obs,
                                                      dec_next_obs,
                                                      dec_actions, dec_rewards)
            # sum along the trajectory which we decoded (sum in ELBO_t)
            if args.decode_only_past:
                curr_idx = 0
                past_reconstr_sum = []
                for i, idx_timestep in enumerate(elbo_timesteps[episode_idx]):
                    dec_until = idx_timestep
                    if dec_until != 0:
                        past_reconstr_sum.append(rrl[curr_idx:curr_idx +
                                                     dec_until].sum())
                    curr_idx += dec_until
                rrl = torch.stack(past_reconstr_sum)
            else:
                rrl = rrl.sum(dim=1)
            rew_recon_losses.append(rrl)
        if args.decode_state:
            srl = vae.compute_state_reconstruction_loss(
                dec_embedding, dec_obs, dec_next_obs, dec_actions)
            srl = srl.sum(dim=1)
            state_recon_losses.append(srl)
        if not args.disable_stochasticity_in_latent:
            # compute the KL term for each ELBO term of the current trajectory
            kl = vae.compute_kl_loss(curr_means, curr_logvars,
                                     elbo_timesteps[episode_idx])
            kl_terms.append(kl)

    # sum the ELBO_t terms per task
    if args.decode_reward:
        rew_recon_losses = torch.stack(rew_recon_losses)
        rew_recon_losses = rew_recon_losses.sum(dim=1)
    else:
        rew_recon_losses = ptu.zeros(1)  # 0 -- but with option of .mean()

    if args.decode_state:
        state_recon_losses = torch.stack(state_recon_losses)
        state_recon_losses = state_recon_losses.sum(dim=1)
    else:
        state_recon_losses = ptu.zeros(1)

    if not args.disable_stochasticity_in_latent:
        kl_terms = torch.stack(kl_terms)
        kl_terms = kl_terms.sum(dim=1)
    else:
        kl_terms = ptu.zeros(1)

    # make sure we can compute gradients
    if not args.disable_stochasticity_in_latent:
        assert kl_terms.requires_grad
    if args.decode_reward:
        assert rew_recon_losses.requires_grad
    if args.decode_state:
        assert state_recon_losses.requires_grad

    return rew_recon_losses.mean(), state_recon_losses.mean(), kl_terms.mean()
Example #6
0
 def forward(self, inputs):
     if self.output_size != 0:
         return self.activation_function(self.fc(inputs))
     else:
         # return torchkit.zeros(0, ).to(device)
         return ptu.zeros(0, )