def forward_backward(self, input): """ input: Variable([seq_len x batch_size]) """ input = input.long() seq_len, batch_size = input.size() alpha = [None for i in range(seq_len)] beta = [None for i in range(seq_len)] T = F.log_softmax(self.T, 0) pi = F.log_softmax(self.pi, 0) emit = self.calc_emit() # forward pass alpha[0] = self.log_prob(input[0], (emit, )) + pi.view(1, -1) beta[-1] = Variable(torch.zeros(batch_size, self.z_dim)) if T.is_cuda: beta[-1] = beta[-1].cuda() for t in range(1, seq_len): logprod = alpha[t - 1].unsqueeze(2).expand( batch_size, self.z_dim, self.z_dim) + T.t().unsqueeze(0) alpha[t] = self.log_prob(input[t], (emit, )) + log_sum_exp(logprod, 1) # keep around for now, but unnecessary in our models # for t in range(seq_len - 2, -1, -1): # beta_expand = beta[t + 1].unsqueeze(1).expand(batch_size, self.z_dim, self.z_dim) # beta[t] = log_sum_exp(beta_expand + T.t().unsqueeze(0), 2) + emit[input[t + 1]] log_marginal = log_sum_exp(alpha[-1] + beta[-1], dim=-1) return alpha, beta, log_marginal
def forward(self, input, args, n_particles, test=False): """ n_particles is interpreted as 1 for now to not screw anything up """ n_particles = 1 T = nn.Softmax(dim=0)(self.T) # NOTE: not in log-space pi = nn.Softmax(dim=0)(self.pi) emit = self.calc_emit() # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step z = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): # logits = self.logits(torch.cat([hidden_states[i], h], 1)) # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits))) logits = self.logits(hidden_states[i]) # build the next z sample q = RelaxedOneHotCategorical(temperature=Variable( torch.Tensor([args.temp]).cuda()), logits=logits) z = q.sample() lse = log_sum_exp(logits, dim=1).view(-1, 1) log_probs = logits - lse # now, compute the log-likelihood of the data given this z-sample # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this # data for element i given choice z emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -log_sum_exp(emission + log_probs, 1) nlls[i] = NLL.data KL = (log_probs.exp() * (log_probs - (prior_probs + 1e-16).log())).sum(1) loss += (NLL + KL) if i != seq_len - 1: prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2) # now, we calculate the final log-marginal estimator return loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), 0
def mutual_info(self, x, lengths): """ *modified from https://github.com/jxhe/vae-lagging-encoder* Calculate the approximate mutual information between z & x under distribution q(z|x). I(x, z) =E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) :param x: input sentence. (seq_len, batch_size) :param x: (list[int]) length of each sequence in batch. :return: (float) approximate mutual information. can be non-negative when n_z > 1. """ mu, logvar = self.encoder(x, lengths) x_batch, nz = mu.size() neg_entropy = (-0.5 * nz * math.log(2 * math.pi) - 0.5 * (1 + logvar).sum(-1)).mean() # [z_batch, 1, nz] z, kld = self.reparameterize(mu, logvar) z = z.unsqueeze(1) # [1, x_batch, nz] mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) var = logvar.exp() # (z_batch, x_batch, nz) dev = z - mu # dimension broadcast # (z_batch, x_batch) log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) # log q(z): aggregate posterior # [z_batch] log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch) return (neg_entropy - log_qz.mean(-1)).item()
def forward(self, input, args, n_particles, test=False): """ If n_particles != 1, this the IWAE estimator, which doesn't make sense here """ n_particles = 1 T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # in log-space, intentionally emit = self.calc_emit() # also in log-space seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) elbo = 0 NLL = 0 # now a logit prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prev_probs = None for i in range(seq_len): logits = F.log_softmax(self.logits(hidden_states[i]), 1) # log q(z_i) probs = logits.exp() # q(z_i) emission = F.embedding(input[i].repeat(n_particles), emit) # log p(x_i | z_i) # unary potentials elbo += (emission * probs).sum(1) # E_q[log p(x_i | z_i)] NLL += -(emission * probs).sum(1).data # binary potentials q(z_t)q(z_{t - 1})log p(z_t | z_{t - 1}) if i != 0: elbo += (prev_probs.unsqueeze(1) * probs.unsqueeze(2) * T.unsqueeze(0)).sum(2).sum(1) else: # add the log p(z_1) term elbo += (probs * prior_logits).sum(1) # entropy term - E[-log q] elbo -= (logits * probs).sum(1) prev_probs = probs if n_particles != 1: elbo = log_sum_exp(elbo.view(n_particles, batch_sz), 0) - math.log(n_particles) NLL = NLL.view(n_particles, batch_sz).mean(0) # now, we calculate the final log-marginal estimator return -elbo.sum(), NLL.sum(), (seq_len * batch_sz), 0
def forward_backward(self, input, speedup=False): """ Modify the forward-backward to compute beta[t], since we need that for checking the sampling in the particle filter case """ input = input.long() seq_len, batch_size = input.size() alpha = [None for i in range(seq_len)] beta = [None for i in range(seq_len)] T = F.log_softmax(self.T, 0) pi = F.log_softmax(self.pi, 0) emit = self.calc_emit() # forward pass alpha[0] = self.log_prob(input[0], (emit, )) + pi.view(1, -1) beta[-1] = Variable(torch.zeros(batch_size, self.z_dim)) if T.is_cuda: beta[-1] = beta[-1].cuda() for t in range(1, seq_len): logprod = alpha[t - 1].unsqueeze(2).expand( batch_size, self.z_dim, self.z_dim) + T.t().unsqueeze(0) alpha[t] = self.log_prob(input[t], (emit, )) + log_sum_exp(logprod, 1) log_marginal = log_sum_exp(alpha[-1] + beta[-1], dim=-1) if speedup: return 0, 0, log_marginal else: for t in range(seq_len - 2, -1, -1): beta[t] = log_sum_exp( T.unsqueeze(0) + beta[t + 1].unsqueeze(2) + F.embedding(input[t + 1], emit).unsqueeze(2), 1) return [ alpha[i] + beta[i] - log_marginal.unsqueeze(1) for i in range(seq_len) ], 0, log_marginal
def forward(self, input, args, test=False): NO_HMM = False seq_len, batch_size = input.size() # compute the loss as the sum of the forward-backward loss if not NO_HMM: alpha, _, log_marginal = self.forward_backward(input) emb = self.inp_embedding(input) T = F.log_softmax(self.T, 0) pi = F.log_softmax(self.pi, 0).unsqueeze(0).expand(batch_size, self.z_dim) if self.separate_opt: pi = pi.detach() T = T.detach() h = (Variable(torch.zeros(batch_size, self.lstm_hidden_size).cuda()), Variable(torch.zeros(batch_size, self.lstm_hidden_size).cuda())) NLL = 0 # now, compute the filtered posterior and together with the LSTM feed data into the net-output # note that \alpha(t) contains information about the current x, so we need to prop forward current_state = None for i in range(seq_len): if not NO_HMM: if i == 0: hmm_post = pi else: hmm_post = log_sum_exp( T.unsqueeze(0) + current_state.unsqueeze(1), 2) if NO_HMM: hmm_post = Variable(torch.zeros(batch_size, self.z_dim).cuda()) else: hmm_post = hmm_post.exp() scores = self.project(torch.cat([h[0], hmm_post], 1)) NLL += nn.CrossEntropyLoss(size_average=False)(scores, input[i]) # feed information from the current state into the next prediction (i.e. teacher-forcing) h = self.lstm(emb[i], h) if not NO_HMM: current_state = F.log_softmax(alpha[i], 1) if self.separate_opt: current_state = current_state.detach() if NO_HMM: loss = NLL.sum() else: loss = -log_marginal.sum() + NLL.sum() return loss, NLL.data.sum()
def sampled_elbo(self, input, args, n_particles, emb, hidden_states): seq_len, batch_sz = input.size() T = nn.Softmax(dim=0)(self.T) # NOTE: not in log-space pi = nn.Softmax(dim=0)(self.pi) emit = self.calc_emit() hidden_states = hidden_states.repeat(1, n_particles, 1) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a value in probability space prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): logits = self.logits(hidden_states[i]) # build the next z sample p = RelaxedOneHotCategorical(temperature=self.temp_prior, probs=prior_probs) q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() log_probs = F.log_softmax(logits, dim=1) # now, compute the log-likelihood of the data given this z-sample # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this # data for element i given choice z emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -log_sum_exp(emission + log_probs, 1) nlls[i] = NLL.data KL = q.log_prob(z) - p.log_prob(z) # pretty inexact loss += (NLL + KL) if i != seq_len - 1: prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2) (loss.sum() / (seq_len * batch_sz * n_particles)).backward(retain_graph=True) return loss, 0, seq_len * batch_sz * n_particles, 0
def sampled_filter(self, input, args, n_particles, emb, hidden_states): seq_len, batch_sz = input.size() T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # NOTE: in log-space emit = self.calc_emit() hidden_states = hidden_states.repeat(1, n_particles, 1) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} resamples = 0 # in log probability space prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) for i in range(seq_len): # the approximate posterior comes from the same thing as before logits = self.logits(hidden_states[i]) if not self.training: # this is crucial!! p = OneHotCategorical(logits=prior_logits) q = OneHotCategorical(logits=logits) z = q.sample() else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() # now, compute the log-likelihood of the data given this z-sample emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -(emission * z).sum(1) # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,)) # diff. w.r.t. z nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # F.log_softmax(wa, dim=0) # line 9 # sample ancestors, and reindex everything if args.filter: probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) z = torch.index_select(z, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # now in log-probability space prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2) if self.training: (-loss.sum() / (seq_len * batch_sz * n_particles)).backward(retain_graph=True) return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
def sampled_iwae(self, input, args, n_particles, loss, tokens): seq_len, batch_sz = input.size() loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) (loss.sum() / tokens).backward(retain_graph=True)
def forward(self, input, args, n_particles, test=False): """ n_particles is interpreted as 1 for now to not screw anything up """ if test: n_particles = 10 else: n_particles = 1 T = nn.Softmax(dim=0)(self.T) # NOTE: not in log-space pi = nn.Softmax(dim=0)(self.pi) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach() z = OneHotCategorical(logits=logits).sample() # this should be batch_sz x x_dim feed = self.project(torch.cat([h, z], 1)) # batch_sz x hidden_dim scores = torch.mm(feed, self.emit.t()) # batch_sz x x_dim NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) KL = (logits.exp() * (logits - (prior_probs + 1e-16).log())).sum(1) loss += (NLL + KL) nlls[i] = NLL.data # set things up for next time if i != seq_len - 1: prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2) h = self.hidden_rnn(emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN if n_particles != 1: loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) NLL = -log_sum_exp( -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log( n_particles) # not quite accurate, but what can you do else: NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
def forward(self, input, args, n_particles, test=False): """ evaluation is the IWAE-10 bound """ pi = F.log_softmax(self.pi, 0) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = (Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()), Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_())) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()), Variable(torch.zeros(batch_sz * n_particles, 50).cuda())) accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None x_emb = self.lockdrop(emb, self.dropout_x) if test: pdb.set_trace() for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach() # if test: q = OneHotCategorical(logits=logits) p = OneHotCategorical(logits=prior_logits) a = q.sample() # else: # q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) # a = q.rsample() # to guard against being too crazy b = a + 1e-16 z = b / b.sum(1, keepdim=True) # this should be batch_sz x x_dim scores = torch.mm(self.project(torch.cat([h[0], z], 1)), self.emit.t()) NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) nlls[i] = NLL.data f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # F.log_softmax(wa, dim=0) # line 9 probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) if any_nans(probs): pdb.set_trace() # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) # shuffle! z = torch.index_select(z, 0, unrolled_idx) a, b = h h = torch.index_select(a, 0, unrolled_idx), torch.index_select( b, 0, unrolled_idx) a, b = prior_h prior_h = torch.index_select(a, 0, unrolled_idx), torch.index_select( b, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} # set things up for next time if i != seq_len - 1: feed = torch.cat( [emb[i].repeat(n_particles, 1), self.z_emb(z)], 1) prior_h = self.z_decoder(feed, prior_h) prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1) h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz * n_particles), 0
def forward(self, input, args, n_particles, test=False): """ evaluation is the IWAE-10 bound """ if test: n_particles = 10 else: n_particles = 1 pi = F.log_softmax(self.pi, 0) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = (Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()), Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_())) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()), Variable(torch.zeros(batch_sz * n_particles, 50).cuda())) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None x_emb = self.lockdrop(emb, self.dropout_x) for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1) if test: q = OneHotCategorical(logits=logits) # p = OneHotCategorical(logits=prior_logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) z = q.rsample() # this should be batch_sz x x_dim scores = torch.mm(self.project(torch.cat([h[0], z], 1)), self.emit.t()) NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) # KL = q.log_prob(z) - p.log_prob(z) KL = (logits.exp() * (logits - prior_logits)).sum(1) loss += (NLL + KL) # else: # loss += (NLL + args.anneal * KL) nlls[i] = NLL.data # set things up for next time if i != seq_len - 1: feed = torch.cat( [emb[i].repeat(n_particles, 1), self.z_emb(z)], 1) prior_h = self.z_decoder(feed, prior_h) prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1) h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN if n_particles != 1: loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) NLL = -log_sum_exp( -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log( n_particles) # not quite accurate, but what can you do else: NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
def forward(self, input, args, n_particles, test=False): """ The major difference is that now we use a GRU to predict the prior z logits, instead of using a linear map T. I think trying to fit this GRU is really hard, I'm kind of concerned """ if test: n_particles = 10 else: n_particles = 1 pi = F.log_softmax(self.pi, 0) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prior_h = Variable(torch.zeros(batch_sz * n_particles, 50).cuda()) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None # use dropout on the teacher-forcing x_emb = self.lockdrop(emb, self.dropout_x) for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach() z = OneHotCategorical(logits=logits).sample() # this should be batch_sz x x_dim scores = torch.mm(self.project(torch.cat([h, z], 1)), self.emit.t()) NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) KL = (logits.exp() * (logits - prior_logits)).sum(1) loss += (NLL + KL) nlls[i] = NLL.data # set things up for next time if i != seq_len - 1: feed = torch.cat( [emb[i].repeat(n_particles, 1), self.z_emb(z)], 1) prior_h = self.z_decoder(feed, prior_h) prior_logits = F.log_softmax(self.project_z(prior_h), 1) h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN if n_particles != 1: loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) NLL = -log_sum_exp( -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log( n_particles) # not quite accurate, but what can you do else: NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
def forward(self, input, args, n_particles, test=False): T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # NOTE: in log-space emit = self.calc_emit() # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} resamples = 0 # in log probability space prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits))) logits = self.logits( nn.functional.relu( self.z_decoder(torch.cat([hidden_states[i], h], 1), logits))) # build the next z sample if any_nans(logits): pdb.set_trace() if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() h = z # prior if any_nans(prior_probs): pdb.set_trace() if test: p = OneHotCategorical(logits=prior_probs) else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_probs) if any_nans(prior_probs): pdb.set_trace() if any_nans(logits): pdb.set_trace() # now, compute the log-likelihood of the data given this z-sample NLL = -self.decode(z, input[i].repeat(n_particles), (emit, )) # diff. w.r.t. z nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) # sample ancestors, and reindex everything Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # line 9 if args.filter: probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) h = torch.index_select(h, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # now in probability space prior_probs = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2) # let's normalize things - slower, but safer # prior_probs += 0.01 # prior_probs = prior_probs / prior_probs.sum(1, keepdim=True) # # if ((prior_probs.sum(1) - 1) > 1e-3).any()[0]: # pdb.set_trace() if any_nans(loss): pdb.set_trace() # now, we calculate the final log-marginal estimator return -loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), resamples