示例#1
0
 def test_concat_enc_outs(self):
     enc_output = torch.rand(self.bsz * self.n_docs, self.seqlen, self.esz)
     enc_input, mask = self._create_input_and_mask()
     # Right padded
     mask = mask.repeat_interleave(self.n_docs, dim=0)
     _, new_mask = concat_enc_outs(enc_input, enc_output, mask, self.esz,
                                   self.pad_idx)
     ########################################################################
     # Assertion: new mask has `True` elements in first (n_docs * seqlen_i) #
     # tokens in concatenated output                                        #
     ########################################################################
     assert all(new_mask[i, :self.batch_lens[i] *
                         self.n_docs].sum() == self.n_docs *
                self.batch_lens[i] for i in range(self.bsz))
     # Left padded
     enc_input, mask = self._create_input_and_mask(right_padded=False)
     mask = mask.repeat_interleave(self.n_docs, dim=0)
     _, new_mask = concat_enc_outs(enc_input,
                                   enc_output,
                                   mask,
                                   self.esz,
                                   self.pad_idx,
                                   right_padded=False)
     #######################################################################
     # Assertion: new mask has `True` elements in last (n_docs * seqlen_i) #
     # tokens in concatenated output                                       #
     #######################################################################
     assert all(new_mask[i, -(self.batch_lens[i] *
                              self.n_docs):].sum() == self.n_docs *
                self.batch_lens[i] for i in range(self.bsz))
示例#2
0
    def encoder(
        self,
        input: torch.LongTensor,
        input_lengths: torch.LongTensor,
        query_vec: torch.LongTensor,
        input_turns_cnt: torch.LongTensor,
        memory_vec: torch.LongTensor,
        num_memories: torch.LongTensor,
        query_generator_vec: torch.LongTensor,
        gold_doc_vec: torch.LongTensor,
        gold_doc_title_vec: torch.LongTensor,
        num_gold_docs: torch.LongTensor,
        memory_decoder_vec: torch.LongTensor,
        num_memory_decoder_vecs: torch.LongTensor,
        positions: Optional[torch.LongTensor] = None,
        segments: Optional[torch.LongTensor] = None,
    ) -> Tuple[
        torch.Tensor,
        torch.BoolTensor,
        Optional[torch.LongTensor],
        Optional[List[List[Document]]],
        Optional[torch.Tensor],
    ]:
        enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = super().encoder(  # type: ignore
            input,
            input_lengths,
            query_vec,
            input_turns_cnt,
            memory_vec,
            num_memories,
            query_generator_vec,
            gold_doc_vec,
            gold_doc_title_vec,
            num_gold_docs,
            memory_decoder_vec,
            num_memory_decoder_vecs,
            positions,
            segments,
        )  # type: ignore

        if input_turns_cnt is not None:
            # Input Turns is a tensor of dim [bsz]
            input = input.repeat_interleave(input_turns_cnt, dim=0)  # type: ignore

        new_out, new_mask = concat_enc_outs(
            input, enc_out, mask, self.embedding_size, self.pad_idx
        )

        return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores
示例#3
0
    def encoder(
        self,
        input: torch.LongTensor,
        input_lengths: torch.LongTensor,
        query_vec: torch.LongTensor,
        input_turns_cnt: torch.LongTensor,
        target_lengths: Optional[torch.LongTensor],
        positions: Optional[torch.LongTensor] = None,
        segments: Optional[torch.LongTensor] = None,
    ) -> Tuple[
        torch.Tensor,
        torch.BoolTensor,
        Optional[torch.LongTensor],
        Optional[List[List[Document]]],
        Optional[torch.Tensor],
    ]:
        """
        Override FidModel.encoder to pack all the documents into one input example.

        :param input:
            2D [bsz, seqlen] input to the encoder
        :param input_lengths:
            1D [bsz] lengths of each input item
        :param query_vec:
            2D [bsz*n_turns, seqlen] input for the retriever
        :param input_turns_cnt:
            1D [bsz] number of dialogue turns for each input example
        :param input_lengths:
            1D [bsz] lengths of each target item (for each input item)

        :return (encoder_out, encoder_mask, input_turns_cnt, top_docs, top_doc_scores):
            encoder_out: *concatenated* encoded representations of context/document pairs
            encoder_mask: new mask for enc_out
            input_turns_cnt: pass along the input turns count for the decoder
            top_docs: List of top Documents for each batch example
            top_doc_scores: scores for each retrieved document.
        """
        enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = RagModel.encoder(
            self, input, input_lengths, query_vec, input_turns_cnt, positions, segments
        )  # type: ignore
        seq_len, n_docs = enc_out.size(1), enc_out.size(0) // input.size(0)

        if input_turns_cnt is not None:
            # Input Turns is a tensor of dim [bsz]
            input = input.repeat_interleave(input_turns_cnt, dim=0)  # type: ignore
        doc_starts = (enc_out == self.pad_idx).sum(dim=1)  # left padded
        doc_lens = (seq_len - input_lengths.repeat_interleave(n_docs)) - doc_starts
        # if no padding, docs are assumed to be min doc length long
        doc_lens[doc_lens.le(0)] = self.min_doc_len
        new_enc_out = enc_out.clone()
        # BEFORE:
        #  [
        #   pad...doc_0 / in_0
        #   pad...doc_1 / in_0
        #   ...
        #   pad...doc_n / in_m
        #                       ]
        total_length_i = 0
        for i, doc_len in enumerate(doc_lens):
            if i % n_docs == 0:
                total_length_i = 0
            # max doc length is determined by how much space we have after subtracting input length
            input_and_target = input_lengths[i // n_docs] + (
                target_lengths[i // n_docs] if target_lengths is not None else 0
            )
            max_doc_len = torch.div(
                (seq_len - input_and_target), n_docs, rounding_mode='floor'
            )
            max_doc_len = max(max_doc_len, self.min_doc_len)  # type: ignore
            doc_len = min(doc_len, max_doc_len)
            total_length_i += doc_len
            if i % n_docs == n_docs - 1:
                # keep the actual input context when processing the last doc.
                clamped_input_length = input_lengths[i // n_docs].clamp(
                    max=self.truncate - total_length_i
                )
                if target_lengths is not None:
                    clamped_input_length = input_lengths[i // n_docs].clamp(
                        max=self.truncate - total_length_i - target_lengths
                    )

                pad_end = seq_len - clamped_input_length - doc_len
            else:
                pad_end = seq_len - doc_len
            # we simply move the portion of the doc we want to keep to the end of the tensor,
            # and mask out everything else.
            mask[i, :pad_end] = False
            new_enc_out[i, pad_end : pad_end + doc_len] = enc_out[
                i, doc_starts[i] : doc_starts[i] + doc_len
            ]

        new_out, new_mask = concat_enc_outs(
            input, new_enc_out.unsqueeze(-1), mask, 1, self.pad_idx, right_padded=False
        )
        # After:
        #  [
        #   doc_0 / doc_1 / doc_2 ... in_0
        #   doc_0 / doc_1 / doc_2 ... in_m
        #                                   ]
        return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores