def evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, store_parse_masks=show_sample, example_lengths=eval_num_transitions_batch) can_sample = (FLAGS.model_type == "RLSPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = output.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[: batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[ batch_size:] if transitions_per_example is not None else None sent1_trees = tree_strs[: batch_size] if tree_strs is not None else None sent2_trees = tree_strs[ batch_size:] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + tree_strs[0]) end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc
def evaluate(FLAGS, model, data_manager, eval_set, index, logger, step, vocabulary=None): filename, dataset = eval_set A = Accumulator() M = MetricsWriter(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name)) reporter = EvalReporter() eval_str = eval_format(model) eval_extra_str = eval_extra_format(model) # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 invalid = 0 start = time.time() model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability eval_accumulate(model, data_manager, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Optionally calculate transition loss/acc. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt+1)/2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: reporter_args = [pred, target, eval_ids, output.data.cpu().numpy()] if hasattr(model, 'transition_loss'): transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[:batch_size] sent2_transitions = transitions_per_example[batch_size:] reporter_args.append(sent1_transitions) reporter_args.append(sent2_transitions) else: reporter_args.append(transitions_per_example) reporter.save_batch(*reporter_args) # Print Progress progress_bar.step(i+1, total=total_batches) progress_bar.finish() end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) stats_args = eval_stats(model, A, step) stats_args['filename'] = filename logger.Log(eval_str.format(**stats_args)) logger.Log(eval_extra_str.format(**stats_args)) if FLAGS.write_eval_report: eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".report") reporter.write_report(eval_report_path) eval_class_acc = stats_args['class_acc'] eval_trans_acc = stats_args['transition_acc'] if index == 0: eval_metrics(M, stats_args, step) return eval_class_acc, eval_trans_acc
def evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() index = len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps ** ( step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos((step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, store_parse_masks=show_sample, example_lengths=eval_num_transitions_batch) can_sample = FLAGS.model_type in ["ChoiPyramid"] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) # TODO: Restore support in Pyramid if using. if show_sample and can_sample: tmp_samples = model.get_samples(eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: show_sample = False # Only show one sample, regardless of the number of batches. # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, data_manager, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Optionally calculate transition loss/acc. model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[:batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[batch_size:] if transitions_per_example is not None else None sent1_trees = tree_strs[:batch_size] if tree_strs is not None else None sent2_trees = tree_strs[batch_size:] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + tree_strs[0]) end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc
def evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() index = len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**( step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos( (step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids, _, silver_tree = batch # eval_X_batch: <batch x maxlen x 2> # eval_y_batch: <batch > # silver_tree: # the dist is invalid for val # Run model. output = model( eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, store_parse_masks=True, example_lengths=eval_num_transitions_batch) # TODO: Restore support in Pyramid if using. can_sample = FLAGS.model_type in [ "ChoiPyramid" ] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) # tree_strs = prettyprint_trees(tmp_samples) tree_strs = [tree for tree in tmp_samples] tmp_samples = model.get_samples(eval_X_batch, vocabulary, only_one=False) # def get_max(s): # # test f1 # max = 0 # for x in s: # _, idx = x.split(',') # if int(idx) > max: # max = int(idx) # return max for s in (range(int(model.use_sentence_pair) + 1)): for b in (range(silver_tree.shape[0])): model_out = tmp_samples[s * silver_tree.shape[0] + b] std_out = silver_tree[b, :, s] std_out = set([x for x in std_out if x != '-1,-1']) model_out_brackets, model_out_max_l = get_brackets(model_out) model_out = set(convert_brackets_to_string(model_out_brackets)) outmost_bracket = '{:d},{:d}'.format(0, model_out_max_l) std_out.add(outmost_bracket) model_out.add(outmost_bracket) # print get_max(model_out), get_max(std_out) # print model_out # print std_out # print '=' * 30 # assert get_max(model_out) == get_max(std_out) overlap = model_out & std_out prec = float(len(overlap)) / (len(model_out) + 1e-8) reca = float(len(overlap)) / (len(std_out) + 1e-8) if len(std_out) == 0: reca = 1. if len(model_out) == 0: prec = 1. f1 = 2 * prec * reca / (prec + reca + 1e-8) A.add('f1', f1) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, data_manager, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Optionally calculate transition loss/acc. #model.transition_loss if hasattr(model, 'transition_loss') else None # TODO: review this. the original line seems to have no effect # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[: batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[ batch_size:] if transitions_per_example is not None else None sent1_trees = tree_strs[: batch_size] if tree_strs is not None else None sent2_trees = tree_strs[ batch_size:] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + str(tree_strs[0])) end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) # get the eval statistics (e.g. average F1) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy eval_f1 = eval_log.f1 return eval_class_acc, eval_trans_acc, eval_f1
def evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, vocabulary=None, show_sample=False, eval_index=0, target_vocabulary=None): filename, dataset = eval_set A = Accumulator() len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() ref_file_name = FLAGS.log_path + "/ref_file" pred_file_name = FLAGS.log_path + "/pred_file" reference_file = open(ref_file_name, "w") predict_file = open(pred_file_name, "w") full_ref = [] full_pred = [] for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, example_lengths=eval_num_transitions_batch) can_sample = (FLAGS.model_type == "RLSPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.encoder.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Get reference translation ref_out = [" ".join(map(str, k[:-1])) + " ." for k in eval_y_batch] full_ref += ref_out # Get predicted translation predicted = [[] for i in range(len(eval_y_batch))] done = [] for x in output: index = -1 for x_0 in x: index += 1 val = int(x_0) if val == 1: if index in done: continue done.append(index) elif index not in done: predicted[index].append(val) pred_out = [" ".join(map(str, k)) + " ." for k in predicted] full_pred += pred_out eval_accumulate(model, A, batch) # Optionally calculate transition loss/acc. model.encoder.transition_loss if hasattr(model.encoder, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / \ 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.encoder.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(full_pred, full_ref, eval_ids, [None], sent1_transitions, sent2_transitions, sent1_trees, sent2_trees, mt=True) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + tree_strs[0]) reference_file.write("\n".join(full_ref)) reference_file.close() predict_file.write("\n".join(full_pred)) predict_file.close() bleu_score = os.popen("perl spinn/util/multi-bleu.perl " + ref_file_name + " < " + pred_file_name).read() try: bleu_score = float(bleu_score) except: bleu_score = 0.0 end = time.time() total_time = end - start A.add('class_correct', bleu_score) A.add('class_total', 1) A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) stats = parse_comparison.run_main( data_type="mt", main_report_path_template=FLAGS.log_path + "/" + FLAGS.experiment_name + ".eval_set_0.report", main_data_path=FLAGS.source_eval_path) # To-do: include the following into lgog-formatter so it's reported in standard format. if tree_strs is not None: logger.Log( 'F1 w/ GT: ' + str(stats['gt']) + '\n' +\ 'F1 w/ LB: ' + str(stats['lb']) + '\n' +\ 'F1 w/ RB: ' + str(stats['rb']) + '\n' +\ 'Avg. tree depth: ' + str(stats['depth']) ) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc
def evaluate(model, eval_set, logger, metrics_logger, step, vocabulary=None): filename, dataset = eval_set reporter = EvalReporter() # Evaluate class_correct = 0 class_total = 0 total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() transition_preds = [] transition_targets = [] for i, (eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids) in enumerate(dataset): if FLAGS.truncate_eval_batch: eval_X_batch, eval_transitions_batch = truncate( eval_X_batch, eval_transitions_batch, eval_num_transitions_batch) if FLAGS.saving_eval_attention_matrix: model.set_recording_attention_weight_matrix(True) # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions,) if FLAGS.saving_eval_attention_matrix: # WARNING: only attention SPINN model have attention matrix attention_matrix = model.get_attention_matrix_from_last_forward() with open(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name, 'attention-matrix-{}.txt'.format(step)), 'a') as txtfile: for eval_id, attmat in izip(eval_ids, attention_matrix): txtfile.write('{}\n'.format(eval_id)) txtfile.write('{},{}\n'.format(len(attmat), len(attmat[0]))) for row in attmat: txtfile.write(','.join(['{:.1f}'.format(x*100.0) for x in row])) txtfile.write('\n') model.set_recording_attention_weight_matrix(False) # reset it after run # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_correct += pred.eq(target).sum() class_total += target.size(0) # Optionally calculate transition loss/acc. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += eval_num_transitions_batch.ravel().sum() # Accumulate stats for transition accuracy. if transition_loss is not None: transition_preds.append([m["t_preds"] for m in model.spinn.memories]) transition_targets.append([m["t_given"] for m in model.spinn.memories]) if FLAGS.write_eval_report: reporter_args = [pred, target, eval_ids, output.data.cpu().numpy()] if hasattr(model, 'transition_loss'): transition_preds_per_example = model.spinn.get_transition_preds_per_example() if model.use_sentence_pair: batch_size = pred.size(0) sent1_preds = transition_preds_per_example[:batch_size] sent2_preds = transition_preds_per_example[batch_size:] reporter_args.append(sent1_preds) reporter_args.append(sent2_preds) else: reporter_args.append(transition_preds_per_example) reporter.save_batch(*reporter_args) # Print Progress progress_bar.step(i+1, total=total_batches) progress_bar.finish() end = time.time() total_time = end - start # Get time per token. time_metric = time_per_token([total_tokens], [total_time]) # Get class accuracy. eval_class_acc = class_correct / float(class_total) # Get transition accuracy if applicable. if len(transition_preds) > 0: all_preds = np.array(flatten(transition_preds)) all_truth = np.array(flatten(transition_targets)) eval_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0]) else: eval_trans_acc = 0.0 logger.Log("Step: %i Eval acc: %f %f %s Time: %5f" % (step, eval_class_acc, eval_trans_acc, filename, time_metric)) metrics_logger.Log('eval_class_acc', eval_class_acc, step) metrics_logger.Log('eval_trans_acc', eval_trans_acc, step) if FLAGS.write_eval_report: eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".report") reporter.write_report(eval_report_path) return eval_class_acc