예제 #1
0
def make_plot(
        plot_lists,
        use_median=False,
        plot_width=None,
        plot_height=None,
        title=None,
    ):
    """
    plot_lists is a list of lists.
    Each outer list represents different y-axis attributes.
    Each inner list represents different experiments to run, within that y-axis
    attribute.
    Each plot is an AttrDict which should have the elements used below.
    """

    x_axis = [(subplot['plot_key'], subplot['means']) for plot_list in plot_lists for subplot in plot_list if subplot['x_key']]
    plot_lists = [[subplot for subplot in plot_list] for plot_list in plot_lists if not plot_list[0]['x_key']]
    xlabel = x_axis[0][0] if len(x_axis) else 'iteration'

    p25, p50, p75 = [], [], []
    num_y_axes = len(plot_lists)
    fig = tools.make_subplots(rows=num_y_axes, cols=1, print_grid=False)
    fig['layout'].update(
        width=plot_width,
        height=plot_height,
        title=title,
    )

    for y_idx, plot_list in enumerate(plot_lists):
        for idx, plt in enumerate(plot_list):
            color = core.color_defaults[idx % len(core.color_defaults)]
            if use_median:
                p25.append(np.mean(plt.percentile25))
                p50.append(np.mean(plt.percentile50))
                p75.append(np.mean(plt.percentile75))
                if x_axis:
                    x = list(x_axis[idx][1])
                else:
                    x = list(range(len(plt.percentile50)))
                y = list(plt.percentile50)
                y_upper = list(plt.percentile75)
                y_lower = list(plt.percentile25)
            else:
                if x_axis:
                    x = list(x_axis[idx][1])
                else:
                    x = list(range(len(plt.means)))
                y = list(plt.means)
                y_upper = list(plt.means + plt.stds)
                y_lower = list(plt.means - plt.stds)

            errors = go.Scatter(
                x=x + x[::-1],
                y=y_upper + y_lower[::-1],
                fill='tozerox',
                fillcolor=core.hex_to_rgb(color, 0.2),
                line=go.scatter.Line(color=core.hex_to_rgb(color, 0)),
                showlegend=False,
                legendgroup=plt.legend,
                hoverinfo='none',
            )
            values = go.Scatter(
                x=x,
                y=y,
                name=plt.legend,
                legendgroup=plt.legend,
                line=dict(color=core.hex_to_rgb(color)),
                hoverlabel=dict(namelength=-1),
                hoverinfo='all',
            )
            # plotly is 1-indexed like matplotlib for subplots
            y_idx_plotly = y_idx + 1
            fig.append_trace(values, y_idx_plotly, 1)
            fig.append_trace(errors, y_idx_plotly, 1)
            title = plt.plot_key
            if len(title) > 30:
                title_parts = title.split('/')
                title = "<br />/".join(
                    title_parts[:-1]
                    + [r"<b>{}</b>".format(t) for t in title_parts[-1:]]
                )
            fig['layout']['yaxis{}'.format(y_idx_plotly)].update(
                title=title,
            )
            fig['layout']['xaxis{}'.format(y_idx_plotly)].update(
                title=xlabel,
            )

    fig_div = po.plot(fig, output_type='div', include_plotlyjs=False)
    if "footnote" in plot_list[0]:
        footnote = "<br />".join([
            r"<span><b>%s</b></span>: <span>%s</span>" % (
                plt.legend, plt.footnote)
            for plt in plot_list
        ])
        return r"%s<div>%s</div>" % (fig_div, footnote)
    else:
        return fig_div
예제 #2
0
def make_plot(plot_list,
              use_median=False,
              plot_width=None,
              plot_height=None,
              title=None):
    data = []
    p25, p50, p75 = [], [], []
    for idx, plt in enumerate(plot_list):
        color = core.color_defaults[idx % len(core.color_defaults)]
        if use_median:
            p25.append(np.mean(plt.percentile25))
            p50.append(np.mean(plt.percentile50))
            p75.append(np.mean(plt.percentile75))
            x = list(range(len(plt.percentile50)))
            y = list(plt.percentile50)
            y_upper = list(plt.percentile75)
            y_lower = list(plt.percentile25)
        else:
            x = list(range(len(plt.means)))
            y = list(plt.means)
            y_upper = list(plt.means + plt.stds)
            y_lower = list(plt.means - plt.stds)

        data.append(
            go.Scatter(x=x + x[::-1],
                       y=y_upper + y_lower[::-1],
                       fill='tozerox',
                       fillcolor=core.hex_to_rgb(color, 0.2),
                       line=go.Line(color='transparent'),
                       showlegend=False,
                       legendgroup=plt.legend,
                       hoverinfo='none'))
        data.append(
            go.Scatter(
                x=x,
                y=y,
                name=plt.legend,
                legendgroup=plt.legend,
                line=dict(color=core.hex_to_rgb(color)),
            ))
    p25str = '['
    p50str = '['
    p75str = '['
    for p25e, p50e, p75e in zip(p25, p50, p75):
        p25str += (str(p25e) + ',')
        p50str += (str(p50e) + ',')
        p75str += (str(p75e) + ',')
    p25str += ']'
    p50str += ']'
    p75str += ']'
    print(p25str)
    print(p50str)
    print(p75str)

    layout = go.Layout(
        legend=dict(
            x=1,
            y=1,
            # xanchor="left",
            # yanchor="bottom",
        ),
        width=plot_width,
        height=plot_height,
        title=title,
    )
    fig = go.Figure(data=data, layout=layout)
    fig_div = po.plot(fig, output_type='div', include_plotlyjs=False)
    if "footnote" in plot_list[0]:
        footnote = "<br />".join([
            r"<span><b>%s</b></span>: <span>%s</span>" %
            (plt.legend, plt.footnote) for plt in plot_list
        ])
        return r"%s<div>%s</div>" % (fig_div, footnote)
    else:
        return fig_div
예제 #3
0
def make_plot(plot_list, title=None):
    data = []
    for idx, plt in enumerate(plot_list):
        color = core.color_defaults[idx % len(core.color_defaults)]
        x = list(plt.xs)

        if plt.display_mode in ["mean_std", "mean_se"]:
            y = list(plt.means)
            if plt.display_mode == "mean_std":
                y_upper = list(plt.means + plt.stds)
                y_lower = list(plt.means - plt.stds)
            elif plt.display_mode == "mean_se":
                y_upper = list(plt.means + plt.ses)
                y_lower = list(plt.means - plt.ses)
            else:
                raise NotImplementedError
            data.append(
                go.Scatter(x=x + x[::-1],
                           y=y_upper + y_lower[::-1],
                           fill='tozerox',
                           fillcolor=core.hex_to_rgb(color, 0.2),
                           line=go.Line(color='transparent'),
                           showlegend=False,
                           legendgroup=plt.legend,
                           hoverinfo='none'))
            data.append(
                go.Scatter(
                    x=x,
                    y=y,
                    name=plt.legend,
                    legendgroup=plt.legend,
                    line=dict(color=core.hex_to_rgb(color)),
                ))
        elif plt.display_mode == "individual":
            for idx, y in enumerate(plt.ys):
                data.append(
                    go.Scatter(
                        x=x,
                        y=y,
                        name=plt.legend,
                        legendgroup=plt.legend,
                        line=dict(color=core.hex_to_rgb(color)),
                        showlegend=idx == 0,
                    ))
        else:
            raise NotImplementedError

    layout = go.Layout(
        legend=dict(
            x=1,
            y=1,
        ),
        title=title,
    )
    fig = go.Figure(data=data, layout=layout)
    fig_div = po.plot(fig, output_type='div', include_plotlyjs=False)
    if "footnote" in plot_list[0]:
        footnote = "<br />".join([
            r"<span><b>%s</b></span>: <span>%s</span>" %
            (plt.legend, plt.footnote) for plt in plot_list
        ])
        return r"%s<div>%s</div>" % (fig_div, footnote)
    else:
        return fig_div
예제 #4
0
def make_plot(plot_list,
              use_median=False,
              use_five_numbers=False,
              plot_width=None,
              plot_height=None,
              title=None,
              xlim=None,
              ylim=None):
    data = []
    p25, p50, p75 = [], [], []
    p0, p100 = [], []
    for idx, plt in enumerate(plot_list):
        color = core.color_defaults[idx % len(core.color_defaults)]
        if use_median:
            p25.append(np.mean(plt.percentile25))
            p50.append(np.mean(plt.percentile50))
            p75.append(np.mean(plt.percentile75))
            x = list(range(len(plt.percentile50)))
            y = list(plt.percentile50)
            y_upper = list(plt.percentile75)
            y_lower = list(plt.percentile25)
            y_extras = []
        elif use_five_numbers:
            p0.append(np.mean(plt.percentile0))
            p25.append(np.mean(plt.percentile25))
            p50.append(np.mean(plt.percentile50))
            print('>>> mean: {}'.format(plt.mean))
            p75.append(np.mean(plt.percentile75))
            p100.append(np.mean(plt.percentile100))
            x = list(range(len(plt.percentile50)))
            y = list(plt.percentile50)
            y_upper = list(plt.percentile75)
            y_lower = list(plt.percentile25)
            y_extras = [
                list(ys) for ys in [plt.percentile0, plt.percentile100]
            ]

        else:
            x = list(range(len(plt.means)))
            y = list(plt.means)
            y_upper = list(plt.means + plt.stds)
            y_lower = list(plt.means - plt.stds)
            y_extras = []

        if hasattr(plt, "custom_x"):
            x = list(plt.custom_x)

        data.append(
            go.Scatter(x=x + x[::-1],
                       y=y_upper + y_lower[::-1],
                       fill='tozerox',
                       fillcolor=core.hex_to_rgb(color, 0.2),
                       line=go.Line(color='hsva(0,0,0,0)'),
                       showlegend=False,
                       legendgroup=plt.legend,
                       hoverinfo='none'))
        data.append(
            go.Scatter(
                x=x,
                y=y,
                name=plt.legend,
                legendgroup=plt.legend,
                line=dict(color=core.hex_to_rgb(color)),
            ))

        for y_extra in y_extras:
            data.append(
                go.Scatter(
                    x=x,
                    y=y_extra,
                    showlegend=False,
                    legendgroup=plt.legend,
                    line=dict(color=core.hex_to_rgb(color), dash='dot')
                    # choices: solid, dot, dash, longdash, dashdot, longdashdot
                ))

    def numeric_list_to_string(numbers):
        s = '['
        for num in numbers:
            s += (str(num) + ',')
        s += ']'
        return s

    print(numeric_list_to_string(p25))
    print(numeric_list_to_string(p50))
    print(numeric_list_to_string(p75))

    layout = go.Layout(
        legend=dict(
            x=1,
            y=1,
            # xanchor="left",
            # yanchor="bottom",
        ),
        width=plot_width,
        height=plot_height,
        title=title,
        xaxis=go.XAxis(range=xlim),
        yaxis=go.YAxis(range=ylim),
    )
    fig = go.Figure(data=data, layout=layout)
    fig_div = po.plot(fig, output_type='div', include_plotlyjs=False)
    if "footnote" in plot_list[0]:
        footnote = "<br />".join([
            r"<span><b>%s</b></span>: <span>%s</span>" %
            (plt.legend, plt.footnote) for plt in plot_list
        ])
        return r"%s<div>%s</div>" % (fig_div, footnote)
    else:
        return fig_div
예제 #5
0
def make_plot(plot_list, title=None):
    data = []
    for idx, plt in enumerate(plot_list):
        color = core.color_defaults[idx % len(core.color_defaults)]
        x = list(plt.xs)

        if plt.display_mode in ["mean_std", "mean_se"]:
            y = list(plt.means)
            if plt.display_mode == "mean_std":
                y_upper = list(plt.means + plt.stds)
                y_lower = list(plt.means - plt.stds)
            elif plt.display_mode == "mean_se":
                y_upper = list(plt.means + plt.ses)
                y_lower = list(plt.means - plt.ses)
            else:
                raise NotImplementedError
            data.append(go.Scatter(
                x=x + x[::-1],
                y=y_upper + y_lower[::-1],
                fill='tozerox',
                fillcolor=core.hex_to_rgb(color, 0.2),
                line=go.Line(color='transparent'),
                showlegend=False,
                legendgroup=plt.legend,
                hoverinfo='none'
            ))
            data.append(go.Scatter(
                x=x,
                y=y,
                name=plt.legend,
                legendgroup=plt.legend,
                line=dict(color=core.hex_to_rgb(color)),
            ))
        elif plt.display_mode == "individual":
            for idx, y in enumerate(plt.ys):
                data.append(go.Scatter(
                    x=x,
                    y=y,
                    name=plt.legend,
                    legendgroup=plt.legend,
                    line=dict(color=core.hex_to_rgb(color)),
                    showlegend=idx == 0,
                ))
        else:
            raise NotImplementedError

    layout = go.Layout(
        legend=dict(
            x=1,
            y=1,
        ),
        title=title,
    )
    fig = go.Figure(data=data, layout=layout)
    fig_div = po.plot(fig, output_type='div', include_plotlyjs=False)
    if "footnote" in plot_list[0]:
        footnote = "<br />".join([
            r"<span><b>%s</b></span>: <span>%s</span>" % (
                plt.legend, plt.footnote)
            for plt in plot_list
        ])
        return r"%s<div>%s</div>" % (fig_div, footnote)
    else:
        return fig_div
예제 #6
0
def make_plot(
        plot_lists,
        use_median=False,
        plot_width=None,
        plot_height=None,
        title=None,
    ):
    """
    plot_lists is a list of lists.
    Each outer list represents different y-axis attributes.
    Each inner list represents different experiments to run, within that y-axis
    attribute.
    Each plot is an AttrDict which should have the elements used below.
    """
    p25, p50, p75 = [], [], []
    num_y_axes = len(plot_lists)
    fig = tools.make_subplots(rows=1, cols=1, print_grid=False)
    fig['layout'].update(
        width=plot_width,
        height=plot_height,
        title=title,
    )
    i = 0
    for y_idx, plot_list in enumerate(plot_lists):
        for idx, plt in enumerate(plot_list):
            color = core.color_defaults[i % len(core.color_defaults)]
            i += 1

            if use_median:
                p25.append(np.mean(plt.percentile25))
                p50.append(np.mean(plt.percentile50))
                p75.append(np.mean(plt.percentile75))
                x = list(range(len(plt.percentile50)))
                y = list(plt.percentile50)
                y_upper = list(plt.percentile75)
                y_lower = list(plt.percentile25)
            else:
                x = list(range(len(plt.means)))
                y = list(plt.means)
                y_upper = list(plt.means + plt.stds)
                y_lower = list(plt.means - plt.stds)

            errors = go.Scatter(
                x=x + x[::-1],
                y=y_upper + y_lower[::-1],
                fill='tozerox',
                fillcolor=core.hex_to_rgb(color, 0.2),
                line=go.Line(color=core.hex_to_rgb(color, 0)),
                showlegend=False,
                legendgroup=plt.legend,
                hoverinfo='none'
            )
            values = go.Scatter(
                x=x,
                y=y,
                name=plt.plot_key,
                legendgroup=plt.legend,
                line=dict(color=core.hex_to_rgb(color)),
            )
            # plotly is 1-indexed like matplotlib for subplots
            y_idx_plotly = 1
            fig.append_trace(values, y_idx_plotly, 1)
            fig.append_trace(errors, y_idx_plotly, 1)
            fig['layout']['yaxis{}'.format(y_idx_plotly)].update(
                title=plt.plot_key,
            )

    fig_div = po.plot(fig, output_type='div', include_plotlyjs=False)
    if "footnote" in plot_list[0]:
        footnote = "<br />".join([
            r"<span><b>%s</b></span>: <span>%s</span>" % (
                plt.legend, plt.footnote)
            for plt in plot_list
        ])
        return r"%s<div>%s</div>" % (fig_div, footnote)
    else:
        return fig_div