def forward(self, state_features): x = self.feedforward_model(state_features) if self.dist == 'tanh_normal': mean, std = th.chunk(x, 2, -1) mean = self.mean_scale * th.tanh(mean / self.mean_scale) std = F.softplus(std + self.raw_init_std) + self.min_std dist = td.Normal(mean, std) # TODO: fix nan problem dist = td.TransformedDistribution(dist, td.TanhTransform(cache_size=1)) dist = td.Independent(dist, 1) dist = SampleDist(dist) elif self.dist == 'trunc_normal': mean, std = th.chunk(x, 2, -1) std = 2 * th.sigmoid((std + self.raw_init_std) / 2) + self.min_std from rls.nn.dists.TruncatedNormal import \ TruncatedNormal as TruncNormalDist dist = TruncNormalDist(th.tanh(mean), std, -1, 1) dist = td.Independent(dist, 1) elif self.dist == 'one_hot': dist = td.OneHotCategoricalStraightThrough(logits=x) elif self.dist == 'relaxed_one_hot': dist = td.RelaxedOneHotCategorical(th.tensor(0.1), logits=x) else: raise NotImplementedError(f"{self.dist} is not implemented.") return dist
def losses_clustering(self, x, x_hat, mu_z, logvar_z, z): if not self.computes_std: std_z = torch.exp(logvar_z / 2) std_c = torch.exp(self.logvar_c / 2) else: std_z = torch.exp(logvar_z) std_c = torch.exp(self.logvar_c) pi = distributions.Categorical(torch.sigmoid(self.pi)).probs pc_given_z = self.pc_given_z(z) BCE = F.binary_cross_entropy_with_logits( x_hat, x, reduction='mean') * self.width * self.height KLD = torch.sum(pc_given_z * distributions.kl_divergence( distributions.Independent(distributions.Normal( mu_z[:, None, :], std_z[:, None, :]), reinterpreted_batch_ndims=1), distributions.Independent(distributions.Normal( self.mu_c[None, :, :], std_c[None, :, :]), reinterpreted_batch_ndims=1)), dim=1).mean() KLD_c = distributions.kl_divergence( distributions.Categorical(pc_given_z), distributions.Categorical(pi[None, :])).mean() return BCE, KLD, KLD_c, torch.tensor(0).float(), pc_given_z
def forward(self, x, beta=1.0, switch=1.0, iw_samples=1): # Encoder step z_mu, z_std = self.encoder(x) q_dist = D.Independent(D.Normal(z_mu, z_std), 1) z = q_dist.rsample([iw_samples]) # Decoder step x_mu, x_std = self.decoder(z, switch) if switch: valid = torch.zeros((x.shape[0], 1), device=x.device) fake = torch.ones((x.shape[0], 1), device=x.device) labels = torch.cat([valid, fake], dim=0) x_cat = torch.cat([x.repeat(iw_samples, 1, 1), x_mu], dim=1) prop = self.adverserial(x_cat) advert_loss = F.binary_cross_entropy(prop, labels.repeat( iw_samples, 1, 1), reduction='sum') x_std = self.dec_std(prop[:, :x.shape[0]]) else: advert_loss = 0 p_dist = D.Independent(D.Normal(x_mu, x_std), 1) # Calculate loss prior = D.Independent( D.Normal(torch.zeros_like(z), torch.ones_like(z)), 1) log_px = p_dist.log_prob(x) kl = q_dist.log_prob(z) - prior.log_prob(z) elbo = (log_px - beta * kl).mean() iw_elbo = elbo.logsumexp(dim=0) - torch.tensor(float(iw_samples)).log() return iw_elbo.mean() - advert_loss, log_px.mean(), kl.mean( ), x_mu[0], x_std, z[0], z_mu, z_std
def reparameterize(self, mu, var): pred_dist = dist.Independent(dist.Normal(mu, var), 1) self.pred_dist = pred_dist eps = pred_dist.rsample() prior_mean = self.loc(self.onehot) prior_std = self.sp(self.scale(self.onehot)) self.prior = dist.Independent(dist.Normal(prior_mean, prior_std), 1) return eps
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways, dropout, speaker_latent_dims, speaker_encoder_dims, n_speakers, noise_latent_dims, noise_encoder_dims): super().__init__() self.n_mels = n_mels self.lstm_dims = lstm_dims self.decoder_dims = decoder_dims # Standard Tacotron ############################################################# self.encoder = Encoder(embed_dims, num_chars, encoder_dims, encoder_K, num_highways, dropout) self.encoder_proj = nn.Linear(decoder_dims, decoder_dims, bias=False) self.decoder = Decoder(n_mels, decoder_dims, lstm_dims, speaker_latent_dims, noise_latent_dims) self.postnet = CBHG(postnet_K, n_mels + noise_latent_dims, postnet_dims, [256, n_mels + noise_latent_dims], num_highways) self.post_proj = nn.Linear(postnet_dims * 2, fft_bins, bias=False) # VAE Domain Adversarial ######################################################## if hp.encoder_model == "CNN": self.speaker_encoder = CNNEncoder(n_mels, speaker_latent_dims, speaker_encoder_dims) self.noise_encoder = CNNEncoder(n_mels, noise_latent_dims, noise_encoder_dims) elif hp.encoder_model == "CNNRNN": self.speaker_encoder = CNNRNNEncoder(n_mels, speaker_latent_dims, speaker_encoder_dims) self.noise_encoder = CNNRNNEncoder(n_mels, noise_latent_dims, noise_encoder_dims) self.speaker_speaker = Classifier(speaker_latent_dims, n_speakers) self.speaker_noise = Classifier(speaker_latent_dims, 2) self.noise_speaker = Classifier(noise_latent_dims, n_speakers) self.noise_noise = Classifier(noise_latent_dims, 2) ## speaker encoder prior self.speaker_latent_loc = nn.Parameter( torch.zeros(speaker_latent_dims), requires_grad=False) self.speaker_latent_scale = nn.Parameter( torch.ones(speaker_latent_dims), requires_grad=False) self.speaker_latent_prior = dist.Independent( dist.Normal(self.speaker_latent_loc, self.speaker_latent_scale), 1) ## noise encoder prior self.noise_latent_loc = nn.Parameter(torch.zeros(noise_latent_dims), requires_grad=False) self.noise_latent_scale = nn.Parameter(torch.ones(noise_latent_dims), requires_grad=False) self.noise_latent_prior = dist.Independent( dist.Normal(self.noise_latent_loc, self.noise_latent_scale), 1) ################################################################################# self.init_model() self.num_params() self.register_buffer("step", torch.zeros(1).long()) self.register_buffer("r", torch.tensor(0).long())
def kl_penalty(self) -> torch.Tensor: """Compute the KL divergence prior penalty, used for constructing the ELBO.""" q = dist.Independent( dist.Normal(self.q_mean, torch.exp(self.log_q_scale)), reinterpreted_batch_ndims=2, ) p_mean = torch.zeros_like(self.q_mean) p_scale = torch.ones_like(self.q_mean) p = dist.Independent(dist.Normal(p_mean, p_scale), reinterpreted_batch_ndims=2) return dist.kl_divergence(q, p)
def forward(self, image, **kwargs): logits = F.relu(super().forward(image, **kwargs)[0]) batch_size = logits.shape[0] event_shape = (self.num_classes, ) + logits.shape[2:] mean = self.mean_l(logits) cov_diag = self.log_cov_diag_l(logits).exp() + self.epsilon mean = mean.view((batch_size, -1)) cov_diag = cov_diag.view((batch_size, -1)) cov_factor = self.cov_factor_l(logits) cov_factor = cov_factor.view( (batch_size, self.rank, self.num_classes, -1)) cov_factor = cov_factor.flatten(2, 3) cov_factor = cov_factor.transpose(1, 2) # covariance in the background tens to blow up to infinity, hence set to 0 outside the ROI mask = kwargs['sampling_mask'] mask = mask.unsqueeze(1).expand((batch_size, self.num_classes) + mask.shape[1:]).reshape( batch_size, -1) cov_factor = cov_factor * mask.unsqueeze(-1) cov_diag = cov_diag * mask + self.epsilon if self.diagonal: base_distribution = td.Independent( td.Normal(loc=mean, scale=torch.sqrt(cov_diag)), 1) else: try: base_distribution = td.LowRankMultivariateNormal( loc=mean, cov_factor=cov_factor, cov_diag=cov_diag) except: print( 'Covariance became not invertible using independent normals for this batch!' ) base_distribution = td.Independent( td.Normal(loc=mean, scale=torch.sqrt(cov_diag)), 1) distribution = ReshapedDistribution(base_distribution, event_shape) shape = (batch_size, ) + event_shape logit_mean = mean.view(shape) cov_diag_view = cov_diag.view(shape).detach() cov_factor_view = cov_factor.transpose( 2, 1).view((batch_size, self.num_classes * self.rank) + event_shape[1:]).detach() output_dict = { 'logit_mean': logit_mean.detach(), 'cov_diag': cov_diag_view, 'cov_factor': cov_factor_view, 'distribution': distribution } return logit_mean, output_dict
def clone_dist(self, dist, detach=False): if self._rssm_type == 'discrete': mean = dist.mean if detach: mean = th.detach(mean) return td.Independent(OneHotDistFlattenSample(mean), 1) else: mean, stddev = dist.mean, dist.stddev if detach: mean, stddev = th.detach(mean), th.detach(stddev) return td.Independent(td.Normal(mean, stddev), 1)
def _train(self, BATCH): output = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] if self.is_continuous: mu, log_std = output # [T, B, A], [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) new_log_prob = dist.log_prob(BATCH.action).unsqueeze( -1) # [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = output # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] new_log_prob = (BATCH.action * logp_all).sum( -1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 ratio = (new_log_prob - BATCH.log_prob).exp() # [T, B, 1] actor_loss = -(ratio * BATCH.gae_adv).mean() # 1 flat_grads = grads_flatten(actor_loss, self.actor, retain_graph=True).detach() # [1,] if self.is_continuous: kl = td.kl_divergence( td.Independent(td.Normal(BATCH.mu, BATCH.log_std.exp()), 1), td.Independent(td.Normal(mu, log_std.exp()), 1)).mean() else: kl = (BATCH.logp_all.exp() * (BATCH.logp_all - logp_all)).sum(-1).mean() # 1 flat_kl_grad = grads_flatten(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( flat_grads, flat_kl_grad, cg_iters=self._cg_iters) # [1,] with th.no_grad(): flat_params = th.cat( [param.data.view(-1) for param in self.actor.parameters()]) new_flat_params = flat_params + self.actor_step_size * search_direction set_from_flat_params(self.actor, new_flat_params) for _ in range(self._train_critic_iters): value = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] td_error = BATCH.discounted_reward - value # [T, B, 1] critic_loss = td_error.square().mean() # 1 self.critic_oplr.optimize(critic_loss) return { 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/critic_lr': self.critic_oplr.lr }
def forward(self, x, beta=1.0, epsilon=1e-5): z_mu, z_var = self.encoder(x) q_dist = D.Independent(D.Normal(z_mu, z_var.sqrt()+epsilon), 1) z = q_dist.rsample() x_mu, x_var = self.decoder(z) p_dist = D.Independent(D.Normal(x_mu, x_var.sqrt()+epsilon), 1) prior = D.Independent(D.Normal(torch.zeros_like(z), torch.ones_like(z)), 1) log_px = p_dist.log_prob(x) kl = q_dist.log_prob(z) - prior.log_prob(z) elbo = log_px - beta*kl return elbo.mean(), log_px.mean(), kl.mean(), x_mu, x_var, z, z_mu, z_var
def forward(self, x, beta=1.0, epsilon=1e-5,Q=0.5): z_mu, z_var = self.encoder(x) q_dist = D.Independent(D.Normal(z_mu, z_var.sqrt()+epsilon), 1) z = q_dist.rsample() x_mu = self.decoder(z) prior = D.Independent(D.Normal(torch.zeros_like(z), torch.ones_like(z)), 1) log_px_Q1 = torch.sum(torch.max(0.15 * (x-x_mu[:,0:4]), (0.15 - 1) * (x-x_mu[:,0:4])).view(-1, 4),(1)) log_px_Q2 = torch.sum(torch.max(0.5 * (x-x_mu[:,4:8]), (0.5 - 1) * (x-x_mu[:,4:8])).view(-1, 4),(1)) log_px_Q3= torch.sum(torch.max(0.85 * (x-x_mu[:,8:12] ), (0.85 - 1) * (x-x_mu[:,8:12] )).view(-1, 4),(1)) log_px=(log_px_Q1+log_px_Q2+log_px_Q3)/3 kl = q_dist.log_prob(z) - prior.log_prob(z) elbo = -log_px - 0.28*kl return elbo.mean(), log_px.mean(), kl.mean(), x_mu, z, z_mu, z_var
def test_inv(): flow = flowtorch.bijectors.AffineAutoregressive( flowtorch.params.DenseAutoregressive()) tdist, params = flow( dist.Independent(dist.Normal(torch.zeros(2), torch.ones(2)), 1)) inv_flow = flow.inv() inv_tdist, inv_params = inv_flow( dist.Independent(dist.Normal(torch.zeros(2), torch.ones(2)), 1)) x = torch.zeros(1, 2) y = flow.forward(x, params, context=torch.empty(0)) assert tdist.bijector.log_abs_det_jacobian( x, y, params, context=torch.empty(0)) == inv_tdist.bijector.log_abs_det_jacobian( y, x, inv_params, context=torch.empty(0)) assert flow.inv().inv == flow
def test_conditional_2gmm(): context_size = 2 flow = flowtorch.bijectors.Compose( [ flowtorch.bijectors.AffineAutoregressive(context_size=context_size) for _ in range(2) ], context_size=context_size, ).inv() base_dist = dist.Normal(torch.zeros(2), torch.ones(2)) new_cond_dist, params_module = flow(base_dist) target_dist_0 = dist.Independent( dist.Normal(torch.zeros(2) + 5, torch.ones(2) * 0.5), 1) target_dist_1 = dist.Independent( dist.Normal(torch.zeros(2) - 5, torch.ones(2) * 0.5), 1) opt = torch.optim.Adam(params_module.parameters(), lr=5e-3) for idx in range(501): opt.zero_grad() if idx % 2 == 0: target_dist = target_dist_0 context = torch.ones(context_size) else: target_dist = target_dist_1 context = -1 * torch.ones(context_size) marginal = new_cond_dist.condition(context) y = marginal.rsample((1000, )) loss = -target_dist.log_prob(y) + marginal.log_prob(y) loss = loss.mean() if idx % 100 == 0: print("epoch", idx, "loss", loss) loss.backward() opt.step() assert (new_cond_dist.condition(torch.ones(context_size)).sample( (1000, )).mean() - 5.0).norm().item() < 0.1 assert (new_cond_dist.condition(-1 * torch.ones(context_size)).sample( (1000, )).mean() + 5.0).norm().item() < 0.1
def gaussian_mixture_sampler(num_latent, num_mixtures=4, weights=None, means=None, cov=None): """ :param num_latent: :param num_mixtures: :param weights: :param means: :param cov: :return: """ if weights is None: weights = torch.randn(num_latent, num_mixtures).softmax(dim=1) if means is None: means = torch.randn(num_latent, num_mixtures, 1) * 2 if cov is None: cov = torch.randn(num_latent, num_mixtures, 1) mix = dist.Categorical(weights) comp = dist.Independent(dist.Normal(means, cov), 1) gmm = dist.MixtureSameFamily(mix, comp) return lambda n: gmm.sample((n, )).squeeze()
def detach(self): self.mean = self.mean.detach() self.log_std = self.log_std.detach() self.normal = None self.diagn = None self.normal = P.Normal(self.mean, (self.log_std.exp())) self.diagn = P.Independent(self.normal, 1)
def test_neals_funnel_vi(): torch.manual_seed(42) nf = NealsFunnel() flow = flowtorch.bijectors.AffineAutoregressive( flowtorch.params.DenseAutoregressive()) tdist, params = flow( dist.Independent(dist.Normal(torch.zeros(2), torch.ones(2)), 1)) opt = torch.optim.Adam(params.parameters(), lr=1e-3) num_elbo_mc_samples = 100 for _ in range(400): z0 = tdist.base_dist.rsample(sample_shape=(num_elbo_mc_samples, )) zk = flow._forward(z0, params, context=torch.empty(0)) ldj = flow._log_abs_det_jacobian(z0, zk, params, context=torch.empty(0)) neg_elbo = -nf.log_prob(zk).sum() neg_elbo += tdist.base_dist.log_prob(z0).sum() - ldj.sum() neg_elbo /= num_elbo_mc_samples if not torch.isnan(neg_elbo): neg_elbo.backward() opt.step() opt.zero_grad() nf_samples = NealsFunnel().sample((20, )).squeeze().numpy() vi_samples = tdist.sample((20, )).detach().numpy() assert scipy.stats.ks_2samp(nf_samples[:, 0], vi_samples[:, 0]).pvalue >= 0.05 assert scipy.stats.ks_2samp(nf_samples[:, 1], vi_samples[:, 1]).pvalue >= 0.05
def select_action(self, obs): q = self.q_net(obs, rnncs=self.rnncs) # [B, P] self.rnncs_ = self.q_net.get_rnncs() pi = self.intra_option_net(obs, rnncs=self.rnncs) # [B, P, A] beta = self.termination_net(obs, rnncs=self.rnncs) # [B, P] options_onehot = F.one_hot(self.options, self.options_num).float() # [B, P] options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1] pi = (pi * options_onehot_expanded).sum(-2) # [B, A] if self.is_continuous: mu = pi.tanh() # [B, A] log_std = self.log_std[self.options] # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) actions = dist.sample().clamp(-1, 1) # [B, A] else: pi = pi / self.boltzmann_temperature # [B, A] dist = td.Categorical(logits=pi) actions = dist.sample() # [B, ] max_options = q.argmax(-1).long() # [B, P] => [B, ] if self.use_eps_greedy: # epsilon greedy if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): self.new_options = self._generate_random_options() else: self.new_options = max_options else: beta_probs = (beta * options_onehot).sum(-1) # [B, P] => [B,] beta_dist = td.Bernoulli(probs=beta_probs) self.new_options = th.where(beta_dist.sample() < 1, self.options, max_options) return actions, Data(action=actions, last_options=self.options, options=self.new_options)
def independent(self, reinterpreted_batch_ndims=1): ''' Flattening the data into one (or more) dimensions and using as if it were td.Independent(distribution=OurDistribution...) is common ''' return td.Independent( self, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
def create_gmm(system, gmm_scale=0.05): """ Get distribution using gaussian kernels on a system of points. Arguments: system: set of points from which gmm will be produced batches: bool indicating if system shape includes batch dimension kernel_size: stdev of kernel placed on each point to form gmm Returns: gmm_x: gmm probability distribution """ system = torch.squeeze(system) n_dim = system.shape[-1] n_concepts = system.shape[-2] # Weight concepts equally mix = D.Categorical(torch.ones(n_concepts, )) # Covariance matrix (diagonal) set with gmm_scale components = D.Independent( D.Normal(system, gmm_scale * torch.ones(n_dim, )), 1) gmm_X = D.mixture_same_family.MixtureSameFamily(mix, components) return gmm_X
def forward(self, input: torch.Tensor, proposal: distributions.Normal, reconstruction: torch.Tensor) -> torch.Tensor: if self.likelihood == 'bernoulli': likelihood = distributions.Bernoulli(probs=reconstruction) else: likelihood = distributions.Normal(reconstruction, torch.ones_like(reconstruction)) likelihood = distributions.Independent(likelihood, reinterpreted_batch_ndims=-1) reconstruction_loss = likelihood.log_prob(input).mean() assert proposal.loc.dim( ) == 2, "proposal.shape == [*, D], D is shape of isotopic gaussian" prior = distributions.Normal(torch.zeros_like(proposal.loc), torch.ones_like(proposal.scale)) regularization = distributions.kl_divergence(proposal, prior).sum(dim=-1).mean() # evidence lower bound (maximize) total_loss = reconstruction_loss - self.beta * regularization return -total_loss, -reconstruction_loss, regularization
def _build_dist(self, output): if self._rssm_type == 'discrete': logits = output.view( output.shape[:-1] + (self.stoch_dim, self._discretes)) # [B, s, d] return td.Independent(OneHotDistFlattenSample(logits=logits), 1) else: mean, stddev = th.chunk(output, 2, -1) # [B, *] if self._std_act == 'softplus': stddev = F.softplus(stddev) elif self._std_act == 'sigmoid': stddev = th.sigmoid(stddev) elif self._std_act == 'sigmoid2': stddev = 2. * th.sigmoid(stddev / 2.) stddev = stddev + self._min_stddev # [B, *] return td.Independent(td.Normal(mean, stddev), 1)
def update(self, observations, actions, adv_n=None): # TODO: update the policy and return the loss # observations = ptu.from_numpy(observations) # actions = ptu.from_numpy(actions) if adv_n is not None: # adv_n = ptu.from_numpy(adv_n) pass else: # in which circumstances can adv_n be None?? seems no raise ValueError("adv_n is None!?") action_dist = self.forward(observations) if self.discrete: log_pi = action_dist.log_prob(actions) else: if len(action_dist.batch_shape) == 1: log_pi = action_dist.log_prob(actions) else: action_dist_new = distributions.Independent(action_dist, 1) log_pi = action_dist_new.log_prob(actions) assert adv_n.ndim == log_pi.ndim sums = adv_n * log_pi # sums = torch.tensor(sums)l # loss = sum(sums) loss = -torch.sum( sums ) # `optimizer.step()` MINIMIZES a loss but we want to MAXIMIZE expectation self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() # what does item() do
def _train(self, BATCH): v = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] td_error = BATCH.discounted_reward - v # [T, B, 1] critic_loss = td_error.square().mean() # 1 self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_act_prob = dist.log_prob(BATCH.action).unsqueeze( -1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] log_act_prob = (BATCH.action * logp_all).sum( -1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum( -1, keepdim=True) # [T, B, 1] # advantage = BATCH.discounted_reward - v.detach() # [T, B, 1] actor_loss = -(log_act_prob * BATCH.gae_adv + self.beta * entropy).mean() # 1 self.actor_oplr.optimize(actor_loss) return { 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr }
def select_action(self, obs): if self.is_continuous: if self._share_net: mu, log_std, value = self.net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.net.get_rnncs() else: mu, log_std = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: if self._share_net: logits, value = self.net(obs, rnncs=self.rnncs) # [B, A], [B, 1] self.rnncs_ = self.net.get_rnncs() else: logits = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] acts_info = Data(action=action, value=value, log_prob=log_prob + th.finfo().eps) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info
def squashed_diagonal_gaussian_head(x): mean, log_scale = torch.chunk(x, 2, dim=-1) log_scale = torch.clamp(log_scale, -20.0, 2.0) var = torch.exp(log_scale * 2) base_distribution = distributions.Independent( distributions.Normal(loc=mean, scale=torch.sqrt(var)), 1) return base_distribution
def __init__(self, x, layers, num_components=100, device=None, old=False): super(VAE_bodies, self).__init__() self.device = device self.p = int(layers[0]) # Dimension of x self.d = int(layers[-1]) # Dimension of z self.h = layers # [1:-1] # Dimension of hidden layers self.num_components = num_components enc = [] for k in range(len(layers) - 1): in_features = int(layers[k]) out_features = int(layers[k + 1]) enc.append( nnj.ResidualBlock(nnj.Linear(in_features, out_features), nnj.Softplus())) enc.append(nnj.Linear(out_features, int(self.d * 2))) dec = [] for k in reversed(range(len(layers) - 1)): in_features = int(layers[k + 1]) out_features = int(layers[k]) if not old: # temporary to load old models TODO: delete if out_features != layers[0]: dec.append( nnj.ResidualBlock( nnj.Linear(in_features, out_features), nnj.Softplus())) else: dec.append( nnj.ResidualBlock( nnj.Linear(in_features, out_features), nnj.Sigmoid())) else: dec.append( nnj.ResidualBlock(nnj.Linear(in_features, out_features), nnj.Softplus())) if out_features == layers[0]: dec.append(nnj.Sigmoid()) # Note how we use 'nnj' instead of 'nn' -- this gives automatic # computation of Jacobians of the implemented neural network. # The embed function is required to also return Jacobians if # requested; by using 'nnj' this becomes a trivial constraint. self.encoder = nnj.Sequential(*enc) self.decoder_loc = nnj.Sequential(*dec) self.init_decoder_scale = 0.01 * torch.ones(self.p, device=self.device) self.prior_loc = torch.zeros(self.d, device=self.device) self.prior_scale = torch.ones(self.d, device=self.device) self.prior = td.Independent( td.Normal(loc=self.prior_loc, scale=self.prior_scale), 1) # Create a blank std-network. # It is important to call init_std after training the mean, but before training the std self.dec_std = None self.to(self.device)
def miwae_loss(iota_x, mask, d, K, p_z, encoder, decoder): batch_size = iota_x.shape[0] p = iota_x.shape[1] out_encoder = encoder(iota_x) q_zgivenxobs = td.Independent( td.Normal(loc=out_encoder[..., :d], scale=torch.nn.Softplus()(out_encoder[..., d:(2 * d)])), 1) zgivenx = q_zgivenxobs.rsample([K]) zgivenx_flat = zgivenx.reshape([K * batch_size, d]) out_decoder = decoder(zgivenx_flat) all_means_obs_model = out_decoder[..., :p] all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2 * p)]) + 0.001 all_degfreedom_obs_model = torch.nn.Softplus()( out_decoder[..., (2 * p):(3 * p)]) + 3 data_flat = torch.Tensor.repeat(iota_x, [K, 1]).reshape([-1, 1]) tiledmask = torch.Tensor.repeat(mask, [K, 1]) all_log_pxgivenz_flat = torch.distributions.StudentT( loc=all_means_obs_model.reshape([-1, 1]), scale=all_scales_obs_model.reshape([-1, 1]), df=all_degfreedom_obs_model.reshape([-1, 1])).log_prob(data_flat) all_log_pxgivenz = all_log_pxgivenz_flat.reshape([K * batch_size, p]) logpxobsgivenz = torch.sum(all_log_pxgivenz * tiledmask, 1).reshape([K, batch_size]) logpz = p_z.log_prob(zgivenx) logq = q_zgivenxobs.log_prob(zgivenx) neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq, 0)) return neg_bound
def forward(self, input: torch.Tensor, proposal: distributions.RelaxedOneHotCategorical, proposal_sample: torch.Tensor, reconstruction: torch.Tensor) -> torch.Tensor: if self.likelihood == 'bernoulli': likelihood = distributions.Bernoulli(probs=reconstruction) else: likelihood = distributions.Normal(reconstruction, torch.ones_like(reconstruction)) likelihood = distributions.Independent(likelihood, reinterpreted_batch_ndims=-1) reconstruction_loss = likelihood.log_prob(input).mean() assert proposal.logits.dim( ) == 2, "proposal.shape == [*, D], D is shape of isotopic gaussian" prior = distributions.RelaxedOneHotCategorical(proposal.temperature, logits=torch.ones_like( proposal.logits)) regularization = (proposal.log_prob(proposal_sample) - prior.log_prob(proposal_sample)) \ .mean() # evidence lower bound (maximize) total_loss = reconstruction_loss - self.beta * regularization return -total_loss, -reconstruction_loss, regularization
def log_prob(self, locations_3d, x_offset_3d, y_offset_3d, z_offset_3d, intensities_3d): xyzi, counts, s_mask = get_true_labels(locations_3d, x_offset_3d, y_offset_3d, z_offset_3d, intensities_3d) x_mu, y_mu, z_mu, i_mu = (i.unsqueeze(1) for i in torch.unbind(self.xyzi_mu, dim=1)) x_si, y_si, z_si, i_si = ( i.unsqueeze(1) for i in torch.unbind(self.xyzi_sigma, dim=1)) P = torch.sigmoid(self.logits) + 0.00001 count_mean = P.sum(dim=[2, 3, 4]).squeeze(-1) count_var = (P - P**2).sum(dim=[2, 3, 4]).squeeze( -1) #avoid situation where we have perfect match count_dist = D.Normal(count_mean, torch.sqrt(count_var)) count_prob = count_dist.log_prob(counts) mixture_probs = P / P.sum(dim=[1, 2, 3], keepdim=True) xyz_mu_list, _, _, i_mu_list, x_sigma_list, y_sigma_list, z_sigma_list, i_sigma_list, mixture_probs_l = img_to_coord( P, x_mu, y_mu, z_mu, i_mu, x_si, y_si, z_si, i_si, mixture_probs) xyzi_mu = torch.cat((xyz_mu_list, i_mu_list), dim=-1) xyzi_sigma = torch.cat( (x_sigma_list, y_sigma_list, z_sigma_list, i_sigma_list), dim=-1) #to avoind NAN mix = D.Categorical(mixture_probs_l.squeeze(-1)) comp = D.Independent(D.Normal(xyzi_mu, xyzi_sigma), 1) spatial_gmm = D.MixtureSameFamily(mix, comp) spatial_prob = spatial_gmm.log_prob(xyzi.transpose(0, 1)).transpose(0, 1) spatial_prob = (spatial_prob * s_mask).sum(-1) log_prob = count_prob + spatial_prob return log_prob
def _goal_likelihood(self, y: torch.Tensor, goal: torch.Tensor, **hyperparams) -> torch.Tensor: """Returns the goal-likelihood of a plan `y`, given `goal`. Args: y: A plan under evaluation, with shape `[B, T, 2]`. goal: The goal locations, with shape `[B, K, 2]`. hyperparams: (keyword arguments) The goal-likelihood hyperparameters. Returns: The log-likelihodd of the plan `y` under the `goal` distribution. """ # Parses tensor dimensions. B, K, _ = goal.shape # Fetches goal-likelihood hyperparameters. epsilon = hyperparams.get("epsilon", 1.0) # TODO(filangel): implement other goal likelihoods from the DIM paper # Initializes the goal distribution. goal_distribution = D.MixtureSameFamily( mixture_distribution=D.Categorical( probs=torch.ones((B, K)).to(goal.device)), component_distribution=D.Independent( D.Normal(loc=goal, scale=torch.ones_like(goal) * epsilon), reinterpreted_batch_ndims=1, )) return torch.mean(goal_distribution.log_prob(y[:, -1, :]), dim=0)