def retriever_factory(opt: Opt, dictionary: DictionaryAgent, shared=None) -> Optional[RagRetriever]: """ Build retriever. Override to build special BB2 Search Retrievers, if necessary :param opt: ParlAI Opt :param dictionary: dictionary agent :param shared: shared objects. :return retriever: return a retriever for RAG. """ if opt.get('converting'): return None retriever = RetrieverType(opt['rag_retriever_type']) if retriever is RetrieverType.SEARCH_ENGINE: return BB2SearchQuerySearchEngineRetriever(opt, dictionary, shared=shared) elif retriever is RetrieverType.SEARCH_TERM_FAISS: return BB2SearchQueryFaissIndexRetriever(opt, dictionary, shared=shared) elif retriever is RetrieverType.OBSERVATION_ECHO_RETRIEVER: return BB2ObservationEchoRetriever(opt, dictionary, shared=shared) else: return rag_retriever_factory(opt, dictionary, shared=shared)
def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None): super().__init__() self.retriever_type = RetrieverType(opt['rag_retriever_type']) if not ( (self.retriever_type == RetrieverType.SEARCH_ENGINE) or (opt.get('retriever_debug_index') in [None, 'none']) ): if opt.get('retriever_debug_index') == 'exact': opt['path_to_index'] = WOW_INDEX_PATH else: opt['path_to_index'] = WOW_COMPRESSED_INDEX_PATH opt['path_to_dpr_passages'] = WOW_PASSAGES_PATH self.opt = opt self.print_docs = opt.get('print_docs', False) self.max_doc_len = opt['max_doc_token_length'] self.max_query_len = opt['rag_query_truncate'] or 1024 self.end_idx = dictionary[dictionary.end_token] self._tokenizer = RagRetrieverTokenizer( datapath=opt['datapath'], query_model=opt['query_model'], dictionary=dictionary, delimiter=opt.get('delimiter', '\n') or '\n', ) self.fp16 = ( not opt['no_cuda'] and torch.cuda.is_available() and self.opt.get('fp16', False) )
def retriever_factory( opt: Opt, dictionary: DictionaryAgent, shared=None ) -> Optional[RagRetriever]: """ Build retriever. :param opt: ParlAI Opt :param dictionary: dictionary agent :param shared: shared objects. :return retriever: return a retriever for RAG. """ if opt.get('converting'): return None # only build retriever when not converting a BART model retriever = RetrieverType(opt['rag_retriever_type']) if retriever is RetrieverType.DPR: return DPRRetriever(opt, dictionary, shared=shared) elif retriever is RetrieverType.TFIDF: return TFIDFRetriever(opt, dictionary, shared=shared) elif retriever is RetrieverType.DPR_THEN_POLY: return DPRThenPolyRetriever(opt, dictionary, shared=shared) elif retriever is RetrieverType.POLY_FAISS: return PolyFaissRetriever(opt, dictionary, shared=shared) elif retriever is RetrieverType.SEARCH_ENGINE: return SearchQuerySearchEngineRetriever(opt, dictionary, shared=shared) elif retriever is RetrieverType.SEARCH_TERM_FAISS: return SearchQueryFAISSIndexRetriever(opt, dictionary, shared=shared)
def __init__(self, opt, dictionary, retriever_shared=None): self.add_start_token = opt["add_start_token"] opt['converting'] = True # set not to build the retriever FidModel.__init__(self, opt, dictionary, retriever_shared) opt['converting'] = False self.config = self.seq2seq_decoder.transformer.config self.embedding_size = self.config.n_embd self.lm_head = torch.nn.Linear( self.config.n_embd, self.config.vocab_size, bias=False ) self._tie_weights(self.lm_head, self.seq2seq_decoder.transformer.wte) self.doc_delim = self.dict.txt2vec(Document.PASSAGE_DELIM)[0] self.min_doc_len = opt['min_doc_token_length'] self.truncate = ( opt['text_truncate'] if opt['text_truncate'] > -1 else opt['truncate'] ) if opt.get('filter_docs_with_label'): assert ( RetrieverType(opt['rag_retriever_type']) == RetrieverType.SEARCH_ENGINE ) self.retriever = FilterDocsForLabelSearchEngineRetrieverBase( opt, dictionary, shared=retriever_shared ) # type: ignore else: self.retriever = retriever_factory(opt, dictionary, shared=retriever_shared)
def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None): # TODO: Get rid of this hack opt['converting'] = True super().__init__(opt, dictionary, retriever_shared) opt['converting'] = False self.opt = opt self.dummy_retriever = DummyRetriever(opt, dictionary) self.retriever = retriever_factory(opt, dictionary, shared=retriever_shared) assert self.retriever is not None query_encoder = (self.retriever.query_encoder if hasattr(self.retriever, 'query_encoder') and opt['share_search_and_memory_query_encoder'] else None) self.long_term_memory = LongTermMemory(opt, dictionary, query_encoder) # type: ignore self.query_generator = QueryGenerator(opt) self.memory_decoder = MemoryDecoder(opt) # attrs self.knowledge_access_method = KnowledgeAccessMethod( opt['knowledge_access_method']) self.search = RetrieverType(opt['rag_retriever_type']) in [ RetrieverType.SEARCH_ENGINE, RetrieverType.SEARCH_TERM_FAISS, ] self.should_generate_query = ( self.knowledge_access_method is KnowledgeAccessMethod.CLASSIFY or self.search) and (self.knowledge_access_method not in [ KnowledgeAccessMethod.MEMORY_ONLY, KnowledgeAccessMethod.NONE ])
def combo_fid_retriever_factory(opt: Opt, dictionary: DictionaryAgent, shared=None) -> Optional[RagRetriever]: """ Bypass call to standard retriever factory to possibly build our own retriever. """ if opt.get('converting'): return None retriever = RetrieverType(opt['rag_retriever_type']) if retriever is RetrieverType.SEARCH_ENGINE: return ComboFidSearchQuerySearchEngineRetriever( opt, dictionary, shared=shared) # type: ignore else: return retriever_factory(opt, dictionary, shared)
def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None): super().__init__(opt, dictionary, retriever_shared) if opt.get('filter_docs_with_label'): assert ( RetrieverType(opt['rag_retriever_type']) == RetrieverType.SEARCH_ENGINE ) self.retriever = FilterDocsForLabelSearchEngineRetrieverCombo( opt, dictionary, shared=retriever_shared ) # type: ignore else: self.retriever = combo_fid_retriever_factory( opt, dictionary, shared=retriever_shared ) self.top_docs: List[List[Document]] = []