def sampleiter(self, bs=1): """ Ancestral sampling with MLP. 1 sample is a tensor (1, M, N). A minibatch of samples is a tensor (bs, M, N). 1 variable is a tensor (bs, 1, N) """ while True: with torch.no_grad(): h = [] # Hard (onehot) samples (bs,1,N) for i in range(self.M): O = torch.zeros(bs, self.M - i, self.N) # (bs,M-i,N) v = torch.cat(h + [O], dim=1) # (bs,M-i,N) + (bs,1,N)*i v = torch.einsum("hik,i,bik->bh", self.W0gt[i], self.gammagt[i], v) v = v + self.B0gt[i].unsqueeze(0) v = v.relu() v = torch.einsum("oh,bh->bo", self.W1gt[i], v) v = v + self.B1gt[i].unsqueeze(0) v = v.softmax(dim=1).unsqueeze(1) h.append(OneHotCategorical(v).sample()) s = torch.cat(h, dim=1) yield s
def latent_prior_sample(self, y, n_batch, n_samples): n_cat = self.n_labels n_latent = self.n_latent u = Normal( torch.zeros(n_latent), torch.ones(n_latent), ).sample((n_samples, n_batch)) if y is None: ys = OneHotCategorical(probs=(1.0 / n_cat) * torch.ones(n_cat)).sample( (n_samples, n_batch)) else: ys = torch.FloatTensor(n_batch, n_cat) ys.zero_() ys.scatter_(1, y.view(-1, 1), 1) ys = ys.view(1, n_batch, n_cat).expand(n_samples, n_batch, n_cat) z2_y = torch.cat([u, ys], dim=-1) pz1_z2m, pz1_z2_v = self.decoder_z1_z2(z2_y) z = Normal(pz1_z2m, pz1_z2_v).sample() return dict(z1=z, z2=u, ys=ys)
self.net = nn.Sequential(nn.Linear(s_dim, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, z_num)) def forward(self, s, log=False): feature = self.net(s) if log: return F.log_softmax(feature, dim=-1) else: return F.softmax(feature, dim=-1) if __name__ == "__main__": from torch.distributions import Categorical from torch.distributions import OneHotCategorical onehot = OneHotCategorical(torch.ones(4)) s = torch.FloatTensor([1, 2]) z = onehot.sample() #torch.LongTensor([0,0,1,0]) print(s, z) policy = Policy(s_dim=2, z_num=4, hidden=32, a_num=4) vnet = VNet(s_dim=2, z_num=4, hidden=32) qnet = QNet(s_dim=2, z_num=4, hidden=32, a_num=4) dis = Discriminator(s_dim=2, z_num=4, hidden=32) prob = policy(s, z) print(prob) dist = Categorical(prob) a = dist.sample() print(a) index = torch.LongTensor(range(1)) v = vnet(s, z)
def variational_posterior(self, logits: torch.Tensor): return OneHotCategorical(probs=logits.softmax(dim=-1))
def forward(self, input, targets, args, n_particles, criterion, test=False): """ This version takes the inputs, and does not expose the logits, but instead computes the losses directly """ # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (h, c) = self.encoder(emb, hidden) # teacher-forcing out_emb = self.dropout(self.dec_embedding(targets)) # now, we'll replicate it for each particle - it's currently [seq_len x batch_sz x nhid] hidden_states = hidden_states.repeat(1, n_particles, 1) out_emb = out_emb.repeat(1, n_particles, 1) # now [seq_len x (n_particles x batch_sz) x nhid] # out_emb, hidden_states should be viewed as (n_particles x batch_sz) - this means that's true for h as well # run the z-decoder at this point, evaluating the NLL at each step p_h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) # initially zero h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) d_h = self.init_hidden(batch_sz * n_particles, self.nhid, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log(n_particles) # will contain log w_{t - 1} resamples = 0 for i in range(seq_len): h = self.z_decoder(hidden_states[i], h) logits = self.logits(h) # build the next z sample if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() h = z # prior if test: p = OneHotCategorical(logits=p_h) else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=p_h) # now, compute the log-likelihood of the data given this mean, and the input out_emb d_h = self.decoder(torch.cat([z, out_emb[i]], 1), d_h) decoder_logits = self.out_embedding(d_h) NLL = criterion(decoder_logits, input[i].repeat(n_particles)) nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + args.anneal * (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) # sample ancestors, and reindex everything Z = log_sum_exp(wa, dim=0) # line 7 if (Z.data > 0.1).any(): pdb.set_trace() loss += Z # line 8 accumulated_weights = wa - Z # line 9 probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1./probs.pow(2).sum(0) # resample / RSAMP if 3 batch elements need resampling if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze(1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors.t().contiguous()+offsets).view(-1) h = torch.index_select(h, 0, unrolled_idx) p_h = torch.index_select(p_h, 0, unrolled_idx) d_h = torch.index_select(d_h, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log(n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # build the next mean prediction, feeding in the correct ancestor p_h = self.ar_prior_logits(torch.cat([h, out_emb[i]], 1), p_h) # now, we calculate the final log-marginal estimator nll = nlls.view(seq_len, n_particles, batch_sz).mean(1).sum() return -loss.sum(), nll, (seq_len * batch_sz), resamples
def test(epoch): model.eval() test_loss = 0 with torch.no_grad(): for i, (data, _) in enumerate(test_loader): data = data.to(device) recon_batch, log_prob, entropy = model(data) test_loss += loss_function(recon_batch, data, log_prob, entropy).item() if i == 0: n = min(data.size(0), 8) comparison = torch.cat([data[:n], recon_batch.view(args.batch_size, 1, 28, 28)[:n]]) save_image(comparison.cpu(), 'results/reconstruction_' + str(epoch) + '.png', nrow=n) test_loss /= len(test_loader.dataset) print('====> Test set loss: {:.4f}'.format(test_loss)) for epoch in range(1, args.epochs + 1): train(epoch) test(epoch) with torch.no_grad(): m = OneHotCategorical(torch.ones(256)/256.) sample = m.sample((64, 20)) sample = sample.to(device) sample = model.decode(sample).cpu() save_image(sample.view(64, 1, 28, 28), 'results/sample_' + str(epoch) + '.png')
def sample(self, params): pi, mean, log_std = params['pi'], params['mean'], params['log_std'] pi_onehot = OneHotCategorical(pi).sample() ac = torch.sum((mean + torch.randn_like(mean) * torch.exp(log_std)) * pi_onehot.unsqueeze(-1), 1) return ac
import storch import torch from torch.distributions import Bernoulli, OneHotCategorical from storch.method import RELAX, REBAR, ARM torch.manual_seed(0) p = torch.tensor(0.5, requires_grad=True) d = Bernoulli(p) sample = RELAX("sample", in_dim=1)(d) # sample = ARM('sample', n_samples=10)(d) storch.add_cost(sample, "cost") storch.backward() method = REBAR("test", n_samples=1) x = torch.Tensor([[0.2, 0.4, 0.4], [0.5, 0.1, 0.4], [0.2, 0.2, 0.6], [0.15, 0.15, 0.7]]) qx = OneHotCategorical(x) print(method(qx))
class GaussianMixture(Distribution): def __init__(self, normal_means, normal_stds, weights): self.num_gaussians = weights.shape[1] self.normal_means = normal_means self.normal_stds = normal_stds self.normal = MultivariateDiagonalNormal(normal_means, normal_stds) self.normals = [ MultivariateDiagonalNormal(normal_means[:, :, i], normal_stds[:, :, i]) for i in range(self.num_gaussians) ] self.weights = weights self.categorical = OneHotCategorical(self.weights[:, :, 0]) def log_prob( self, value, ): # log_p = [self.normals[i].log_prob(value) for i in range(self.num_gaussians)] # log_p = torch.stack(log_p, -1) # # log_p = log_p.sum(dim=1) # log_weights = torch.log(self.weights[:, :, 0]) # lp = log_weights + log_p # m = lp.max(dim=1)[0] # log-sum-exp numerical stability trick # log_p_mixture = m + torch.log(torch.exp(lp.sum(dim=1) - m)) log_p = [ self.normals[i].log_prob(value) for i in range(self.num_gaussians) ] log_p = torch.stack(log_p, -1) p = torch.exp(log_p) weights = self.weights[:, :, 0] p = p * weights p = p.sum(dim=1) log_p = torch.log(p) return log_p def sample(self): z = self.normal.sample().detach() c = self.categorical.sample()[:, :, None] s = torch.matmul(z, c) return torch.squeeze(s, 2) def rsample(self): z = (self.normal_means + self.normal_stds * MultivariateDiagonalNormal( ptu.zeros(self.normal_means.size()), ptu.ones(self.normal_stds.size())).sample()) z.requires_grad_() c = self.categorical.sample()[:, :, None] s = torch.matmul(z, c) return torch.squeeze(s, 2) def mle_estimate(self): """Return the mean of the most likely component. This often computes the mode of the distribution, but not always. """ c = ptu.zeros(self.weights.shape[:2]) ind = torch.argmax(self.weights, dim=1) # [:, 0] c.scatter_(1, ind, 1) s = torch.matmul(self.normal_means, c[:, :, None]) return torch.squeeze(s, 2) def __repr__(self): s = "GaussianMixture(normal_means=%s, normal_stds=%s, weights=%s)" return s % (self.normal_means, self.normal_stds, self.weights)
class PPOTorchPolicy(TorchPolicy): def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) self.device = torch.device('cpu') # Get hyperparameters self.alpha = config['alpha'] self.clip_ratio = config['clip_ratio'] self.gamma = config['gamma'] self.lam = config['lambda'] self.lr_pi = config['lr_pi'] self.lr_vf = config['lr_vf'] self.model_hidden_sizes = config['model_hidden_sizes'] self.num_skills = config['num_skills'] self.skill_input = config['skill_input'] self.target_kl = config['target_kl'] self.use_diayn = config['use_diayn'] self.use_env_rewards = config['use_env_rewards'] self.use_gae = config['use_gae'] # Initialize actor-critic model self.skills = OneHotCategorical(torch.ones((1, self.num_skills))) if self.skill_input is not None: skill_vec = [0.] * (self.num_skills - 1) skill_vec.insert(self.skill_input, 1.) self.z = torch.as_tensor([skill_vec], dtype=torch.float32) else: self.z = None self.model = SkilledA2C(observation_space, action_space, hidden_sizes=self.model_hidden_sizes, skills=self.skills).to(self.device) # Set up optimizers for policy and value function self.pi_optimizer = Adam(self.model.pi.parameters(), self.lr_pi) self.vf_optimizer = Adam(self.model.vf.parameters(), self.lr_vf) self.disc_optimizer = Adam(self.model.discriminator.parameters(), self.lr_vf) def compute_loss_d(self, batch): obs, z = batch[SampleBatch.CUR_OBS], batch[SKILLS] logq_z = self.model.discriminator(obs) return nn.functional.nll_loss(logq_z, z.argmax(dim=-1)) def compute_loss_pi(self, batch): obs, act, z = batch[ SampleBatch.CUR_OBS], batch[ACTIVATIONS], batch[SKILLS] adv, logp_old = batch[Postprocessing.ADVANTAGES], batch[ SampleBatch.ACTION_LOGP] clip_ratio = self.clip_ratio # Policy loss oz = torch.cat([obs, z], dim=-1) pi, logp = self.model.pi(oz, act) ratio = torch.exp(logp - logp_old) clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv loss_pi = -(torch.min(ratio * adv, clip_adv)).mean() # Useful extra info approx_kl = (logp_old - logp).mean().item() ent = pi.entropy().mean().item() clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clip_frac = torch.as_tensor(clipped, dtype=torch.float32).mean().item() pi_info = dict(kl=approx_kl, ent=ent, cf=clip_frac) return loss_pi, pi_info def compute_loss_v(self, batch): obs, z = batch[SampleBatch.NEXT_OBS], batch[SKILLS] v_pred_old, v_targ = batch[SampleBatch.VF_PREDS], batch[ Postprocessing.VALUE_TARGETS] oz = torch.cat([obs, z], dim=-1) v_pred = self.model.vf(oz) v_pred_clipped = v_pred_old + torch.clamp( v_pred - v_pred_old, -self.clip_ratio, self.clip_ratio) loss_clipped = (v_pred_clipped - v_targ).pow(2) loss_unclipped = (v_pred - v_targ).pow(2) return 0.5 * torch.max(loss_unclipped, loss_clipped).mean() def _convert_activation_to_action(self, activation): min_ = self.action_space.low max_ = self.action_space.high return tanh_to_action(activation, min_, max_) def _normalize_obs(self, obs): min_ = self.observation_space.low max_ = self.observation_space.high return normalize_obs(obs, min_, max_) @override(Policy) def compute_actions(self, obs, **kwargs): # Sample a skill at the start of each episode if self.z is None: self.z = self.skills.sample() o = self._normalize_obs(obs) a, v, logp_a, logq_z = self.model.step( torch.as_tensor(o, dtype=torch.float32), self.z) actions = self._convert_activation_to_action(a) extras = { ACTIVATIONS: a, SampleBatch.VF_PREDS: v, SampleBatch.ACTION_LOGP: logp_a, SKILLS: self.z.numpy(), SKILL_LOGQ: logq_z } return actions, [], extras @override(Policy) def postprocess_trajectory(self, batch, other_agent_batches=None, episode=None): """Adds the policy logits, VF preds, and advantages to the trajectory.""" completed = batch["dones"][-1] if completed: # Force end of episode reward last_r = 0.0 # Reset skill at the end of each episode self.z = None else: next_state = [] for i in range(self.num_state_tensors()): next_state.append([batch["state_out_{}".format(i)][-1]]) obs = [batch[SampleBatch.NEXT_OBS][-1]] o = self._normalize_obs(obs) _, last_r, _, _ = self.model.step( torch.as_tensor(o, dtype=torch.float32), self.z) last_r = last_r.item() # Compute DIAYN rewards if self.use_diayn: z = torch.as_tensor(batch[SKILLS], dtype=torch.float32) logp_z = self.skills.log_prob(z).numpy() logq_z = batch[SKILL_LOGQ][:, z.argmax(dim=-1)[0].item()] entropy_reg = self.alpha * batch[SampleBatch.ACTION_LOGP] diayn_rewards = logq_z - logp_z - entropy_reg if self.use_env_rewards: batch[SampleBatch.REWARDS] += diayn_rewards else: batch[SampleBatch.REWARDS] = diayn_rewards batch = compute_advantages(batch, last_r, gamma=self.gamma, lambda_=self.lam, use_gae=self.use_gae) return batch @override(Policy) def learn_on_batch(self, postprocessed_batch): postprocessed_batch[SampleBatch.CUR_OBS] = self._normalize_obs( postprocessed_batch[SampleBatch.CUR_OBS]) train_batch = self._lazy_tensor_dict(postprocessed_batch) # Train policy with multiple steps of gradient descent self.pi_optimizer.zero_grad() loss_pi, pi_info = self.compute_loss_pi(train_batch) # if pi_info['kl'] > 1.5 * self.target_kl: # logger.info('Early stopping at step %d due to reaching max kl.' % i) # return loss_pi.backward() self.pi_optimizer.step() # Value function learning self.vf_optimizer.zero_grad() loss_v = self.compute_loss_v(train_batch) loss_v.backward() self.vf_optimizer.step() # Discriminator learning self.disc_optimizer.zero_grad() loss_d = self.compute_loss_d(train_batch) loss_d.backward() self.disc_optimizer.step() grad_info = dict(pi_loss=loss_pi.item(), vf_loss=loss_v.item(), d_loss=loss_d.item(), **pi_info) return {LEARNER_STATS_KEY: grad_info}
def forward(self, context=None): if context is not None: self.logits = self.network(context) return OneHotCategorical(logits=self.logits)
def forward(self, x): logits = self.categories_net(x) return OneHotCategorical(logits=logits)
def forward(self, input, args, n_particles, test=False): T = nn.Softmax(dim=0)(self.T) # NOTE: not in log-space pi = nn.Softmax(dim=0)(self.pi) emit = self.calc_emit() # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step z = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): # logits = self.logits(torch.cat([hidden_states[i], h], 1)) # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits))) logits = self.logits( nn.functional.relu( self.z_decoder(torch.cat([hidden_states[i], z], 1), logits))) # build the next z sample if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() lse = log_sum_exp(logits, dim=1).view(-1, 1) log_probs = logits - lse # now, compute the log-likelihood of the data given this z-sample # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this # data for element i given choice z emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -log_sum_exp(emission + log_probs, 1) nlls[i] = NLL.data KL = (log_probs.exp() * (log_probs - (prior_probs + 1e-16).log())).sum(1) loss += (NLL + KL) if i != seq_len - 1: prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2) # now, we calculate the final log-marginal estimator return loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), 0
def forward(self, input, args, n_particles, test=False): T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # NOTE: in log-space emit = self.calc_emit() # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} resamples = 0 # in log probability space prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits))) logits = self.logits( nn.functional.relu( self.z_decoder(torch.cat([hidden_states[i], h], 1), logits))) # build the next z sample if any_nans(logits): pdb.set_trace() if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() h = z # prior if any_nans(prior_probs): pdb.set_trace() if test: p = OneHotCategorical(logits=prior_probs) else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_probs) if any_nans(prior_probs): pdb.set_trace() if any_nans(logits): pdb.set_trace() # now, compute the log-likelihood of the data given this z-sample NLL = -self.decode(z, input[i].repeat(n_particles), (emit, )) # diff. w.r.t. z nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) # sample ancestors, and reindex everything Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # line 9 if args.filter: probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) h = torch.index_select(h, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # now in probability space prior_probs = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2) # let's normalize things - slower, but safer # prior_probs += 0.01 # prior_probs = prior_probs / prior_probs.sum(1, keepdim=True) # # if ((prior_probs.sum(1) - 1) > 1e-3).any()[0]: # pdb.set_trace() if any_nans(loss): pdb.set_trace() # now, we calculate the final log-marginal estimator return -loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), resamples
def get_actions(self, obs, prev_actions, actor_rnn_states, available_actions=None, use_target=False, t_env=None, use_gumbel=False, explore=False): assert prev_actions is None or len(obs.shape) == len( prev_actions.shape) # obs is either an array of shape (batch_size, obs_dim) or (seq_len, batch_size, obs_dim) if len(obs.shape) == 2: batch_size = obs.shape[0] no_sequence = True else: batch_size = obs.shape[1] no_sequence = False eps = None if use_target: actor_out, new_rnn_states = self.target_actor( obs, prev_actions, actor_rnn_states) else: actor_out, new_rnn_states = self.actor(obs, prev_actions, actor_rnn_states) if self.discrete_action: if self.multidiscrete: if use_gumbel or explore or use_target: onehot_actions = list( map(lambda a: gumbel_softmax(a, hard=True), actor_out)) else: onehot_actions = list(map(onehot_from_logits, actor_out)) onehot_actions = torch.cat(onehot_actions, dim=-1) if explore: # eps greedy exploration batch_size = obs.shape[0] eps = self.exploration.eval(t_env) rand_numbers = torch.rand((batch_size, 1)) take_random = (rand_numbers < eps).int().view(-1, 1) # random actions sample uniformly from action space random_actions = [ OneHotCategorical(logits=torch.ones( batch_size, self.act_dim[i])).sample() for i in range(len(self.act_dim)) ] random_actions = torch.cat(random_actions, dim=1) actions = ( 1 - take_random ) * onehot_actions + take_random * random_actions else: actions = onehot_actions else: if use_gumbel or explore or use_target: onehot_actions = gumbel_softmax( actor_out, available_actions, hard=True) # gumbel has a gradient else: onehot_actions = onehot_from_logits( actor_out, available_actions) # no gradient if explore: assert no_sequence, "Doesn't make sense to do exploration on a sequence!" # eps greedy exploration eps = self.exploration.eval(t_env) rand_numbers = np.random.rand(batch_size, 1) # random actions sample uniformly from action space logits = torch.ones(batch_size, self.act_dim) random_actions = avail_choose(logits, available_actions).sample() random_actions = make_onehot(random_actions, batch_size, self.act_dim) take_random = (rand_numbers < eps).astype(float) actions = ( 1.0 - take_random) * onehot_actions.detach().cpu( ).numpy() + take_random * random_actions.cpu().numpy() else: actions = onehot_actions else: if explore: assert no_sequence, "Cannot do exploration on a sequence!" actions = gaussian_noise(actor_out.shape, self.args.act_noise_std) + actor_out elif use_target: target_noise = gaussian_noise( actor_out.shape, self.args.target_noise_std).clamp( -self.args.target_noise_clip, self.args.target_noise_clip) actions = actor_out + target_noise else: actions = actor_out # # clip the actions at the bounds of the action space # actions = torch.max(torch.min(actions, torch.from_numpy(self.act_space.high)), torch.from_numpy(self.act_space.low)) return actions, new_rnn_states, eps
def sampled_filter(self, input, args, n_particles, emb, hidden_states): seq_len, batch_sz = input.size() T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # NOTE: in log-space emit = self.calc_emit() hidden_states = hidden_states.repeat(1, n_particles, 1) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} resamples = 0 # in log probability space prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) for i in range(seq_len): # the approximate posterior comes from the same thing as before logits = self.logits(hidden_states[i]) if not self.training: # this is crucial!! p = OneHotCategorical(logits=prior_logits) q = OneHotCategorical(logits=logits) z = q.sample() else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() # now, compute the log-likelihood of the data given this z-sample emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -(emission * z).sum(1) # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,)) # diff. w.r.t. z nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # F.log_softmax(wa, dim=0) # line 9 # sample ancestors, and reindex everything if args.filter: probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) z = torch.index_select(z, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # now in log-probability space prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2) if self.training: (-loss.sum() / (seq_len * batch_sz * n_particles)).backward(retain_graph=True) return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
def sample(self, obs, prev_acts, rnn_hidden_states, available_actions=None, sample_gumbel=False): # TODO: review this method act_logits, h_outs = self.forward(obs, prev_acts, rnn_hidden_states) if self.multidiscrete: sampled_actions = [] mean_action_logprobs = [] max_prob_actions = [] for act_logit in act_logits: categorical = OneHotCategorical(logits=act_logit) all_action_prob = categorical.probs eps = (all_action_prob == 0.0) * 1e-6 all_action_logprob = torch.log(all_action_prob + eps.float().detach()) mean_action_logprob = (all_action_logprob * all_action_prob).sum( dim=-1).unsqueeze(-1) if sample_gumbel: # get a differentiable sample of the action sampled_action = gumbel_softmax(act_logit, hard=True) else: sampled_action = categorical.sample() max_prob_action = onehot_from_logits(act_logit) sampled_actions.append(sampled_action) mean_action_logprobs.append(mean_action_logprob) max_prob_actions.append(max_prob_action) sampled_actions = torch.cat(sampled_actions, dim=-1) mean_action_logprobs = torch.cat(mean_action_logprobs, dim=-1) max_prob_actions = torch.cat(max_prob_actions, dim=-1) return sampled_actions, mean_action_logprobs, max_prob_actions, h_outs else: categorical = OneHotCategorical(logits=act_logits) all_action_probs = categorical.probs eps = (all_action_probs == 0.0) * 1e-6 all_action_logprobs = torch.log(all_action_probs + eps.float().detach()) mean_action_logprobs = (all_action_logprobs * all_action_probs).sum(dim=-1).unsqueeze(-1) if sample_gumbel: # get a differentiable sample of the action sampled_actions = gumbel_softmax(act_logits, available_actions, hard=True) else: if available_actions is not None: if type(available_actions) == np.ndarray: available_actions = torch.from_numpy(available_actions) act_logits[available_actions == 0] = -1e10 sampled_actions = OneHotCategorical( logits=act_logits).sample() else: sampled_actions = categorical.sample() max_prob_actions = onehot_from_logits(act_logits, available_actions) return sampled_actions, mean_action_logprobs, max_prob_actions, h_outs
def forward(self, x): params = self.fc(self.body(x)) return OneHotCategorical(logits=params)
def inference( self, x, y=None, temperature=None, n_samples=1, reparam=True, encoder_key="default", counts=None, ): """ Dimension choice (n_categories, n_is, n_batch, n_latent) log_q (n_categories, n_is, n_batch) """ if temperature is None: raise ValueError( "Please provide a temperature for the relaxed OneHot distribution" ) if counts is not None: return self.inference_defensive_sampling( x=x, y=y, temperature=temperature, counts=counts ) n_cat = self.n_labels n_batch = len(x) # Z | X inp = x q_z1 = self.encoder_z1[encoder_key]( inp, n_samples=n_samples, reparam=reparam, squeeze=False ) # if not self.do_iaf: qz1_m = q_z1["q_m"] qz1_v = q_z1["q_v"] z1 = q_z1["latent"] assert z1.dim() == 3 # log_qz1_x = Normal(qz1_m, qz1_v.sqrt()).log_prob(z1).sum(-1) log_qz1_x = q_z1["dist"].log_prob(z1) dfs = q_z1.get("df", None) if q_z1["sum_last"]: log_qz1_x = log_qz1_x.sum(-1) z1s = z1 # torch.cuda.synchronize() # C | Z # Broadcast labels if necessary qc_z1 = self.classifier[encoder_key](z1) log_qc_z1 = qc_z1.log() qc_z1_all_probas = qc_z1 # C if y is None: if reparam: cat_dist = RelaxedOneHotCategorical( temperature=temperature, probs=qc_z1 ) ys_probs = cat_dist.rsample() else: cat_dist = OneHotCategorical(probs=qc_z1) ys_probs = cat_dist.sample() ys = (ys_probs == ys_probs.max(-1, keepdim=True).values).float() y_int = ys.argmax(-1) else: ys = torch.cuda.FloatTensor(n_batch, n_cat) ys.zero_() ys.scatter_(1, y.view(-1, 1), 1) ys = ys.view(1, n_batch, n_cat).expand(n_samples, n_batch, n_cat) y_int = y.view(1, -1).expand(n_samples, n_batch) log_pc = self.y_prior.log_prob(y_int) assert y_int.unsqueeze(-1).shape == (n_samples, n_batch, 1), y_int.shape log_qc_z1 = torch.gather(log_qc_z1, dim=-1, index=y_int.unsqueeze(-1)).squeeze( -1 ) qc_z1 = torch.gather(qc_z1, dim=-1, index=y_int.unsqueeze(-1)).squeeze(-1) assert qc_z1.shape == (n_samples, n_batch) pc = log_pc.exp() # U | Z1, C z1_y = torch.cat([z1s, ys], dim=-1) q_z2_z1 = self.encoder_z2_z1[encoder_key](z1_y, n_samples=1, reparam=reparam) z2 = q_z2_z1["latent"] qz2_z1_m = q_z2_z1["q_m"] qz2_z1_v = q_z2_z1["q_v"] # log_qz2_z1 = Normal(q_z2_z1["q_m"], q_z2_z1["q_v"].sqrt()).log_prob(z2).sum(-1) log_qz2_z1 = q_z2_z1["dist"].log_prob(z2) if q_z2_z1["sum_last"]: log_qz2_z1 = log_qz2_z1.sum(-1) z2_y = torch.cat([z2, ys], dim=-1) pz1_z2m, pz1_z2_v = self.decoder_z1_z2(z2_y) log_pz1_z2 = Normal(pz1_z2m, pz1_z2_v.sqrt()).log_prob(z1).sum(-1) log_pz2 = Normal(torch.zeros_like(z2), torch.ones_like(z2)).log_prob(z2).sum(-1) px_z_loc = self.x_decoder(z1) log_px_z = Bernoulli(px_z_loc).log_prob(x).sum(-1) generative_density = log_pz2 + log_pc + log_pz1_z2 + log_px_z variational_density = log_qz1_x + log_qz2_z1 log_ratio = generative_density - variational_density variables = dict( z1=z1, ys=ys, z2=z2, qz1_m=qz1_m, qz1_v=qz1_v, qz2_z1_m=qz2_z1_m, qz2_z1_v=qz2_z1_v, pz1_z2m=pz1_z2m, pz1_z2_v=pz1_z2_v, px_z_m=px_z_loc, log_qz1_x=log_qz1_x, qc_z1=qc_z1, log_qc_z1=log_qc_z1, log_qz2_z1=log_qz2_z1, log_pz2=log_pz2, log_pc=log_pc, pc=pc, log_pz1_z2=log_pz1_z2, log_px_z=log_px_z, generative_density=generative_density, variational_density=variational_density, log_ratio=log_ratio, qc_z1_all_probas=qc_z1_all_probas, df=dfs, ) # torch.cuda.synchronize() return variables
import storch import torch from torch.distributions import Bernoulli, OneHotCategorical expect = storch.method.Expect("x") probs = torch.tensor([0.95, 0.01, 0.01, 0.01, 0.01, 0.01], requires_grad=True) indices = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) b = OneHotCategorical(probs=probs) z = expect.sample(b) c = (2.4 * z * indices).sum(-1) storch.add_cost(c, "no_baseline_cost") storch.backward() expect_grad = z.grad["probs"].clone() def eval(grads): print("----------------------------------") grad_samples = storch.gather_samples(grads, "variance") mean = storch.reduce_plates(grad_samples, plates=["variance"]) print("mean grad", mean) print("expected grad", expect_grad) print("specific_diffs", (mean - expect_grad)**2) mse = storch.reduce_plates((grad_samples - expect_grad)**2).sum() print("MSE", mse) bias = (storch.reduce_plates((mean - expect_grad)**2)).sum() print("bias", bias) return bias
def distribution(self, output_net): return OneHotCategorical(logits=output_net)
def forward(self, input, args, n_particles, test=False): """ The major difference is that now we use a GRU to predict the prior z logits, instead of using a linear map T. I think trying to fit this GRU is really hard, I'm kind of concerned """ if test: n_particles = 10 else: n_particles = 1 pi = F.log_softmax(self.pi, 0) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prior_h = Variable(torch.zeros(batch_sz * n_particles, 50).cuda()) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None # use dropout on the teacher-forcing x_emb = self.lockdrop(emb, self.dropout_x) for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach() z = OneHotCategorical(logits=logits).sample() # this should be batch_sz x x_dim scores = torch.mm(self.project(torch.cat([h, z], 1)), self.emit.t()) NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) KL = (logits.exp() * (logits - prior_logits)).sum(1) loss += (NLL + KL) nlls[i] = NLL.data # set things up for next time if i != seq_len - 1: feed = torch.cat( [emb[i].repeat(n_particles, 1), self.z_emb(z)], 1) prior_h = self.z_decoder(feed, prior_h) prior_logits = F.log_softmax(self.project_z(prior_h), 1) h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN if n_particles != 1: loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) NLL = -log_sum_exp( -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log( n_particles) # not quite accurate, but what can you do else: NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
event = 2 plt_n1 = 3 plt_n2 = 2 # Define swr method swr_method = storch.method.ScoreFunctionWOR("z", k, biased=True, use_baseline=False) normal_method1 = storch.method.ScoreFunction("n1", n_samples=plt_n1) l_entropy = torch.tensor([-3.0, -3.0, 2, -2.0], requires_grad=True) h_entropy = torch.tensor([-0.1, 0.1, 0.05, -0.05], requires_grad=True) n_params = torch.tensor(0.0, requires_grad=True) d1 = OneHotCategorical(logits=l_entropy.repeat((event, 1))) d2 = OneHotCategorical(logits=h_entropy) dn1 = Normal(n_params, 1.0) # k x event x |D_yv| z_1 = swr_method.sample(d1) # k x |D_yv| z_2 = swr_method.sample(d2) print("z1", z_1) print("z2", z_2) assert z_1.shape == (min(k, d_yv ** event), event, d_yv) assert z_2.shape == (min(k, d_yv ** (event + 1)), d_yv)
def forward(self, input, args, n_particles, test=False): """ evaluation is the IWAE-10 bound """ if test: n_particles = 10 else: n_particles = 1 pi = F.log_softmax(self.pi, 0) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = (Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()), Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_())) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()), Variable(torch.zeros(batch_sz * n_particles, 50).cuda())) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None x_emb = self.lockdrop(emb, self.dropout_x) for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1) if test: q = OneHotCategorical(logits=logits) # p = OneHotCategorical(logits=prior_logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) z = q.rsample() # this should be batch_sz x x_dim scores = torch.mm(self.project(torch.cat([h[0], z], 1)), self.emit.t()) NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) # KL = q.log_prob(z) - p.log_prob(z) KL = (logits.exp() * (logits - prior_logits)).sum(1) loss += (NLL + KL) # else: # loss += (NLL + args.anneal * KL) nlls[i] = NLL.data # set things up for next time if i != seq_len - 1: feed = torch.cat( [emb[i].repeat(n_particles, 1), self.z_emb(z)], 1) prior_h = self.z_decoder(feed, prior_h) prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1) h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN if n_particles != 1: loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) NLL = -log_sum_exp( -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log( n_particles) # not quite accurate, but what can you do else: NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
def __init__(self, probs: torch.Tensor, sections: Tuple): self._sections = sections self._dists = [ OneHotCategorical(x) for x in torch.split(probs, sections, dim=-1) ]
def forward(self, input, args, n_particles, test=False): """ evaluation is the IWAE-10 bound """ pi = F.log_softmax(self.pi, 0) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = (Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()), Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_())) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()), Variable(torch.zeros(batch_sz * n_particles, 50).cuda())) accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None x_emb = self.lockdrop(emb, self.dropout_x) if test: pdb.set_trace() for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach() # if test: q = OneHotCategorical(logits=logits) p = OneHotCategorical(logits=prior_logits) a = q.sample() # else: # q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) # a = q.rsample() # to guard against being too crazy b = a + 1e-16 z = b / b.sum(1, keepdim=True) # this should be batch_sz x x_dim scores = torch.mm(self.project(torch.cat([h[0], z], 1)), self.emit.t()) NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) nlls[i] = NLL.data f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # F.log_softmax(wa, dim=0) # line 9 probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) if any_nans(probs): pdb.set_trace() # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) # shuffle! z = torch.index_select(z, 0, unrolled_idx) a, b = h h = torch.index_select(a, 0, unrolled_idx), torch.index_select( b, 0, unrolled_idx) a, b = prior_h prior_h = torch.index_select(a, 0, unrolled_idx), torch.index_select( b, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} # set things up for next time if i != seq_len - 1: feed = torch.cat( [emb[i].repeat(n_particles, 1), self.z_emb(z)], 1) prior_h = self.z_decoder(feed, prior_h) prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1) h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz * n_particles), 0
def prior(self, posterior: Distribution): return OneHotCategorical(probs=torch.ones_like(posterior.probs) / 10.0)
def forward(self, input, args, n_particles, test=False): """ n_particles is interpreted as 1 for now to not screw anything up """ if test: n_particles = 10 else: n_particles = 1 T = nn.Softmax(dim=0)(self.T) # NOTE: not in log-space pi = nn.Softmax(dim=0)(self.pi) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach() z = OneHotCategorical(logits=logits).sample() # this should be batch_sz x x_dim feed = self.project(torch.cat([h, z], 1)) # batch_sz x hidden_dim scores = torch.mm(feed, self.emit.t()) # batch_sz x x_dim NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) KL = (logits.exp() * (logits - (prior_probs + 1e-16).log())).sum(1) loss += (NLL + KL) nlls[i] = NLL.data # set things up for next time if i != seq_len - 1: prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2) h = self.hidden_rnn(emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN if n_particles != 1: loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) NLL = -log_sum_exp( -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log( n_particles) # not quite accurate, but what can you do else: NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
def forward(self, x): params = self.network(x) return OneHotCategorical(logits=params)
def kl_categorical(self, logits_q): # Analytical KL with categorical prior p_cat = OneHotCategorical(logits=self.logits_p.expand_as(logits_q)) q_cat = OneHotCategorical(logits=logits_q) KL_qp = kl_divergence(q_cat, p_cat) return KL_qp