示例#1
0
    def __init__(self, cfg: DictConfig):

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False)
        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type,
            cfg,
            gradient_checkpointing=gradient_checkpointing,
        )
        with omegaconf.open_dict(cfg):
            cfg.others = DictConfig({
                "is_matching":
                isinstance(model, (Match_BiEncoder, MatchGated_BiEncoder))
            })

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )
        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)
        self.loss_function = init_loss(cfg.encoder.encoder_model_type, cfg)
        self.clustering = cfg.clustering
        if self.clustering:
            cfg.global_loss_buf_sz = 80000000  # this requires a lot of memory

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
示例#2
0
    def __init__(self, cfg: DictConfig, save_temp_conf=False):
        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        if save_temp_conf:
            import pickle
            with open('/content/drive/MyDrive/conf.pickle', 'wb') as f:
                pickle.dump(cfg, f)

        logger.info("***** Initializing components for training *****")

        # graded dataset settings
        self.trainer_type = 'binary' if cfg.binary_trainer else 'graded'
        logger.info(f'trainer_type: {self.trainer_type}')
        self.relations = cfg.relations
        logger.info(f'relations: {self.relations}')

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type, cfg
        )

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
        )
        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
示例#3
0
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        model_file = get_model_file(self.cfg, self.cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        gradient_checkpointing = getattr(self.cfg, "gradient_checkpointing",
                                         False)
        tensorizer, reader, optimizer = init_reader_components(
            cfg.encoder.encoder_model_type,
            cfg,
            gradient_checkpointing=gradient_checkpointing,
        )

        reader, optimizer = setup_for_distributed_mode(
            reader,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )
        self.reader = reader
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.debugging = getattr(self.cfg, "debugging", False)
        self.wiki_data = None
        self.dev_iterator = None
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        if saved_state:
            self._load_saved_state(saved_state)
    def __init__(self, cfg: DictConfig):
        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type, cfg
        )

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
        )
        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
示例#5
0
def main(cfg: DictConfig):

    assert cfg.pair_file, "Please specify passages source as pair_file param"
    assert cfg.model_file, "Please specify encoder checkpoint as model_file param"

    cfg = setup_cfg_gpu(cfg)

    saved_state = load_states_from_checkpoint(cfg.model_file)
    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    logger.info("CFG:")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type,
        cfg,
        inference_only=True,
        cache_dir='/n/fs/nlp-jl5167/cache')

    # load weights from the model file
    logger.info("Loading saved model state ...")
    logger.debug("saved model keys =%s", saved_state.model_dict.keys())

    encoder.load_state_dict(saved_state.model_dict)
    encoder.to(cfg.device)

    # Set seed
    # set_seed(args)
    cfg.encoder.sequence_length = 512  # TODO: Passage can be cut by max_seq_length
    logger.info(f"Max seq length: {cfg.encoder.sequence_length}")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    # Dump passages
    dump_passages(cfg, encoder, tensorizer)
示例#6
0
    def __init__(self, cfg: DictConfig):

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint_legacy(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

            if isinstance(saved_state, CheckpointStateOFA):
                self.mode = "normal"
                # Initialize everything
                gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False)
                tensorizer, model, biencoder_optimizer, reader_optimizer, forward_fn = init_ofa_model(
                    cfg.encoder.encoder_model_type, cfg, gradient_checkpointing=gradient_checkpointing,
                )

            else:
                # Only allowed during evaluation-only mode
                assert isinstance(saved_state, CheckpointState)
                assert cfg.train_datasets is None or len(cfg.train_datasets) == 0
                # Convert from old state to OFA state
                saved_state, self.mode = convert_from_old_state_to_ofa(saved_state)

                if self.mode == "biencoder":
                    # Sanity check
                    assert cfg.evaluate_retriever and (not cfg.evaluate_reader)
                    # Initialize everything
                    tensorizer, biencoder, _ = init_biencoder_components(
                        cfg.encoder.encoder_model_type, cfg, inference_only=True,
                    )
                    reader = None
                else:
                    # Sanity check
                    assert cfg.evaluate_reader and (not cfg.evaluate_retriever)
                    # Initialize everything
                    tensorizer, reader, _ = init_reader_components(
                        cfg.encoder.encoder_model_type, cfg, inference_only=True,
                    )
                    biencoder = None

                # Create a "fake" one-for-all model
                model = SimpleOneForAllModel(
                    biencoder=biencoder, reader=reader, tensorizer=tensorizer,
                )

                # Modify config
                cfg.ignore_checkpoint_optimizer = True
                cfg.ignore_checkpoint_offset = True
                cfg.gradient_checkpointing = False
                cfg.fp16 = False
                # Place holder for backward compatibility
                gradient_checkpointing = False
                biencoder_optimizer = None
                reader_optimizer = None
                forward_fn = ofa_simple_fw_pass  # always the simplest

        else:
            self.mode = "normal"
            # Initialize everything
            gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False)
            tensorizer, model, biencoder_optimizer, reader_optimizer, forward_fn = init_ofa_model(
                cfg.encoder.encoder_model_type, cfg, gradient_checkpointing=gradient_checkpointing,
            )

        model, (biencoder_optimizer, reader_optimizer) = setup_for_distributed_mode(
            model,
            [biencoder_optimizer, reader_optimizer],
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )

        self.forward_fn = forward_fn
        self.model = model
        self.cfg = cfg
        self.ds_cfg = OneForAllDatasetsCfg(cfg)
        self.biencoder_optimizer = biencoder_optimizer
        self.biencoder_scheduler_state = None
        self.reader_optimizer = reader_optimizer
        self.reader_scheduler_state = None
        self.clustering = cfg.biencoder.clustering
        if self.clustering:
            cfg.global_loss_buf_sz = 72000000  # this requires a lot of memory

        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.best_validation_result = None
        self.best_cp_name = None

        # Biencoder loss function (note that reader loss is automatically computed)
        self.biencoder_loss_function: BiEncoderNllLoss = init_loss(cfg.encoder.encoder_model_type, cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
示例#7
0
def main(cfg: DictConfig):
    cfg = setup_cfg_gpu(cfg)
    logger.info("CFG (after gpu  configuration):")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    saved_state = load_states_from_checkpoint(cfg.model_file)
    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)

    encoder_path = cfg.encoder_path
    if encoder_path:
        logger.info("Selecting encoder: %s", encoder_path)
        encoder = getattr(encoder, encoder_path)
    else:
        logger.info("Selecting standard question encoder")
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device,
                                            cfg.n_gpu, cfg.local_rank,
                                            cfg.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")

    encoder_prefix = (encoder_path if encoder_path else "question_model") + "."
    prefix_len = len(encoder_prefix)

    logger.info("Encoder state prefix %s", encoder_prefix)
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith(encoder_prefix)
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    # get questions & answers
    questions = []
    question_answers = []

    if not cfg.qa_dataset:
        logger.warning("Please specify qa_dataset to use")
        return

    ds_key = cfg.qa_dataset
    logger.info("qa_dataset: %s", ds_key)

    qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
    qa_src.load_data()

    for ds_item in qa_src.data:
        question, answers = ds_item.query, ds_item.answers
        questions.append(question)
        question_answers.append(answers)

    index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
    logger.info("Index class %s ", type(index))
    index_buffer_sz = index.buffer_size
    index.init_index(vector_size)
    retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index)

    logger.info("Using special token %s", qa_src.special_query_token)
    questions_tensor = retriever.generate_question_vectors(
        questions, query_token=qa_src.special_query_token)

    if qa_src.selector:
        logger.info("Using custom representation token selector")
        retriever.selector = qa_src.selector

    id_prefixes = []
    ctx_sources = []
    for ctx_src in cfg.ctx_datatsets:
        ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
        id_prefixes.append(ctx_src.id_prefix)
        ctx_sources.append(ctx_src)

    logger.info("id_prefixes per dataset: %s", id_prefixes)

    # index all passages
    ctx_files_patterns = cfg.encoded_ctx_files
    index_path = cfg.index_path

    logger.info("ctx_files_patterns: %s", ctx_files_patterns)
    if ctx_files_patterns:
        assert len(ctx_files_patterns) == len(
            id_prefixes), "ctx len={} pref leb={}".format(
                len(ctx_files_patterns), len(id_prefixes))
    else:
        assert (
            index_path
        ), "Either encoded_ctx_files or index_path parameter should be set."

    input_paths = []
    path_id_prefixes = []
    for i, pattern in enumerate(ctx_files_patterns):
        pattern_files = glob.glob(pattern)
        pattern_id_prefix = id_prefixes[i]
        input_paths.extend(pattern_files)
        path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))

    logger.info("Embeddings files id prefixes: %s", path_id_prefixes)

    if index_path and index.index_exists(index_path):
        logger.info("Index path: %s", index_path)
        retriever.index.deserialize(index_path)
    else:
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index_encoded_data(input_paths,
                                     index_buffer_sz,
                                     path_id_prefixes=path_id_prefixes)
        if index_path:
            retriever.index.serialize(index_path)

    # get top k results
    top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(),
                                                cfg.n_docs)

    # we no longer need the index
    retriever = None

    all_passages = {}
    for ctx_src in ctx_sources:
        ctx_src.load_data_to(all_passages)

    if len(all_passages) == 0:
        raise RuntimeError(
            "No passages data found. Please specify ctx_file param properly.")

    if cfg.validate_as_tables:
        questions_doc_hits = validate_tables(
            all_passages,
            question_answers,
            top_ids_and_scores,
            cfg.validation_workers,
            cfg.match,
        )
    else:
        questions_doc_hits = validate(
            all_passages,
            question_answers,
            top_ids_and_scores,
            cfg.validation_workers,
            cfg.match,
        )

    if cfg.out_file:
        save_results(
            all_passages,
            questions,
            question_answers,
            top_ids_and_scores,
            questions_doc_hits,
            cfg.out_file,
        )

    if cfg.kilt_out_file:
        kilt_ctx = next(
            iter([
                ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)
            ]), None)
        if not kilt_ctx:
            raise RuntimeError("No Kilt compatible context file provided")
        assert hasattr(cfg, "kilt_out_file")
        kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file,
                                 cfg.kilt_out_file)
示例#8
0
def main(cfg: DictConfig):

    assert cfg.model_file, "Please specify encoder checkpoint as model_file param"
    assert cfg.ctx_src, "Please specify passages source as ctx_src param"
    print(os.getcwd())
    cfg = setup_cfg_gpu(cfg)

    saved_state = load_states_from_checkpoint(cfg.model_file)
    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    logger.info("CFG:")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)

    encoder = encoder.ctx_model if cfg.encoder_type == "ctx" else encoder.question_model

    encoder, _ = setup_for_distributed_mode(
        encoder,
        None,
        cfg.device,
        cfg.n_gpu,
        cfg.local_rank,
        cfg.fp16,
        cfg.fp16_opt_level,
    )
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")
    logger.debug("saved model keys =%s", saved_state.model_dict.keys())

    prefix_len = len("ctx_model.")
    ctx_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith("ctx_model.")
    }
    model_to_load.load_state_dict(ctx_state)

    logger.info("reading data source: %s", cfg.ctx_src)

    ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src])
    all_passages_dict = {}
    ctx_src.load_data_to(all_passages_dict)
    all_passages = [(k, v) for k, v in all_passages_dict.items()]

    shard_size = math.ceil(len(all_passages) / cfg.num_shards)
    start_idx = cfg.shard_id * shard_size
    end_idx = start_idx + shard_size

    logger.info(
        "Producing encodings for passages range: %d to %d (out of total %d)",
        start_idx,
        end_idx,
        len(all_passages),
    )
    shard_passages = all_passages[start_idx:end_idx]

    data = gen_ctx_vectors(cfg, shard_passages, encoder, tensorizer, True)

    file = cfg.out_file + "_" + str(cfg.shard_id)
    pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True)
    logger.info("Writing results to %s" % file)
    with open(file, mode="wb") as f:
        pickle.dump(data, f)

    logger.info("Total passages processed %d. Written to %s", len(data), file)
示例#9
0
def main(cfg: DictConfig):
    cfg = setup_cfg_gpu(cfg)
    saved_state = load_states_from_checkpoint(cfg.model_file)

    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    logger.info("CFG (after gpu  configuration):")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)

    logger.info("Loading saved model state ...")
    encoder.load_state(saved_state, strict=False)

    encoder_path = cfg.encoder_path
    if encoder_path:
        logger.info("Selecting encoder: %s", encoder_path)
        encoder = getattr(encoder, encoder_path)
    else:
        logger.info("Selecting standard question encoder")
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device,
                                            cfg.n_gpu, cfg.local_rank,
                                            cfg.fp16)
    encoder.eval()

    model_to_load = get_model_obj(encoder)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    # get questions & answers
    questions = []
    questions_text = []
    question_answers = []

    if not cfg.qa_dataset:
        logger.warning("Please specify qa_dataset to use")
        return

    ds_key = cfg.qa_dataset
    logger.info("qa_dataset: %s", ds_key)

    qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
    qa_src.load_data()

    total_queries = len(qa_src)
    for i in range(total_queries):
        qa_sample = qa_src[i]
        question, answers = qa_sample.query, qa_sample.answers
        questions.append(question)
        question_answers.append(answers)

    logger.info("questions len %d", len(questions))
    logger.info("questions_text len %d", len(questions_text))

    if cfg.rpc_retriever_cfg_file:
        index_buffer_sz = 1000
        retriever = DenseRPCRetriever(
            encoder,
            cfg.batch_size,
            tensorizer,
            cfg.rpc_retriever_cfg_file,
            vector_size,
            use_l2_conversion=cfg.use_l2_conversion,
        )
    else:
        index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
        logger.info("Local Index class %s ", type(index))
        index_buffer_sz = index.buffer_size
        index.init_index(vector_size)
        retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer,
                                        index)

    logger.info("Using special token %s", qa_src.special_query_token)
    questions_tensor = retriever.generate_question_vectors(
        questions, query_token=qa_src.special_query_token)

    if qa_src.selector:
        logger.info("Using custom representation token selector")
        retriever.selector = qa_src.selector

    index_path = cfg.index_path
    if cfg.rpc_retriever_cfg_file and cfg.rpc_index_id:
        retriever.load_index(cfg.rpc_index_id)
    elif index_path and index.index_exists(index_path):
        logger.info("Index path: %s", index_path)
        retriever.index.deserialize(index_path)
    else:
        # send data for indexing
        id_prefixes = []
        ctx_sources = []
        for ctx_src in cfg.ctx_datatsets:
            ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
            id_prefixes.append(ctx_src.id_prefix)
            ctx_sources.append(ctx_src)
            logger.info("ctx_sources: %s", type(ctx_src))

        logger.info("id_prefixes per dataset: %s", id_prefixes)

        # index all passages
        ctx_files_patterns = cfg.encoded_ctx_files

        logger.info("ctx_files_patterns: %s", ctx_files_patterns)
        if ctx_files_patterns:
            assert len(ctx_files_patterns) == len(
                id_prefixes), "ctx len={} pref leb={}".format(
                    len(ctx_files_patterns), len(id_prefixes))
        else:
            assert (
                index_path or cfg.rpc_index_id
            ), "Either encoded_ctx_files or index_path pr rpc_index_id parameter should be set."

        input_paths = []
        path_id_prefixes = []
        for i, pattern in enumerate(ctx_files_patterns):
            pattern_files = glob.glob(pattern)
            pattern_id_prefix = id_prefixes[i]
            input_paths.extend(pattern_files)
            path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))
        logger.info("Embeddings files id prefixes: %s", path_id_prefixes)
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index_encoded_data(input_paths,
                                     index_buffer_sz,
                                     path_id_prefixes=path_id_prefixes)
        if index_path:
            retriever.index.serialize(index_path)

    # get top k results
    top_results_and_scores = retriever.get_top_docs(questions_tensor.numpy(),
                                                    cfg.n_docs)

    if cfg.use_rpc_meta:
        questions_doc_hits = validate_from_meta(
            question_answers,
            top_results_and_scores,
            cfg.validation_workers,
            cfg.match,
            cfg.rpc_meta_compressed,
        )
        if cfg.out_file:
            save_results_from_meta(
                questions,
                question_answers,
                top_results_and_scores,
                questions_doc_hits,
                cfg.out_file,
                cfg.rpc_meta_compressed,
            )
    else:
        all_passages = get_all_passages(ctx_sources)
        if cfg.validate_as_tables:

            questions_doc_hits = validate_tables(
                all_passages,
                question_answers,
                top_results_and_scores,
                cfg.validation_workers,
                cfg.match,
            )

        else:
            questions_doc_hits = validate(
                all_passages,
                question_answers,
                top_results_and_scores,
                cfg.validation_workers,
                cfg.match,
            )

        if cfg.out_file:
            save_results(
                all_passages,
                questions_text if questions_text else questions,
                question_answers,
                top_results_and_scores,
                questions_doc_hits,
                cfg.out_file,
            )

    if cfg.kilt_out_file:
        kilt_ctx = next(
            iter([
                ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)
            ]), None)
        if not kilt_ctx:
            raise RuntimeError("No Kilt compatible context file provided")
        assert hasattr(cfg, "kilt_out_file")
        kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file,
                                 cfg.kilt_out_file)
示例#10
0
    def __init__(self, cfg: DictConfig):
        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type, cfg
        )

        if cfg.deepspeed:
            model.half()

            # XXX
           #no_decay = ["bias", "LayerNorm.weight"]
           #
           #optimizer_grouped_parameters = [
           #    {
           #        "params": [
           #            p
           #            for n, p in model.named_parameters()
           #            if not any(nd in n for nd in no_decay)
           #        ],
           #        "weight_decay": cfg.train.weight_decay,
           #    },
           #    {
           #        "params": [
           #            p
           #            for n, p in model.named_parameters()
           #            if any(nd in n for nd in no_decay)
           #        ],
           #        "weight_decay": 0.0,
           #    },
           #]
    
            optimizer = DeepSpeedCPUAdam(optimizer.param_groups, lr=cfg.train.learning_rate, 
                    weight_decay=cfg.train.weight_decay)

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
        )

        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None