def forward(self, input, lengths=None, hidden=None): """ See :obj:`EncoderBase.forward()`""" self._check_args(input, lengths, hidden) emb = self.embeddings(input) s_len, n_batch, emb_dim = emb.size() out = emb.transpose(0, 1).contiguous() words = input[:, :, 0].transpose(0, 1) # CHECKS out_batch, out_len, _ = out.size() w_batch, w_len = words.size() aeq(out_batch, w_batch) aeq(out_len, w_len) # END CHECKS # Make mask. padding_idx = self.embeddings.word_padding_idx mask = words.data.eq(padding_idx).unsqueeze(1) \ .expand(w_batch, w_len, w_len) # Run the forward pass of every layer of the tranformer. for i in range(self.num_layers): out = self.transformer[i](out, mask) out = self.layer_norm(out) return Variable(emb.data), out.transpose(0, 1).contiguous()
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: tgt (LongTensor): a sequence of input tokens tensors [len x batch x nfeats]. memory_bank (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. memory_lengths (LongTensor): the source memory_bank lengths. Returns: decoder_final (Variable): final hidden state from the decoder. decoder_outputs ([FloatTensor]): an array of output of every time step from the decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the decoder. """ assert not self._copy # TODO, no support yet. assert not self._coverage # TODO, no support yet. # Initialize local and return variables. attns = {} emb = self.embeddings(tgt) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU): rnn_output, decoder_final = self.rnn(emb, state.hidden[0]) else: rnn_output, decoder_final = self.rnn(emb, state.hidden) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # END # Calculate the attention. decoder_outputs, p_attn = self.attn( rnn_output.transpose(0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths ) attns["std"] = p_attn # Calculate the context gate. if self.context_gate is not None: decoder_outputs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), decoder_outputs.view(-1, decoder_outputs.size(2)) ) decoder_outputs = \ decoder_outputs.view(tgt_len, tgt_batch, self.hidden_size) decoder_outputs = self.dropout(decoder_outputs) return decoder_final, decoder_outputs, attns
def _example_dict_iter(self, line, index): line = line.split() if self.line_truncate: line = line[:self.line_truncate] words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = {self.side: words, "indices": index} if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update((prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def _example_dict_iter(self, line, index): line = line.split() if self.line_truncate: line = line[:self.line_truncate] words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = {self.side: words, "indices": index} if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update( (prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def forward(self, input, context, state, context_lengths=None): """ Forward through the decoder. Args: input (LongTensor): a sequence of input tokens tensors of size (len x batch x nfeats). context (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. context_lengths (LongTensor): the source context lengths. Returns: outputs (FloatTensor): a Tensor sequence of output from the decoder of shape (len x batch x hidden_size). state (FloatTensor): final hidden state from the decoder. attns (dict of (str, FloatTensor)): a dictionary of different type of attention Tensor from the decoder of shape (src_len x batch). """ # Args Check assert isinstance(state, RNNDecoderState) input_len, input_batch, _ = input.size() contxt_len, contxt_batch, _ = context.size() aeq(input_batch, contxt_batch) # END Args Check # Run the forward pass of the RNN. hidden, outputs, attns, coverage, rnn_output, emb = \ self._run_forward_pass(input, context, state, context_lengths=context_lengths) # Update the state with the result. final_output = outputs[-1] state.update_state(hidden, final_output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None) # Concatenates sequence of tensors along a new dimension. outputs = torch.stack(outputs) rnn_output = torch.stack(rnn_output) for k in attns: attns[k] = torch.stack(attns[k]) return ( outputs, state, attns, # pointer_gen rnn_output, emb )
def _check_args_double_enc(self, input, lengths_src=None, lengths_inter=None, hidden_src=None, hidden_inter=None): #SRC s_len, n_batch, n_feats = input[0].size() if lengths_src is not None: n_batch_, = lengths_src.size() aeq(n_batch, n_batch_) #INTER s_len, n_batch, n_feats = input[1].size() if lengths_inter is not None: n_batch_, = lengths_inter.size() aeq(n_batch, n_batch_)
def forward(self, tgt, memory_bank, state, memory_lengths=None): """ Args: tgt (`LongTensor`): sequences of padded tokens `[tgt_len x batch x nfeats]`. memory_bank (`FloatTensor`): vectors from the encoder `[src_len x batch x hidden]`. state (:obj:`onmt.Models.DecoderState`): decoder state object to initialize the decoder memory_lengths (`LongTensor`): the padded source lengths `[batch]`. Returns: (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`): * decoder_outputs: output from the decoder (after attn) `[tgt_len x batch x hidden]`. * decoder_state: final hidden state from the decoder * attns: distribution over src at each tgt `[tgt_len x batch x src_len]`. """ # Check assert isinstance(state, RNNDecoderState) tgt_len, tgt_batch, _ = tgt.size() _, memory_batch, _ = memory_bank.size() aeq(tgt_batch, memory_batch) # END # Run the forward pass of the RNN. decoder_final, decoder_outputs, attns = self._run_forward_pass( tgt, memory_bank, state, memory_lengths=memory_lengths) # Update the state with the result. final_output = decoder_outputs[-1] coverage = None if "coverage" in attns: coverage = attns["coverage"][-1].unsqueeze(0) state.update_state(decoder_final, final_output.unsqueeze(0), coverage) # Concatenates sequence of tensors along a new dimension. # Change for torch0.4 if type(decoder_outputs ) is not torch.Tensor: # If input feeding is being used decoder_outputs = torch.stack(decoder_outputs) for k in attns: if type(attns[k]) is not torch.Tensor: attns[k] = torch.stack(attns[k]) return decoder_outputs, state, attns
def forward(self, input, context, state, context_lengths=None, **kwargs): """ Args: input (`LongTensor`): sequences of padded tokens `[tgt_len x batch x nfeats]`. context (`FloatTensor`): vectors from the encoder `[src_len x batch x hidden]`. state (:obj:`onmt.Models.DecoderState`): decoder state object to initialize the decoder context_lengths (`LongTensor`): the padded source lengths `[batch]`. Returns: (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`): * outputs: output from the decoder `[tgt_len x batch x hidden]`. * state: final hidden state from the decoder * attns: distribution over src at each tgt `[tgt_len x batch x src_len]`. """ # Args Check assert isinstance(state, RNNDecoderState) input_len, input_batch, _ = input.size() contxt_len, contxt_batch, _ = context.size() aeq(input_batch, contxt_batch) # END Args Check # Run the forward pass of the RNN. # All the latent variables and additional inputs are found in kwargs hidden, outputs, attns, coverage = self._run_forward_pass( input, context, state, context_lengths=context_lengths, **kwargs) # Update the state with the result. final_output = outputs[-1] state.update_state(hidden, final_output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None) # Concatenates sequence of tensors along a new dimension. outputs = torch.stack(outputs) for k in attns: if not k in ["q_latent", "p_latent"]: attns[k] = torch.stack(attns[k]) return outputs, state, attns
def forward_mm(self, input, img_proj, lengths=None, hidden=None): """ Args: input (:obj:`LongTensor`): padded sequences of sparse indices `[src_len x batch x nfeat]` lengths (:obj:`LongTensor`): length of each sequence `[batch]` hidden (class specific): initial hidden state. Returns:k (tuple of :obj:`FloatTensor`, :obj:`FloatTensor`): * final encoder state, used to initialize decoder `[layers x batch x hidden]` * contexts for attention, `[src_len x batch x hidden]` """ self._check_args(input, lengths, hidden) emb = self.embeddings(input) s_len, n_batch, emb_dim = emb.size() emb = emb.transpose(0, 1).contiguous() # (batch, src_len, nfeat) input_mm = torch.cat([emb, img_proj], dim=1) out = input_mm # words = input[:, :, 0].transpose(0, 1) # (batch, src_len) words = emb[:, :, 0] # CHECKS out_batch, out_len, _ = out.size() w_batch, w_len = words.size() aeq(out_batch, w_batch) aeq(s_len, w_len) # END CHECKS # Make mask. the mask here is no use padding_idx = self.embeddings.word_padding_idx mask = words.data.eq(padding_idx).unsqueeze(1) \ .expand(w_batch, out_len, s_len) # assert not words.data.eq(padding_idx).max(), "there are some mask items eqaul to 1" # Run the forward pass of every layer of the tranformer. for i in range(self.num_layers): out, attn = self.transformer[i](emb, out, mask) # attn 1x49x10 out = self.layer_norm(out) return Variable(input_mm.data), out.transpose(0, 1).contiguous(), attn
def forward(self, hidden, attn, src_map, align=None, ptrs=None): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by compying source words. Args: hidden (`FloatTensor`): hidden outputs `[batch*tlen, input_size]` attn (`FloatTensor`): attn for each `[batch*tlen, input_size]` src_map (`FloatTensor`): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. `[src_len, batch, extra_words]` """ # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.tgt_dict.stoi[onmt.io.PAD_WORD]] = -float('inf') prob = F.softmax(logits) # Probability of copying p(z=1) batch. p_copy = F.sigmoid(self.linear_copy(hidden)) # Probibility of not copying: p_{word}(w) * (1 - p(z)) if self.training: align_unk = align.eq(0).float().view(-1, 1) align_not_unk = align.ne(0).float().view(-1, 1) out_prob = torch.mul(prob, align_unk.expand_as(prob)) mul_attn = torch.mul(attn, align_not_unk.expand_as(attn)) mul_attn = torch.mul(mul_attn, ptrs.view(-1, slen_).float()) else: out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob)) mul_attn = torch.mul(attn, p_copy.expand_as(attn)) copy_prob = torch.bmm( mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1), p_copy
def score(self, h_t, h_s, entity_attn=False): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(self.dim, tgt_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) if entity_attn: h_t_ = self.linear_in_entity(h_t_) else: h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, src_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = self.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, tgt, memory_bank, state, memory_lengths=None): """ Args: tgt (`LongTensor`): sequences of padded tokens `[tgt_len x batch x nfeats]`. memory_bank (`FloatTensor`): vectors from the encoder `[src_len x batch x hidden]`. state (:obj:`onmt.Models.DecoderState`): decoder state object to initialize the decoder memory_lengths (`LongTensor`): the padded source lengths `[batch]`. Returns: (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`): * decoder_outputs: output from the decoder (after attn) `[tgt_len x batch x hidden]`. * decoder_state: final hidden state from the decoder * attns: distribution over src at each tgt `[tgt_len x batch x src_len]`. """ # Check assert isinstance(state, RNNDecoderState) tgt_len, tgt_batch, _ = tgt.size() _, memory_batch, _ = memory_bank.size() aeq(tgt_batch, memory_batch) # END # Run the forward pass of the RNN. decoder_final, decoder_outputs, attns = self._run_forward_pass( tgt, memory_bank, state, memory_lengths=memory_lengths) # Update the state with the result. final_output = decoder_outputs[-1] coverage = None if "coverage" in attns: coverage = attns["coverage"][-1].unsqueeze(0) state.update_state(decoder_final, final_output.unsqueeze(0), coverage) # Concatenates sequence of tensors along a new dimension. decoder_outputs = torch.stack(decoder_outputs) for k in attns: attns[k] = torch.stack(attns[k]) return decoder_outputs, state, attns
def forward(self, input): """ Computes the partly linked embeddings for words. Args: input (`LongTensor`): index tensor `[len x batch x 1]` Return: `FloatTensor`: word embeddings `[len x batch x embedding_size]` """ in_length, in_batch, nfeat = input.size() aeq(nfeat, 1) flat_input = input.view(-1) cluster_indices = self.cluster_mapping.index_select(0, flat_input) concat = torch.cat([input, cluster_indices.view(input.shape)], dim=-1) emb = self.make_embedding(concat) return emb
def _get_word_context(self, query, context, index, mask_word): """ Verify sizes """ b_size, t_size, d_size = query.size() b_size_, s_size, d_size_ = context.size() aeq(d_size, d_size_) b_size__, c_size = index.size() aeq(b_size, b_size__) b_size__, t_size_, s_size_ = mask_word.size() aeq(b_size_, b_size__) aeq(s_size, s_size_) aeq(t_size, t_size_) """ Padding index of previous invalid sentences index (<0) to 0, and saving mask for sentences """ mask_sent = index < 0 index_ = copy.deepcopy(index) index_[mask_sent] = 0 """ Select context with index vector """ context_ = context.view(b_size_, -1).expand(b_size, b_size_, s_size * d_size) index__ = index_.unsqueeze(2).expand(b_size, c_size, s_size * d_size) context_word = torch.gather(context_, 1, Variable(index__, requires_grad=False)).view( b_size * c_size, s_size, d_size) """ Create complete mask for context: word padding + sentence padding """ mask_ = mask_word.contiguous().view(b_size_, -1).expand(b_size, b_size_, t_size_ * s_size) index__ = index_.unsqueeze(2).expand(b_size, c_size, t_size_ * s_size) context_pad_mask = torch.gather(mask_, 1, index__).view(b_size * c_size, t_size_, s_size) mask_sent_ = mask_sent.unsqueeze(2).expand( b_size, c_size, t_size_ * s_size).contiguous().view(b_size * c_size, t_size_, s_size) context_pad_mask[mask_sent_] = self.padding_idx """ Reshape query for future operations """ query_ = query.unsqueeze(1).expand(b_size, c_size, t_size, d_size).contiguous().view( b_size * c_size, t_size, d_size) return query_, context_word, context_pad_mask
def forward(self, input, lengths=None, hidden=None, contexts=None, neg=None, tau=0.5, scale=0.5): """ See :obj:`EncoderBase.forward()`""" self._check_args(input, lengths, hidden) emb = self.embeddings(input, contexts=contexts, neg=neg, tau=tau, scale=scale) if neg is not None: sense_loss = emb[1] emb = emb[0] s_len, n_batch, emb_dim = emb.size() out = emb.transpose(0, 1).contiguous() words = input[:, :, 0].transpose(0, 1) # CHECKS out_batch, out_len, _ = out.size() w_batch, w_len = words.size() aeq(out_batch, w_batch) aeq(out_len, w_len) # END CHECKS # Make mask. padding_idx = self.embeddings.word_padding_idx mask = words.data.eq(padding_idx).unsqueeze(1) \ .expand(w_batch, w_len, w_len) # Run the forward pass of every layer of the tranformer. for i in range(self.num_layers): out = self.transformer[i](out, mask) out = self.layer_norm(out) if neg is None: return Variable(emb.data), out.transpose(0, 1).contiguous() else: return Variable(emb.data), out.transpose( 0, 1).contiguous(), sense_loss
def forward(self, tgt, memory_bank, state, memory_lengths=None, q_scores=None, tgt_emb=None): # Check assert isinstance(state, RNNDecoderState) tgt_len, tgt_batch, _ = tgt.size() _, memory_batch, _ = memory_bank.size() aeq(tgt_batch, memory_batch) # END # Run the forward pass of the RNN. decoder_final, decoder_outputs, input_feed, attns, dist_info, decoder_outputs_baseline = self._run_forward_pass( tgt, memory_bank, state, memory_lengths=memory_lengths, q_scores=q_scores, tgt_emb=tgt_emb) # Update the state with the result. final_output = decoder_outputs[-1] coverage = None if "coverage" in attns: coverage = attns["coverage"][-1].unsqueeze(0) state.update_state(decoder_final, input_feed.unsqueeze(0), coverage) # Concatenates sequence of tensors along a new dimension. # T x K x N x H decoder_outputs = torch.stack(decoder_outputs, dim=0) if len(decoder_outputs_baseline) > 0: decoder_outputs_baseline = torch.stack(decoder_outputs_baseline, dim=0) else: decoder_outputs_baseline = None for k in attns: attns[k] = torch.stack(attns[k]) return decoder_outputs, state, attns, dist_info, decoder_outputs_baseline
def coalesce_datasets(datasets): """Coalesce all dataset instances. """ final = datasets[0] for d in datasets[1:]: # `src_vocabs` is a list of `torchtext.vocab.Vocab`. # Each sentence transforms into on Vocab. # Coalesce them into one big list. final.src_vocabs += d.src_vocabs # All datasets have same number of features. aeq(final.n_src_feats, d.n_src_feats) aeq(final.n_tgt_feats, d.n_tgt_feats) # `examples` is a list of `torchtext.data.Example`. # Coalesce them into one big list. final.examples += d.examples # All datasets have same fields, no need to update. return final
def forward(self, input): in_length, in_batch, nfeat = input.size() aeq(nfeat, len(self.emb_luts)) emb = self.make_embedding(input) out_length, out_batch, emb_size = emb.size() aeq(in_length, out_length) aeq(in_batch, out_batch) aeq(emb_size, self.embedding_size) return emb
def forward(self, src, lengths=None, encoder_state=None, entities_list=None, entities_len=None): "See :obj:`EncoderBase.forward()`" self._check_args(src, lengths, encoder_state) emb = self.dropout(self.embeddings(src)) s_len, batch, emb_dim = emb.size() mean = emb.mean(0).expand(self.num_layers, batch, emb_dim) memory_bank = emb encoder_final = (mean, mean) s_len_batch, batch_ent, num_entities = entities_list.size() s_len_entities, batch_len_ent = entities_len.size() aeq(batch_len_ent, batch_ent) aeq(num_entities, s_len_entities) ent_emb = emb.unsqueeze(1).expand(-1, num_entities, -1, -1) s_len, batch, emb_dim = emb.size() ent_dim = entities_list.transpose(1,2).unsqueeze(3).expand(-1, -1, -1, emb_dim) ent_len_dim = entities_len.unsqueeze(2).expand(-1, -1, emb_dim) ent_emb = (ent_emb[:s_len_batch,:,:,:]*ent_dim).sum(0) ent_emb = ent_emb/ent_len_dim return encoder_final, memory_bank, ent_emb
def get_normal_scores(self, h_s, h_t): """ h_s: [batch x src_length x rnn_size] h_t: [batch x tgt_length x rnn_size] """ src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.rnn_size, src_dim) #import pdb; pdb.set_trace() h_t_expand = h_t.unsqueeze(2).expand(-1, -1, src_len, -1) h_s_expand = h_s.unsqueeze(1).expand(-1, tgt_len, -1, -1) # [batch, tgt_len, src_len, src_dim] h_expand = torch.cat((h_t_expand, h_s_expand), dim=3) h_fold = h_expand.contiguous().view(-1, src_dim + tgt_dim) h_enc = self.softplus(self.linear_1(h_fold)) h_enc = self.softplus(self.linear_2(h_enc)) h_mean = self.softplus(self.mean_out(h_enc)) h_var = self.softplus(self.var_out(h_enc)) h_mean = h_mean.view(tgt_batch, tgt_len, src_len) h_var = h_var.view(tgt_batch, tgt_len, src_len) return [h_mean, h_var]
def score(self, h_t, h_s): # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = self.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, hidden, attn, src_map, rnn_output, src_emb): """ Computes p(w) = p(z=1) p_{copy}(w|z=0) + p(z=0) * p_{softmax}(w|z=0) """ # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.tgt_dict.stoi[onmt.IO.PAD_WORD]] = -float('inf') prob = F.softmax(logits) # Probability of copying p(z=1) batch. if self.pointer_gen: """ p_gen = sigm(w1*hidden + w2*decoder_state + w3*decoder_input) hidden = post-attention hidden_state decoder_state = pre-attention hidden_state """ copy = F.sigmoid( self.linear_hidden(hidden) + self.linear_decoder_state(rnn_output) + self.linear_decoder_input(src_emb)) else: copy = F.sigmoid(self.linear_copy(hidden)) # Probability of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - copy.expand_as(prob.unsqueeze(0))).squeeze(0) mul_attn = torch.mul(attn, copy.expand_as(attn.unsqueeze(0))) copy_prob = torch.bmm( mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1)
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): assert not self._copy # TODO, no support yet. assert not self._coverage # TODO, no support yet. # Initialize local and return variables. attns = {} emb = self.embeddings(tgt) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU): rnn_output, decoder_final = self.rnn(emb, state.hidden[0]) else: rnn_output, decoder_final = self.rnn(emb, state.hidden) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # END # Calculate the attention. decoder_outputs, p_attn = self.attn(rnn_output.transpose( 0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"] = p_attn # Calculate the context gate. if self.context_gate is not None: decoder_outputs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), decoder_outputs.view(-1, decoder_outputs.size(2))) decoder_outputs = \ decoder_outputs.view(tgt_len, tgt_batch, self.hidden_size) decoder_outputs = self.dropout(decoder_outputs) return decoder_final, decoder_outputs, attns
def score(self, h_t, h_s): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch*tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_)
def _get_word_context(self, query, context, index, mask_word): b_size, t_size, d_size = query.size() b_size_, s_size, d_size_ = context.size() aeq(d_size, d_size_) b_size__, c_size = index.size() aeq(b_size, b_size__) b_size__, t_size_, s_size_ = mask_word.size() aeq(b_size_, b_size__) aeq(s_size, s_size_) aeq(t_size, t_size_) # Create padding mask for previous sentences mask_sent = index < 0 index_ = copy.deepcopy(index) index_[mask_sent] = 0 # Get context context_ = context.view(b_size_, -1).expand(b_size, b_size_, s_size * d_size) index__ = index_.unsqueeze(2).expand(b_size, c_size, s_size * d_size) context_word = torch.gather(context_, 1, Variable(index__, requires_grad=False)).view(b_size * c_size, s_size, d_size) # Get mask for context mask_ = mask_word.contiguous().view(b_size_, -1).expand(b_size, b_size_, t_size_ * s_size) index__ = index_.unsqueeze(2).expand(b_size, c_size, t_size_ * s_size) context_pad_mask = torch.gather(mask_, 1, index__).view(b_size * c_size, t_size_, s_size) # Mask previous sentences mask_sent_ = mask_sent.unsqueeze(2).expand(b_size, c_size, t_size_ * s_size).contiguous().view(b_size * c_size, t_size_, s_size) context_pad_mask[mask_sent_] = self.padding_idx # Expand query for each context sentence query_ = query.unsqueeze(1).expand(b_size, c_size, t_size, d_size).contiguous().view(b_size * c_size, t_size, d_size) return query_, context_word, context_pad_mask,
def forward(self, input, contexts=None, neg=None, tau=0.5, scale=0.5): """ Computes the embeddings for words and features. Args: input (`LongTensor`): index tensor `[len x batch x nfeat]` Return: `FloatTensor`: word embeddings `[len x batch x embedding_size]` """ in_length, in_batch, nfeat = input.size() word_emb, sense_loss = None, None if self.SenseModule is not None: aeq(nfeat - 1, len(self.emb_luts)) assert contexts is not None if neg is None: word_emb = self.SenseModule(input[:, :, 0], contexts, tau=tau, scale=scale) else: word_emb, sense_loss = self.SenseModule(input[:, :, 0], contexts, neg=neg, tau=tau, scale=scale) else: aeq(nfeat, len(self.emb_luts)) emb = self.make_embedding((input, word_emb)) out_length, out_batch, emb_size = emb.size() aeq(in_length, out_length) aeq(in_batch, out_batch) aeq(emb_size, self.embedding_size) if neg is None: return emb else: return emb, sense_loss
def forward(self, input, context, state, context_lengths=None, r_input=None): # Args Check assert isinstance(state, RNNDecoderState) input_len, input_batch, _ = input.size() contxt_len, contxt_batch, _ = context.size() aeq(input_batch, contxt_batch) # END Args Check # Run the forward pass of the RNN. context = self.context_mlp(context) if self.training: bk_att_output, bk_rnn_output = self._run_backward_pass( r_input, context, state) self.bk_rnn_output = bk_rnn_output.detach() # self.bk_rnn_output = bk_rnn_output hidden, outputs, attns, coverage = self._run_forward_pass( input, context, state, context_lengths=context_lengths) # Update the state with the result. final_output = outputs[-1] state.update_state( hidden, final_output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None) # Concatenates sequence of tensors along a new dimension. outputs = torch.stack(outputs) for k in attns: attns[k] = torch.stack(attns[k]) if self.training: return outputs, bk_att_output, state, attns else: return outputs, state, attns
def _example_dict_iter(self, line, index): if self.symbol_representation == "char": line = list(line.strip()) elif self.symbol_representation == "word": line = line.split() if self.line_truncate: line = line[:self.line_truncate] words, feats, n_feats = TextDataset.extract_text_features(line) if self.revert: words = tuple(reversed(words)) feats = tuple(reversed(feats)) example_dict = {self.side: words, "indices": index} if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update( (prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def _example_dict_iter(self, line, index): line = line.split() if self.line_truncate: line = line[:self.line_truncate] words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = {self.side: words, "indices": index} if self.side == 'tgt1': example_dict = { self.side: words, 'tgt1_planning': [int(word) for word in words], 'player_row_indices': [int(word) for word in words], 'team_row_indices': [int(word) for word in words], "indices": index } if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update( (prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def _get_sent_context(self, query, context_word, context_index, attn_word): b_size, t_size, d_size = query.size() _, c_size = context_index.size() # Sequence size now context_word is context size context_sent = context_word.view(b_size, c_size, t_size, d_size).transpose(1, 2).contiguous().view( b_size * t_size, c_size, d_size) # Creating the mask for padding by word and sentence mask_sent = context_index < 0 context_pad_mask = mask_sent.unsqueeze(1).expand(b_size, t_size, c_size).contiguous().view(b_size * t_size, -1) context_pad_mask = context_pad_mask.unsqueeze(1).contiguous() # Re-arrange the query query_ = query.view(b_size * t_size, 1, d_size) _, h, t, s = attn_word.size() aeq(t, t_size) attn_word = attn_word.view(b_size, c_size, h, t, s) return query_, context_sent, context_pad_mask, attn_word
def forward(self, base_target_emb, input, encoder_out_top, encoder_out_combine): """ It's like Luong Attetion. Conv attention takes a key matrix, a value matrix and a query vector. Attention weight is calculated by key matrix with the query vector and sum on the value matrix. And the same operation is applied in each decode conv layer. Args: base_target_emb: target emb tensor input: output of decode conv encoder_out_t: the key matrix for calculation of attetion weight, which is the top output of encode conv encoder_out_c: the value matrix for the attention-weighted sum, which is the combination of base emb and top output of encode """ # checks batch, channel, height, width = base_target_emb.size() batch_, channel_, height_, width_ = input.size() aeq(batch, batch_) aeq(height, height_) enc_batch, enc_channel, enc_height = encoder_out_top.size() enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() aeq(enc_batch, enc_batch_) aeq(enc_height, enc_height_) preatt = seq_linear(self.linear_in, input) target = (base_target_emb + preatt) * SCALE_WEIGHT target = torch.squeeze(target, 3) target = torch.transpose(target, 1, 2) pre_attn = torch.bmm(target, encoder_out_top) if self.mask is not None: pre_attn.data.masked_fill_(self.mask, -float('inf')) pre_attn = pre_attn.transpose(0, 2) attn = F.softmax(pre_attn) attn = attn.transpose(0, 2).contiguous() context_output = torch.bmm( attn, torch.transpose(encoder_out_combine, 1, 2)) context_output = torch.transpose( torch.unsqueeze(context_output, 3), 1, 2) return context_output, attn
def forward(self, hidden, attn, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by compying source words. Args: hidden (`FloatTensor`): hidden outputs `[batch*tlen, input_size]` attn (`FloatTensor`): attn for each `[batch*tlen, input_size]` src_map (`FloatTensor`): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. `[src_len, batch, extra_words]` """ # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.tgt_dict.stoi[onmt.io.PAD_WORD]] = -float('inf') prob = F.softmax(logits) # Probability of copying p(z=1) batch. p_copy = F.sigmoid(self.linear_copy(hidden)) # Probibility of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob)) mul_attn = torch.mul(attn, p_copy.expand_as(attn)) copy_prob = torch.bmm(mul_attn.view(-1, batch, slen) .transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1)
def forward(self, hidden, attn, src_map): # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.tgt_dict.stoi[onmt.io.PAD_WORD]] = -float('inf') prob = F.softmax(logits) # Probability of copying p(z=1) batch. p_copy = F.sigmoid(self.linear_copy(hidden)) # Probibility of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob)) mul_attn = torch.mul(attn, p_copy.expand_as(attn)) copy_prob = torch.bmm( mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1)
def forward(self, input): """ Return the embeddings for words, and features if there are any. Args: input (LongTensor): len x batch x nfeat Return: emb (FloatTensor): len x batch x self.embedding_size """ in_length, in_batch, nfeat = input.size() aeq(nfeat, len(self.emb_luts)) emb = self.make_embedding(input) out_length, out_batch, emb_size = emb.size() aeq(in_length, out_length) aeq(in_batch, out_batch) aeq(emb_size, self.embedding_size) return emb
def forward(self, input): """ Computes the embeddings for words and features. Args: input (`LongTensor`): index tensor `[len x batch x nfeat]` Return: `FloatTensor`: word embeddings `[len x batch x embedding_size]` """ in_length, in_batch, nfeat = input.size() aeq(nfeat, len(self.emb_luts)) emb = self.make_embedding(input) out_length, out_batch, emb_size = emb.size() aeq(in_length, out_length) aeq(in_batch, out_batch) aeq(emb_size, self.embedding_size) return emb
def forward(self, base_target_emb, input, encoder_out_top, encoder_out_combine): """ Args: base_target_emb: target emb tensor input: output of decode conv encoder_out_t: the key matrix for calculation of attetion weight, which is the top output of encode conv encoder_out_combine: the value matrix for the attention-weighted sum, which is the combination of base emb and top output of encode """ # checks batch, channel, height, width = base_target_emb.size() batch_, channel_, height_, width_ = input.size() aeq(batch, batch_) aeq(height, height_) enc_batch, enc_channel, enc_height = encoder_out_top.size() enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() aeq(enc_batch, enc_batch_) aeq(enc_height, enc_height_) preatt = seq_linear(self.linear_in, input) target = (base_target_emb + preatt) * SCALE_WEIGHT target = torch.squeeze(target, 3) target = torch.transpose(target, 1, 2) pre_attn = torch.bmm(target, encoder_out_top) if self.mask is not None: pre_attn.data.masked_fill_(self.mask, -float('inf')) pre_attn = pre_attn.transpose(0, 2) attn = F.softmax(pre_attn) attn = attn.transpose(0, 2).contiguous() context_output = torch.bmm( attn, torch.transpose(encoder_out_combine, 1, 2)) context_output = torch.transpose( torch.unsqueeze(context_output, 3), 1, 2) return context_output, attn
def score(self, h_t, h_s): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch*tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = self.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = state.input_feed.squeeze(0) input_feed_batch, _ = input_feed.size() tgt_len, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. # Initialize local and return variables. decoder_outputs = [] attns = {"std": []} if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim hidden = state.hidden coverage = state.coverage.squeeze(0) \ if state.coverage is not None else None # Input feed concatenates hidden state with # input at every time step. for i, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, hidden = self.rnn(decoder_input, hidden) decoder_output, p_attn = self.attn( rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output = self.context_gate( decoder_input, rnn_output, decoder_output ) decoder_output = self.dropout(decoder_output) input_feed = decoder_output decoder_outputs += [decoder_output] attns["std"] += [p_attn] # Update the coverage attention. if self._coverage: coverage = coverage + p_attn \ if coverage is not None else p_attn attns["coverage"] += [coverage] # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._copy: attns["copy"] = attns["std"] # Return result. return hidden, decoder_outputs, attns
def forward(self, input, memory_bank, memory_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = self.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(input, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch*targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2) attn_h = self.linear_out(concat_c).view(batch, targetL, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def _check_args(self, input, lengths=None, hidden=None): s_len, n_batch, n_feats = input.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def forward(self, key, value, query, mask=None): """ Compute the context vector and the attention vectors. Args: key (`FloatTensor`): set of `key_len` key vectors `[batch, key_len, dim]` value (`FloatTensor`): set of `key_len` value vectors `[batch, key_len, dim]` query (`FloatTensor`): set of `query_len` query vectors `[batch, query_len, dim]` mask: binary mask indicating which keys have non-zero attention `[batch, query_len, key_len]` Returns: (`FloatTensor`, `FloatTensor`) : * output context vectors `[batch, query_len, dim]` * one of the attention vectors `[batch, query_len, key_len]` """ # CHECKS batch, k_len, d = key.size() batch_, k_len_, d_ = value.size() aeq(batch, batch_) aeq(k_len, k_len_) aeq(d, d_) batch_, q_len, d_ = query.size() aeq(batch, batch_) aeq(d, d_) aeq(self.model_dim % 8, 0) if mask is not None: batch_, q_len_, k_len_ = mask.size() aeq(batch_, batch) aeq(k_len_, k_len) aeq(q_len_ == q_len) # END CHECKS batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count key_len = key.size(1) query_len = query.size(1) def shape(x): return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) # 1) Project key, value, and query. key_up = shape(self.linear_keys(key)) value_up = shape(self.linear_values(value)) query_up = shape(self.linear_query(query)) # 2) Calculate and scale scores. query_up = query_up / math.sqrt(dim_per_head) scores = torch.matmul(query_up, key_up.transpose(2, 3)) if mask is not None: mask = mask.unsqueeze(1).expand_as(scores) scores = scores.masked_fill(Variable(mask), -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.sm(scores) drop_attn = self.dropout(attn) context = unshape(torch.matmul(drop_attn, value_up)) output = self.final_linear(context) # CHECK batch_, q_len_, d_ = output.size() aeq(q_len, q_len_) aeq(batch, batch_) aeq(d, d_) # Return one attn top_attn = attn \ .view(batch_size, head_count, query_len, key_len)[:, 0, :, :] \ .contiguous() # END CHECK return output, top_attn
def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None): # Args Checks input_batch, input_len, _ = inputs.size() if previous_input is not None: pi_batch, _, _ = previous_input.size() aeq(pi_batch, input_batch) contxt_batch, contxt_len, _ = memory_bank.size() aeq(input_batch, contxt_batch) src_batch, t_len, s_len = src_pad_mask.size() tgt_batch, t_len_, t_len__ = tgt_pad_mask.size() aeq(input_batch, contxt_batch, src_batch, tgt_batch) # aeq(t_len, t_len_, t_len__, input_len) aeq(s_len, contxt_len) # END Args Checks dec_mask = torch.gt(tgt_pad_mask + self.mask[:, :tgt_pad_mask.size(1), :tgt_pad_mask.size(1)], 0) input_norm = self.layer_norm_1(inputs) all_input = input_norm if previous_input is not None: all_input = torch.cat((previous_input, input_norm), dim=1) dec_mask = None query, attn = self.self_attn(all_input, all_input, input_norm, mask=dec_mask) query = self.drop(query) + inputs query_norm = self.layer_norm_2(query) mid, attn = self.context_attn(memory_bank, memory_bank, query_norm, mask=src_pad_mask) output = self.feed_forward(self.drop(mid) + query) # CHECKS output_batch, output_len, _ = output.size() aeq(input_len, output_len) aeq(contxt_batch, output_batch) n_batch_, t_len_, s_len_ = attn.size() aeq(input_batch, n_batch_) aeq(contxt_len, s_len_) aeq(input_len, t_len_) # END CHECKS return output, attn, all_input
def forward(self, tgt, memory_bank, state, memory_lengths=None): """ See :obj:`onmt.modules.RNNDecoderBase.forward()` """ # CHECKS assert isinstance(state, TransformerDecoderState) tgt_len, tgt_batch, _ = tgt.size() memory_len, memory_batch, _ = memory_bank.size() aeq(tgt_batch, memory_batch) src = state.src src_words = src[:, :, 0].transpose(0, 1) tgt_words = tgt[:, :, 0].transpose(0, 1) src_batch, src_len = src_words.size() tgt_batch, tgt_len = tgt_words.size() aeq(tgt_batch, memory_batch, src_batch, tgt_batch) aeq(memory_len, src_len) if state.previous_input is not None: tgt = torch.cat([state.previous_input, tgt], 0) # END CHECKS # Initialize return variables. outputs = [] attns = {"std": []} if self._copy: attns["copy"] = [] # Run the forward pass of the TransformerDecoder. emb = self.embeddings(tgt) if state.previous_input is not None: emb = emb[state.previous_input.size(0):, ] assert emb.dim() == 3 # len x batch x embedding_dim output = emb.transpose(0, 1).contiguous() src_memory_bank = memory_bank.transpose(0, 1).contiguous() padding_idx = self.embeddings.word_padding_idx src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ .expand(src_batch, tgt_len, src_len) tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \ .expand(tgt_batch, tgt_len, tgt_len) saved_inputs = [] for i in range(self.num_layers): prev_layer_input = None if state.previous_input is not None: prev_layer_input = state.previous_layer_inputs[i] output, attn, all_input \ = self.transformer_layers[i](output, src_memory_bank, src_pad_mask, tgt_pad_mask, previous_input=prev_layer_input) saved_inputs.append(all_input) saved_inputs = torch.stack(saved_inputs) output = self.layer_norm(output) # Process the result and update the attentions. outputs = output.transpose(0, 1).contiguous() attn = attn.transpose(0, 1).contiguous() attns["std"] = attn if self._copy: attns["copy"] = attn # Update the state. state = state.update_state(tgt, saved_inputs) return outputs, state, attns