Ejemplo n.º 1
0
    def encode_document(self, docs, docs_length):
        """
        Encode the documents into vectors.
        Parameters
        --------------------
            docs           -- 4d tensor (batch_size,session_length,num_candidates,max_doc_length)
            docs_length    -- 3d tensor (batch_size,session_length,num_candidates)
        Returns
        --------------------
            encoded_docs   -- 4d tensor (batch_size,session_length,num_candidates,nhid_document)
        """
        batch_size = docs.size(0)
        session_len = docs.size(1)
        num_candidates = docs.size(2)
        max_doc_len = docs.size(3)

        source_rep = docs.view(batch_size * session_len * num_candidates, -1).contiguous()
        source_len = docs_length.view(-1).contiguous()

        # (batch_size * session_len * num_candidates) x max_src_len x emsize
        source_word_rep = self.embedder(source_rep)
        # (batch_size * session_len * num_candidates) x max_src_len x nhid_document
        hidden, encoded_docs = self.document_encoder(source_word_rep, source_len)
        encoded_docs = self.dropout(encoded_docs)

        document_mask = sequence_mask(source_len, max_len=max_doc_len)
        # pooled_docs: (batch_size * session_len * num_candidates) x nhid_document
        pooled_docs = self.apply_pooling(encoded_docs,
                                         self.pool_type,
                                         dtype='document',
                                         mask=document_mask)
        # encoded_docs: batch_size x session_length x num_candidates x nhid_document
        pooled_docs = pooled_docs.view(batch_size, session_len, num_candidates, -1).contiguous()
        return pooled_docs
Ejemplo n.º 2
0
    def encode(self, queries, query_length):
        """
        Encode the queries into matrices/vectors.
        Parameters
        --------------------
            queries         -- 3d tensor (batch_size,session_length,max_query_length)
            query_length    -- 2d tensor (batch_size,session_length)
        Returns
        --------------------
            pooled_rep      -- 3d tensor (batch_size,session_length,nhid_query * num_directions)
            encoded_rep     -- 3d tensor (batch_size * session_length,max_query_length,nhid_query * num_directions)
            embedded_rep    -- 2d tensor (batch_size * session_length,emb_size)
        """
        batch_size = queries.size(0)
        session_len = queries.size(1)
        max_query_len = queries.size(2)

        source_rep = queries.view(batch_size * session_len, -1).contiguous()
        source_len = query_length.view(-1).contiguous()

        # (batch_size * session_len) x max_src_len x emsize
        source_word_rep = self.embedder(source_rep)
        # (batch_size * session_len) x max_src_len x nhid_query
        hidden, encoded_queries = self.query_encoder(source_word_rep, source_len)
        encoded_queries = self.dropout(encoded_queries)

        query_mask = sequence_mask(source_len, max_len=max_query_len)
        # pooled_queries: (batch_size * session_length) x nhid_query
        pooled_queries = self.apply_pooling(encoded_queries, self.pool_type, dtype='query', mask=query_mask)
        # batch_size x session_len x nhid_query
        pooled_queries = pooled_queries.view(batch_size, session_len, -1).contiguous()

        return pooled_queries, encoded_queries, hidden
Ejemplo n.º 3
0
    def encode_clicks(self, docs, doc_labels):
        """
        Encode all the clicked documents of queries into vectors.
        Parameters
        --------------------
            docs            -- 4d tensor (batch_size,session_length,num_candidates,nhid_document)
            doc_labels      -- 3d tensor (batch_size,session_length,num_candidates)
        Returns
        --------------------
            encoded_clicks  -- 3d tensor (batch_size,session_length,nhid_document)
        """
        batch_size = docs.size(0)
        session_len = docs.size(1)
        num_candidates = docs.size(2)
        use_cuda = docs.is_cuda

        sorted_index = doc_labels.sort(2, descending=True)[1]
        sorted_docs = [torch.index_select(docs[i, j], 0, sorted_index[i, j])
                       for i in range(batch_size)
                       for j in range(session_len)]
        # batch_size*session_len x num_candidates x nhid_document
        sorted_docs = torch.stack(sorted_docs, 0)

        click_length = numpy.count_nonzero(doc_labels.view(batch_size * session_len,
                                                           -1).cpu().numpy(), axis=1)
        click_length = torch.from_numpy(click_length)
        # (batch_size*session_len) x max_n_click
        click_length = sequence_mask(click_length)  # B*s_len x max_n_click

        click_mask = torch.ones(*sorted_docs.size()[:-1]).byte()
        click_mask[:, :click_length.size(1)] = click_length
        if use_cuda:
            click_mask = click_mask.cuda()

        att_weights = self.click_attn(sorted_docs.view(-1, sorted_docs.size(2))).squeeze(1)
        att_weights = att_weights.view(*sorted_docs.size()[:-1])
        att_weights.masked_fill_(1 - click_mask, -float('inf'))
        att_weights = f.softmax(att_weights, 1)
        encoded_clicks = torch.bmm(sorted_docs.transpose(1, 2), att_weights.unsqueeze(2)).squeeze(2)

        # encoded_clicks: batch_size x session_length x (nhid_doc * num_directions)
        encoded_clicks = encoded_clicks.contiguous().view(batch_size, session_len, -1)
        return encoded_clicks
Ejemplo n.º 4
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """
        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
        Returns:
          (`FloatTensor`, `FloatTensor`):
          * Computed vector `[batch x tgt_len x dim]`
          * Attention distribtutions for each query
             `[batch x tgt_len x src_len]`
        """

        # one step input
        assert source.dim() == 3
        one_step = True if source.size(1) == 1 else False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(~mask, -float('inf'))

        # We adopt coverage attn described in Paulus et al., 2018
        # REF: https://arxiv.org/abs/1705.04304
        if self._coverage:
            maxes = torch.max(align, 2, keepdim=True)[0]
            exp_score = torch.exp(align - maxes)

            if one_step:
                if coverage is None:
                    # t = 1 in Eq(3) from Paulus et al., 2018
                    unnormalized_score = exp_score
                else:
                    # t = otherwise in Eq(3) from Paulus et al., 2018
                    assert coverage.dim() == 3  # B x 1 x slen
                    unnormalized_score = exp_score.div(coverage + 1e-20)
            else:
                multiplier = torch.tril(torch.ones(target_l - 1, target_l - 1))
                multiplier = multiplier.unsqueeze(0).expand(
                    batch, *multiplier.size())
                multiplier = torch.autograd.Variable(multiplier)
                multiplier = multiplier.cuda() if align.is_cuda else multiplier

                penalty = torch.bmm(multiplier,
                                    exp_score[:, :-1, :])  # B x tlen-1 x slen
                no_penalty = torch.ones_like(penalty[:, -1, :])  # B x slen
                penalty = torch.cat([no_penalty.unsqueeze(1), penalty],
                                    dim=1)  # B x tlen x slen
                assert exp_score.size() == penalty.size()
                unnormalized_score = exp_score.div(penalty + 1e-20)

            # Eq.(4) from Paulus et al., 2018
            align_vectors = unnormalized_score.div(
                unnormalized_score.sum(2, keepdim=True))

        # Softmax to normalize attention weights
        else:
            align_vectors = self.softmax(align.view(batch * target_l,
                                                    source_l))
            align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        # Check output sizes
        batch_, target_l_, dim_ = attn_h.size()
        aeq(target_l, target_l_)
        aeq(batch, batch_)
        aeq(dim, dim_)
        batch_, target_l_, source_l_ = align_vectors.size()
        aeq(target_l, target_l_)
        aeq(batch, batch_)
        aeq(source_l, source_l_)

        covrage_vector = None
        if self._coverage and one_step:
            covrage_vector = exp_score  # B x 1 x slen

        return attn_h, align_vectors, covrage_vector