def get_args(): options = Options() options.add_index_options() args = options.parse() coarse = 'hnsw' if args.hnsw else 'flat' args.index_name = f'{args.num_clusters}_{coarse}_{args.fine_quant}{"_first" if args.first_passage else ""}' if args.index_filter != -1e8: # other than default args.index_name = args.index_name + f'_ft{int(args.index_filter)}' args.index_dir = os.path.join(args.dump_dir, 'start', args.index_name) args.quantizer_path = os.path.join(args.index_dir, args.quantizer_path) args.trained_index_path = os.path.join(args.index_dir, args.trained_index_path) args.inv_path = os.path.join(args.index_dir, args.inv_path) args.subindex_dir = os.path.join(args.index_dir, args.subindex_name) if args.dump_paths is None: args.index_path = os.path.join(args.index_dir, args.index_path) args.idx2id_path = os.path.join(args.index_dir, args.idx2id_path) else: args.dump_paths = [ os.path.join(args.dump_dir, args.phrase_dir, path) for path in args.dump_paths.split(',') ] args.index_path = os.path.join(args.subindex_dir, '%d.faiss' % args.offset) args.idx2id_path = os.path.join(args.subindex_dir, '%d.hdf5' % args.offset) logger.info(f"Creating {args.index_name}...") return args
def __init__(self, load_dir, dump_dir, index_name='start/1048576_flat_OPQ96', device='cuda', verbose=False, **kwargs): print( "This could take up to 15 mins depending on the file reading speed of HDD/SSD" ) # Turn off loggers if not verbose: logging.getLogger("densephrases").setLevel(logging.WARNING) logging.getLogger("transformers").setLevel(logging.WARNING) # Get default options options = Options() options.add_model_options() options.add_index_options() options.add_retrieval_options() options.add_data_options() self.args = options.parse() # Set options self.args.load_dir = load_dir self.args.dump_dir = dump_dir self.args.cache_dir = os.environ['CACHE_DIR'] self.args.index_name = index_name self.args.cuda = True if device == 'cuda' else False self.args.__dict__.update(kwargs) # Load encoder self.set_encoder(load_dir, device) # Load MIPS self.mips = load_phrase_index(self.args, ignore_logging=not verbose) # Others self.truecase = TrueCaser( os.path.join(os.environ['DATA_DIR'], self.args.truecase_path)) print("Loading DensePhrases Completed!")
os.path.splitext(os.path.basename(pred_path))[0] + f'_{"sent" if args.return_sent else "psg"}-top{args.psg_top_k}{"_mark" if args.mark_phrase else ""}.json' ) logger.info(f"dump to {out_file}") json.dump(my_target, open(out_file, 'w'), indent=4) # Call subprocess for evaluation command = f'python scripts/postprocess/recall.py --k_values 1,5,20,100 --results_file {out_file} --ans_fn string' subprocess.run(command.split(' ')) if __name__ == '__main__': # See options in densephrases.options options = Options() options.add_model_options() options.add_index_options() options.add_retrieval_options() options.add_data_options() options.add_question_type_options() args = options.parse() # Seed for reproducibility random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) # Set wandb if args.wandb: wandb.init(project="DensePhrases Evaluation")