def forward(self, input, lengths=None, hidden=None): """ See :obj:`EncoderBase.forward()`""" self._check_args(input, lengths, hidden) emb = self.embeddings(input) s_len, n_batch, emb_dim = emb.size() out = emb.transpose(0, 1).contiguous() words = input[:, :, 0].transpose(0, 1) # CHECKS out_batch, out_len, _ = out.size() w_batch, w_len = words.size() aeq(out_batch, w_batch) aeq(out_len, w_len) # END CHECKS # Make mask. padding_idx = self.embeddings.word_padding_idx mask = words.data.eq(padding_idx).unsqueeze(1) \ .expand(w_batch, w_len, w_len) # Run the forward pass of every layer of the tranformer. for i in range(self.num_layers): out = self.transformer[i](out, mask) out = self.layer_norm(out) return Variable(emb.data), out.transpose(0, 1).contiguous()
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: tgt (LongTensor): a sequence of input tokens tensors [len x batch x nfeats]. memory_bank (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. memory_lengths (LongTensor): the source memory_bank lengths. Returns: decoder_final (Variable): final hidden state from the decoder. decoder_outputs ([FloatTensor]): an array of output of every time step from the decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the decoder. """ assert not self._copy # TODO, no support yet. assert not self._coverage # TODO, no support yet. # Initialize local and return variables. attns = {} emb = self.dropout(self.embeddings(tgt)) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU): rnn_output, decoder_final = self.rnn(emb, state.hidden[0]) else: rnn_output, decoder_final = self.rnn(emb, state.hidden) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # END # Calculate the attention. decoder_outputs, p_attn = self.attn(rnn_output.transpose( 0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"] = p_attn # Calculate the context gate. if self.context_gate is not None: decoder_outputs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), decoder_outputs.view(-1, decoder_outputs.size(2))) decoder_outputs = \ decoder_outputs.view(tgt_len, tgt_batch, self.hidden_size) decoder_outputs = self.dropout(decoder_outputs) return decoder_final, decoder_outputs, attns
def _example_dict_iter(self, line, index): line = line.split() if self.line_truncate: line = line[:self.line_truncate] words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = {self.side: words, "indices": index} if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update((prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def forward(self, tgt, memory_bank, state, memory_lengths=None): """ Args: tgt (`LongTensor`): sequences of padded tokens `[tgt_len x batch x nfeats]`. memory_bank (`FloatTensor`): vectors from the encoder `[src_len x batch x hidden]`. state (:obj:`tools.Models.DecoderState`): decoder state object to initialize the decoder memory_lengths (`LongTensor`): the padded source lengths `[batch]`. Returns: (`FloatTensor`,:obj:`tools.Models.DecoderState`,`FloatTensor`): * decoder_outputs: output from the decoder (after attn) `[tgt_len x batch x hidden]`. * decoder_state: final hidden state from the decoder * attns: distribution over src at each tgt `[tgt_len x batch x src_len]`. """ # Check assert isinstance(state, RNNDecoderState) tgt_len, tgt_batch, _ = tgt.size() _, memory_batch, _ = memory_bank.size() aeq(tgt_batch, memory_batch) # END # Run the forward pass of the RNN. decoder_final, decoder_outputs, input_feed, attns = self._run_forward_pass( tgt, memory_bank, state, memory_lengths=memory_lengths) # Update the state with the result. final_output = decoder_outputs[-1] coverage = None if "coverage" in attns: coverage = attns["coverage"][-1].unsqueeze(0) state.update_state(decoder_final, input_feed.unsqueeze(0), coverage) # Concatenates sequence of tensors along a new dimension. decoder_outputs = torch.stack(decoder_outputs) for k in attns: attns[k] = torch.stack(attns[k]) return decoder_outputs, state, attns
def score(self, h_t, h_s): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, self.tgt_dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, self.src_dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = self.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, hidden, attn, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by compying source words. Args: hidden (`FloatTensor`): hidden outputs `[batch*tlen, input_size]` attn (`FloatTensor`): attn for each `[batch*tlen, input_size]` src_map (`FloatTensor`): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. `[src_len, batch, extra_words]` """ # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.tgt_dict.stoi[tools.io.PAD_WORD]] = -float('inf') prob = F.softmax(logits) # Probability of copying p(z=1) batch. p_copy = F.sigmoid(self.linear_copy(hidden)) # Probibility of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob)) mul_attn = torch.mul(attn, p_copy.expand_as(attn)) copy_prob = torch.bmm( mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1)
def forward(self, base_target_emb, input, encoder_out_top, encoder_out_combine): """ Args: base_target_emb: target emb tensor input: output of decode conv encoder_out_t: the key matrix for calculation of attetion weight, which is the top output of encode conv encoder_out_combine: the value matrix for the attention-weighted sum, which is the combination of base emb and top output of encode """ # checks batch, channel, height, width = base_target_emb.size() batch_, channel_, height_, width_ = input.size() aeq(batch, batch_) aeq(height, height_) enc_batch, enc_channel, enc_height = encoder_out_top.size() enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() aeq(enc_batch, enc_batch_) aeq(enc_height, enc_height_) preatt = seq_linear(self.linear_in, input) target = (base_target_emb + preatt) * SCALE_WEIGHT target = torch.squeeze(target, 3) target = torch.transpose(target, 1, 2) pre_attn = torch.bmm(target, encoder_out_top) if self.mask is not None: pre_attn.data.masked_fill_(self.mask, -float('inf')) pre_attn = pre_attn.transpose(0, 2) attn = F.softmax(pre_attn) attn = attn.transpose(0, 2).contiguous() context_output = torch.bmm(attn, torch.transpose(encoder_out_combine, 1, 2)) context_output = torch.transpose(torch.unsqueeze(context_output, 3), 1, 2) return context_output, attn
def forward(self, input): """ Computes the embeddings for words and features. Args: input (`LongTensor`): index tensor `[len x batch x nfeat]` Return: `FloatTensor`: word embeddings `[len x batch x embedding_size]` """ in_length, in_batch, nfeat = input.size() aeq(nfeat, len(self.emb_luts)) emb = self.make_embedding(input) out_length, out_batch, emb_size = emb.size() aeq(in_length, out_length) aeq(in_batch, out_batch) aeq(emb_size, self.embedding_size) return emb
def forward(self, input, memory_bank, memory_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = self.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(input, memory_bank) self.p_attn_score = align if memory_lengths is not None: mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch * targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate #concat_c = torch.cat([c, input], -1) concat_c = torch.cat([input, c], -1) attn_h = self.linear_out(concat_c) #if self.attn_type in ["general", "dot"]: if True or self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) c = c.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors, c
def _check_args(self, input, lengths=None, hidden=None): s_len, n_batch, n_feats = input.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = state.input_feed.squeeze(0) input_feed_batch, _ = input_feed.size() tgt_len, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. # Initialize local and return variables. decoder_outputs = [] attns = {"std": []} if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.dropout(self.embeddings(tgt)) assert emb.dim() == 3 # len x batch x embedding_dim hidden = state.hidden coverage = state.coverage.squeeze(0) \ if state.coverage is not None else None # Input feed concatenates hidden state with # input at every time step. # DBG self.p_attn_score = [] self.dec_h = [] self.src_context = [] self.context = [] for i, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, hidden = self.rnn(decoder_input.unsqueeze(0), hidden) rnn_output = rnn_output.squeeze(0) decoder_output, p_attn, input_feed = self.attn( rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) # DBG self.dec_h.append(rnn_output) self.p_attn_score.append(self.attn.p_attn_score) self.src_context.append(input_feed) self.context.append(decoder_output) if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output = self.context_gate(decoder_input, rnn_output, decoder_output) decoder_output = self.dropout(decoder_output) #input_feed = decoder_output decoder_outputs += [decoder_output] attns["std"] += [p_attn] # Update the coverage attention. if self._coverage: coverage = coverage + p_attn \ if coverage is not None else p_attn attns["coverage"] += [coverage] # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._copy: attns["copy"] = attns["std"] # Return result. return hidden, decoder_outputs, input_feed, attns
def forward(self, input, memory_bank, memory_lengths=None, coverage=None, q_scores=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) q_scores (`FloatTensor`): the attention params from the inference network Returns: (`FloatTensor`, `FloatTensor`): * Weighted context vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` * Unormalized attention scores for each query `[batch x tgt_len x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) if q_scores is not None: # oh, I guess this is super messy if q_scores.alpha is not None: q_scores = Params( alpha=q_scores.alpha.unsqueeze(1), log_alpha=q_scores.log_alpha.unsqueeze(1), dist_type=q_scores.dist_type, ) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) # compute attention scores, as in Luong et al. # Params should be T x N x S if self.p_dist_type == "categorical": scores = self.score(input, memory_bank) if memory_lengths is not None: # mask : N x T x S mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. scores.data.masked_fill_(1 - mask, -float('inf')) if self.k > 0 and self.k < scores.size(-1): topk, idx = scores.data.topk(self.k) new_attn_score = torch.zeros_like(scores.data).fill_( float("-inf")) new_attn_score = new_attn_score.scatter_(2, idx, topk) scores = new_attn_score log_scores = F.log_softmax(scores, dim=-1) scores = log_scores.exp() c_align_vectors = scores p_scores = Params( alpha=scores, log_alpha=log_scores, dist_type=self.p_dist_type, ) # each context vector c_t is the weighted average # over all the source hidden states context_c = torch.bmm(c_align_vectors, memory_bank) if self.mode != 'wsram': concat_c = torch.cat([input, context_c], -1) # N x T x H h_c = self.tanh(self.linear_out(concat_c)) else: h_c = None # sample or enumerate # y_align_vectors: K x N x T x S q_sample, p_sample, sample_log_probs = None, None, None sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = None, None, None if self.mode == "sample": if q_scores is None or self.use_prior: p_sample, sample_log_probs = self.sample_attn( p_scores, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = p_sample else: q_sample, sample_log_probs = self.sample_attn( q_scores, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = q_sample elif self.mode == "gumbel": if q_scores is None or self.use_prior: p_sample, _ = self.sample_attn_gumbel( p_scores, self.temperature, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = p_sample else: q_sample, _ = self.sample_attn_gumbel( q_scores, self.temperature, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = q_sample elif self.mode == "enum" or self.mode == "exact": y_align_vectors = None elif self.mode == "wsram": assert q_scores is not None q_sample, sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = self.sample_attn_wsram( q_scores, p_scores, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = q_sample # context_y: K x N x T x H if y_align_vectors is not None: context_y = torch.bmm( y_align_vectors.view(-1, targetL, sourceL), memory_bank.unsqueeze(0).repeat(self.n_samples, 1, 1, 1).view( -1, sourceL, dim)).view(self.n_samples, batch, targetL, dim) else: # For enumerate, K = S. # memory_bank: N x S x H context_y = ( memory_bank.unsqueeze(0).repeat(targetL, 1, 1, 1) # T, N, S, H .permute(2, 1, 0, 3)) # S, N, T, H input = input.unsqueeze(0).repeat(context_y.size(0), 1, 1, 1) concat_y = torch.cat([input, context_y], -1) # K x N x T x H h_y = self.tanh(self.linear_out(concat_y)) if one_step: if h_c is not None: # N x H h_c = h_c.squeeze(1) # N x S c_align_vectors = c_align_vectors.squeeze(1) context_c = context_c.squeeze(1) # K x N x H h_y = h_y.squeeze(2) # K x N x S #y_align_vectors = y_align_vectors.squeeze(2) q_scores = Params( alpha=q_scores.alpha.squeeze(1) if q_scores.alpha is not None else None, dist_type=q_scores.dist_type, samples=q_sample.squeeze(2) if q_sample is not None else None, sample_log_probs=sample_log_probs.squeeze(2) if sample_log_probs is not None else None, sample_log_probs_q=sample_log_probs_q.squeeze(2) if sample_log_probs_q is not None else None, sample_log_probs_p=sample_log_probs_p.squeeze(2) if sample_log_probs_p is not None else None, sample_p_div_q_log=sample_p_div_q_log.squeeze(2) if sample_p_div_q_log is not None else None, ) if q_scores is not None else None p_scores = Params( alpha=p_scores.alpha.squeeze(1), log_alpha=log_scores.squeeze(1), dist_type=p_scores.dist_type, samples=p_sample.squeeze(2) if p_sample is not None else None, ) if h_c is not None: # Check output sizes batch_, dim_ = h_c.size() aeq(batch, batch_) batch_, sourceL_ = c_align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: assert False # Only support input feeding. # T x N x H h_c = h_c.transpose(0, 1).contiguous() # T x N x S c_align_vectors = c_align_vectors.transpose(0, 1).contiguous() # T x K x N x H h_y = h_y.permute(2, 0, 1, 3).contiguous() # T x K x N x S #y_align_vectors = y_align_vectors.permute(2, 0, 1, 3).contiguous() q_scores = Params( alpha=q_scores.alpha.transpose(0, 1).contiguous(), dist_type=q_scores.dist_type, samples=q_sample.permute(2, 0, 1, 3).contiguous(), ) p_scores = Params( alpha=p_scores.alpha.transpose(0, 1).contiguous(), log_alpha=log_alpha.transpose(0, 1).contiguous(), dist_type=p_scores.dist_type, samples=p_sample.permute(2, 0, 1, 3).contiguous(), ) # Check output sizes targetL_, batch_, dim_ = h_c.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = c_align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) # For now, don't include samples. dist_info = DistInfo( q=q_scores, p=p_scores, ) # h_y: samples from simplex # either K x N x H, or T x K x N x H # h_c: convex combination of memory_bank for input feeding # either N x H, or T x N x H # align_vectors: convex coefficients / boltzmann dist # either N x S, or T x N x S # raw_scores: unnormalized scores # either N x S, or T x N x S return h_y, h_c, context_c, c_align_vectors, dist_info
def forward(self, key, value, query, mask=None): """ Compute the context vector and the attention vectors. Args: key (`FloatTensor`): set of `key_len` key vectors `[batch, key_len, dim]` value (`FloatTensor`): set of `key_len` value vectors `[batch, key_len, dim]` query (`FloatTensor`): set of `query_len` query vectors `[batch, query_len, dim]` mask: binary mask indicating which keys have non-zero attention `[batch, query_len, key_len]` Returns: (`FloatTensor`, `FloatTensor`) : * output context vectors `[batch, query_len, dim]` * one of the attention vectors `[batch, query_len, key_len]` """ # CHECKS batch, k_len, d = key.size() batch_, k_len_, d_ = value.size() aeq(batch, batch_) aeq(k_len, k_len_) aeq(d, d_) batch_, q_len, d_ = query.size() aeq(batch, batch_) aeq(d, d_) aeq(self.model_dim % 8, 0) if mask is not None: batch_, q_len_, k_len_ = mask.size() aeq(batch_, batch) aeq(k_len_, k_len) aeq(q_len_ == q_len) # END CHECKS batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count key_len = key.size(1) query_len = query.size(1) def shape(x): return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) # 1) Project key, value, and query. key_up = shape(self.linear_keys(key)) value_up = shape(self.linear_values(value)) query_up = shape(self.linear_query(query)) # 2) Calculate and scale scores. query_up = query_up / math.sqrt(dim_per_head) scores = torch.matmul(query_up, key_up.transpose(2, 3)) if mask is not None: mask = mask.unsqueeze(1).expand_as(scores) scores = scores.masked_fill(Variable(mask), -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.sm(scores) drop_attn = self.dropout(attn) context = unshape(torch.matmul(drop_attn, value_up)) output = self.final_linear(context) # CHECK batch_, q_len_, d_ = output.size() aeq(q_len, q_len_) aeq(batch, batch_) aeq(d, d_) # Return one attn top_attn = attn \ .view(batch_size, head_count, query_len, key_len)[:, 0, :, :] \ .contiguous() # END CHECK return output, top_attn
def forward(self, tgt, memory_bank, state, memory_lengths=None): """ See :obj:`tools.modules.RNNDecoderBase.forward()` """ # CHECKS assert isinstance(state, TransformerDecoderState) tgt_len, tgt_batch, _ = tgt.size() memory_len, memory_batch, _ = memory_bank.size() aeq(tgt_batch, memory_batch) src = state.src src_words = src[:, :, 0].transpose(0, 1) tgt_words = tgt[:, :, 0].transpose(0, 1) src_batch, src_len = src_words.size() tgt_batch, tgt_len = tgt_words.size() aeq(tgt_batch, memory_batch, src_batch, tgt_batch) aeq(memory_len, src_len) if state.previous_input is not None: tgt = torch.cat([state.previous_input, tgt], 0) # END CHECKS # Initialize return variables. outputs = [] attns = {"std": []} if self._copy: attns["copy"] = [] # Run the forward pass of the TransformerDecoder. emb = self.embeddings(tgt) if state.previous_input is not None: emb = emb[state.previous_input.size(0):, ] assert emb.dim() == 3 # len x batch x embedding_dim output = emb.transpose(0, 1).contiguous() src_memory_bank = memory_bank.transpose(0, 1).contiguous() padding_idx = self.embeddings.word_padding_idx src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ .expand(src_batch, tgt_len, src_len) tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \ .expand(tgt_batch, tgt_len, tgt_len) saved_inputs = [] for i in range(self.num_layers): prev_layer_input = None if state.previous_input is not None: prev_layer_input = state.previous_layer_inputs[i] output, attn, all_input \ = self.transformer_layers[i](output, src_memory_bank, src_pad_mask, tgt_pad_mask, previous_input=prev_layer_input) saved_inputs.append(all_input) saved_inputs = torch.stack(saved_inputs) output = self.layer_norm(output) # Process the result and update the attentions. outputs = output.transpose(0, 1).contiguous() attn = attn.transpose(0, 1).contiguous() attns["std"] = attn if self._copy: attns["copy"] = attn # Update the state. state = state.update_state(tgt, saved_inputs) return outputs, state, attns
def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None): # Args Checks input_batch, input_len, _ = inputs.size() if previous_input is not None: pi_batch, _, _ = previous_input.size() aeq(pi_batch, input_batch) contxt_batch, contxt_len, _ = memory_bank.size() aeq(input_batch, contxt_batch) src_batch, t_len, s_len = src_pad_mask.size() tgt_batch, t_len_, t_len__ = tgt_pad_mask.size() aeq(input_batch, contxt_batch, src_batch, tgt_batch) # aeq(t_len, t_len_, t_len__, input_len) aeq(s_len, contxt_len) # END Args Checks dec_mask = torch.gt( tgt_pad_mask + self.mask[:, :tgt_pad_mask.size(1), :tgt_pad_mask.size(1)], 0) input_norm = self.layer_norm_1(inputs) all_input = input_norm if previous_input is not None: all_input = torch.cat((previous_input, input_norm), dim=1) dec_mask = None query, attn = self.self_attn(all_input, all_input, input_norm, mask=dec_mask) query = self.drop(query) + inputs query_norm = self.layer_norm_2(query) mid, attn = self.context_attn(memory_bank, memory_bank, query_norm, mask=src_pad_mask) output = self.feed_forward(self.drop(mid) + query) # CHECKS output_batch, output_len, _ = output.size() aeq(input_len, output_len) aeq(contxt_batch, output_batch) n_batch_, t_len_, s_len_ = attn.size() aeq(input_batch, n_batch_) aeq(contxt_len, s_len_) aeq(input_len, t_len_) # END CHECKS return output, attn, all_input