def test_roc_curves_from_test_statistics_vis_api(csv_filename):
    """Ensure pdf and png figures can be saved via visualization API call.

    :param csv_filename: csv fixture from tests.fixtures.filenames.csv_filename
    :return: None
    """
    input_features = [binary_feature(), bag_feature()]
    output_features = [binary_feature()]
    encoder = 'parallel_cnn'

    # Generate test data
    data_csv = generate_data(input_features, output_features, csv_filename)
    output_feature_name = output_features[0]['name']
    input_features[0]['encoder'] = encoder
    model = run_api_experiment(input_features, output_features)
    data_df = read_csv(data_csv)
    model.train(data_df=data_df)
    test_stats = model.test(data_df=data_df)[1]
    viz_outputs = ('pdf', 'png')
    for viz_output in viz_outputs:
        vis_output_pattern_pdf = model.exp_dir_name + '/*.{}'.format(
            viz_output)
        visualize.roc_curves_from_test_statistics(
            [test_stats, test_stats],
            output_feature_name,
            model_namess=['Model1', 'Model2'],
            output_directory=model.exp_dir_name,
            file_format=viz_output)
        figure_cnt = glob.glob(vis_output_pattern_pdf)
        assert 1 == len(figure_cnt)
    shutil.rmtree(model.exp_dir_name, ignore_errors=True)
def test_roc_curves_from_test_statistics_vis_api(csv_filename):
    """Ensure pdf and png figures can be saved via visualization API call.

    :param csv_filename: csv fixture from tests.fixtures.filenames.csv_filename
    :return: None
    """
    input_features = [binary_feature(), bag_feature()]
    output_features = [binary_feature()]

    # Generate test data
    data_csv = generate_data(input_features, output_features, csv_filename)
    output_feature_name = output_features[0]['name']

    model = run_api_experiment(input_features, output_features)
    data_df = read_csv(data_csv)
    _, _, output_dir = model.train(dataset=data_df)
    # extract test metrics
    test_stats, _, _ = model.evaluate(dataset=data_df,
                                      collect_overall_stats=True,
                                      output_directory=output_dir)
    test_stats = test_stats
    viz_outputs = ('pdf', 'png')
    for viz_output in viz_outputs:
        vis_output_pattern_pdf = os.path.join(output_dir, '*.{}'.format(
            viz_output))
        visualize.roc_curves_from_test_statistics(
            [test_stats, test_stats],
            output_feature_name,
            model_names=['Model1', 'Model2'],
            output_directory=output_dir,
            file_format=viz_output
        )
        figure_cnt = glob.glob(vis_output_pattern_pdf)
        assert 1 == len(figure_cnt)
    shutil.rmtree(output_dir, ignore_errors=True)
Exemple #3
0
def test_roc_curves_from_test_statistics_vis_api(csv_filename):
    """Ensure pdf and png figures can be saved via visualization API call.

    :param csv_filename: csv fixture from tests.fixtures.filenames.csv_filename
    :return: None
    """
    input_features = [binary_feature(), bag_feature()]
    output_features = [binary_feature()]

    with TemporaryDirectory() as tmpvizdir:
        # Generate test data
        data_csv = generate_data(input_features, output_features,
                                 os.path.join(tmpvizdir, csv_filename))
        output_feature_name = output_features[0]["name"]

        model = run_api_experiment(input_features, output_features)
        data_df = read_csv(data_csv)
        _, _, output_dir = model.train(dataset=data_df,
                                       output_directory=os.path.join(
                                           tmpvizdir, "results"))
        # extract test metrics
        test_stats, _, _ = model.evaluate(dataset=data_df,
                                          collect_overall_stats=True,
                                          output_directory=output_dir)
        test_stats = test_stats
        viz_outputs = ("pdf", "png")
        for viz_output in viz_outputs:
            vis_output_pattern_pdf = os.path.join(output_dir,
                                                  f"*.{viz_output}")
            visualize.roc_curves_from_test_statistics(
                [test_stats, test_stats],
                output_feature_name,
                model_names=["Model1", "Model2"],
                output_directory=output_dir,
                file_format=viz_output,
            )
            figure_cnt = glob.glob(vis_output_pattern_pdf)
            assert 1 == len(figure_cnt)