예제 #1
0
def plot_to_json(obj):
    """Converts a matplotlib or plotly object to json so that we can pass
        it the the wandb server and display it nicely there"""

    if util.is_matplotlib_typename(util.get_full_typename(obj)):
        tools = util.get_module(
            "plotly.tools", required="plotly is required to log interactive plots, install with: pip install plotly or convert the plot to an image with `wandb.Image(plt)`")
        obj = tools.mpl_to_plotly(obj)

    if util.is_plotly_typename(util.get_full_typename(obj)):
        return {"_type": "plotly", "plot": numpy_arrays_to_lists(obj.to_plotly_json())}
    else:
        return obj
예제 #2
0
def plot_to_json(obj):
    if util.is_matplotlib_typename(util.get_full_typename(obj)):
        tools = util.get_module(
            "plotly.tools",
            required=
            "plotly is required to log interactive plots, install with: pip install plotly or convert the plot to an image with `wandb.Image(plt)`"
        )
        obj = tools.mpl_to_plotly(obj)

    if util.is_plotly_typename(util.get_full_typename(obj)):
        return {"_type": "plotly", "plot": obj.to_plotly_json()}
    else:
        return obj
예제 #3
0
def val_to_json(key, val, mode="summary", step=None):
    """Converts a wandb datatype to its JSON representation"""
    converted = val
    typename = util.get_full_typename(val)
    if util.is_matplotlib_typename(typename):
        # This handles plots with images in it because plotly doesn't support it
        # TODO: should we handle a list of plots?
        val = util.ensure_matplotlib_figure(val)
        if any(len(ax.images) > 0 for ax in val.axes):
            PILImage = util.get_module(
                "PIL.Image",
                required=
                "Logging plots with images requires pil: pip install pillow")
            buf = six.BytesIO()
            val.savefig(buf)
            val = Image(PILImage.open(buf))
        else:
            converted = util.convert_plots(val)
    elif util.is_plotly_typename(typename):
        converted = util.convert_plots(val)
    if isinstance(val, IterableMedia):
        val = [val]

    if isinstance(val, collections.Sequence) and len(val) > 0:
        is_media = [isinstance(v, IterableMedia) for v in val]
        if all(is_media):
            cwd = wandb.run.dir if wandb.run else "."
            if step is None:
                step = "summary"
            if isinstance(val[0], Image):
                converted = Image.transform(val, cwd,
                                            "{}_{}.jpg".format(key, step))
            elif isinstance(val[0], Audio):
                converted = Audio.transform(val, cwd, key, step)
            elif isinstance(val[0], Html):
                converted = Html.transform(val, cwd, key, step)
            elif isinstance(val[0], Object3D):
                converted = Object3D.transform(val, cwd, key, step)
        elif any(is_media):
            raise ValueError(
                "Mixed media types in the same list aren't supported")
    elif isinstance(val, Histogram):
        converted = Histogram.transform(val)
    elif isinstance(val, Graph):
        if mode == "history":
            raise ValueError("Graphs are only supported in summary")
        converted = Graph.transform(val)
    elif isinstance(val, Table):
        converted = Table.transform(val)
    return converted
예제 #4
0
def val_to_json(run, key, val, step='summary'):
    """Converts a wandb datatype to its JSON representation.
    """
    converted = val
    typename = util.get_full_typename(val)

    if util.is_pandas_data_frame(val):
        assert step == 'summary', "We don't yet support DataFrames in History."
        return data_frame_to_json(val, run, key, step)
    elif util.is_matplotlib_typename(typename):
        # This handles plots with images in it because plotly doesn't support it
        # TODO: should we handle a list of plots?
        val = util.ensure_matplotlib_figure(val)
        if any(len(ax.images) > 0 for ax in val.axes):
            PILImage = util.get_module(
                "PIL.Image",
                required=
                "Logging plots with images requires pil: pip install pillow")
            buf = six.BytesIO()
            val.savefig(buf)
            val = Image(PILImage.open(buf))
        else:
            converted = plot_to_json(val)
    elif util.is_plotly_typename(typename):
        converted = plot_to_json(val)
    elif isinstance(val, collections.Sequence) and all(
            isinstance(v, WBValue) for v in val):
        # This check will break down if Image/Audio/... have child classes.
        if len(val) and isinstance(val[0], BatchableMedia) and all(
                isinstance(v, type(val[0])) for v in val):
            return val[0].seq_to_json(val, run, key, step)
        else:
            # TODO(adrian): Good idea to pass on the same key here? Maybe include
            # the array index?
            # There is a bug here: if this array contains two arrays of the same type of
            # anonymous media objects, their eventual names will collide.
            # This used to happen. The frontend doesn't handle heterogenous arrays
            #raise ValueError(
            #    "Mixed media types in the same list aren't supported")
            return [val_to_json(run, key, v, step=step) for v in val]

    if isinstance(val, WBValue):
        if isinstance(val, Media) and not val.is_bound():
            val.bind_to_run(run, key, step)
        return val.to_json(run)

    return converted