예제 #1
0
    def forward(self, hidden, dec_rnn_output, concat_c, attn, copy_to_ext, copy_to_tgt):
        """
        Compute a distribution over the target dictionary extended by the dynamic dictionary implied by compying source words.
        Args:
           hidden (`FloatTensor`): hidden outputs `[tlen * batch, hidden_size]`
           attn (`FloatTensor`): attn for each `[tlen * batch, src_len]`
           copy_to_ext (`FloatTensor`): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. `[src_len, batch]`
           copy_to_tgt (`FloatTensor`): A sparse indicator matrix mapping each source word to its index in the target vocab containing. `[src_len, batch]`
        """
        dec_seq_len = hidden.size(0)
        batch_size = hidden.size(1)
        # -> (targetL_ * batch_, rnn_size)
        hidden = hidden.view(dec_seq_len * batch_size, -1)
        dec_rnn_output = dec_rnn_output.view(dec_seq_len * batch_size, -1)
        concat_c = concat_c.view(dec_seq_len * batch_size, -1)
        # -> (targetL_ * batch_, sourceL_)
        attn = attn.view(dec_seq_len * batch_size, -1)

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch = copy_to_ext.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        hidden = self.dropout(hidden)

        # Original probabilities.
        logits = self.linear(hidden)
        # logits[:, self.tgt_dict.stoi[table.IO.PAD_WORD]] = -float('inf')
        prob_log = F.log_softmax(logits)
        # return prob_log.view(dec_seq_len, batch_size, -1)
        return prob_log.view(dec_seq_len, batch_size, -1)
        # Probability of copying p(z=1) batch.
        # copy = F.sigmoid(self.linear_copy(hidden))
        if self.copy_prb == 'hidden':
            copy = F.sigmoid(self.linear_copy(dec_rnn_output))
        elif self.copy_prb == 'hidden_context':
            copy = F.sigmoid(self.linear_copy(concat_c))
        else:
            raise NotImplementedError

        def safe_log(v):
            return torch.log(v.clamp(1e-3, 1 - 1e-3))

        # Probibility of not copying: p_{word}(w) * (1 - p(z))
        out_prob_log = prob_log + safe_log(copy).expand_as(prob_log)
        mul_attn = torch.mul(attn, 1.0 - copy.expand_as(attn))
        # copy to extend vocabulary
        #print(prob_log, mul_attn)
        copy_to_ext_onehot = onehot(
            copy_to_ext, N=len(self.ext_dict), ignore_index=self.ext_dict.stoi[table.IO.UNK_WORD]).float()
        ext_copy_prob = torch.bmm(mul_attn.view(-1, batch, slen).transpose(0, 1),
                                  copy_to_ext_onehot.transpose(0, 1)).transpose(0, 1).contiguous().view(-1, len(self.ext_dict))
        ext_copy_prob_log = safe_log(ext_copy_prob) * (1e8)

        return torch.cat([prob_log, ext_copy_prob_log], 1).view(dec_seq_len, batch_size, -1)
    def score(self, h_t, h_s):
        """
        h_t (FloatTensor): batch x tgt_len x dim
        h_s (FloatTensor): batch x src_len x dim
        returns scores (FloatTensor): batch x tgt_len x src_len:
            raw attention scores for each src index
        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        # aeq(src_dim, tgt_dim)
        # aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_hidden > 0:
                h_t = self.transform_in(h_t)
                h_s = self.transform_in(h_s)
            if self.attn_type == "general":
                h_t = self.linear_in(h_t)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(
                -1, self.context_size))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = self.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
    def forward(self, input, context):
        """
        input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output.
        context (FloatTensor): batch x src_len x dim: src hidden states
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = context.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)

        if self.mask is not None:
            beam_, batch_, sourceL_ = self.mask.size()
            aeq(batch, batch_ * beam_)
            aeq(sourceL, sourceL_)

        # compute attention scores, as in Luong et al.
        align = self.score(input, context)

        if self.mask is not None:
            mask_ = self.mask.view(batch, 1, sourceL).type(
                torch.bool)  # make it broardcastable
            align.data.masked_fill_(mask_, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch * targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, context)

        # concatenate
        concat_c = torch.cat([c, input], 2)
        if self.linear_out is None:
            attn_h = concat_c
        else:
            attn_h = self.linear_out(concat_c)
            if self.attn_type in ["general", "dot"]:
                attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            # aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            # aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors