def compute_pi_logprob(mean_std_batch, action_arr): # mean_std_batch: (batch_size, 2) # action_arr: (batch_size, 1022) # (batch_size, ) # using ln(1 + exp(param)) permuted_mean_std_batch = mean_std_batch.permute(1, 0) # ( 2, batch_size) permuted_action_arr = action_arr.permute( 1, 0) if len(action_arr.shape) > 1 else action_arr if use_tanh: logprob = Normal(permuted_mean_std_batch[0], F.softplus(permuted_mean_std_batch[1])).log_prob( custom_atanh(permuted_action_arr)) logprob = logprob.permute(1, 0) if len(action_arr.shape) > 1 else logprob logprob -= torch.log(1 - torch.pow(action_arr, 2)) else: logprob = Normal( permuted_mean_std_batch[0], F.softplus( permuted_mean_std_batch[1])).log_prob(permuted_action_arr) logprob = logprob.permute(1, 0) if len(action_arr.shape) > 1 else logprob assert not torch.isnan(logprob).any() # (batch_size, actions_in_batch) return logprob
def forward(self, ss: List, phase_use_mode: bool = False) -> Tuple: p_pres_logits, p_where_mean, p_where_std, p_depth_mean, \ p_depth_std, p_what_mean, p_what_std = ss if phase_use_mode: z_pres = (p_pres_logits > 0).float() else: z_pres = RelaxedBernoulli(logits=p_pres_logits, temperature=self.args.train.tau_pres).rsample() # z_where_scale, z_where_shift: (bs, dim, num_cell, num_cell) if phase_use_mode: z_where_scale, z_where_shift = p_where_mean.chunk(2, 1) else: z_where_scale, z_where_shift = \ Normal(p_where_mean, p_where_std).rsample().chunk(2, 1) # z_where_origin: (bs, dim, num_cell, num_cell) z_where_origin = \ torch.cat([z_where_scale.detach(), z_where_shift.detach()], dim=1) z_where_shift = \ (2. / self.args.arch.num_cell) * \ (self.offset + 0.5 + torch.tanh(z_where_shift)) - 1. scale, ratio = z_where_scale.chunk(2, 1) scale = scale.sigmoid() ratio = torch.exp(ratio) ratio_sqrt = ratio.sqrt() z_where_scale = torch.cat([scale / ratio_sqrt, scale * ratio_sqrt], dim=1) # z_where: (bs, dim, num_cell, num_cell) z_where = torch.cat([z_where_scale, z_where_shift], dim=1) if phase_use_mode: z_depth = p_depth_mean z_what = p_what_mean else: z_depth = Normal(p_depth_mean, p_depth_std).rsample() z_what = Normal(p_what_mean, p_what_std).rsample() z_what_reshape = z_what.permute(0, 2, 3, 1).reshape(-1, self.args.z.z_what_dim). \ view(-1, self.args.z.z_what_dim, 1, 1) if self.args.data.inp_channel == 1 or not self.args.arch.phase_overlap: o = self.z_what_decoder_net(z_what_reshape) o = o.sigmoid() a = o.new_ones(o.size()) elif self.args.arch.phase_overlap: o, a = self.z_what_decoder_net(z_what_reshape).split([self.args.data.inp_channel, 1], dim=1) o, a = o.sigmoid(), a.sigmoid() else: raise NotImplemented lv = [z_pres, z_where, z_depth, z_what, z_where_origin] pa = [o, a] return pa, lv
def so3_entropy(w_eps, std, k=10): ''' w_eps(Tensor of dim Bx3): sample from so3 std(Tensor of dim Bx3): std of distribution on so3 k: Use 2k+1 samples for truncated summation ''' # entropy of gaussian distribution on so3 # see appendix C of https://arxiv.org/pdf/1807.04689.pdf theta = w_eps.norm(p=2, dim=-1, keepdim=True) # [B, 1] u = w_eps / theta # [B, 3] angles = 2 * np.pi * torch.arange( -k, k + 1, dtype=w_eps.dtype, device=w_eps.device) # 2k+1 theta_hat = theta[:, None, :] + angles[:, None] # [B, 2k+1, 1] x = u[:, None, :] * theta_hat # [B, 2k+1 , 3] log_p = Normal(torch.zeros(3, device=w_eps.device), std).log_prob(x.permute([1, 0, 2])) # [2k+1, B, 3] log_p = log_p.permute([1, 0, 2]) # [B, 2k+1, 3] clamp = 1e-3 log_vol = torch.log( (theta_hat**2).clamp(min=clamp) / (2 - 2 * torch.cos(theta_hat)).clamp(min=clamp)) # [B, 2k+1, 1] log_p = log_p.sum(-1) + log_vol.sum(-1) #[B, 2k+1] entropy = -logsumexp(log_p, -1) return entropy
def sample(self, sample_size): sigma = torch.exp(self.psi[:, 1, :]) samples = Normal(self.psi[:, 0, :], sigma).sample(torch.Size([sample_size])) samples = samples.permute(1, 0, 2) return samples