Ejemplo n.º 1
0
def _get_pareto_front_3d(info: _ParetoFrontInfo) -> "Axes":
    # Set up the graph style.
    plt.style.use(
        "ggplot")  # Use ggplot style sheet for similar outputs to plotly.
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    ax.set_title("Pareto-front Plot")
    cmap = plt.get_cmap(
        "tab10")  # Use tab10 colormap for similar outputs to plotly.

    ax.set_xlabel(info.target_names[info.axis_order[0]])
    ax.set_ylabel(info.target_names[info.axis_order[1]])
    ax.set_zlabel(info.target_names[info.axis_order[2]])

    if info.non_best_trials_with_values is not None and len(
            info.non_best_trials_with_values) > 0:
        ax.scatter(
            xs=[
                values[info.axis_order[0]]
                for _, values in info.non_best_trials_with_values
            ],
            ys=[
                values[info.axis_order[1]]
                for _, values in info.non_best_trials_with_values
            ],
            zs=[
                values[info.axis_order[2]]
                for _, values in info.non_best_trials_with_values
            ],
            color=cmap(0),
            label="Trial",
        )

    if info.best_trials_with_values is not None and len(
            info.best_trials_with_values):
        ax.scatter(
            xs=[
                values[info.axis_order[0]]
                for _, values in info.best_trials_with_values
            ],
            ys=[
                values[info.axis_order[1]]
                for _, values in info.best_trials_with_values
            ],
            zs=[
                values[info.axis_order[2]]
                for _, values in info.best_trials_with_values
            ],
            color=cmap(3),
            label="Best Trial",
        )

    if info.non_best_trials_with_values is not None and ax.has_data():
        ax.legend()

    return ax
Ejemplo n.º 2
0
def _get_pareto_front_3d(
    study: Study,
    target_names: Optional[List[str]],
    include_dominated_trials: bool = False,
    axis_order: Optional[List[int]] = None,
) -> "Axes":

    # Set up the graph style.
    plt.style.use(
        "ggplot")  # Use ggplot style sheet for similar outputs to plotly.
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    ax.set_title("Pareto-front Plot")
    cmap = plt.get_cmap(
        "tab10")  # Use tab10 colormap for similar outputs to plotly.

    if target_names is None:
        target_names = ["Objective 0", "Objective 1", "Objective 2"]
    elif len(target_names) != 3:
        raise ValueError("The length of `target_names` is supposed to be 3.")

    trials = study.best_trials
    if len(trials) == 0:
        _logger.warning("Your study does not have any completed trials.")

    if include_dominated_trials:
        non_pareto_trials = _get_non_pareto_front_trials(study, trials)
        trials += non_pareto_trials

    if axis_order is None:
        axis_order = list(range(3))
    else:
        if len(axis_order) != 3:
            raise ValueError(
                f"Size of `axis_order` {axis_order}. Expect: 3, Actual: {len(axis_order)}."
            )
        if len(set(axis_order)) != 3:
            raise ValueError(
                f"Elements of given `axis_order` {axis_order} are not unique!."
            )
        if max(axis_order) > 2:
            raise ValueError(
                f"Given `axis_order` {axis_order} contains invalid index {max(axis_order)} "
                "higher than 2.")
        if min(axis_order) < 0:
            raise ValueError(
                f"Given `axis_order` {axis_order} contains invalid index {min(axis_order)} "
                "lower than 0.")

    ax.set_xlabel(target_names[axis_order[0]])
    ax.set_ylabel(target_names[axis_order[1]])
    ax.set_zlabel(target_names[axis_order[2]])

    if len(trials) - len(study.best_trials) != 0:
        ax.scatter(
            xs=[
                t.values[axis_order[0]]
                for t in trials[len(study.best_trials):]
            ],
            ys=[
                t.values[axis_order[1]]
                for t in trials[len(study.best_trials):]
            ],
            zs=[
                t.values[axis_order[2]]
                for t in trials[len(study.best_trials):]
            ],
            color=cmap(0),
            label="Trial",
        )

    if len(study.best_trials):
        ax.scatter(
            xs=[
                t.values[axis_order[0]]
                for t in trials[:len(study.best_trials)]
            ],
            ys=[
                t.values[axis_order[1]]
                for t in trials[:len(study.best_trials)]
            ],
            zs=[
                t.values[axis_order[2]]
                for t in trials[:len(study.best_trials)]
            ],
            color=cmap(3),
            label="Best Trial",
        )

    if include_dominated_trials and ax.has_data():
        ax.legend()

    return ax