Exemple #1
0
def evaluate(FLAGS,
             model,
             eval_set,
             log_entry,
             logger,
             trainer,
             vocabulary=None,
             show_sample=False,
             eval_index=0):
    filename, dataset = eval_set

    A = Accumulator()
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    model.eval()
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        output = model(eval_X_batch,
                       eval_transitions_batch,
                       eval_y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       store_parse_masks=show_sample,
                       example_lengths=eval_num_transitions_batch)

        can_sample = (FLAGS.model_type == "RLSPINN"
                      and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)
        if not FLAGS.write_eval_report:
            # Only show one sample, regardless of the number of batches.
            show_sample = False

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()

        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=False)[1].cpu()

        eval_accumulate(model, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / 2
                             for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                    FLAGS.model_type == "SPINN"
                    and FLAGS.use_internal_parser) else (None, None)

            if model.use_sentence_pair:
                batch_size = pred.size(0)
                sent1_transitions = transitions_per_example[:
                                                            batch_size] if transitions_per_example is not None else None
                sent2_transitions = transitions_per_example[
                    batch_size:] if transitions_per_example is not None else None

                sent1_trees = tree_strs[:
                                        batch_size] if tree_strs is not None else None
                sent2_trees = tree_strs[
                    batch_size:] if tree_strs is not None else None
            else:
                sent1_transitions = transitions_per_example if transitions_per_example is not None else None
                sent2_transitions = None

                sent1_trees = tree_strs if tree_strs is not None else None
                sent2_trees = None

            reporter.save_batch(pred, target, eval_ids,
                                output.data.cpu().numpy(), sent1_transitions,
                                sent2_transitions, sent1_trees, sent2_trees)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()
    if tree_strs is not None:
        logger.Log('Sample: ' + tree_strs[0])

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    eval_stats(model, A, eval_log)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc
Exemple #2
0
def evaluate(FLAGS, model, data_manager, eval_set, index, logger, step, vocabulary=None):
    filename, dataset = eval_set

    A = Accumulator()
    M = MetricsWriter(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name))
    reporter = EvalReporter()

    eval_str = eval_format(model)
    eval_extra_str = eval_extra_format(model)

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    invalid = 0
    start = time.time()

    model.eval()
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        output = model(eval_X_batch, eval_transitions_batch, eval_y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions)

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()
        pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability

        eval_accumulate(model, data_manager, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Optionally calculate transition loss/acc.
        transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += sum([(nt+1)/2 for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            reporter_args = [pred, target, eval_ids, output.data.cpu().numpy()]
            if hasattr(model, 'transition_loss'):
                transitions_per_example, _ = model.spinn.get_transitions_per_example(
                    style="preds" if FLAGS.eval_report_use_preds else "given")
                if model.use_sentence_pair:
                    batch_size = pred.size(0)
                    sent1_transitions = transitions_per_example[:batch_size]
                    sent2_transitions = transitions_per_example[batch_size:]
                    reporter_args.append(sent1_transitions)
                    reporter_args.append(sent2_transitions)
                else:
                    reporter_args.append(transitions_per_example)
            reporter.save_batch(*reporter_args)

        # Print Progress
        progress_bar.step(i+1, total=total_batches)
    progress_bar.finish()

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    stats_args = eval_stats(model, A, step)
    stats_args['filename'] = filename

    logger.Log(eval_str.format(**stats_args))
    logger.Log(eval_extra_str.format(**stats_args))

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = stats_args['class_acc']
    eval_trans_acc = stats_args['transition_acc']

    if index == 0:
        eval_metrics(M, stats_args, step)

    return eval_class_acc, eval_trans_acc
def evaluate(FLAGS, model, data_manager, eval_set, log_entry,
             logger, step, vocabulary=None, show_sample=False, eval_index=0):
    filename, dataset = eval_set

    A = Accumulator()
    index = len(log_entry.evaluation)
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]:
        pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps ** (
            step / 10000.0)
        if FLAGS.pyramid_temperature_cycle_length > 0.0:
            min_temp = 1e-5
            pyramid_temperature_multiplier *= (math.cos((step) /
                                                        FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2
    else:
        pyramid_temperature_multiplier = None

    model.eval()
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        output = model(eval_X_batch, eval_transitions_batch, eval_y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                       store_parse_masks=show_sample,
                       example_lengths=eval_num_transitions_batch)

        can_sample = FLAGS.model_type in ["ChoiPyramid"] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser)  # TODO: Restore support in Pyramid if using.
        if show_sample and can_sample:
            tmp_samples = model.get_samples(eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)
        if not FLAGS.write_eval_report:
            show_sample = False  # Only show one sample, regardless of the number of batches.


        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()

        # get the index of the max log-probability
        pred = logits.data.max(1, keepdim=False)[1].cpu()

        eval_accumulate(model, data_manager, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Optionally calculate transition loss/acc.
        model.transition_loss if hasattr(model, 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.spinn.get_transitions_per_example(
                    style="preds" if FLAGS.eval_report_use_preds else "given") if (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None)

            if model.use_sentence_pair:
                batch_size = pred.size(0)
                sent1_transitions = transitions_per_example[:batch_size] if transitions_per_example is not None else None
                sent2_transitions = transitions_per_example[batch_size:] if transitions_per_example is not None else None

                sent1_trees = tree_strs[:batch_size] if tree_strs is not None else None
                sent2_trees = tree_strs[batch_size:] if tree_strs is not None else None
            else:
                sent1_transitions = transitions_per_example if transitions_per_example is not None else None
                sent2_transitions = None

                sent1_trees = tree_strs if tree_strs is not None else None
                sent2_trees = None

            reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()
    if tree_strs is not None:
        logger.Log('Sample: ' + tree_strs[0])

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    eval_stats(model, A, eval_log)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc
def evaluate(FLAGS,
             model,
             data_manager,
             eval_set,
             log_entry,
             logger,
             step,
             vocabulary=None,
             show_sample=False,
             eval_index=0):
    filename, dataset = eval_set

    A = Accumulator()
    index = len(log_entry.evaluation)
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]:
        pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**(
            step / 10000.0)
        if FLAGS.pyramid_temperature_cycle_length > 0.0:
            min_temp = 1e-5
            pyramid_temperature_multiplier *= (math.cos(
                (step) / FLAGS.pyramid_temperature_cycle_length) + 1 +
                                               min_temp) / 2
    else:
        pyramid_temperature_multiplier = None

    model.eval()

    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids, _, silver_tree = batch
        # eval_X_batch: <batch x maxlen x 2>
        # eval_y_batch: <batch >
        # silver_tree:
        # the dist is invalid for val

        # Run model.
        output = model(
            eval_X_batch,
            eval_transitions_batch,
            eval_y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions,
            pyramid_temperature_multiplier=pyramid_temperature_multiplier,
            store_parse_masks=True,
            example_lengths=eval_num_transitions_batch)

        # TODO: Restore support in Pyramid if using.
        can_sample = FLAGS.model_type in [
            "ChoiPyramid"
        ] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            # tree_strs = prettyprint_trees(tmp_samples)
            tree_strs = [tree for tree in tmp_samples]

        tmp_samples = model.get_samples(eval_X_batch,
                                        vocabulary,
                                        only_one=False)

        # def get_max(s):
        #     # test f1
        #     max = 0
        #     for x in s:
        #         _, idx = x.split(',')
        #         if int(idx) > max:
        #             max = int(idx)
        #     return max

        for s in (range(int(model.use_sentence_pair) + 1)):
            for b in (range(silver_tree.shape[0])):
                model_out = tmp_samples[s * silver_tree.shape[0] + b]
                std_out = silver_tree[b, :, s]
                std_out = set([x for x in std_out if x != '-1,-1'])
                model_out_brackets, model_out_max_l = get_brackets(model_out)
                model_out = set(convert_brackets_to_string(model_out_brackets))

                outmost_bracket = '{:d},{:d}'.format(0, model_out_max_l)
                std_out.add(outmost_bracket)
                model_out.add(outmost_bracket)

                # print get_max(model_out), get_max(std_out)
                # print model_out
                # print std_out
                # print '=' * 30
                # assert get_max(model_out) == get_max(std_out)

                overlap = model_out & std_out
                prec = float(len(overlap)) / (len(model_out) + 1e-8)
                reca = float(len(overlap)) / (len(std_out) + 1e-8)
                if len(std_out) == 0:
                    reca = 1.
                    if len(model_out) == 0:
                        prec = 1.
                f1 = 2 * prec * reca / (prec + reca + 1e-8)
                A.add('f1', f1)

        if not FLAGS.write_eval_report:
            # Only show one sample, regardless of the number of batches.
            show_sample = False

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()

        # get the index of the max log-probability
        pred = logits.data.max(1, keepdim=False)[1].cpu()

        eval_accumulate(model, data_manager, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Optionally calculate transition loss/acc.
        #model.transition_loss if hasattr(model, 'transition_loss') else None
        # TODO: review this. the original line seems to have no effect

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / 2
                             for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                    FLAGS.model_type == "SPINN"
                    and FLAGS.use_internal_parser) else (None, None)

            if model.use_sentence_pair:
                batch_size = pred.size(0)
                sent1_transitions = transitions_per_example[:
                                                            batch_size] if transitions_per_example is not None else None
                sent2_transitions = transitions_per_example[
                    batch_size:] if transitions_per_example is not None else None

                sent1_trees = tree_strs[:
                                        batch_size] if tree_strs is not None else None
                sent2_trees = tree_strs[
                    batch_size:] if tree_strs is not None else None
            else:
                sent1_transitions = transitions_per_example if transitions_per_example is not None else None
                sent2_transitions = None

                sent1_trees = tree_strs if tree_strs is not None else None
                sent2_trees = None

            reporter.save_batch(pred, target, eval_ids,
                                output.data.cpu().numpy(), sent1_transitions,
                                sent2_transitions, sent1_trees, sent2_trees)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()
    if tree_strs is not None:
        logger.Log('Sample: ' + str(tree_strs[0]))

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    eval_stats(model, A, eval_log)  # get the eval statistics (e.g. average F1)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy
    eval_f1 = eval_log.f1

    return eval_class_acc, eval_trans_acc, eval_f1
Exemple #5
0
def evaluate(FLAGS,
             model,
             eval_set,
             log_entry,
             logger,
             trainer,
             vocabulary=None,
             show_sample=False,
             eval_index=0,
             target_vocabulary=None):
    filename, dataset = eval_set

    A = Accumulator()
    len(log_entry.evaluation)
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    model.eval()
    ref_file_name = FLAGS.log_path + "/ref_file"
    pred_file_name = FLAGS.log_path + "/pred_file"
    reference_file = open(ref_file_name, "w")
    predict_file = open(pred_file_name, "w")
    full_ref = []
    full_pred = []
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        output = model(eval_X_batch,
                       eval_transitions_batch,
                       eval_y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       example_lengths=eval_num_transitions_batch)

        can_sample = (FLAGS.model_type == "RLSPINN"
                      and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.encoder.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)

        if not FLAGS.write_eval_report:
            # Only show one sample, regardless of the number of batches.
            show_sample = False

        # Get reference translation
        ref_out = [" ".join(map(str, k[:-1])) + " ." for k in eval_y_batch]
        full_ref += ref_out

        # Get predicted translation
        predicted = [[] for i in range(len(eval_y_batch))]
        done = []
        for x in output:
            index = -1
            for x_0 in x:
                index += 1
                val = int(x_0)
                if val == 1:
                    if index in done:
                        continue
                    done.append(index)
                elif index not in done:
                    predicted[index].append(val)
        pred_out = [" ".join(map(str, k)) + " ." for k in predicted]
        full_pred += pred_out

        eval_accumulate(model, A, batch)

        # Optionally calculate transition loss/acc.
        model.encoder.transition_loss if hasattr(model.encoder,
                                                 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / \
                            2 for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.encoder.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                    FLAGS.model_type == "SPINN"
                    and FLAGS.use_internal_parser) else (None, None)

            sent1_transitions = transitions_per_example if transitions_per_example is not None else None
            sent2_transitions = None

            sent1_trees = tree_strs if tree_strs is not None else None
            sent2_trees = None
            reporter.save_batch(full_pred,
                                full_ref,
                                eval_ids, [None],
                                sent1_transitions,
                                sent2_transitions,
                                sent1_trees,
                                sent2_trees,
                                mt=True)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()

    if tree_strs is not None:
        logger.Log('Sample: ' + tree_strs[0])

    reference_file.write("\n".join(full_ref))
    reference_file.close()
    predict_file.write("\n".join(full_pred))
    predict_file.close()

    bleu_score = os.popen("perl spinn/util/multi-bleu.perl " + ref_file_name +
                          " < " + pred_file_name).read()
    try:
        bleu_score = float(bleu_score)
    except:
        bleu_score = 0.0

    end = time.time()
    total_time = end - start
    A.add('class_correct', bleu_score)
    A.add('class_total', 1)
    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)
    eval_stats(model, A, eval_log)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)
        stats = parse_comparison.run_main(
            data_type="mt",
            main_report_path_template=FLAGS.log_path + "/" +
            FLAGS.experiment_name + ".eval_set_0.report",
            main_data_path=FLAGS.source_eval_path)
        # To-do: include the following into lgog-formatter so it's reported in standard format.
        if tree_strs is not None:
            logger.Log(
                'F1 w/ GT: ' + str(stats['gt']) + '\n' +\
                'F1 w/ LB: ' + str(stats['lb']) + '\n' +\
                'F1 w/ RB: ' + str(stats['rb']) + '\n' +\
                'Avg. tree depth: ' + str(stats['depth'])
                )

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc
Exemple #6
0
def evaluate(model, eval_set, logger, metrics_logger, step, vocabulary=None):
    filename, dataset = eval_set

    reporter = EvalReporter()

    # Evaluate
    class_correct = 0
    class_total = 0
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    model.eval()

    transition_preds = []
    transition_targets = []

    for i, (eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids) in enumerate(dataset):
        if FLAGS.truncate_eval_batch:
            eval_X_batch, eval_transitions_batch = truncate(
                eval_X_batch, eval_transitions_batch, eval_num_transitions_batch)

        if FLAGS.saving_eval_attention_matrix:
            model.set_recording_attention_weight_matrix(True)

        # Run model.
        output = model(eval_X_batch, eval_transitions_batch, eval_y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions,)

        if FLAGS.saving_eval_attention_matrix:
            # WARNING: only attention SPINN model have attention matrix
            attention_matrix = model.get_attention_matrix_from_last_forward()
            with open(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name, 'attention-matrix-{}.txt'.format(step)), 'a') as txtfile:
                for eval_id, attmat in izip(eval_ids, attention_matrix):
                    txtfile.write('{}\n'.format(eval_id))
                    txtfile.write('{},{}\n'.format(len(attmat), len(attmat[0])))
                    for row in attmat:
                        txtfile.write(','.join(['{:.1f}'.format(x*100.0) for x in row]))
                        txtfile.write('\n')
            model.set_recording_attention_weight_matrix(False) # reset it after run

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()
        pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability
        class_correct += pred.eq(target).sum()
        class_total += target.size(0)

        # Optionally calculate transition loss/acc.
        transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += eval_num_transitions_batch.ravel().sum()

        # Accumulate stats for transition accuracy.
        if transition_loss is not None:
            transition_preds.append([m["t_preds"] for m in model.spinn.memories])
            transition_targets.append([m["t_given"] for m in model.spinn.memories])

        if FLAGS.write_eval_report:
            reporter_args = [pred, target, eval_ids, output.data.cpu().numpy()]
            if hasattr(model, 'transition_loss'):
                transition_preds_per_example = model.spinn.get_transition_preds_per_example()
                if model.use_sentence_pair:
                    batch_size = pred.size(0)
                    sent1_preds = transition_preds_per_example[:batch_size]
                    sent2_preds = transition_preds_per_example[batch_size:]
                    reporter_args.append(sent1_preds)
                    reporter_args.append(sent2_preds)
                else:
                    reporter_args.append(transition_preds_per_example)
            reporter.save_batch(*reporter_args)

        # Print Progress
        progress_bar.step(i+1, total=total_batches)
    progress_bar.finish()

    end = time.time()
    total_time = end - start

    # Get time per token.
    time_metric = time_per_token([total_tokens], [total_time])

    # Get class accuracy.
    eval_class_acc = class_correct / float(class_total)

    # Get transition accuracy if applicable.
    if len(transition_preds) > 0:
        all_preds = np.array(flatten(transition_preds))
        all_truth = np.array(flatten(transition_targets))
        eval_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0])
    else:
        eval_trans_acc = 0.0

    logger.Log("Step: %i Eval acc: %f  %f %s Time: %5f" %
              (step, eval_class_acc, eval_trans_acc, filename, time_metric))

    metrics_logger.Log('eval_class_acc', eval_class_acc, step)
    metrics_logger.Log('eval_trans_acc', eval_trans_acc, step)

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".report")
        reporter.write_report(eval_report_path)

    return eval_class_acc