def score(self, h_t, h_s): """ Args: h_t (FloatTensor): sequence of queries [batch x tgt_len x h_t_dim] h_s (FloatTensor): sequence of sources [batch x src_len x h_s_dim] Returns: raw attention scores for each src index [batch x tgt_len x src_len] """ src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() utils.aeq(src_batch, tgt_batch) #utils.aeq(src_dim, tgt_dim) if self.attn_type == "bilinear": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, src_dim) h_s_ = h_s.transpose(1, 2) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = torch.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def _run_forward_pass(self, input, context, state, mask): """ Only used for beam search. Only compatible with runs with attention. Todo: implementation without attention. Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: input (LongTensor): a sequence of input tokens tensors of size (len x batch x nfeats). context (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. Returns: hidden (Variable): final hidden state from the decoder. outputs ([FloatTensor]): an array of output of every time step from the decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the decoder. coverage (FloatTensor, optional): coverage from the decoder. """ # Initialize local and return variables. attns = {"std": []} coverage = None emb = self.embedding(input) if emb.size(2) != state.hidden[0].size(2): state.hidden = [self.project_encoder(state.hidden[0])] # Run the forward pass of the RNN. if isinstance(self.rnn, StackedGRU): rnn_output, hidden = self.rnn(emb, state.hidden[0]) else: rnn_output, hidden = self.rnn(emb, state.hidden) # Result Check input_len, input_batch, _ = input.size() output_len, output_batch, _ = rnn_output.size() aeq(input_len, output_len) aeq(input_batch, output_batch) # END Result Check # Calculate the attention. attn_outputs, attn_scores = self.attention( rnn_output.contiguous(), # (output_len, batch, d) context.transpose(0, 1).contiguous(), # (contxt_len, batch, d) mask ) attns["std"] = attn_scores outputs = attn_outputs # (input_len, batch, d) # Return result. return hidden, outputs, attns, coverage
def forward(self, inp): """ Return the embeddings for words, and features if there are any. Args: inp (LongTensor): batch x len x nfeat Return: emb (Tensor): batch x len x self.embedding_size """ if inp.dim() == 2: # batch x len emb = self.word_lookup_table(inp) return emb in_batch, in_length, nfeat = inp.size() # 特征数量应与Embedding个数相同 aeq(nfeat, len(self.emb_luts)) if len(self.emb_luts) == 1: emb = self.word_lookup_table(inp.squeeze(2)) else: feat_inputs = (feat.squeeze(2) for feat in inp.split(1, dim=2)) features = [ lut(feat) for lut, feat in zip(self.emb_luts, feat_inputs) ] emb = self.merge(features) out_batch, out_length, emb_size = emb.size() aeq(in_batch, out_batch) aeq(in_length, out_length) aeq(emb_size, self.embedding_size()) return emb
def _run_forward_pass(self, ph_sel, phrase_bank, phrase_lengths=None): """ Args: ph_sum_emb (FloatTensor): a tensor of phrase embeddings for each RR sentence [batch x max_sent_num x ph_emb_dim] phrase_bank (FloatTensor): embeddings for phrase collections [batch x len x nfeats] phrase_lengths (LongTensor): the lengths of phrase collections Returns: dec_state (Tensor): final hidden state from the decoder. dec_outs ([FloatTensor]): an array of output of every time step from the decoder. ph_attns ([FloatTensor]): phrase attention Tensor array of every time step from the decoder. """ ph_sel_emb = self.embedding(ph_sel) ph_sel_emb = torch.sum(ph_sel_emb, -2) ph_sum_emb = torch.sum(ph_sel_emb, -2) rnn_output, dec_state = self.LSTM(ph_sum_emb, self.state["hidden"]) self.rnn_output = rnn_output ph_batch, ph_len, _ = ph_sum_emb.size() output_batch, output_len, _ = rnn_output.size() utils.aeq(ph_len, output_len) utils.aeq(ph_batch, output_batch) dec_outs, ph_attn, ph_attn_raw = self.ph_attn( rnn_output.contiguous(), phrase_bank.contiguous(), memory_lengths=phrase_lengths, use_softmax=False) readouts = self.readout(dec_outs) # dec_outs: [batch_size x max_sent_num x ph_emb_dim] # readouts: [batch_size x max_sent_num x ph_vocab_size] return dec_state, dec_outs, ph_attn, ph_attn_raw, readouts
def original_forward(self, input, context, state, mask): """ Only used for beam search. Forward through the decoder. Args: input (LongTensor): a sequence of input tokens tensors of size (len x batch x nfeats). context (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. Returns: outputs (FloatTensor): a Tensor sequence of output from the decoder of shape (len x batch x hidden_size). state (FloatTensor): final hidden state from the decoder. attns (dict of (str, FloatTensor)): a dictionary of different type of attention Tensor from the decoder of shape (src_len x batch). """ # Args Check assert isinstance(state, RNNDecoderState) input_len, input_batch, _ = input.size() contxt_len, contxt_batch, _ = context.size() aeq(input_batch, contxt_batch) # END Args Check # Run the forward pass of the RNN. hidden, outputs, attns, coverage = self._run_forward_pass(input, context, state, mask) # Update the state with the result. final_output = outputs[-1] state.update_state(hidden, final_output.unsqueeze(0), coverage.unsqueeze(0) if coverage is not None else None) # Concatenates sequence of tensors along a new dimension. outputs = torch.stack(outputs) for k in attns: attns[k] = torch.stack(attns[k]) return outputs, state, attns
def _run_forward_pass(self, ph_sel, phrase_bank, phrase_lengths=None): """ Args: ph_sel (batch_size x max_sent_num x max_ph_per_sent x max_ph_len): token ids for selected phrases phrase_bank (batch_size x max_ph_bank_size x dim): embeddings for phrases in each phrase bank for each sample phrase_lengths (batch_size): size of phrase bank for each sample Returns: dec_state (Tuple of C and H): final hidden state from the decoder. dec_outs (batch_size x max_sent_num x dim): an array of output of every time step from the decoder. ph_attns (batch_size x max_sent_num x max_ph_bank_size): phrase attention Tensor array of every time step from the decoder. """ ph_sel_emb = self.embedding(ph_sel) ph_sel_emb = torch.sum(ph_sel_emb, -2) # sum over all tokens in each phrase ph_sum_emb = torch.sum(ph_sel_emb, -2) # sum over all phrases in each sentence rnn_output, dec_state = self.LSTM(ph_sum_emb, self.state["hidden"]) self.rnn_output = rnn_output batch_size, max_sent_num, _ = ph_sum_emb.size() output_batch, output_len, _ = rnn_output.size() utils.aeq(max_sent_num, output_len) utils.aeq(batch_size, output_batch) dec_outs, ph_attn_probs, ph_attn_logits = self.ph_attn( rnn_output.contiguous(), phrase_bank.contiguous(), memory_lengths=phrase_lengths, use_softmax=False ) stype_logits = self.readout(dec_outs) return dec_state, dec_outs, ph_attn_probs, ph_attn_logits, stype_logits
def forward(self, input, context, context_lengths=None, context_max_len=None): """ input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output. context (FloatTensor): batch x src_len x dim: src hidden states """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = context.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) # compute attention scores, as in Luong et al. align = self.score(input, context) if context_lengths is not None: mask = self.sequence_mask(context_lengths, context_max_len) mask = mask.unsqueeze(1) # (bz, max_len) -> (bz, 1, max_len), so mask can broadcast align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = torch.softmax(align, dim=-1) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, context) # concatenate concat_c = torch.cat([c, input], -1) # linear_out if self.linear_out is None: attn_h = concat_c else: attn_h = self.linear_out(concat_c) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # (batch, targetL, dim_), (batch, targetL, sourceL) return attn_h, align_vectors
def score(self, h_t, h_s): """ h_t (FloatTensor): batch x tgt_len x dim h_s (FloatTensor): batch x src_len x dim returns scores (FloatTensor): batch x tgt_len x src_len: raw attention scores for each src index """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_hidden > 0: h_t = self.transform_in(h_t) h_s = self.transform_in(h_s) if self.attn_type == "general": h_t = self.linear_in(h_t) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = self.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, query, memory_bank, memory_lengths=None, use_softmax=True): """ Args: query (FloatTensor): query vectors [batch x tgt_len x dim] memory_bank (FloatTensor): source vectors [batch x src_len x dim] memory_lengths (LongTensor): source context lengths [batch] use_softmax (bool): use softmax to produce alignment score, otherwise use sigmoid for each individual one Returns: (FloatTensor, FloatTensor) computed attention weighted average: [batch x tgt_len x dim] attention distribution: [batch x tgt_len x src_len] """ ''' print("memory_bank:") print(memory_bank.size()) ''' if query.dim == 2: one_step = True query = query.unsqueeze(1) else: one_step = False src_batch, src_len, src_dim = memory_bank.size() query_batch, query_len, query_dim = query.size() utils.aeq(src_batch, query_batch) #utils.aeq(src_dim, query_dim) align = self.score(query, memory_bank) ''' print("memory_lengths:") print(memory_lengths.size()) print(memory_lengths) ''' if memory_lengths is not None: mask = utils.sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) align.masked_fill_(1 - mask, -float('inf')) ''' print("align:") print(align) print(align.size()) ''' if use_softmax: align_vectors = F.softmax( align.view(src_batch * query_len, src_len), -1) align_vectors = align_vectors.view(src_batch, query_len, src_len) else: align_vectors = F.sigmoid(align) ''' print("align after normalize:") print(align_vectors) print("align_vectors:") print(align_vectors) print(align_vectors.size()) print("memory_bank:") print(memory_bank) print(memory_bank.size()) ''' c = torch.bmm(align_vectors, memory_bank) # c is the attention weighted context representation # [batch x tgt_len x hidden_size] ''' print("c:") print(c.size()) print("query:") print(query.size()) ''' concat_c = torch.cat([c, query], 2).view(src_batch * query_len, src_dim + query_dim) ''' print("concat_c:") print(concat_c.size()) ''' attn_h = self.linear_out(concat_c).view(src_batch, query_len, query_dim) if self.attn_type == "bilinear": attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) batch_, dim_ = attn_h.size() utils.aeq(src_batch, batch_) utils.aeq(src_dim, dim_) batch_, src_l_ = align_vectors.size() utils.aeq(src_batch, batch_) utils.aeq(src_len, src_l_) else: batch_, target_l_, dim_ = attn_h.size() utils.aeq(target_l_, query_len) utils.aeq(batch_, query_batch) utils.aeq(dim_, query_dim) batch_, target_l_, source_l_ = align_vectors.size() utils.aeq(target_l_, query_len) utils.aeq(batch_, query_batch) utils.aeq(source_l_, src_len) return attn_h, align_vectors, align
def forward(self, query, memory_bank, memory_lengths=None, use_softmax=True): """ Args: query (FloatTensor): query vectors [batch x tgt_len x q_dim] memory_bank (FloatTensor): source vectors [batch x src_len x k_dim] memory_lengths (LongTensor): source context lengths [batch] use_softmax (bool): use softmax to produce alignment score, otherwise use sigmoid for keyphrase selection Returns: attn_h (FloatTensor, batch x tgt_len x k_dim): weighted value vectors after attention attn_vectors (FloatTensor, batch x tgt_len x src_len) : normalized attention scores align (FloatTensor, batch x tgt_len x src_len): raw attention scores used for loss calculation """ if query.dim == 2: one_step = True query = query.unsqueeze(1) else: one_step = False src_batch, src_len, src_dim = memory_bank.size() query_batch, query_len, query_dim = query.size() utils.aeq(src_batch, query_batch) #utils.aeq(src_dim, query_dim) align = self.score(query, memory_bank) if memory_lengths is not None: mask = utils.sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1).long() align.masked_fill_((1 - mask).bool(), -float('inf')) if use_softmax: align_vectors = self.softmax( align.view(src_batch * query_len, src_len)) align_vectors = align_vectors.view(src_batch, query_len, src_len) else: align_vectors = self.sigmoid(align) c = torch.bmm(align_vectors, memory_bank) # c is the attention weighted context representation # [batch x tgt_len x hidden_size] concat_c = torch.cat([c, query], 2).view(src_batch * query_len, src_dim + query_dim) attn_h = self.linear_out(concat_c).view(src_batch, query_len, query_dim) attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) batch_, dim_ = attn_h.size() utils.aeq(src_batch, batch_) utils.aeq(src_dim, dim_) batch_, src_l_ = align_vectors.size() utils.aeq(src_batch, batch_) utils.aeq(src_len, src_l_) else: batch_, target_l_, dim_ = attn_h.size() utils.aeq(target_l_, query_len) utils.aeq(batch_, query_batch) utils.aeq(dim_, query_dim) batch_, target_l_, source_l_ = align_vectors.size() utils.aeq(target_l_, query_len) utils.aeq(batch_, query_batch) utils.aeq(source_l_, src_len) return attn_h, align_vectors, align