コード例 #1
0
def _get_parallel_coordinate_plot(
    study: Study,
    params: Optional[List[str]] = None,
    target: Optional[Callable[[FrozenTrial], float]] = None,
    target_name: str = "Objective Value",
) -> "Axes":

    if target is None:

        def _target(t: FrozenTrial) -> float:
            return cast(float, t.value)

        target = _target
        reversescale = study.direction == StudyDirection.MINIMIZE
    else:
        reversescale = True

    # Set up the graph style.
    fig, ax = plt.subplots()
    cmap = plt.get_cmap("Blues_r" if reversescale else "Blues")
    ax.set_title("Parallel Coordinate Plot")
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)

    # Prepare data for plotting.
    trials = [
        trial for trial in study.trials if trial.state == TrialState.COMPLETE
    ]

    if len(trials) == 0:
        _logger.warning("Your study does not have any completed trials.")
        return ax

    all_params = {p_name for t in trials for p_name in t.params.keys()}
    if params is not None:
        for input_p_name in params:
            if input_p_name not in all_params:
                raise ValueError(
                    "Parameter {} does not exist in your study.".format(
                        input_p_name))
        all_params = set(params)
    sorted_params = sorted(list(all_params))

    obj_org = [target(t) for t in trials]
    obj_min = min(obj_org)
    obj_max = max(obj_org)
    obj_w = obj_max - obj_min
    dims_obj_base = [[o] for o in obj_org]

    cat_param_names = []
    cat_param_values = []
    cat_param_ticks = []
    log_param_names = []
    param_values = []
    var_names = [target_name]
    for p_name in sorted_params:
        values = [
            t.params[p_name] if p_name in t.params else np.nan for t in trials
        ]

        if _is_log_scale(trials, p_name):
            p_min = math.log10(min(values))
            p_max = math.log10(max(values))
            p_w = p_max - p_min
            log_param_names.append(p_name)
            for i, v in enumerate(values):
                dims_obj_base[i].append((math.log10(v) - p_min) / p_w * obj_w +
                                        obj_min)
        elif _is_categorical(trials, p_name):
            vocab = defaultdict(
                lambda: len(vocab))  # type: DefaultDict[str, int]
            values = [vocab[v] for v in values]
            cat_param_names.append(p_name)
            vocab_item_sorted = sorted(vocab.items(), key=lambda x: x[1])
            cat_param_values.append([v[0] for v in vocab_item_sorted])
            cat_param_ticks.append([v[1] for v in vocab_item_sorted])
            p_min = min(values)
            p_max = max(values)
            p_w = p_max - p_min
            for i, v in enumerate(values):
                dims_obj_base[i].append((v - p_min) / p_w * obj_w + obj_min)
        else:
            p_min = min(values)
            p_max = max(values)
            p_w = p_max - p_min

            for i, v in enumerate(values):
                dims_obj_base[i].append((v - p_min) / p_w * obj_w + obj_min)

        var_names.append(
            p_name if len(p_name) < 20 else "{}...".format(p_name[:17]))
        param_values.append(values)

    # Draw multiple line plots and axes.
    # Ref: https://stackoverflow.com/a/50029441
    ax.set_xlim(0, len(sorted_params))
    ax.set_ylim(obj_min, obj_max)
    xs = [range(0, len(sorted_params) + 1) for i in range(len(dims_obj_base))]
    segments = [np.column_stack([x, y]) for x, y in zip(xs, dims_obj_base)]
    lc = LineCollection(segments, cmap=cmap)
    lc.set_array(np.asarray([target(t) for t in trials] + [0]))
    axcb = fig.colorbar(lc, pad=0.1)
    axcb.set_label(target_name)
    plt.xticks(range(0, len(sorted_params) + 1), var_names, rotation=330)

    for i, p_name in enumerate(sorted_params):
        ax2 = ax.twinx()
        ax2.set_ylim(min(param_values[i]), max(param_values[i]))
        if _is_log_scale(trials, p_name):
            ax2.set_yscale("log")
        ax2.spines["top"].set_visible(False)
        ax2.spines["bottom"].set_visible(False)
        ax2.get_xaxis().set_visible(False)
        ax2.plot([1] * len(param_values[i]), param_values[i], visible=False)
        ax2.spines["right"].set_position(
            ("axes", (i + 1) / len(sorted_params)))
        if p_name in cat_param_names:
            idx = cat_param_names.index(p_name)
            tick_pos = cat_param_ticks[idx]
            tick_labels = cat_param_values[idx]
            ax2.set_yticks(tick_pos)
            ax2.set_yticklabels(tick_labels)

    ax.add_collection(lc)

    return ax
コード例 #2
0
ファイル: _parallel_coordinate.py プロジェクト: optuna/optuna
def _get_parallel_coordinate_plot(
    study: Study,
    params: Optional[List[str]] = None,
    target: Optional[Callable[[FrozenTrial], float]] = None,
    target_name: str = "Objective Value",
) -> "Axes":

    if target is None:

        def _target(t: FrozenTrial) -> float:
            return cast(float, t.value)

        target = _target
        reversescale = study.direction == StudyDirection.MINIMIZE
    else:
        reversescale = True

    # Set up the graph style.
    fig, ax = plt.subplots()
    cmap = plt.get_cmap("Blues_r" if reversescale else "Blues")
    ax.set_title("Parallel Coordinate Plot")
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)

    # Prepare data for plotting.
    trials = [
        trial for trial in study.trials if trial.state == TrialState.COMPLETE
    ]

    if len(trials) == 0:
        _logger.warning("Your study does not have any completed trials.")
        return ax

    all_params = {p_name for t in trials for p_name in t.params.keys()}
    if params is not None:
        for input_p_name in params:
            if input_p_name not in all_params:
                raise ValueError(
                    "Parameter {} does not exist in your study.".format(
                        input_p_name))
        all_params = set(params)
    sorted_params = sorted(all_params)

    skipped_trial_numbers = _get_skipped_trial_numbers(trials, sorted_params)

    obj_org = [
        target(t) for t in trials if t.number not in skipped_trial_numbers
    ]

    if len(obj_org) == 0:
        _logger.warning(
            "Your study has only completed trials with missing parameters.")
        return ax

    obj_min = min(obj_org)
    obj_max = max(obj_org)
    obj_w = obj_max - obj_min
    dims_obj_base = [[o] for o in obj_org]

    cat_param_names = []
    cat_param_values = []
    cat_param_ticks = []
    param_values = []
    var_names = [target_name]
    numeric_cat_params_indices: List[int] = []

    for param_index, p_name in enumerate(sorted_params):
        values = [
            t.params[p_name] for t in trials
            if t.number not in skipped_trial_numbers
        ]

        if _is_categorical(trials, p_name):
            vocab = defaultdict(
                lambda: len(vocab))  # type: DefaultDict[str, int]

            if _is_numerical(trials, p_name):
                _ = [vocab[v] for v in sorted(values)]
                numeric_cat_params_indices.append(param_index)

            values = [vocab[v] for v in values]

            cat_param_names.append(p_name)
            vocab_item_sorted = sorted(vocab.items(), key=lambda x: x[1])
            cat_param_values.append([v[0] for v in vocab_item_sorted])
            cat_param_ticks.append([v[1] for v in vocab_item_sorted])

        if _is_log_scale(trials, p_name):
            values_for_lc = [np.log10(v) for v in values]
        else:
            values_for_lc = values

        p_min = min(values_for_lc)
        p_max = max(values_for_lc)
        p_w = p_max - p_min

        if p_w == 0.0:
            center = obj_w / 2 + obj_min
            for i in range(len(values)):
                dims_obj_base[i].append(center)
        else:
            for i, v in enumerate(values_for_lc):
                dims_obj_base[i].append((v - p_min) / p_w * obj_w + obj_min)

        var_names.append(
            p_name if len(p_name) < 20 else "{}...".format(p_name[:17]))
        param_values.append(values)

    if numeric_cat_params_indices:
        # np.lexsort consumes the sort keys the order from back to front.
        # So the values of parameters have to be reversed the order.
        sorted_idx = np.lexsort([
            param_values[index] for index in numeric_cat_params_indices
        ][::-1])
        # Since the values are mapped to other categories by the index,
        # the index will be swapped according to the sorted index of numeric params.
        param_values = [list(np.array(v)[sorted_idx]) for v in param_values]

    # Draw multiple line plots and axes.
    # Ref: https://stackoverflow.com/a/50029441
    ax.set_xlim(0, len(sorted_params))
    ax.set_ylim(obj_min, obj_max)
    xs = [range(len(sorted_params) + 1) for _ in range(len(dims_obj_base))]
    segments = [np.column_stack([x, y]) for x, y in zip(xs, dims_obj_base)]
    lc = LineCollection(segments, cmap=cmap)
    lc.set_array(np.asarray(obj_org))
    axcb = fig.colorbar(lc, pad=0.1)
    axcb.set_label(target_name)
    plt.xticks(range(len(sorted_params) + 1), var_names, rotation=330)

    for i, p_name in enumerate(sorted_params):
        ax2 = ax.twinx()
        ax2.set_ylim(min(param_values[i]), max(param_values[i]))
        if _is_log_scale(trials, p_name):
            ax2.set_yscale("log")
        ax2.spines["top"].set_visible(False)
        ax2.spines["bottom"].set_visible(False)
        ax2.xaxis.set_visible(False)
        ax2.plot([1] * len(param_values[i]), param_values[i], visible=False)
        ax2.spines["right"].set_position(
            ("axes", (i + 1) / len(sorted_params)))
        if p_name in cat_param_names:
            idx = cat_param_names.index(p_name)
            tick_pos = cat_param_ticks[idx]
            tick_labels = cat_param_values[idx]
            ax2.set_yticks(tick_pos)
            ax2.set_yticklabels(tick_labels)

    ax.add_collection(lc)

    return ax
コード例 #3
0
def _get_parallel_coordinate_plot(info: _ParallelCoordinateInfo) -> "Axes":

    reversescale = info.reverse_scale
    target_name = info.target_name

    # Set up the graph style.
    fig, ax = plt.subplots()
    cmap = plt.get_cmap("Blues_r" if reversescale else "Blues")
    ax.set_title("Parallel Coordinate Plot")
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)

    # Prepare data for plotting.
    if len(info.dims_params) == 0 or len(info.dim_objective.values) == 0:
        return ax

    obj_min = info.dim_objective.range[0]
    obj_max = info.dim_objective.range[1]
    obj_w = obj_max - obj_min
    dims_obj_base = [[o] for o in info.dim_objective.values]
    for dim in info.dims_params:
        p_min = dim.range[0]
        p_max = dim.range[1]
        p_w = p_max - p_min

        if p_w == 0.0:
            center = obj_w / 2 + obj_min
            for i in range(len(dim.values)):
                dims_obj_base[i].append(center)
        else:
            for i, v in enumerate(dim.values):
                dims_obj_base[i].append((v - p_min) / p_w * obj_w + obj_min)

    # Draw multiple line plots and axes.
    # Ref: https://stackoverflow.com/a/50029441
    n_params = len(info.dims_params)
    ax.set_xlim(0, n_params)
    ax.set_ylim(info.dim_objective.range[0], info.dim_objective.range[1])
    xs = [range(n_params + 1) for _ in range(len(dims_obj_base))]
    segments = [np.column_stack([x, y]) for x, y in zip(xs, dims_obj_base)]
    lc = LineCollection(segments, cmap=cmap)
    lc.set_array(np.asarray(info.dim_objective.values))
    axcb = fig.colorbar(lc, pad=0.1)
    axcb.set_label(target_name)
    var_names = [info.dim_objective.label
                 ] + [dim.label for dim in info.dims_params]
    plt.xticks(range(n_params + 1), var_names, rotation=330)

    for i, dim in enumerate(info.dims_params):
        ax2 = ax.twinx()
        if dim.is_log:
            ax2.set_ylim(np.power(10, dim.range[0]),
                         np.power(10, dim.range[1]))
            ax2.set_yscale("log")
        else:
            ax2.set_ylim(dim.range[0], dim.range[1])
        ax2.spines["top"].set_visible(False)
        ax2.spines["bottom"].set_visible(False)
        ax2.xaxis.set_visible(False)
        ax2.spines["right"].set_position(("axes", (i + 1) / n_params))
        if dim.is_cat:
            ax2.set_yticks(dim.tickvals)
            ax2.set_yticklabels(dim.ticktext)

    ax.add_collection(lc)

    return ax