def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) retrievers = dict() for key, kg_args in config_dict['retrievers'].items(): file_path = kg_args['file_path'] retrievers[key] = initialize_kg_retriever(key, file_path) max_length = kg_args['max_concept_length'] retrievers[key].update_max_concept_length(max_length) config = cls(**config_dict) if len(retrievers.items()) > 0: config.add_kgretrievers(retrievers) config.set_sizes() if hasattr(config, "pruned_heads"): config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) # Update config with kwargs if needed to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) logger.info("Model config %s", str(config)) if return_unused_kwargs: return config, kwargs else: return config
def evaluate(args, model, processor, tokenizer, global_step, input_dir, prefix=""): retrievers = dict() for kg in args.use_kgs: logger.info("Initialize kg:{}".format(kg)) kg_path = os.path.join(input_dir, args.kg_paths[kg]) data_path = os.path.join(args.data_dir, args.kg_paths[kg]) if not os.path.exists(kg_path): logger.warning( "need prepare training dataset firstly, program exit") exit() retrievers[kg] = initialize_kg_retriever(kg, kg_path, data_path, args.cache_file_suffix) dataset, examples_tokenized, features, wn_synset_graphs, wn_synset_graphs_label_dict = \ load_and_cache_examples(args, processor, retrievers, relation_list=args.relation_list, input_dir=input_dir, evaluate=True, output_examples=True) if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.mkdir(args.output_dir) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) eval_sampler = SequentialSampler( dataset) if args.local_rank == -1 else DistributedSampler(dataset) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # multi-gpu evaluate if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): model = torch.nn.DataParallel(model) if args.local_rank != -1 and not isinstance( model, torch.nn.parallel.DistributedDataParallel): model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Dataset size = %d", len(dataset)) logger.info(" Batch size = %d", args.eval_batch_size) if args.local_rank == -1: logger.warning("program exits and please use pytorch DDP framework") exit() else: # all_results = [] all_start_logits = torch.tensor([], dtype=torch.float32, device=args.device) all_end_logits = torch.tensor([], dtype=torch.float32, device=args.device) all_unique_ids = [] # start_time = timeit.default_timer() epoch_iterator = tqdm(eval_dataloader, desc="Evaluating Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): model.eval() batch = tuple(t.to(args.device) for t in batch) batch_synset_graphs = batch[3] with torch.no_grad(): inputs = create_input(args, batch, global_step, batch_synset_graphs=batch_synset_graphs, wn_synset_graphs=wn_synset_graphs, evaluate=True) feature_indices = batch[3] outputs = model(**inputs) all_start_logits = torch.cat((all_start_logits, outputs[0]), dim=0) all_end_logits = torch.cat((all_end_logits, outputs[1]), dim=0) for i, feature_index in enumerate(feature_indices): eval_feature = features[feature_index.item()] unique_id = int(eval_feature.unique_id) all_unique_ids.append(unique_id) all_unique_ids = torch.tensor(all_unique_ids, dtype=torch.long, device=args.device) start_time = timeit.default_timer() all_start_logits_list = [ torch.zeros_like(all_start_logits, device=args.device) for _ in range(torch.distributed.get_world_size()) ] all_end_logits_list = [ torch.zeros_like(all_end_logits, device=args.device) for _ in range(torch.distributed.get_world_size()) ] all_unique_ids_list = [ torch.zeros_like(all_unique_ids, device=args.device) for _ in range(torch.distributed.get_world_size()) ] all_gather(all_start_logits_list, all_start_logits) all_gather(all_end_logits_list, all_end_logits) all_gather(all_unique_ids_list, all_unique_ids) if args.local_rank == 0: start_time = timeit.default_timer() all_results = [] all_unique_ids_list = all_unique_ids_list all_start_logits_list = all_start_logits_list all_end_logits_list = all_end_logits_list for batch_idx, batch_unique_ids in enumerate(all_unique_ids_list): batch_start_logits = all_start_logits_list[batch_idx] batch_end_logits = all_end_logits_list[batch_idx] for i, unique_id in enumerate(batch_unique_ids): start_logits, end_logits = to_list( batch_start_logits[i]), to_list(batch_end_logits[i]) result = RecordResult(int(unique_id.cpu().numpy()), start_logits, end_logits) all_results.append(result) evalTime = timeit.default_timer() - start_time logger.info( " Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset)) # Compute predictions output_prediction_file = os.path.join( args.output_dir, "predictions_{}.json".format(prefix)) output_result = os.path.join(args.output_dir, "results_{}.jsonl".format(prefix)) predictions = RecordProcessor.compute_predictions_logits( examples_tokenized, features, all_results, args.n_best_size, args.max_answer_length, output_prediction_file, output_result, args.verbose_logging, os.path.join(args.data_dir, args.predict_file), tokenizer, is_testing=args.test, ) # Compute the F1 and exact scores. if not args.test: results = RecordProcessor.record_evaluate( examples_tokenized, predictions, relate_path=args.output_dir) return results else: return None
def main(): parser = argparse.ArgumentParser() model_g = ArgumentGroup(parser, "model", "model configuration and path.") model_g.add_arg("dataset", str, "record", "used dataset") model_g.add_arg("is_update_max_concept", bool, True, "weather update max concept for kg retriver") model_g.add_arg("full_table", bool, True, "full_table") model_g.add_arg("test", bool, False, "weather load superglue test set") model_g.add_arg("use_wn", bool, True, "wn") model_g.add_arg("use_nell", bool, True, "nell") model_g.add_arg("sentinel_trainable", bool, False, "sentinel_trainable") model_g.add_arg("memory_bank_update", bool, False, "memory_bank_update") model_g.add_arg("memory_bank_update_steps", int, 500, "memory_bank_update_steps") model_g.add_arg("memory_bank_keep_coef", float, 0.0, "what percent keep") model_g.add_arg("use_context_graph", bool, True, "use_context_graph") model_g.add_arg("schedule_strategy", str, "linear", "schedule_strategy") model_g.add_arg("tokenizer_path", str, "", "tokenizer_path") model_g.add_arg("save_model", bool, True, "whether save model") model_g.add_arg("data_preprocess", bool, False, "data process") model_g.add_arg("data_preprocess_evaluate", bool, False, "data_preprocess_evaluate") # multi-relational part model_g.add_arg("relation_agg", str, "sum", "the method to aggeregate multi-relational neoghbor") model_g.add_arg("is_lemma", bool, False, "whether trigger lemma") model_g.add_arg("is_filter", bool, True, "weather filter node not in wn18") model_g.add_arg("is_clean", bool, True, "weather filter node not in repeated_id") model_g.add_arg("is_morphy", bool, False, "weather morphy") model_g.add_arg("fewer_label", bool, False, "weather fewer_label") model_g.add_arg("label_rate", float, 0.1, "label rate") model_g.add_arg("relation_list", list, [ "_hyponym", "_hypernym", "_derivationally_related_form", "_member_meronym", "_member_holonym", "_part_of", "_has_part", "_member_of_domain_topic", "_synset_domain_topic_of", "_instance_hyponym", "_instance_hypernym", "_also_see", "_verb_group", "_member_of_domain_region", "_synset_domain_region_of", "_member_of_domain_usage", "_synset_domain_usage_of", "_similar_to" ], "The used relation.") model_g.add_arg("is_all_relation", bool, True, "use all relations") model_g.add_arg("selected_relation", str, "_hyponym,_hypernym,_derivationally_related_form", "relations") model_g.add_arg("wn18_dir", str, "", "wn18 dir") # SSL part model_g.add_arg("use_consistent_loss_wn", bool, False, "add consistent loss between entity embedding from WN.") model_g.add_arg("warm_up", int, 10000, "warm_up_iterations") model_g.add_arg("consistent_loss_wn_coeff", float, 2.0, "Weight decay if we apply some.") model_g.add_arg("consistent_loss_type", str, "kld", "consistent loss type") model_g.add_arg("mark", str, "test1", "mark") model_g.add_arg("tensorboard_dir", str, "./", "tensorboard_dir") model_g.add_arg("debug", bool, False, "debug") model_g.add_arg( "model_name_or_path", str, "", "Path to pretrained model or model identifier from huggingface.co/models" ) model_g.add_arg( "config_name", str, "", "Pretrained config name or path if not the same as model_name") model_g.add_arg("model_type", str, "kelm", "The classification model to be used.") model_g.add_arg("text_embed_model", str, "bert", "The model for embedding texts in KELM model.") model_g.add_arg("output_dir", str, "../outputs/test", "Path to save checkpoints.") model_g.add_arg("overwrite_output_dir", bool, True, "Overwrite the content of the output directory.") model_g.add_arg( "--tokenizer_name", default="", type=str, help="Pretrained tokenizer name or path if not the same as model_name", ) model_g.add_arg("per_gpu_train_batch_size", int, 6, "Batch size per GPU/CPU for training.") model_g.add_arg("per_gpu_eval_batch_size", int, 4, "Batch size per GPU/CPU for evaluation.") model_g.add_arg( "max_steps", int, -1, "If > 0: set total number of training steps to perform. Override num_train_epochs." ) model_g.add_arg( "gradient_accumulation_steps", int, 1, "Number of updates steps to accumulate before performing a backward/update pass." ) model_g.add_arg("num_train_epochs", float, 10, "Total number of training epochs to perform.") model_g.add_arg("weight_decay", float, 0.01, "Weight decay if we apply some.") model_g.add_arg("learning_rate", float, 3e-4, "The initial learning rate for Adam.") model_g.add_arg("adam_epsilon", float, 1e-8, "Epsilon for Adam optimizer.") model_g.add_arg("warmup_steps", int, 10, "Linear warmup over warmup_steps.") model_g.add_arg("max_grad_norm", float, 1.0, "Max gradient norm.") model_g.add_arg("evaluate_steps", int, 2, "Evaluate every X updates steps.") model_g.add_arg("evaluate_epoch", float, 0.0, "evaluate every X update epoch") model_g.add_arg("save_steps", int, 1, "Save every X updates steps.") model_g.add_arg("evaluate_during_training", bool, True, "Run evaluation during training at each logging step.") model_g.add_arg( "n_best_size", int, 20, "The total number of n-best predictions to generate in the nbest_predictions.json output file." ) model_g.add_arg( "verbose_logging", bool, False, "If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") model_g.add_arg("init_dir", str, "", "The path of loading pre-trained model.") model_g.add_arg("initializer_range", float, 0.02, "The initializer range for KELM") model_g.add_arg("cat_mul", bool, True, "The output part of vector in KELM") model_g.add_arg("cat_sub", bool, True, "The output part of vector in KELM") model_g.add_arg("cat_twotime", bool, True, "The output part of vector in KELM") model_g.add_arg("cat_twotime_mul", bool, True, "The output part of vector in KELM") model_g.add_arg("cat_twotime_sub", bool, False, "The output part of vector in KELM") data_g = ArgumentGroup( parser, "data", "Data paths, vocab paths and data processing options") data_g.add_arg("train_file", str, "record/train_0831.json", "ReCoRD json for training. E.g., train.json.") data_g.add_arg("predict_file", str, "record/dev_0831.json", "ReCoRD json for predictions. E.g. dev.json.") data_g.add_arg("cache_file_suffix", str, "test", "The suffix of cached file.") data_g.add_arg("cache_dir", str, "", "The cached data path.") data_g.add_arg("cache_store_dir", str, "", "The cached data path.") data_g.add_arg( "data_dir", str, "", "The input data dir. Should contain the .json files for the task." + "If no data dir or train/predict files are specified, will run with tensorflow_datasets." ) data_g.add_arg("vocab_path", str, "vocab.txt", "Vocabulary path.") data_g.add_arg( "do_lower_case", bool, False, "Whether to lower case the input text. Should be True for uncased models and False for cased models." ) data_g.add_arg("seed", int, 42, "Random seed.") data_g.add_arg("kg_paths", dict, { "wordnet": "kgs/", "nell": "kgs/" }, "The paths of knowledge graph files.") data_g.add_arg("wn_concept_embedding_path", str, "embedded/wn_concept2vec.txt", "The embeddings of concept in knowledge graph : Wordnet.") data_g.add_arg("nell_concept_embedding_path", str, "embedded/nell_concept2vec.txt", "The embeddings of concept in knowledge graph : Nell.") data_g.add_arg("use_kgs", list, ['nell', 'wordnet'], "The used knowledge graphs.") data_g.add_arg( "doc_stride", int, 128, "When splitting up a long document into chunks, how much stride to take between chunks." ) data_g.add_arg("max_seq_length", int, 384, "Number of words of the longest seqence.") data_g.add_arg("max_query_length", int, 64, "Max query length.") data_g.add_arg("max_answer_length", int, 30, "Max answer length.") data_g.add_arg("no_stopwords", bool, True, "Whether to include stopwords.") data_g.add_arg("ignore_length", int, 0, "The smallest size of token.") data_g.add_arg("print_loss_step", int, 100, "The steps to print loss.") run_type_g = ArgumentGroup(parser, "run_type", "running type options.") run_type_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.") run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.") run_type_g.add_arg("max_n_gpu", int, 100, "The maximum number of GPU to use.") run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).") run_type_g.add_arg( "num_iteration_per_drop_scope", int, 1, "Ihe iteration intervals to clean up temporary variables.") run_type_g.add_arg("do_train", bool, True, "Whether to perform training.") run_type_g.add_arg("do_eval", bool, False, "Whether to perform evaluation during training.") run_type_g.add_arg("do_predict", bool, False, "Whether to perform prediction.") run_type_g.add_arg("freeze", bool, True, "freeze bert parameters") run_type_g.add_arg("server_ip", str, "", "Can be used for distant debugging.") run_type_g.add_arg( "chunksize", int, 1024, "The chunksize for multiprocessing to convert examples to features.") run_type_g.add_arg("server_port", str, "", "Can be used for distant debugging.") run_type_g.add_arg("local_rank", int, -1, "Index for distributed training on gpus.") run_type_g.add_arg("threads", int, 50, "multiple threads for converting example to features") run_type_g.add_arg("overwrite_cache", bool, False, "Overwrite the cached training and evaluation sets") run_type_g.add_arg( "eval_all_checkpoints", bool, False, "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number" ) run_type_g.add_arg( "min_diff_steps", int, 50, "The minimum saving steps before the last maximum steps.") args = parser.parse_args() logging.getLogger("transformers.modeling_utils").setLevel( logging.WARNING) # Reduce model loading logs if not args.is_all_relation: args.relation_list = args.selected_relation.split(",") logger.info("not use all relation, relation_list: {}".format( args.relation_list)) if args.doc_stride >= args.max_seq_length - args.max_query_length: logger.warning( "WARNING - You've set a doc stride which may be superior to the document length in some " "examples. This could result in errors when building features from the examples. Please reduce the doc " "stride or increase the maximum length to ensure the features are correctly built." ) if (os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir): raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome." .format(args.output_dir)) # Setup distant debugging if needed if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() # Setup CUDA, GPU & distributed training if args.local_rank == -1 or not args.use_cuda: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs device = torch.device( "cuda" if torch.cuda.is_available() and args.use_cuda else "cpu") args.n_gpu = 0 if not args.use_cuda else min(args.max_n_gpu, torch.cuda.device_count()) else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend="nccl") args.n_gpu = 1 args.device = device if args.local_rank in [-1, 0] and not os.path.exists(args.output_dir): os.mkdir(args.output_dir) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARNING, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.use_fp16, ) # Set seed set_seed(args) logger.info("Parameters from arguments are:\n{}".format(args)) # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.use_fp16 is set. # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will # remove the need for this code, but it is still valid. if args.use_fp16: try: import apex apex.amp.register_half_function(torch, "einsum") except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) processor = RecordProcessor(args) input_dir = os.path.join( args.cache_store_dir, "cached_{}_{}".format( args.model_type, str(args.cache_file_suffix), )) if not os.path.exists(input_dir): os.mkdir(input_dir) if args.full_table: logger.warning("set full_table False and program exits") exit() else: args.wn_def_embed_mat_dir = os.path.join( input_dir, args.cache_file_suffix) + "_" + "definition_embedding" # if not os.path.exists(args.wn_def_embed_mat_dir): # data_path = os.path.join(args.data_dir, args.kg_paths["wordnet"]) # definition_embedding_mat = create_definition_table(args, data_path) # # torch.save({"definition_embedding_mat": definition_embedding_mat}, args.wn_def_embed_mat_dir) # # logger.info("definition embedding is done. program exits.") # exit() ## create data retrievers = dict() for kg in args.use_kgs: logger.info("Initialize kg:{}".format(kg)) kg_path = os.path.join(input_dir, args.kg_paths[kg]) data_path = os.path.join(args.data_dir, args.kg_paths[kg]) retrievers[kg] = initialize_kg_retriever(kg, kg_path, data_path, args.cache_file_suffix) if args.data_preprocess: logger.info("begin preprocess") create_dataset(args, processor, retrievers, relation_list=args.relation_list, evaluate=args.data_preprocess_evaluate, input_dir=input_dir) logger.info("data preprocess is done") # Load pretrained model and tokenizers if args.local_rank not in [-1, 0]: # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() tokenizer, model = configure_tokenizer_model(args, logger, retrievers) if args.local_rank == 0: # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() model.to(args.device) results = evaluate(args, model, processor, tokenizer, 100, input_dir, prefix=args.mark) if args.local_rank in [-1, 0]: logger.info("results: {}".format(results)) logger.info("eval is done")
def evaluate(args, model, processor, tokenizer, global_step, input_dir, prefix=""): retrievers = dict() for kg in args.use_kgs: logger.info("Initialize kg:{}".format(kg)) kg_path = os.path.join(input_dir, args.kg_paths[kg]) data_path = os.path.join(args.data_dir, args.kg_paths[kg]) if not os.path.exists(kg_path): logger.warning("need prepare training dataset firstly, program exit") exit() retrievers[kg] = initialize_kg_retriever(kg, kg_path, data_path, args.cache_file_suffix) dataset, examples_tokenized, features, wn_synset_graphs, wn_synset_graphs_label_dict = \ load_and_cache_examples(args, processor, retrievers, relation_list=args.relation_list, input_dir=input_dir, evaluate=True, output_examples=True) if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.mkdir(args.output_dir) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset, shuffle=False) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # synset_graphs_batch = [] # for batch_index in eval_dataloader.batch_sampler: # synset_graphs_batch.append([i for i in batch_index]) # multi-gpu evaluate if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): model = torch.nn.DataParallel(model) if args.local_rank != -1 and not isinstance(model, torch.nn.parallel.DistributedDataParallel): model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True ) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Dataset size = %d", len(dataset)) logger.info(" Batch size = %d", args.eval_batch_size) if args.local_rank == -1: logger.warning("program exits and please use pytorch DDP framework") exit() else: # all_results = [] # all_start_logits = torch.tensor([], dtype=torch.float32, device=args.device) # all_end_logits = torch.tensor([], dtype=torch.float32, device=args.device) # all_unique_ids = [] all_pred = torch.tensor([], dtype=torch.long, device=args.device) all_label_ids = torch.tensor([], dtype=torch.long, device=args.device) all_question_ids = torch.tensor([], dtype=torch.long, device=args.device) # start_time = timeit.default_timer() epoch_iterator = tqdm(eval_dataloader, desc="Evaluating Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): model.eval() batch = tuple(t.to(args.device) for t in batch) batch_synset_graphs = batch[3] with torch.no_grad(): inputs = create_input(args, batch, global_step, batch_synset_graphs=batch_synset_graphs, wn_synset_graphs=wn_synset_graphs, evaluate=True) feature_indices = batch[3] outputs = model(**inputs) logits, label_ids, qas_ids = outputs[1], outputs[2], outputs[3] all_pred = torch.cat((all_pred, torch.argmax(logits, axis=-1)), dim=0) all_label_ids = torch.cat((all_label_ids, label_ids), dim=0) all_question_ids = torch.cat((all_question_ids, qas_ids), dim=0) start_time = timeit.default_timer() all_pred_list = [torch.zeros_like(all_pred, device=args.device) for _ in range(torch.distributed.get_world_size())] all_label_ids_list = [torch.zeros_like(all_label_ids, device=args.device) for _ in range(torch.distributed.get_world_size())] all_question_ids_list = [torch.zeros_like(all_question_ids, device=args.device) for _ in range(torch.distributed.get_world_size())] all_gather(all_pred_list, all_pred) all_gather(all_label_ids_list, all_label_ids) all_gather(all_question_ids_list, all_question_ids) logger.info( "time for gather communication:{} in rank {}".format(timeit.default_timer() - start_time, args.local_rank)) if args.local_rank == 0: all_results = [] all_pred_list = all_pred_list all_label_ids_list = all_label_ids_list all_question_ids_list = all_question_ids_list preds = np.asarray([], dtype=int) label_values = np.asarray([], dtype=int) question_ids = np.asarray([], dtype=int) for batch_idx, batch_preds in enumerate(all_pred_list): preds = np.concatenate((preds, batch_preds.cpu().detach().numpy()), axis=0) label_values = np.concatenate((label_values, all_label_ids_list[batch_idx].cpu().detach().numpy()), axis=0) question_ids = np.concatenate((question_ids, all_question_ids_list[batch_idx].cpu().detach().numpy()), axis=0) if not args.test: df = pd.DataFrame({'label_values': label_values, 'question_ids': question_ids}) assert "label_values" in df.columns assert "question_ids" in df.columns df["preds"] = preds # noinspection PyUnresolvedReferences exact_match = ( df.groupby("question_ids") .apply(lambda _: (_["preds"] == _["label_values"]).all()) .mean() ) exact_match = float(exact_match) f1 = f1_score(y_true=df["label_values"], y_pred=df["preds"]) results = {'exact_match': exact_match, 'f1': f1} else: results = None if args.write_preds: guids = [] for f in features: guids.append(f.guid[0]) guids = np.asarray(guids, dtype='<U18') assert len(preds)==len(guids) write_prediction(preds, guids, "multirc", args.output_dir, prefix) return results else: return None