Exemplo n.º 1
0
def test_compare_classifiers_multiclass_multimetric_vis_api(experiment_to_use):
    """Ensure pdf and png figures can be saved via visualization API call.

    :param experiment_to_use: Object containing trained model and results to
        test visualization
    :return: None
    """
    experiment = experiment_to_use
    # extract test stats only
    test_stats = experiment.test_stats_full
    viz_outputs = ('pdf', 'png')
    with TemporaryDirectory() as tmpvizdir:
        for viz_output in viz_outputs:
            vis_output_pattern_pdf = tmpvizdir + '/*.{}'.format(
                viz_output
            )
            visualize.compare_classifiers_multiclass_multimetric(
                [test_stats, test_stats],
                experiment.ground_truth_metadata,
                experiment.output_feature_name,
                top_n_classes=[6],
                model_namess=['Model1', 'Model2'],
                output_directory=tmpvizdir,
                file_format=viz_output
            )
            figure_cnt = glob.glob(vis_output_pattern_pdf)
            assert 4 == len(figure_cnt)
Exemplo n.º 2
0
def test_compare_classifiers_multiclass_multimetric_vis_api(csv_filename):
    """Ensure pdf and png figures can be saved via visualisation API call.

    :param csv_filename: csv fixture from tests.fixtures.filenames.csv_filename
    :return: None
    """
    experiment = Experiment(csv_filename)
    test_stats = experiment.test_stats_full[1]
    viz_outputs = ('pdf', 'png')
    for viz_output in viz_outputs:
        vis_output_pattern_pdf = experiment.model.exp_dir_name + '/*.{}'.format(
            viz_output
        )
        visualize.compare_classifiers_multiclass_multimetric(
            [test_stats, test_stats],
            experiment.ground_truth_metadata,
            experiment.output_feature_name,
            top_n_classes=[6],
            model_namess=['Model1', 'Model2'],
            output_directory=experiment.model.exp_dir_name,
            file_format=viz_output
        )
        figure_cnt = glob.glob(vis_output_pattern_pdf)
        assert 4 == len(figure_cnt)
    shutil.rmtree(experiment.model.exp_dir_name, ignore_errors=True)