Esempio n. 1
0
    def __init__(self, args):
        self.args = args

        self.shard_id = args.local_rank if args.local_rank != -1 else 0
        self.distributed_factor = args.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        model_file = get_model_file(self.args, self.args.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_encoder_params_from_state(saved_state.encoder_params, args)

        tensorizer, reader, optimizer = init_reader_components(
            args.encoder_model_type, args)

        reader, optimizer = setup_for_distributed_mode(reader, optimizer,
                                                       args.device, args.n_gpu,
                                                       args.local_rank,
                                                       args.fp16,
                                                       args.fp16_opt_level)
        self.reader = reader
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        if saved_state:
            self._load_saved_state(saved_state)
Esempio n. 2
0
def run(args):
    # pretrained_model = "/home/naoki/Document/git/DPR/results/dpr_long/dpr_biencoder.5.379"
    # out_root = "models/dpr_long"
    pretrained_model = args.pretrained_model
    out_root = args.out_dir

    saved_state = load_states_from_checkpoint(pretrained_model)
    set_encoder_params_from_state(saved_state.encoder_params, args)
    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    out_q = os.path.join(out_root, "q_encoder")
    out_c = os.path.join(out_root, "c_encoder")
    os.makedirs(out_q, exist_ok=True)
    os.makedirs(out_c, exist_ok=True)

    tensorizer.tokenizer.save_pretrained(out_q)
    encoder.question_model.save_pretrained(out_q)

    tensorizer.tokenizer.save_pretrained(out_c)
    encoder.ctx_model.save_pretrained(out_c)

    q_save_dict = {
        k.replace("question_model", "question_encoder.bert_model"): v
        for k, v in saved_state.model_dict.items()
        if k.startswith("question_model")
    }
    c_save_dict = {
        k.replace("ctx_model", "ctx_encoder.bert_model"): v
        for k, v in saved_state.model_dict.items() if k.startswith("ctx_model")
    }
    torch.save(q_save_dict, os.path.join(out_q, PRETRAIN_MODEL_NAME))
    torch.save(c_save_dict, os.path.join(out_c, PRETRAIN_MODEL_NAME))
Esempio n. 3
0
def load_bi_encoders(dpr_cp_path: str, args):
    def load_saved_state_into_model(encoder, prefix="ctx_model."):
        encoder.eval()
        model_to_load = get_model_obj(encoder)
        prefix_len = len(prefix)
        encoder_state = {
            key[prefix_len:]: value
            for (key, value) in saved_state.model_dict.items()
            if key.startswith(prefix)
        }
        model_to_load.load_state_dict(encoder_state)
        return model_to_load

    saved_state = load_states_from_checkpoint(dpr_cp_path)
    set_encoder_params_from_state(saved_state.encoder_params, args)
    print(args)
    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)
    encoder_ctx = encoder.ctx_model
    encoder_question = encoder.question_model

    final_encoder_ctx = load_saved_state_into_model(encoder_ctx, "ctx_model.")
    final_encoder_question = load_saved_state_into_model(
        encoder_question, "question_model.")

    return final_encoder_ctx, final_encoder_question
Esempio n. 4
0
def get_query_encoder(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)
    print_args(args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.question_model

    #encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu,
    #                                        args.local_rank,
    #                                        args.fp16,
    #                                        args.fp16_opt_level)
    #encoder.eval()

    logger.info(args.__dict__)

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')
    logger.debug('saved model keys =%s', saved_state.model_dict.keys())

    prefix_len = len('question_model.')
    ctx_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith('question_model.')
    }
    model_to_load.load_state_dict(ctx_state)

    return model_to_load
Esempio n. 5
0
def main(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)
    print_args(args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.ctx_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16, args.fp16_opt_level)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')
    logger.debug('saved model keys =%s', saved_state.model_dict.keys())

    prefix_len = len('ctx_model.')
    ctx_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith('ctx_model.')
    }
    model_to_load.load_state_dict(ctx_state)

    logger.info('reading data from file=%s', args.ctx_file)

    rows = []
    with open(args.ctx_file) as tsvfile:
        reader = csv.reader(tsvfile, delimiter='\t')
        # file format: doc_id, doc_text, title
        rows.extend([(row[0], row[1], row[2]) for row in reader
                     if row[0] != 'id'])

    shard_size = int(len(rows) / args.num_shards)
    start_idx = args.shard_id * shard_size
    end_idx = start_idx + shard_size

    logger.info(
        'Producing encodings for passages range: %d to %d (out of total %d)',
        start_idx, end_idx, len(rows))
    rows = rows[start_idx:end_idx]

    data = gen_ctx_vectors(rows, encoder, tensorizer, True)

    file = args.out_file + '_' + str(args.shard_id)
    pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True)
    logger.info('Writing results to %s' % file)
    with open(file, mode='wb') as f:
        pickle.dump(data, f)

    logger.info('Total passages processed %d. Written to %s', len(data), file)
Esempio n. 6
0
    def __init__(self, cfg: DictConfig):

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False)
        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type,
            cfg,
            gradient_checkpointing=gradient_checkpointing,
        )
        with omegaconf.open_dict(cfg):
            cfg.others = DictConfig({
                "is_matching":
                isinstance(model, (Match_BiEncoder, MatchGated_BiEncoder))
            })

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )
        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)
        self.loss_function = init_loss(cfg.encoder.encoder_model_type, cfg)
        self.clustering = cfg.clustering
        if self.clustering:
            cfg.global_loss_buf_sz = 80000000  # this requires a lot of memory

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
Esempio n. 7
0
def get_retriever(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.question_model
    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')

    prefix_len = len('question_model.')
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith('question_model.')
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info('Encoder vector_size=%d', vector_size)

    index_buffer_sz = args.index_buffer
    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size)
        index_buffer_sz = -1  # encode all at once
    else:
        index = DenseFlatIndexer(vector_size)

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index,
                               args.device)

    # index all passages
    if len(args.encoded_ctx_file) > 0:
        ctx_files_pattern = args.encoded_ctx_file
        input_paths = glob.glob(ctx_files_pattern)

        index_path = "_".join(input_paths[0].split("_")[:-1])
        if args.save_or_load_index and os.path.exists(index_path):
            retriever.index.deserialize(index_path)
        else:
            logger.info('Reading all passages data from files: %s',
                        input_paths)
            retriever.index_encoded_data(input_paths,
                                         buffer_size=index_buffer_sz)
            if args.save_or_load_index:
                retriever.index.serialize(index_path)
        # get questions & answers
    return retriever
Esempio n. 8
0
    def __init__(self, name, **config):
        super().__init__(name)

        self.args = argparse.Namespace(**config)
        saved_state = load_states_from_checkpoint(self.args.model_file)
        set_encoder_params_from_state(saved_state.encoder_params, self.args)
        tensorizer, encoder, _ = init_biencoder_components(
            self.args.encoder_model_type, self.args, inference_only=True)
        encoder = encoder.question_model
        encoder, _ = setup_for_distributed_mode(
            encoder,
            None,
            self.args.device,
            self.args.n_gpu,
            self.args.local_rank,
            self.args.fp16,
        )
        encoder.eval()

        # load weights from the model file
        model_to_load = get_model_obj(encoder)

        prefix_len = len("question_model.")
        question_encoder_state = {
            key[prefix_len:]: value
            for (key, value) in saved_state.model_dict.items()
            if key.startswith("question_model.")
        }
        model_to_load.load_state_dict(question_encoder_state)
        vector_size = model_to_load.get_out_size()

        index_buffer_sz = self.args.index_buffer
        if self.args.hnsw_index:
            index = DenseHNSWFlatIndexer(vector_size)
            index.deserialize_from(self.args.hnsw_index_path)
        else:
            index = DenseFlatIndexer(vector_size)

        self.retriever = DenseRetriever(encoder, self.args.batch_size,
                                        tensorizer, index)

        # index all passages
        ctx_files_pattern = self.args.encoded_ctx_file
        input_paths = glob.glob(ctx_files_pattern)

        if not self.args.hnsw_index:
            self.retriever.index_encoded_data(input_paths,
                                              buffer_size=index_buffer_sz)

        # not needed for now
        self.all_passages = load_passages(self.args.ctx_file)

        self.KILT_mapping = None
        if self.args.KILT_mapping:
            self.KILT_mapping = pickle.load(open(self.args.KILT_mapping, "rb"))
Esempio n. 9
0
    def __init__(self, cfg: DictConfig, save_temp_conf=False):
        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        if save_temp_conf:
            import pickle
            with open('/content/drive/MyDrive/conf.pickle', 'wb') as f:
                pickle.dump(cfg, f)

        logger.info("***** Initializing components for training *****")

        # graded dataset settings
        self.trainer_type = 'binary' if cfg.binary_trainer else 'graded'
        logger.info(f'trainer_type: {self.trainer_type}')
        self.relations = cfg.relations
        logger.info(f'relations: {self.relations}')

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type, cfg
        )

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
        )
        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
def main(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)
    print_args(args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.ctx_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16, args.fp16_opt_level)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')
    logger.debug('saved model keys =%s', saved_state.model_dict.keys())

    prefix_len = len('ctx_model.')
    ctx_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith('ctx_model.')
    }
    model_to_load.load_state_dict(ctx_state)

    logger.info('reading data from file=%s', args.ctx_file)

    rows = []
    with open(args.ctx_file) as jsonfile:
        blob = json.load(jsonfile)

    for paper in blob.values():
        s2_id = paper["paper_id"]
        title = paper.get("title") or ""
        abstract = paper.get("abstract") or ""
        rows.append((s2_id, abstract, title))

    data = gen_ctx_vectors(rows, encoder, tensorizer, True)

    file = args.out_file
    pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True)
    logger.info('Writing results to %s' % file)
    with open(file, "w+") as f:
        for (s2_id, vec) in data:
            out = {"paper_id": s2_id, "embedding": vec.tolist()}
            f.write(json.dumps(out) + "\n")

    logger.info('Total passages processed %d. Written to %s', len(data), file)
Esempio n. 11
0
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        model_file = get_model_file(self.cfg, self.cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        gradient_checkpointing = getattr(self.cfg, "gradient_checkpointing",
                                         False)
        tensorizer, reader, optimizer = init_reader_components(
            cfg.encoder.encoder_model_type,
            cfg,
            gradient_checkpointing=gradient_checkpointing,
        )

        reader, optimizer = setup_for_distributed_mode(
            reader,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )
        self.reader = reader
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.debugging = getattr(self.cfg, "debugging", False)
        self.wiki_data = None
        self.dev_iterator = None
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        if saved_state:
            self._load_saved_state(saved_state)
Esempio n. 12
0
    def __init__(self, args, model_file):
        self.args = args

        saved_state = load_states_from_checkpoint(model_file)
        set_encoder_params_from_state(saved_state.encoder_params, args)

        tensorizer, reader, optimizer = init_reader_components(
            args.encoder_model_type, args)
        tensorizer.pad_to_max = False
        del optimizer

        reader = reader.cuda()
        reader = reader.eval()
        self.reader = reader
        self.tensorizer = tensorizer

        model_to_load = get_model_obj(self.reader)
        model_to_load.load_state_dict(saved_state.model_dict)
Esempio n. 13
0
    def __init__(self, cfg: DictConfig):
        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type, cfg
        )

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
        )
        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
Esempio n. 14
0
def main(cfg: DictConfig):

    assert cfg.pair_file, "Please specify passages source as pair_file param"
    assert cfg.model_file, "Please specify encoder checkpoint as model_file param"

    cfg = setup_cfg_gpu(cfg)

    saved_state = load_states_from_checkpoint(cfg.model_file)
    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    logger.info("CFG:")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type,
        cfg,
        inference_only=True,
        cache_dir='/n/fs/nlp-jl5167/cache')

    # load weights from the model file
    logger.info("Loading saved model state ...")
    logger.debug("saved model keys =%s", saved_state.model_dict.keys())

    encoder.load_state_dict(saved_state.model_dict)
    encoder.to(cfg.device)

    # Set seed
    # set_seed(args)
    cfg.encoder.sequence_length = 512  # TODO: Passage can be cut by max_seq_length
    logger.info(f"Max seq length: {cfg.encoder.sequence_length}")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    # Dump passages
    dump_passages(cfg, encoder, tensorizer)
Esempio n. 15
0
    def __init__(self, cfg: DictConfig):
        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        tensorizer, model, optimizer = init_biencoder_components(
            cfg.encoder.encoder_model_type, cfg
        )

        if cfg.deepspeed:
            model.half()

            # XXX
           #no_decay = ["bias", "LayerNorm.weight"]
           #
           #optimizer_grouped_parameters = [
           #    {
           #        "params": [
           #            p
           #            for n, p in model.named_parameters()
           #            if not any(nd in n for nd in no_decay)
           #        ],
           #        "weight_decay": cfg.train.weight_decay,
           #    },
           #    {
           #        "params": [
           #            p
           #            for n, p in model.named_parameters()
           #            if any(nd in n for nd in no_decay)
           #        ],
           #        "weight_decay": 0.0,
           #    },
           #]
    
            optimizer = DeepSpeedCPUAdam(optimizer.param_groups, lr=cfg.train.learning_rate, 
                    weight_decay=cfg.train.weight_decay)

        model, optimizer = setup_for_distributed_mode(
            model,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
        )

        self.biencoder = model
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        self.cfg = cfg
        self.ds_cfg = BiencoderDatasetsCfg(cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None
Esempio n. 16
0
def main(cfg: DictConfig):
    cfg = setup_cfg_gpu(cfg)
    logger.info("CFG (after gpu  configuration):")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    saved_state = load_states_from_checkpoint(cfg.model_file)
    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)

    encoder_path = cfg.encoder_path
    if encoder_path:
        logger.info("Selecting encoder: %s", encoder_path)
        encoder = getattr(encoder, encoder_path)
    else:
        logger.info("Selecting standard question encoder")
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device,
                                            cfg.n_gpu, cfg.local_rank,
                                            cfg.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")

    encoder_prefix = (encoder_path if encoder_path else "question_model") + "."
    prefix_len = len(encoder_prefix)

    logger.info("Encoder state prefix %s", encoder_prefix)
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith(encoder_prefix)
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    # get questions & answers
    questions = []
    question_answers = []

    if not cfg.qa_dataset:
        logger.warning("Please specify qa_dataset to use")
        return

    ds_key = cfg.qa_dataset
    logger.info("qa_dataset: %s", ds_key)

    qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
    qa_src.load_data()

    for ds_item in qa_src.data:
        question, answers = ds_item.query, ds_item.answers
        questions.append(question)
        question_answers.append(answers)

    index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
    logger.info("Index class %s ", type(index))
    index_buffer_sz = index.buffer_size
    index.init_index(vector_size)
    retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index)

    logger.info("Using special token %s", qa_src.special_query_token)
    questions_tensor = retriever.generate_question_vectors(
        questions, query_token=qa_src.special_query_token)

    if qa_src.selector:
        logger.info("Using custom representation token selector")
        retriever.selector = qa_src.selector

    id_prefixes = []
    ctx_sources = []
    for ctx_src in cfg.ctx_datatsets:
        ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
        id_prefixes.append(ctx_src.id_prefix)
        ctx_sources.append(ctx_src)

    logger.info("id_prefixes per dataset: %s", id_prefixes)

    # index all passages
    ctx_files_patterns = cfg.encoded_ctx_files
    index_path = cfg.index_path

    logger.info("ctx_files_patterns: %s", ctx_files_patterns)
    if ctx_files_patterns:
        assert len(ctx_files_patterns) == len(
            id_prefixes), "ctx len={} pref leb={}".format(
                len(ctx_files_patterns), len(id_prefixes))
    else:
        assert (
            index_path
        ), "Either encoded_ctx_files or index_path parameter should be set."

    input_paths = []
    path_id_prefixes = []
    for i, pattern in enumerate(ctx_files_patterns):
        pattern_files = glob.glob(pattern)
        pattern_id_prefix = id_prefixes[i]
        input_paths.extend(pattern_files)
        path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))

    logger.info("Embeddings files id prefixes: %s", path_id_prefixes)

    if index_path and index.index_exists(index_path):
        logger.info("Index path: %s", index_path)
        retriever.index.deserialize(index_path)
    else:
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index_encoded_data(input_paths,
                                     index_buffer_sz,
                                     path_id_prefixes=path_id_prefixes)
        if index_path:
            retriever.index.serialize(index_path)

    # get top k results
    top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(),
                                                cfg.n_docs)

    # we no longer need the index
    retriever = None

    all_passages = {}
    for ctx_src in ctx_sources:
        ctx_src.load_data_to(all_passages)

    if len(all_passages) == 0:
        raise RuntimeError(
            "No passages data found. Please specify ctx_file param properly.")

    if cfg.validate_as_tables:
        questions_doc_hits = validate_tables(
            all_passages,
            question_answers,
            top_ids_and_scores,
            cfg.validation_workers,
            cfg.match,
        )
    else:
        questions_doc_hits = validate(
            all_passages,
            question_answers,
            top_ids_and_scores,
            cfg.validation_workers,
            cfg.match,
        )

    if cfg.out_file:
        save_results(
            all_passages,
            questions,
            question_answers,
            top_ids_and_scores,
            questions_doc_hits,
            cfg.out_file,
        )

    if cfg.kilt_out_file:
        kilt_ctx = next(
            iter([
                ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)
            ]), None)
        if not kilt_ctx:
            raise RuntimeError("No Kilt compatible context file provided")
        assert hasattr(cfg, "kilt_out_file")
        kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file,
                                 cfg.kilt_out_file)
Esempio n. 17
0
def main(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(
        args.encoder_model_type, args, inference_only=True)

    encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu,
                                            args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')

    prefix_len = len('question_model.')
    question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                              key.startswith('question_model.')}
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info('Encoder vector_size=%d', vector_size)

    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size, args.index_buffer)
    else:
        index = DenseFlatIndexer(vector_size, args.index_buffer)

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)

    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)

    index_path = "_".join(input_paths[0].split("_")[:-1])
    if args.save_or_load_index and (os.path.exists(index_path) or os.path.exists(index_path + ".index.dpr")):
        retriever.index.deserialize_from(index_path)
    else:
        logger.info('Reading all passages data from files: %s', input_paths)
        retriever.index.index_data(input_paths)
        if args.save_or_load_index:
            retriever.index.serialize(index_path)
    # get questions & answers
    questions = []
    question_ids = []

    for ds_item in parse_qa_csv_file(args.qa_file):
        question_id, question = ds_item
        question_ids.append(question_id)
        questions.append(question)

    questions_tensor = retriever.generate_question_vectors(questions)

    # get top k results
    top_ids_and_scores = retriever.get_top_docs(
        questions_tensor.numpy(), args.n_docs)

    # all_passages = load_passages(args.ctx_file)

    # if len(all_passages) == 0:
    #     raise RuntimeError(
    #         'No passages data found. Please specify ctx_file param properly.')

    # questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers,
    #                               args.match)

    if args.out_file:
        save_results(question_ids,
                     top_ids_and_scores, args.out_file)
Esempio n. 18
0
def main(cfg: DictConfig):
    cfg = setup_cfg_gpu(cfg)
    saved_state = load_states_from_checkpoint(cfg.model_file)

    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    logger.info("CFG (after gpu  configuration):")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)

    logger.info("Loading saved model state ...")
    encoder.load_state(saved_state, strict=False)

    encoder_path = cfg.encoder_path
    if encoder_path:
        logger.info("Selecting encoder: %s", encoder_path)
        encoder = getattr(encoder, encoder_path)
    else:
        logger.info("Selecting standard question encoder")
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device,
                                            cfg.n_gpu, cfg.local_rank,
                                            cfg.fp16)
    encoder.eval()

    model_to_load = get_model_obj(encoder)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    # get questions & answers
    questions = []
    questions_text = []
    question_answers = []

    if not cfg.qa_dataset:
        logger.warning("Please specify qa_dataset to use")
        return

    ds_key = cfg.qa_dataset
    logger.info("qa_dataset: %s", ds_key)

    qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
    qa_src.load_data()

    total_queries = len(qa_src)
    for i in range(total_queries):
        qa_sample = qa_src[i]
        question, answers = qa_sample.query, qa_sample.answers
        questions.append(question)
        question_answers.append(answers)

    logger.info("questions len %d", len(questions))
    logger.info("questions_text len %d", len(questions_text))

    if cfg.rpc_retriever_cfg_file:
        index_buffer_sz = 1000
        retriever = DenseRPCRetriever(
            encoder,
            cfg.batch_size,
            tensorizer,
            cfg.rpc_retriever_cfg_file,
            vector_size,
            use_l2_conversion=cfg.use_l2_conversion,
        )
    else:
        index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
        logger.info("Local Index class %s ", type(index))
        index_buffer_sz = index.buffer_size
        index.init_index(vector_size)
        retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer,
                                        index)

    logger.info("Using special token %s", qa_src.special_query_token)
    questions_tensor = retriever.generate_question_vectors(
        questions, query_token=qa_src.special_query_token)

    if qa_src.selector:
        logger.info("Using custom representation token selector")
        retriever.selector = qa_src.selector

    index_path = cfg.index_path
    if cfg.rpc_retriever_cfg_file and cfg.rpc_index_id:
        retriever.load_index(cfg.rpc_index_id)
    elif index_path and index.index_exists(index_path):
        logger.info("Index path: %s", index_path)
        retriever.index.deserialize(index_path)
    else:
        # send data for indexing
        id_prefixes = []
        ctx_sources = []
        for ctx_src in cfg.ctx_datatsets:
            ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
            id_prefixes.append(ctx_src.id_prefix)
            ctx_sources.append(ctx_src)
            logger.info("ctx_sources: %s", type(ctx_src))

        logger.info("id_prefixes per dataset: %s", id_prefixes)

        # index all passages
        ctx_files_patterns = cfg.encoded_ctx_files

        logger.info("ctx_files_patterns: %s", ctx_files_patterns)
        if ctx_files_patterns:
            assert len(ctx_files_patterns) == len(
                id_prefixes), "ctx len={} pref leb={}".format(
                    len(ctx_files_patterns), len(id_prefixes))
        else:
            assert (
                index_path or cfg.rpc_index_id
            ), "Either encoded_ctx_files or index_path pr rpc_index_id parameter should be set."

        input_paths = []
        path_id_prefixes = []
        for i, pattern in enumerate(ctx_files_patterns):
            pattern_files = glob.glob(pattern)
            pattern_id_prefix = id_prefixes[i]
            input_paths.extend(pattern_files)
            path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))
        logger.info("Embeddings files id prefixes: %s", path_id_prefixes)
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index_encoded_data(input_paths,
                                     index_buffer_sz,
                                     path_id_prefixes=path_id_prefixes)
        if index_path:
            retriever.index.serialize(index_path)

    # get top k results
    top_results_and_scores = retriever.get_top_docs(questions_tensor.numpy(),
                                                    cfg.n_docs)

    if cfg.use_rpc_meta:
        questions_doc_hits = validate_from_meta(
            question_answers,
            top_results_and_scores,
            cfg.validation_workers,
            cfg.match,
            cfg.rpc_meta_compressed,
        )
        if cfg.out_file:
            save_results_from_meta(
                questions,
                question_answers,
                top_results_and_scores,
                questions_doc_hits,
                cfg.out_file,
                cfg.rpc_meta_compressed,
            )
    else:
        all_passages = get_all_passages(ctx_sources)
        if cfg.validate_as_tables:

            questions_doc_hits = validate_tables(
                all_passages,
                question_answers,
                top_results_and_scores,
                cfg.validation_workers,
                cfg.match,
            )

        else:
            questions_doc_hits = validate(
                all_passages,
                question_answers,
                top_results_and_scores,
                cfg.validation_workers,
                cfg.match,
            )

        if cfg.out_file:
            save_results(
                all_passages,
                questions_text if questions_text else questions,
                question_answers,
                top_results_and_scores,
                questions_doc_hits,
                cfg.out_file,
            )

    if cfg.kilt_out_file:
        kilt_ctx = next(
            iter([
                ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)
            ]), None)
        if not kilt_ctx:
            raise RuntimeError("No Kilt compatible context file provided")
        assert hasattr(cfg, "kilt_out_file")
        kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file,
                                 cfg.kilt_out_file)
Esempio n. 19
0
def main(args):
    questions = []
    question_answers = []
    for i, ds_item in enumerate(parse_qa_csv_file(args.qa_file)):
        #if i == 10:
        #    break
        question, answers = ds_item
        questions.append(question)
        question_answers.append(answers)
    if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert':
        saved_state = load_states_from_checkpoint(args.model_file)
        set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True)

    if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert':
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu,
                                            args.local_rank,
                                            args.fp16)
    encoder.eval()

    if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert':
        # load weights from the model file
        model_to_load = get_model_obj(encoder)
        logger.info('Loading saved model state ...')

        prefix_len = len('question_model.')
        question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                                  key.startswith('question_model.')}
        model_to_load.load_state_dict(question_encoder_state)
        vector_size = model_to_load.get_out_size()
        logger.info('Encoder vector_size=%d', vector_size)
    else:
        vector_size = 16

    index_buffer_sz = args.index_buffer

    if args.index_type == 'hnsw':
        index = DenseHNSWFlatIndexer(vector_size, index_buffer_sz)
        index_buffer_sz = -1  # encode all at once
    elif args.index_type == 'custom':
        index = CustomIndexer(vector_size, index_buffer_sz)
    else:
        index = DenseFlatIndexer(vector_size)

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index, args.encoder_model_type=='colbert')


    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)
    #logger.info('Reading all passages data from files: %s', input_paths)
    #memmap = retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz, memmap=args.encoder_model_type=='colbert')
    print(input_paths)
    index_path = "_".join(input_paths[0].split("_")[:-1])
    memmap_path = 'memmap.npy'
    print(args.save_or_load_index, os.path.exists(index_path), index_path)
    if args.save_or_load_index and os.path.exists(index_path+'.index.dpr'):
    #if False:
        retriever.index.deserialize_from(index_path)
        if args.encoder_model_type=='colbert':
            memmap = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=(21015324, 250, 16))
        else:
            memmap = None
    else:
        logger.info('Reading all passages data from files: %s', input_paths)
        if args.encoder_model_type=='colbert':
            memmap = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=(21015324, 250, 16))
        else:
            memmap = None
        retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz, memmap=memmap)
        if args.save_or_load_index:
            retriever.index.serialize(index_path)
    # get questions & answers



    questions_tensor = retriever.generate_question_vectors(questions)

    # get top k results
    top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), args.n_docs, is_colbert=args.encoder_model_type=='colbert')

    with open('approx_scores.pkl', 'wb') as f:
        pickle.dump(top_ids_and_scores, f)

    retriever.index.index.reset()
    if args.encoder_model_type=='colbert':
        logger.info('Colbert score') 
        top_ids_and_scores_colbert = retriever.colbert_search(questions_tensor.numpy(), memmap, top_ids_and_scores, args.n_docs)

    all_passages = load_passages(args.ctx_file)

    if len(all_passages) == 0:
        raise RuntimeError('No passages data found. Please specify ctx_file param properly.')


    with open('colbert_scores.pkl', 'wb') as f:
        pickle.dump(top_ids_and_scores_colbert, f)

    

    questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers,
                                  args.match)
    if args.encoder_model_type=='colbert':
        questions_doc_hits_colbert = validate(all_passages, question_answers, top_ids_and_scores_colbert, args.validation_workers,
                                  args.match)
        

    if args.out_file:
        save_results(all_passages, questions, question_answers, top_ids_and_scores, questions_doc_hits, args.out_file)
        if args.encoder_model_type=='colbert':
            save_results(all_passages, questions, question_answers, top_ids_and_scores_colbert, questions_doc_hits_colbert, args.out_file+'colbert')
Esempio n. 20
0
def main(cfg: DictConfig):

    assert cfg.model_file, "Please specify encoder checkpoint as model_file param"
    assert cfg.ctx_src, "Please specify passages source as ctx_src param"
    print(os.getcwd())
    cfg = setup_cfg_gpu(cfg)

    saved_state = load_states_from_checkpoint(cfg.model_file)
    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    logger.info("CFG:")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)

    encoder = encoder.ctx_model if cfg.encoder_type == "ctx" else encoder.question_model

    encoder, _ = setup_for_distributed_mode(
        encoder,
        None,
        cfg.device,
        cfg.n_gpu,
        cfg.local_rank,
        cfg.fp16,
        cfg.fp16_opt_level,
    )
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")
    logger.debug("saved model keys =%s", saved_state.model_dict.keys())

    prefix_len = len("ctx_model.")
    ctx_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith("ctx_model.")
    }
    model_to_load.load_state_dict(ctx_state)

    logger.info("reading data source: %s", cfg.ctx_src)

    ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src])
    all_passages_dict = {}
    ctx_src.load_data_to(all_passages_dict)
    all_passages = [(k, v) for k, v in all_passages_dict.items()]

    shard_size = math.ceil(len(all_passages) / cfg.num_shards)
    start_idx = cfg.shard_id * shard_size
    end_idx = start_idx + shard_size

    logger.info(
        "Producing encodings for passages range: %d to %d (out of total %d)",
        start_idx,
        end_idx,
        len(all_passages),
    )
    shard_passages = all_passages[start_idx:end_idx]

    data = gen_ctx_vectors(cfg, shard_passages, encoder, tensorizer, True)

    file = cfg.out_file + "_" + str(cfg.shard_id)
    pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True)
    logger.info("Writing results to %s" % file)
    with open(file, mode="wb") as f:
        pickle.dump(data, f)

    logger.info("Total passages processed %d. Written to %s", len(data), file)
Esempio n. 21
0
    if args.retrieval_type=="tfidf":
        import drqa_retriever as retriever
        ranker = retriever.get_class('tfidf')(tfidf_path=args.tfidf_path)
        top_ids_and_scores = []
        for question in questions:
            psg_ids, scores = ranker.closest_docs(question, args.n_docs)
            top_ids_and_scores.append((psg_ids, scores))
    else:
        from dpr.models import init_biencoder_components
        from dpr.utils.data_utils import Tensorizer
        from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint
        from dpr.indexer.faiss_indexers import DenseIndexer, DenseHNSWFlatIndexer, DenseFlatIndexer
        from dense_retriever import DenseRetriever

        saved_state = load_states_from_checkpoint(args.dpr_model_file)
        set_encoder_params_from_state(saved_state.encoder_params, args)
        tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True)
        encoder = encoder.question_model
        setup_args_gpu(args)
        encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu,
                                                args.local_rank,
                                                args.fp16)
        encoder.eval()

        # load weights from the model file
        model_to_load = get_model_obj(encoder)
        prefix_len = len('question_model.')
        question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                                key.startswith('question_model.')}
        model_to_load.load_state_dict(question_encoder_state)
def get_bert_reader_components(cfg, inference_only: bool = False, **kwargs):
    dropout = cfg.encoder.dropout if hasattr(cfg.encoder, "dropout") else 0.0
    num_layers = cfg.encoder.num_layers
    logger.info(
        f"Initializing InterPassageReader with number of layers for two components: {num_layers}."
    )

    encoder = HFBertEncoderWithNumLayers.init_encoder(
        cfg.encoder.pretrained_model_cfg,
        projection_dim=cfg.encoder.projection_dim,
        dropout=dropout,
        pretrained=cfg.encoder.pretrained,
        num_hidden_layers=num_layers[0],
        **kwargs)

    # Load state dict for the main encoder
    if cfg.encoder.pretrained_encoder_file is not None:
        assert cfg.encoder.pretrained_file is None, "Ambiguous pretrained model"
        assert cfg.encoder.encoder_initialize_from in [
            "question_encoder", "ctx_encoder"
        ]
        saved_state = load_states_from_checkpoint(
            cfg.encoder.pretrained_encoder_file).model_dict
        key_start = "question_model" if cfg.encoder.encoder_initialize_from == "question_encoder" else "ctx_model"

        new_saved_state = {}
        for key, param in saved_state.items():
            if not key.startswith(key_start):
                continue

            # Remove "xxx_model." part
            key = ".".join(key.split(".")[1:])

            # Truncate embeddings
            if "embeddings" in key and cfg.encoder.allow_embedding_size_mismatch:
                model_embeddings = [
                    v for k, v in encoder.named_parameters() if k == key
                ]
                assert len(model_embeddings) == 1
                model_embeddings = model_embeddings[0]

                assert len(param) >= len(model_embeddings)
                if len(param) > len(model_embeddings):
                    logger.info(
                        f"Truncating pretrained embedding ('{key}') size from {len(param)} to {len(model_embeddings)} "
                        f"by simply selecting the first {len(model_embeddings)} embeddings."
                    )
                    param = param[:len(model_embeddings)]

            new_saved_state[key] = param

        # Load to the model
        encoder.load_state(new_saved_state)

        # Freeze main encoder
        if cfg.encoder.encoder_freeze:
            for _, param in encoder.named_parameters():
                param.requires_grad = False

    elif cfg.encoder.encoder_freeze:
        raise ValueError(
            "You should not freeze a model that is not trained on a QA task.")

    inter_passage_encoder = HFBertEncoderWithNumLayers.init_encoder(
        cfg.encoder.pretrained_model_cfg,
        projection_dim=0,
        dropout=dropout,
        pretrained=cfg.encoder.pretrained,
        num_hidden_layers=num_layers[1],
        **kwargs)

    reader = InterPassageReader(encoder,
                                inter_passage_encoder,
                                hidden_size=None)

    optimizer = (get_optimizer(
        reader,
        learning_rate=cfg.train.learning_rate,
        encoder_learning_rate=cfg.encoder.encoder_learning_rate,
        adam_eps=cfg.train.adam_eps,
        weight_decay=cfg.train.weight_decay,
    ) if not inference_only else None)

    tensorizer = get_bert_tensorizer(cfg)
    return tensorizer, reader, optimizer
Esempio n. 23
0
def setup_dpr(model_file,
              ctx_file,
              encoded_ctx_file,
              hnsw_index=False,
              save_or_load_index=False):
    global retriever
    global all_passages
    global answer_cache
    global answer_cache_path
    parameter_setting = model_file + ctx_file + encoded_ctx_file
    answer_cache_path = hashlib.sha1(
        parameter_setting.encode("utf-8")).hexdigest()
    if os.path.exists(answer_cache_path):
        answer_cache = pickle.load(open(answer_cache_path, 'rb'))
    else:
        answer_cache = {}
    parser = argparse.ArgumentParser()
    add_encoder_params(parser)
    add_tokenizer_params(parser)
    add_cuda_params(parser)

    args = parser.parse_args()
    args.model_file = model_file
    args.ctx_file = ctx_file
    args.encoded_ctx_file = encoded_ctx_file
    args.hnsw_index = hnsw_index
    args.save_or_load_index = save_or_load_index
    args.batch_size = 1  # TODO

    setup_args_gpu(args)
    print_args(args)

    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")

    prefix_len = len("question_model.")
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith("question_model.")
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size, 50000)
    else:
        index = DenseFlatIndexer(vector_size, 50000,
                                 "IVF65536,PQ64")  #IVF65536

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)

    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)

    index_path = "_".join(input_paths[0].split("_")[:-1])
    if args.save_or_load_index and (os.path.exists(index_path) or
                                    os.path.exists(index_path + ".index.dpr")):
        retriever.index.deserialize_from(index_path)
    else:
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index.index_data(input_paths)

        if args.save_or_load_index:
            retriever.index.serialize(index_path)
        # get questions & answers

    all_passages = load_passages(args.ctx_file)