コード例 #1
0
    def set_input_turn_cnt_vec(self, observation: Message, model: RagModel,
                               query_str: str) -> Message:
        """
        Compute the number of turns of input, and set the vec accordingly.

        :param observation:
            observation in which to set the vec
        :param model:
            model provided for access to retriever tokenizer
        :param query_str:
            the query string for computation of the input turns.

        :return observation:
            return the observation with the input turn vec set appropriately.
        """
        delimiter = model.get_retriever_delimiter()
        split_text_raw = query_str.split(delimiter)
        split_text: List[str] = []
        if self.n_turns > 1 and len(split_text_raw) > self.n_turns:
            end_off = self.n_turns - 1
            split_text = [delimiter.join(split_text_raw[:-end_off])
                          ] + split_text_raw[-end_off:]
        else:
            split_text = split_text_raw

        input_turns_cnt = torch.LongTensor([len(split_text)])
        query_vecs = [model.tokenize_query(q) for q in split_text]
        # Override query vec
        observation.force_set('query_vec', query_vecs)
        observation['input_turn_cnt_vec'] = input_turns_cnt
        return observation
コード例 #2
0
    def build_regret_model(self) -> RagModel:
        """
        Build and return regret RagModel.

        Assume dictionary is the same.
        """
        model_file = self.opt['regret_model_file']
        if model_file:
            assert os.path.exists(
                model_file), 'specify correct path for --regret-model-file'
            regret_opt = Opt.load(f'{model_file}.opt')
            regret_opt['n_docs'] = self.opt[
                'n_docs']  # Urgent that this is the same
            # add keys that were not in this model when originally trained
            regret_opt.update(
                {k: v
                 for k, v in self.opt.items() if k not in regret_opt})
            retriever_shared = None
            if all([
                    regret_opt[k] == self.opt[k] for k in [
                        'rag_retriever_type',
                        'path_to_index',
                        'path_to_dpr_passages',
                    ]
            ]):
                logging.warning(
                    'Sharing retrievers between model and regret model!')
                retriever_shared = self.model.encoder.retriever.share()

            model = RagModel(regret_opt,
                             self.dict,
                             retriever_shared=retriever_shared)
            with PathManager.open(self.opt['regret_model_file'], 'rb') as f:
                states = torch.load(
                    f,
                    map_location=lambda cpu, _: cpu,
                    pickle_module=parlai.utils.pickle,
                )
            assert 'model' in states
            model.load_state_dict(states['model'])
            if self.model_parallel:
                ph = PipelineHelper()
                ph.check_compatibility(self.opt)
                self.regret_model = ph.make_parallel(self.regret_model)
            else:
                self.regret_model.cuda()
            if self.fp16:
                self.regret_model = self.regret_model.half()

            sync_parameters(self.regret_model)
            train_params = trainable_parameters(self.regret_model)
            total_params = total_parameters(self.regret_model)
            logging.info(
                f"Total regret parameters: {total_params:,d} ({train_params:,d} trainable)"
            )
        else:
            model = self.model

        return model
コード例 #3
0
    def augment_batch_for_generation(self, batch: Batch,
                                     model: RagModel) -> Batch:
        """
        Augment batch for generation.

        For RAG Sequence, we retrieve prior to generation, as we do not consider the
        document probabilities until after generating all of the beams.

        :param batch:
            batch to augment
        :param model:
            model to possibly help with augmenting

        :return batch:
            return batch with text vec swapped out.
        """
        (expanded_input, _, doc_scores) = model.retrieve_and_concat(
            batch.text_vec,
            batch.text_vec.ne(self.null_idx).sum(1),
            batch.query_vec,
            batch.input_turn_cnt_vec,
        )
        doc_log_probs = F.log_softmax(doc_scores, dim=1)
        batch.src_text_vec = batch.text_vec
        batch.text_vec = expanded_input
        batch.doc_log_probs = doc_log_probs
        batch.batchsize = batch.text_vec.size(0)

        return batch
コード例 #4
0
    def thorough_generation(
        cls,
        hyps: List[torch.LongTensor],
        new_input: torch.LongTensor,
        null_idx: int,
        model: RagModel,
    ) -> List[Tuple[torch.LongTensor, torch.Tensor]]:
        """
        Apply RAG-sequence thorough generation for a single batch item.

        Recomputes model scores with given hypotheses, sorts accordingly.

        :param hyps:
            list of candidate hypotheses
        :param new_input:
            input for the model

        :return sorted_hyps:
            return list of (hyp, score) tuples, sorted by their score.
        """
        # deduplicate, exclude BOS Token
        hyps = list({str(h.tolist()): h[1:] for h in hyps}.values())  # type: ignore
        new_input = new_input.repeat(len(hyps), 1)  # type: ignore
        new_ys, _ = padded_tensor(
            hyps, fp16friendly=new_input.size(1) % FP16_PAD_SIZE == 0, pad_idx=null_idx
        )
        new_ys = new_ys.to(new_input.device)
        scores, *_ = model.seq2seq_forward_pass(new_input, new_ys)
        loss = cls._rag_sequence_loss(
            new_ys.unsqueeze(1).unsqueeze(-1), scores.unsqueeze(1), null_idx
        )  # type: ignore
        sorted_by_score = [
            (hyps[idx], loss[idx]) for idx in loss.sort()[-1]
        ]  # sort ascending
        return sorted_by_score
コード例 #5
0
 def build_rag_model(opt: Opt, dictionary: DictionaryAgent) -> RagModel:
     return RagModel(opt, dictionary)
コード例 #6
0
    def encoder(
        self,
        input: torch.LongTensor,
        input_lengths: torch.LongTensor,
        query_vec: torch.LongTensor,
        input_turns_cnt: torch.LongTensor,
        target_lengths: Optional[torch.LongTensor],
        positions: Optional[torch.LongTensor] = None,
        segments: Optional[torch.LongTensor] = None,
    ) -> Tuple[
        torch.Tensor,
        torch.BoolTensor,
        Optional[torch.LongTensor],
        Optional[List[List[Document]]],
        Optional[torch.Tensor],
    ]:
        """
        Override FidModel.encoder to pack all the documents into one input example.

        :param input:
            2D [bsz, seqlen] input to the encoder
        :param input_lengths:
            1D [bsz] lengths of each input item
        :param query_vec:
            2D [bsz*n_turns, seqlen] input for the retriever
        :param input_turns_cnt:
            1D [bsz] number of dialogue turns for each input example
        :param input_lengths:
            1D [bsz] lengths of each target item (for each input item)

        :return (encoder_out, encoder_mask, input_turns_cnt, top_docs, top_doc_scores):
            encoder_out: *concatenated* encoded representations of context/document pairs
            encoder_mask: new mask for enc_out
            input_turns_cnt: pass along the input turns count for the decoder
            top_docs: List of top Documents for each batch example
            top_doc_scores: scores for each retrieved document.
        """
        enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = RagModel.encoder(
            self, input, input_lengths, query_vec, input_turns_cnt, positions, segments
        )  # type: ignore
        seq_len, n_docs = enc_out.size(1), enc_out.size(0) // input.size(0)

        if input_turns_cnt is not None:
            # Input Turns is a tensor of dim [bsz]
            input = input.repeat_interleave(input_turns_cnt, dim=0)  # type: ignore
        doc_starts = (enc_out == self.pad_idx).sum(dim=1)  # left padded
        doc_lens = (seq_len - input_lengths.repeat_interleave(n_docs)) - doc_starts
        # if no padding, docs are assumed to be min doc length long
        doc_lens[doc_lens.le(0)] = self.min_doc_len
        new_enc_out = enc_out.clone()
        # BEFORE:
        #  [
        #   pad...doc_0 / in_0
        #   pad...doc_1 / in_0
        #   ...
        #   pad...doc_n / in_m
        #                       ]
        total_length_i = 0
        for i, doc_len in enumerate(doc_lens):
            if i % n_docs == 0:
                total_length_i = 0
            # max doc length is determined by how much space we have after subtracting input length
            input_and_target = input_lengths[i // n_docs] + (
                target_lengths[i // n_docs] if target_lengths is not None else 0
            )
            max_doc_len = torch.div(
                (seq_len - input_and_target), n_docs, rounding_mode='floor'
            )
            max_doc_len = max(max_doc_len, self.min_doc_len)  # type: ignore
            doc_len = min(doc_len, max_doc_len)
            total_length_i += doc_len
            if i % n_docs == n_docs - 1:
                # keep the actual input context when processing the last doc.
                clamped_input_length = input_lengths[i // n_docs].clamp(
                    max=self.truncate - total_length_i
                )
                if target_lengths is not None:
                    clamped_input_length = input_lengths[i // n_docs].clamp(
                        max=self.truncate - total_length_i - target_lengths
                    )

                pad_end = seq_len - clamped_input_length - doc_len
            else:
                pad_end = seq_len - doc_len
            # we simply move the portion of the doc we want to keep to the end of the tensor,
            # and mask out everything else.
            mask[i, :pad_end] = False
            new_enc_out[i, pad_end : pad_end + doc_len] = enc_out[
                i, doc_starts[i] : doc_starts[i] + doc_len
            ]

        new_out, new_mask = concat_enc_outs(
            input, new_enc_out.unsqueeze(-1), mask, 1, self.pad_idx, right_padded=False
        )
        # After:
        #  [
        #   doc_0 / doc_1 / doc_2 ... in_0
        #   doc_0 / doc_1 / doc_2 ... in_m
        #                                   ]
        return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores