def forward(self, hidden, attn, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by compying source words. Args: hidden (`FloatTensor`): hidden outputs `[batch, tlen, input_size]` attn (`FloatTensor`): attn for each `[batch, tlen, slen]` src_map (`FloatTensor`): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. `[batch, src_len, extra_words]` """ # CHECKS batch, tlen, _ = hidden.size() batch_, tlen_, slen = attn.size() batch, slen_, cvocab = src_map.size() aeq(tlen, tlen_) aeq(slen, slen_) # Original probabilities. logits = self.linear(hidden) logits[:, :, self.tgt_dict[constants.PAD_WORD]] = -self.eps prob = self.softmax(logits) # Probability of copying p(z=1) batch. p_copy = self.sigmoid(self.linear_copy(hidden)) # Probibility of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob)) mul_attn = torch.mul(attn, p_copy.expand_as(attn)) copy_prob = torch.bmm(mul_attn, src_map) # `[batch, tlen, extra_words]` return torch.cat([out_prob, copy_prob], 2)
def forward(self, tgt, memory_bank, state, memory_lengths=None): """ Args: tgt (`LongTensor`): sequences of padded tokens `[batch x tgt_len x nfeats]`. memory_bank (`FloatTensor`): vectors from the encoder `[batch x src_len x hidden]`. state (:obj:`onmt.models.DecoderState`): decoder state object to initialize the decoder memory_lengths (`LongTensor`): the padded source lengths `[batch]`. Returns: (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`): * decoder_outputs: output from the decoder (after attn) `[batch x tgt_len x hidden]`. * decoder_state: final hidden state from the decoder * attns: distribution over src at each tgt `[batch x tgt_len x src_len]`. """ # Check assert isinstance(state, RNNDecoderState) # tgt.size() returns tgt length and batch tgt_batch, _, _ = tgt.size() if self.attn is not None: memory_batch, _, _ = memory_bank.size() aeq(tgt_batch, memory_batch) # END # Run the forward pass of the RNN. decoder_final, decoder_outputs, attns = self._run_forward_pass( tgt, memory_bank, state, memory_lengths=memory_lengths) coverage = None if "coverage" in attns: coverage = attns["coverage"] # Update the state with the result. state.update_state(decoder_final, coverage) return decoder_outputs, state, attns
def __call__(self, scores, align, target): # CHECKS batch, tlen, _ = scores.size() _, _tlen = target.size() aeq(tlen, _tlen) _, _tlen = align.size() aeq(tlen, _tlen) align = align.view(-1) target = target.view(-1) scores = scores.view(-1, scores.size(2)) # Compute unks in align and target for readability align_unk = align.eq(constants.UNK).float() align_not_unk = align.ne(constants.UNK).float() target_unk = target.eq(constants.UNK).float() target_not_unk = target.ne(constants.UNK).float() # Copy probability of tokens in source out = scores.gather(1, align.view(-1, 1) + self.offset).view(-1) # Set scores for unk to 0 and add eps out = out.mul(align_not_unk) + self.eps # Get scores for tokens in target tmp = scores.gather(1, target.view(-1, 1)).view(-1) # Regular prob (no unks and unks that can't be copied) if not self.force_copy: # Add score for non-unks in target out = out + tmp.mul(target_not_unk) # Add score for when word is unk in both align and tgt out = out + tmp.mul(align_unk).mul(target_unk) else: # Forced copy. Add only probability for not-copied tokens out = out + tmp.mul(align_unk) loss = -out.log() return loss
def score(self, h_t, h_s): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = self.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, source, memory_bank, memory_lengths=None, coverage=None, softmax_weights=True): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[batch x tgt_len x dim]` * Attention distribtutions for each query `[batch x tgt_len x src_len]` """ # one step input assert source.dim() == 3 one_step = True if source.size(1) == 1 else False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(~mask, -float('inf')) # We adopt coverage attn described in Paulus et al., 2018 # REF: https://arxiv.org/abs/1705.04304 if self._coverage: maxes = torch.max(align, 2, keepdim=True)[0] exp_score = torch.exp(align - maxes) if one_step: if coverage is None: # t = 1 in Eq(3) from Paulus et al., 2018 unnormalized_score = exp_score else: # t = otherwise in Eq(3) from Paulus et al., 2018 assert coverage.dim() == 3 # B x 1 x slen unnormalized_score = exp_score.div(coverage + 1e-20) else: multiplier = torch.tril(torch.ones(target_l - 1, target_l - 1)) multiplier = multiplier.unsqueeze(0).expand( batch, *multiplier.size()) multiplier = multiplier.cuda() if align.is_cuda else multiplier penalty = torch.bmm(multiplier, exp_score[:, :-1, :]) # B x tlen-1 x slen no_penalty = torch.ones_like(penalty[:, -1, :]) # B x slen penalty = torch.cat([no_penalty.unsqueeze(1), penalty], dim=1) # B x tlen x slen assert exp_score.size() == penalty.size() unnormalized_score = exp_score.div(penalty + 1e-20) # Eq.(4) from Paulus et al., 2018 align_vectors = unnormalized_score.div( unnormalized_score.sum(2, keepdim=True)) # Softmax to normalize attention weights else: align_vectors = self.softmax(align) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) # Check output sizes batch_, target_l_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) batch_, target_l_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) covrage_vector = None if self._coverage and one_step: covrage_vector = exp_score # B x 1 x slen if softmax_weights: return attn_h, align_vectors, covrage_vector else: return attn_h, align, covrage_vector
def _check_args(self, src, lengths=None, hidden=None): n_batch, _, _ = src.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: tgt (LongTensor): a sequence of input tokens tensors [batch x len x nfeats]. memory_bank (FloatTensor): output(tensor sequence) from the encoder RNN of size (batch x src_len x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. memory_lengths (LongTensor): the source memory_bank lengths. Returns: decoder_final (Tensor): final hidden state from the decoder. decoder_outputs (Tensor): output from the decoder (after attn) `[batch x tgt_len x hidden]`. attns (Tensor): distribution over src at each tgt `[batch x tgt_len x src_len]`. """ # Initialize local and return variables. attns = {} emb = tgt assert emb.dim() == 3 coverage = state.coverage if isinstance(self.rnn, nn.GRU): rnn_output, decoder_final = self.rnn(emb, state.hidden[0]) else: rnn_output, decoder_final = self.rnn(emb, state.hidden) # Check tgt_batch, tgt_len, _ = tgt.size() output_batch, output_len, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # END # Calculate the attention. if self.attn is not None: decoder_outputs, p_attn, coverage_v = self.attn( rnn_output.contiguous(), memory_bank, memory_lengths=memory_lengths, coverage=coverage, softmax_weights=False ) attns["std"] = p_attn else: decoder_outputs = rnn_output.contiguous() # Update the coverage attention. if self._coverage: if coverage_v is None: coverage = coverage + p_attn \ if coverage is not None else p_attn else: coverage = coverage + coverage_v \ if coverage is not None else coverage_v attns["coverage"] = coverage decoder_outputs = self.dropout(decoder_outputs) # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn, _ = self.copy_attn(decoder_outputs, memory_bank, memory_lengths=memory_lengths, softmax_weights=False) attns["copy"] = copy_attn elif self._copy: attns["copy"] = attns["std"] return decoder_final, decoder_outputs, attns