コード例 #1
0
ファイル: test_utilities.py プロジェクト: nimmen/skll
def test_compute_eval_from_predictions():
    """
    Test compute_eval_from_predictions function console script
    """

    pred_path = join(_my_dir, 'other',
                     'test_compute_eval_from_predictions.predictions')
    input_path = join(_my_dir, 'other',
                      'test_compute_eval_from_predictions.jsonlines')

    # we need to capture stdout since that's what main() writes to
    compute_eval_from_predictions_cmd = [input_path, pred_path, 'pearson',
                                         'unweighted_kappa']
    try:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = mystdout = StringIO()
        sys.stderr = mystderr = StringIO()
        cefp.main(compute_eval_from_predictions_cmd)
        score_rows = mystdout.getvalue().strip().split('\n')
        err = mystderr.getvalue()
        print(err)
    finally:
        sys.stdout = old_stdout
        sys.stderr = old_stderr

    scores = {}
    for score_row in score_rows:
        score, metric_name, pred_path = score_row.split('\t')
        scores[metric_name] = float(score)

    assert_almost_equal(scores['pearson'], 0.6197797868009122)
    assert_almost_equal(scores['unweighted_kappa'], 0.2)
コード例 #2
0
def test_warning_when_prediction_method_and_no_probabilities():
    """
    Test compute_eval_from_predictions logs a warning if a prediction method is provided
    but the predictions file doesn't contain probabilities.
    """
    lc = LogCapture()
    lc.begin()

    pred_path = join(_my_dir, 'other',
                     'test_compute_eval_from_predictions_predictions.tsv')
    input_path = join(_my_dir, 'other',
                      'test_compute_eval_from_predictions.jsonlines')

    # we need to capture stdout since that's what main() writes to
    compute_eval_from_predictions_cmd = [input_path, pred_path, 'pearson',
                                         'unweighted_kappa', '--method', 'highest']
    try:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = mystdout = StringIO()
        sys.stderr = mystderr = StringIO()
        cefp.main(compute_eval_from_predictions_cmd)
        score_rows = mystdout.getvalue().strip().split('\n')
        err = mystderr.getvalue()
        print(err)
    finally:
        sys.stdout = old_stdout
        sys.stderr = old_stderr

    log_msg = ("skll.utilities.compute_eval_from_predictions: WARNING: A prediction "
               "method was provided, but the predictions file doesn't contain "
               "probabilities. Ignoring prediction method 'highest'.")

    eq_(lc.handler.buffer[-1], log_msg)
コード例 #3
0
def test_conflicting_prediction_and_example_ids():
    """
    Make sure compute_eval_from_predictions breaks with ValueError when predictions and
    examples don't have the same id set in 'compute_eval_from_predictions'.
    """
    pred_path = join(_my_dir, 'other',
                     'test_compute_eval_from_predictions_probs_predictions.tsv')
    input_path = join(_my_dir, 'other',
                      'test_compute_eval_from_predictions_different_ids.jsonlines')

    compute_eval_from_predictions_cmd = [input_path, pred_path, 'pearson']
    cefp.main(compute_eval_from_predictions_cmd)
コード例 #4
0
def test_compute_eval_from_predictions_breaks_with_expval_and_nonnumeric_classes():
    """
    Make sure compute_eval_from_predictions breaks with ValueError when predictions are
    calculated via expected_value and the classes are non numeric.
    """

    pred_path = join(_my_dir, 'other',
                     'test_compute_eval_from_predictions_nonnumeric_classes_predictions.tsv')
    input_path = join(_my_dir, 'other',
                      'test_compute_eval_from_predictions_nonnumeric_classes.jsonlines')

    compute_eval_from_predictions_cmd = [input_path, pred_path, 'explained_variance',
                                         'r2', '--method', 'expected_value']
    cefp.main(compute_eval_from_predictions_cmd)
コード例 #5
0
def test_compute_eval_from_predictions_with_probs():
    """
    Test compute_eval_from_predictions function console script, with probabilities in
    the predictions file.
    """

    pred_path = join(_my_dir, 'other',
                     'test_compute_eval_from_predictions_probs_predictions.tsv')
    input_path = join(_my_dir, 'other',
                      'test_compute_eval_from_predictions_probs.jsonlines')

    # we need to capture stdout since that's what main() writes to
    compute_eval_from_predictions_cmd = [input_path, pred_path, 'pearson',
                                         'unweighted_kappa']
    try:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = mystdout = StringIO()
        sys.stderr = mystderr = StringIO()
        cefp.main(compute_eval_from_predictions_cmd)
        score_rows = mystdout.getvalue().strip().split('\n')
        err = mystderr.getvalue()
        print(err)
    finally:
        sys.stdout = old_stdout
        sys.stderr = old_stderr

    scores = {}
    for score_row in score_rows:
        score, metric_name, pred_path = score_row.split('\t')
        scores[metric_name] = float(score)

    assert_almost_equal(scores['pearson'], 0.6197797868009122)
    assert_almost_equal(scores['unweighted_kappa'], 0.2)


    #
    # Test expected value predictions method
    #
    compute_eval_from_predictions_cmd = [input_path, pred_path, 'explained_variance',
                                         'r2' ,'--method', 'expected_value']
    try:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = mystdout = StringIO()
        sys.stderr = mystderr = StringIO()
        cefp.main(compute_eval_from_predictions_cmd)
        score_rows = mystdout.getvalue().strip().split('\n')
        err = mystderr.getvalue()
        print(err)
    finally:
        sys.stdout = old_stdout
        sys.stderr = old_stderr

    scores = {}
    for score_row in score_rows:
        score, metric_name, pred_path = score_row.split('\t')
        scores[metric_name] = float(score)

    assert_almost_equal(scores['r2'], 0.19999999999999996)
    assert_almost_equal(scores['explained_variance'], 0.23809523809523792)