Exemplo n.º 1
0
    def __init__(self,
                 latent_dim,
                 nets,
                 **kwargs
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.task_enc, self.cnn_enc, self.policy, self.qf1, self.qf2, self.vf = nets
        self.target_vf = self.vf.copy()
        self.recurrent = kwargs['recurrent']
        self.reparam = kwargs['reparameterize']
        self.use_ib = kwargs['use_information_bottleneck']
        self.tau = kwargs['soft_target_tau']
        self.reward_scale = kwargs['reward_scale']
        self.sparse_rewards = kwargs['sparse_rewards']
        self.det_z = False
        self.obs_emb_dim = kwargs['obs_emb_dim']

        self.q1_buff = []
        self.n_updates = 0

        # initialize task embedding to zero
        # (task, latent dim)
        self.register_buffer('z', torch.zeros(1, latent_dim))
        # for incremental update, must keep track of number of datapoints accumulated
        self.register_buffer('num_z', torch.zeros(1))

        # initialize posterior to the prior
        if self.use_ib:
            self.z_dists = [torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim))]
Exemplo n.º 2
0
 def dump_mixed_latents(self, epoch):
     n = 8
     batch, reconstructions = self.eval_data["test/last_batch"]
     x_t, env = batch["x_t"][:n], batch["env"][:n]
     z_comb = self.model.encode(x_t, env)
     z_pos = z_comb[:, :self.model.latent_sizes[0]]
     z_obj = z_comb[:, self.model.latent_sizes[0]:]
     grid = []
     for i in range(n):
         for j in range(n):
             if i + j == 0:
                 grid.append(
                     ptu.zeros(1, self.input_channels, self.imsize,
                               self.imsize))
             elif i == 0:
                 #grid.append(self.model.decode(torch.cat([z_pos[j], z_obj[i]], dim=1)))
                 grid.append(x_t[j].reshape(1, self.input_channels,
                                            self.imsize, self.imsize))
             elif j == 0:
                 #grid.append(self.model.decode(torch.cat([z_pos[j], z_obj[i]], dim=1)))
                 grid.append(env[i].reshape(1, self.input_channels,
                                            self.imsize, self.imsize))
             else:
                 z, z_c = z_pos[j].reshape(1, -1), z_obj[i].reshape(1, -1)
                 grid.append(self.model.decode(torch.cat([z, z_c], dim=1)))
     samples = torch.cat(grid)
     save_dir = osp.join(self.log_dir, 'mixed_latents_%d.png' % epoch)
     save_image(samples.data.cpu().transpose(2, 3), save_dir, nrow=n)
Exemplo n.º 3
0
    def forward(self, obs, context=None, cal_rew=True):
        ''' given context, get statistics under the current policy of a set of observations '''
        t, b, _ = obs.size()

        in_ = obs
        policy_outputs = self.policy(in_,
                                     reparameterize=True,
                                     return_log_prob=True)
        rew = None
        #in_=in_.view(t * b, -1)
        if cal_rew:
            encoder_output_next = self.context_encoder.forward_seq(context)
            z_mean_next = encoder_output_next[:, :, :self.latent_dim]
            z_var_next = F.softplus(encoder_output_next[:, :,
                                                        self.latent_dim:])
            var = ptu.ones(context.shape[0], 1, self.latent_dim)
            mean = ptu.zeros(context.shape[0], 1, self.latent_dim)
            z_mean = torch.cat([mean, z_mean_next], dim=1)[:, :-1, :]
            z_var = torch.cat([var, z_var_next], dim=1)[:, :-1, :]

            z_mean, z_var, z_mean_next, z_var_next = z_mean.contiguous(
            ), z_var.contiguous(), z_mean_next.contiguous(
            ), z_var_next.contiguous()
            z_mean, z_var, z_mean_next, z_var_next = z_mean.view(
                t * b, -1), z_var.view(t * b, -1), z_mean_next.view(
                    t * b, -1), z_var_next.view(t * b, -1)
            rew = self.compute_kl_div_vime(z_mean, z_var, z_mean_next,
                                           z_var_next)
            rew = rew.detach()

        return policy_outputs, rew  #, z_mean,z_var,z_mean_next,z_var_next
Exemplo n.º 4
0
    def from_vae_latents_to_lstm_latents(self, latents, lstm_hidden=None):
        batch_size, feature_size = latents.shape
        # print(latents.shape)
        lstm_input = latents
        lstm_input = lstm_input.view((1, batch_size, -1))

        if lstm_hidden is None:
            lstm_hidden = (ptu.zeros(self.lstm_num_layers, batch_size, self.lstm_hidden_size), \
                     ptu.zeros(self.lstm_num_layers, batch_size, self.lstm_hidden_size))

        h, hidden = self.lstm(
            lstm_input, lstm_hidden)  # [seq_len, batch_size, lstm_hidden_size]

        lstm_latent = self.lstm_fc(h)
        lstm_latent = lstm_latent.view((batch_size, -1))
        return lstm_latent
Exemplo n.º 5
0
 def compute_kl_div(self):
     ''' compute KL( q(z|c) || r(z) ) '''
     prior = torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim))
     posteriors = [torch.distributions.Normal(mu, torch.sqrt(var)) for mu, var in zip(torch.unbind(self.z_means), torch.unbind(self.z_vars))]
     kl_divs = [torch.distributions.kl.kl_divergence(post, prior) for post in posteriors]
     kl_div_sum = torch.sum(torch.stack(kl_divs))
     return kl_div_sum
Exemplo n.º 6
0
 def get_action(
     self,
     observation,
     use_raps_obs=False,
     use_true_actions=True,
     use_obs=True,
 ):
     """
     :param observation:
     :return: action, debug_dictionary
     """
     observation = ptu.from_numpy(np.array(observation))
     if self.state:
         prev_state, action = self.state
     else:
         prev_state = self.world_model.initial(observation.shape[0])
         action = ptu.zeros((observation.shape[0], self.action_dim))
     embed = self.world_model.encode(observation)
     new_state, _ = self.world_model.obs_step(prev_state, action, embed)
     feat = self.world_model.get_features(new_state)
     dist = self.actor(feat)
     action = dist.mode()
     if self.exploration:
         action = self.actor.compute_exploration_action(action, self.expl_amount)
     self.state = (new_state, action)
     return ptu.get_numpy(action), {"state": new_state}
def compute_log_p_log_q_log_d(
    model,
    data,
    decoder_distribution='bernoulli',
    num_latents_to_sample=1,
    sampling_method='importance_sampling'
):
    assert data.dtype == np.float64, 'images should be normalized'
    imgs = ptu.from_numpy(data)
    latent_distribution_params = model.encode(imgs)
    batch_size = data.shape[0]
    representation_size = model.representation_size
    log_p, log_q, log_d = ptu.zeros((batch_size, num_latents_to_sample)), ptu.zeros(
        (batch_size, num_latents_to_sample)), ptu.zeros((batch_size, num_latents_to_sample))
    true_prior = Normal(ptu.zeros((batch_size, representation_size)),
                        ptu.ones((batch_size, representation_size)))
    mus, logvars = latent_distribution_params
    for i in range(num_latents_to_sample):
        if sampling_method == 'importance_sampling':
            latents = model.rsample(latent_distribution_params)
        elif sampling_method == 'biased_sampling':
            latents = model.rsample(latent_distribution_params)
        elif sampling_method == 'true_prior_sampling':
            latents = true_prior.rsample()
        else:
            raise EnvironmentError('Invalid Sampling Method Provided')

        stds = logvars.exp().pow(.5)
        vae_dist = Normal(mus, stds)
        log_p_z = true_prior.log_prob(latents).sum(dim=1)
        log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1)
        if decoder_distribution == 'bernoulli':
            decoded = model.decode(latents)[0]
            log_d_x_given_z = torch.log(imgs * decoded + (1 - imgs) * (1 - decoded) + 1e-8).sum(dim=1)
        elif decoder_distribution == 'gaussian_identity_variance':
            _, obs_distribution_params = model.decode(latents)
            dec_mu, dec_logvar = obs_distribution_params
            dec_var = dec_logvar.exp()
            decoder_dist = Normal(dec_mu, dec_var.pow(.5))
            log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1)
        else:
            raise EnvironmentError('Invalid Decoder Distribution Provided')

        log_p[:, i] = log_p_z
        log_q[:, i] = log_q_z_given_x
        log_d[:, i] = log_d_x_given_z
    return log_p, log_q, log_d
Exemplo n.º 8
0
    def initial(self, batch_size):
        """
        :param batch_size: int

        :return state: Dict
            mean: (batch_size, stoch_size)
            std: (batch_size, stoch_size)
            deter: (batch_size, deter_size)
            stoch: (batch_size, stoch_size)
        """
        state = dict(
            mean=ptu.zeros([batch_size, self.stochastic_state_size]),
            std=ptu.zeros([batch_size, self.stochastic_state_size]),
            stoch=ptu.zeros([batch_size, self.stochastic_state_size]),
            deter=ptu.zeros([batch_size, self.deterministic_state_size]),
        )
        return state
Exemplo n.º 9
0
 def clear_z(self, num_tasks=1):
     if self.use_ib:
         self.z_dists = [torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) for _ in range(num_tasks)]
         z = [d.rsample() for d in self.z_dists]
         self.z = torch.stack(z)
     else:
         self.z = self.z.new_full((num_tasks, self.latent_dim), 0)
     self.task_enc.reset(num_tasks) # clear hidden state in recurrent case
Exemplo n.º 10
0
 def clear_z(self, num_tasks=1):
     '''
     reset q(z|c) to the prior
     sample a new z from the prior
     '''
     self.z = ptu.zeros(num_tasks, self.pie_hidden_dim)
     self.context = None
     self.pie_snail.reset(num_tasks)
Exemplo n.º 11
0
 def clear_z(self, num_tasks=1):
     '''
     reset q(z|c) to the prior
     sample a new z from the prior
     '''
     # reset distribution over z to the prior
     mu = ptu.zeros(num_tasks, self.latent_dim)
     if self.use_ib:
         var = ptu.ones(num_tasks, self.latent_dim)
     else:
         var = ptu.zeros(num_tasks, self.latent_dim)
     self.z_means = mu
     self.z_vars = var
     # sample a new z from the prior
     self.sample_z()
     # reset the context collected so far
     self.context = None
Exemplo n.º 12
0
 def rsample(self):
     z = (self.normal_means + self.normal_stds * MultivariateDiagonalNormal(
         ptu.zeros(self.normal_means.size()),
         ptu.ones(self.normal_stds.size())).sample())
     z.requires_grad_()
     c = self.categorical.sample()[:, :, None]
     s = torch.gather(z, dim=2, index=c)
     return s[:, :, 0]
Exemplo n.º 13
0
 def rsample(self):
     z = (self.normal_means + self.normal_stds * MultivariateDiagonalNormal(
         ptu.zeros(self.normal_means.size()),
         ptu.ones(self.normal_stds.size())).sample())
     z.requires_grad_()
     c = self.categorical.sample()[:, :, None]
     s = torch.matmul(z, c)
     return torch.squeeze(s, 2)
Exemplo n.º 14
0
    def __init__(
        self,
        env,
        policy,
        qf1,
        qf2,
        target_qf1,
        target_qf2,
        discount=0.99,
        reward_scale=1.0,
        policy_lr=1e-3,
        qf_lr=1e-3,
        optimizer_class=optim.Adam,
        soft_target_tau=1e-2,
        target_update_period=1,
        plotter=None,
        render_eval_paths=False,
        use_automatic_entropy_tuning=True,
        target_entropy=None,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(
                    self.env.action_space.shape).item(
                    )  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class([self.log_alpha],
                                                   lr=policy_lr)

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(self.policy.parameters(),
                                                lr=policy_lr)
        self.qf1_optimizer = optimizer_class(self.qf1.parameters(), lr=qf_lr)
        self.qf2_optimizer = optimizer_class(self.qf2.parameters(), lr=qf_lr)

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
Exemplo n.º 15
0
 def compute_kl_div(self):
     prior = torch.distributions.Normal(ptu.zeros(self.latent_dim),
                                        ptu.ones(self.latent_dim))
     kl_divs = [
         torch.distributions.kl.kl_divergence(z_dist, prior)
         for z_dist in self.z_dists
     ]
     kl_div_sum = torch.sum(torch.stack(kl_divs))
     return kl_div_sum
Exemplo n.º 16
0
 def rsample(self, return_pretanh_value=False):
     z = (self.normal_mean + self.normal_std * Variable(
         Normal(ptu.zeros(self.normal_mean.size()),
                ptu.ones(self.normal_std.size())).sample()))
     # z.requires_grad_()
     if return_pretanh_value:
         return torch.tanh(z), z
     else:
         return torch.tanh(z)
Exemplo n.º 17
0
 def get_full_edges(self, num_node):
     edges = ptu.zeros(2, num_node * num_node, dtype=int)
     edges[0, :] = torch.arange(num_node).repeat(num_node, 1).transpose(
         0, 1).reshape(num_node * num_node)
     edges[1, :] = torch.arange(num_node).repeat(1, num_node).reshape(
         num_node * num_node)
     if not self.contain_self_loop:
         edges = pyg_utils.remove_self_loops(edges)[0]
     return edges
Exemplo n.º 18
0
 def clear_z(self, num_tasks=1):
     '''
     reset q(z|c) to the prior
     sample a new z from the prior
     '''
     # reset distribution over z to the prior
     mu = ptu.zeros(num_tasks, self.latent_dim)
     if self.use_ib:
         var = ptu.ones(num_tasks, self.latent_dim)
     else:
         var = ptu.zeros(num_tasks, self.latent_dim)
     self.z_means = mu
     self.z_vars = var
     # sample a new z from the prior
     # reset the context collected so far
     self.context = None
     # reset any hidden state in the encoder network (relevant for RNN)
     self.context_encoder.reset(num_tasks)
Exemplo n.º 19
0
 def rsample_with_pretanh(self):
     z = (
             self.normal_mean +
             self.normal_std *
             MultivariateDiagonalNormal(
                 ptu.zeros(self.normal_mean.size()),
                 ptu.ones(self.normal_std.size())
             ).sample()
     )
     return torch.tanh(z), z
Exemplo n.º 20
0
 def clear_z(self, num_tasks=1):
     '''
     reset q(z|c) to the prior
     sample a new z from the prior
     '''
     #  reset distribution over z to the prior
     mu = ptu.zeros(num_tasks, self.latent_dim)
     var = ptu.ones(num_tasks, self.latent_dim)
     self.z_means = mu
     self.z_vars = var
Exemplo n.º 21
0
    def mle_estimate(self):
        """Return the mean of the most likely component.

        This often computes the mode of the distribution, but not always.
        """
        c = ptu.zeros(self.weights.shape[:2])
        ind = torch.argmax(self.weights, dim=1)  # [:, 0]
        c.scatter_(1, ind, 1)
        s = torch.matmul(self.normal_means, c[:, :, None])
        return torch.squeeze(s, 2)
Exemplo n.º 22
0
 def get_tau(self, actions, fp=None):
     if self.tau_type == 'fix':
         presum_tau = ptu.zeros(len(actions), self.num_quantiles) + 1. / self.num_quantiles
     elif self.tau_type == 'iqn':  # add 0.1 to prevent tau getting too close
         presum_tau = ptu.rand(len(actions), self.num_quantiles) + 0.1
         presum_tau /= presum_tau.sum(dim=-1, keepdims=True)
     tau = torch.cumsum(presum_tau, dim=1)  # (N, T), note that they are tau1...tauN in the paper
     with torch.no_grad():
         tau_hat = ptu.zeros_like(tau)
         tau_hat[:, 0:1] = tau[:, 0:1] / 2.
         tau_hat[:, 1:] = (tau[:, 1:] + tau[:, :-1]) / 2.
     return tau, tau_hat, presum_tau
Exemplo n.º 23
0
    def rsample(self, return_pretanh_value=False):
        """
        Sampling in the reparameterization case.
        """
        z = (self.normal_mean + self.normal_std *
             Normal(ptu.zeros(self.normal_mean.size()),
                    ptu.ones(self.normal_std.size())).sample())
        z.requires_grad_()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)
Exemplo n.º 24
0
    def clear_sequence_z(self, num_tasks=1, batch_size=1, traj_batch_size=1): 
        assert self.recurrent_context_encoder != None
        if self.r_cat_dim > 0:
            self.seq_z_cat = ptu.ones(num_tasks * batch_size * self.r_n_cat, self.r_cat_dim) / self.r_cat_dim
            self.seq_z_next_cat = None
        if self.r_cont_dim > 0:
            self.seq_z_cont_mean = ptu.zeros(num_tasks * batch_size, self.r_cont_dim)
            self.seq_z_cont_var = ptu.ones(num_tasks * batch_size, self.r_cont_dim)
            self.seq_z_next_cont_mean = None
            self.seq_z_next_cont_var = None
        if self.r_dir_dim > 0:
            if self.r_constraint == 'logitnormal':
                self.seq_z_dir_mean = ptu.zeros(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim)
                self.seq_z_dir_var = ptu.ones(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) * self.r_var
                self.seq_z_next_dir_mean = None
                self.seq_z_next_dir_var = None
            elif self.r_constraint == 'dirichlet':
                self.seq_z_dir = ptu.ones(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) * self.r_alpha
                self.seq_z_next_dir = None

        self.sample_sequence_z()
        self.recurrent_context_encoder.reset(num_tasks*traj_batch_size)
Exemplo n.º 25
0
    def __init__(
            self,
            env,
            dsp,
            policy,
            classifier,
            search_buffer,

            policy_lr=1e-3,
            classifier_lr=1e-3,
            optimizer_class=optim.Adam,

            use_automatic_entropy_tuning=True,
            target_entropy=None,
    ):
        super().__init__()
        self.env = env
        self.dsp = dsp
        self.policy = policy
        self.classifier = classifier
        self.search_buffer = search_buffer

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.classifier_criterion = nn.MSELoss()
        self.dsp_optimizer = optimizer_class(
            self.dsp.parameters(),
            lr=policy_lr,
        )
        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.classisier_optimizer = optimizer_class(
            self.classifier.parameters(),
            lr=classifier_lr,
        )

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
Exemplo n.º 26
0
    def forward(self, *inputs):
        x = super().forward(*inputs)
        mean = self.mean_layer(x)
        logstd = self.logstd_layer(x)
        logstd = torch.clamp(logstd, LOGMIN, LOGMAX)
        std = torch.exp(logstd)

        unit_normal = Normal(ptu.zeros(mean.size()), ptu.ones(std.size()))
        eps = unit_normal.sample()
        pre_tanh_z = mean.cpu() + std.cpu() * eps
        action = torch.tanh(pre_tanh_z)
        logp = unit_normal.log_prob(eps)  #
        logp = logp.sum(dim=1, keepdim=True)  # logsum = exp mult
        return action, pre_tanh_z, logp, mean, logstd
Exemplo n.º 27
0
    def compute_features(self, dataloader):
        self.prepare_for_inference()
        # features = np.zeros((self.num_trajectories, self.episode_length, self.feature_size), dtype=np.float32)
        features = ptu.zeros([self.num_trajectories, self.episode_length, self.feature_size], dtype=torch.float32)

        with torch.no_grad():
            for i, (input_tensor,) in enumerate(dataloader):
                # feature_batch = self.model(input_tensor.cuda()).cpu().numpy()
                feature_batch = self.forward(input_tensor)

                if i < len(dataloader) - 1:
                    features[i * self.batch_size_trajectory: (i + 1) * self.batch_size_trajectory] = feature_batch
                else:
                    features[i * self.batch_size_trajectory:] = feature_batch
        return features
Exemplo n.º 28
0
    def __init__(self,
                 latent_dim,
                 context_encoder,
                 policy,
                 reward_predictor,
                 use_next_obs_in_context=False,
                 _debug_ignore_context=False,
                 _debug_do_not_sqrt=False,
                 _debug_use_ground_truth_context=False):
        super().__init__()
        self.latent_dim = latent_dim

        self.context_encoder = context_encoder
        self.policy = policy
        self.reward_predictor = reward_predictor
        self.deterministic_policy = MakeDeterministic(self.policy)
        self._debug_ignore_context = _debug_ignore_context
        self._debug_use_ground_truth_context = _debug_use_ground_truth_context

        # self.recurrent = kwargs['recurrent']
        # self.use_ib = kwargs['use_information_bottleneck']
        # self.sparse_rewards = kwargs['sparse_rewards']
        self.use_next_obs_in_context = use_next_obs_in_context

        # initialize buffers for z dist and z
        # use buffers so latent context can be saved along with model weights
        self.register_buffer('z', torch.zeros(1, latent_dim))
        self.register_buffer('z_means', torch.zeros(1, latent_dim))
        self.register_buffer('z_vars', torch.zeros(1, latent_dim))

        self.z_means = None
        self.z_vars = None
        self.context = None
        self.z = None

        # rp = reward predictor
        # TODO: add back in reward predictor code
        self.z_means_rp = None
        self.z_vars_rp = None
        self.z_rp = None
        self.context_encoder_rp = context_encoder
        self._use_context_encoder_snapshot_for_reward_pred = False

        self.latent_prior = torch.distributions.Normal(
            ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim))

        self._debug_do_not_sqrt = _debug_do_not_sqrt
Exemplo n.º 29
0
    def __init__(self,
                 hidden_sizes,
                 obs_dim,
                 action_dim,
                 std=None,
                 init_w=1e-3,
                 min_log_std=None,
                 max_log_std=None,
                 num_gaussians=1,
                 std_architecture="shared",
                 **kwargs):
        super().__init__(
            hidden_sizes,
            input_size=obs_dim,
            output_size=action_dim * num_gaussians,
            init_w=init_w,
            # output_activation=torch.tanh,
            **kwargs)
        self.action_dim = action_dim
        self.num_gaussians = num_gaussians
        self.min_log_std = min_log_std
        self.max_log_std = max_log_std
        self.log_std = None
        self.std = std
        self.std_architecture = std_architecture
        if std is None:
            last_hidden_size = obs_dim
            if len(hidden_sizes) > 0:
                last_hidden_size = hidden_sizes[-1]

            if self.std_architecture == "shared":
                self.last_fc_log_std = nn.Linear(last_hidden_size,
                                                 action_dim * num_gaussians)
                self.last_fc_log_std.weight.data.uniform_(-init_w, init_w)
                self.last_fc_log_std.bias.data.uniform_(-init_w, init_w)
            elif self.std_architecture == "values":
                self.log_std_logits = nn.Parameter(
                    ptu.zeros(action_dim * num_gaussians, requires_grad=True))
            else:
                raise ValueError(self.std_architecture)
        else:
            self.log_std = np.log(std)
            assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX
        self.last_fc_weights = nn.Linear(last_hidden_size,
                                         action_dim * num_gaussians)
        self.last_fc_weights.weight.data.uniform_(-init_w, init_w)
        self.last_fc_weights.bias.data.uniform_(-init_w, init_w)
Exemplo n.º 30
0
    def compute_density(self, data):
        orig_data_length = len(data)
        data = np.vstack([data for _ in range(self.n_average)])
        data = ptu.from_numpy(data)
        if self.mode == 'biased':
            latents, means, log_vars, stds = (
                self.encoder.get_encoding_and_suff_stats(data))
            importance_weights = ptu.ones(data.shape[0])
        elif self.mode == 'prior':
            latents = ptu.randn(len(data), self.z_dim)
            importance_weights = ptu.ones(data.shape[0])
        elif self.mode == 'importance_sampling':
            latents, means, log_vars, stds = (
                self.encoder.get_encoding_and_suff_stats(data))
            prior = Normal(ptu.zeros(1), ptu.ones(1))
            prior_log_prob = prior.log_prob(latents).sum(dim=1)

            encoder_distrib = Normal(means, stds)
            encoder_log_prob = encoder_distrib.log_prob(latents).sum(dim=1)

            importance_weights = (prior_log_prob - encoder_log_prob).exp()
        else:
            raise NotImplementedError()

        unweighted_data_log_prob = self.compute_log_prob(
            data, self.decoder, latents).squeeze(1)
        unweighted_data_prob = unweighted_data_log_prob.exp()
        unnormalized_data_prob = unweighted_data_prob * importance_weights
        """
        Average over `n_average`
        """
        dp_split = torch.split(unnormalized_data_prob, orig_data_length, dim=0)
        # pre_avg.shape = ORIG_LEN x N_AVERAGE
        dp_stacked = torch.stack(dp_split, dim=1)
        # final.shape = ORIG_LEN
        unnormalized_dp = torch.sum(dp_stacked, dim=1, keepdim=False)
        """
        Compute the importance weight denomintors.
        This requires summing across the `n_average` dimension.
        """
        iw_split = torch.split(importance_weights, orig_data_length, dim=0)
        iw_stacked = torch.stack(iw_split, dim=1)
        iw_denominators = iw_stacked.sum(dim=1, keepdim=False)

        final = unnormalized_dp / iw_denominators
        return ptu.get_numpy(final)