def evaluate(FLAGS, model, data_manager, eval_set, log_entry,
             logger, step, vocabulary=None, show_sample=False, eval_index=0):
    filename, dataset = eval_set

    A = Accumulator()
    index = len(log_entry.evaluation)
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]:
        pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps ** (
            step / 10000.0)
        if FLAGS.pyramid_temperature_cycle_length > 0.0:
            min_temp = 1e-5
            pyramid_temperature_multiplier *= (math.cos((step) /
                                                        FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2
    else:
        pyramid_temperature_multiplier = None

    model.eval()
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        output = model(eval_X_batch, eval_transitions_batch, eval_y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                       store_parse_masks=show_sample,
                       example_lengths=eval_num_transitions_batch)

        can_sample = FLAGS.model_type in ["ChoiPyramid"] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser)  # TODO: Restore support in Pyramid if using.
        if show_sample and can_sample:
            tmp_samples = model.get_samples(eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)
        if not FLAGS.write_eval_report:
            show_sample = False  # Only show one sample, regardless of the number of batches.


        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()

        # get the index of the max log-probability
        pred = logits.data.max(1, keepdim=False)[1].cpu()

        eval_accumulate(model, data_manager, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Optionally calculate transition loss/acc.
        model.transition_loss if hasattr(model, 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.spinn.get_transitions_per_example(
                    style="preds" if FLAGS.eval_report_use_preds else "given") if (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None)

            if model.use_sentence_pair:
                batch_size = pred.size(0)
                sent1_transitions = transitions_per_example[:batch_size] if transitions_per_example is not None else None
                sent2_transitions = transitions_per_example[batch_size:] if transitions_per_example is not None else None

                sent1_trees = tree_strs[:batch_size] if tree_strs is not None else None
                sent2_trees = tree_strs[batch_size:] if tree_strs is not None else None
            else:
                sent1_transitions = transitions_per_example if transitions_per_example is not None else None
                sent2_transitions = None

                sent1_trees = tree_strs if tree_strs is not None else None
                sent2_trees = None

            reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()
    if tree_strs is not None:
        logger.Log('Sample: ' + tree_strs[0])

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    eval_stats(model, A, eval_log)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc
Esempio n. 2
0
def evaluate(FLAGS,
             model,
             eval_set,
             log_entry,
             logger,
             trainer,
             vocabulary=None,
             show_sample=False,
             eval_index=0):
    filename, dataset = eval_set

    A = Accumulator()
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    model.eval()
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        output = model(eval_X_batch,
                       eval_transitions_batch,
                       eval_y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       store_parse_masks=show_sample,
                       example_lengths=eval_num_transitions_batch)

        can_sample = (FLAGS.model_type == "RLSPINN"
                      and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)
        if not FLAGS.write_eval_report:
            # Only show one sample, regardless of the number of batches.
            show_sample = False

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()

        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=False)[1].cpu()

        eval_accumulate(model, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / 2
                             for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                    FLAGS.model_type == "SPINN"
                    and FLAGS.use_internal_parser) else (None, None)

            if model.use_sentence_pair:
                batch_size = pred.size(0)
                sent1_transitions = transitions_per_example[:
                                                            batch_size] if transitions_per_example is not None else None
                sent2_transitions = transitions_per_example[
                    batch_size:] if transitions_per_example is not None else None

                sent1_trees = tree_strs[:
                                        batch_size] if tree_strs is not None else None
                sent2_trees = tree_strs[
                    batch_size:] if tree_strs is not None else None
            else:
                sent1_transitions = transitions_per_example if transitions_per_example is not None else None
                sent2_transitions = None

                sent1_trees = tree_strs if tree_strs is not None else None
                sent2_trees = None

            reporter.save_batch(pred, target, eval_ids,
                                output.data.cpu().numpy(), sent1_transitions,
                                sent2_transitions, sent1_trees, sent2_trees)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()
    if tree_strs is not None:
        logger.Log('Sample: ' + tree_strs[0])

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    eval_stats(model, A, eval_log)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc
Esempio n. 3
0
def evaluate(FLAGS,
             model,
             eval_set,
             log_entry,
             logger,
             trainer,
             vocabulary=None,
             show_sample=False,
             eval_index=0,
             target_vocabulary=None):
    filename, dataset = eval_set

    A = Accumulator()
    len(log_entry.evaluation)
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    model.eval()
    ref_file_name = FLAGS.log_path + "/ref_file"
    pred_file_name = FLAGS.log_path + "/pred_file"
    reference_file = open(ref_file_name, "w")
    predict_file = open(pred_file_name, "w")
    full_ref = []
    full_pred = []
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        output = model(eval_X_batch,
                       eval_transitions_batch,
                       eval_y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       example_lengths=eval_num_transitions_batch)

        can_sample = (FLAGS.model_type == "RLSPINN"
                      and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.encoder.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)

        if not FLAGS.write_eval_report:
            # Only show one sample, regardless of the number of batches.
            show_sample = False

        # Get reference translation
        ref_out = [" ".join(map(str, k[:-1])) + " ." for k in eval_y_batch]
        full_ref += ref_out

        # Get predicted translation
        predicted = [[] for i in range(len(eval_y_batch))]
        done = []
        for x in output:
            index = -1
            for x_0 in x:
                index += 1
                val = int(x_0)
                if val == 1:
                    if index in done:
                        continue
                    done.append(index)
                elif index not in done:
                    predicted[index].append(val)
        pred_out = [" ".join(map(str, k)) + " ." for k in predicted]
        full_pred += pred_out

        eval_accumulate(model, A, batch)

        # Optionally calculate transition loss/acc.
        model.encoder.transition_loss if hasattr(model.encoder,
                                                 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / \
                            2 for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.encoder.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                    FLAGS.model_type == "SPINN"
                    and FLAGS.use_internal_parser) else (None, None)

            sent1_transitions = transitions_per_example if transitions_per_example is not None else None
            sent2_transitions = None

            sent1_trees = tree_strs if tree_strs is not None else None
            sent2_trees = None
            reporter.save_batch(full_pred,
                                full_ref,
                                eval_ids, [None],
                                sent1_transitions,
                                sent2_transitions,
                                sent1_trees,
                                sent2_trees,
                                mt=True)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()

    if tree_strs is not None:
        logger.Log('Sample: ' + tree_strs[0])

    reference_file.write("\n".join(full_ref))
    reference_file.close()
    predict_file.write("\n".join(full_pred))
    predict_file.close()

    bleu_score = os.popen("perl spinn/util/multi-bleu.perl " + ref_file_name +
                          " < " + pred_file_name).read()
    try:
        bleu_score = float(bleu_score)
    except:
        bleu_score = 0.0

    end = time.time()
    total_time = end - start
    A.add('class_correct', bleu_score)
    A.add('class_total', 1)
    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)
    eval_stats(model, A, eval_log)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)
        stats = parse_comparison.run_main(
            data_type="mt",
            main_report_path_template=FLAGS.log_path + "/" +
            FLAGS.experiment_name + ".eval_set_0.report",
            main_data_path=FLAGS.source_eval_path)
        # To-do: include the following into lgog-formatter so it's reported in standard format.
        if tree_strs is not None:
            logger.Log(
                'F1 w/ GT: ' + str(stats['gt']) + '\n' +\
                'F1 w/ LB: ' + str(stats['lb']) + '\n' +\
                'F1 w/ RB: ' + str(stats['rb']) + '\n' +\
                'Avg. tree depth: ' + str(stats['depth'])
                )

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc