def forward(self, input, targets, args, n_particles, criterion, test=False): """ This version takes the inputs, and does not expose the logits, but instead computes the losses directly """ # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (h, c) = self.encoder(emb, hidden) # teacher-forcing out_emb = self.dropout(self.dec_embedding(targets)) # now, we'll replicate it for each particle - it's currently [seq_len x batch_sz x nhid] hidden_states = hidden_states.repeat(1, n_particles, 1) out_emb = out_emb.repeat(1, n_particles, 1) # now [seq_len x (n_particles x batch_sz) x nhid] # out_emb, hidden_states should be viewed as (n_particles x batch_sz) - this means that's true for h as well # run the z-decoder at this point, evaluating the NLL at each step p_h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) # initially zero h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) d_h = self.init_hidden(batch_sz * n_particles, self.nhid, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log(n_particles) # will contain log w_{t - 1} resamples = 0 for i in range(seq_len): h = self.z_decoder(hidden_states[i], h) logits = self.logits(h) # build the next z sample if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() h = z # prior if test: p = OneHotCategorical(logits=p_h) else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=p_h) # now, compute the log-likelihood of the data given this mean, and the input out_emb d_h = self.decoder(torch.cat([z, out_emb[i]], 1), d_h) decoder_logits = self.out_embedding(d_h) NLL = criterion(decoder_logits, input[i].repeat(n_particles)) nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + args.anneal * (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) # sample ancestors, and reindex everything Z = log_sum_exp(wa, dim=0) # line 7 if (Z.data > 0.1).any(): pdb.set_trace() loss += Z # line 8 accumulated_weights = wa - Z # line 9 probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1./probs.pow(2).sum(0) # resample / RSAMP if 3 batch elements need resampling if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze(1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors.t().contiguous()+offsets).view(-1) h = torch.index_select(h, 0, unrolled_idx) p_h = torch.index_select(p_h, 0, unrolled_idx) d_h = torch.index_select(d_h, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log(n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # build the next mean prediction, feeding in the correct ancestor p_h = self.ar_prior_logits(torch.cat([h, out_emb[i]], 1), p_h) # now, we calculate the final log-marginal estimator nll = nlls.view(seq_len, n_particles, batch_sz).mean(1).sum() return -loss.sum(), nll, (seq_len * batch_sz), resamples
def sampled_filter(self, input, args, n_particles, emb, hidden_states): seq_len, batch_sz = input.size() T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # NOTE: in log-space emit = self.calc_emit() hidden_states = hidden_states.repeat(1, n_particles, 1) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} resamples = 0 # in log probability space prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) for i in range(seq_len): # the approximate posterior comes from the same thing as before logits = self.logits(hidden_states[i]) if not self.training: # this is crucial!! p = OneHotCategorical(logits=prior_logits) q = OneHotCategorical(logits=logits) z = q.sample() else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() # now, compute the log-likelihood of the data given this z-sample emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -(emission * z).sum(1) # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,)) # diff. w.r.t. z nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # F.log_softmax(wa, dim=0) # line 9 # sample ancestors, and reindex everything if args.filter: probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) z = torch.index_select(z, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # now in log-probability space prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2) if self.training: (-loss.sum() / (seq_len * batch_sz * n_particles)).backward(retain_graph=True) return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
def forward(self, input, args, n_particles, test=False): T = nn.Softmax(dim=0)(self.T) # NOTE: not in log-space pi = nn.Softmax(dim=0)(self.pi) emit = self.calc_emit() # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step z = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): # logits = self.logits(torch.cat([hidden_states[i], h], 1)) # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits))) logits = self.logits( nn.functional.relu( self.z_decoder(torch.cat([hidden_states[i], z], 1), logits))) # build the next z sample if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() lse = log_sum_exp(logits, dim=1).view(-1, 1) log_probs = logits - lse # now, compute the log-likelihood of the data given this z-sample # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this # data for element i given choice z emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -log_sum_exp(emission + log_probs, 1) nlls[i] = NLL.data KL = (log_probs.exp() * (log_probs - (prior_probs + 1e-16).log())).sum(1) loss += (NLL + KL) if i != seq_len - 1: prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2) # now, we calculate the final log-marginal estimator return loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), 0
def forward(self, input, args, n_particles, test=False): """ evaluation is the IWAE-10 bound """ if test: n_particles = 10 else: n_particles = 1 pi = F.log_softmax(self.pi, 0) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = (Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()), Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_())) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()), Variable(torch.zeros(batch_sz * n_particles, 50).cuda())) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None x_emb = self.lockdrop(emb, self.dropout_x) for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1) if test: q = OneHotCategorical(logits=logits) # p = OneHotCategorical(logits=prior_logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) z = q.rsample() # this should be batch_sz x x_dim scores = torch.mm(self.project(torch.cat([h[0], z], 1)), self.emit.t()) NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) # KL = q.log_prob(z) - p.log_prob(z) KL = (logits.exp() * (logits - prior_logits)).sum(1) loss += (NLL + KL) # else: # loss += (NLL + args.anneal * KL) nlls[i] = NLL.data # set things up for next time if i != seq_len - 1: feed = torch.cat( [emb[i].repeat(n_particles, 1), self.z_emb(z)], 1) prior_h = self.z_decoder(feed, prior_h) prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1) h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN if n_particles != 1: loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) NLL = -log_sum_exp( -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log( n_particles) # not quite accurate, but what can you do else: NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
def forward(self, input, args, n_particles, test=False): T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # NOTE: in log-space emit = self.calc_emit() # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} resamples = 0 # in log probability space prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits))) logits = self.logits( nn.functional.relu( self.z_decoder(torch.cat([hidden_states[i], h], 1), logits))) # build the next z sample if any_nans(logits): pdb.set_trace() if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() h = z # prior if any_nans(prior_probs): pdb.set_trace() if test: p = OneHotCategorical(logits=prior_probs) else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_probs) if any_nans(prior_probs): pdb.set_trace() if any_nans(logits): pdb.set_trace() # now, compute the log-likelihood of the data given this z-sample NLL = -self.decode(z, input[i].repeat(n_particles), (emit, )) # diff. w.r.t. z nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) # sample ancestors, and reindex everything Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # line 9 if args.filter: probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) h = torch.index_select(h, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # now in probability space prior_probs = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2) # let's normalize things - slower, but safer # prior_probs += 0.01 # prior_probs = prior_probs / prior_probs.sum(1, keepdim=True) # # if ((prior_probs.sum(1) - 1) > 1e-3).any()[0]: # pdb.set_trace() if any_nans(loss): pdb.set_trace() # now, we calculate the final log-marginal estimator return -loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), resamples