Ejemplo n.º 1
0
def compute_pi_logprob(mean_std_batch, action_arr):

    # mean_std_batch: (batch_size, 2)
    # action_arr: (batch_size, 1022)

    # (batch_size, )
    # using ln(1 + exp(param))
    permuted_mean_std_batch = mean_std_batch.permute(1, 0)  # ( 2, batch_size)

    permuted_action_arr = action_arr.permute(
        1, 0) if len(action_arr.shape) > 1 else action_arr

    if use_tanh:
        logprob = Normal(permuted_mean_std_batch[0],
                         F.softplus(permuted_mean_std_batch[1])).log_prob(
                             custom_atanh(permuted_action_arr))
        logprob = logprob.permute(1,
                                  0) if len(action_arr.shape) > 1 else logprob
        logprob -= torch.log(1 - torch.pow(action_arr, 2))
    else:
        logprob = Normal(
            permuted_mean_std_batch[0], F.softplus(
                permuted_mean_std_batch[1])).log_prob(permuted_action_arr)
        logprob = logprob.permute(1,
                                  0) if len(action_arr.shape) > 1 else logprob

    assert not torch.isnan(logprob).any()

    # (batch_size, actions_in_batch)
    return logprob
Ejemplo n.º 2
0
    def forward(self, ss: List, phase_use_mode: bool = False) -> Tuple:

        p_pres_logits, p_where_mean, p_where_std, p_depth_mean, \
        p_depth_std, p_what_mean, p_what_std = ss

        if phase_use_mode:
            z_pres = (p_pres_logits > 0).float()
        else:
            z_pres = RelaxedBernoulli(logits=p_pres_logits, temperature=self.args.train.tau_pres).rsample()

        # z_where_scale, z_where_shift: (bs, dim, num_cell, num_cell)
        if phase_use_mode:
            z_where_scale, z_where_shift = p_where_mean.chunk(2, 1)
        else:
            z_where_scale, z_where_shift = \
                Normal(p_where_mean, p_where_std).rsample().chunk(2, 1)

        # z_where_origin: (bs, dim, num_cell, num_cell)
        z_where_origin = \
            torch.cat([z_where_scale.detach(), z_where_shift.detach()], dim=1)

        z_where_shift = \
            (2. / self.args.arch.num_cell) * \
            (self.offset + 0.5 + torch.tanh(z_where_shift)) - 1.

        scale, ratio = z_where_scale.chunk(2, 1)
        scale = scale.sigmoid()
        ratio = torch.exp(ratio)
        ratio_sqrt = ratio.sqrt()
        z_where_scale = torch.cat([scale / ratio_sqrt, scale * ratio_sqrt], dim=1)
        # z_where: (bs, dim, num_cell, num_cell)
        z_where = torch.cat([z_where_scale, z_where_shift], dim=1)

        if phase_use_mode:
            z_depth = p_depth_mean
            z_what = p_what_mean
        else:
            z_depth = Normal(p_depth_mean, p_depth_std).rsample()
            z_what = Normal(p_what_mean, p_what_std).rsample()

        z_what_reshape = z_what.permute(0, 2, 3, 1).reshape(-1, self.args.z.z_what_dim). \
            view(-1, self.args.z.z_what_dim, 1, 1)

        if self.args.data.inp_channel == 1 or not self.args.arch.phase_overlap:
            o = self.z_what_decoder_net(z_what_reshape)
            o = o.sigmoid()
            a = o.new_ones(o.size())
        elif self.args.arch.phase_overlap:
            o, a = self.z_what_decoder_net(z_what_reshape).split([self.args.data.inp_channel, 1], dim=1)
            o, a = o.sigmoid(), a.sigmoid()
        else:
            raise NotImplemented

        lv = [z_pres, z_where, z_depth, z_what, z_where_origin]
        pa = [o, a]

        return pa, lv
Ejemplo n.º 3
0
def so3_entropy(w_eps, std, k=10):
    '''
    w_eps(Tensor of dim Bx3): sample from so3
    std(Tensor of dim Bx3): std of distribution on so3
    k: Use 2k+1 samples for truncated summation
    '''
    # entropy of gaussian distribution on so3
    # see appendix C of https://arxiv.org/pdf/1807.04689.pdf
    theta = w_eps.norm(p=2, dim=-1, keepdim=True)  # [B, 1]
    u = w_eps / theta  # [B, 3]
    angles = 2 * np.pi * torch.arange(
        -k, k + 1, dtype=w_eps.dtype, device=w_eps.device)  # 2k+1
    theta_hat = theta[:, None, :] + angles[:, None]  # [B, 2k+1, 1]
    x = u[:, None, :] * theta_hat  # [B, 2k+1 , 3]
    log_p = Normal(torch.zeros(3, device=w_eps.device),
                   std).log_prob(x.permute([1, 0, 2]))  # [2k+1, B, 3]
    log_p = log_p.permute([1, 0, 2])  # [B, 2k+1, 3]
    clamp = 1e-3
    log_vol = torch.log(
        (theta_hat**2).clamp(min=clamp) /
        (2 - 2 * torch.cos(theta_hat)).clamp(min=clamp))  # [B, 2k+1, 1]
    log_p = log_p.sum(-1) + log_vol.sum(-1)  #[B, 2k+1]
    entropy = -logsumexp(log_p, -1)
    return entropy
 def sample(self, sample_size):
     sigma = torch.exp(self.psi[:, 1, :])
     samples = Normal(self.psi[:, 0, :], sigma).sample(torch.Size([sample_size]))
     samples = samples.permute(1, 0, 2)
     return samples