def main(cfg: DictConfig): if cfg.train.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( cfg.train.gradient_accumulation_steps ) ) if cfg.output_dir is not None: os.makedirs(cfg.output_dir, exist_ok=True) cfg = setup_cfg_gpu(cfg) set_seed(cfg) if cfg.local_rank in [-1, 0]: logger.info("CFG (after gpu configuration):") logger.info("%s", OmegaConf.to_yaml(cfg)) trainer = BiEncoderTrainer(cfg) if cfg.train_datasets and len(cfg.train_datasets) > 0: trainer.run_train() elif cfg.model_file and cfg.dev_datasets: logger.info( "No train files are specified. Run 2 types of validation for specified model file" ) trainer.validate_nll() trainer.validate_average_rank() else: logger.warning( "Neither train_file or (model_file & dev_file) parameters are specified. Nothing to do." )
def main(cfg: DictConfig): if cfg.output_dir is not None: os.makedirs(cfg.output_dir, exist_ok=True) cfg = setup_cfg_gpu(cfg) set_seed(cfg) get_gpu_info( rank=cfg.local_rank) # for now only work with single-GPU and DDP mode if cfg.local_rank in [-1, 0]: logger.info("CFG (after gpu configuration):") logger.info("%s", OmegaConf.to_yaml(cfg)) # Save config with open("config.yaml", "w") as fout: yaml.dump(eval(str(cfg)), fout) trainer = ReaderTrainer(cfg) if cfg.train_files is not None: trainer.run_train() elif cfg.dev_files: logger.info("No train files are specified. Run validation.") trainer.validate() else: logger.warning( "Neither train_file or (model_file & dev_file) parameters are specified. Nothing to do." )
def main(cfg: DictConfig): if cfg.train.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( cfg.train.gradient_accumulation_steps ) ) if cfg.output_dir is not None: os.makedirs(cfg.output_dir, exist_ok=True) if cfg.deepspeed: os.environ.setdefault('RANK', '0') os.environ.setdefault('LOCAL_RANK', '0') os.environ.setdefault('WORLD_SIZE', '1') os.environ.setdefault('MASTER_PORT', '3600') os.environ.setdefault('MASTER_ADDR', '127.0.0.1') # dist_init in cfg = setup_cfg_gpu(cfg) set_seed(cfg) if cfg.local_rank in [-1, 0]: logger.info("CFG (after gpu configuration):") logger.info("%s", OmegaConf.to_yaml(cfg)) trainer = BiEncoderTrainer(cfg) if cfg.train_datasets and len(cfg.train_datasets) > 0: trainer.run_train() elif cfg.model_file and cfg.dev_datasets: logger.info( "No train files are specified. Run 2 types of validation for specified model file" ) trainer.validate_nll() trainer.validate_average_rank() else: logger.warning( "Neither train_file or (model_file & dev_file) parameters are specified. Nothing to do." ) if cfg.deepspeed: dist_cleanup()
def main(cfg: DictConfig): if cfg.train.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( cfg.train.gradient_accumulation_steps ) ) if cfg.output_dir is not None: os.makedirs(cfg.output_dir, exist_ok=True) cfg = setup_cfg_gpu(cfg) set_seed(cfg) if cfg.local_rank in [-1, 0]: logger.info("CFG (after gpu configuration):") logger.info("%s", OmegaConf.to_yaml(cfg)) # Save config with open("config.yaml", "w") as fout: yaml.dump(eval(str(cfg)), fout) trainer = OneForAllTrainer(cfg) if cfg.train_datasets and len(cfg.train_datasets) > 0: trainer.run_train() elif cfg.model_file and cfg.dev_datasets: logger.info("No train files are specified.") if cfg.evaluate_retriever: logger.info("Run 2 types of retriever validation for specified model file") trainer.validate_biencoder_nll() trainer.validate_biencoder_average_rank() if cfg.evaluate_reader: logger.info("Run reader validation for specified model file") trainer.validate_reader() else: logger.warning( "Neither train_file or (model_file & dev_file) parameters are specified. Nothing to do." )
def main(cfg: DictConfig): assert cfg.pair_file, "Please specify passages source as pair_file param" assert cfg.model_file, "Please specify encoder checkpoint as model_file param" cfg = setup_cfg_gpu(cfg) saved_state = load_states_from_checkpoint(cfg.model_file) set_cfg_params_from_state(saved_state.encoder_params, cfg) logger.info("CFG:") logger.info("%s", OmegaConf.to_yaml(cfg)) tensorizer, encoder, _ = init_biencoder_components( cfg.encoder.encoder_model_type, cfg, inference_only=True, cache_dir='/n/fs/nlp-jl5167/cache') # load weights from the model file logger.info("Loading saved model state ...") logger.debug("saved model keys =%s", saved_state.model_dict.keys()) encoder.load_state_dict(saved_state.model_dict) encoder.to(cfg.device) # Set seed # set_seed(args) cfg.encoder.sequence_length = 512 # TODO: Passage can be cut by max_seq_length logger.info(f"Max seq length: {cfg.encoder.sequence_length}") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Dump passages dump_passages(cfg, encoder, tensorizer)
def main(cfg: DictConfig): if cfg.output_dir is not None: os.makedirs(cfg.output_dir, exist_ok=True) cfg = setup_cfg_gpu(cfg) set_seed(cfg) if cfg.local_rank in [-1, 0]: logger.info("CFG (after gpu configuration):") logger.info("%s", OmegaConf.to_yaml(cfg)) trainer = ReaderTrainer(cfg) if cfg.train_files is not None: trainer.run_train() elif cfg.dev_files: logger.info("No train files are specified. Run validation.") trainer.validate() else: logger.warning( "Neither train_file or (model_file & dev_file) parameters are specified. Nothing to do." )
def main(cfg: DictConfig): cfg = setup_cfg_gpu(cfg) logger.info("CFG (after gpu configuration):") logger.info("%s", OmegaConf.to_yaml(cfg)) saved_state = load_states_from_checkpoint(cfg.model_file) set_cfg_params_from_state(saved_state.encoder_params, cfg) tensorizer, encoder, _ = init_biencoder_components( cfg.encoder.encoder_model_type, cfg, inference_only=True) encoder_path = cfg.encoder_path if encoder_path: logger.info("Selecting encoder: %s", encoder_path) encoder = getattr(encoder, encoder_path) else: logger.info("Selecting standard question encoder") encoder = encoder.question_model encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info("Loading saved model state ...") encoder_prefix = (encoder_path if encoder_path else "question_model") + "." prefix_len = len(encoder_prefix) logger.info("Encoder state prefix %s", encoder_prefix) question_encoder_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith(encoder_prefix) } model_to_load.load_state_dict(question_encoder_state) vector_size = model_to_load.get_out_size() logger.info("Encoder vector_size=%d", vector_size) # get questions & answers questions = [] question_answers = [] if not cfg.qa_dataset: logger.warning("Please specify qa_dataset to use") return ds_key = cfg.qa_dataset logger.info("qa_dataset: %s", ds_key) qa_src = hydra.utils.instantiate(cfg.datasets[ds_key]) qa_src.load_data() for ds_item in qa_src.data: question, answers = ds_item.query, ds_item.answers questions.append(question) question_answers.append(answers) index = hydra.utils.instantiate(cfg.indexers[cfg.indexer]) logger.info("Index class %s ", type(index)) index_buffer_sz = index.buffer_size index.init_index(vector_size) retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index) logger.info("Using special token %s", qa_src.special_query_token) questions_tensor = retriever.generate_question_vectors( questions, query_token=qa_src.special_query_token) if qa_src.selector: logger.info("Using custom representation token selector") retriever.selector = qa_src.selector id_prefixes = [] ctx_sources = [] for ctx_src in cfg.ctx_datatsets: ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src]) id_prefixes.append(ctx_src.id_prefix) ctx_sources.append(ctx_src) logger.info("id_prefixes per dataset: %s", id_prefixes) # index all passages ctx_files_patterns = cfg.encoded_ctx_files index_path = cfg.index_path logger.info("ctx_files_patterns: %s", ctx_files_patterns) if ctx_files_patterns: assert len(ctx_files_patterns) == len( id_prefixes), "ctx len={} pref leb={}".format( len(ctx_files_patterns), len(id_prefixes)) else: assert ( index_path ), "Either encoded_ctx_files or index_path parameter should be set." input_paths = [] path_id_prefixes = [] for i, pattern in enumerate(ctx_files_patterns): pattern_files = glob.glob(pattern) pattern_id_prefix = id_prefixes[i] input_paths.extend(pattern_files) path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files)) logger.info("Embeddings files id prefixes: %s", path_id_prefixes) if index_path and index.index_exists(index_path): logger.info("Index path: %s", index_path) retriever.index.deserialize(index_path) else: logger.info("Reading all passages data from files: %s", input_paths) retriever.index_encoded_data(input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes) if index_path: retriever.index.serialize(index_path) # get top k results top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), cfg.n_docs) # we no longer need the index retriever = None all_passages = {} for ctx_src in ctx_sources: ctx_src.load_data_to(all_passages) if len(all_passages) == 0: raise RuntimeError( "No passages data found. Please specify ctx_file param properly.") if cfg.validate_as_tables: questions_doc_hits = validate_tables( all_passages, question_answers, top_ids_and_scores, cfg.validation_workers, cfg.match, ) else: questions_doc_hits = validate( all_passages, question_answers, top_ids_and_scores, cfg.validation_workers, cfg.match, ) if cfg.out_file: save_results( all_passages, questions, question_answers, top_ids_and_scores, questions_doc_hits, cfg.out_file, ) if cfg.kilt_out_file: kilt_ctx = next( iter([ ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc) ]), None) if not kilt_ctx: raise RuntimeError("No Kilt compatible context file provided") assert hasattr(cfg, "kilt_out_file") kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file, cfg.kilt_out_file)
def main(cfg: DictConfig): assert cfg.model_file, "Please specify encoder checkpoint as model_file param" assert cfg.ctx_src, "Please specify passages source as ctx_src param" print(os.getcwd()) cfg = setup_cfg_gpu(cfg) saved_state = load_states_from_checkpoint(cfg.model_file) set_cfg_params_from_state(saved_state.encoder_params, cfg) logger.info("CFG:") logger.info("%s", OmegaConf.to_yaml(cfg)) tensorizer, encoder, _ = init_biencoder_components( cfg.encoder.encoder_model_type, cfg, inference_only=True) encoder = encoder.ctx_model if cfg.encoder_type == "ctx" else encoder.question_model encoder, _ = setup_for_distributed_mode( encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16, cfg.fp16_opt_level, ) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info("Loading saved model state ...") logger.debug("saved model keys =%s", saved_state.model_dict.keys()) prefix_len = len("ctx_model.") ctx_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith("ctx_model.") } model_to_load.load_state_dict(ctx_state) logger.info("reading data source: %s", cfg.ctx_src) ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src]) all_passages_dict = {} ctx_src.load_data_to(all_passages_dict) all_passages = [(k, v) for k, v in all_passages_dict.items()] shard_size = math.ceil(len(all_passages) / cfg.num_shards) start_idx = cfg.shard_id * shard_size end_idx = start_idx + shard_size logger.info( "Producing encodings for passages range: %d to %d (out of total %d)", start_idx, end_idx, len(all_passages), ) shard_passages = all_passages[start_idx:end_idx] data = gen_ctx_vectors(cfg, shard_passages, encoder, tensorizer, True) file = cfg.out_file + "_" + str(cfg.shard_id) pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True) logger.info("Writing results to %s" % file) with open(file, mode="wb") as f: pickle.dump(data, f) logger.info("Total passages processed %d. Written to %s", len(data), file)
def main(cfg: DictConfig): cfg = setup_cfg_gpu(cfg) saved_state = load_states_from_checkpoint(cfg.model_file) set_cfg_params_from_state(saved_state.encoder_params, cfg) logger.info("CFG (after gpu configuration):") logger.info("%s", OmegaConf.to_yaml(cfg)) tensorizer, encoder, _ = init_biencoder_components( cfg.encoder.encoder_model_type, cfg, inference_only=True) logger.info("Loading saved model state ...") encoder.load_state(saved_state, strict=False) encoder_path = cfg.encoder_path if encoder_path: logger.info("Selecting encoder: %s", encoder_path) encoder = getattr(encoder, encoder_path) else: logger.info("Selecting standard question encoder") encoder = encoder.question_model encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16) encoder.eval() model_to_load = get_model_obj(encoder) vector_size = model_to_load.get_out_size() logger.info("Encoder vector_size=%d", vector_size) # get questions & answers questions = [] questions_text = [] question_answers = [] if not cfg.qa_dataset: logger.warning("Please specify qa_dataset to use") return ds_key = cfg.qa_dataset logger.info("qa_dataset: %s", ds_key) qa_src = hydra.utils.instantiate(cfg.datasets[ds_key]) qa_src.load_data() total_queries = len(qa_src) for i in range(total_queries): qa_sample = qa_src[i] question, answers = qa_sample.query, qa_sample.answers questions.append(question) question_answers.append(answers) logger.info("questions len %d", len(questions)) logger.info("questions_text len %d", len(questions_text)) if cfg.rpc_retriever_cfg_file: index_buffer_sz = 1000 retriever = DenseRPCRetriever( encoder, cfg.batch_size, tensorizer, cfg.rpc_retriever_cfg_file, vector_size, use_l2_conversion=cfg.use_l2_conversion, ) else: index = hydra.utils.instantiate(cfg.indexers[cfg.indexer]) logger.info("Local Index class %s ", type(index)) index_buffer_sz = index.buffer_size index.init_index(vector_size) retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index) logger.info("Using special token %s", qa_src.special_query_token) questions_tensor = retriever.generate_question_vectors( questions, query_token=qa_src.special_query_token) if qa_src.selector: logger.info("Using custom representation token selector") retriever.selector = qa_src.selector index_path = cfg.index_path if cfg.rpc_retriever_cfg_file and cfg.rpc_index_id: retriever.load_index(cfg.rpc_index_id) elif index_path and index.index_exists(index_path): logger.info("Index path: %s", index_path) retriever.index.deserialize(index_path) else: # send data for indexing id_prefixes = [] ctx_sources = [] for ctx_src in cfg.ctx_datatsets: ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src]) id_prefixes.append(ctx_src.id_prefix) ctx_sources.append(ctx_src) logger.info("ctx_sources: %s", type(ctx_src)) logger.info("id_prefixes per dataset: %s", id_prefixes) # index all passages ctx_files_patterns = cfg.encoded_ctx_files logger.info("ctx_files_patterns: %s", ctx_files_patterns) if ctx_files_patterns: assert len(ctx_files_patterns) == len( id_prefixes), "ctx len={} pref leb={}".format( len(ctx_files_patterns), len(id_prefixes)) else: assert ( index_path or cfg.rpc_index_id ), "Either encoded_ctx_files or index_path pr rpc_index_id parameter should be set." input_paths = [] path_id_prefixes = [] for i, pattern in enumerate(ctx_files_patterns): pattern_files = glob.glob(pattern) pattern_id_prefix = id_prefixes[i] input_paths.extend(pattern_files) path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files)) logger.info("Embeddings files id prefixes: %s", path_id_prefixes) logger.info("Reading all passages data from files: %s", input_paths) retriever.index_encoded_data(input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes) if index_path: retriever.index.serialize(index_path) # get top k results top_results_and_scores = retriever.get_top_docs(questions_tensor.numpy(), cfg.n_docs) if cfg.use_rpc_meta: questions_doc_hits = validate_from_meta( question_answers, top_results_and_scores, cfg.validation_workers, cfg.match, cfg.rpc_meta_compressed, ) if cfg.out_file: save_results_from_meta( questions, question_answers, top_results_and_scores, questions_doc_hits, cfg.out_file, cfg.rpc_meta_compressed, ) else: all_passages = get_all_passages(ctx_sources) if cfg.validate_as_tables: questions_doc_hits = validate_tables( all_passages, question_answers, top_results_and_scores, cfg.validation_workers, cfg.match, ) else: questions_doc_hits = validate( all_passages, question_answers, top_results_and_scores, cfg.validation_workers, cfg.match, ) if cfg.out_file: save_results( all_passages, questions_text if questions_text else questions, question_answers, top_results_and_scores, questions_doc_hits, cfg.out_file, ) if cfg.kilt_out_file: kilt_ctx = next( iter([ ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc) ]), None) if not kilt_ctx: raise RuntimeError("No Kilt compatible context file provided") assert hasattr(cfg, "kilt_out_file") kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file, cfg.kilt_out_file)