Пример #1
0
    def _build_reranker(
            self, opt: Opt) -> Tuple[torch.nn.Module, RagRetrieverTokenizer]:
        """
        Builds reranker.

        :param opt:
            original opt

        :return (module, dict)
            module: the model from the agent created via the options
            dict: A RagRetrieverTokenizer, dictionary for the created model.
        """
        rerank_opt = copy.deepcopy(opt)
        rerank_opt = {
            **TRANSFORMER_RANKER_BASE_OPT,
            **self.get_reranker_opts(opt)
        }
        logging.disable()
        agent = create_agent(rerank_opt)
        logging.enable()
        assert isinstance(agent, TorchRankerAgent)

        return agent.model, RagRetrieverTokenizer('',
                                                  agent.dict,
                                                  max_length=360)
Пример #2
0
def get_classifier_model_and_dict(
    opt: Opt
) -> Tuple[Optional[TorchAgent], Optional[DictionaryAgent]]:
    """
    Build classifier model and dictionary.
    """
    model_file = modelzoo_path(
        opt['datapath'], opt['expanded_attention_classifier_model_file']
    )
    model, dictionary = None, None
    if model_file and os.path.exists(model_file):
        logging.info(f'Building polyencoder from path: {model_file}')
        logging.disable()
        overrides = {
            'model': 'return_code_weights_agent',
            'data_parallel': opt.get('data_parallel', False),
            'model_parallel': opt['model_parallel'],
            'delimiter': opt['delimiter'],
            'no_cuda': opt['no_cuda'],
            'fp16': opt['fp16'],
        }
        poly_agent = create_agent_from_model_file(model_file, overrides)
        logging.enable()
        logging.info('Poly Build Complete')
        dictionary = poly_agent.build_dictionary()
        model = poly_agent.model
    return model, dictionary
Пример #3
0
 def __init__(self, opt: Opt):
     self.opt = opt
     self.agents = []
     self.agent_dict = None
     self.generations = []
     self.input_type = 'Memory'
     self.delimiter = opt.get('memory_decoder_delimiter', '\n')
     self.one_line_memories = opt.get('memory_decoder_one_line_memories', False)
     model_file = modelzoo_path(opt['datapath'], opt['memory_decoder_model_file'])
     if model_file and os.path.exists(model_file):
         logging.info(f'Building Memory Decoder from file: {model_file}')
         logging.disable()
         overrides = {
             'skip_generation': False,
             'inference': 'beam',
             'beam_size': opt.get('memory_decoder_beam_size', 3),
             'beam_min_length': opt.get('memory_decoder_beam_min_length', 10),
             'beam_block_ngram': 3,
         }
         if self.opt.get('memory_decoder_truncate', -1) > 0:
             overrides['text_truncate'] = self.opt['memory_decoder_truncate']
             overrides['truncate'] = self.opt['memory_decoder_truncate']
         base_agent = create_agent_from_model_file(
             model_file, opt_overrides=overrides
         )
         assert isinstance(base_agent, TorchAgent)
         self.agents = [base_agent]
         assert isinstance(self.agents[0], TorchAgent)
         copies = max(100, (opt['batchsize'] * opt.get('rag_turn_n_turns', 1)))
         self.agents += [
             create_agent_from_shared(self.agents[0].share()) for _ in range(copies)
         ]
         self.agent_dict = self.agents[0].build_dictionary()
         logging.enable()
Пример #4
0
 def __init__(self, opt: Opt):
     self.opt = opt
     self.agents = []
     self.agent_dict = None
     self.generations = []
     self.input_type = 'Search'
     self.knowledge_access_method = KnowledgeAccessMethod(
         opt['knowledge_access_method']
     )
     model_file = modelzoo_path(opt['datapath'], opt['query_generator_model_file'])
     if model_file and os.path.exists(model_file):
         logging.info(f'Building Query Generator from file: {model_file}')
         logging.disable()
         overrides: Dict[str, Any] = {'skip_generation': False}
         overrides['inference'] = opt['query_generator_inference']
         overrides['beam_size'] = opt.get('query_generator_beam_size', 3)
         overrides['beam_min_length'] = opt.get('query_generator_beam_min_length', 2)
         if self.opt['query_generator_truncate'] > 0:
             overrides['text_truncate'] = self.opt['query_generator_truncate']
             overrides['truncate'] = self.opt['query_generator_truncate']
         base_agent = create_agent_from_model_file(
             model_file, opt_overrides=overrides
         )
         assert isinstance(base_agent, TorchAgent)
         self.agents = [base_agent]
         bsz = opt.get('batchsize', 1)
         rag_turn_n_turns = opt.get('rag_turn_n_turns', 1)
         if bsz > 1 or rag_turn_n_turns > 1:
             self.agents += [
                 create_agent_from_shared(self.agents[0].share())
                 for _ in range((bsz * rag_turn_n_turns) - 1)
             ]
         self.agent_dict = self.agents[0].build_dictionary()
         logging.enable()
Пример #5
0
    def _build_model(self, opt: Opt) -> Tuple[PolyEncoderModule, DictionaryAgent]:
        """
        Build poly-encoder module.

        :param opt:
            options from base RAG Model

        :return dropout poly-encoder:
            return dropout poly agent.
        """
        model_file = modelzoo_path(opt['datapath'], opt['poly_faiss_model_file'])
        model_opt = Opt.load(f'{model_file}.opt')

        create_model_opt = {
            **{k: model_opt[k] for k in TRANSFORMER_RANKER_BASE_OPT},
            **{k: model_opt[k] for k in POLYENCODER_OPT_KEYS},
            'model': 'transformer/dropout_poly',
            'init_model': model_file,
            'dict_file': f'{model_file}.dict',
            # necessary opt args
            'multitask_weights': [1],
            # dropout_poly args
            'poly_dropout_reduction_type': model_opt['poly_dropout_reduction_type'],
            'poly_dropout_use_codes': model_opt.get('poly_dropout_use_codes', True),
        }
        logging.disable()
        agent = create_agent(Opt(create_model_opt))
        logging.enable()
        assert isinstance(agent, DropoutPolyAgent)
        return agent.model, agent.dict
Пример #6
0
 def init_search_query_generator(self, opt) -> TorchGeneratorAgent:
     model_file = opt['search_query_generator_model_file']
     logging.info('Loading search generator model')
     logging.disable()
     search_query_gen_agent = create_agent_from_model_file(
         model_file,
         opt_overrides={
             'skip_generation': False,
             'inference': opt['search_query_generator_inference'],
             'beam_min_length': opt['search_query_generator_beam_min_length'],
             'beam_size': opt['search_query_generator_beam_size'],
             'text_truncate': opt['search_query_generator_text_truncate'],
         },
     )
     logging.enable()
     logging.info('Search query generator model loading completed!')
     return search_query_gen_agent
Пример #7
0
 def __init__(self, opt: Opt):
     self.opt = opt
     self.agents = []
     self.agent_dict = None
     self.generations = []
     self.input_type = 'Search'
     self.knowledge_access_method = KnowledgeAccessMethod(
         opt['knowledge_access_method'])
     model_file = modelzoo_path(opt['datapath'],
                                opt['query_generator_model_file'])
     if (self.knowledge_access_method is KnowledgeAccessMethod.SEARCH_ONLY
             and 'blenderbot2/query_generator/model' in model_file):
         raise ValueError(
             'You cannot use the blenderbot2 query generator with search_only. Please '
             'consider setting --query-generator-model-file zoo:sea/bart_sq_gen/model '
             'instead.')
     if model_file and os.path.exists(model_file):
         logging.info(f'Building Query Generator from file: {model_file}')
         logging.disable()
         overrides: Dict[str, Any] = {'skip_generation': False}
         overrides['inference'] = opt['query_generator_inference']
         overrides['beam_size'] = opt.get('query_generator_beam_size', 3)
         overrides['beam_min_length'] = opt.get(
             'query_generator_beam_min_length', 2)
         overrides['model_parallel'] = opt['model_parallel']
         overrides['no_cuda'] = opt['no_cuda']
         if self.opt['query_generator_truncate'] > 0:
             overrides['text_truncate'] = self.opt[
                 'query_generator_truncate']
             overrides['truncate'] = self.opt['query_generator_truncate']
         base_agent = create_agent_from_model_file(model_file,
                                                   opt_overrides=overrides)
         assert isinstance(base_agent, TorchAgent)
         self.agents = [base_agent]
         bsz = max(
             opt.get('batchsize') or 1,
             opt.get('eval_batchsize') or 1)
         rag_turn_n_turns = opt.get('rag_turn_n_turns', 1)
         if bsz > 1 or rag_turn_n_turns > 1:
             self.agents += [
                 create_agent_from_shared(self.agents[0].share())
                 for _ in range((bsz * rag_turn_n_turns) - 1)
             ]
         self.agent_dict = self.agents[0].build_dictionary()
         logging.enable()
Пример #8
0
def override_print(suppress=False, prefix=None):
    """
    Context manager to override the print to suppress or modify output.

    Recommended usage is to call this with suppress=True for all non-primary
    workers, or call with a
    prefix of rank on all workers.

    >>> with override_print(prefix="rank{}".format(rank)):
    ...     my_computation()
    :param bool suppress:
        if true, all future print statements are noops.
    :param str prefix:
        if not None, this string is prefixed to all future print statements.
    """
    builtin_print = builtins.print

    def new_print(*args, **kwargs):
        if suppress:
            # do nothing
            return
        elif prefix:
            return builtin_print(prefix, *args, **kwargs)
        else:
            # default to normal print
            return builtin_print(*args, **kwargs)

    if prefix:
        logging.logger.add_format_prefix(prefix)
    if suppress:
        logging.disable()

    # override the print for now
    builtins.print = new_print
    yield
    # bring it back at the end of the context
    builtins.print = builtin_print

    if suppress:
        logging.enable()
Пример #9
0
        if not actions:
            continue

        readme.append(f'__{ag.title.title()}__\n\n')
        readme.append("| Argument | Description |\n")
        readme.append("|----------|----------|\n")
        for row in actions:
            text = "| " + " | ".join(row) + " |"
            text = text.replace("\n", "<br>")
            readme.append(f"{text}\n")
        readme.append("\n\n")
    return readme


logging.disable()

mutators = setup_mutator_registry()


def _display_data(**kwargs):
    with capture_output() as output:
        DisplayData.main(**kwargs)
    return output.getvalue()


with open('mutators_list.inc', 'w') as fout:
    output = _display_data(task=TASK)
    fout.write("## Original output\n\n")
    fout.write(
        "We show the unmutated output of the examples for reference:\n\n")