def main(params): output_path = params["output_path"] if not os.path.exists(output_path): os.makedirs(output_path) logger = utils.get_logger(params["output_path"]) # Init model reranker = BiEncoderRanker(params) tokenizer = reranker.tokenizer model = reranker.model device = reranker.device cand_encode_path = params.get("cand_encode_path", None) # candidate encoding is not pre-computed. # load/generate candidate pool to compute candidate encoding. cand_pool_path = params.get("cand_pool_path", None) candidate_pool = load_or_generate_candidate_pool( tokenizer, params, logger, cand_pool_path, ) candidate_encoding = None if cand_encode_path is not None: # try to load candidate encoding from path # if success, avoid computing candidate encoding try: logger.info("Loading pre-generated candidate encode path.") candidate_encoding = torch.load(cand_encode_path) except: logger.info("Loading failed. Generating candidate encoding.") if candidate_encoding is None: candidate_encoding = encode_candidate_zeshel( reranker, candidate_pool, params["encode_batch_size"], silent=params["silent"], logger=logger, ) if cand_encode_path is not None: # Save candidate encoding to avoid re-compute logger.info("Saving candidate encoding to file " + cand_encode_path) torch.save(cand_encode_path, candidate_encoding) test_samples = utils.read_dataset(params["mode"], params["data_path"]) logger.info("Read %d test samples." % len(test_samples)) test_data, test_tensor_data = data.process_mention_data( test_samples, tokenizer, params["max_context_length"], params["max_cand_length"], context_key=params['context_key'], silent=params["silent"], logger=logger, debug=params["debug"], ) test_sampler = SequentialSampler(test_tensor_data) test_dataloader = DataLoader(test_tensor_data, sampler=test_sampler, batch_size=params["encode_batch_size"]) save_results = params.get("save_topk_result") new_data = nnquery.get_topk_predictions( reranker, test_dataloader, candidate_pool, candidate_encoding, params["silent"], logger, params["top_k"], params.get("zeshel", None), save_results, ) if save_results: save_data_path = os.path.join( params['output_path'], 'candidates_%s_top%d.t7' % (params['mode'], params['top_k'])) torch.save(new_data, save_data_path)
def main(params): output_path = params["output_path"] if not os.path.exists(output_path): os.makedirs(output_path) logger = utils.get_logger(params["output_path"]) # Init model reranker = BiEncoderRanker(params) tokenizer = reranker.tokenizer model = reranker.model device = reranker.device cand_encode_path = params.get("cand_encode_path", None) # candidate encoding is not pre-computed. # load/generate candidate pool to compute candidate encoding. cand_pool_path = params.get("cand_pool_path", None) candidate_pool = load_or_generate_candidate_pool( tokenizer, params, logger, cand_pool_path, ) candidate_encoding = None if cand_encode_path is not None: # try to load candidate encoding from path # if success, avoid computing candidate encoding try: logger.info("Loading pre-generated candidate encode path.") candidate_encoding = torch.load(cand_encode_path) except: logger.info("Loading failed. Generating candidate encoding.") if candidate_encoding is None: candidate_encoding = encode_candidate(reranker, candidate_pool, params["encode_batch_size"], silent=params["silent"], logger=logger, is_zeshel=params.get( "zeshel", None)) if cand_encode_path is not None: # Save candidate encoding to avoid re-compute logger.info("Saving candidate encoding to file " + cand_encode_path) torch.save(candidate_encoding, cand_encode_path) test_samples = utils.read_dataset(params["mode"], params["data_path"]) # test_samples_custom = utils.read_dataset(params["mode"], 'data/zeshel/processed') # # Copy custom dataset to original except 'label_id' # for i in range(len(test_samples)): # for k in test_samples[i]: # if k == 'label_id': # continue # k_custom = 'type' if k == 'world' else k # test_samples[i][k] = test_samples_custom[i][k_custom] logger.info("Read %d test samples." % len(test_samples)) test_data, test_tensor_data = data.process_mention_data( test_samples, tokenizer, params["max_context_length"], params["max_cand_length"], context_key=params['context_key'], silent=params["silent"], logger=logger, debug=params["debug"], ) # custom_embeds = torch.load('models/trained/zeshel_og/eval/data_og/custom_embed.t7') # dict_idxs_by_type = torch.load('models/trained/zeshel_og/eval/data_og/dict_idx_mapping.t7') # test_dict_vecs = torch.load('models/trained/zeshel_og/eval/data_og/custom_test_ent_vecs.t7') # test_men_vecs = torch.load('models/trained/zeshel_og/eval/data_og/custom_test_men_vecs.t7') # world_to_type = {12:'forgotten_realms', 13:'lego', 14:'star_trek', 15:'yugioh'} # embed() test_sampler = SequentialSampler(test_tensor_data) test_dataloader = DataLoader(test_tensor_data, sampler=test_sampler, batch_size=params["encode_batch_size"]) save_results = params.get("save_topk_result") new_data = nnquery.get_topk_predictions( reranker, test_dataloader, candidate_pool, candidate_encoding, params["silent"], logger, params["top_k"], params.get("zeshel", None), save_results, ) if save_results: save_data_path = os.path.join( params['output_path'], 'candidates_%s_top%d.t7' % (params['mode'], params['top_k'])) torch.save(new_data, save_data_path)