def evaluate(args, last=True): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) setup_logging(args, local_rank) # only evaluate once if rank != 0: logging.info('Skipping node {}'.format(rank)) return ctx_l = parse_ctx(args.gpus) logging.info( 'Srarting inference without horovod on the first node on device {}'. format(str(ctx_l))) cfg, tokenizer, qa_net, use_segmentation = get_network( args.model_name, ctx_l, args.classifier_dropout, dtype=args.eval_dtype) if args.eval_dtype == 'float16': qa_net.cast('float16') qa_net.hybridize() logging.info('Prepare dev data') dev_features = get_squad_features(args, tokenizer, segment='dev') dev_data_path = os.path.join(args.data_dir, 'dev-v{}.json'.format(args.version)) dataset_processor = SquadDatasetProcessor( tokenizer=tokenizer, doc_stride=args.doc_stride, max_seq_length=args.max_seq_length, max_query_length=args.max_query_length) dev_all_chunk_features = [] dev_chunk_feature_ptr = [0] for feature in dev_features: chunk_features = dataset_processor.process_sample(feature) dev_all_chunk_features.extend(chunk_features) dev_chunk_feature_ptr.append(dev_chunk_feature_ptr[-1] + len(chunk_features)) def eval_validation(ckpt_name, best_eval): """ Model inference during validation or final evaluation. """ dev_dataloader = mx.gluon.data.DataLoader( dev_all_chunk_features, batchify_fn=dataset_processor.BatchifyFunction, batch_size=args.eval_batch_size, num_workers=0, shuffle=False) log_interval = args.eval_log_interval all_results = [] epoch_tic = time.time() tic = time.time() epoch_size = len(dev_features) total_num = 0 log_num = 0 for batch_idx, dev_batch in enumerate( grouper(dev_dataloader, len(ctx_l))): # Predict for each chunk for sample, ctx in zip(dev_batch, ctx_l): if sample is None: continue # Copy the data to device tokens = sample.data.as_in_ctx(ctx) total_num += len(tokens) log_num += len(tokens) segment_ids = sample.segment_ids.as_in_ctx( ctx) if use_segmentation else None valid_length = sample.valid_length.as_in_ctx(ctx) p_mask = sample.masks.as_in_ctx(ctx) p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask start_top_logits, start_top_index, end_top_logits, end_top_index, answerable_logits \ = qa_net.inference(tokens, segment_ids, valid_length, p_mask, args.start_top_n, args.end_top_n) for i, qas_id in enumerate(sample.qas_id): result = RawResultExtended( qas_id=qas_id, start_top_logits=start_top_logits[i].asnumpy(), start_top_index=start_top_index[i].asnumpy(), end_top_logits=end_top_logits[i].asnumpy(), end_top_index=end_top_index[i].asnumpy(), answerable_logits=answerable_logits[i].asnumpy()) all_results.append(result) # logging if (batch_idx + 1) % log_interval == 0: # Output the loss of per step toc = time.time() logging.info( '[batch {}], Time cost={:.2f},' ' Throughput={:.2f} samples/s, ETA={:.2f}h'.format( batch_idx + 1, toc - tic, log_num / (toc - tic), (epoch_size - total_num) / (total_num / (toc - epoch_tic)) / 3600)) tic = time.time() log_num = 0 epoch_toc = time.time() logging.info('Time cost=%2f s, Thoughput=%.2f samples/s', epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)) all_predictions = collections.OrderedDict() all_nbest_json = collections.OrderedDict() no_answer_score_json = collections.OrderedDict() for index, (left_index, right_index) in enumerate( zip(dev_chunk_feature_ptr[:-1], dev_chunk_feature_ptr[1:])): chunked_features = dev_all_chunk_features[left_index:right_index] results = all_results[left_index:right_index] original_feature = dev_features[index] qas_ids = set([result.qas_id for result in results] + [feature.qas_id for feature in chunked_features]) assert len( qas_ids) == 1, 'Mismatch Occured between features and results' example_qas_id = list(qas_ids)[0] assert example_qas_id == original_feature.qas_id, \ 'Mismatch Occured between original feature and chunked features' not_answerable_score, best_pred, nbest_json = predict_extended( original_feature=original_feature, chunked_features=chunked_features, results=results, n_best_size=args.n_best_size, max_answer_length=args.max_answer_length, start_top_n=args.start_top_n, end_top_n=args.end_top_n) no_answer_score_json[example_qas_id] = not_answerable_score all_predictions[example_qas_id] = best_pred all_nbest_json[example_qas_id] = nbest_json if args.version == '2.0': exact = 'best_exact' f1 = 'best_f1' na_prob = no_answer_score_json else: exact = 'exact' f1 = 'f1' na_prob = None cur_eval, revised_predictions = squad_eval(dev_data_path, all_predictions, na_prob, revise=na_prob is not None) logging.info('The evaluated results are {}'.format( json.dumps(cur_eval))) cur_metrics = 0.5 * (cur_eval[exact] + cur_eval[f1]) if best_eval: best_metrics = 0.5 * (best_eval[exact] + best_eval[f1]) else: best_metrics = 0. if cur_metrics > best_metrics: logging.info('The evaluated files are saved in {}'.format( args.output_dir)) output_prediction_file = os.path.join(args.output_dir, 'predictions.json') output_nbest_file = os.path.join(args.output_dir, 'nbest_predictions.json') na_prob_file = os.path.join(args.output_dir, 'na_prob.json') revised_prediction_file = os.path.join(args.output_dir, 'revised_predictions.json') with open(output_prediction_file, 'w') as of: of.write(json.dumps(all_predictions, indent=4) + '\n') with open(output_nbest_file, 'w') as of: of.write(json.dumps(all_nbest_json, indent=4) + '\n') with open(na_prob_file, 'w') as of: of.write(json.dumps(no_answer_score_json, indent=4) + '\n') with open(revised_prediction_file, 'w') as of: of.write(json.dumps(revised_predictions, indent=4) + '\n') best_eval = cur_eval best_eval.update({'best_ckpt': ckpt_name}) return best_eval if args.param_checkpoint and args.param_checkpoint.endswith('.params'): ckpt_candidates = [args.param_checkpoint] else: ckpt_candidates = [ f for f in os.listdir(args.output_dir) if f.endswith('.params') ] ckpt_candidates.sort(key=lambda ele: (len(ele), ele)) ckpt_candidates = [ os.path.join(args.output_dir, ele) for ele in ckpt_candidates ] if last: ckpt_candidates = ckpt_candidates[-1:] best_eval = {} for ckpt_path in ckpt_candidates: logging.info('Starting evaluate the checkpoint {}'.format(ckpt_path)) qa_net.load_parameters(ckpt_path, ctx=ctx_l, cast_dtype=True) best_eval = eval_validation(ckpt_path, best_eval) logging.info('The best evaluated results are {}'.format( json.dumps(best_eval))) output_eval_results_file = os.path.join(args.output_dir, 'best_results.json') with open(output_eval_results_file, 'w') as of: of.write(json.dumps(best_eval, indent=4) + '\n') return best_eval
def evaluate(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) # setup_logging(args, local_rank) task = get_task(args.task_name) if rank != 0: logging.info('Skipping node {}'.format(rank)) return ctx_l = parse_ctx(args.gpus) logging.info( 'Srarting inference without horovod on the first node on device {}'. format(str(ctx_l))) cfg, tokenizer, classify_net, use_segmentation = \ get_network(args.model_name, ctx_l, args.param_checkpoint, args.backbone_path, task) candidate_ckpt = [] detail_dir = os.path.join(args.output_dir, args.task_name) for name in os.listdir(detail_dir): if name.endswith( '.params' ) and args.task_name in name and args.model_name in name: candidate_ckpt.append(os.path.join(detail_dir, name)) best_ckpt = {} metrics = task.metric def evaluate_by_ckpt(ckpt_name, best_ckpt): classify_net.load_parameters(ckpt_name, ctx=ctx_l, cast_dtype=True) logging.info('Prepare dev data') dev_data, label = get_task_data(args, tokenizer, segment='eval', task=task) dev_batchify = bf.Group(bf.Group(bf.Pad(), bf.Pad(), bf.Stack()), bf.Stack()) dataloader = DataLoader(dev_data, batch_size=args.batch_size, batchify_fn=dev_batchify, shuffle=False) for sample_l in grouper(dataloader, len(ctx_l)): for sample, ctx in zip(sample_l, ctx_l): if sample is None: continue (token_ids, token_types, valid_length), label = sample token_ids = mx.np.array(token_ids, ctx=ctx) token_types = mx.np.array(token_types, ctx=ctx) valid_length = mx.np.array(valid_length, ctx=ctx) scores = classify_net(token_ids, token_types, valid_length) if args.task_name == 'sts': label = label.reshape((-1, 1)) for metric in metrics: metric.update([label], [scores]) #pred.append(scores) for metric in metrics: metric_name, result = metric.get() logging.info('checkpoint {} get result: {}:{}'.format( ckpt_name, metric_name, result)) if best_ckpt.get(metric_name, [0, ''])[0] < result: best_ckpt[metric_name] = [result, ckpt_name] for ckpt_name in candidate_ckpt: evaluate_by_ckpt(ckpt_name, best_ckpt) for metric_name in best_ckpt: logging.info( 'best result on metric {}: is {}, and on checkpoint {}'.format( metric_name, best_ckpt[metric_name][0], best_ckpt[metric_name][1]))