def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ assert self.copy_attn is None # TODO, no support yet. assert not self._coverage # TODO, no support yet. attns = {} """ assert self.copy_attn is None assert not self._coverage attns = {} index_select = [ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(torch.transpose(memory_bank, 0, 1), torch.t(torch.squeeze(tgt, 2))) ] emb = torch.transpose(torch.cat(index_select), 0, 1) if isinstance(self.rnn, nn.GRU): rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0]) else: rnn_output, dec_state = self.rnn(emb, self.state["hidden"]) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # Calculate the attention p_attn = self.attn(rnn_output.transpose(0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"] = p_attn return dec_state, None, attns
def forward(self, src, img_feats, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings(src) #s_len, n_batch, emb_dim = emb.size() img_emb = self.img_to_emb(img_feats).unsqueeze(0) # prepend image "word" emb = torch.cat([img_emb, emb], dim=0) out = emb.transpose(0, 1).contiguous() words = src[:, :, 0].transpose(0, 1) # expand mask to account for image "word" words = torch.cat([words[:, 0:1], words], dim=1) # CHECKS out_batch, out_len, _ = out.size() w_batch, w_len = words.size() aeq(out_batch, w_batch) aeq(out_len, w_len) # END CHECKS # Make mask. padding_idx = self.embeddings.word_padding_idx mask = words.data.eq(padding_idx).unsqueeze(1) \ .expand(w_batch, w_len, w_len) # Run the forward pass of every layer of the tranformer. for layer in self.transformer: out = layer(out, mask) out = self.layer_norm(out) return emb, out.transpose(0, 1).contiguous(), lengths
def forward(self, source, memory_bank, memory_lengths=None, memory_turns = None, coverage=None): # here we implement a hierarchical attention if source.dim() == 2: source = source.unsqueeze(1) batch, source_tl, source_wl, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) # word level attention word_align = self.word_score(source, memory_bank.contiguous() .view(batch, -1, dim)) # transform align (b, 1, tl * wl) -> (b * tl, 1, wl) word_align = word_align.view(batch * source_tl, 1, source_wl) if memory_lengths is not None: word_mask = sequence_mask_herd(memory_lengths.view(-1), max_len=word_align.size(-1)) word_mask = word_mask.unsqueeze(1) # Make it broadcastable. word_align.masked_fill_(1 - word_mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": word_align_vectors = F.softmax(word_align.view(batch * source_tl, source_wl), -1) else: word_align_vectors = sparsemax(word_align.view(batch * source_tl, source_wl), -1) # mask the all padded sentences sent_pad_mask = memory_lengths.view(-1).eq(0).unsqueeze(1) word_align_vectors = torch.mul(word_align_vectors, (1.0 - sent_pad_mask).type_as(word_align_vectors)) word_align_vectors = word_align_vectors.view(batch * source_tl, target_l, source_wl) # each context vector c_t is the weighted average # over all the source hidden states cw = torch.bmm(word_align_vectors, memory_bank.view(batch * source_tl, source_wl, -1)) cw = cw.view(batch, source_tl, -1) # concat_cw = torch.cat([cw, source.repeat(1, source_tl, 1)], 2).view(batch*source_tl, -1) # attn_hw = self.word_linear_out(concat_cw).view(batch, source_tl, -1) # attn_hw = torch.tanh(attn_hw) # turn level attention turn_align = self.turn_score(source, cw) if memory_turns is not None: turn_mask = sequence_mask(memory_turns, max_len=turn_align.size(-1)) turn_mask = turn_mask.unsqueeze(1) # Make it broadcastable. turn_align.masked_fill_(1 - turn_mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": turn_align_vectors = F.softmax(turn_align.view(batch * target_l, source_tl), -1) else: turn_align_vectors = sparsemax(turn_align.view(batch * target_l, source_tl), -1) turn_align_vectors = turn_align_vectors.view(batch, target_l, source_tl) # each context vector c_t is the weighted average # over all the source hidden states ct = torch.bmm(turn_align_vectors, cw) return ct.squeeze(1), None
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): # here we do not need to calculate the align # because the answer vector is already averaged representations if source.dim() == 2: source = source.unsqueeze(1) batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) return c.squeeze(1), align_vectors
def forward(self, src, lengths=None): """ See :obj:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings(src) out = emb.transpose(0, 1).contiguous() words = src[:, :, 0].transpose(0, 1) # CHECKS out_batch, out_len, _ = out.size() w_batch, w_len = words.size() aeq(out_batch, w_batch) aeq(out_len, w_len) # END CHECKS # Make mask. padding_idx = self.embeddings.word_padding_idx mask = words.data.eq(padding_idx).unsqueeze(1) \ .expand(w_batch, w_len, w_len) # Run the forward pass of every layer of the tranformer. for i in range(self.num_layers): out = self.transformer[i](out, mask) out = self.layer_norm(out) return Variable(emb.data), out.transpose(0, 1).contiguous()
def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = self.state["input_feed"].squeeze(0) input_feed_batch, _ = input_feed.size() _, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. dec_outs = [] attns = {"std": []} if self.copy_attn is not None or self._reuse_copy_attn: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim dec_state = self.state["hidden"] coverage = self.state["coverage"].squeeze(0) \ if self.state["coverage"] is not None else None # Input feed concatenates hidden state with # input at every time step. for emb_t in emb.split(1): decoder_input = torch.cat([emb_t.squeeze(0), input_feed], 1) rnn_output, dec_state = self.rnn(decoder_input, dec_state) decoder_output, p_attn = self.attn( rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output = self.context_gate( decoder_input, rnn_output, decoder_output ) decoder_output = self.dropout(decoder_output) input_feed = decoder_output dec_outs += [decoder_output] attns["std"] += [p_attn] # Update the coverage attention. if self._coverage: coverage = p_attn if coverage is None else p_attn + coverage attns["coverage"] += [coverage] if self.copy_attn is not None: _, copy_attn = self.copy_attn( decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._reuse_copy_attn: attns["copy"] = attns["std"] return dec_state, dec_outs, attns
def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: tgt (LongTensor): a sequence of input tokens tensors ``(len, batch, nfeats)``. memory_bank (FloatTensor): output(tensor sequence) from the encoder RNN of size ``(src_len, batch, hidden_size)``. memory_lengths (LongTensor): the source memory_bank lengths. Returns: (Tensor, List[FloatTensor], Dict[str, List[FloatTensor]): * dec_state: final hidden state from the decoder. * dec_outs: an array of output of every time step from the decoder. * attns: a dictionary of different type of attention Tensor array of every time step from the decoder. """ assert self.copy_attn is None # TODO, no support yet. assert not self._coverage # TODO, no support yet. attns = {} emb = self.embeddings(tgt) if isinstance(self.rnn, nn.GRU): rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0]) else: rnn_output, dec_state = self.rnn(emb, self.state["hidden"]) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # Calculate the attention. if not self.attentional: dec_outs = rnn_output else: dec_outs, p_attn = self.attn(rnn_output.transpose(0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"] = p_attn # Calculate the context gate. if self.context_gate is not None: dec_outs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), dec_outs.view(-1, dec_outs.size(2))) dec_outs = dec_outs.view(tgt_len, tgt_batch, self.hidden_size) dec_outs = self.dropout(dec_outs) return dec_state, dec_outs, attns
def _compute_orthogonal_loss(self, sep_states): """ The orthogonal loss computation function sep_states: a tuple (stacked_sep_states, sep_states_lens) :return: a scalar, the orthogonal loss """ # stacked_sep_states: [b_size, max_sep_num, src_h_size] stacked_sep_states, sep_states_lens = sep_states b_size, max_sep_num, src_h_size = stacked_sep_states.size() b_size_ = len(sep_states_lens) aeq(b_size, b_size_) device = stacked_sep_states.device # obtain the mask # [b_size, max_sep_num] mask = sequence_mask(torch.Tensor(sep_states_lens)).to(device) mask = mask.float() # [b_size, 1, max_sep_num] mask = mask.unsqueeze(1) # [b_size, max_sep_num, max_sep_num] mask_2d = torch.bmm(mask.transpose(1, 2), mask) # compute the loss # [b_size, max_sep_num, max_sep_num] identity = torch.eye(max_sep_num).unsqueeze(0).repeat(b_size, 1, 1).to(device) # [b_size, max_sep_num, max_sep_num] orthogonal_loss_ = torch.bmm(stacked_sep_states, stacked_sep_states.transpose(1, 2)) - identity orthogonal_loss_ = orthogonal_loss_ * mask_2d # [b_size] orthogonal_loss = torch.norm(orthogonal_loss_.view(b_size, -1), p=2, dim=1) return orthogonal_loss
def forward(self, src1, src2): """ See :obj:`EncoderBase.forward()`""" # src: (seq_len, bsz, 1) emb1 = self.embeddings(src1) emb2 = self.embeddings(src2) # emb: (seq_len, bsz, dim) emb2_biased = emb2 + self.emb_bias emb = torch.cat([emb1, emb2_biased], dim=0) out = emb.transpose(0, 1).contiguous() src = torch.cat([src1, src2], dim=0) words = src[:, :, 0].transpose(0, 1) # CHECKS out_batch, out_len, _ = out.size() w_batch, w_len = words.size() aeq(out_batch, w_batch) aeq(out_len, w_len) # END CHECKS # Make mask.i padding_idx = self.embeddings.word_padding_idx mask = words.data.eq(padding_idx).unsqueeze(1) \ .expand(w_batch, w_len, w_len) # Run the forward pass of every layer of the tranformer. for i in range(self.num_layers): out = self.transformer[i](out, mask) out = self.layer_norm(out) return Variable(emb.data), out.transpose(0, 1).contiguous()
def _compute_orthogonal_loss(self, batch, orthog_states): """ The orthogonal loss computation function :param batch: the current batch :param orthog_states: the orthog_states from the sent level decoder :return: a scalar, the orthogonal loss """ # [b_size, s_num, tgt_s_len-1] valid_tgt = batch.tgt[:, :, 1:] b_size, s_num, _ = valid_tgt.size() b_size1, s_num1, _ = orthog_states.size() aeq(b_size, b_size1) aeq(s_num, s_num1) # obtain the mask # [b_size, s_num] mask = valid_tgt.ne(self.padding_idx).sum(dim=-1).ne(0) mask = mask.float() # [b_size, 1, s_num] mask = mask.unsqueeze(1) # [b_size, s_num, s_num] mask_2d = torch.bmm(mask.transpose(1, 2), mask) # compute the loss # [b_size, s_num, s_num] identity = torch.eye(s_num).unsqueeze(0).repeat(b_size, 1, 1).to(orthog_states.device) # [b_size, s_num, s_num] orthogonal_loss_ = torch.bmm(orthog_states, orthog_states.transpose(1, 2)) - identity orthogonal_loss_ = orthogonal_loss_ * mask_2d # [b_size] orthogonal_loss = torch.norm(orthogonal_loss_.view(b_size, -1), p=2, dim=1) return orthogonal_loss
def _check_args(self, src, lengths=None, hidden=None): if isinstance(src, dict): src = src['src'] n_batch = src.size(1) if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def forward(self, query, memory_bank, memory_lengths=None, **kwargs): """ query (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` returns attention distribution (tgt_len x batch x src_len) """ src_batch, src_len, src_dim = memory_bank.size() tgt_batch, tgt_len, tgt_dim = query.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) align = self.score(query, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) align.masked_fill_(~mask.unsqueeze(1), -float('inf')) #import pdb; pdb.set_trace() # it should not be necessary to view align as a 2d tensor, but # something is broken with sparsemax and it cannot handle a 3d tensor #print(align.size()) #print(align.view(-1, src_len).size()) #print(src_len) #print(src_batch) #return self.transform(align.view(-1, src_len), lengths=torch.tensor([src_len]*src_batch)).view_as(align) #return self.transform(align.view(-1, src_len), lengths=memory_lengths).view_as(align) return self.transform(align.view(-1, src_len)).view_as(align)
def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: tgt (LongTensor): a sequence of input tokens tensors [len x batch x nfeats]. memory_bank (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. memory_lengths (LongTensor): the source memory_bank lengths. Returns: dec_state (Tensor): final hidden state from the decoder. dec_outs ([FloatTensor]): an array of output of every time step from the decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the decoder. """ assert not self._copy # TODO, no support yet. assert not self._coverage # TODO, no support yet. # Initialize local and return variables. attns = {} emb = self.embeddings(tgt) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU): rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0]) else: rnn_output, dec_state = self.rnn(emb, self.state["hidden"]) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # END # Calculate the attention. dec_outs, p_attn = self.attn( rnn_output.transpose(0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths ) attns["std"] = p_attn # Calculate the context gate. if self.context_gate is not None: dec_outs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), dec_outs.view(-1, dec_outs.size(2)) ) dec_outs = \ dec_outs.view(tgt_len, tgt_batch, self.hidden_size) dec_outs = self.dropout(dec_outs) return dec_state, dec_outs, attns
def _example_dict_iter(self, line, index): line = line.split() if self.line_truncate: line = line[:self.line_truncate] if self.side == 'tgt': words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = {self.side: words, "indices": index} else: feats = None n_feats = 0 graph = AMR.extract_amr_features(line, reentrancies) words = graph.traverse example_dict = { self.side: words, self.side + "_graph": graph, "indices": index } if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update( (prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def score(self, h_t, h_s, type): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` type: use word or sent matrix Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) if type == 'qa_word': h_t_ = self.qa_word_linear_in(h_t_) elif type == 'qa_sent': h_t_ = self.qa_sent_linear_in(h_t_) elif type == 'pass': h_t_ = self.pass_linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_)
def _check_args(self, src, lengths=None, hidden=None): if isinstance(src, tuple): src = src[0] _, n_batch, _ = src.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def _run_forward_pass(self, tgt, memory_bank, state, answer, memory_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = state.input_feed.squeeze(0) input_feed_batch, _ = input_feed.size() _, tgt_batch = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. # Initialize local and return variables. decoder_outputs = [] attns = {"std": []} if self._copy: attns["copy"] = [] emb = self.embeddings(tgt.unsqueeze(-1)) assert emb.dim() == 3 # len x batch x embedding_dim hidden = state.hidden coverage = None # Input feed concatenates hidden state with # input at every time step. for _, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, hidden = self.rnn(decoder_input, hidden) # construct query [h, ans] to interact with sources query = torch.cat([rnn_output, answer], 1) decoder_output, p_attn = self.attn(query, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) decoder_output = self.dropout(decoder_output) input_feed = decoder_output decoder_outputs += [decoder_output] attns["std"] += [p_attn] # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._copy: attns["copy"] = attns["std"] # Return result. return hidden, decoder_outputs, attns
def forward(self, tgt, memory_bank, state, memory_lengths=None, step=None,sent_encoder=None,src_sents=None,dec=None): """ Args: tgt (`LongTensor`): sequences of padded tokens `[tgt_len x batch x nfeats]`. memory_bank (`FloatTensor`): vectors from the encoder `[src_len x batch x hidden]`. state (:obj:`onmt.models.DecoderState`): decoder state object to initialize the decoder memory_lengths (`LongTensor`): the padded source lengths `[batch]`. Returns: (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`): * decoder_outputs: output from the decoder (after attn) `[tgt_len x batch x hidden]`. * decoder_state: final hidden state from the decoder * attns: distribution over src at each tgt `[tgt_len x batch x src_len]`. """ # Check assert isinstance(state, RNNDecoderState) # tgt.size() returns tgt length and batch _, tgt_batch, _ = tgt.size() _, memory_batch, _ = memory_bank.size() aeq(tgt_batch, memory_batch) # END # 23333: TODO I changed this return value 'sent_decoder' # Run the forward pass of the RNN. decoder_final, decoder_outputs, attns = self._run_forward_pass( tgt, memory_bank, state, memory_lengths=memory_lengths,sent_encoder=sent_encoder,src_sents=src_sents,dec=dec) # Update the state with the result. final_output = decoder_outputs[-1] coverage = None if "coverage" in attns: coverage = attns["coverage"][-1].unsqueeze(0) state.update_state(decoder_final, final_output.unsqueeze(0), coverage) # Concatenates sequence of tensors along a new dimension. # NOTE: v0.3 to 0.4: decoder_outputs / attns[*] may not be list # (in particular in case of SRU) it was not raising error in 0.3 # since stack(Variable) was allowed. # In 0.4, SRU returns a tensor that shouldn't be stacke if type(decoder_outputs) == list: decoder_outputs = torch.stack(decoder_outputs) for k in attns: if type(attns[k]) == list: attns[k] = torch.stack(attns[k]) return decoder_outputs, state, attns
def forward(self, tgt, memory_bank, state, memory_lengths=None, wals_features=None, step=None): # Check assert isinstance(state, RNNDecoderStateDoublyAttentive) # tgt.size() returns tgt length and batch _, tgt_batch, _ = tgt.size() _, memory_batch, _ = memory_bank.size() _, wals_features_batch, _ = wals_features.size() aeq(tgt_batch, memory_batch) aeq(tgt_batch, wals_features_batch) # END # Run the forward pass of the RNN. decoder_final, decoder_outputs, decoder_outputs_wals, attns = self._run_forward_pass( tgt, memory_bank, state, wals_features=wals_features, memory_lengths=memory_lengths) # Update the state with the result. final_output = decoder_outputs[-1] final_output_wals = decoder_outputs_wals[-1] coverage = None coverage_wals = None if "coverage" in attns: coverage = attns["coverage"][-1].unsqueeze(0) if "coverage_wals" in attns: coverage_wals = attns["coverage_wals"][-1].unsqueeze(0) state.update_state(decoder_final, final_output.unsqueeze(0), final_output_wals.unsqueeze(0), coverage, coverage_wals) # Concatenates sequence of tensors along a new dimension. # NOTE: v0.3 to 0.4: decoder_outputs / attns[*] may not be list # (in particular in case of SRU) it was not raising error in 0.3 # since stack(Variable) was allowed. # In 0.4, SRU returns a tensor that shouldn't be stacke if type(decoder_outputs) == list: decoder_outputs = torch.stack(decoder_outputs) if type(decoder_outputs_wals) == list: decoder_outputs_wals = torch.stack(decoder_outputs_wals) for k in attns: if type(attns[k]) == list: attns[k] = torch.stack(attns[k]) return decoder_outputs, decoder_outputs_wals, state, attns
def _check_args(self, src, lengths=None, hidden=None): #import pdb;pdb.set_trace() n_batch = src.size(1) if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_) if src.size(0) != max(lengths): lengths -= max(lengths) - src.size(0) return lengths
def _run_forward_pass(self, tgt, memory_bank, state, wals_features, memory_lengths=None): assert not self._copy assert not self._coverage # Initialize local and return variables. attns = {} emb = self.embeddings(tgt) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU): rnn_output, decoder_final = self.rnn(emb, state.hidden[0]) else: rnn_output, decoder_final = self.rnn(emb, state.hidden) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # END # Calculate the attention. decoder_outputs, p_attn = self.attn(rnn_output.transpose( 0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"] = p_attn decoder_outputs_wals, p_attn_wals = self.attn_wals( rnn_output.transpose(0, 1).contiguous(), wals_features.transpose(0, 1), None) attns["std_wals"] = p_attn_wals # Calculate the context gate. if self.context_gate is not None: decoder_outputs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), decoder_outputs.view(-1, attn_outputs.size(2))) decoder_outputs = decoder_outputs.view(tgt_len, tgt_batch, self.hidden_size) decoder_outputs = self.dropout(decoder_outputs) else: decoder_outputs = self.dropout(decoder_outputs) # no context gate on WALS features decoder_outputs_wals = self.dropout(decoder_outputs_wals) # Return result. return decoder_final, decoder_outputs, decoder_outputs_wals, attns
def _run_forward_pass(self, tgt, word_memory_bank, sent_memory_bank,sent_context, state, word_memory_lengths, sent_memory_lengths, static_attn): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = state.input_feed.squeeze(0) input_feed_batch, _ = input_feed.size() _, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. # Initialize local and return variables. decoder_outputs = [] attns = {"std": []} emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim # topic #topic_emb = self.topic_emb(tgt) #topic_emb = topic_emb.squeeze(2) #emb = torch.cat((emb, topic_emb), 2) #emb = self.norm_linear_topic(emb) hidden = state.hidden # Input feed concatenates hidden state with # input at every time step. for outidx, emb_t in enumerate(emb.split(1)): # logger.info('generate %d word' %outidx) emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, hidden = self.rnn(decoder_input, hidden) # attn decoder_output, attn = self.attn( rnn_output, word_memory_bank, word_memory_lengths, sent_memory_bank, sent_memory_lengths, sent_context, static_attn) decoder_output = self.dropout(decoder_output) input_feed = decoder_output decoder_outputs += [decoder_output] attns["std"] += [attn] # Return result. return hidden, decoder_outputs, attns
def forward(self, hidden, attn, src_map, align=None, ptrs=None, tags=None): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by copying source words. Args: hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` attn (FloatTensor): attn for each ``(batch x tlen, input_size)`` src_map (FloatTensor): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. ``(src_len, batch, extra_words)`` """ # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.pad_idx] = -float('inf') prob = torch.softmax(logits, 1) # Probability of copying p(z=1) batch. p_copy = torch.sigmoid(self.linear_copy(hidden)) # Probability of not copying: p_{word}(w) * (1 - p(z)) if self.training and ptrs is not None: align_unk = align.eq(0).float().view(-1, 1) align_not_unk = align.ne(0).float().view(-1, 1) out_prob = torch.mul(prob, align_unk) mul_attn = torch.mul(attn, align_not_unk) mul_attn = torch.mul(mul_attn, ptrs.view(-1, slen_).float()) else: out_prob = torch.mul(prob, 1 - p_copy) # Mask disallowed copys if tags is not None: mul_attn = torch.mul(attn, tags.t())*2 else: mul_attn = attn mul_attn = torch.mul(mul_attn, p_copy) copy_prob = torch.bmm( mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1) ).transpose(0, 1) # The P_copy actual contain the importance of the word from the training decision. copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1), p_copy
def _run_forward_pass(self, tgt, word_memory_bank, sent_memory_bank, state, word_memory_lengths, sent_memory_lengths, C): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = state.input_feed.squeeze(0) input_feed_batch, _ = input_feed.size() _, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. C_final = self.V( torch.tanh(C + input_feed.unsqueeze(1).expand(-1, C.size(1), -1)) ).expand(-1, -1, C.size(2)) * C r_t = torch.sum(C_final, 1) # Initialize local and return variables. decoder_outputs = [] attns = {"std": []} emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim hidden = state.hidden # Input feed concatenates hidden state with # input at every time step. for outidx, emb_t in enumerate(emb.split(1)): # logger.info('generate %d word' %outidx) emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed, r_t], 1) rnn_output, hidden = self.rnn(decoder_input, hidden) # attn decoder_output, attn = self.attn(rnn_output, word_memory_bank, word_memory_lengths, sent_memory_bank, sent_memory_lengths) decoder_output = self.dropout(decoder_output) input_feed = decoder_output C_final = self.V( torch.tanh(C + input_feed.unsqueeze(1).expand( -1, C.size(1), -1))).expand(-1, -1, C.size(2)) * C r_t = torch.sum(C_final, 1) decoder_outputs += [decoder_output] attns["std"] += [attn] # Return result. return hidden, decoder_outputs, attns
def _example_dict_iter(self, line, index): sessions = line.strip('\n').split('||') for s in sessions: assert len(s.split('\t')) == 11 session_id = [s.split('\t')[0] for s in sessions] item_sku_id = [s.split('\t')[1] for s in sessions] user_log = [s.split('\t')[2] for s in sessions] operator = [s.split('\t')[3] for s in sessions] user_site_cy = [s.split('\t')[4] for s in sessions] user_site_pro = [s.split('\t')[5] for s in sessions] user_site_ct = [s.split('\t')[6] for s in sessions] stm = [int(s.split('\t')[7]) for s in sessions] page_ts = [int(s.split('\t')[8]) for s in sessions] item_name = [s.split('\t')[9].split() for s in sessions] item_comment = [s.split('\t')[10].split() for s in sessions] line = [] if self.line_truncate: for tmp_name, tmp_comment in zip(item_name, item_comment): line.extend(tmp_name[:self.line_truncate]) line.extend(tmp_comment[:self.line_truncate]) else: for tmp_name, tmp_comment in zip(item_name, item_comment): line.extend(tmp_name) line.extend(tmp_comment) words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = { self.side: words, self.side + "_session_id": session_id, self.side + "_item_sku": item_sku_id, self.side + "_user_log": user_log, self.side + "_operator": operator, self.side + "_site_cy": user_site_cy, self.side + "_site_pro": user_site_pro, self.side + "_site_ct": user_site_ct, self.side + "_stm": stm, self.side + "_page_ts": page_ts, "indices": index } if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update( (prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def _check_args(self, src, lengths=None, hidden=None): # print("in chcek Args") # print(type(src)) _, n_batch, _ = src.size() #print(n_batch) #print(lengths.size()) if lengths is not None: # print("encoder base \n") # print(lengths.size()) x_batch_ = lengths.size() n_batch_, = lengths.size() #print(' <<<<<<<<<<<<<<<<<', (n_batch, n_batch_)) aeq(n_batch, n_batch_)
def forward(self, hidden, attn, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by copying source words. Args: hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` attn (FloatTensor): attn for each ``(batch x tlen, input_size)`` src_map (FloatTensor): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. ``(src_len, batch, extra_words)`` """ if self.conv_first: attn = torch.unsqueeze(attn, 1) original_seq_len = src_map.shape[0] if original_seq_len % 3 == 0: attn = self.conv_transpose(attn) elif original_seq_len % 3 == 1: attn = self.conv_transpose_pad1(attn) else: attn = self.conv_transpose_pad2(attn) attn = torch.squeeze(attn, 1) # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.pad_idx] = -float('inf') prob = torch.softmax(logits, 1) # Probability of copying p(z=1) batch. p_copy = torch.sigmoid(self.linear_copy(hidden)) # Probability of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - p_copy) mul_attn = torch.mul(attn, p_copy) copy_prob = torch.bmm( mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1)
def _example_dict_iter(self, line, index): line = line.split() if self.line_truncate: line = line[:self.line_truncate] words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = {self.side: words, "indices": index} if feats: aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update( (prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def forward(self, hidden, attn, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by copying source words. Args: hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` attn (FloatTensor): attn for each ``(batch x tlen, input_size)`` src_map (FloatTensor): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. ``(src_len, batch, extra_words)`` """ # CHECKS # hidden = (tgt_len * batch, hidden) # attn = (tgt_len * batch, src_len) # src_map = (src_len * batch, cvocab) batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. # logits = (tgt_len * batch, tvocab) logits = self.linear(hidden) logits[:, self.pad_idx] = -float('inf') # prob = (tgt_len * batch, tvocab) prob = torch.softmax(logits, 1) # Probability of copying p(z=1) batch. # p_copy = (tgt_len * batch, 1) p_copy = torch.sigmoid(self.linear_copy(hidden)) # Probability of not copying: p_{word}(w) * (1 - p(z)) # out_prob = (tgt_len * batch, tvocab) out_prob = torch.mul(prob, 1 - p_copy) # mul_attn = (tgt_len * batch, src_len) mul_attn = torch.mul(attn, p_copy) # copy_prob = (batch, tgt_len, src_len) x (batch, src_len, cvocab) --> (batch, tgt_len, cvocab) # copy_prob --> (tgt_len, batch, cvocab) copy_prob = torch.bmm( mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) # copy_prob --> (tgt_len * batch, cvocab) copy_prob = copy_prob.contiguous().view(-1, cvocab) # --> (tgt_len * batch, tvocab + cvocab) return torch.cat([out_prob, copy_prob], 1)
def forward(self, hidden, his_attn, cur_attn, his_mid, cur_mid, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by copying source words. Args: hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` his_mid (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` cur_mid (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` his_attn (FloatTensor): attn for each ``(batch x tlen, input_size)`` cur_attn (FloatTensor): attn for each ``(batch x tlen, input_size)`` src_map (FloatTensor): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. ``(src_len, batch, extra_words)`` """ # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = his_attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.pad_idx] = -float('inf') prob = torch.softmax(logits, 1) # Probability of lambda feature = self.hidden_dense(hidden) + self.his_dense( his_mid) + self.cur_dense(cur_mid) lambda_gate = torch.sigmoid(feature) # Probability of copying p(z=1) batch. p_copy = torch.sigmoid(self.linear_copy(hidden)) # Probability of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - p_copy) attn = lambda_gate * his_attn + (1 - lambda_gate) * cur_attn mul_attn = torch.mul(attn, p_copy) copy_prob = torch.bmm( mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1), attn
def forward(ctx, input, target): """ input (FloatTensor): ``(n, num_classes)``. target (LongTensor): ``(n,)``, the indices of the target classes """ input_batch, classes = input.size() target_batch = target.size(0) aeq(input_batch, target_batch) z_k = input.gather(1, target.unsqueeze(1)).squeeze() tau_z, support_size = _threshold_and_support(input, dim=1) support = input > tau_z x = torch.where( support, input**2 - tau_z**2, torch.tensor(0.0, device=input.device) ).sum(dim=1) ctx.save_for_backward(input, target, tau_z) # clamping necessary because of numerical errors: loss should be lower # bounded by zero, but negative values near zero are possible without # the clamp return torch.clamp(x / 2 - z_k + 0.5, min=0.0)
def forward(self, hidden, attn, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by copying source words. Args: hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` attn (FloatTensor): attn for each ``(batch x tlen, input_size)`` src_map (FloatTensor): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. ``(src_len, batch, extra_words)`` """ # CHECKS batch_by_tlen, _ = hidden.size() batch_by_tlen_, slen = attn.size() slen_, batch, cvocab = src_map.size() aeq(batch_by_tlen, batch_by_tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, self.pad_idx] = -float('inf') prob = torch.softmax(logits, 1) # Probability of copying p(z=1) batch. p_copy = torch.sigmoid(self.linear_copy(hidden)) # Probability of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - p_copy) mul_attn = torch.mul(attn, p_copy) copy_prob = torch.bmm( mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1) ).transpose(0, 1) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1)
def forward(self, base_target_emb, input_from_dec, encoder_out_top, encoder_out_combine): """ Args: base_target_emb: target emb tensor input_from_dec: output of decode conv encoder_out_top: the key matrix for calculation of attetion weight, which is the top output of encode conv encoder_out_combine: the value matrix for the attention-weighted sum, which is the combination of base emb and top output of encode """ # checks # batch, channel, height, width = base_target_emb.size() batch, _, height, _ = base_target_emb.size() # batch_, channel_, height_, width_ = input_from_dec.size() batch_, _, height_, _ = input_from_dec.size() aeq(batch, batch_) aeq(height, height_) # enc_batch, enc_channel, enc_height = encoder_out_top.size() enc_batch, _, enc_height = encoder_out_top.size() # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() enc_batch_, _, enc_height_ = encoder_out_combine.size() aeq(enc_batch, enc_batch_) aeq(enc_height, enc_height_) preatt = seq_linear(self.linear_in, input_from_dec) target = (base_target_emb + preatt) * SCALE_WEIGHT target = torch.squeeze(target, 3) target = torch.transpose(target, 1, 2) pre_attn = torch.bmm(target, encoder_out_top) if self.mask is not None: pre_attn.data.masked_fill_(self.mask, -float('inf')) attn = F.softmax(pre_attn, dim=2) context_output = torch.bmm( attn, torch.transpose(encoder_out_combine, 1, 2)) context_output = torch.transpose( torch.unsqueeze(context_output, 3), 1, 2) return context_output, attn
def score(self, h_t, h_s): """ Args: h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)`` h_s (FloatTensor): sequence of sources ``(batch, src_len, dim`` Returns: FloatTensor: raw attention scores (unnormalized) for each src index ``(batch, tgt_len, src_len)`` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = torch.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors
def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = self.state["input_feed"].squeeze(0) input_feed_batch, _ = input_feed.size() _, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. dec_outs = [] attns = {} if self.attn is not None: attns["std"] = [] if self.copy_attn is not None or self._reuse_copy_attn: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim dec_state = self.state["hidden"] coverage = self.state["coverage"].squeeze(0) \ if self.state["coverage"] is not None else None # Input feed concatenates hidden state with # input at every time step. for emb_t in emb.split(1): decoder_input = torch.cat([emb_t.squeeze(0), input_feed], 1) rnn_output, dec_state = self.rnn(decoder_input, dec_state) if self.attentional: decoder_output, p_attn = self.attn( rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"].append(p_attn) else: decoder_output = rnn_output if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output = self.context_gate( decoder_input, rnn_output, decoder_output ) decoder_output = self.dropout(decoder_output) input_feed = decoder_output dec_outs += [decoder_output] # Update the coverage attention. if self._coverage: coverage = p_attn if coverage is None else p_attn + coverage attns["coverage"] += [coverage] if self.copy_attn is not None: _, copy_attn = self.copy_attn( decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._reuse_copy_attn: attns["copy"] = attns["std"] return dec_state, dec_outs, attns
def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: tgt (LongTensor): a sequence of input tokens tensors ``(len, batch, nfeats)``. memory_bank (FloatTensor): output(tensor sequence) from the encoder RNN of size ``(src_len, batch, hidden_size)``. memory_lengths (LongTensor): the source memory_bank lengths. Returns: (Tensor, List[FloatTensor], Dict[str, List[FloatTensor]): * dec_state: final hidden state from the decoder. * dec_outs: an array of output of every time step from the decoder. * attns: a dictionary of different type of attention Tensor array of every time step from the decoder. """ assert self.copy_attn is None # TODO, no support yet. assert not self._coverage # TODO, no support yet. attns = {} emb = self.embeddings(tgt) if isinstance(self.rnn, nn.GRU): rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0]) else: rnn_output, dec_state = self.rnn(emb, self.state["hidden"]) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # Calculate the attention. if not self.attentional: dec_outs = rnn_output else: dec_outs, p_attn = self.attn( rnn_output.transpose(0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths ) attns["std"] = p_attn # Calculate the context gate. if self.context_gate is not None: dec_outs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), dec_outs.view(-1, dec_outs.size(2)) ) dec_outs = dec_outs.view(tgt_len, tgt_batch, self.hidden_size) dec_outs = self.dropout(dec_outs) return dec_state, dec_outs, attns