def get_dist(self): n = len(self.mean) mix = D.Categorical(torch.ones(n, )) comp = D.Independent(D.Normal(self.mean, self.var * torch.ones(n, 2)), 1) return D.MixtureSameFamily(mix, comp)
def log_prob(self, locations_3d, x_offset_3d, y_offset_3d, z_offset_3d, intensities_3d): xyzi, counts, s_mask = get_true_labels(locations_3d, x_offset_3d, y_offset_3d, z_offset_3d, intensities_3d) x_mu, y_mu, z_mu, i_mu = (i.unsqueeze(1) for i in torch.unbind(self.xyzi_mu, dim=1)) x_si, y_si, z_si, i_si = ( i.unsqueeze(1) for i in torch.unbind(self.xyzi_sigma, dim=1)) P = torch.sigmoid(self.logits) + 0.00001 count_mean = P.sum(dim=[2, 3, 4]).squeeze(-1) count_var = (P - P**2).sum(dim=[2, 3, 4]).squeeze( -1) #avoid situation where we have perfect match count_dist = D.Normal(count_mean, torch.sqrt(count_var)) count_prob = count_dist.log_prob(counts) mixture_probs = P / P.sum(dim=[1, 2, 3], keepdim=True) xyz_mu_list, _, _, i_mu_list, x_sigma_list, y_sigma_list, z_sigma_list, i_sigma_list, mixture_probs_l = img_to_coord( P, x_mu, y_mu, z_mu, i_mu, x_si, y_si, z_si, i_si, mixture_probs) xyzi_mu = torch.cat((xyz_mu_list, i_mu_list), dim=-1) xyzi_sigma = torch.cat( (x_sigma_list, y_sigma_list, z_sigma_list, i_sigma_list), dim=-1) #to avoind NAN mix = D.Categorical(mixture_probs_l.squeeze(-1)) comp = D.Independent(D.Normal(xyzi_mu, xyzi_sigma), 1) spatial_gmm = D.MixtureSameFamily(mix, comp) spatial_prob = spatial_gmm.log_prob(xyzi.transpose(0, 1)).transpose(0, 1) spatial_prob = (spatial_prob * s_mask).sum(-1) log_prob = count_prob + spatial_prob return log_prob
def gaussian_mixture_sampler(num_latent, num_mixtures=4, weights=None, means=None, cov=None): """ :param num_latent: :param num_mixtures: :param weights: :param means: :param cov: :return: """ if weights is None: weights = torch.randn(num_latent, num_mixtures).softmax(dim=1) if means is None: means = torch.randn(num_latent, num_mixtures, 1) * 2 if cov is None: cov = torch.randn(num_latent, num_mixtures, 1) mix = dist.Categorical(weights) comp = dist.Independent(dist.Normal(means, cov), 1) gmm = dist.MixtureSameFamily(mix, comp) return lambda n: gmm.sample((n, )).squeeze()
def _goal_likelihood(self, y: torch.Tensor, goal: torch.Tensor, **hyperparams) -> torch.Tensor: """Returns the goal-likelihood of a plan `y`, given `goal`. Args: y: A plan under evaluation, with shape `[B, T, 2]`. goal: The goal locations, with shape `[B, K, 2]`. hyperparams: (keyword arguments) The goal-likelihood hyperparameters. Returns: The log-likelihodd of the plan `y` under the `goal` distribution. """ # Parses tensor dimensions. B, K, _ = goal.shape # Fetches goal-likelihood hyperparameters. epsilon = hyperparams.get("epsilon", 1.0) # TODO(filangel): implement other goal likelihoods from the DIM paper # Initializes the goal distribution. goal_distribution = D.MixtureSameFamily( mixture_distribution=D.Categorical( probs=torch.ones((B, K)).to(goal.device)), component_distribution=D.Independent( D.Normal(loc=goal, scale=torch.ones_like(goal) * epsilon), reinterpreted_batch_ndims=1, )) return torch.mean(goal_distribution.log_prob(y[:, -1, :]), dim=0)
def decoder(self, z, encoded_history, current_state, y_e=None, train=False): pass bs = encoded_history.shape[0] a_0 = F.dropout(self.action(current_state.reshape(bs, -1)), self.dropout_p) state = F.dropout(self.state(encoded_history.reshape(bs, -1)), self.dropout_p) current_state = current_state.unsqueeze(1) gauses = [] inp = F.dropout( torch.cat((encoded_history.reshape(bs, -1), a_0), dim=-1), self.dropout_p) for i in range(12): h_state = self.gru(inp.reshape(bs, -1), state) _, deltas, log_sigmas, corrs = self.project_to_GMM_params(h_state) deltas = torch.clamp(deltas, max=1.5, min=-1.5) deltas = deltas.reshape(bs, -1, 2) log_sigmas = log_sigmas.reshape(bs, -1, 2) corrs = corrs.reshape(bs, -1, 1) mus = deltas + current_state current_state = mus variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2, max=1e3) m_diag = variance * torch.eye(2).to(variance.device) sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1), min=1e-8, max=1e3) if train: # log_pis = z.reshape(bs, 1) * torch.ones(bs, self.num_modes).cuda() log_pis = to_one_hot(z, n_dims=self.num_modes).cuda() else: log_pis = to_one_hot(z, n_dims=self.num_modes).cuda() log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True) mix = D.Categorical(logits=log_pis) comp = D.MultivariateNormal(mus, m_diag) gmm = D.MixtureSameFamily(mix, comp) t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1) cov_matrix = m_diag # + anti_diag gauses.append(gmm) a_t = gmm.sample() # possible grad problems? a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p) state = h_state inp = F.dropout( torch.cat((encoded_history.reshape(bs, -1), a_tt), dim=-1), self.dropout_p) return gauses
def forward(self, output_sizes, hold_seed=None, hold_initial_set=False): """ Sample from prior :param output_sizes: Tensor([B,]) :param hold_seed :param hold_initial_set :return: Tensor([B, N, D]) """ bsize = output_sizes.shape[0] if hold_initial_set: # [B, N] x_mask = get_mask(output_sizes, self.max_outputs) else: x_mask = sample_mask(output_sizes, self.max_outputs) if hold_seed is not None: # [B, N, Ds] torch.random.manual_seed(hold_seed) eps = torch.randn([1, self.max_outputs, self.dim_seed ]).to(x_mask.device).repeat(bsize, 1, 1) else: eps = torch.randn([bsize, self.max_outputs, self.dim_seed]).to(x_mask.device) if self.n_mixtures == 1: x = self.mu + torch.exp(self.logvar / 2.) * eps else: if self.train_gmm: if hold_seed is not None: torch.random.manual_seed(hold_seed) logits = self.logits.reshape([1, 1, self.n_mixtures]).repeat( 1, self.max_outputs, 1) # [1, N, M] onehot = F.gumbel_softmax( logits, tau=self.tau, hard=True).repeat(bsize, 1, 1).unsqueeze(-1) # [B, N, M, 1] else: logits = self.logits.reshape([1, 1, self.n_mixtures]).repeat( bsize, self.max_outputs, 1) # [B, N, M] onehot = F.gumbel_softmax(logits, tau=self.tau, hard=True).unsqueeze( -1) # [B, N, M, 1] mu = self.mu.reshape([1, 1, self.n_mixtures, self.dim_seed]) # [1, 1, M, D] sig = self.sig.reshape([1, 1, self.n_mixtures, self.dim_seed]) # [1, 1, M, D] mu = (mu * onehot).sum(2) # [B, N, D] sig = (sig * onehot).sum(2) # [B, N, D] x = mu + sig * eps else: mix = D.Categorical(self.logits) comp = D.Independent(D.Normal(self.mu, self.sig.abs()), 1) mixture = D.MixtureSameFamily(mix, comp) x = mixture.sample((output_sizes.size(0), self.max_outputs)) x = self.output(x) # [B, N, D] return x, x_mask
def __init__(self, is_test): super().__init__() #self.flip_var_order = flip_var_order #if is_test: #self.pX = D.Uniform(torch.tensor([0.0]), torch.tensor([1.0])) #else: mix = D.Categorical(torch.ones(2, )) comp = D.Uniform(torch.tensor([0.0, 0.35]), torch.tensor([0.45, 1.0])) self.pX = D.MixtureSameFamily(mix, comp) self.pY1 = D.Uniform(torch.tensor([0.0]), torch.tensor([1.0])) self.pY2 = lambda X: D.Normal(torch.sin(10 * X), 0.05)
def _parameterize_distribution(self, hidden: torch.Tensor) -> D.Distribution: mixture_logits = self.mixture_linear(hidden) mixture = F.softmax(mixture_logits, dim=-1) means = self.means_linear(hidden)[..., None] stddev = F.softplus(self.stddev_linear(hidden))[..., None] c = D.Categorical(probs=mixture) n = D.Independent(D.Normal(means, stddev), 1) return D.MixtureSameFamily(c, n)
def __init__(self, pX=None): super().__init__() #self.flip_var_order = flip_var_order #if is_test: #self.pX = D.Uniform(torch.tensor([0.0]), torch.tensor([1.0])) #else: if pX is None: mix = D.Categorical(torch.ones(3, )) comp = D.Independent( D.Normal(torch.randn(3, 2), 0.3 * torch.ones(3, 2)), 1) self.pX = D.MixtureSameFamily(mix, comp) else: self.pX = pX
def __call__(self, scene: torch.Tensor, train=False): gmms = [] for model in self.models: gmm = model(scene) gmms.append(gmm) combined_gmm = [] for timestamp in range(12): logits = torch.cat([gmms[i][timestamp].mixture_distribution.logits for i in range(len(gmms))], dim=1) mus = torch.cat([gmms[i][timestamp].component_distribution.mean for i in range(len(gmms))], dim=1) variance = torch.cat([gmms[i][timestamp].component_distribution.variance for i in range(len(gmms))], dim=1) m_diag = variance.unsqueeze(2) * torch.eye(2).to(variance.device) mix = D.Categorical(logits) comp = D.MultivariateNormal(mus, m_diag) combined_gmm.append(D.MixtureSameFamily(mix, comp)) return combined_gmm
def dist( self, batch: dict[str, Union[torch.Tensor, list[torch.Tensor]]], ) -> distributions.Distribution: """Возвращает распределение доходности.""" logits, mean, std = self(batch) try: weights_dist = distributions.Categorical(logits=logits) except ValueError: raise GradientsError( f"Ошибка при обновлении градиентов: NaN in Categorical distribution" ) comp_dist = distributions.LogNormal(mean, std) return distributions.MixtureSameFamily(weights_dist, comp_dist)
def __call__(self): """Get the distribution object from the backend""" if get_backend() == "pytorch": import torch import torch.distributions as tod # Convert to pytorch distributions if probflow distributions if isinstance(self.distributions, BaseDistribution): self.distributions = self.distributions() # Broadcast probs/logits shape = self.distributions.batch_shape args = {"logits": None, "probs": None} if self.logits is not None: args["logits"], _ = torch.broadcast_tensors( self["logits"], torch.zeros(shape) ) else: args["probs"], _ = torch.broadcast_tensors( self["probs"], torch.zeros(shape) ) # Return torcch distribution object return tod.MixtureSameFamily( tod.Categorical(**args), self.distributions ) else: import tensorflow as tf from tensorflow_probability import distributions as tfd # Convert to tensorflow distributions if probflow distributions if isinstance(self.distributions, BaseDistribution): self.distributions = self.distributions() # Broadcast probs/logits shape = self.distributions.batch_shape args = {"logits": None, "probs": None} if self.logits is not None: args["logits"] = tf.broadcast_to(self["logits"], shape) else: args["probs"] = tf.broadcast_to(self["probs"], shape) # Return TFP distribution object return tfd.MixtureSameFamily( tfd.Categorical(**args), self.distributions )
def SampleGMM_detach(nsamps): global mu_pr1 global var_pr1 Z = shape(mu_pr1)[1] K = shape(mu_pr1)[0] alpha_pr = torch.zeros(K) for k in range(K): alpha_pr[k] = 1.0 / K mix = distributions.Categorical(alpha_pr) comp = distributions.MultivariateNormal(mu_pr1.detach(), var_pr1.detach()) gmm = distributions.MixtureSameFamily(mix, comp) sample = torch.zeros(nsamps, Z).to(device) sample = gmm.sample((nsamps, )) return sample
def predict(): mdrnn.eval() preds = [] gt = [] n_episodes = test_dataset[-1][-2] + 1 predictions = [[] for _ in range(n_episodes)] with torch.no_grad(): for batch_index, (states, actions, next_states, rewards, episode, timesteps) in enumerate(test_loader): states = states.to(device) next_states = next_states.to(device) rewards = rewards.to(device) actions = actions.to(device) latent_obs, _ = to_latent(states, next_states, batch_size=1,sequence_horizon=1) # Check model's next state predictions mus, sigmas, logpi, _ , _, _ = mdrnn(actions, latent_obs) mix = D.Categorical(logpi) comp = D.Independent(D.Normal(mus, sigmas), 1) gmm = D.MixtureSameFamily(mix, comp) sample = gmm.sample() decoded_states = vae.decoder(sample).squeeze(0) decoded_states = decoded_states.cpu().detach().numpy() preds.append(decoded_states) for i in range(len(states)): predictions[episode[i].int()].append(np.expand_dims(decoded_states[i], axis=0)) gt.append(next_states.cpu().detach().numpy()) #import pdb;pdb.set_trace() predictions = [np.stack(p) for p in predictions] preds = np.asarray(preds) gt = np.asarray(gt).squeeze(1) error = (preds - gt)**2 path = cfg.logdir + '/' + cfg.resname + '.pkl' pickle.dump(predictions, open(path, 'wb')) print("Mean Error: {}".format(error.mean(0)[0])) print("Min Error: {}".format(error.min(0)[0])) print("Max Error: {}".format(error.max(0)[0]))
def sample(self, n) -> Iterable[Tuple[Individual, t.Tensor]]: samples = [] components = d.Normal(loc=self.component_means, scale=self.std) for i in range(n): log_p = 0.0 params = {} for k, logits in self.mixing_logits.items(): mix = d.Categorical(logits=logits) expanded = components.expand(mix.batch_shape + components.batch_shape) gmm = d.MixtureSameFamily( mix, d.Independent(expanded, self.component_means.ndim - 1)) with t.no_grad(): sample = gmm.sample() params[k] = sample log_p += gmm.log_prob(sample).sum() samples.append((self.constructor(params), log_p)) return samples
def sample_gmm(mix_logp, means, scales, corrs): covs = compute_cov2d(scales, corrs) mix = D.Categorical(mix_logp.exp()) comp = D.MultivariateNormal(means, covs) gmm = D.MixtureSameFamily(mix, comp) return gmm.sample()
def decoder(self, z, encoded_history, current_state, train=False): bs = encoded_history.shape[0] a_0 = F.dropout(self.action(current_state.reshape(bs, -1)), self.dropout_p) # state = self.bn3(F.dropout(self.state(encoded_history.reshape(bs, -1)), self.dropout_p)) # state = self.ln3(F.dropout(self.state(encoded_history.reshape(bs, -1)), self.dropout_p)) state = F.dropout(self.state(encoded_history.reshape(bs, -1)), self.dropout_p) current_state = current_state.unsqueeze(1) gauses = [] lp = to_one_hot(z, n_dims=self.num_modes).to(encoded_history.device) # lp = z.reshape(bs, 1) * torch.ones(bs, self.num_modes).cuda() # lp = z inp = F.dropout( torch.cat((encoded_history.reshape(bs, -1), a_0, 0 * lp), dim=-1), self.dropout_p) for i in range(12): # h_state = self.ln4(self.gru(inp.reshape(bs, -1), state)) input = inp.reshape(bs, -1) # input = self.gru_prep(inp.reshape(bs, -1)) h_state = self.gru(input, state) # h_state = self.bn4(self.gru(inp.reshape(bs, -1), state)) _, deltas, log_sigmas, corrs = self.project_to_gmm_params(h_state) deltas = torch.clamp(deltas, max=1.5, min=-1.5) deltas = deltas.reshape(bs, -1, 2) log_sigmas = log_sigmas.reshape(bs, -1, 2) corrs = corrs.reshape(bs, -1, 1) mus = deltas + current_state current_state = mus variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2, max=1e3, min=1e-3) m_diag = variance * torch.eye(2).to(variance.device) sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1), min=1e-3, max=1e3) if train: # log_pis = z.reshape(bs, 1) * torch.ones(bs, self.num_modes).cuda() log_pis = to_one_hot(z, n_dims=self.num_modes).to( encoded_history.device) # log_pis = z else: log_pis = to_one_hot(z, n_dims=self.num_modes).to( encoded_history.device) # log_pis = z log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True) mix = D.Categorical(logits=log_pis) scale_tril = torch.cholesky(m_diag.cpu()).to(z.device) comp = D.MultivariateNormal(mus, scale_tril=scale_tril) gmm = D.MixtureSameFamily(mix, comp) t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1) # cov_matrix = m_diag # + anti_diag gauses.append(gmm) a_t = gmm.sample() # TODO possible grad problems? a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p) state = h_state # input = self.gru_prep(torch.cat((encoded_history.reshape(bs, -1), a_tt, lp), dim=-1)) input = torch.cat((encoded_history.reshape(bs, -1), a_tt, 0 * lp), dim=-1) inp = F.dropout(input, self.dropout_p) return gauses
def create_gmm(self, log_w_t, mu_t, log_sig_t): mix = D.Categorical(logits=log_w_t) # Batchsize x K # refer https://github.com/pytorch/pytorch/pull/22742/files comp = D.Independent(D.Normal(mu_t, torch.exp(log_sig_t)), 1) # Individual Distribution = Batchsize x K x Da return D.MixtureSameFamily(mix, comp)
def forward(self, scene: torch.Tensor): """ :param scene: tensor of shape num_peds, history_size, data_dim :return: predicted poses distributions for each agent at next 12 timesteps """ bs = scene.shape[0] poses = scene[:, :, :2] pv = scene[:, :, 2:6] vel = scene[:, :, 2:4] acc = scene[:, :, 4:6] pav = scene[:, :, :6] lstm__poses_out, _ = self.node_hist_encoder_poses( poses) # lstm_out shape num_peds, timestamps , 2*hidden_dim lstm_out_acc, hid = self.node_hist_encoder_acc( acc) # lstm_out shape num_peds, timestamps , 2*hidden_dim lstm_out_vell, hid = self.node_hist_encoder_vel( vel) # lstm_out shape num_peds, timestamps , 2*hidden_dim # lstm_out_poses, hid = self.node_hist_encoder_poses(poses) lstm_out = lstm_out_vell + lstm_out_acc # + lstm_out_poses current_state = poses[:, -1, :] # np, data_dim = current_pose.shape bs, seq, data_dim = poses.shape stacked = poses.permute(1, 0, 2).reshape(seq, -1).repeat(1, bs).reshape( seq, bs, bs * data_dim) deltas = (stacked - poses.permute(1, 0, 2).repeat(1, 1, bs)) deltas = deltas.permute(1, 0, 2).reshape(bs, seq, bs, data_dim) deltas_flat = deltas.reshape(deltas.shape[0], deltas.shape[1], -1).cuda() max_size = 50 # TODO: fix prep_for_deltas = torch.zeros(bs, seq, 50).cuda() if deltas_flat.shape[2] >= max_size: prep_for_deltas = deltas_flat[:, :, :max_size] else: prep_for_deltas[:, :, :deltas_flat.shape[2]] = deltas_flat at_hidden = self.att.init_hidden(bs=bs) for i in range(8): at_output, at_hidden, at_normalized_weights = self.att( at_hidden, lstm__poses_out[:, i:i + 1, :], prep_for_deltas[:, i:i + 1, :]) # current_pose = scene[:, -1, :2] # num_people, data_dim # stacked = current_pose.flatten().repeat(np).reshape(np, np * data_dim) # deltas = (stacked - current_pose.repeat(1, np)).reshape(np, np, data_dim) # np, np, data_dim # distruction, _ = self.edge_encoder(deltas, poses, poses) catted = torch.cat((lstm_out[:, -1:, :], at_output[:, -1:, :]), dim=2) a_0 = F.dropout(self.action(current_state.reshape(bs, -1)), self.dropout_p) state = F.dropout(self.state(catted.reshape(bs, -1)), self.dropout_p) current_state = current_state.unsqueeze(1) gauses = [] inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_0), dim=-1), self.dropout_p) for i in range(12): h_state = self.gru(inp.reshape(bs, -1), state) log_pis, deltas, log_sigmas, corrs = self.project_to_GMM_params( h_state) deltas = torch.clamp(deltas, max=1.5, min=-1.5) log_pis = log_pis.reshape(bs, -1) log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True) deltas = deltas.reshape(bs, -1, 2) log_sigmas = log_sigmas.reshape(bs, -1, 2) corrs = corrs.reshape(bs, -1, 1) mus = deltas + current_state current_state = mus variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2, max=1e3) m_diag = variance * torch.eye(2).to(variance.device) sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1), min=1e-8, max=1e3) mix = D.Categorical(log_pis) comp = D.MultivariateNormal(mus, m_diag) gmm = D.MixtureSameFamily(mix, comp) t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1) cov_matrix = m_diag # + anti_diag gauses.append(gmm) a_t = gmm.sample() # possible grad problems? state = h_state inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_t), dim=-1), self.dropout_p) return gauses
def position_log_prob(self, x): # Computing the log probability over only the positions. component_dist = td.MultivariateNormal(loc=self.component_distribution.mean[..., :2], scale_tril=self.component_distribution.scale_tril[..., :2, :2]) position_dist = td.MixtureSameFamily(self.mixture_distribution, component_dist) return position_dist.log_prob(x)
def forward(self, scene: torch.Tensor): """ :param scene: tensor of shape num_peds, history_size, data_dim :return: predicted poses distributions for each agent at next 12 timesteps """ bs = scene.shape[0] poses = scene[:, :, :2] pv = scene[:, :, 2:6] vel = scene[:, :, 2:4] acc = scene[:, :, 4:6] pav = scene[:, :, :6] # lstm_out, hid = self.node_hist_encoder(pav) # lstm_out shape num_peds, timestamps , 2*hidden_dim lstm_out_acc, hid = self.node_hist_encoder_acc( acc) # lstm_out shape num_peds, timestamps , 2*hidden_dim lstm_out_vell, hid = self.node_hist_encoder_vel( vel) # lstm_out shape num_peds, timestamps , 2*hidden_dim lstm_out_poses, hid = self.node_hist_encoder_poses(poses) lstm_out = lstm_out_vell + lstm_out_poses + lstm_out_acc # lstm_out = lstm_out_poses # + lstm_out_poses current_pose = scene[:, -1, :2] # num_people, data_dim current_state = poses[:, -1, :] np, data_dim = current_pose.shape stacked = current_pose.flatten().repeat(np).reshape(np, np * data_dim) deltas = (stacked - current_pose.repeat(1, np)).reshape( np, np, data_dim) # np, np, data_dim distruction, _ = self.edge_encoder(deltas) catted = torch.cat((lstm_out[:, -1:, :], distruction[:, -1:, :]), dim=1) a_0 = F.dropout(self.action(current_state.reshape(bs, -1)), self.dropout_p) state = F.dropout(self.state(catted.reshape(bs, -1)), self.dropout_p) current_state = current_state.unsqueeze(1) gauses = [] inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_0), dim=-1), self.dropout_p) for i in range(12): h_state = self.gru(inp.reshape(bs, -1), state) log_pis, deltas, log_sigmas, corrs = self.project_to_GMM_params( h_state) deltas = torch.clamp(deltas, max=1.5, min=-1.5) log_pis = log_pis.reshape(bs, -1) log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True) deltas = deltas.reshape(bs, -1, 2) log_sigmas = log_sigmas.reshape(bs, -1, 2) corrs = corrs.reshape(bs, -1, 1) mus = deltas + current_state current_state = mus variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2, max=1e3) m_diag = variance * torch.eye(2).to(variance.device) sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1), min=1e-8, max=1e3) mix = D.Categorical(log_pis) comp = D.MultivariateNormal(mus, m_diag) gmm = D.MixtureSameFamily(mix, comp) t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1) cov_matrix = m_diag # + anti_diag gauses.append(gmm) a_t = gmm.sample() # possible grad problems? a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p) state = h_state inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_tt), dim=-1), self.dropout_p) return gauses
y = torch.pow(torch.abs(x), 1) y.mean().backward() print(x.grad) main() def log_prob(value, loc=0, var=1, p=0.7 ): # loc = torch.Tensor(loc) # var = torch.Tensor(var) # return -torch.log(2 * var) - torch.abs(var - loc)**p / var return -(torch.abs(value - loc) ** p) / var - math.log(2 * var) value = torch.linspace(-100,100,2000) x = torch.distributions.Laplace(0,1) y = torch.distributions.Normal(0,1) mix = dist.Categorical(torch.ones(5,)) comp = dist.Normal(0, torch.rand(5,)) # z = torch.distributions.MixtureSameFamily gmm = dist.MixtureSameFamily(mix, comp) plt.plot(value, torch.exp(x.log_prob(value)),color='blue') plt.plot(value, torch.exp(y.log_prob(value)),color='red') plt.plot(value, torch.exp(gmm.log_prob(value)),color='green') plt.plot(value, torch.exp(log_prob(value)),color='yellow') # plt.ylim(10**0, -10**2) plt.yscale('log') # plt.fill_between(value, torch.exp(x.log_prob(value))) plt.show() # main()
nclus, P, device=device) + args.clus_sep * torch.randn(nclus, P, device=device) var_pr1 = torch.zeros(nclus, P, P, device=device) #store centers of gaussian inputs to prior-encoder clscenlog = 'ClusterCenters_CELEBA_EPSWAE.txt' for i in range(nclus): var_pr1[i] = torch.eye(P, device=device).detach() np.savetxt(clscenlog, mu_pr1.detach().numpy()) alpha_pr = torch.zeros(nclus, device=device) for k in range(nclus): alpha_pr[k] = 1.0 / nclus mix = distributions.Categorical(alpha_pr) comp = distributions.MultivariateNormal(mu_pr1, var_pr1) gmm = distributions.MixtureSameFamily(mix, comp) for epoch in range(1, args.epochs + 1): recon_batch, data = train(epoch) if (epoch > 0 and epoch % args.save_interval == 0): torch.save(model.state_dict(), 'models/CELEBA_EPSWAE_AE_e' + str(epoch)) torch.save(prior.state_dict(), 'models/CELEBA_EPSWAE_PE_e' + str(epoch)) #generate sample plots sample = prior.GenNsamples_NormPrior(64) sample = model.decode(sample).cpu() save_image(sample.view(64, 3, leng, leng), 'results/CelebA_EPSWAE_Sample_e_' + str(epoch) + '.png')
def kde_pig_dl( dm: pl.LightningDataModule, batch_size: int, N_hat_multiplier: float = 1, ) -> DataLoader: # % gd_n_steps, gd_lr, gd_threshold = 5, 4e-1, 0.005 # Spherical = each component has single variance. bgm = BayesianGaussianMixture( n_components=batch_size, covariance_type="spherical", warm_start=True, ) x_hat = torch.Tensor() for idx, batch in enumerate(iter(dm.train_dataloader())): x, _ = batch device = x.device x = x.detach().cpu().numpy() # Last batch might have less elements than origin n_components if x.shape[0] < bgm.n_components: bgm = BayesianGaussianMixture( n_components=x.shape[0], covariance_type="spherical", ) # Estimate KDE bgm.fit(x) # [N_components, 1], [N_components, N_features], [N_components, 1] weights, means, variances = ( torch.Tensor(bgm.weights_).to(device), torch.Tensor(bgm.means_).to(device), torch.Tensor(bgm.covariances_).to(device), ) filter_weights_idx = weights >= 1e-5 weights, means, variances = ( weights[filter_weights_idx], means[filter_weights_idx], variances[filter_weights_idx][:, None], ) n_selected_components = weights.shape[0] p_x = D.Independent(D.Normal(means, torch.sqrt(variances)), 1) mix = D.Categorical(weights) p_x = D.MixtureSameFamily(mix, p_x) # Sample according to multiplier x_start = p_x.sample( ( n_selected_components * ((batch_size // n_selected_components) + 1) * N_hat_multiplier, ) ).reshape(-1, x.shape[1]) # Use GD _x_hat = density_gradient_descent( p_x, x_start, {"N_steps": gd_n_steps, "lr": gd_lr, "threshold": gd_threshold}, ) # Ensure same device if x_hat.device != device: x_hat = x_hat.to(device) x_hat = torch.cat((x_hat, _x_hat.detach())) dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True) return dl
def MixtureLogistic(logits, loc, scale): return D.MixtureSameFamily( D.Categorical(logits=logits), Logistic(loc, scale), )