def forward(self, input, memory_bank, memory_lengths=None, coverage=None, q_scores=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) q_scores (`FloatTensor`): the attention params from the inference network Returns: (`FloatTensor`, `FloatTensor`): * Weighted context vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` * Unormalized attention scores for each query `[batch x tgt_len x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) if q_scores is not None: # oh, I guess this is super messy if q_scores.alpha is not None: q_scores = Params( alpha=q_scores.alpha.unsqueeze(1), log_alpha=q_scores.log_alpha.unsqueeze(1), dist_type=q_scores.dist_type, ) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) # compute attention scores, as in Luong et al. # Params should be T x N x S if self.p_dist_type == "categorical": scores = self.score(input, memory_bank) if memory_lengths is not None: # mask : N x T x S mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. scores.data.masked_fill_(~mask, -float('inf')) if self.k > 0 and self.k < scores.size(-1): topk, idx = scores.data.topk(self.k) new_attn_score = torch.zeros_like(scores.data).fill_( float("-inf")) new_attn_score = new_attn_score.scatter_(2, idx, topk) scores = new_attn_score log_scores = F.log_softmax(scores, dim=-1) scores = log_scores.exp() c_align_vectors = scores p_scores = Params( alpha=scores, log_alpha=log_scores, dist_type=self.p_dist_type, ) # each context vector c_t is the weighted average # over all the source hidden states context_c = torch.bmm(c_align_vectors, memory_bank) if self.mode != 'wsram': concat_c = torch.cat([input, context_c], -1) # N x T x H h_c = self.tanh(self.linear_out(concat_c)) else: h_c = None # sample or enumerate # y_align_vectors: K x N x T x S q_sample, p_sample, sample_log_probs = None, None, None sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = None, None, None if self.mode == "sample": if q_scores is None or self.use_prior: p_sample, sample_log_probs = self.sample_attn( p_scores, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = p_sample else: q_sample, sample_log_probs = self.sample_attn( q_scores, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = q_sample elif self.mode == "gumbel": if q_scores is None or self.use_prior: p_sample, _ = self.sample_attn_gumbel( p_scores, self.temperature, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = p_sample else: q_sample, _ = self.sample_attn_gumbel( q_scores, self.temperature, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = q_sample elif self.mode == "enum" or self.mode == "exact": y_align_vectors = None elif self.mode == "wsram": assert q_scores is not None q_sample, sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = self.sample_attn_wsram( q_scores, p_scores, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = q_sample # context_y: K x N x T x H if y_align_vectors is not None: context_y = torch.bmm( y_align_vectors.view(-1, targetL, sourceL), memory_bank.unsqueeze(0).repeat(self.n_samples, 1, 1, 1).view( -1, sourceL, dim)).view(self.n_samples, batch, targetL, dim) else: # For enumerate, K = S. # memory_bank: N x S x H context_y = ( memory_bank.unsqueeze(0).repeat(targetL, 1, 1, 1) # T, N, S, H .permute(2, 1, 0, 3)) # S, N, T, H input = input.unsqueeze(0).repeat(context_y.size(0), 1, 1, 1) concat_y = torch.cat([input, context_y], -1) # K x N x T x H h_y = self.tanh(self.linear_out(concat_y)) if one_step: if h_c is not None: # N x H h_c = h_c.squeeze(1) # N x S c_align_vectors = c_align_vectors.squeeze(1) context_c = context_c.squeeze(1) # K x N x H h_y = h_y.squeeze(2) # K x N x S #y_align_vectors = y_align_vectors.squeeze(2) q_scores = Params( alpha=q_scores.alpha.squeeze(1) if q_scores.alpha is not None else None, dist_type=q_scores.dist_type, samples=q_sample.squeeze(2) if q_sample is not None else None, sample_log_probs=sample_log_probs.squeeze(2) if sample_log_probs is not None else None, sample_log_probs_q=sample_log_probs_q.squeeze(2) if sample_log_probs_q is not None else None, sample_log_probs_p=sample_log_probs_p.squeeze(2) if sample_log_probs_p is not None else None, sample_p_div_q_log=sample_p_div_q_log.squeeze(2) if sample_p_div_q_log is not None else None, ) if q_scores is not None else None p_scores = Params( alpha=p_scores.alpha.squeeze(1), log_alpha=log_scores.squeeze(1), dist_type=p_scores.dist_type, samples=p_sample.squeeze(2) if p_sample is not None else None, ) if h_c is not None: # Check output sizes batch_, dim_ = h_c.size() aeq(batch, batch_) batch_, sourceL_ = c_align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: assert False # Only support input feeding. # T x N x H h_c = h_c.transpose(0, 1).contiguous() # T x N x S c_align_vectors = c_align_vectors.transpose(0, 1).contiguous() # T x K x N x H h_y = h_y.permute(2, 0, 1, 3).contiguous() # T x K x N x S #y_align_vectors = y_align_vectors.permute(2, 0, 1, 3).contiguous() q_scores = Params( alpha=q_scores.alpha.transpose(0, 1).contiguous(), dist_type=q_scores.dist_type, samples=q_sample.permute(2, 0, 1, 3).contiguous(), ) p_scores = Params( alpha=p_scores.alpha.transpose(0, 1).contiguous(), log_alpha=log_alpha.transpose(0, 1).contiguous(), dist_type=p_scores.dist_type, samples=p_sample.permute(2, 0, 1, 3).contiguous(), ) # Check output sizes targetL_, batch_, dim_ = h_c.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = c_align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) # For now, don't include samples. dist_info = DistInfo( q=q_scores, p=p_scores, ) # h_y: samples from simplex # either K x N x H, or T x K x N x H # h_c: convex combination of memory_bank for input feeding # either N x H, or T x N x H # align_vectors: convex coefficients / boltzmann dist # either N x S, or T x N x S # raw_scores: unnormalized scores # either N x S, or T x N x S return h_y, h_c, context_c, c_align_vectors, dist_info
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None, q_scores=None, tgt_emb=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = state.input_feed.squeeze(0) input_feed_batch, _ = input_feed.size() tgt_len, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. # Initialize local and return variables. decoder_outputs = [] decoder_outputs_baseline = [] dist_infos = [] attns = {"std": []} if q_scores is not None: attns["q"] = [] if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.dropout( self.embeddings(tgt)) if tgt_emb is None else tgt_emb assert emb.dim() == 3 # len x batch x embedding_dim tgt_len, batch_size = emb.size(0), emb.size(1) src_len = memory_bank.size(0) hidden = state.hidden coverage = state.coverage.squeeze(0) \ if state.coverage is not None else None # Input feed concatenates hidden state with # input at every time step. for i, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], -1) rnn_output, hidden = self.rnn(decoder_input.unsqueeze(0), hidden) rnn_output = rnn_output.squeeze(0) if q_scores is not None: # map over tensor-like keys q_scores_i = Params( alpha=q_scores.alpha[i], log_alpha=q_scores.log_alpha[i], dist_type=q_scores.dist_type, ) else: q_scores_i = None decoder_output_y, decoder_output_c, context_c, attn_c, dist_info = self.attn( rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths, q_scores=q_scores_i) dist_infos += [dist_info] if self.context_gate is not None and decoder_output_c is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output_c = self.context_gate(decoder_input, rnn_output, decoder_output_c) if decoder_output_c is not None: decoder_output_c = self.dropout(decoder_output_c) input_feed = context_c # decoder_output_y : K x N x H decoder_output_y = self.dropout(decoder_output_y) decoder_outputs += [decoder_output_y] if decoder_output_c is not None: decoder_outputs_baseline += [decoder_output_c] attns["std"] += [attn_c] if q_scores is not None: attns["q"] += [q_scores.alpha[i]] # Update the coverage attention. if self._coverage: coverage = coverage + p_attn \ if coverage is not None else p_attn attns["coverage"] += [coverage] # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._copy: attns["copy"] = attns["std"] q_info = Params( alpha=q_scores.alpha, dist_type=q_scores.dist_type, samples=torch.stack([d.q.samples for d in dist_infos], dim=0) if dist_infos[0].q.samples is not None else None, log_alpha=q_scores.log_alpha, sample_log_probs=torch.stack( [d.q.sample_log_probs for d in dist_infos], dim=0) if dist_infos[0].q.sample_log_probs is not None else None, sample_log_probs_q=torch.stack( [d.q.sample_log_probs_q for d in dist_infos], dim=0) if dist_infos[0].q.sample_log_probs_q is not None else None, sample_log_probs_p=torch.stack( [d.q.sample_log_probs_p for d in dist_infos], dim=0) if dist_infos[0].q.sample_log_probs_p is not None else None, sample_p_div_q_log=torch.stack( [d.q.sample_p_div_q_log for d in dist_infos], dim=0) if dist_infos[0].q.sample_p_div_q_log is not None else None, ) if q_scores is not None else None p_info = Params( alpha=torch.stack([d.p.alpha for d in dist_infos], dim=0), dist_type=dist_infos[0].p.dist_type, log_alpha=torch.stack([d.p.log_alpha for d in dist_infos], dim=0) if dist_infos[0].p.log_alpha is not None else None, samples=torch.stack([d.p.samples for d in dist_infos], dim=0) if dist_infos[0].p.samples is not None else None, ) dist_info = DistInfo( q=q_info, p=p_info, ) return hidden, decoder_outputs, input_feed, attns, dist_info, decoder_outputs_baseline