def test_one_hot_categorical_shape(self): dist = OneHotCategorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) self.assertEqual(dist._batch_shape, torch.Size((3,))) self.assertEqual(dist._event_shape, torch.Size((2,))) self.assertEqual(dist.sample().size(), torch.Size((3, 2))) self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,))) self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))
class TestMultiOneHotCategorical(unittest.TestCase): def setUp(self) -> None: self.test_probs = torch.tensor([[0.3, 0.2, 0.4, 0.1, 0.25, 0.5, 0.25, 0.3, 0.4, 0.1, 0.1, 0.1], [0.2, 0.3, 0.1, 0.4, 0.5, 0.3, 0.2, 0.2, 0.3, 0.2, 0.2, 0.1]]) self.test_sections = (4, 3, 5) self.test_actions = torch.tensor([[0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0.]]).long() self.test_sected_actions = torch.split(self.test_actions, self.test_sections, dim=-1) self.test_multi_onehot_categorical = MultiOneHotCategorical(self.test_probs, self.test_sections) self.test_onehot_categorical1 = OneHotCategorical(self.test_probs[:, :4]) self.test_onehot_categorical2 = OneHotCategorical(self.test_probs[:, 4:7]) self.test_onehot_categorical3 = OneHotCategorical(self.test_probs[:, 7:]) def test_log_prob(self): test_cat1_log_prob = self.test_onehot_categorical1.log_prob(self.test_sected_actions[0]) test_cat2_log_prob = self.test_onehot_categorical2.log_prob(self.test_sected_actions[1]) test_cat3_log_prob = self.test_onehot_categorical3.log_prob(self.test_sected_actions[2]) test_multi_cat_log_prob = self.test_multi_onehot_categorical.log_prob(self.test_actions) print(test_multi_cat_log_prob) print(test_cat1_log_prob) self.assertEqual(test_cat1_log_prob.shape, test_multi_cat_log_prob.shape) self.assertTrue( torch.equal(test_cat1_log_prob + test_cat2_log_prob + test_cat3_log_prob, test_multi_cat_log_prob)) def test_sample(self): test_cat1_sample = self.test_onehot_categorical1.sample() test_cat2_sample = self.test_onehot_categorical2.sample() test_cat3_sample = self.test_onehot_categorical3.sample() test_cat_sample = torch.cat([test_cat1_sample, test_cat2_sample, test_cat3_sample], dim=-1) test_multi_cat_sample = self.test_multi_onehot_categorical.sample() self.assertEqual(test_cat_sample.shape, test_multi_cat_sample.shape) self.assertTrue(torch.equal(test_cat_sample.sum(dim=-1), test_multi_cat_sample.sum(dim=-1))) def test_entropy(self): test_cat1_entropy = self.test_onehot_categorical1.entropy() test_cat2_entropy = self.test_onehot_categorical2.entropy() test_cat3_entropy = self.test_onehot_categorical3.entropy() test_multi_cat_entropy = self.test_multi_onehot_categorical.entropy() self.assertTrue(torch.equal(test_cat1_entropy + test_cat2_entropy + test_cat3_entropy, test_multi_cat_entropy), "Expected same entropy!!!")
def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]: """ Returns a sample of the policy on the input with the mean and log probability of the sample. Args: inp: The input tensor to put through the network.categorie Returns: A multi-categorical distribution of the network. """ linear = self.linear(inp) value = self.value(linear) value = value.view(*value.shape[:-1], -1, self.num_classes) probs = self.probs(value) dist = OneHotCategorical(probs) sample = dist.sample() log_prob = dist.log_prob(sample) # Straight through gradient trick sample = sample + probs - probs.detach() mean = torch.argmax(probs, dim=-1).values return sample, log_prob, None
def forward(self, x): # For convenience we use torch.distributions to sample and compute the values of interest for the distribution see (https://pytorch.org/docs/stable/distributions.html) for more details. probs = self.encode(x.view(-1, 784)) m = OneHotCategorical(probs) action = m.sample() log_prob = m.log_prob(action) entropy = m.entropy() return self.decode(action), log_prob, entropy
def test_one_hot_categorical_2d(self): probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] p = Variable(torch.Tensor(probabilities), requires_grad=True) s = Variable(torch.Tensor(probabilities_1), requires_grad=True) self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3)) self.assertEqual(OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3)) self.assertEqual(OneHotCategorical(p).sample_n(6).size(), (6, 2, 3)) self._gradcheck_log_prob(OneHotCategorical, (p,)) dist = OneHotCategorical(p) x = dist.sample() self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))
def forward(self, input, targets, args, n_particles, criterion, test=False): """ This version takes the inputs, and does not expose the logits, but instead computes the losses directly """ # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (h, c) = self.encoder(emb, hidden) # teacher-forcing out_emb = self.dropout(self.dec_embedding(targets)) # now, we'll replicate it for each particle - it's currently [seq_len x batch_sz x nhid] hidden_states = hidden_states.repeat(1, n_particles, 1) out_emb = out_emb.repeat(1, n_particles, 1) # now [seq_len x (n_particles x batch_sz) x nhid] # out_emb, hidden_states should be viewed as (n_particles x batch_sz) - this means that's true for h as well # run the z-decoder at this point, evaluating the NLL at each step p_h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) # initially zero h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) d_h = self.init_hidden(batch_sz * n_particles, self.nhid, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log(n_particles) # will contain log w_{t - 1} resamples = 0 for i in range(seq_len): h = self.z_decoder(hidden_states[i], h) logits = self.logits(h) # build the next z sample if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() h = z # prior if test: p = OneHotCategorical(logits=p_h) else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=p_h) # now, compute the log-likelihood of the data given this mean, and the input out_emb d_h = self.decoder(torch.cat([z, out_emb[i]], 1), d_h) decoder_logits = self.out_embedding(d_h) NLL = criterion(decoder_logits, input[i].repeat(n_particles)) nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + args.anneal * (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) # sample ancestors, and reindex everything Z = log_sum_exp(wa, dim=0) # line 7 if (Z.data > 0.1).any(): pdb.set_trace() loss += Z # line 8 accumulated_weights = wa - Z # line 9 probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1./probs.pow(2).sum(0) # resample / RSAMP if 3 batch elements need resampling if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze(1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors.t().contiguous()+offsets).view(-1) h = torch.index_select(h, 0, unrolled_idx) p_h = torch.index_select(p_h, 0, unrolled_idx) d_h = torch.index_select(d_h, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log(n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # build the next mean prediction, feeding in the correct ancestor p_h = self.ar_prior_logits(torch.cat([h, out_emb[i]], 1), p_h) # now, we calculate the final log-marginal estimator nll = nlls.view(seq_len, n_particles, batch_sz).mean(1).sum() return -loss.sum(), nll, (seq_len * batch_sz), resamples
def sampled_filter(self, input, args, n_particles, emb, hidden_states): seq_len, batch_sz = input.size() T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # NOTE: in log-space emit = self.calc_emit() hidden_states = hidden_states.repeat(1, n_particles, 1) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} resamples = 0 # in log probability space prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) for i in range(seq_len): # the approximate posterior comes from the same thing as before logits = self.logits(hidden_states[i]) if not self.training: # this is crucial!! p = OneHotCategorical(logits=prior_logits) q = OneHotCategorical(logits=logits) z = q.sample() else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() # now, compute the log-likelihood of the data given this z-sample emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -(emission * z).sum(1) # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,)) # diff. w.r.t. z nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # F.log_softmax(wa, dim=0) # line 9 # sample ancestors, and reindex everything if args.filter: probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) z = torch.index_select(z, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # now in log-probability space prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2) if self.training: (-loss.sum() / (seq_len * batch_sz * n_particles)).backward(retain_graph=True) return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
def forward(self, input, args, n_particles, test=False): """ evaluation is the IWAE-10 bound """ pi = F.log_softmax(self.pi, 0) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = (Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()), Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_())) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()), Variable(torch.zeros(batch_sz * n_particles, 50).cuda())) accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None x_emb = self.lockdrop(emb, self.dropout_x) if test: pdb.set_trace() for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach() # if test: q = OneHotCategorical(logits=logits) p = OneHotCategorical(logits=prior_logits) a = q.sample() # else: # q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) # a = q.rsample() # to guard against being too crazy b = a + 1e-16 z = b / b.sum(1, keepdim=True) # this should be batch_sz x x_dim scores = torch.mm(self.project(torch.cat([h[0], z], 1)), self.emit.t()) NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) nlls[i] = NLL.data f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # F.log_softmax(wa, dim=0) # line 9 probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) if any_nans(probs): pdb.set_trace() # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) # shuffle! z = torch.index_select(z, 0, unrolled_idx) a, b = h h = torch.index_select(a, 0, unrolled_idx), torch.index_select( b, 0, unrolled_idx) a, b = prior_h prior_h = torch.index_select(a, 0, unrolled_idx), torch.index_select( b, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} # set things up for next time if i != seq_len - 1: feed = torch.cat( [emb[i].repeat(n_particles, 1), self.z_emb(z)], 1) prior_h = self.z_decoder(feed, prior_h) prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1) h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz * n_particles), 0
class PPOTorchPolicy(TorchPolicy): def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) self.device = torch.device('cpu') # Get hyperparameters self.alpha = config['alpha'] self.clip_ratio = config['clip_ratio'] self.gamma = config['gamma'] self.lam = config['lambda'] self.lr_pi = config['lr_pi'] self.lr_vf = config['lr_vf'] self.model_hidden_sizes = config['model_hidden_sizes'] self.num_skills = config['num_skills'] self.skill_input = config['skill_input'] self.target_kl = config['target_kl'] self.use_diayn = config['use_diayn'] self.use_env_rewards = config['use_env_rewards'] self.use_gae = config['use_gae'] # Initialize actor-critic model self.skills = OneHotCategorical(torch.ones((1, self.num_skills))) if self.skill_input is not None: skill_vec = [0.] * (self.num_skills - 1) skill_vec.insert(self.skill_input, 1.) self.z = torch.as_tensor([skill_vec], dtype=torch.float32) else: self.z = None self.model = SkilledA2C(observation_space, action_space, hidden_sizes=self.model_hidden_sizes, skills=self.skills).to(self.device) # Set up optimizers for policy and value function self.pi_optimizer = Adam(self.model.pi.parameters(), self.lr_pi) self.vf_optimizer = Adam(self.model.vf.parameters(), self.lr_vf) self.disc_optimizer = Adam(self.model.discriminator.parameters(), self.lr_vf) def compute_loss_d(self, batch): obs, z = batch[SampleBatch.CUR_OBS], batch[SKILLS] logq_z = self.model.discriminator(obs) return nn.functional.nll_loss(logq_z, z.argmax(dim=-1)) def compute_loss_pi(self, batch): obs, act, z = batch[ SampleBatch.CUR_OBS], batch[ACTIVATIONS], batch[SKILLS] adv, logp_old = batch[Postprocessing.ADVANTAGES], batch[ SampleBatch.ACTION_LOGP] clip_ratio = self.clip_ratio # Policy loss oz = torch.cat([obs, z], dim=-1) pi, logp = self.model.pi(oz, act) ratio = torch.exp(logp - logp_old) clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv loss_pi = -(torch.min(ratio * adv, clip_adv)).mean() # Useful extra info approx_kl = (logp_old - logp).mean().item() ent = pi.entropy().mean().item() clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clip_frac = torch.as_tensor(clipped, dtype=torch.float32).mean().item() pi_info = dict(kl=approx_kl, ent=ent, cf=clip_frac) return loss_pi, pi_info def compute_loss_v(self, batch): obs, z = batch[SampleBatch.NEXT_OBS], batch[SKILLS] v_pred_old, v_targ = batch[SampleBatch.VF_PREDS], batch[ Postprocessing.VALUE_TARGETS] oz = torch.cat([obs, z], dim=-1) v_pred = self.model.vf(oz) v_pred_clipped = v_pred_old + torch.clamp( v_pred - v_pred_old, -self.clip_ratio, self.clip_ratio) loss_clipped = (v_pred_clipped - v_targ).pow(2) loss_unclipped = (v_pred - v_targ).pow(2) return 0.5 * torch.max(loss_unclipped, loss_clipped).mean() def _convert_activation_to_action(self, activation): min_ = self.action_space.low max_ = self.action_space.high return tanh_to_action(activation, min_, max_) def _normalize_obs(self, obs): min_ = self.observation_space.low max_ = self.observation_space.high return normalize_obs(obs, min_, max_) @override(Policy) def compute_actions(self, obs, **kwargs): # Sample a skill at the start of each episode if self.z is None: self.z = self.skills.sample() o = self._normalize_obs(obs) a, v, logp_a, logq_z = self.model.step( torch.as_tensor(o, dtype=torch.float32), self.z) actions = self._convert_activation_to_action(a) extras = { ACTIVATIONS: a, SampleBatch.VF_PREDS: v, SampleBatch.ACTION_LOGP: logp_a, SKILLS: self.z.numpy(), SKILL_LOGQ: logq_z } return actions, [], extras @override(Policy) def postprocess_trajectory(self, batch, other_agent_batches=None, episode=None): """Adds the policy logits, VF preds, and advantages to the trajectory.""" completed = batch["dones"][-1] if completed: # Force end of episode reward last_r = 0.0 # Reset skill at the end of each episode self.z = None else: next_state = [] for i in range(self.num_state_tensors()): next_state.append([batch["state_out_{}".format(i)][-1]]) obs = [batch[SampleBatch.NEXT_OBS][-1]] o = self._normalize_obs(obs) _, last_r, _, _ = self.model.step( torch.as_tensor(o, dtype=torch.float32), self.z) last_r = last_r.item() # Compute DIAYN rewards if self.use_diayn: z = torch.as_tensor(batch[SKILLS], dtype=torch.float32) logp_z = self.skills.log_prob(z).numpy() logq_z = batch[SKILL_LOGQ][:, z.argmax(dim=-1)[0].item()] entropy_reg = self.alpha * batch[SampleBatch.ACTION_LOGP] diayn_rewards = logq_z - logp_z - entropy_reg if self.use_env_rewards: batch[SampleBatch.REWARDS] += diayn_rewards else: batch[SampleBatch.REWARDS] = diayn_rewards batch = compute_advantages(batch, last_r, gamma=self.gamma, lambda_=self.lam, use_gae=self.use_gae) return batch @override(Policy) def learn_on_batch(self, postprocessed_batch): postprocessed_batch[SampleBatch.CUR_OBS] = self._normalize_obs( postprocessed_batch[SampleBatch.CUR_OBS]) train_batch = self._lazy_tensor_dict(postprocessed_batch) # Train policy with multiple steps of gradient descent self.pi_optimizer.zero_grad() loss_pi, pi_info = self.compute_loss_pi(train_batch) # if pi_info['kl'] > 1.5 * self.target_kl: # logger.info('Early stopping at step %d due to reaching max kl.' % i) # return loss_pi.backward() self.pi_optimizer.step() # Value function learning self.vf_optimizer.zero_grad() loss_v = self.compute_loss_v(train_batch) loss_v.backward() self.vf_optimizer.step() # Discriminator learning self.disc_optimizer.zero_grad() loss_d = self.compute_loss_d(train_batch) loss_d.backward() self.disc_optimizer.step() grad_info = dict(pi_loss=loss_pi.item(), vf_loss=loss_v.item(), d_loss=loss_d.item(), **pi_info) return {LEARNER_STATS_KEY: grad_info}
def forward(self, input, args, n_particles, test=False): T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # NOTE: in log-space emit = self.calc_emit() # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} resamples = 0 # in log probability space prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits))) logits = self.logits( nn.functional.relu( self.z_decoder(torch.cat([hidden_states[i], h], 1), logits))) # build the next z sample if any_nans(logits): pdb.set_trace() if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() h = z # prior if any_nans(prior_probs): pdb.set_trace() if test: p = OneHotCategorical(logits=prior_probs) else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_probs) if any_nans(prior_probs): pdb.set_trace() if any_nans(logits): pdb.set_trace() # now, compute the log-likelihood of the data given this z-sample NLL = -self.decode(z, input[i].repeat(n_particles), (emit, )) # diff. w.r.t. z nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) # sample ancestors, and reindex everything Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # line 9 if args.filter: probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) h = torch.index_select(h, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # now in probability space prior_probs = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2) # let's normalize things - slower, but safer # prior_probs += 0.01 # prior_probs = prior_probs / prior_probs.sum(1, keepdim=True) # # if ((prior_probs.sum(1) - 1) > 1e-3).any()[0]: # pdb.set_trace() if any_nans(loss): pdb.set_trace() # now, we calculate the final log-marginal estimator return -loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), resamples