Пример #1
0
def retriever_factory(opt: Opt,
                      dictionary: DictionaryAgent,
                      shared=None) -> Optional[RagRetriever]:
    """
    Build retriever.

    Override to build special BB2 Search Retrievers, if necessary

    :param opt:
        ParlAI Opt
    :param dictionary:
        dictionary agent
    :param shared:
        shared objects.

    :return retriever:
        return a retriever for RAG.
    """
    if opt.get('converting'):
        return None
    retriever = RetrieverType(opt['rag_retriever_type'])
    if retriever is RetrieverType.SEARCH_ENGINE:
        return BB2SearchQuerySearchEngineRetriever(opt,
                                                   dictionary,
                                                   shared=shared)
    elif retriever is RetrieverType.SEARCH_TERM_FAISS:
        return BB2SearchQueryFaissIndexRetriever(opt,
                                                 dictionary,
                                                 shared=shared)
    elif retriever is RetrieverType.OBSERVATION_ECHO_RETRIEVER:
        return BB2ObservationEchoRetriever(opt, dictionary, shared=shared)
    else:
        return rag_retriever_factory(opt, dictionary, shared=shared)
Пример #2
0
 def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None):
     super().__init__()
     self.retriever_type = RetrieverType(opt['rag_retriever_type'])
     if not (
         (self.retriever_type == RetrieverType.SEARCH_ENGINE)
         or (opt.get('retriever_debug_index') in [None, 'none'])
     ):
         if opt.get('retriever_debug_index') == 'exact':
             opt['path_to_index'] = WOW_INDEX_PATH
         else:
             opt['path_to_index'] = WOW_COMPRESSED_INDEX_PATH
         opt['path_to_dpr_passages'] = WOW_PASSAGES_PATH
     self.opt = opt
     self.print_docs = opt.get('print_docs', False)
     self.max_doc_len = opt['max_doc_token_length']
     self.max_query_len = opt['rag_query_truncate'] or 1024
     self.end_idx = dictionary[dictionary.end_token]
     self._tokenizer = RagRetrieverTokenizer(
         datapath=opt['datapath'],
         query_model=opt['query_model'],
         dictionary=dictionary,
         delimiter=opt.get('delimiter', '\n') or '\n',
     )
     self.fp16 = (
         not opt['no_cuda']
         and torch.cuda.is_available()
         and self.opt.get('fp16', False)
     )
Пример #3
0
def retriever_factory(
    opt: Opt, dictionary: DictionaryAgent, shared=None
) -> Optional[RagRetriever]:
    """
    Build retriever.

    :param opt:
        ParlAI Opt
    :param dictionary:
        dictionary agent
    :param shared:
        shared objects.

    :return retriever:
        return a retriever for RAG.
    """
    if opt.get('converting'):
        return None
    # only build retriever when not converting a BART model
    retriever = RetrieverType(opt['rag_retriever_type'])
    if retriever is RetrieverType.DPR:
        return DPRRetriever(opt, dictionary, shared=shared)
    elif retriever is RetrieverType.TFIDF:
        return TFIDFRetriever(opt, dictionary, shared=shared)
    elif retriever is RetrieverType.DPR_THEN_POLY:
        return DPRThenPolyRetriever(opt, dictionary, shared=shared)
    elif retriever is RetrieverType.POLY_FAISS:
        return PolyFaissRetriever(opt, dictionary, shared=shared)
    elif retriever is RetrieverType.SEARCH_ENGINE:
        return SearchQuerySearchEngineRetriever(opt, dictionary, shared=shared)
    elif retriever is RetrieverType.SEARCH_TERM_FAISS:
        return SearchQueryFAISSIndexRetriever(opt, dictionary, shared=shared)
Пример #4
0
 def __init__(self, opt, dictionary, retriever_shared=None):
     self.add_start_token = opt["add_start_token"]
     opt['converting'] = True  # set not to build the retriever
     FidModel.__init__(self, opt, dictionary, retriever_shared)
     opt['converting'] = False
     self.config = self.seq2seq_decoder.transformer.config
     self.embedding_size = self.config.n_embd
     self.lm_head = torch.nn.Linear(
         self.config.n_embd, self.config.vocab_size, bias=False
     )
     self._tie_weights(self.lm_head, self.seq2seq_decoder.transformer.wte)
     self.doc_delim = self.dict.txt2vec(Document.PASSAGE_DELIM)[0]
     self.min_doc_len = opt['min_doc_token_length']
     self.truncate = (
         opt['text_truncate'] if opt['text_truncate'] > -1 else opt['truncate']
     )
     if opt.get('filter_docs_with_label'):
         assert (
             RetrieverType(opt['rag_retriever_type']) == RetrieverType.SEARCH_ENGINE
         )
         self.retriever = FilterDocsForLabelSearchEngineRetrieverBase(
             opt, dictionary, shared=retriever_shared
         )  # type: ignore
     else:
         self.retriever = retriever_factory(opt, dictionary, shared=retriever_shared)
Пример #5
0
    def __init__(self,
                 opt: Opt,
                 dictionary: DictionaryAgent,
                 retriever_shared=None):
        # TODO: Get rid of this hack
        opt['converting'] = True
        super().__init__(opt, dictionary, retriever_shared)
        opt['converting'] = False
        self.opt = opt
        self.dummy_retriever = DummyRetriever(opt, dictionary)
        self.retriever = retriever_factory(opt,
                                           dictionary,
                                           shared=retriever_shared)
        assert self.retriever is not None
        query_encoder = (self.retriever.query_encoder
                         if hasattr(self.retriever, 'query_encoder')
                         and opt['share_search_and_memory_query_encoder'] else
                         None)
        self.long_term_memory = LongTermMemory(opt, dictionary,
                                               query_encoder)  # type: ignore
        self.query_generator = QueryGenerator(opt)
        self.memory_decoder = MemoryDecoder(opt)

        # attrs
        self.knowledge_access_method = KnowledgeAccessMethod(
            opt['knowledge_access_method'])
        self.search = RetrieverType(opt['rag_retriever_type']) in [
            RetrieverType.SEARCH_ENGINE,
            RetrieverType.SEARCH_TERM_FAISS,
        ]
        self.should_generate_query = (
            self.knowledge_access_method is KnowledgeAccessMethod.CLASSIFY
            or self.search) and (self.knowledge_access_method not in [
                KnowledgeAccessMethod.MEMORY_ONLY, KnowledgeAccessMethod.NONE
            ])
Пример #6
0
def combo_fid_retriever_factory(opt: Opt,
                                dictionary: DictionaryAgent,
                                shared=None) -> Optional[RagRetriever]:
    """
    Bypass call to standard retriever factory to possibly build our own retriever.
    """
    if opt.get('converting'):
        return None
    retriever = RetrieverType(opt['rag_retriever_type'])
    if retriever is RetrieverType.SEARCH_ENGINE:
        return ComboFidSearchQuerySearchEngineRetriever(
            opt, dictionary, shared=shared)  # type: ignore
    else:
        return retriever_factory(opt, dictionary, shared)
Пример #7
0
 def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None):
     super().__init__(opt, dictionary, retriever_shared)
     if opt.get('filter_docs_with_label'):
         assert (
             RetrieverType(opt['rag_retriever_type']) == RetrieverType.SEARCH_ENGINE
         )
         self.retriever = FilterDocsForLabelSearchEngineRetrieverCombo(
             opt, dictionary, shared=retriever_shared
         )  # type: ignore
     else:
         self.retriever = combo_fid_retriever_factory(
             opt, dictionary, shared=retriever_shared
         )
     self.top_docs: List[List[Document]] = []