def get_query_encoder(args): saved_state = load_states_from_checkpoint(args.model_file) set_encoder_params_from_state(saved_state.encoder_params, args) print_args(args) tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) encoder = encoder.question_model #encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, # args.local_rank, # args.fp16, # args.fp16_opt_level) #encoder.eval() logger.info(args.__dict__) # load weights from the model file model_to_load = get_model_obj(encoder) logger.info('Loading saved model state ...') logger.debug('saved model keys =%s', saved_state.model_dict.keys()) prefix_len = len('question_model.') ctx_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith('question_model.') } model_to_load.load_state_dict(ctx_state) return model_to_load
def _load_saved_state(self, saved_state: CheckpointStateOFA): epoch = saved_state.epoch # offset is currently ignored since all checkpoints are made after full epochs offset = saved_state.offset if offset == 0: # epoch has been completed epoch += 1 logger.info("Loading checkpoint @ batch=%s and epoch=%s", offset, epoch) if self.cfg.ignore_checkpoint_offset: self.start_epoch = 0 self.start_batch = 0 else: self.start_epoch = epoch # TODO: offset doesn't work for multiset currently self.start_batch = 0 # offset model_to_load = get_model_obj(self.model) logger.info("Loading saved model state ...") model_to_load.load_state(saved_state) if not self.cfg.ignore_checkpoint_optimizer: if saved_state.biencoder_optimizer_dict: logger.info("Loading saved biencoder optimizer state ...") self.biencoder_optimizer.load_state_dict(saved_state.biencoder_optimizer_dict) if saved_state.biencoder_scheduler_dict: self.biencoder_scheduler_state = saved_state.biencoder_scheduler_dict if saved_state.reader_optimizer_dict: logger.info("Loading saved reader optimizer state ...") self.reader_optimizer.load_state_dict(saved_state.reader_optimizer_dict) if saved_state.reader_scheduler_dict: self.reader_scheduler_state = saved_state.reader_scheduler_dict
def main(args): saved_state = load_states_from_checkpoint(args.model_file) set_encoder_params_from_state(saved_state.encoder_params, args) print_args(args) tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) encoder = encoder.ctx_model encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, args.local_rank, args.fp16, args.fp16_opt_level) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info('Loading saved model state ...') logger.debug('saved model keys =%s', saved_state.model_dict.keys()) prefix_len = len('ctx_model.') ctx_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith('ctx_model.') } model_to_load.load_state_dict(ctx_state) logger.info('reading data from file=%s', args.ctx_file) rows = [] with open(args.ctx_file) as tsvfile: reader = csv.reader(tsvfile, delimiter='\t') # file format: doc_id, doc_text, title rows.extend([(row[0], row[1], row[2]) for row in reader if row[0] != 'id']) shard_size = int(len(rows) / args.num_shards) start_idx = args.shard_id * shard_size end_idx = start_idx + shard_size logger.info( 'Producing encodings for passages range: %d to %d (out of total %d)', start_idx, end_idx, len(rows)) rows = rows[start_idx:end_idx] data = gen_ctx_vectors(rows, encoder, tensorizer, True) file = args.out_file + '_' + str(args.shard_id) pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True) logger.info('Writing results to %s' % file) with open(file, mode='wb') as f: pickle.dump(data, f) logger.info('Total passages processed %d. Written to %s', len(data), file)
def __init__(self, name, **config): super().__init__(name) self.args = argparse.Namespace(**config) saved_state = load_states_from_checkpoint(self.args.model_file) set_encoder_params_from_state(saved_state.encoder_params, self.args) tensorizer, encoder, _ = init_biencoder_components( self.args.encoder_model_type, self.args, inference_only=True) encoder = encoder.question_model encoder, _ = setup_for_distributed_mode( encoder, None, self.args.device, self.args.n_gpu, self.args.local_rank, self.args.fp16, ) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) prefix_len = len("question_model.") question_encoder_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith("question_model.") } model_to_load.load_state_dict(question_encoder_state) vector_size = model_to_load.get_out_size() index_buffer_sz = self.args.index_buffer if self.args.hnsw_index: index = DenseHNSWFlatIndexer(vector_size) index.deserialize_from(self.args.hnsw_index_path) else: index = DenseFlatIndexer(vector_size) self.retriever = DenseRetriever(encoder, self.args.batch_size, tensorizer, index) # index all passages ctx_files_pattern = self.args.encoded_ctx_file input_paths = glob.glob(ctx_files_pattern) if not self.args.hnsw_index: self.retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz) # not needed for now self.all_passages = load_passages(self.args.ctx_file) self.KILT_mapping = None if self.args.KILT_mapping: self.KILT_mapping = pickle.load(open(self.args.KILT_mapping, "rb"))
def get_retriever(args): saved_state = load_states_from_checkpoint(args.model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) encoder = encoder.question_model encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, args.local_rank, args.fp16) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info('Loading saved model state ...') prefix_len = len('question_model.') question_encoder_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith('question_model.') } model_to_load.load_state_dict(question_encoder_state) vector_size = model_to_load.get_out_size() logger.info('Encoder vector_size=%d', vector_size) index_buffer_sz = args.index_buffer if args.hnsw_index: index = DenseHNSWFlatIndexer(vector_size) index_buffer_sz = -1 # encode all at once else: index = DenseFlatIndexer(vector_size) retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index, args.device) # index all passages if len(args.encoded_ctx_file) > 0: ctx_files_pattern = args.encoded_ctx_file input_paths = glob.glob(ctx_files_pattern) index_path = "_".join(input_paths[0].split("_")[:-1]) if args.save_or_load_index and os.path.exists(index_path): retriever.index.deserialize(index_path) else: logger.info('Reading all passages data from files: %s', input_paths) retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz) if args.save_or_load_index: retriever.index.serialize(index_path) # get questions & answers return retriever
def load_saved_state_into_model(encoder, prefix="ctx_model."): encoder.eval() model_to_load = get_model_obj(encoder) prefix_len = len(prefix) encoder_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith(prefix) } model_to_load.load_state_dict(encoder_state) return model_to_load
def main(args): saved_state = load_states_from_checkpoint(args.model_file) set_encoder_params_from_state(saved_state.encoder_params, args) print_args(args) tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) encoder = encoder.ctx_model encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, args.local_rank, args.fp16, args.fp16_opt_level) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info('Loading saved model state ...') logger.debug('saved model keys =%s', saved_state.model_dict.keys()) prefix_len = len('ctx_model.') ctx_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith('ctx_model.') } model_to_load.load_state_dict(ctx_state) logger.info('reading data from file=%s', args.ctx_file) rows = [] with open(args.ctx_file) as jsonfile: blob = json.load(jsonfile) for paper in blob.values(): s2_id = paper["paper_id"] title = paper.get("title") or "" abstract = paper.get("abstract") or "" rows.append((s2_id, abstract, title)) data = gen_ctx_vectors(rows, encoder, tensorizer, True) file = args.out_file pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True) logger.info('Writing results to %s' % file) with open(file, "w+") as f: for (s2_id, vec) in data: out = {"paper_id": s2_id, "embedding": vec.tolist()} f.write(json.dumps(out) + "\n") logger.info('Total passages processed %d. Written to %s', len(data), file)
def _load_saved_state(self, saved_state: CheckpointState): epoch = saved_state.epoch offset = saved_state.offset if offset == 0: # epoch has been completed epoch += 1 # logger.info('Loading checkpoint @ batch=%s and epoch=%s', offset, epoch) # self.start_epoch = epoch # self.start_batch = offset model_to_load = get_model_obj(self.reader) if saved_state.model_dict: logger.info('Loading model weights from saved state ...') model_to_load.load_state_dict(saved_state.model_dict)
def _save_checkpoint(self, scheduler, epoch: int, offset: int) -> str: args = self.args model_to_save = get_model_obj(self.reader) cp = os.path.join(args.output_dir, args.checkpoint_file_name + '.' + str(epoch) + ('.' + str(offset) if offset > 0 else '')) meta_params = get_encoder_params_state(args) state = CheckpointState(model_to_save.state_dict(), self.optimizer.state_dict(), scheduler.state_dict(), offset, epoch, meta_params ) torch.save(state._asdict(), cp) return cp
def _save_checkpoint(self, scheduler, epoch: int, offset: int) -> str: cfg = self.cfg model_to_save = get_model_obj(self.biencoder) cp = os.path.join(cfg.output_dir, cfg.checkpoint_file_name + "." + str(epoch)) meta_params = get_encoder_params_state_from_cfg(cfg) state = CheckpointState( model_to_save.get_state_dict(), self.optimizer.state_dict(), scheduler.state_dict(), offset, epoch, meta_params, ) torch.save(state._asdict(), cp) logger.info("Saved checkpoint at %s", cp) return cp
def __init__(self, args, model_file): self.args = args saved_state = load_states_from_checkpoint(model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, reader, optimizer = init_reader_components( args.encoder_model_type, args) tensorizer.pad_to_max = False del optimizer reader = reader.cuda() reader = reader.eval() self.reader = reader self.tensorizer = tensorizer model_to_load = get_model_obj(self.reader) model_to_load.load_state_dict(saved_state.model_dict)
def _load_saved_state(self, saved_state: CheckpointState): epoch = saved_state.epoch offset = saved_state.offset if offset == 0: # epoch has been completed epoch += 1 logger.info("Loading checkpoint @ batch=%s and epoch=%s", offset, epoch) self.start_epoch = epoch self.start_batch = offset model_to_load = get_model_obj(self.reader) if saved_state.model_dict: logger.info("Loading model weights from saved state ...") model_to_load.load_state_dict(saved_state.model_dict) logger.info("Loading saved optimizer state ...") if saved_state.optimizer_dict: self.optimizer.load_state_dict(saved_state.optimizer_dict) self.scheduler_state = saved_state.scheduler_dict
def _load_saved_state(self, saved_state: CheckpointState): epoch = saved_state.epoch offset = saved_state.offset if offset == 0: # epoch has been completed epoch += 1 logger.info('Loading checkpoint @ batch=%s and epoch=%s', offset, epoch) self.start_epoch = epoch self.start_batch = offset model_to_load = get_model_obj(self.biencoder) logger.info('Loading saved model state ...') model_to_load.load_state_dict(saved_state.model_dict, strict=False) # set strict=False if you use extra projection if saved_state.optimizer_dict: logger.info('Loading saved optimizer state ...') self.optimizer.load_state_dict(saved_state.optimizer_dict) if saved_state.scheduler_dict: self.scheduler_state = saved_state.scheduler_dict
def _save_checkpoint(self, scheduler, epoch: int, offset: int) -> str: args = self.args model_to_save = get_model_obj(self.biencoder) cp = os.path.join( args.output_dir, args.checkpoint_file_name + "." + str(epoch) + ("." + str(offset) if offset > 0 else ""), ) meta_params = get_encoder_params_state(args) state = CheckpointState( model_to_save.state_dict(), self.optimizer.state_dict(), scheduler.state_dict(), offset, epoch, meta_params, ) torch.save(state._asdict(), cp) logger.info("Saved checkpoint at %s", cp) return cp
def setup_dpr(model_file, ctx_file, encoded_ctx_file, hnsw_index=False, save_or_load_index=False): global retriever global all_passages global answer_cache global answer_cache_path parameter_setting = model_file + ctx_file + encoded_ctx_file answer_cache_path = hashlib.sha1( parameter_setting.encode("utf-8")).hexdigest() if os.path.exists(answer_cache_path): answer_cache = pickle.load(open(answer_cache_path, 'rb')) else: answer_cache = {} parser = argparse.ArgumentParser() add_encoder_params(parser) add_tokenizer_params(parser) add_cuda_params(parser) args = parser.parse_args() args.model_file = model_file args.ctx_file = ctx_file args.encoded_ctx_file = encoded_ctx_file args.hnsw_index = hnsw_index args.save_or_load_index = save_or_load_index args.batch_size = 1 # TODO setup_args_gpu(args) print_args(args) saved_state = load_states_from_checkpoint(args.model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) encoder = encoder.question_model encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, args.local_rank, args.fp16) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info("Loading saved model state ...") prefix_len = len("question_model.") question_encoder_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith("question_model.") } model_to_load.load_state_dict(question_encoder_state) vector_size = model_to_load.get_out_size() logger.info("Encoder vector_size=%d", vector_size) if args.hnsw_index: index = DenseHNSWFlatIndexer(vector_size, 50000) else: index = DenseFlatIndexer(vector_size, 50000, "IVF65536,PQ64") #IVF65536 retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index) # index all passages ctx_files_pattern = args.encoded_ctx_file input_paths = glob.glob(ctx_files_pattern) index_path = "_".join(input_paths[0].split("_")[:-1]) if args.save_or_load_index and (os.path.exists(index_path) or os.path.exists(index_path + ".index.dpr")): retriever.index.deserialize_from(index_path) else: logger.info("Reading all passages data from files: %s", input_paths) retriever.index.index_data(input_paths) if args.save_or_load_index: retriever.index.serialize(index_path) # get questions & answers all_passages = load_passages(args.ctx_file)
def main(cfg: DictConfig): cfg = setup_cfg_gpu(cfg) logger.info("CFG (after gpu configuration):") logger.info("%s", OmegaConf.to_yaml(cfg)) saved_state = load_states_from_checkpoint(cfg.model_file) set_cfg_params_from_state(saved_state.encoder_params, cfg) tensorizer, encoder, _ = init_biencoder_components( cfg.encoder.encoder_model_type, cfg, inference_only=True) encoder_path = cfg.encoder_path if encoder_path: logger.info("Selecting encoder: %s", encoder_path) encoder = getattr(encoder, encoder_path) else: logger.info("Selecting standard question encoder") encoder = encoder.question_model encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info("Loading saved model state ...") encoder_prefix = (encoder_path if encoder_path else "question_model") + "." prefix_len = len(encoder_prefix) logger.info("Encoder state prefix %s", encoder_prefix) question_encoder_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith(encoder_prefix) } model_to_load.load_state_dict(question_encoder_state) vector_size = model_to_load.get_out_size() logger.info("Encoder vector_size=%d", vector_size) # get questions & answers questions = [] question_answers = [] if not cfg.qa_dataset: logger.warning("Please specify qa_dataset to use") return ds_key = cfg.qa_dataset logger.info("qa_dataset: %s", ds_key) qa_src = hydra.utils.instantiate(cfg.datasets[ds_key]) qa_src.load_data() for ds_item in qa_src.data: question, answers = ds_item.query, ds_item.answers questions.append(question) question_answers.append(answers) index = hydra.utils.instantiate(cfg.indexers[cfg.indexer]) logger.info("Index class %s ", type(index)) index_buffer_sz = index.buffer_size index.init_index(vector_size) retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index) logger.info("Using special token %s", qa_src.special_query_token) questions_tensor = retriever.generate_question_vectors( questions, query_token=qa_src.special_query_token) if qa_src.selector: logger.info("Using custom representation token selector") retriever.selector = qa_src.selector id_prefixes = [] ctx_sources = [] for ctx_src in cfg.ctx_datatsets: ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src]) id_prefixes.append(ctx_src.id_prefix) ctx_sources.append(ctx_src) logger.info("id_prefixes per dataset: %s", id_prefixes) # index all passages ctx_files_patterns = cfg.encoded_ctx_files index_path = cfg.index_path logger.info("ctx_files_patterns: %s", ctx_files_patterns) if ctx_files_patterns: assert len(ctx_files_patterns) == len( id_prefixes), "ctx len={} pref leb={}".format( len(ctx_files_patterns), len(id_prefixes)) else: assert ( index_path ), "Either encoded_ctx_files or index_path parameter should be set." input_paths = [] path_id_prefixes = [] for i, pattern in enumerate(ctx_files_patterns): pattern_files = glob.glob(pattern) pattern_id_prefix = id_prefixes[i] input_paths.extend(pattern_files) path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files)) logger.info("Embeddings files id prefixes: %s", path_id_prefixes) if index_path and index.index_exists(index_path): logger.info("Index path: %s", index_path) retriever.index.deserialize(index_path) else: logger.info("Reading all passages data from files: %s", input_paths) retriever.index_encoded_data(input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes) if index_path: retriever.index.serialize(index_path) # get top k results top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), cfg.n_docs) # we no longer need the index retriever = None all_passages = {} for ctx_src in ctx_sources: ctx_src.load_data_to(all_passages) if len(all_passages) == 0: raise RuntimeError( "No passages data found. Please specify ctx_file param properly.") if cfg.validate_as_tables: questions_doc_hits = validate_tables( all_passages, question_answers, top_ids_and_scores, cfg.validation_workers, cfg.match, ) else: questions_doc_hits = validate( all_passages, question_answers, top_ids_and_scores, cfg.validation_workers, cfg.match, ) if cfg.out_file: save_results( all_passages, questions, question_answers, top_ids_and_scores, questions_doc_hits, cfg.out_file, ) if cfg.kilt_out_file: kilt_ctx = next( iter([ ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc) ]), None) if not kilt_ctx: raise RuntimeError("No Kilt compatible context file provided") assert hasattr(cfg, "kilt_out_file") kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file, cfg.kilt_out_file)
def main(args): saved_state = load_states_from_checkpoint(args.model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, encoder, _ = init_biencoder_components( args.encoder_model_type, args, inference_only=True) encoder = encoder.question_model encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, args.local_rank, args.fp16) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info('Loading saved model state ...') prefix_len = len('question_model.') question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith('question_model.')} model_to_load.load_state_dict(question_encoder_state) vector_size = model_to_load.get_out_size() logger.info('Encoder vector_size=%d', vector_size) if args.hnsw_index: index = DenseHNSWFlatIndexer(vector_size, args.index_buffer) else: index = DenseFlatIndexer(vector_size, args.index_buffer) retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index) # index all passages ctx_files_pattern = args.encoded_ctx_file input_paths = glob.glob(ctx_files_pattern) index_path = "_".join(input_paths[0].split("_")[:-1]) if args.save_or_load_index and (os.path.exists(index_path) or os.path.exists(index_path + ".index.dpr")): retriever.index.deserialize_from(index_path) else: logger.info('Reading all passages data from files: %s', input_paths) retriever.index.index_data(input_paths) if args.save_or_load_index: retriever.index.serialize(index_path) # get questions & answers questions = [] question_ids = [] for ds_item in parse_qa_csv_file(args.qa_file): question_id, question = ds_item question_ids.append(question_id) questions.append(question) questions_tensor = retriever.generate_question_vectors(questions) # get top k results top_ids_and_scores = retriever.get_top_docs( questions_tensor.numpy(), args.n_docs) # all_passages = load_passages(args.ctx_file) # if len(all_passages) == 0: # raise RuntimeError( # 'No passages data found. Please specify ctx_file param properly.') # questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers, # args.match) if args.out_file: save_results(question_ids, top_ids_and_scores, args.out_file)
def main(cfg: DictConfig): cfg = setup_cfg_gpu(cfg) saved_state = load_states_from_checkpoint(cfg.model_file) set_cfg_params_from_state(saved_state.encoder_params, cfg) logger.info("CFG (after gpu configuration):") logger.info("%s", OmegaConf.to_yaml(cfg)) tensorizer, encoder, _ = init_biencoder_components( cfg.encoder.encoder_model_type, cfg, inference_only=True) logger.info("Loading saved model state ...") encoder.load_state(saved_state, strict=False) encoder_path = cfg.encoder_path if encoder_path: logger.info("Selecting encoder: %s", encoder_path) encoder = getattr(encoder, encoder_path) else: logger.info("Selecting standard question encoder") encoder = encoder.question_model encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16) encoder.eval() model_to_load = get_model_obj(encoder) vector_size = model_to_load.get_out_size() logger.info("Encoder vector_size=%d", vector_size) # get questions & answers questions = [] questions_text = [] question_answers = [] if not cfg.qa_dataset: logger.warning("Please specify qa_dataset to use") return ds_key = cfg.qa_dataset logger.info("qa_dataset: %s", ds_key) qa_src = hydra.utils.instantiate(cfg.datasets[ds_key]) qa_src.load_data() total_queries = len(qa_src) for i in range(total_queries): qa_sample = qa_src[i] question, answers = qa_sample.query, qa_sample.answers questions.append(question) question_answers.append(answers) logger.info("questions len %d", len(questions)) logger.info("questions_text len %d", len(questions_text)) if cfg.rpc_retriever_cfg_file: index_buffer_sz = 1000 retriever = DenseRPCRetriever( encoder, cfg.batch_size, tensorizer, cfg.rpc_retriever_cfg_file, vector_size, use_l2_conversion=cfg.use_l2_conversion, ) else: index = hydra.utils.instantiate(cfg.indexers[cfg.indexer]) logger.info("Local Index class %s ", type(index)) index_buffer_sz = index.buffer_size index.init_index(vector_size) retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index) logger.info("Using special token %s", qa_src.special_query_token) questions_tensor = retriever.generate_question_vectors( questions, query_token=qa_src.special_query_token) if qa_src.selector: logger.info("Using custom representation token selector") retriever.selector = qa_src.selector index_path = cfg.index_path if cfg.rpc_retriever_cfg_file and cfg.rpc_index_id: retriever.load_index(cfg.rpc_index_id) elif index_path and index.index_exists(index_path): logger.info("Index path: %s", index_path) retriever.index.deserialize(index_path) else: # send data for indexing id_prefixes = [] ctx_sources = [] for ctx_src in cfg.ctx_datatsets: ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src]) id_prefixes.append(ctx_src.id_prefix) ctx_sources.append(ctx_src) logger.info("ctx_sources: %s", type(ctx_src)) logger.info("id_prefixes per dataset: %s", id_prefixes) # index all passages ctx_files_patterns = cfg.encoded_ctx_files logger.info("ctx_files_patterns: %s", ctx_files_patterns) if ctx_files_patterns: assert len(ctx_files_patterns) == len( id_prefixes), "ctx len={} pref leb={}".format( len(ctx_files_patterns), len(id_prefixes)) else: assert ( index_path or cfg.rpc_index_id ), "Either encoded_ctx_files or index_path pr rpc_index_id parameter should be set." input_paths = [] path_id_prefixes = [] for i, pattern in enumerate(ctx_files_patterns): pattern_files = glob.glob(pattern) pattern_id_prefix = id_prefixes[i] input_paths.extend(pattern_files) path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files)) logger.info("Embeddings files id prefixes: %s", path_id_prefixes) logger.info("Reading all passages data from files: %s", input_paths) retriever.index_encoded_data(input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes) if index_path: retriever.index.serialize(index_path) # get top k results top_results_and_scores = retriever.get_top_docs(questions_tensor.numpy(), cfg.n_docs) if cfg.use_rpc_meta: questions_doc_hits = validate_from_meta( question_answers, top_results_and_scores, cfg.validation_workers, cfg.match, cfg.rpc_meta_compressed, ) if cfg.out_file: save_results_from_meta( questions, question_answers, top_results_and_scores, questions_doc_hits, cfg.out_file, cfg.rpc_meta_compressed, ) else: all_passages = get_all_passages(ctx_sources) if cfg.validate_as_tables: questions_doc_hits = validate_tables( all_passages, question_answers, top_results_and_scores, cfg.validation_workers, cfg.match, ) else: questions_doc_hits = validate( all_passages, question_answers, top_results_and_scores, cfg.validation_workers, cfg.match, ) if cfg.out_file: save_results( all_passages, questions_text if questions_text else questions, question_answers, top_results_and_scores, questions_doc_hits, cfg.out_file, ) if cfg.kilt_out_file: kilt_ctx = next( iter([ ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc) ]), None) if not kilt_ctx: raise RuntimeError("No Kilt compatible context file provided") assert hasattr(cfg, "kilt_out_file") kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file, cfg.kilt_out_file)
def main(args): questions = [] question_answers = [] for i, ds_item in enumerate(parse_qa_csv_file(args.qa_file)): #if i == 10: # break question, answers = ds_item questions.append(question) question_answers.append(answers) if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert': saved_state = load_states_from_checkpoint(args.model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert': encoder = encoder.question_model encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, args.local_rank, args.fp16) encoder.eval() if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert': # load weights from the model file model_to_load = get_model_obj(encoder) logger.info('Loading saved model state ...') prefix_len = len('question_model.') question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith('question_model.')} model_to_load.load_state_dict(question_encoder_state) vector_size = model_to_load.get_out_size() logger.info('Encoder vector_size=%d', vector_size) else: vector_size = 16 index_buffer_sz = args.index_buffer if args.index_type == 'hnsw': index = DenseHNSWFlatIndexer(vector_size, index_buffer_sz) index_buffer_sz = -1 # encode all at once elif args.index_type == 'custom': index = CustomIndexer(vector_size, index_buffer_sz) else: index = DenseFlatIndexer(vector_size) retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index, args.encoder_model_type=='colbert') # index all passages ctx_files_pattern = args.encoded_ctx_file input_paths = glob.glob(ctx_files_pattern) #logger.info('Reading all passages data from files: %s', input_paths) #memmap = retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz, memmap=args.encoder_model_type=='colbert') print(input_paths) index_path = "_".join(input_paths[0].split("_")[:-1]) memmap_path = 'memmap.npy' print(args.save_or_load_index, os.path.exists(index_path), index_path) if args.save_or_load_index and os.path.exists(index_path+'.index.dpr'): #if False: retriever.index.deserialize_from(index_path) if args.encoder_model_type=='colbert': memmap = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=(21015324, 250, 16)) else: memmap = None else: logger.info('Reading all passages data from files: %s', input_paths) if args.encoder_model_type=='colbert': memmap = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=(21015324, 250, 16)) else: memmap = None retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz, memmap=memmap) if args.save_or_load_index: retriever.index.serialize(index_path) # get questions & answers questions_tensor = retriever.generate_question_vectors(questions) # get top k results top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), args.n_docs, is_colbert=args.encoder_model_type=='colbert') with open('approx_scores.pkl', 'wb') as f: pickle.dump(top_ids_and_scores, f) retriever.index.index.reset() if args.encoder_model_type=='colbert': logger.info('Colbert score') top_ids_and_scores_colbert = retriever.colbert_search(questions_tensor.numpy(), memmap, top_ids_and_scores, args.n_docs) all_passages = load_passages(args.ctx_file) if len(all_passages) == 0: raise RuntimeError('No passages data found. Please specify ctx_file param properly.') with open('colbert_scores.pkl', 'wb') as f: pickle.dump(top_ids_and_scores_colbert, f) questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers, args.match) if args.encoder_model_type=='colbert': questions_doc_hits_colbert = validate(all_passages, question_answers, top_ids_and_scores_colbert, args.validation_workers, args.match) if args.out_file: save_results(all_passages, questions, question_answers, top_ids_and_scores, questions_doc_hits, args.out_file) if args.encoder_model_type=='colbert': save_results(all_passages, questions, question_answers, top_ids_and_scores_colbert, questions_doc_hits_colbert, args.out_file+'colbert')
def main(cfg: DictConfig): assert cfg.model_file, "Please specify encoder checkpoint as model_file param" assert cfg.ctx_src, "Please specify passages source as ctx_src param" print(os.getcwd()) cfg = setup_cfg_gpu(cfg) saved_state = load_states_from_checkpoint(cfg.model_file) set_cfg_params_from_state(saved_state.encoder_params, cfg) logger.info("CFG:") logger.info("%s", OmegaConf.to_yaml(cfg)) tensorizer, encoder, _ = init_biencoder_components( cfg.encoder.encoder_model_type, cfg, inference_only=True) encoder = encoder.ctx_model if cfg.encoder_type == "ctx" else encoder.question_model encoder, _ = setup_for_distributed_mode( encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16, cfg.fp16_opt_level, ) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info("Loading saved model state ...") logger.debug("saved model keys =%s", saved_state.model_dict.keys()) prefix_len = len("ctx_model.") ctx_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith("ctx_model.") } model_to_load.load_state_dict(ctx_state) logger.info("reading data source: %s", cfg.ctx_src) ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src]) all_passages_dict = {} ctx_src.load_data_to(all_passages_dict) all_passages = [(k, v) for k, v in all_passages_dict.items()] shard_size = math.ceil(len(all_passages) / cfg.num_shards) start_idx = cfg.shard_id * shard_size end_idx = start_idx + shard_size logger.info( "Producing encodings for passages range: %d to %d (out of total %d)", start_idx, end_idx, len(all_passages), ) shard_passages = all_passages[start_idx:end_idx] data = gen_ctx_vectors(cfg, shard_passages, encoder, tensorizer, True) file = cfg.out_file + "_" + str(cfg.shard_id) pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True) logger.info("Writing results to %s" % file) with open(file, mode="wb") as f: pickle.dump(data, f) logger.info("Total passages processed %d. Written to %s", len(data), file)
from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint from dpr.indexer.faiss_indexers import DenseIndexer, DenseHNSWFlatIndexer, DenseFlatIndexer from dense_retriever import DenseRetriever saved_state = load_states_from_checkpoint(args.dpr_model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) encoder = encoder.question_model setup_args_gpu(args) encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, args.local_rank, args.fp16) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) prefix_len = len('question_model.') question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith('question_model.')} model_to_load.load_state_dict(question_encoder_state) vector_size = model_to_load.get_out_size() index_buffer_sz = args.index_buffer if args.hnsw_index: index = DenseHNSWFlatIndexer(vector_size) index_buffer_sz = -1 # encode all at once else: index = DenseFlatIndexer(vector_size) retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index) retriever.index.deserialize_from(args.dense_index_path)
def validate_nll(self) -> float: logger.info("NLL validation ...") cfg = self.cfg self.biencoder.eval() if not self.dev_iterator: self.dev_iterator = self.get_data_iterator( cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank) data_iterator = self.dev_iterator total_loss = 0.0 start_time = time.time() total_correct_predictions = 0 num_hard_negatives = cfg.train.hard_negatives num_other_negatives = cfg.train.other_negatives log_result_step = cfg.train.log_batch_step batches = 0 dataset = 0 biencoder = get_model_obj(self.biencoder) for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch logger.info("Eval step: %d ,rnk=%s", i, cfg.local_rank) biencoder_input = biencoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=False, ) # get the token to be used for representation selection ds_cfg = self.ds_cfg.dev_datasets[dataset] rep_positions = ds_cfg.selector.get_positions( biencoder_input.question_ids, self.tensorizer) encoder_type = ds_cfg.encoder_type loss, correct_cnt = _do_biencoder_fwd_pass( self.biencoder, biencoder_input, self.tensorizer, cfg, encoder_type=encoder_type, rep_positions=rep_positions, ) total_loss += loss.item() total_correct_predictions += correct_cnt batches += 1 if (i + 1) % log_result_step == 0: logger.info( "Eval step: %d , used_time=%f sec., loss=%f ", i, time.time() - start_time, loss.item(), ) total_loss = total_loss / batches total_samples = batches * cfg.train.dev_batch_size * self.distributed_factor correct_ratio = float(total_correct_predictions / total_samples) logger.info( "NLL Validation: loss = %f. correct prediction ratio %d/%d ~ %f", total_loss, total_correct_predictions, total_samples, correct_ratio, ) return total_loss
def _train_epoch( self, scheduler, epoch: int, eval_step: int, train_data_iterator: MultiSetDataIterator, ): cfg = self.cfg rolling_train_loss = 0.0 epoch_loss = 0 epoch_correct_predictions = 0 log_result_step = cfg.train.log_batch_step rolling_loss_step = cfg.train.train_rolling_loss_step num_hard_negatives = cfg.train.hard_negatives num_other_negatives = cfg.train.other_negatives seed = cfg.seed self.biencoder.train() epoch_batches = train_data_iterator.max_iterations data_iteration = 0 biencoder = get_model_obj(self.biencoder) dataset = 0 for i, samples_batch in enumerate( train_data_iterator.iterate_ds_data(epoch=epoch)): if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch ds_cfg = self.ds_cfg.train_datasets[dataset] special_token = ds_cfg.special_token encoder_type = ds_cfg.encoder_type shuffle_positives = ds_cfg.shuffle_positives # to be able to resume shuffled ctx- pools data_iteration = train_data_iterator.get_iteration() random.seed(seed + epoch + data_iteration) biencoder_batch = biencoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=True, shuffle_positives=shuffle_positives, query_token=special_token, ) # get the token to be used for representation selection from dpr.utils.data_utils import DEFAULT_SELECTOR selector = ds_cfg.selector if ds_cfg else DEFAULT_SELECTOR rep_positions = selector.get_positions( biencoder_batch.question_ids, self.tensorizer) loss_scale = cfg.loss_scale_factors[ dataset] if cfg.loss_scale_factors else None loss, correct_cnt = _do_biencoder_fwd_pass( self.biencoder, biencoder_batch, self.tensorizer, cfg, encoder_type=encoder_type, rep_positions=rep_positions, loss_scale=loss_scale, ) epoch_correct_predictions += correct_cnt epoch_loss += loss.item() rolling_train_loss += loss.item() if cfg.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if cfg.train.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), cfg.train.max_grad_norm) else: loss.backward() if cfg.train.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.biencoder.parameters(), cfg.train.max_grad_norm) if (i + 1) % cfg.train.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.biencoder.zero_grad() if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, loss=%f, lr=%f", epoch, data_iteration, epoch_batches, loss.item(), lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 if data_iteration % eval_step == 0: logger.info( "rank=%d, Validation: Epoch: %d Step: %d/%d", cfg.local_rank, epoch, data_iteration, epoch_batches, ) self.validate_and_save(epoch, train_data_iterator.get_iteration(), scheduler) self.biencoder.train() logger.info("Epoch finished on %d", cfg.local_rank) self.validate_and_save(epoch, data_iteration, scheduler) epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) logger.info("epoch total correct predictions=%d", epoch_correct_predictions)
def validate_average_rank(self) -> float: """ Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset. It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params) and stores them in RAM as well as question vectors. Then the similarity scores are calculted for the entire num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion. Each question's gold passage rank in that sorted list of scores is averaged across all the questions. :return: averaged rank number """ logger.info("Average rank validation ...") cfg = self.cfg self.biencoder.eval() distributed_factor = self.distributed_factor if not self.dev_iterator: self.dev_iterator = self.get_data_iterator( cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank) data_iterator = self.dev_iterator sub_batch_size = cfg.train.val_av_rank_bsz sim_score_f = BiEncoderNllLoss.get_similarity_function() q_represenations = [] ctx_represenations = [] positive_idx_per_question = [] num_hard_negatives = cfg.train.val_av_rank_hard_neg num_other_negatives = cfg.train.val_av_rank_other_neg log_result_step = cfg.train.log_batch_step dataset = 0 biencoder = get_model_obj(self.biencoder) for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): # samples += 1 if len(q_represenations ) > cfg.train.val_av_rank_max_qs / distributed_factor: break if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch biencoder_input = biencoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=False, ) total_ctxs = len(ctx_represenations) ctxs_ids = biencoder_input.context_ids ctxs_segments = biencoder_input.ctx_segments bsz = ctxs_ids.size(0) # get the token to be used for representation selection ds_cfg = self.ds_cfg.dev_datasets[dataset] encoder_type = ds_cfg.encoder_type rep_positions = ds_cfg.selector.get_positions( biencoder_input.question_ids, self.tensorizer) # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch for j, batch_start in enumerate(range(0, bsz, sub_batch_size)): q_ids, q_segments = ((biencoder_input.question_ids, biencoder_input.question_segments) if j == 0 else (None, None)) if j == 0 and cfg.n_gpu > 1 and q_ids.size(0) == 1: # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0, # otherwise the other input tensors will be split but only the first split will be called continue ctx_ids_batch = ctxs_ids[batch_start:batch_start + sub_batch_size] ctx_seg_batch = ctxs_segments[batch_start:batch_start + sub_batch_size] q_attn_mask = self.tensorizer.get_attn_mask(q_ids) ctx_attn_mask = self.tensorizer.get_attn_mask(ctx_ids_batch) with torch.no_grad(): q_dense, ctx_dense = self.biencoder( q_ids, q_segments, q_attn_mask, ctx_ids_batch, ctx_seg_batch, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) if q_dense is not None: q_represenations.extend(q_dense.cpu().split(1, dim=0)) ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0)) batch_positive_idxs = biencoder_input.is_positive positive_idx_per_question.extend( [total_ctxs + v for v in batch_positive_idxs]) if (i + 1) % log_result_step == 0: logger.info( "Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d", i, len(ctx_represenations), len(q_represenations), ) ctx_represenations = torch.cat(ctx_represenations, dim=0) q_represenations = torch.cat(q_represenations, dim=0) logger.info("Av.rank validation: total q_vectors size=%s", q_represenations.size()) logger.info("Av.rank validation: total ctx_vectors size=%s", ctx_represenations.size()) q_num = q_represenations.size(0) assert q_num == len(positive_idx_per_question) scores = sim_score_f(q_represenations, ctx_represenations) values, indices = torch.sort(scores, dim=1, descending=True) rank = 0 for i, idx in enumerate(positive_idx_per_question): # aggregate the rank of the known gold passage in the sorted results for each question gold_idx = (indices[i] == idx).nonzero() rank += gold_idx.item() if distributed_factor > 1: # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank # NOTE: the set of passages is still unique for every node eval_stats = all_gather_list([rank, q_num], max_size=100) for i, item in enumerate(eval_stats): remote_rank, remote_q_num = item if i != cfg.local_rank: rank += remote_rank q_num += remote_q_num av_rank = float(rank / q_num) logger.info("Av.rank validation: average rank %s, total questions=%d", av_rank, q_num) return av_rank