def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = self.state["input_feed"].squeeze(0) input_feed_batch, _ = input_feed.size() _, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. # Initialize local and return variables. dec_outs = [] attns = {"std": []} if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim dec_state = self.state["hidden"] coverage = self.state["coverage"].squeeze(0) \ if self.state["coverage"] is not None else None # Input feed concatenates hidden state with # input at every time step. for _, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, dec_state = self.rnn(decoder_input, dec_state) decoder_output, p_attn = self.attn(rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) decoder_output = self.dropout(decoder_output) input_feed = decoder_output dec_outs += [decoder_output] attns["std"] += [p_attn] # Update the coverage attention. if self._coverage: coverage = coverage + p_attn \ if coverage is not None else p_attn attns["coverage"] += [coverage] # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._copy: attns["copy"] = attns["std"] # Return result. return dec_state, dec_outs, attns
def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: tgt (LongTensor): a sequence of input tokens tensors [len x batch x nfeats]. memory_bank (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. memory_lengths (LongTensor): the source memory_bank lengths. Returns: dec_state (Tensor): final hidden state from the decoder. dec_outs ([FloatTensor]): an array of output of every time step from the decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the decoder. """ assert not self._copy # TODO, no support yet. assert not self._coverage # TODO, no support yet. # Initialize local and return variables. tgt = tgt.unsqueeze(2) attns = {} emb = self.embeddings(tgt) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU): rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0]) else: rnn_output, dec_state = self.rnn(emb, self.state["hidden"]) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # END # Calculate the attention. dec_outs, p_attn = self.attn(rnn_output.transpose(0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"] = p_attn dec_outs = self.dropout(dec_outs) return dec_state, dec_outs, attns
def forward(self, inputs): """ Computes the embedding for words and features. Args: inputs (`LongTensor`): index tensor `[len x batch]` Return: `FloatTensor`: word embedding `[len x batch x embeddededding_size]` """ in_length, in_batch = inputs.size() #print("inputs shape: {}", inputs.shape) # aeq(nfeat, len(self.embedded_luts)) embedded = self.embedding(inputs) #print("self.droput_ratio: %f" % self.dropout.p) #print("embedded shape: {}".format(embedded.shape)) #print("embedded device: {}".format(embedded.device)) #print("embedded: {}".format(embedded)) if self.dropout is not None: embedded = self.dropout(embedded) out_length, out_batch, embedded_size = embedded.size() aeq(in_length, out_length) aeq(in_batch, out_batch) aeq(embedded_size, self.embedding_size) return embedded
def forward(self, inputs, is_dropout=True): """ Computes the embedding for words and features. Args: inputs (`LongTensor`): index tensor `[len x batch]` Return: `FloatTensor`: word embedding `[len x batch x embeddededding_size]` """ dim = inputs.dim() if dim == 2: # with batch in_length, in_batch = inputs.size() embedded = self.embedding(inputs) if self.dropout is not None and is_dropout: embedded = self.dropout(embedded) if dim == 2: out_length, out_batch, embedded_size = embedded.size() aeq(in_length, out_length) aeq(in_batch, out_batch) aeq(embedded_size, self.embedding_size) return embedded
def score(self, decoder_output, encoder_outputs): """ Args: decoder_output (`FloatTensor`): sequence of queries `[batch_sizse x tgt_len x hidden_size]` encoder_outputs (`FloatTensor`): sequence of sources `[batch_sizse x src_len x hidden_size]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch_sizse x tgt_len x src_len]` """ # Check decoder_output sizes src_batch, src_len, src_dim = encoder_outputs.size() tgt_batch, tgt_len, tgt_dim = decoder_output.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.hidden_size, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": # h_t_ = decoder_output.view(tgt_batch * tgt_len, tgt_dim) # h_t_ = self.linear_in(h_t_) # decoder_output = h_t_.view(tgt_batch, tgt_len, tgt_dim) decoder_output = self.linear_in(decoder_output) # (batch_sizse, t_len, d) x (batch_sizse, d, s_len) --> (batch_sizse, t_len, s_len) # [batch_sizse, t_len, s_len] return torch.bmm(decoder_output, encoder_outputs.transpose(1, 2)) else: hidden_size = self.hidden_size wq = self.linear_query(decoder_output.view(-1, hidden_size)) wq = wq.view(tgt_batch, tgt_len, 1, hidden_size) wq = wq.expand(tgt_batch, tgt_len, src_len, hidden_size) uh = self.linear_context(encoder_outputs.contiguous().view( -1, hidden_size)) uh = uh.view(src_batch, 1, src_len, hidden_size) uh = uh.expand(src_batch, tgt_len, src_len, hidden_size) # (batch_sizse, t_len, s_len, d) wquh = self.tanh(wq + uh) return self.v(wquh.view(-1, hidden_size)).view(tgt_batch, tgt_len, src_len)
def score(self, h_t, h_s): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = torch.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def _check_args(self, src, lengths=None, hidden=None): _, n_batch, _ = src.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def _check_args(self, tgt, memory_bank, state): assert isinstance(state, RNNDecoderState) tgt_len, tgt_batch = tgt.size() _, memory_batch, _ = memory_bank.size() aeq(tgt_batch, memory_batch)
def forward(self, decoder_output, encoder_outputs, encoder_inputs_length=None): """ Args: decoder_output (`FloatTensor`): query vectors `[batch_sizse x tgt_len x hidden_size]` memory_bank (`FloatTensor`): source vectors `[batch_sizse x src_len x hidden_size]` encoder_inputs_length (`LongTensor`): the source context lengths `[batch_size]` Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch_sizse x hidden_size]` * Attention distribtutions for each query `[tgt_len x batch_sizse x src_len]` """ # one step decoder_output if decoder_output.dim() == 2: one_step = True # insert one dimension decoder_output = decoder_output.unsqueeze(1) else: one_step = False batch_sizse, sourceL, hidden_size = encoder_outputs.size() batch_size_, targetL, hidden_sizse_ = decoder_output.size() aeq(batch_sizse, batch_size_) aeq(hidden_size, hidden_sizse_) aeq(self.hidden_size, hidden_size) # compute attention scores, as in Luong et al. align = self.score(decoder_output, encoder_outputs) #[batch_size, t_len, s_len] if encoder_inputs_length is not None: # obtain mask for memory_lenghts mask = sequence_mask(encoder_inputs_length) mask = mask.to(device=encoder_outputs.device) mask = mask.unsqueeze(1) # Make it broadcastable. # Fills elements of self tensor with value where mask is one. masked_fill_(mask, value) align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.softmax(align) # # each context vector c_t is the weighted average # over all the source hidden states context_vecotr = torch.bmm( align_vectors, encoder_outputs) #[batch_size, t_len , hidden_size] # concatenate concated_cv = torch.cat((context_vecotr, decoder_output), dim=2) #[batch_size, t_len, 2*hidden_size] attn_h = self.linear_out( concated_cv) #[batch_size, t_len, hidden_size] if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) # tanh activation if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_size_, hidden_sizse_ = attn_h.size() aeq(batch_sizse, batch_size_) aeq(hidden_size, hidden_sizse_) batch_size_, sourceL_ = align_vectors.size() aeq(batch_sizse, batch_size_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose( 0, 1).contiguous() # [t_len, batch_size, hidden_size] align_vectors = align_vectors.transpose( 0, 1).contiguous() # [t_len, batch_size, s_len] # Check output sizes targetL_, batch_size_, hidden_sizse_ = attn_h.size() aeq(targetL, targetL_) aeq(batch_sizse, batch_size_) aeq(hidden_size, hidden_sizse_) targetL_, batch_size_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch_sizse, batch_size_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def _check_args(self, input, lengths=None, hidden=None): s_len, n_batch = input.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def _run_forward_pass(self, inputs, encoder_outputs, decoder_state, encoder_inputs_length=None): """ Private helper for running the specific RNN forward pass. Must be overrided by all subclasses. Args: inputs (LongTensor): a sequence of input tokens tensors [inputs_len x batch]. encoder_outputs (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). decoder_state (FloatTensor): hidden decoder_state from the encoder RNN for initializing the decoder. encoder_inputs_length (LongTensor): the source encoder_outputs lengths. Returns: decoder_final (Variable): final hidden decoder_state from the decoder. decoder_outputs ([FloatTensor]): an array of output of every time step from the decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the decoder. """ # Initialize local and return variables. attns = {} embedded = self.embedding(inputs) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU) or isinstance(self.rnn, nn.RNN): decoder_output, decoder_final = self.rnn(embedded, decoder_state.hidden[0]) else: # LSTM decoder_output, decoder_final = self.rnn(embedded, decoder_state.hidden) # Check inputs_len, tgt_batch = inputs.size() output_len, output_batch, _ = decoder_output.size() aeq(inputs_len, output_len) aeq(tgt_batch, output_batch) # Calculate the attention. if self.attn_type is not None: # attention forward # decoder_output, p_attn = self.attn( # decoder_output.transpose(0, 1), # encoder_outputs.transpose(0, 1)) # decoder_output -> [1, batch_size, hidden_size], encoder_outputs -> # [1, batch_size, hidden_sizes] -> [batch_size, 1, hidden_size] decoder_output, p_attn = self.attn(decoder_output.transpose(0, 1), encoder_outputs.transpose(0, 1), encoder_inputs_length) attns["std"] = p_attn else: decoder_output = decoder_output # dropout decoder_output = self.dropout(decoder_output) return decoder_final, decoder_output, attns
def _check_args(self, inputs, encoder_outputs, decoder_state): assert isinstance(decoder_state, RNNDecoderState) inputs_len, tgt_batch = inputs.size() _, memory_batch, _ = encoder_outputs.size() aeq(tgt_batch, memory_batch)
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors