Пример #1
0
    def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by compying
        source words.
        Args:
           hidden (`FloatTensor`): hidden outputs `[batch, tlen, input_size]`
           attn (`FloatTensor`): attn for each `[batch, tlen, slen]`
           src_map (`FloatTensor`):
             A sparse indicator matrix mapping each source word to
             its index in the "extended" vocab containing.
             `[batch, src_len, extra_words]`
        """
        # CHECKS
        batch, tlen, _ = hidden.size()
        batch_, tlen_, slen = attn.size()
        batch, slen_, cvocab = src_map.size()
        aeq(tlen, tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, :, constants.PAD] = -self.eps
        prob = self.softmax(logits)

        # Probability of copying p(z=1) batch.
        p_copy = self.sigmoid(self.linear_copy(hidden))
        # Probibility of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob))
        mul_attn = torch.mul(attn, p_copy.expand_as(attn))
        copy_prob = torch.bmm(mul_attn,
                              src_map)  # `[batch, tlen, extra_words]`
        return torch.cat([out_prob, copy_prob], 2)
Пример #2
0
 def _check_args(self,
                 src,
                 lengths=None,
                 hidden=None):
     n_batch, _, _ = src.size()
     if lengths is not None:
         n_batch_, = lengths.size()
         aeq(n_batch, n_batch_)
Пример #3
0
    def forward(self, tgt, memory_bank, state, memory_lengths=None):
        """
        Args:
            tgt (`LongTensor`): sequences of padded tokens
                 `[batch x tgt_len x nfeats]`.
            memory_bank (`FloatTensor`): vectors from the encoder
                 `[batch x src_len x hidden]`.
            state (:obj:`onmt.models.DecoderState`):
                 decoder state object to initialize the decoder
            memory_lengths (`LongTensor`): the padded source lengths
                `[batch]`.
        Returns:
            (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`):
                * decoder_outputs: output from the decoder (after attn)
                         `[batch x tgt_len x hidden]`.
                * decoder_state: final hidden state from the decoder
                * attns: distribution over src at each tgt
                        `[batch x tgt_len x src_len]`.
        """
        # Check
        assert isinstance(state, RNNDecoderState)
        # tgt.size() returns tgt length and batch
        tgt_batch, _, _ = tgt.size()
        if self.attn is not None:
            memory_batch, _, _ = memory_bank.size()
            aeq(tgt_batch, memory_batch)
        # END

        # Run the forward pass of the RNN.
        decoder_final, decoder_outputs, attns = self._run_forward_pass(
            tgt, memory_bank, state, memory_lengths=memory_lengths)

        coverage = None
        if "coverage" in attns:
            coverage = attns["coverage"]
        # Update the state with the result.
        state.update_state(decoder_final, coverage)

        return decoder_outputs, state, attns
Пример #4
0
    def __call__(self, scores, align, target):
        # CHECKS
        batch, tlen, _ = scores.size()
        _, _tlen = target.size()
        aeq(tlen, _tlen)
        _, _tlen = align.size()
        aeq(tlen, _tlen)

        align = align.view(-1)
        target = target.view(-1)
        scores = scores.view(-1, scores.size(2))

        # Compute unks in align and target for readability
        align_unk = align.eq(constants.UNK).float()
        align_not_unk = align.ne(constants.UNK).float()
        target_unk = target.eq(constants.UNK).float()
        target_not_unk = target.ne(constants.UNK).float()

        # Copy probability of tokens in source
        out = scores.gather(1, align.view(-1, 1) + self.offset).view(-1)
        # Set scores for unk to 0 and add eps
        out = out.mul(align_not_unk) + self.eps
        # Get scores for tokens in target
        tmp = scores.gather(1, target.view(-1, 1)).view(-1)

        # Regular prob (no unks and unks that can't be copied)
        if not self.force_copy:
            # Add score for non-unks in target
            out = out + tmp.mul(target_not_unk)
            # Add score for when word is unk in both align and tgt
            out = out + tmp.mul(align_unk).mul(target_unk)
        else:
            # Forced copy. Add only probability for not-copied tokens
            out = out + tmp.mul(align_unk)

        loss = -out.log()
        return loss
Пример #5
0
    def score(self, h_t, h_s):
        """
        Args:
          h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]`
          h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]`
        Returns:
          :obj:`FloatTensor`:
           raw attention scores (unnormalized) for each src index
          `[batch x tgt_len x src_len]`
        """
        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
                h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = self.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
Пример #6
0
    def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None):
        """
        Private helper for running the specific RNN forward pass.
        Must be overriden by all subclasses.
        Args:
            tgt (LongTensor): a sequence of input tokens tensors
                                 [batch x len x nfeats].
            memory_bank (FloatTensor): output(tensor sequence) from the encoder
                        RNN of size (batch x src_len x hidden_size).
            state (FloatTensor): hidden state from the encoder RNN for
                                 initializing the decoder.
            memory_lengths (LongTensor): the source memory_bank lengths.
        Returns:
            decoder_final (Tensor): final hidden state from the decoder.
            decoder_outputs (Tensor): output from the decoder (after attn)
                         `[batch x tgt_len x hidden]`.
            attns (Tensor): distribution over src at each tgt
                        `[batch x tgt_len x src_len]`.
        """
        # Initialize local and return variables.
        attns = {}

        emb = tgt
        assert emb.dim() == 3

        coverage = state.coverage

        if isinstance(self.rnn, nn.GRU):
            rnn_output, decoder_final = self.rnn(emb, state.hidden[0])
        else:
            rnn_output, decoder_final = self.rnn(emb, state.hidden)

        # Check
        tgt_batch, tgt_len, _ = tgt.size()
        output_batch, output_len, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)
        # END

        # Calculate the attention.
        if self.attn is not None:
            decoder_outputs, p_attn, coverage_v = self.attn(
                rnn_output.contiguous(),
                memory_bank,
                memory_lengths=memory_lengths,
                coverage=coverage
            )
            attns["std"] = p_attn
        else:
            decoder_outputs = rnn_output.contiguous()

        # Update the coverage attention.
        if self._coverage:
            if coverage_v is None:
                coverage = coverage + p_attn \
                    if coverage is not None else p_attn
            else:
                coverage = coverage + coverage_v \
                    if coverage is not None else coverage_v
            attns["coverage"] = coverage

        decoder_outputs = self.dropout(decoder_outputs)
        # Run the forward pass of the copy attention layer.
        if self._copy and not self._reuse_copy_attn:
            _, copy_attn, _ = self.copy_attn(decoder_outputs,
                                             memory_bank,
                                             memory_lengths=memory_lengths)
            attns["copy"] = copy_attn
        elif self._copy:
            attns["copy"] = attns["std"]

        return decoder_final, decoder_outputs, attns
Пример #7
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """
        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
        Returns:
          (`FloatTensor`, `FloatTensor`):
          * Computed vector `[batch x tgt_len x dim]`
          * Attention distribtutions for each query
             `[batch x tgt_len x src_len]`
        """

        # one step input
        assert source.dim() == 3
        one_step = True if source.size(1) == 1 else False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(~mask, -float('inf'))

        # We adopt coverage attn described in Paulus et al., 2018
        # REF: https://arxiv.org/abs/1705.04304
        if self._coverage:
            maxes = torch.max(align, 2, keepdim=True)[0]
            exp_score = torch.exp(align - maxes)

            if one_step:
                if coverage is None:
                    # t = 1 in Eq(3) from Paulus et al., 2018
                    unnormalized_score = exp_score
                else:
                    # t = otherwise in Eq(3) from Paulus et al., 2018
                    assert coverage.dim() == 3  # B x 1 x slen
                    unnormalized_score = exp_score.div(coverage + 1e-20)
            else:
                multiplier = torch.tril(torch.ones(target_l - 1, target_l - 1))
                multiplier = multiplier.unsqueeze(0).expand(
                    batch, *multiplier.size())
                multiplier = torch.autograd.Variable(multiplier)
                multiplier = multiplier.cuda() if align.is_cuda else multiplier

                penalty = torch.bmm(multiplier,
                                    exp_score[:, :-1, :])  # B x tlen-1 x slen
                no_penalty = torch.ones_like(penalty[:, -1, :])  # B x slen
                penalty = torch.cat([no_penalty.unsqueeze(1), penalty],
                                    dim=1)  # B x tlen x slen
                assert exp_score.size() == penalty.size()
                unnormalized_score = exp_score.div(penalty + 1e-20)

            # Eq.(4) from Paulus et al., 2018
            align_vectors = unnormalized_score.div(
                unnormalized_score.sum(2, keepdim=True))

        # Softmax to normalize attention weights
        else:
            align_vectors = self.softmax(align.view(batch * target_l,
                                                    source_l))
            align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        # Check output sizes
        batch_, target_l_, dim_ = attn_h.size()
        aeq(target_l, target_l_)
        aeq(batch, batch_)
        aeq(dim, dim_)
        batch_, target_l_, source_l_ = align_vectors.size()
        aeq(target_l, target_l_)
        aeq(batch, batch_)
        aeq(source_l, source_l_)

        covrage_vector = None
        if self._coverage and one_step:
            covrage_vector = exp_score  # B x 1 x slen

        return attn_h, align_vectors, covrage_vector