Esempio n. 1
0
def test_confidence_thresholding_data_vs_acc_subset_per_class_vis_api(
        experiment_to_use
):
    """Ensure pdf and png figures can be saved via visualization API call.

    :param experiment_to_use: Object containing trained model and results to
        test visualization
    :return: None
    """
    experiment = experiment_to_use
    probabilities = experiment.probabilities
    viz_outputs = ('pdf', 'png')
    with TemporaryDirectory() as tmpvizdir:
        for viz_output in viz_outputs:
            vis_output_pattern_pdf = tmpvizdir + '/*.{}'.format(
                viz_output
            )
            visualize.confidence_thresholding_data_vs_acc_subset_per_class(
                [probabilities, probabilities],
                experiment.ground_truth,
                experiment.ground_truth_metadata,
                experiment.output_feature_name,
                top_n_classes=[3],
                labels_limit=0,
                subset='ground_truth',
                model_names=['Model1', 'Model2'],
                output_directory=tmpvizdir,
                file_format=viz_output
            )
            figure_cnt = glob.glob(vis_output_pattern_pdf)
            # 3 figures should be saved because experiment setting top_n_classes = 3
            # hence one figure per class
            assert 3 == len(figure_cnt)
Esempio n. 2
0
def test_confidence_thresholding_data_vs_acc_subset_per_class_vis_api(
        csv_filename):
    """Ensure pdf and png figures can be saved via visualization API call.

    :param csv_filename: csv fixture from tests.fixtures.filenames.csv_filename
    :return: None
    """
    experiment = Experiment(csv_filename)
    probability = experiment.probability
    viz_outputs = ('pdf', 'png')
    for viz_output in viz_outputs:
        vis_output_pattern_pdf = experiment.model.exp_dir_name + '/*.{}'.format(
            viz_output)
        visualize.confidence_thresholding_data_vs_acc_subset_per_class(
            [probability, probability],
            experiment.ground_truth,
            experiment.ground_truth_metadata,
            experiment.output_feature_name,
            top_n_classes=[3],
            labels_limit=0,
            subset='ground_truth',
            model_names=['Model1', 'Model2'],
            output_directory=experiment.model.exp_dir_name,
            file_format=viz_output)
        figure_cnt = glob.glob(vis_output_pattern_pdf)
        # 3 figures should be saved because experiment setting top_n_classes = 3
        # hence one figure per class
        assert 3 == len(figure_cnt)
    shutil.rmtree(experiment.model.exp_dir_name, ignore_errors=True)