def test_plot_intermediate_values() -> None:

    # Test with no trials.
    study = prepare_study_with_trials(no_trials=True)
    figure = plot_intermediate_values(study)
    assert len(figure.get_lines()) == 0
    plt.savefig(BytesIO())

    def objective(trial: Trial, report_intermediate_values: bool) -> float:

        if report_intermediate_values:
            trial.report(1.0, step=0)
            trial.report(2.0, step=1)
        return 0.0

    # Test with a trial with intermediate values.
    study = create_study()
    study.optimize(lambda t: objective(t, True), n_trials=1)
    figure = plot_intermediate_values(study)
    assert len(figure.get_lines()) == 1
    assert list(figure.get_lines()[0].get_xdata()) == [0, 1]
    assert list(figure.get_lines()[0].get_ydata()) == [1.0, 2.0]
    plt.savefig(BytesIO())

    # Test a study with one trial with intermediate values and
    # one trial without intermediate values.
    # Expect the trial with no intermediate values to be ignored.
    study.optimize(lambda t: objective(t, False), n_trials=1)
    assert len(study.trials) == 2
    figure = plot_intermediate_values(study)
    assert len(figure.get_lines()) == 1
    assert list(figure.get_lines()[0].get_xdata()) == [0, 1]
    assert list(figure.get_lines()[0].get_ydata()) == [1.0, 2.0]
    plt.savefig(BytesIO())

    # Test a study of only one trial that has no intermediate values.
    study = create_study()
    study.optimize(lambda t: objective(t, False), n_trials=1)
    figure = plot_intermediate_values(study)
    assert len(figure.get_lines()) == 0
    plt.savefig(BytesIO())

    # Ignore failed trials.
    def fail_objective(_: Trial) -> float:

        raise ValueError

    study = create_study()
    study.optimize(fail_objective, n_trials=1, catch=(ValueError, ))
    figure = plot_intermediate_values(study)
    assert len(figure.get_lines()) == 0
    plt.savefig(BytesIO())
Esempio n. 2
0
def test_plot_intermediate_values() -> None:

    # Test with no trials.
    study = prepare_study_with_trials(no_trials=True)
    figure = plot_intermediate_values(study)
    assert not figure.has_data()

    def objective(trial: Trial, report_intermediate_values: bool) -> float:

        if report_intermediate_values:
            trial.report(1.0, step=0)
            trial.report(2.0, step=1)
        return 0.0

    # Test with a trial with intermediate values.
    # TODO(ytknzw): Add more specific assertion with the test case.
    study = create_study()
    study.optimize(lambda t: objective(t, True), n_trials=1)
    figure = plot_intermediate_values(study)
    assert figure.has_data()

    # Test a study with one trial with intermediate values and
    # one trial without intermediate values.
    # Expect the trial with no intermediate values to be ignored.
    # TODO(ytknzw): Add more specific assertion with the test case.
    study.optimize(lambda t: objective(t, False), n_trials=1)
    assert len(study.trials) == 2
    figure = plot_intermediate_values(study)
    assert figure.has_data()

    # Test a study of only one trial that has no intermediate values.
    study = create_study()
    study.optimize(lambda t: objective(t, False), n_trials=1)
    figure = plot_intermediate_values(study)
    assert not figure.has_data()

    # Ignore failed trials.
    def fail_objective(_: Trial) -> float:

        raise ValueError

    study = create_study()
    study.optimize(fail_objective, n_trials=1, catch=(ValueError,))
    figure = plot_intermediate_values(study)
    assert not figure.has_data()
Esempio n. 3
0
def _log_plots(run,
               study: optuna.Study,
               visualization_backend='plotly',
               log_plot_contour=True,
               log_plot_edf=True,
               log_plot_parallel_coordinate=True,
               log_plot_param_importances=True,
               log_plot_pareto_front=True,
               log_plot_slice=True,
               log_plot_intermediate_values=True,
               log_plot_optimization_history=True,
               ):
    if visualization_backend == 'matplotlib':
        import optuna.visualization.matplotlib as vis
    elif visualization_backend == 'plotly':
        import optuna.visualization as vis
    else:
        raise NotImplementedError(f'{visualization_backend} visualisation backend is not implemented')

    if vis.is_available:
        params = list(p_name for t in study.trials for p_name in t.params.keys())

        if log_plot_contour and any(params):
            run['visualizations/plot_contour'] = neptune.types.File.as_html(vis.plot_contour(study))

        if log_plot_edf:
            run['visualizations/plot_edf'] = neptune.types.File.as_html(vis.plot_edf(study))

        if log_plot_parallel_coordinate:
            run['visualizations/plot_parallel_coordinate'] = \
                neptune.types.File.as_html(vis.plot_parallel_coordinate(study))

        if log_plot_param_importances and len(study.get_trials(states=(optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED,))) > 1:
            try:
                run['visualizations/plot_param_importances'] = neptune.types.File.as_html(vis.plot_param_importances(study))
            except (RuntimeError, ValueError, ZeroDivisionError):
                # Unable to compute importances
                pass

        if log_plot_pareto_front and study._is_multi_objective() and visualization_backend == 'plotly':
            run['visualizations/plot_pareto_front'] = neptune.types.File.as_html(vis.plot_pareto_front(study))

        if log_plot_slice and any(params):
            run['visualizations/plot_slice'] = neptune.types.File.as_html(vis.plot_slice(study))

        if log_plot_intermediate_values and any(trial.intermediate_values for trial in study.trials):
            # Intermediate values plot if available only if the above condition is met
            run['visualizations/plot_intermediate_values'] = \
                neptune.types.File.as_html(vis.plot_intermediate_values(study))

        if log_plot_optimization_history:
            run['visualizations/plot_optimization_history'] = \
                neptune.types.File.as_html(vis.plot_optimization_history(study))