def _build_reranker( self, opt: Opt) -> Tuple[torch.nn.Module, RagRetrieverTokenizer]: """ Builds reranker. :param opt: original opt :return (module, dict) module: the model from the agent created via the options dict: A RagRetrieverTokenizer, dictionary for the created model. """ rerank_opt = copy.deepcopy(opt) rerank_opt = { **TRANSFORMER_RANKER_BASE_OPT, **self.get_reranker_opts(opt) } logging.disable() agent = create_agent(rerank_opt) logging.enable() assert isinstance(agent, TorchRankerAgent) return agent.model, RagRetrieverTokenizer('', agent.dict, max_length=360)
def get_classifier_model_and_dict( opt: Opt ) -> Tuple[Optional[TorchAgent], Optional[DictionaryAgent]]: """ Build classifier model and dictionary. """ model_file = modelzoo_path( opt['datapath'], opt['expanded_attention_classifier_model_file'] ) model, dictionary = None, None if model_file and os.path.exists(model_file): logging.info(f'Building polyencoder from path: {model_file}') logging.disable() overrides = { 'model': 'return_code_weights_agent', 'data_parallel': opt.get('data_parallel', False), 'model_parallel': opt['model_parallel'], 'delimiter': opt['delimiter'], 'no_cuda': opt['no_cuda'], 'fp16': opt['fp16'], } poly_agent = create_agent_from_model_file(model_file, overrides) logging.enable() logging.info('Poly Build Complete') dictionary = poly_agent.build_dictionary() model = poly_agent.model return model, dictionary
def __init__(self, opt: Opt): self.opt = opt self.agents = [] self.agent_dict = None self.generations = [] self.input_type = 'Memory' self.delimiter = opt.get('memory_decoder_delimiter', '\n') self.one_line_memories = opt.get('memory_decoder_one_line_memories', False) model_file = modelzoo_path(opt['datapath'], opt['memory_decoder_model_file']) if model_file and os.path.exists(model_file): logging.info(f'Building Memory Decoder from file: {model_file}') logging.disable() overrides = { 'skip_generation': False, 'inference': 'beam', 'beam_size': opt.get('memory_decoder_beam_size', 3), 'beam_min_length': opt.get('memory_decoder_beam_min_length', 10), 'beam_block_ngram': 3, } if self.opt.get('memory_decoder_truncate', -1) > 0: overrides['text_truncate'] = self.opt['memory_decoder_truncate'] overrides['truncate'] = self.opt['memory_decoder_truncate'] base_agent = create_agent_from_model_file( model_file, opt_overrides=overrides ) assert isinstance(base_agent, TorchAgent) self.agents = [base_agent] assert isinstance(self.agents[0], TorchAgent) copies = max(100, (opt['batchsize'] * opt.get('rag_turn_n_turns', 1))) self.agents += [ create_agent_from_shared(self.agents[0].share()) for _ in range(copies) ] self.agent_dict = self.agents[0].build_dictionary() logging.enable()
def __init__(self, opt: Opt): self.opt = opt self.agents = [] self.agent_dict = None self.generations = [] self.input_type = 'Search' self.knowledge_access_method = KnowledgeAccessMethod( opt['knowledge_access_method'] ) model_file = modelzoo_path(opt['datapath'], opt['query_generator_model_file']) if model_file and os.path.exists(model_file): logging.info(f'Building Query Generator from file: {model_file}') logging.disable() overrides: Dict[str, Any] = {'skip_generation': False} overrides['inference'] = opt['query_generator_inference'] overrides['beam_size'] = opt.get('query_generator_beam_size', 3) overrides['beam_min_length'] = opt.get('query_generator_beam_min_length', 2) if self.opt['query_generator_truncate'] > 0: overrides['text_truncate'] = self.opt['query_generator_truncate'] overrides['truncate'] = self.opt['query_generator_truncate'] base_agent = create_agent_from_model_file( model_file, opt_overrides=overrides ) assert isinstance(base_agent, TorchAgent) self.agents = [base_agent] bsz = opt.get('batchsize', 1) rag_turn_n_turns = opt.get('rag_turn_n_turns', 1) if bsz > 1 or rag_turn_n_turns > 1: self.agents += [ create_agent_from_shared(self.agents[0].share()) for _ in range((bsz * rag_turn_n_turns) - 1) ] self.agent_dict = self.agents[0].build_dictionary() logging.enable()
def _build_model(self, opt: Opt) -> Tuple[PolyEncoderModule, DictionaryAgent]: """ Build poly-encoder module. :param opt: options from base RAG Model :return dropout poly-encoder: return dropout poly agent. """ model_file = modelzoo_path(opt['datapath'], opt['poly_faiss_model_file']) model_opt = Opt.load(f'{model_file}.opt') create_model_opt = { **{k: model_opt[k] for k in TRANSFORMER_RANKER_BASE_OPT}, **{k: model_opt[k] for k in POLYENCODER_OPT_KEYS}, 'model': 'transformer/dropout_poly', 'init_model': model_file, 'dict_file': f'{model_file}.dict', # necessary opt args 'multitask_weights': [1], # dropout_poly args 'poly_dropout_reduction_type': model_opt['poly_dropout_reduction_type'], 'poly_dropout_use_codes': model_opt.get('poly_dropout_use_codes', True), } logging.disable() agent = create_agent(Opt(create_model_opt)) logging.enable() assert isinstance(agent, DropoutPolyAgent) return agent.model, agent.dict
def init_search_query_generator(self, opt) -> TorchGeneratorAgent: model_file = opt['search_query_generator_model_file'] logging.info('Loading search generator model') logging.disable() search_query_gen_agent = create_agent_from_model_file( model_file, opt_overrides={ 'skip_generation': False, 'inference': opt['search_query_generator_inference'], 'beam_min_length': opt['search_query_generator_beam_min_length'], 'beam_size': opt['search_query_generator_beam_size'], 'text_truncate': opt['search_query_generator_text_truncate'], }, ) logging.enable() logging.info('Search query generator model loading completed!') return search_query_gen_agent
def __init__(self, opt: Opt): self.opt = opt self.agents = [] self.agent_dict = None self.generations = [] self.input_type = 'Search' self.knowledge_access_method = KnowledgeAccessMethod( opt['knowledge_access_method']) model_file = modelzoo_path(opt['datapath'], opt['query_generator_model_file']) if (self.knowledge_access_method is KnowledgeAccessMethod.SEARCH_ONLY and 'blenderbot2/query_generator/model' in model_file): raise ValueError( 'You cannot use the blenderbot2 query generator with search_only. Please ' 'consider setting --query-generator-model-file zoo:sea/bart_sq_gen/model ' 'instead.') if model_file and os.path.exists(model_file): logging.info(f'Building Query Generator from file: {model_file}') logging.disable() overrides: Dict[str, Any] = {'skip_generation': False} overrides['inference'] = opt['query_generator_inference'] overrides['beam_size'] = opt.get('query_generator_beam_size', 3) overrides['beam_min_length'] = opt.get( 'query_generator_beam_min_length', 2) overrides['model_parallel'] = opt['model_parallel'] overrides['no_cuda'] = opt['no_cuda'] if self.opt['query_generator_truncate'] > 0: overrides['text_truncate'] = self.opt[ 'query_generator_truncate'] overrides['truncate'] = self.opt['query_generator_truncate'] base_agent = create_agent_from_model_file(model_file, opt_overrides=overrides) assert isinstance(base_agent, TorchAgent) self.agents = [base_agent] bsz = max( opt.get('batchsize') or 1, opt.get('eval_batchsize') or 1) rag_turn_n_turns = opt.get('rag_turn_n_turns', 1) if bsz > 1 or rag_turn_n_turns > 1: self.agents += [ create_agent_from_shared(self.agents[0].share()) for _ in range((bsz * rag_turn_n_turns) - 1) ] self.agent_dict = self.agents[0].build_dictionary() logging.enable()
def override_print(suppress=False, prefix=None): """ Context manager to override the print to suppress or modify output. Recommended usage is to call this with suppress=True for all non-primary workers, or call with a prefix of rank on all workers. >>> with override_print(prefix="rank{}".format(rank)): ... my_computation() :param bool suppress: if true, all future print statements are noops. :param str prefix: if not None, this string is prefixed to all future print statements. """ builtin_print = builtins.print def new_print(*args, **kwargs): if suppress: # do nothing return elif prefix: return builtin_print(prefix, *args, **kwargs) else: # default to normal print return builtin_print(*args, **kwargs) if prefix: logging.logger.add_format_prefix(prefix) if suppress: logging.disable() # override the print for now builtins.print = new_print yield # bring it back at the end of the context builtins.print = builtin_print if suppress: logging.enable()