Esempio n. 1
0
    def get_next_decoder_input(
        self,
        prev_input: torch.LongTensor,
        selection: torch.LongTensor,
        incr_state_inds: torch.LongTensor,
    ) -> torch.LongTensor:
        """
        Repeat decoder input accordingly for RAG Token.

        RAG Turn marginalizes over all the documents within turns; this means that during generation,
        although the input will be of length [n_docs*n_turns*bsz], the output will be of length [n_turns*bsz].

        Thus, we need to make sure that we repeat the beam selection n_docs times and continue to
        marginalize over documents accordingly.

        :param prev_input:
            previous [bsz*n_docs*n_turns, seqlen] input to the decoder.
        :param selection:
            [bsz*n_turns]-length beam selection ids for the decoder
        :param incr_state_inds:
            beam indices that are continuing in next generation.

        :return dec_input:
            return the decoder input with appropriate repeated selections.
        """
        prev_input = _unstack_ctxt(prev_input, self.n_docs).index_select(
            0, incr_state_inds
        )  # type: ignore
        dec_input = torch.cat(
            [prev_input, selection.repeat_interleave(self.n_docs, 1).unsqueeze(-1)],
            dim=-1,
        )
        dec_input = _stack_ctxt(dec_input)

        return dec_input  # type: ignore
    def corrupt_batch(
            self, positive_batch: torch.LongTensor
    ) -> torch.LongTensor:  # noqa: D102
        if self.num_negs_per_pos > 1:
            positive_batch = positive_batch.repeat_interleave(
                repeats=self.num_negs_per_pos, dim=0)

        # Bind number of negatives to sample
        num_negs = positive_batch.shape[0]

        # Copy positive batch for corruption.
        # Do not detach, as no gradients should flow into the indices.
        negative_batch = positive_batch.clone()

        device = positive_batch.device
        # Decide whether to corrupt head or tail
        head_corruption_probability = self.corrupt_head_probability[
            positive_batch[:, 1]]
        head_mask = torch.rand(
            num_negs,
            device=device) < head_corruption_probability.to(device=device)

        # Tails are corrupted if heads are not corrupted
        tail_mask = ~head_mask

        # We at least make sure to not replace the triples by the original value
        # See below for explanation of why this is on a range of [0, num_entities - 1]
        index_max = self.num_entities - 1

        # Randomly sample corruption.
        negative_entities = torch.randint(
            index_max,
            size=(num_negs, ),
            device=positive_batch.device,
        )

        # Replace heads
        negative_batch[head_mask, 0] = negative_entities[head_mask]

        # Replace tails
        negative_batch[tail_mask, 2] = negative_entities[tail_mask]

        # To make sure we don't replace the head by the original value
        # we shift all values greater or equal than the original value by one up
        # for that reason we choose the random value from [0, num_entities -1]
        negative_batch[head_mask, 0] += (negative_batch[head_mask, 0] >=
                                         positive_batch[head_mask, 0]).long()
        negative_batch[tail_mask, 2] += (negative_batch[tail_mask, 2] >=
                                         positive_batch[tail_mask, 2]).long()

        return negative_batch.view(-1, self.num_negs_per_pos, 3)
Esempio n. 3
0
    def retrieve_and_concat(
        self,
        input: torch.LongTensor,
        input_lengths: torch.LongTensor,
        query_vec: torch.LongTensor,
        input_turns_cnt: torch.LongTensor,
    ) -> Tuple[torch.LongTensor, List[List[Document]], torch.Tensor]:
        """
        Retrieve documents, concat with input.

        :param input:
            2D [bsz, seqlen] input to the encoder
        :param input_lengths:
            1D [bsz] lengths of each input item
        :param query_vec:
            2D [bsz*n_turns, seqlen] input for the retriever
        :param input_turns_cnt:
            1D [bsz] number of dialogue turns for each input example

        :return (expanded_input, top_docs, top_doc_scores):
            expanded_input: [bsz * n_docs, seqlen+doc_len] tensor of context/document inputs
            top_docs: List of top documents for each input
            top_doc_scores: document scores for each document
        """
        # 1. Retrieve
        top_docs, top_doc_scores = self.retriever.retrieve(query_vec)

        # 2. Expand the input
        if input_turns_cnt is not None:
            input = input.repeat_interleave(input_turns_cnt,
                                            dim=0)  # type: ignore
            input_lengths = input_lengths.repeat_interleave(
                input_turns_cnt, dim=0)  # type: ignore
        expanded_input = self.concat_docs_and_input(input, input_lengths,
                                                    top_docs,
                                                    top_doc_scores.size(1))

        return expanded_input, top_docs, top_doc_scores
Esempio n. 4
0
    def encoder(
        self,
        input: torch.LongTensor,
        input_lengths: torch.LongTensor,
        query_vec: torch.LongTensor,
        input_turns_cnt: torch.LongTensor,
        memory_vec: torch.LongTensor,
        num_memories: torch.LongTensor,
        query_generator_vec: torch.LongTensor,
        gold_doc_vec: torch.LongTensor,
        gold_doc_title_vec: torch.LongTensor,
        num_gold_docs: torch.LongTensor,
        memory_decoder_vec: torch.LongTensor,
        num_memory_decoder_vecs: torch.LongTensor,
        positions: Optional[torch.LongTensor] = None,
        segments: Optional[torch.LongTensor] = None,
    ) -> Tuple[
        torch.Tensor,
        torch.BoolTensor,
        Optional[torch.LongTensor],
        Optional[List[List[Document]]],
        Optional[torch.Tensor],
    ]:
        enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = super().encoder(  # type: ignore
            input,
            input_lengths,
            query_vec,
            input_turns_cnt,
            memory_vec,
            num_memories,
            query_generator_vec,
            gold_doc_vec,
            gold_doc_title_vec,
            num_gold_docs,
            memory_decoder_vec,
            num_memory_decoder_vecs,
            positions,
            segments,
        )  # type: ignore

        if input_turns_cnt is not None:
            # Input Turns is a tensor of dim [bsz]
            input = input.repeat_interleave(input_turns_cnt, dim=0)  # type: ignore

        new_out, new_mask = concat_enc_outs(
            input, enc_out, mask, self.embedding_size, self.pad_idx
        )

        return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores
Esempio n. 5
0
    def encoder(
        self,
        input: torch.LongTensor,
        input_lengths: torch.LongTensor,
        query_vec: torch.LongTensor,
        input_turns_cnt: torch.LongTensor,
        positions: Optional[torch.LongTensor] = None,
        segments: Optional[torch.LongTensor] = None,
    ) -> Tuple[
        torch.Tensor,
        torch.BoolTensor,
        Optional[torch.LongTensor],
        Optional[List[List[Document]]],
        Optional[torch.Tensor],
    ]:
        """
        Concatenate all encoder outputs in model forward.

        :param input:
            2D [bsz, seqlen] input to the encoder
        :param input_lengths:
            1D [bsz] lengths of each input item
        :param query_vec:
            2D [bsz*n_turns, seqlen] input for the retriever
        :param input_turns_cnt:
            1D [bsz] number of dialogue turns for each input example

        :return (encoder_out, encoder_mask, input_turns_cnt, top_docs, top_doc_scores):
            encoder_out: *concatenated* encoded representations of context/document pairs
            encoder_mask: new mask for enc_out
            input_turns_cnt: pass along the input turns count for the decoder
            top_docs: List of top Documents for each batch example
            top_doc_scores: scores for each retrieved document.
        """
        enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = super().encoder(
            input, input_lengths, query_vec, input_turns_cnt, positions, segments
        )  # type: ignore

        if input_turns_cnt is not None:
            # Input Turns is a tensor of dim [bsz]
            input = input.repeat_interleave(input_turns_cnt, dim=0)  # type: ignore

        new_out, new_mask = concat_enc_outs(
            input, enc_out, mask, self.embedding_size, self.pad_idx
        )

        return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores
Esempio n. 6
0
    def corrupt_batch(
            self, positive_batch: torch.LongTensor
    ) -> torch.LongTensor:  # noqa: D102
        if self.num_negs_per_pos > 1:
            positive_batch = positive_batch.repeat_interleave(
                repeats=self.num_negs_per_pos, dim=0)

        # Bind number of negatives to sample
        num_negs = positive_batch.shape[0]

        # Equally corrupt all sides
        split_idx = int(math.ceil(num_negs / len(self._corruption_indices)))

        # Copy positive batch for corruption.
        # Do not detach, as no gradients should flow into the indices.
        negative_batch = positive_batch.clone()

        for index, start in zip(self._corruption_indices,
                                range(0, num_negs, split_idx)):
            stop = min(start + split_idx, num_negs)

            # Relations have a different index maximum than entities
            # At least make sure to not replace the triples by the original value
            index_max = (self.num_relations
                         if index == 1 else self.num_entities) - 1

            negative_batch[start:stop, index] = torch.randint(
                high=index_max,
                size=(stop - start, ),
                device=positive_batch.device,
            )

            # To make sure we don't replace the {head, relation, tail} by the
            # original value we shift all values greater or equal than the original value by one up
            # for that reason we choose the random value from [0, num_{heads, relations, tails} -1]
            negative_batch[start:stop,
                           index] += (negative_batch[start:stop, index] >=
                                      positive_batch[start:stop,
                                                     index]).long()

        return negative_batch.view(-1, self.num_negs_per_pos, 3)
Esempio n. 7
0
    def get_initial_forced_decoder_input(
        self,
        bsz: int,
        inputs: torch.LongTensor,
        n_docs: int,
        start_idx: int,
        end_idx: int,
        input_turns_cnt: Optional[torch.LongTensor] = None,
    ) -> torch.LongTensor:
        """
        Return the initial input to the decoder during training.

        Repeat inputs n_docs * n_turns times.

        :param bsz:
            batchsize
        :param inputs:
            inputs to decode
        :param n_docs:
            number of docs per input
        :param start_idx:
            start token idx
        :param end_idx:
            end token idx
        :param input_turns_cnt:
            an optional tensor containing the number of turns of each corresponding context.

        :return initial_input:
            initial input for the decoder.
        """
        if input_turns_cnt is not None:
            inputs = inputs.repeat_interleave(input_turns_cnt,
                                              dim=0)  # type: ignore
            bsz = input_turns_cnt.sum()  # type: ignore
        inputs = get_forced_decoder_inputs(inputs, bsz, start_idx, end_idx,
                                           self.generation_model)
        inputs = inputs.repeat(1,
                               n_docs).reshape(-1,
                                               inputs.size(1))  # type: ignore
        return inputs
Esempio n. 8
0
    def retrieve_and_concat(
        self,
        input: torch.LongTensor,
        input_lengths: torch.LongTensor,
        query_generator_vec: torch.LongTensor,
        query_vec: torch.LongTensor,
        input_turns_cnt: torch.LongTensor,
        memory_vec: torch.LongTensor,
        num_memories: torch.LongTensor,
        gold_doc_vec: torch.LongTensor,
        gold_doc_title_vec: torch.LongTensor,
        num_gold_docs: torch.LongTensor,
        memory_decoder_vec: torch.LongTensor,
        num_memory_decoder_vecs: torch.LongTensor,
        skip_search: torch.BoolTensor,
    ) -> Tuple[torch.LongTensor, List[List[Document]], torch.Tensor]:
        """
        Override RagModel.retrieve_and_concat to perform different retrieval, depending
        on the RetrieverType.
        """
        self.flush_previous_retriever_search_results()
        start = time.time()
        logging.debug(f'Begin encoder: {time.time() - start:.2f}')
        if input_turns_cnt is not None:
            if query_generator_vec is not None:
                query_generator_vec = query_generator_vec.repeat_interleave(
                    input_turns_cnt, dim=0)  # type: ignore
            if memory_vec is not None:
                memory_vec = memory_vec.repeat_interleave(
                    input_turns_cnt, dim=0)  # type: ignore
            if num_memories is not None:
                num_memories = num_memories.repeat_interleave(
                    input_turns_cnt, dim=0)  # type: ignore
            if memory_decoder_vec is not None:
                memory_decoder_vec = memory_decoder_vec.repeat_interleave(
                    input_turns_cnt, dim=0)  # type: ignore
            if num_memory_decoder_vecs is not None:
                num_memory_decoder_vecs = num_memory_decoder_vecs.repeat_interleave(
                    input_turns_cnt, dim=0)  # type: ignore
        n_input = (input_turns_cnt.sum().item()
                   if input_turns_cnt is not None else input.size(0))
        # 0a. Classify retrieval type, if necessary
        generated_memories = [[] for _ in range(int(n_input))]
        if memory_decoder_vec is not None:
            generated_memories = self.memory_decoder.generate_memories(
                memory_decoder_vec, num_memory_decoder_vecs)
        if self.should_generate_query:
            assert self.has_query_generator()
            retrieval_type, search_queries = self.query_generator.classify_retrieval(
                query_generator_vec, num_memories, generated_memories,
                skip_search)
            logging.debug(f'Classify Retrieval: {time.time() - start:.2f}')
        else:
            retrieval_type = torch.LongTensor(input.size(0))
            search_queries = None

        # 1. Retrieve
        top_docs: List[List[Document]] = [[] for _ in range(int(n_input))]
        doc_scores: List[List[torch.Tensor]] = [[]
                                                for _ in range(int(n_input))]

        # 1a. retrieve from faiss or search
        search_indices = self.get_retrieval_indices(retrieval_type,
                                                    RetrievalType.SEARCH)
        if search_indices.numel() > 0:
            search_docs, search_doc_scores = self.perform_search(
                search_queries, query_vec, search_indices)
            logging.debug(f'Search Complete: {time.time() - start:.2f}')
            logging.debug(f'search: {search_docs}')
            if gold_doc_vec is not None:
                logging.debug(f'num gold docs: {num_gold_docs}')
            self._fill_docs_and_scores(
                top_docs,
                doc_scores,
                search_indices,
                search_docs,
                search_doc_scores,
                gold_doc_vec,
                gold_doc_title_vec,
                num_gold_docs,
            )

        # 1b. memory search
        memory_indices = self.get_retrieval_indices(retrieval_type,
                                                    RetrievalType.MEMORY)
        if memory_indices.numel() > 0:
            memories, memory_scores = self.access_long_term_memory(
                query_vec,
                memory_indices,
                memory_vec,
                num_memories,
                memory_decoder_vec,
                generated_memories,
            )
            logging.debug(f'Memory Access Complete: {time.time() - start:.2f}')
            if memories is not None and memory_scores is not None:
                self._fill_docs_and_scores(top_docs, doc_scores,
                                           memory_indices, memories,
                                           memory_scores)

        # 1c. no search
        no_search_indices = self.get_retrieval_indices(retrieval_type,
                                                       RetrievalType.NONE)
        if no_search_indices.numel() > 0:
            dummy_docs, dummy_scores = self.dummy_retriever.retrieve(
                query_vec[no_search_indices]  # type: ignore
            )
            logging.debug('no search')
            self._fill_docs_and_scores(top_docs, doc_scores, no_search_indices,
                                       dummy_docs, dummy_scores)

        # 2. Expand the input
        if input_turns_cnt is not None:
            input = input.repeat_interleave(input_turns_cnt,
                                            dim=0)  # type: ignore
            input_lengths = input_lengths.repeat_interleave(
                input_turns_cnt, dim=0)  # type: ignore

        # Filtering empty doc_scores added due to dynamic batching (if used)
        doc_scores = [[s for s in ds if s is not None] for ds in doc_scores
                      if ds]
        top_doc_scores = torch.stack(
            [torch.cat([s_i for s_i in scores_i]) for scores_i in doc_scores])
        expanded_input = self.concat_docs_and_input(input, input_lengths,
                                                    top_docs,
                                                    top_doc_scores.size(1))
        return expanded_input, top_docs, top_doc_scores
Esempio n. 9
0
    def encoder(
        self,
        input: torch.LongTensor,
        input_lengths: torch.LongTensor,
        query_vec: torch.LongTensor,
        input_turns_cnt: torch.LongTensor,
        target_lengths: Optional[torch.LongTensor],
        positions: Optional[torch.LongTensor] = None,
        segments: Optional[torch.LongTensor] = None,
    ) -> Tuple[
        torch.Tensor,
        torch.BoolTensor,
        Optional[torch.LongTensor],
        Optional[List[List[Document]]],
        Optional[torch.Tensor],
    ]:
        """
        Override FidModel.encoder to pack all the documents into one input example.

        :param input:
            2D [bsz, seqlen] input to the encoder
        :param input_lengths:
            1D [bsz] lengths of each input item
        :param query_vec:
            2D [bsz*n_turns, seqlen] input for the retriever
        :param input_turns_cnt:
            1D [bsz] number of dialogue turns for each input example
        :param input_lengths:
            1D [bsz] lengths of each target item (for each input item)

        :return (encoder_out, encoder_mask, input_turns_cnt, top_docs, top_doc_scores):
            encoder_out: *concatenated* encoded representations of context/document pairs
            encoder_mask: new mask for enc_out
            input_turns_cnt: pass along the input turns count for the decoder
            top_docs: List of top Documents for each batch example
            top_doc_scores: scores for each retrieved document.
        """
        enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = RagModel.encoder(
            self, input, input_lengths, query_vec, input_turns_cnt, positions, segments
        )  # type: ignore
        seq_len, n_docs = enc_out.size(1), enc_out.size(0) // input.size(0)

        if input_turns_cnt is not None:
            # Input Turns is a tensor of dim [bsz]
            input = input.repeat_interleave(input_turns_cnt, dim=0)  # type: ignore
        doc_starts = (enc_out == self.pad_idx).sum(dim=1)  # left padded
        doc_lens = (seq_len - input_lengths.repeat_interleave(n_docs)) - doc_starts
        # if no padding, docs are assumed to be min doc length long
        doc_lens[doc_lens.le(0)] = self.min_doc_len
        new_enc_out = enc_out.clone()
        # BEFORE:
        #  [
        #   pad...doc_0 / in_0
        #   pad...doc_1 / in_0
        #   ...
        #   pad...doc_n / in_m
        #                       ]
        total_length_i = 0
        for i, doc_len in enumerate(doc_lens):
            if i % n_docs == 0:
                total_length_i = 0
            # max doc length is determined by how much space we have after subtracting input length
            input_and_target = input_lengths[i // n_docs] + (
                target_lengths[i // n_docs] if target_lengths is not None else 0
            )
            max_doc_len = torch.div(
                (seq_len - input_and_target), n_docs, rounding_mode='floor'
            )
            max_doc_len = max(max_doc_len, self.min_doc_len)  # type: ignore
            doc_len = min(doc_len, max_doc_len)
            total_length_i += doc_len
            if i % n_docs == n_docs - 1:
                # keep the actual input context when processing the last doc.
                clamped_input_length = input_lengths[i // n_docs].clamp(
                    max=self.truncate - total_length_i
                )
                if target_lengths is not None:
                    clamped_input_length = input_lengths[i // n_docs].clamp(
                        max=self.truncate - total_length_i - target_lengths
                    )

                pad_end = seq_len - clamped_input_length - doc_len
            else:
                pad_end = seq_len - doc_len
            # we simply move the portion of the doc we want to keep to the end of the tensor,
            # and mask out everything else.
            mask[i, :pad_end] = False
            new_enc_out[i, pad_end : pad_end + doc_len] = enc_out[
                i, doc_starts[i] : doc_starts[i] + doc_len
            ]

        new_out, new_mask = concat_enc_outs(
            input, new_enc_out.unsqueeze(-1), mask, 1, self.pad_idx, right_padded=False
        )
        # After:
        #  [
        #   doc_0 / doc_1 / doc_2 ... in_0
        #   doc_0 / doc_1 / doc_2 ... in_m
        #                                   ]
        return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores
Esempio n. 10
0
    def compute_loss(
        self,
        criterion: torch.nn.Module,
        scores: torch.Tensor,
        preds: torch.LongTensor,
        enc_state: Tuple[Any, ...],
        label_vec: torch.LongTensor,
    ) -> Tuple[torch.Tensor, List[int], torch.Tensor, torch.Tensor]:
        """
        Compute Loss for Rag Turn.

        RAG Turn Doc-Then-Turn computes loss with a normal NLL Loss
        (everything is marginalized beforehand)

        RAG Turn Doc-Only computes loss for each input turn; this loss can be
        weighted with a discount factor, applying less weight to prior turns (only for
        backpropagation purposes).

        :param criterion:
            torch criterion module
        :param scores:
            model scores
        :param preds:
            model "predicions" of tokens
        :param enc_state:
            encoder states
        :param label_vec:
            target tokens

        :return (loss, metric_loss, correct_tokens, target_tokens):
            loss: the loss through which we backpropagate
            metric_loss: loss we use for metrics
            correct_tokens: correct predictions from the model
            target_tokens: the ground truth tokens.
        """
        if scores.size(1) != label_vec.size(1):
            # ignore start
            scores = scores[:, 1:, :]
            preds = preds[:, 1:]  # type: ignore

        input_turns_cnt = enc_state[2]
        real_bsz = label_vec.size(0)
        resize_label = real_bsz != scores.size(0)
        if resize_label:
            assert self.turn_marginalize == 'doc_only'
            label_vec = label_vec.repeat_interleave(input_turns_cnt,
                                                    dim=0)  # type: ignore

        # compute loss
        score_view = scores.reshape(-1, scores.size(-1))
        loss = criterion(score_view, label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        metric_loss = loss.tolist()

        if resize_label:
            assert self.turn_marginalize == 'doc_only'
            loss = sum_across_turns(loss,
                                    input_turns_cnt,
                                    discount=self.discount_factor)
            metric_loss = sum_across_turns(loss, input_turns_cnt).tolist()

        # compute metric counters
        notnull = label_vec.ne(self.null_idx)
        target_tokens = metric_target_tokens = notnull.long().sum(dim=-1)
        correct = metric_correct = ((label_vec == preds) * notnull).sum(dim=-1)
        if resize_label:
            metric_target_tokens = sum_across_turns(target_tokens,
                                                    input_turns_cnt)
            metric_correct = sum_across_turns(correct, input_turns_cnt)

        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        return loss, metric_loss, metric_correct, metric_target_tokens