def multivariate_score_method(self, p, q, feature_list): # note: if feature list is all features, then this is just the joint score method p_mean = p.mean(axis=0) p_cov = np.cov(p, rowvar=False) p_cov += 1e-5 * np.eye(p_cov.shape[0]) q_mean = q.mean(axis=0) q_cov = np.cov(q, rowvar=False) q_cov += 1e-5 * np.eye(q_cov.shape[0]) p_hat = MultivariateNormal(loc=torch.from_numpy(p_mean), covariance_matrix=torch.from_numpy(p_cov)) q_hat = MultivariateNormal(loc=torch.from_numpy(q_mean), covariance_matrix=torch.from_numpy(q_cov)) running_score = 0 for idx in range(self.n_conditional_expectation): for sample in [p_hat.sample(), q_hat.sample()]: sample.requires_grad_(True) log_p_sample = p_hat.log_prob(sample) log_q_sample = q_hat.log_prob(sample) p_grad = torch.autograd.grad(log_p_sample, sample)[ 0] # grad returns a tensor inside a tuple, hence [0] q_grad = torch.autograd.grad(log_q_sample, sample)[0] p_score_vector = np.array(p_grad[feature_list]) q_score_vector = np.array(q_grad[feature_list]) score = np.sum((p_score_vector - q_score_vector)**2) running_score += score return running_score / (self.n_conditional_expectation * 2)
def Gaussian2DLoss(target_coord, prediction_tensor, beta = 5, alpha = 1): ''' Compute the negative log likelihood of (x, y) given bivariate gaussian with params (mx, my, sx, sy, sxy), The PDF is given by k * (sx * sy - sxy ** 2) ** -0.5 * exp(-0.5 * (mx - x) * (inv )) Assume batch of n points, target_coord: (n, 2), in each row, it is (x, y) prediction_tensor: (n, 5), in each row, it is (mx, my, sx, sy, sxy) ''' # get the modified prediction of the gaussian mean_vec, conv_mat = gaussian_prediction(prediction_tensor) # the bivariate pdf pdf = MultivariateNormal(mean_vec, conv_mat) # the negative log likelihood loss raw_probs = torch.exp(pdf.log_prob(target_coord)) # NLL loss losses = (-pdf.log_prob(target_coord)) # breakpoint() final_loss, _ = torch.min(torch.stack(( losses, (beta - raw_probs * ((beta + np.log(alpha)) / alpha)).abs() ), dim = -1), dim = -1) final_loss = torch.mean(final_loss) return final_loss
class NegativeGaussianLoss(nn.Module): """ Standard Normal Likelihood (negative) """ def __init__(self, size): super().__init__() self.size = size self.dim = dim = int(np.prod(size)) self.N = MultivariateNormal(torch.zeros(dim, device='cuda'), torch.eye(dim, device='cuda')) def forward(self, input, context=None): return -self.log_prob(input, context).sum(-1) def log_prob(self, input, context=None, sum=True): try: p = self.N.log_prob(input.view(-1, self.dim)) except RuntimeError: p = self.N.log_prob(input.reshape(-1, self.dim)) return p def sample(self, n_samples, context=None): x = self.N.sample((n_samples,)).view(n_samples, *self.size) log_px = self.log_prob(x, context) return x, log_px
def univariate_score_method(self, p, q): p_mean = p.mean(axis=0) p_cov = np.cov(p, rowvar=False) p_cov += 1e-5 * np.eye(p_cov.shape[0]) q_mean = q.mean(axis=0) q_cov = np.cov(q, rowvar=False) q_cov += 1e-5 * np.eye(q_cov.shape[0]) p_hat = MultivariateNormal(loc=torch.from_numpy(p_mean), covariance_matrix=torch.from_numpy(p_cov)) q_hat = MultivariateNormal(loc=torch.from_numpy(q_mean), covariance_matrix=torch.from_numpy(q_cov)) running_score = np.zeros(self.n_dim) for idx in range(self.n_conditional_expectation): for sample in [p_hat.sample(), q_hat.sample()]: sample.requires_grad_(True) log_p_sample = p_hat.log_prob(sample) log_q_sample = q_hat.log_prob(sample) p_grad = torch.autograd.grad(log_p_sample, sample)[ 0] # grad returns a tensor inside a tuple, hence [0] q_grad = torch.autograd.grad(log_q_sample, sample)[0] score = (p_grad - q_grad).data.numpy()**2 running_score += score return running_score / (self.n_conditional_expectation * 2)
def kl_gaussian_gaussian_mc(mu_q, logvar_q, mu_p, logvar_p, num_samples=1): """ COMPLETE ME. DONT MODIFY THE PARAMETERS OF THE FUNCTION. Otherwise, tests might fail. *** note. *** :param mu_q: (FloatTensor) - shape: (batch_size x input_size) - The mean of first distributions (Normal distributions). :param logvar_q: (FloatTensor) - shape: (batch_size x input_size) - The log variance of first distributions (Normal distributions). :param mu_p: (FloatTensor) - shape: (batch_size x input_size) - The mean of second distributions (Normal distributions). :param logvar_p: (FloatTensor) - shape: (batch_size x input_size) - The log variance of second distributions (Normal distributions). :param num_samples: (int) - shape: () - The number of sample for Monte Carlo estimate for KL-divergence :return: (FloatTensor) - shape: (batch_size,) - kl-divergence of KL(q||p). """ # init batch_size = mu_q.size(0) input_size = np.prod(mu_q.size()[1:]) mu_q = mu_q.view(batch_size, -1).unsqueeze(1).expand(batch_size, num_samples, input_size) logvar_q = logvar_q.view(batch_size, -1).unsqueeze(1).expand(batch_size, num_samples, input_size) mu_p = mu_p.view(batch_size, -1).unsqueeze(1).expand(batch_size, num_samples, input_size) logvar_p = logvar_p.view(batch_size, -1).unsqueeze(1).expand(batch_size, num_samples, input_size) dist_q = MultivariateNormal(mu_q, torch.diag_embed(torch.exp(logvar_q))) dist_p = MultivariateNormal(mu_p, torch.diag_embed(torch.exp(logvar_p))) z = dist_q.sample() kld = (dist_q.log_prob(z) - dist_p.log_prob(z)).sum(1) / num_samples return kld
class StackSimpleAffine(nn.Module): def __init__(self, transforms, dim=2): super().__init__() self.dim = dim self.transforms = nn.ModuleList(transforms) self.distribution = MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim)) def log_probability(self, x): log_prob = torch.zeros(x.shape[0]) for transform in reversed(self.transforms): x, inv_log_det_jac = transform.inverse(x) log_prob += inv_log_det_jac log_prob += self.distribution.log_prob(x) return log_prob def rsample(self, num_samples): x = self.distribution.sample((num_samples, )) log_prob = self.distribution.log_prob(x) for transform in self.transforms: x, log_det_jac = transform.forward(x) log_prob += log_det_jac return x, log_prob
def get_log_prob_n_step(self, x, output_torch_var, n_step, num_sample, only_last_step_log_prob): # repeat the input for num_sample times to approximate the next state distribution via samples (note this is not # very efficient) x = x.repeat(num_sample, 1, 1) output_torch_var = output_torch_var.repeat(num_sample, 1, 1) state_tensor = x[:, 0, 0:4] log_probs = 0 for t in range(n_step): action_tensor = x[:, t, 4] means = self.predict_forward_model_deterministic( state_tensor, action_tensor) multivariate_normal = MultivariateNormal( means, covariance_matrix=batch_diagonal(self.std.pow(2))) if only_last_step_log_prob and t == n_step - 1: log_probs = multivariate_normal.log_prob( output_torch_var[:, t, :]).unsqueeze(1) else: log_probs += multivariate_normal.log_prob( output_torch_var[:, t, :]).unsqueeze(1) # do now a step with the means and std to compute the new state output = multivariate_normal.sample( ) # sampled output for euler states # update x states input[:, self.x_state_idx] += input[:, self.x_dot_state_idx] * self.tau # update x_dot states input[:, self.x_dot_state_idx] += output * self.tau # map theta between -pi and pi input[:, 2] = torch.fmod(input[:, 2] + np.pi, np.pi * 2) - np.pi # clamp the velocities to max values s.t. we cannot get numerical errors? input[:, 1] = torch.clamp(input[:, 1], -self.high[1], self.high[1]) input[:, 3] = torch.clamp(input[:, 3], -self.high[3], self.high[3]) return log_probs
class RealNVPCNN(nn.Module): def __init__(self, masks): super(RealNVPCNN, self).__init__() self.in_channels = masks[0].size(0) self.image_width = masks[0].size(1) pixels = self.in_channels * self.image_width**2 self.masks = nn.ParameterList([ nn.Parameter(torch.Tensor(mask), requires_grad=False) for mask in masks ]) self.layers = nn.ModuleList( [RealNVPNodeCNN(mask, self.in_channels) for mask in self.masks]) self.distribution = MultivariateNormal(torch.zeros(pixels), torch.eye(pixels)) def log_probability(self, x): log_prob = torch.zeros(x.shape[0]) for layer in reversed(self.layers): x, inv_log_det_jac = layer.inverse(x) log_prob += inv_log_det_jac log_prob += self.distribution.log_prob(x.view(x.shape[0], -1)) return log_prob def rsample(self, num_samples): x = self.distribution.sample((num_samples, )) log_prob = self.distribution.log_prob(x) x = x.view(num_samples, self.in_channels, self.image_width, self.image_width) for layer in self.layers: x, log_det_jac = layer.forward(x) log_prob += log_det_jac return x, log_prob def sample_each_step(self, num_samples): samples = [] x = self.distribution.sample((num_samples, )) samples.append(x.detach().numpy()) for layer in self.layers: x, _ = layer.forward(x) samples.append(x.detach().numpy()) return samples
def dual_improvement( self, eta: Union[to.Tensor, np.ndarray], param_samples: to.Tensor, w: to.Tensor ) -> Union[to.Tensor, np.ndarray]: """ Compute the REPS dual function value for policy improvement. :param eta: lagrangian multiplier (optimization variable of the dual) :param param_samples: all sampled policy parameters :param w: weights of the policy parameter samples :return: dual loss value """ # The sample weights have been computed by minimizing dual_evaluation, don't track the gradient twice assert w.requires_grad is False with to.no_grad(): distr_old = MultivariateNormal(self._policy.param_values, self._expl_strat.cov.data) if self.optim_mode == "scipy" and not isinstance(eta, to.Tensor): # We can arrive there during the 'normal' REPS routine, but also when computing the gradient (jac) for # the scipy optimizer. In the latter case, eta is already a tensor. eta = to.from_numpy(eta).to(to.get_default_dtype()) self.wml(eta, param_samples, w) distr_new = MultivariateNormal(self._policy.param_values, self._expl_strat.cov.data) logprobs = distr_new.log_prob(param_samples) kl = kl_divergence(distr_new, distr_old) # mode seeking a.k.a. exclusive KL if self.optim_mode == "scipy": loss = w.numpy() @ logprobs.numpy() + eta * (self.eps - kl.numpy()) else: loss = w @ logprobs + eta * (self.eps - kl) return loss
def train(model, data_loader, epochs, optimizer): train_loss = [] for epoch in range(epochs): model.train() train_losses = [] for t, posterior in data_loader: optimizer.zero_grad() posterior = posterior.permute(0, 2, 1) t = t.reshape(128, 1, -1) est_posterior, log_det = model(posterior, t) ll = MultivariateNormal(posterior, 0.25 * torch.eye(2)) loss = -ll.log_prob(est_posterior) + log_det #print(log_det) loss = loss.mean() loss.backward() optimizer.step() train_losses.append(loss.item()) print() print(epoch, loss.item()) print("Posterior: {}".format(posterior[-1])) print("Estimated Posterior: {}".format(est_posterior[-1])) print(est_posterior.shape) print("Train epoch: {}, train_loss: {}".format(epoch, loss.item())) train_loss.append(np.mean(train_losses)) return model, train_loss
def forward(self, x): b, c, w, h = x.size() #Step1: embedding for each local point. # st = time.perf_counter() x_embedded = self.embedding(x) # print("Embedding time: {}".format(time.perf_counter() - st)) # postion = torch.max(x_embedded)[1] # Step2: Distribution # TODO: Learn a local point for each channel. # st = time.perf_counter() multiNorm = MultivariateNormal( loc=self.normal_loc, scale_tril=(self.normal_scal).diag_embed()) # print("Generate Norm time: {}".format(time.perf_counter() - st)) # localtion_map = Variable(self.get_localation_map(b,w,h,self.local_num), requires_grad=False) # shape[b, w, h, local_num, 2] # localtion_map = self.localation_map[:,0:w,0:h,:,:].expand([b,w,h,self.local_num,2]) localtion_map = self.get_location_mask(x, b, w, h, self.local_num) pdf = multiNorm.log_prob(localtion_map * self.position_scal).exp() # print("PDF shape: {}".format(pdf.shape)) #Step3: Value embedding x_value = x.expand(self.local_num, b, c, w, h).reshape(self.local_num * b, c, w, h) x_value = self.value_embed(x_value).reshape(self.local_num, b, c, w, h).permute(1, 2, 3, 4, 0) # print("x_value shape: {}".format(x_value.shape)) #Step4: embeded_Value X possibility_density increment = (x_value * pdf.unsqueeze(dim=1)).mean(dim=-1) return x + increment
def test_joint_feature_generator(model, test_loader): device = 'cuda' if torch.cuda.is_available() else 'cpu' data = model.data model.eval() _, n_features, signel_len = next(iter(test_loader))[0].shape test_loss = 0 if data == 'mimic': tvec = [24] elif data == 'mimic_int': tvec = [5] else: num = 1 tvec = [ int(tt) for tt in np.logspace(1.0, np.log10(signel_len), num=num) ] for i, (signals, labels) in enumerate(test_loader): for t in tvec: mean, covariance = model.likelihood_distribution(signals[:, :, :t]) # dist = OMTMultivariateNormal(mean, torch.cholesky(covariance)) dist = MultivariateNormal(loc=mean, covariance_matrix=covariance) reconstruction_loss = -dist.log_prob(signals[:, :, t].to(device)).mean() test_loss = test_loss + reconstruction_loss.item() # label = signals[:, :, t:t+model.prediction_size].contiguous().view(signals.shape[0], signals.shape[1]) # prediction = model.forward_joint(signals[:, :, :t]) # loss = torch.nn.MSELoss()(prediction, label.to(device)) # test_loss += loss.item() return test_loss / (i + 1)
def plot_variational_post(mean, covariance, beta, hparams, label="baseline"): mu = mean[0] cov = torch.tensor([[torch.exp(covariance[0][0]), 0.], [0., torch.exp(covariance[0][1])]]) note_taking("variational q mu at beta={} is {}".format( beta, mu.detach().cpu().numpy())) note_taking("variational q cov at beta={} is {}".format( beta, cov.detach().cpu().numpy())) m = MultivariateNormal(mu, cov) nbins = 300 x = np.linspace(-2, 2, nbins) y = np.linspace(-2, 2, nbins) x_grid, y_grid = np.meshgrid( np.linspace(-2, 2, nbins), np.linspace(-2, 2, nbins)) samples_grid = np.vstack([x_grid.flatten(), y_grid.flatten()]).transpose() density = torch.exp( m.log_prob( torch.from_numpy(samples_grid).to(hparams.tensor_type).to( hparams.device))) plt.contourf( x_grid, y_grid, density.view((nbins, nbins)).detach().cpu().numpy(), 20, cmap='jet') plt.colorbar() path = hparams.messenger.arxiv_dir + hparams.hparam_set + "_{}_var_q_b{}.pdf".format( label, beta) plt.savefig(path) plt.close()
def jump(self, xi, t): n = xi.size(0) # spatial multivariate gaussian m_gauss = MultivariateNormal(self.mu_jump * torch.ones(n), self.std_jump * torch.eye(n)) # poisson process, probability of arrival at time t exp_d = Exponential(self.lambd) # independent events, mult probabilities p = torch.exp(m_gauss.log_prob(xi)) * ( 1 - torch.exp(exp_d.log_prob(self.last_jump))) # one sample from bernoulli trial event = Bernoulli(p).sample([1]) if event: coord_before = xi xi = self.jump_event(xi, t) # flatten resulting sampled location coord_after = xi # saving jump coordinate info self.log_jump(t, coord_before, coord_after) self.last_jump = 0 # if no jump, increase counter for bern trial else: self.last_jump += 1 return xi
def nll_full_rank(self, target, mu, tril_elements, reduce=True): """Evaluate the NLLNative for a single Gaussian with a full-rank covariance matrix Parameters ---------- target : torch.Tensor of shape [batch_size, Y_dim] Y labels mu : torch.Tensor of shape [batch_size, Y_dim] network prediction of the mu (mean parameter) of the BNN posterior tril_elements : torch.Tensor of shape [batch_size, Y_dim*(Y_dim + 1)//2] reduce : bool whether to take the mean across the batch Returns ------- torch.Tensor of shape [batch_size,] NLL values """ batch_size, _ = target.shape tril = torch.zeros([batch_size, self.Y_dim, self.Y_dim], device=self.device, dtype=None) tril[:, self.tril_idx[0], self.tril_idx[1]] = tril_elements log_diag_tril = torch.diagonal(tril, offset=0, dim1=1, dim2=2) # [batch_size, Y_dim] tril[:, torch.eye(self.Y_dim, dtype=bool)] = torch.exp(log_diag_tril) prec_mat = torch.bmm(tril, torch.transpose(tril, 1, 2)) mvn = MultivariateNormal(loc=mu, precision_matrix=prec_mat) loss = -mvn.log_prob(target) if reduce: return torch.mean(loss, dim=0) # float else: return loss # [batch_size,]
def plot_analytical_post(mean, covariance, beta, hparams): """ Plot analytical posterior for visualization. """ mu = mean[:, 0] cov = covariance note_taking("analytical q mu at beta={} is {}".format( beta, mu.detach().cpu().numpy())) note_taking("analytical q cov at beta={} is {}".format( beta, cov.detach().cpu().numpy())) m = MultivariateNormal(mu, cov) samples = m.sample([500]) nbins = 300 x = np.linspace(-2, 2, nbins) y = np.linspace(-2, 2, nbins) x_grid, y_grid = np.meshgrid(np.linspace(-2, 2, nbins), np.linspace(-2, 2, nbins)) samples_grid = np.vstack([x_grid.flatten(), y_grid.flatten()]).transpose() density = torch.exp( m.log_prob( torch.from_numpy(samples_grid).to(hparams.tensor_type).to( hparams.device))) plt.contourf(x_grid, y_grid, density.view((nbins, nbins)).detach().cpu().numpy(), 20, cmap='jet') plt.colorbar() path = hparams.messenger.arxiv_dir + hparams.hparam_set + "_anaytical_q_b{}.pdf".format( beta) plt.savefig(path) plt.close()
def log_prior(z): dim = z.shape[1] mean = torch.zeros(dim).cuda() cov = torch.eye(dim).cuda() m = MultivariateNormal(mean, cov) m.requires_grad = True return m.log_prob(z)
class ParticleSmootherSystemWrapper: def __init__(self, sys, R): self._sys = sys self._Rmvn = MultivariateNormal(torch.zeros(R.shape[0], ), R) def __call__(self, x, t): """ t (int): time x (torch.tensor): (N,n) particles """ x = x.unsqueeze(1) t = torch.tensor([float(t)]) nx = self._sys.step(t, x) return nx.squeeze(1) def obsll(self, x, y): """ x (torch.tensor): (N,n) particles y (torch.tensor): (1,m) observation """ y_ = self._sys.observe(None, x.unsqueeze(1)).squeeze(1) dy = y - y_ logprob_y = self._Rmvn.log_prob(dy).unsqueeze(1) return logprob_y @property def _xdim(self): return self._sys.xdim
class RealNVP(nn.Module): def __init__(self, masks, hidden_size): super(RealNVP, self).__init__() self.dim = len(masks[0]) self.hidden_size = hidden_size self.masks = nn.ParameterList([ nn.Parameter(torch.Tensor(mask), requires_grad=False) for mask in masks ]) self.layers = nn.ModuleList( [RealNVPNode(mask, self.hidden_size) for mask in self.masks]) self.distribution = MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim)) def log_probability(self, x): log_prob = torch.zeros(x.shape[0]) for layer in reversed(self.layers): x, inv_log_det_jac = layer.inverse(x) log_prob += inv_log_det_jac log_prob += self.distribution.log_prob(x) return log_prob def rsample(self, num_samples): x = self.distribution.sample((num_samples, )) log_prob = self.distribution.log_prob(x) for layer in self.layers: x, log_det_jac = layer.forward(x) log_prob += log_det_jac return x, log_prob def sample_each_step(self, num_samples): samples = [] x = self.distribution.sample((num_samples, )) samples.append(x.detach().numpy()) for layer in self.layers: x, _ = layer.forward(x) samples.append(x.detach().numpy()) return samples
def get_W(self, p, b, bb): zero = p.new_tensor(0) sigma = self.hw / 2 cov = stackify(((sigma**2, zero), (zero, sigma**2))) d = MultivariateNormal(p, covariance_matrix=cov) return torch.exp(d.log_prob(grid2ps(*bb_grid(bb)))).reshape( tuple(bb_sz(bb).long()))
def compute_log_prob_noise(self, observations, goals): noise_stds = self.noise_decoder(observations) #noise_means, noise_stds = torch.split(self.noise_decoder(observations), 1, dim=2) m = MultivariateNormal( torch.zeros_like(noise_stds), (noise_stds**2 + 0.01) * torch.eye(self.action_dim)) # squaring stds so as to be positive log_prob_noise = m.log_prob(noise) return log_prob_noise
def run(self, x): x=Variable(Tensor(x)) u, logstd=self(x) sigma2=torch.exp(2*logstd)*self.output_id d=MultivariateNormal(u, sigma2) #might want to use N Gaussian instead action=d.sample() self.history_of_log_probs.append(d.log_prob(action)) return action
def p_bump(x, npy=False): ndims = x.shape[-1] if npy: z = MVN(np.zeros(ndims)).pdf(x) else: dist = TMVN(torch.zeros(ndims), torch.eye(ndims)) z = torch.exp(dist.log_prob(x.float())) return 100 * z
def get_log_prob(self, x, output_torch_var): means, log_std, std = self.forward(x) # calculate difference in x dot states from actions and scale them by tau multivariate_normal = MultivariateNormal( means, covariance_matrix=batch_diagonal(std.pow(2))) # NOTE: we need here a unsqueeze(1) to work with the rlrl framework, else we get in trouble when calculating # advantage * log_prob log_prob = multivariate_normal.log_prob(output_torch_var).unsqueeze(1) return log_prob
def reparameterize(self, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) mu = torch.zeros(len(std[0])).cuda() covar_mat = torch.eye(len(std[0])).cuda() covar_mat = std * covar_mat m = MultivariateNormal(mu, covar_mat) log_prob_a = m.log_prob(eps.mul(std)) return eps.mul(std), log_prob_a
def run(self, x): x=Variable(Tensor(x)) #the action space is continuous u=self(x) sigma2=torch.exp(self.logstd_raw)*self.outputid d=MultivariateNormal(u, sigma2) action=d.sample() self.history_of_log_probs.append(d.log_prob(action)) return action
def run(self, x): x=Variable(x) u, logstd=self(x) sigma2=torch.exp(2*logstd)*self.outputid d=MultivariateNormal(u, sigma2) #might want to use N Gaussian instead action=d.sample() log_prob=d.log_prob(action) self.history_of_log_probs.append(log_prob) return action, log_prob
def run(self, x): x=Variable(x) #the action space is continuous u=self(x) sigma2=torch.exp(self.logstd_raw)*self.outputid d=MultivariateNormal(u, sigma2) action=d.sample() log_prob=d.log_prob(action) return action, log_prob
def log_prob_forward_model(self, state_tensor, action_tensor, next_state_tensor): # calculate qdd from current state and next state means = self.predict_forward_model_deterministic( state_tensor, action_tensor) multivariate_normal = MultivariateNormal(means, covariance_matrix=torch.diag( self.std.pow(2))) log_prob = (multivariate_normal.log_prob(next_state_tensor)) return log_prob
def get_log_prob_action(self, goals, observations, actions): latent_and_observation = torch.cat([goals, observations], dim=2) action_means, action_stds = torch.split( self.action_decoder(latent_and_observation), self.action_dim, dim=2) m = MultivariateNormal(action_means, (10 * action_stds**2 + 0.001) * torch.eye(self.action_dim)) log_prob_actions = m.log_prob(actions) return log_prob_actions
def compute_log_prob_goals(self, observations, goals): pen_vars_slice = self.pen_vars_slice goal_means, goal_stds = torch.split(self.goal_decoder(observations), self.goal_dim, dim=2) m = MultivariateNormal( goal_means, (goal_stds**2 + 0.01) * torch.eye(self.goal_dim)) # squaring stds so as to be positive log_prob_goals = m.log_prob(goals) return log_prob_goals