def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: tgt (LongTensor): a sequence of input tokens tensors [len x batch x nfeats]. memory_bank (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. memory_lengths (LongTensor): the source memory_bank lengths. Returns: decoder_final (Variable): final hidden state from the decoder. decoder_outputs ([FloatTensor]): an array of output of every time step from the decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the decoder. """ assert not self._copy # TODO, no support yet. assert not self._coverage # TODO, no support yet. # Initialize local and return variables. attns = {} emb = self.embeddings(tgt) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU): rnn_output, decoder_final = self.rnn(emb, state.hidden[0]) else: rnn_output, decoder_final = self.rnn(emb, state.hidden) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # END # Calculate the attention. decoder_outputs, p_attn = self.attn(rnn_output.transpose( 0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"] = p_attn # Calculate the context gate. if self.context_gate is not None: decoder_outputs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), decoder_outputs.view(-1, decoder_outputs.size(2))) decoder_outputs = \ decoder_outputs.view(tgt_len, tgt_batch, self.hidden_size) decoder_outputs = self.dropout(decoder_outputs) return decoder_final, decoder_outputs, attns
def _example_dict_iter(self, line, index): line = line.split() if self.line_truncate: line = line[:self.line_truncate] words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = {self.side: words, "indices": index} if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update( (prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def forward(self, tgt, memory_bank, state, memory_lengths=None): """ Args: tgt (`LongTensor`): sequences of padded tokens `[tgt_len x batch x nfeats]`. memory_bank (`FloatTensor`): vectors from the encoder `[src_len x batch x hidden]`. state (:obj:`mtos.Models.DecoderState`): decoder state object to initialize the decoder memory_lengths (`LongTensor`): the padded source lengths `[batch]`. Returns: (`FloatTensor`,:obj:`mtos.Models.DecoderState`,`FloatTensor`): * decoder_outputs: output from the decoder (after attn) `[tgt_len x batch x hidden]`. * decoder_state: final hidden state from the decoder * attns: distribution over src at each tgt `[tgt_len x batch x src_len]`. """ # Check assert isinstance(state, RNNDecoderState) tgt_len, tgt_batch, _ = tgt.size() _, memory_batch, _ = memory_bank.size() aeq(tgt_batch, memory_batch) # END # Run the forward pass of the RNN. decoder_final, decoder_outputs, attns = self._run_forward_pass( tgt, memory_bank, state, memory_lengths=memory_lengths) # Update the state with the result. final_output = decoder_outputs[-1] coverage = None if "coverage" in attns: coverage = attns["coverage"][-1].unsqueeze(0) state.update_state(decoder_final, final_output.unsqueeze(0), coverage) # Concatenates sequence of tensors along a new dimension. decoder_outputs = torch.stack(decoder_outputs) #weight_lists = torch.stack(weight_list) for k in attns: attns[k] = torch.stack(attns[k]) return decoder_outputs, state, attns
def _check_args(self, input, lengths=None, hidden=None): s_len, n_batch, n_feats = input.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = state.input_feed.squeeze(0) input_feed_batch, _ = input_feed.size() tgt_len, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. # Initialize local and return variables. decoder_outputs = [] #weight_list = [] attns = {"std": []} if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim hidden = state.hidden coverage = state.coverage.squeeze(0) \ if state.coverage is not None else None # Input feed concatenates hidden state with # input at every time step. for i, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, hidden = self.rnn(decoder_input, hidden) decoder_output, p_attn = self.attn(rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output = self.context_gate(decoder_input, rnn_output, decoder_output) # decoder output : tgt_len * batch * hidden # memory : src_len * batch * hidden #print("model line:512 memory_bank[-1]", memory_bank[-1].size()) # batch * hidden #print("model line:512 decoder_output", decoder_output) # batch * hidden #concat_type_info = torch.cat([memory_bank[-1], decoder_output], 1).view(tgt_batch, -1) #print("model line:512 concat_type", concat_type_info) #type_weight = self.type_weight(concat_type_info) # batch * (hidden*2) #print("model line:518 concat_type", type_weight) #self.type_weight() #weight_list += [type_weight] decoder_output = self.dropout(decoder_output) input_feed = decoder_output decoder_outputs += [decoder_output] attns["std"] += [p_attn] # Update the coverage attention. if self._coverage: coverage = coverage + p_attn \ if coverage is not None else p_attn attns["coverage"] += [coverage] # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._copy: attns["copy"] = attns["std"] # Return result. return hidden, decoder_outputs, attns