Exemplo n.º 1
0
    def __init__(self, args):
        self.args = args

        self.shard_id = args.local_rank if args.local_rank != -1 else 0
        self.distributed_factor = args.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        model_file = get_model_file(self.args, self.args.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_encoder_params_from_state(saved_state.encoder_params, args)

        tensorizer, reader, optimizer = init_reader_components(
            args.encoder_model_type, args)

        reader, optimizer = setup_for_distributed_mode(reader, optimizer,
                                                       args.device, args.n_gpu,
                                                       args.local_rank,
                                                       args.fp16,
                                                       args.fp16_opt_level)
        self.reader = reader
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        if saved_state:
            self._load_saved_state(saved_state)
Exemplo n.º 2
0
def main(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)
    print_args(args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.ctx_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16, args.fp16_opt_level)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')
    logger.debug('saved model keys =%s', saved_state.model_dict.keys())

    prefix_len = len('ctx_model.')
    ctx_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith('ctx_model.')
    }
    model_to_load.load_state_dict(ctx_state)

    logger.info('reading data from file=%s', args.ctx_file)

    rows = []
    with open(args.ctx_file) as tsvfile:
        reader = csv.reader(tsvfile, delimiter='\t')
        # file format: doc_id, doc_text, title
        rows.extend([(row[0], row[1], row[2]) for row in reader
                     if row[0] != 'id'])

    shard_size = int(len(rows) / args.num_shards)
    start_idx = args.shard_id * shard_size
    end_idx = start_idx + shard_size

    logger.info(
        'Producing encodings for passages range: %d to %d (out of total %d)',
        start_idx, end_idx, len(rows))
    rows = rows[start_idx:end_idx]

    data = gen_ctx_vectors(rows, encoder, tensorizer, True)

    file = args.out_file + '_' + str(args.shard_id)
    pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True)
    logger.info('Writing results to %s' % file)
    with open(file, mode='wb') as f:
        pickle.dump(data, f)

    logger.info('Total passages processed %d. Written to %s', len(data), file)
Exemplo n.º 3
0
def get_retriever(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.question_model
    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')

    prefix_len = len('question_model.')
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith('question_model.')
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info('Encoder vector_size=%d', vector_size)

    index_buffer_sz = args.index_buffer
    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size)
        index_buffer_sz = -1  # encode all at once
    else:
        index = DenseFlatIndexer(vector_size)

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index,
                               args.device)

    # index all passages
    if len(args.encoded_ctx_file) > 0:
        ctx_files_pattern = args.encoded_ctx_file
        input_paths = glob.glob(ctx_files_pattern)

        index_path = "_".join(input_paths[0].split("_")[:-1])
        if args.save_or_load_index and os.path.exists(index_path):
            retriever.index.deserialize(index_path)
        else:
            logger.info('Reading all passages data from files: %s',
                        input_paths)
            retriever.index_encoded_data(input_paths,
                                         buffer_size=index_buffer_sz)
            if args.save_or_load_index:
                retriever.index.serialize(index_path)
        # get questions & answers
    return retriever
Exemplo n.º 4
0
    def __init__(self, name, **config):
        super().__init__(name)

        self.args = argparse.Namespace(**config)
        saved_state = load_states_from_checkpoint(self.args.model_file)
        set_encoder_params_from_state(saved_state.encoder_params, self.args)
        tensorizer, encoder, _ = init_biencoder_components(
            self.args.encoder_model_type, self.args, inference_only=True)
        encoder = encoder.question_model
        encoder, _ = setup_for_distributed_mode(
            encoder,
            None,
            self.args.device,
            self.args.n_gpu,
            self.args.local_rank,
            self.args.fp16,
        )
        encoder.eval()

        # load weights from the model file
        model_to_load = get_model_obj(encoder)

        prefix_len = len("question_model.")
        question_encoder_state = {
            key[prefix_len:]: value
            for (key, value) in saved_state.model_dict.items()
            if key.startswith("question_model.")
        }
        model_to_load.load_state_dict(question_encoder_state)
        vector_size = model_to_load.get_out_size()

        index_buffer_sz = self.args.index_buffer
        if self.args.hnsw_index:
            index = DenseHNSWFlatIndexer(vector_size)
            index.deserialize_from(self.args.hnsw_index_path)
        else:
            index = DenseFlatIndexer(vector_size)

        self.retriever = DenseRetriever(encoder, self.args.batch_size,
                                        tensorizer, index)

        # index all passages
        ctx_files_pattern = self.args.encoded_ctx_file
        input_paths = glob.glob(ctx_files_pattern)

        if not self.args.hnsw_index:
            self.retriever.index_encoded_data(input_paths,
                                              buffer_size=index_buffer_sz)

        # not needed for now
        self.all_passages = load_passages(self.args.ctx_file)

        self.KILT_mapping = None
        if self.args.KILT_mapping:
            self.KILT_mapping = pickle.load(open(self.args.KILT_mapping, "rb"))
Exemplo n.º 5
0
    def __init__(self, cfg: DictConfig):

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False)
        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type,
            cfg,
            gradient_checkpointing=gradient_checkpointing,
        )
        with omegaconf.open_dict(cfg):
            cfg.others = DictConfig({
                "is_matching":
                isinstance(model, (Match_BiEncoder, MatchGated_BiEncoder))
            })

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )
        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)
        self.loss_function = init_loss(cfg.encoder.encoder_model_type, cfg)
        self.clustering = cfg.clustering
        if self.clustering:
            cfg.global_loss_buf_sz = 80000000  # this requires a lot of memory

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
def main(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)
    print_args(args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.ctx_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16, args.fp16_opt_level)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')
    logger.debug('saved model keys =%s', saved_state.model_dict.keys())

    prefix_len = len('ctx_model.')
    ctx_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith('ctx_model.')
    }
    model_to_load.load_state_dict(ctx_state)

    logger.info('reading data from file=%s', args.ctx_file)

    rows = []
    with open(args.ctx_file) as jsonfile:
        blob = json.load(jsonfile)

    for paper in blob.values():
        s2_id = paper["paper_id"]
        title = paper.get("title") or ""
        abstract = paper.get("abstract") or ""
        rows.append((s2_id, abstract, title))

    data = gen_ctx_vectors(rows, encoder, tensorizer, True)

    file = args.out_file
    pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True)
    logger.info('Writing results to %s' % file)
    with open(file, "w+") as f:
        for (s2_id, vec) in data:
            out = {"paper_id": s2_id, "embedding": vec.tolist()}
            f.write(json.dumps(out) + "\n")

    logger.info('Total passages processed %d. Written to %s', len(data), file)
Exemplo n.º 7
0
    def __init__(self, cfg: DictConfig, save_temp_conf=False):
        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        if save_temp_conf:
            import pickle
            with open('/content/drive/MyDrive/conf.pickle', 'wb') as f:
                pickle.dump(cfg, f)

        logger.info("***** Initializing components for training *****")

        # graded dataset settings
        self.trainer_type = 'binary' if cfg.binary_trainer else 'graded'
        logger.info(f'trainer_type: {self.trainer_type}')
        self.relations = cfg.relations
        logger.info(f'relations: {self.relations}')

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type, cfg
        )

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
        )
        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
Exemplo n.º 8
0
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        model_file = get_model_file(self.cfg, self.cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        gradient_checkpointing = getattr(self.cfg, "gradient_checkpointing",
                                         False)
        tensorizer, reader, optimizer = init_reader_components(
            cfg.encoder.encoder_model_type,
            cfg,
            gradient_checkpointing=gradient_checkpointing,
        )

        reader, optimizer = setup_for_distributed_mode(
            reader,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )
        self.reader = reader
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.debugging = getattr(self.cfg, "debugging", False)
        self.wiki_data = None
        self.dev_iterator = None
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        if saved_state:
            self._load_saved_state(saved_state)
Exemplo n.º 9
0
    def __init__(self, cfg: DictConfig):
        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type, cfg
        )

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
        )
        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
Exemplo n.º 10
0
def main(cfg: DictConfig):
    cfg = setup_cfg_gpu(cfg)
    logger.info("CFG (after gpu  configuration):")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    saved_state = load_states_from_checkpoint(cfg.model_file)
    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)

    encoder_path = cfg.encoder_path
    if encoder_path:
        logger.info("Selecting encoder: %s", encoder_path)
        encoder = getattr(encoder, encoder_path)
    else:
        logger.info("Selecting standard question encoder")
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device,
                                            cfg.n_gpu, cfg.local_rank,
                                            cfg.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")

    encoder_prefix = (encoder_path if encoder_path else "question_model") + "."
    prefix_len = len(encoder_prefix)

    logger.info("Encoder state prefix %s", encoder_prefix)
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith(encoder_prefix)
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    # get questions & answers
    questions = []
    question_answers = []

    if not cfg.qa_dataset:
        logger.warning("Please specify qa_dataset to use")
        return

    ds_key = cfg.qa_dataset
    logger.info("qa_dataset: %s", ds_key)

    qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
    qa_src.load_data()

    for ds_item in qa_src.data:
        question, answers = ds_item.query, ds_item.answers
        questions.append(question)
        question_answers.append(answers)

    index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
    logger.info("Index class %s ", type(index))
    index_buffer_sz = index.buffer_size
    index.init_index(vector_size)
    retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index)

    logger.info("Using special token %s", qa_src.special_query_token)
    questions_tensor = retriever.generate_question_vectors(
        questions, query_token=qa_src.special_query_token)

    if qa_src.selector:
        logger.info("Using custom representation token selector")
        retriever.selector = qa_src.selector

    id_prefixes = []
    ctx_sources = []
    for ctx_src in cfg.ctx_datatsets:
        ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
        id_prefixes.append(ctx_src.id_prefix)
        ctx_sources.append(ctx_src)

    logger.info("id_prefixes per dataset: %s", id_prefixes)

    # index all passages
    ctx_files_patterns = cfg.encoded_ctx_files
    index_path = cfg.index_path

    logger.info("ctx_files_patterns: %s", ctx_files_patterns)
    if ctx_files_patterns:
        assert len(ctx_files_patterns) == len(
            id_prefixes), "ctx len={} pref leb={}".format(
                len(ctx_files_patterns), len(id_prefixes))
    else:
        assert (
            index_path
        ), "Either encoded_ctx_files or index_path parameter should be set."

    input_paths = []
    path_id_prefixes = []
    for i, pattern in enumerate(ctx_files_patterns):
        pattern_files = glob.glob(pattern)
        pattern_id_prefix = id_prefixes[i]
        input_paths.extend(pattern_files)
        path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))

    logger.info("Embeddings files id prefixes: %s", path_id_prefixes)

    if index_path and index.index_exists(index_path):
        logger.info("Index path: %s", index_path)
        retriever.index.deserialize(index_path)
    else:
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index_encoded_data(input_paths,
                                     index_buffer_sz,
                                     path_id_prefixes=path_id_prefixes)
        if index_path:
            retriever.index.serialize(index_path)

    # get top k results
    top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(),
                                                cfg.n_docs)

    # we no longer need the index
    retriever = None

    all_passages = {}
    for ctx_src in ctx_sources:
        ctx_src.load_data_to(all_passages)

    if len(all_passages) == 0:
        raise RuntimeError(
            "No passages data found. Please specify ctx_file param properly.")

    if cfg.validate_as_tables:
        questions_doc_hits = validate_tables(
            all_passages,
            question_answers,
            top_ids_and_scores,
            cfg.validation_workers,
            cfg.match,
        )
    else:
        questions_doc_hits = validate(
            all_passages,
            question_answers,
            top_ids_and_scores,
            cfg.validation_workers,
            cfg.match,
        )

    if cfg.out_file:
        save_results(
            all_passages,
            questions,
            question_answers,
            top_ids_and_scores,
            questions_doc_hits,
            cfg.out_file,
        )

    if cfg.kilt_out_file:
        kilt_ctx = next(
            iter([
                ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)
            ]), None)
        if not kilt_ctx:
            raise RuntimeError("No Kilt compatible context file provided")
        assert hasattr(cfg, "kilt_out_file")
        kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file,
                                 cfg.kilt_out_file)
Exemplo n.º 11
0
def main(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(
        args.encoder_model_type, args, inference_only=True)

    encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu,
                                            args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')

    prefix_len = len('question_model.')
    question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                              key.startswith('question_model.')}
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info('Encoder vector_size=%d', vector_size)

    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size, args.index_buffer)
    else:
        index = DenseFlatIndexer(vector_size, args.index_buffer)

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)

    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)

    index_path = "_".join(input_paths[0].split("_")[:-1])
    if args.save_or_load_index and (os.path.exists(index_path) or os.path.exists(index_path + ".index.dpr")):
        retriever.index.deserialize_from(index_path)
    else:
        logger.info('Reading all passages data from files: %s', input_paths)
        retriever.index.index_data(input_paths)
        if args.save_or_load_index:
            retriever.index.serialize(index_path)
    # get questions & answers
    questions = []
    question_ids = []

    for ds_item in parse_qa_csv_file(args.qa_file):
        question_id, question = ds_item
        question_ids.append(question_id)
        questions.append(question)

    questions_tensor = retriever.generate_question_vectors(questions)

    # get top k results
    top_ids_and_scores = retriever.get_top_docs(
        questions_tensor.numpy(), args.n_docs)

    # all_passages = load_passages(args.ctx_file)

    # if len(all_passages) == 0:
    #     raise RuntimeError(
    #         'No passages data found. Please specify ctx_file param properly.')

    # questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers,
    #                               args.match)

    if args.out_file:
        save_results(question_ids,
                     top_ids_and_scores, args.out_file)
Exemplo n.º 12
0
def main(cfg: DictConfig):

    assert cfg.model_file, "Please specify encoder checkpoint as model_file param"
    assert cfg.ctx_src, "Please specify passages source as ctx_src param"
    print(os.getcwd())
    cfg = setup_cfg_gpu(cfg)

    saved_state = load_states_from_checkpoint(cfg.model_file)
    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    logger.info("CFG:")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)

    encoder = encoder.ctx_model if cfg.encoder_type == "ctx" else encoder.question_model

    encoder, _ = setup_for_distributed_mode(
        encoder,
        None,
        cfg.device,
        cfg.n_gpu,
        cfg.local_rank,
        cfg.fp16,
        cfg.fp16_opt_level,
    )
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")
    logger.debug("saved model keys =%s", saved_state.model_dict.keys())

    prefix_len = len("ctx_model.")
    ctx_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith("ctx_model.")
    }
    model_to_load.load_state_dict(ctx_state)

    logger.info("reading data source: %s", cfg.ctx_src)

    ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src])
    all_passages_dict = {}
    ctx_src.load_data_to(all_passages_dict)
    all_passages = [(k, v) for k, v in all_passages_dict.items()]

    shard_size = math.ceil(len(all_passages) / cfg.num_shards)
    start_idx = cfg.shard_id * shard_size
    end_idx = start_idx + shard_size

    logger.info(
        "Producing encodings for passages range: %d to %d (out of total %d)",
        start_idx,
        end_idx,
        len(all_passages),
    )
    shard_passages = all_passages[start_idx:end_idx]

    data = gen_ctx_vectors(cfg, shard_passages, encoder, tensorizer, True)

    file = cfg.out_file + "_" + str(cfg.shard_id)
    pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True)
    logger.info("Writing results to %s" % file)
    with open(file, mode="wb") as f:
        pickle.dump(data, f)

    logger.info("Total passages processed %d. Written to %s", len(data), file)
Exemplo n.º 13
0
def main(cfg: DictConfig):
    cfg = setup_cfg_gpu(cfg)
    saved_state = load_states_from_checkpoint(cfg.model_file)

    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    logger.info("CFG (after gpu  configuration):")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)

    logger.info("Loading saved model state ...")
    encoder.load_state(saved_state, strict=False)

    encoder_path = cfg.encoder_path
    if encoder_path:
        logger.info("Selecting encoder: %s", encoder_path)
        encoder = getattr(encoder, encoder_path)
    else:
        logger.info("Selecting standard question encoder")
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device,
                                            cfg.n_gpu, cfg.local_rank,
                                            cfg.fp16)
    encoder.eval()

    model_to_load = get_model_obj(encoder)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    # get questions & answers
    questions = []
    questions_text = []
    question_answers = []

    if not cfg.qa_dataset:
        logger.warning("Please specify qa_dataset to use")
        return

    ds_key = cfg.qa_dataset
    logger.info("qa_dataset: %s", ds_key)

    qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
    qa_src.load_data()

    total_queries = len(qa_src)
    for i in range(total_queries):
        qa_sample = qa_src[i]
        question, answers = qa_sample.query, qa_sample.answers
        questions.append(question)
        question_answers.append(answers)

    logger.info("questions len %d", len(questions))
    logger.info("questions_text len %d", len(questions_text))

    if cfg.rpc_retriever_cfg_file:
        index_buffer_sz = 1000
        retriever = DenseRPCRetriever(
            encoder,
            cfg.batch_size,
            tensorizer,
            cfg.rpc_retriever_cfg_file,
            vector_size,
            use_l2_conversion=cfg.use_l2_conversion,
        )
    else:
        index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
        logger.info("Local Index class %s ", type(index))
        index_buffer_sz = index.buffer_size
        index.init_index(vector_size)
        retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer,
                                        index)

    logger.info("Using special token %s", qa_src.special_query_token)
    questions_tensor = retriever.generate_question_vectors(
        questions, query_token=qa_src.special_query_token)

    if qa_src.selector:
        logger.info("Using custom representation token selector")
        retriever.selector = qa_src.selector

    index_path = cfg.index_path
    if cfg.rpc_retriever_cfg_file and cfg.rpc_index_id:
        retriever.load_index(cfg.rpc_index_id)
    elif index_path and index.index_exists(index_path):
        logger.info("Index path: %s", index_path)
        retriever.index.deserialize(index_path)
    else:
        # send data for indexing
        id_prefixes = []
        ctx_sources = []
        for ctx_src in cfg.ctx_datatsets:
            ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
            id_prefixes.append(ctx_src.id_prefix)
            ctx_sources.append(ctx_src)
            logger.info("ctx_sources: %s", type(ctx_src))

        logger.info("id_prefixes per dataset: %s", id_prefixes)

        # index all passages
        ctx_files_patterns = cfg.encoded_ctx_files

        logger.info("ctx_files_patterns: %s", ctx_files_patterns)
        if ctx_files_patterns:
            assert len(ctx_files_patterns) == len(
                id_prefixes), "ctx len={} pref leb={}".format(
                    len(ctx_files_patterns), len(id_prefixes))
        else:
            assert (
                index_path or cfg.rpc_index_id
            ), "Either encoded_ctx_files or index_path pr rpc_index_id parameter should be set."

        input_paths = []
        path_id_prefixes = []
        for i, pattern in enumerate(ctx_files_patterns):
            pattern_files = glob.glob(pattern)
            pattern_id_prefix = id_prefixes[i]
            input_paths.extend(pattern_files)
            path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))
        logger.info("Embeddings files id prefixes: %s", path_id_prefixes)
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index_encoded_data(input_paths,
                                     index_buffer_sz,
                                     path_id_prefixes=path_id_prefixes)
        if index_path:
            retriever.index.serialize(index_path)

    # get top k results
    top_results_and_scores = retriever.get_top_docs(questions_tensor.numpy(),
                                                    cfg.n_docs)

    if cfg.use_rpc_meta:
        questions_doc_hits = validate_from_meta(
            question_answers,
            top_results_and_scores,
            cfg.validation_workers,
            cfg.match,
            cfg.rpc_meta_compressed,
        )
        if cfg.out_file:
            save_results_from_meta(
                questions,
                question_answers,
                top_results_and_scores,
                questions_doc_hits,
                cfg.out_file,
                cfg.rpc_meta_compressed,
            )
    else:
        all_passages = get_all_passages(ctx_sources)
        if cfg.validate_as_tables:

            questions_doc_hits = validate_tables(
                all_passages,
                question_answers,
                top_results_and_scores,
                cfg.validation_workers,
                cfg.match,
            )

        else:
            questions_doc_hits = validate(
                all_passages,
                question_answers,
                top_results_and_scores,
                cfg.validation_workers,
                cfg.match,
            )

        if cfg.out_file:
            save_results(
                all_passages,
                questions_text if questions_text else questions,
                question_answers,
                top_results_and_scores,
                questions_doc_hits,
                cfg.out_file,
            )

    if cfg.kilt_out_file:
        kilt_ctx = next(
            iter([
                ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)
            ]), None)
        if not kilt_ctx:
            raise RuntimeError("No Kilt compatible context file provided")
        assert hasattr(cfg, "kilt_out_file")
        kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file,
                                 cfg.kilt_out_file)
Exemplo n.º 14
0
def setup_dpr(model_file,
              ctx_file,
              encoded_ctx_file,
              hnsw_index=False,
              save_or_load_index=False):
    global retriever
    global all_passages
    global answer_cache
    global answer_cache_path
    parameter_setting = model_file + ctx_file + encoded_ctx_file
    answer_cache_path = hashlib.sha1(
        parameter_setting.encode("utf-8")).hexdigest()
    if os.path.exists(answer_cache_path):
        answer_cache = pickle.load(open(answer_cache_path, 'rb'))
    else:
        answer_cache = {}
    parser = argparse.ArgumentParser()
    add_encoder_params(parser)
    add_tokenizer_params(parser)
    add_cuda_params(parser)

    args = parser.parse_args()
    args.model_file = model_file
    args.ctx_file = ctx_file
    args.encoded_ctx_file = encoded_ctx_file
    args.hnsw_index = hnsw_index
    args.save_or_load_index = save_or_load_index
    args.batch_size = 1  # TODO

    setup_args_gpu(args)
    print_args(args)

    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")

    prefix_len = len("question_model.")
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith("question_model.")
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size, 50000)
    else:
        index = DenseFlatIndexer(vector_size, 50000,
                                 "IVF65536,PQ64")  #IVF65536

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)

    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)

    index_path = "_".join(input_paths[0].split("_")[:-1])
    if args.save_or_load_index and (os.path.exists(index_path) or
                                    os.path.exists(index_path + ".index.dpr")):
        retriever.index.deserialize_from(index_path)
    else:
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index.index_data(input_paths)

        if args.save_or_load_index:
            retriever.index.serialize(index_path)
        # get questions & answers

    all_passages = load_passages(args.ctx_file)
Exemplo n.º 15
0
def main(args):
    questions = []
    question_answers = []
    for i, ds_item in enumerate(parse_qa_csv_file(args.qa_file)):
        #if i == 10:
        #    break
        question, answers = ds_item
        questions.append(question)
        question_answers.append(answers)
    if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert':
        saved_state = load_states_from_checkpoint(args.model_file)
        set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True)

    if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert':
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu,
                                            args.local_rank,
                                            args.fp16)
    encoder.eval()

    if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert':
        # load weights from the model file
        model_to_load = get_model_obj(encoder)
        logger.info('Loading saved model state ...')

        prefix_len = len('question_model.')
        question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                                  key.startswith('question_model.')}
        model_to_load.load_state_dict(question_encoder_state)
        vector_size = model_to_load.get_out_size()
        logger.info('Encoder vector_size=%d', vector_size)
    else:
        vector_size = 16

    index_buffer_sz = args.index_buffer

    if args.index_type == 'hnsw':
        index = DenseHNSWFlatIndexer(vector_size, index_buffer_sz)
        index_buffer_sz = -1  # encode all at once
    elif args.index_type == 'custom':
        index = CustomIndexer(vector_size, index_buffer_sz)
    else:
        index = DenseFlatIndexer(vector_size)

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index, args.encoder_model_type=='colbert')


    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)
    #logger.info('Reading all passages data from files: %s', input_paths)
    #memmap = retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz, memmap=args.encoder_model_type=='colbert')
    print(input_paths)
    index_path = "_".join(input_paths[0].split("_")[:-1])
    memmap_path = 'memmap.npy'
    print(args.save_or_load_index, os.path.exists(index_path), index_path)
    if args.save_or_load_index and os.path.exists(index_path+'.index.dpr'):
    #if False:
        retriever.index.deserialize_from(index_path)
        if args.encoder_model_type=='colbert':
            memmap = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=(21015324, 250, 16))
        else:
            memmap = None
    else:
        logger.info('Reading all passages data from files: %s', input_paths)
        if args.encoder_model_type=='colbert':
            memmap = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=(21015324, 250, 16))
        else:
            memmap = None
        retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz, memmap=memmap)
        if args.save_or_load_index:
            retriever.index.serialize(index_path)
    # get questions & answers



    questions_tensor = retriever.generate_question_vectors(questions)

    # get top k results
    top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), args.n_docs, is_colbert=args.encoder_model_type=='colbert')

    with open('approx_scores.pkl', 'wb') as f:
        pickle.dump(top_ids_and_scores, f)

    retriever.index.index.reset()
    if args.encoder_model_type=='colbert':
        logger.info('Colbert score') 
        top_ids_and_scores_colbert = retriever.colbert_search(questions_tensor.numpy(), memmap, top_ids_and_scores, args.n_docs)

    all_passages = load_passages(args.ctx_file)

    if len(all_passages) == 0:
        raise RuntimeError('No passages data found. Please specify ctx_file param properly.')


    with open('colbert_scores.pkl', 'wb') as f:
        pickle.dump(top_ids_and_scores_colbert, f)

    

    questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers,
                                  args.match)
    if args.encoder_model_type=='colbert':
        questions_doc_hits_colbert = validate(all_passages, question_answers, top_ids_and_scores_colbert, args.validation_workers,
                                  args.match)
        

    if args.out_file:
        save_results(all_passages, questions, question_answers, top_ids_and_scores, questions_doc_hits, args.out_file)
        if args.encoder_model_type=='colbert':
            save_results(all_passages, questions, question_answers, top_ids_and_scores_colbert, questions_doc_hits_colbert, args.out_file+'colbert')
Exemplo n.º 16
0
            psg_ids, scores = ranker.closest_docs(question, args.n_docs)
            top_ids_and_scores.append((psg_ids, scores))
    else:
        from dpr.models import init_biencoder_components
        from dpr.utils.data_utils import Tensorizer
        from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint
        from dpr.indexer.faiss_indexers import DenseIndexer, DenseHNSWFlatIndexer, DenseFlatIndexer
        from dense_retriever import DenseRetriever

        saved_state = load_states_from_checkpoint(args.dpr_model_file)
        set_encoder_params_from_state(saved_state.encoder_params, args)
        tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True)
        encoder = encoder.question_model
        setup_args_gpu(args)
        encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu,
                                                args.local_rank,
                                                args.fp16)
        encoder.eval()

        # load weights from the model file
        model_to_load = get_model_obj(encoder)
        prefix_len = len('question_model.')
        question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                                key.startswith('question_model.')}
        model_to_load.load_state_dict(question_encoder_state)
        vector_size = model_to_load.get_out_size()

        index_buffer_sz = args.index_buffer
        if args.hnsw_index:
            index = DenseHNSWFlatIndexer(vector_size)
            index_buffer_sz = -1  # encode all at once
Exemplo n.º 17
0
    def __init__(self, cfg: DictConfig):

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint_legacy(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

            if isinstance(saved_state, CheckpointStateOFA):
                self.mode = "normal"
                # Initialize everything
                gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False)
                tensorizer, model, biencoder_optimizer, reader_optimizer, forward_fn = init_ofa_model(
                    cfg.encoder.encoder_model_type, cfg, gradient_checkpointing=gradient_checkpointing,
                )

            else:
                # Only allowed during evaluation-only mode
                assert isinstance(saved_state, CheckpointState)
                assert cfg.train_datasets is None or len(cfg.train_datasets) == 0
                # Convert from old state to OFA state
                saved_state, self.mode = convert_from_old_state_to_ofa(saved_state)

                if self.mode == "biencoder":
                    # Sanity check
                    assert cfg.evaluate_retriever and (not cfg.evaluate_reader)
                    # Initialize everything
                    tensorizer, biencoder, _ = init_biencoder_components(
                        cfg.encoder.encoder_model_type, cfg, inference_only=True,
                    )
                    reader = None
                else:
                    # Sanity check
                    assert cfg.evaluate_reader and (not cfg.evaluate_retriever)
                    # Initialize everything
                    tensorizer, reader, _ = init_reader_components(
                        cfg.encoder.encoder_model_type, cfg, inference_only=True,
                    )
                    biencoder = None

                # Create a "fake" one-for-all model
                model = SimpleOneForAllModel(
                    biencoder=biencoder, reader=reader, tensorizer=tensorizer,
                )

                # Modify config
                cfg.ignore_checkpoint_optimizer = True
                cfg.ignore_checkpoint_offset = True
                cfg.gradient_checkpointing = False
                cfg.fp16 = False
                # Place holder for backward compatibility
                gradient_checkpointing = False
                biencoder_optimizer = None
                reader_optimizer = None
                forward_fn = ofa_simple_fw_pass  # always the simplest

        else:
            self.mode = "normal"
            # Initialize everything
            gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False)
            tensorizer, model, biencoder_optimizer, reader_optimizer, forward_fn = init_ofa_model(
                cfg.encoder.encoder_model_type, cfg, gradient_checkpointing=gradient_checkpointing,
            )

        model, (biencoder_optimizer, reader_optimizer) = setup_for_distributed_mode(
            model,
            [biencoder_optimizer, reader_optimizer],
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )

        self.forward_fn = forward_fn
        self.model = model
        self.cfg = cfg
        self.ds_cfg = OneForAllDatasetsCfg(cfg)
        self.biencoder_optimizer = biencoder_optimizer
        self.biencoder_scheduler_state = None
        self.reader_optimizer = reader_optimizer
        self.reader_scheduler_state = None
        self.clustering = cfg.biencoder.clustering
        if self.clustering:
            cfg.global_loss_buf_sz = 72000000  # this requires a lot of memory

        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.best_validation_result = None
        self.best_cp_name = None

        # Biencoder loss function (note that reader loss is automatically computed)
        self.biencoder_loss_function: BiEncoderNllLoss = init_loss(cfg.encoder.encoder_model_type, cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
Exemplo n.º 18
0
    def __init__(self, cfg: DictConfig):
        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type, cfg
        )

        if cfg.deepspeed:
            model.half()

            # XXX
           #no_decay = ["bias", "LayerNorm.weight"]
           #
           #optimizer_grouped_parameters = [
           #    {
           #        "params": [
           #            p
           #            for n, p in model.named_parameters()
           #            if not any(nd in n for nd in no_decay)
           #        ],
           #        "weight_decay": cfg.train.weight_decay,
           #    },
           #    {
           #        "params": [
           #            p
           #            for n, p in model.named_parameters()
           #            if any(nd in n for nd in no_decay)
           #        ],
           #        "weight_decay": 0.0,
           #    },
           #]
    
            optimizer = DeepSpeedCPUAdam(optimizer.param_groups, lr=cfg.train.learning_rate, 
                    weight_decay=cfg.train.weight_decay)

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
        )

        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None