Ejemplo n.º 1
0
 def __init__(self, opt, dictionary, retriever_shared=None):
     self.add_start_token = opt["add_start_token"]
     opt['converting'] = True  # set not to build the retriever
     FidModel.__init__(self, opt, dictionary, retriever_shared)
     opt['converting'] = False
     self.config = self.seq2seq_decoder.transformer.config
     self.embedding_size = self.config.n_embd
     self.lm_head = torch.nn.Linear(
         self.config.n_embd, self.config.vocab_size, bias=False
     )
     self._tie_weights(self.lm_head, self.seq2seq_decoder.transformer.wte)
     self.doc_delim = self.dict.txt2vec(Document.PASSAGE_DELIM)[0]
     self.min_doc_len = opt['min_doc_token_length']
     self.truncate = (
         opt['text_truncate'] if opt['text_truncate'] > -1 else opt['truncate']
     )
     if opt.get('filter_docs_with_label'):
         assert (
             RetrieverType(opt['rag_retriever_type']) == RetrieverType.SEARCH_ENGINE
         )
         self.retriever = FilterDocsForLabelSearchEngineRetrieverBase(
             opt, dictionary, shared=retriever_shared
         )  # type: ignore
     else:
         self.retriever = retriever_factory(opt, dictionary, shared=retriever_shared)
Ejemplo n.º 2
0
def combo_fid_retriever_factory(opt: Opt,
                                dictionary: DictionaryAgent,
                                shared=None) -> Optional[RagRetriever]:
    """
    Bypass call to standard retriever factory to possibly build our own retriever.
    """
    if opt.get('converting'):
        return None
    retriever = RetrieverType(opt['rag_retriever_type'])
    if retriever is RetrieverType.SEARCH_ENGINE:
        return ComboFidSearchQuerySearchEngineRetriever(
            opt, dictionary, shared=shared)  # type: ignore
    else:
        return retriever_factory(opt, dictionary, shared)
Ejemplo n.º 3
0
    def __init__(self, opt, dictionary, retriever_shared=None):
        from parlai.agents.rag.rag import RAG_MODELS

        self.pad_idx = dictionary[dictionary.null_token]
        self.start_idx = dictionary[dictionary.start_token]
        self.end_idx = dictionary[dictionary.end_token]
        super().__init__(self.pad_idx, self.start_idx, self.end_idx)
        self.fp16 = (not opt['no_cuda'] and torch.cuda.is_available()
                     and opt.get('fp16', False))
        self.dict = dictionary
        self.embeddings = create_embeddings(dictionary, opt['embedding_size'],
                                            self.pad_idx)
        # attrs
        self.rag_model_type = opt['rag_model_type']
        self._rag_model_interface = RAG_MODELS[self.rag_model_type](
            opt, self.pad_idx)
        self.generation_model = opt['generation_model']
        self.n_extra_positions = opt['n_extra_positions']
        self.n_positions = get_n_positions_from_options(
            opt) + opt['n_extra_positions']
        assert opt['n_extra_positions'] >= 0
        self.expanded_input_truncate = min(
            opt['text_truncate'] or opt['truncate'],
            get_n_positions_from_options(opt))
        if self.n_extra_positions > 0:
            # This attribute is overloaded.
            # when n_extra_positions == 0, it is the truncation of the full expanded input
            # when >0, it is the maximum length of the knowledge tokens.
            self.expanded_input_truncate = self.n_extra_positions
        self.min_doc_token_length = opt['min_doc_token_length']

        # modules
        self.retriever = retriever_factory(opt,
                                           dictionary,
                                           shared=retriever_shared)
        self.seq2seq_encoder = self.build_encoder(
            opt,
            dictionary=dictionary,
            embedding=self.embeddings,
            padding_idx=self.pad_idx,
        )
        self.seq2seq_decoder = self.build_decoder(opt,
                                                  embedding=self.embeddings,
                                                  padding_idx=self.pad_idx)