Beispiel #1
0
def test_plotter_xarray(plotter: LivePlot):
    tr_batch, tr_epoch = plotter.to_xarray("train")
    te_batch, te_epoch = plotter.to_xarray("test")

    check_batch_xarray(plotter.train_metrics, tr_batch)
    check_epoch_xarray(plotter.train_metrics, tr_epoch)
    check_batch_xarray(plotter.test_metrics, te_batch)
    check_epoch_xarray(plotter.test_metrics, te_epoch)
Beispiel #2
0
def create_plot(
    metrics: Metrics,
    max_fraction_spent_plotting: float = 0.05,
    last_n_batches: Optional[int] = None,
    nrows: Optional[int] = None,
    ncols: int = 1,
    figsize: Optional[Tuple[int, int]] = None,
) -> Tuple[LivePlot, Figure, ndarray]:
    """ Create matplotlib figure/axes, and a live-plotter, which publishes
    "live" training/testing metric data, at a batch and epoch level, to
    the figure.

    Returns
    -------
    Tuple[liveplot.LivePlot, matplotlib.figure.Figure, numpy.ndarray(matplotlib.axes.Axes)]
        (LivePlot-instance, figure, array-of-axes)


    Examples
    --------
    Creating a live plot in a Jupyter notebook

    >>> %matplotlib notebook
    >>> import numpy as np
    >>> from noggin import create_plot, save_metrics
    >>> metrics = ["accuracy", "loss"]
    >>> plotter, fig, ax = create_plot(metrics)
    >>> for i, x in enumerate(np.linspace(0, 10, 100)):
    ...     # training
    ...     x += np.random.rand(1)*5
    ...     batch_metrics = {"accuracy": x**2, "loss": 1/x**.5}
    ...     plotter.set_train_batch(batch_metrics, batch_size=1, plot=True)
    ...
    ...     # cue training epoch
    ...     if i%10 == 0 and i > 0:
    ...         plotter.plot_train_epoch()
    ...
    ...         # cue test-time computations
    ...         for x in np.linspace(0, 10, 5):
    ...             x += (np.random.rand(1) - 0.5)*5
    ...             test_metrics = {"accuracy": x**2}
    ...             plotter.set_test_batch(test_metrics, batch_size=1)
    ...         plotter.plot_test_epoch()
    ...
    ... plotter.plot()  # ensures final data gets plotted

    Saving the logged metrics

    >>> save_metrics("./metrics.npz", plotter) # save metrics to numpy-archive
    """
    live_plotter = LivePlot(
        metrics,
        max_fraction_spent_plotting=max_fraction_spent_plotting,
        last_n_batches=last_n_batches,
        figsize=figsize,
        ncols=ncols,
        nrows=nrows,
    )
    fig, ax = live_plotter.plot_objects
    return live_plotter, fig, ax
Beispiel #3
0
def test_flat_color_syntax(colors: list):
    metric_names = ascii_letters[:len(colors)]
    p = LivePlot({n: c for n, c in zip(metric_names, colors)})
    assert p.metric_colors == {
        n: dict(train=c)
        for n, c in zip(metric_names, colors)
    }
Beispiel #4
0
def plotters(draw) -> st.SearchStrategy[LivePlot]:
    train_metrics = draw(live_metrics())
    min_num_test = 1 if not train_metrics else 0
    test_metrics = draw(live_metrics(min_num_metrics=min_num_test))
    metric_names = sorted(set(train_metrics).union(set(test_metrics)))
    train_colors = {k: draw(matplotlib_colors()) for k in train_metrics}
    test_colors = {k: draw(matplotlib_colors()) for k in test_metrics}

    return LivePlot.from_dict(
        dict(
            train_metrics=train_metrics,
            test_metrics=test_metrics,
            num_train_epoch=max(
                (len(v["epoch_data"]) for v in train_metrics.values()),
                default=0),
            num_train_batch=max(
                (len(v["batch_data"]) for v in train_metrics.values()),
                default=0),
            num_test_epoch=max(
                (len(v["epoch_data"]) for v in test_metrics.values()),
                default=0),
            num_test_batch=max(
                (len(v["batch_data"]) for v in test_metrics.values()),
                default=0),
            max_fraction_spent_plotting=draw(st.floats(0, 1)),
            last_n_batches=draw(st.none() | st.integers(1, 100)),
            pltkwargs=dict(figsize=(3, 2), nrows=len(metric_names), ncols=1),
            train_colors=train_colors,
            test_colors=test_colors,
            metric_names=metric_names,
        ))
Beispiel #5
0
def test_input_validation(bad_input: dict, data: st.DataObject):
    defaults = dict(metrics=["a"])
    defaults.update({
        k: cst.draw_if_strategy(data, v, label=k)
        for k, v in bad_input.items()
    })
    with pytest.raises((ValueError, TypeError)):
        LivePlot(**defaults)
Beispiel #6
0
    def choose_metrics(self, num_train_metrics: int, num_test_metrics: int,
                       data: st.SearchStrategy):
        assume(num_train_metrics + num_test_metrics > 0)
        self.train_metric_names = ["metric-a", "metric-b",
                                   "metric-c"][:num_train_metrics]

        self.test_metric_names = ["metric-a", "metric-b",
                                  "metric-c"][:num_test_metrics]
        train_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_train_metrics,
                max_size=num_train_metrics,
            ),
            label="train_colors",
        )

        test_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_test_metrics,
                max_size=num_test_metrics,
            ),
            label="test_colors",
        )

        metrics = OrderedDict((n, dict()) for n in sorted(
            set(self.train_metric_names + self.test_metric_names)))

        for metric, color in zip(self.train_metric_names, train_colors):
            metrics[metric]["train"] = color

        for metric, color in zip(self.test_metric_names, test_colors):
            metrics[metric]["test"] = color

        self.plotter = LivePlot(
            metrics,
            max_fraction_spent_plotting=data.draw(
                st.floats(0, 1), label="max_fraction_spent_plotting"),
            last_n_batches=data.draw(st.none() | st.integers(1, 100),
                                     label="last_n_batches"),
        )
        self.logger = LiveLogger()

        note("Train metric names: {}".format(self.train_metric_names))
        note("Test metric names: {}".format(self.test_metric_names))
Beispiel #7
0
def test_setting_color_for_non_metric_is_silent(plotter: LivePlot,
                                                data: st.DataObject):
    color = {
        data.draw(st.text(), label="non_metric"):
        data.draw(cst.matplotlib_colors(), label="color")
    }
    original_colors = plotter.metric_colors
    plotter.metric_colors = color
    assert plotter.metric_colors == original_colors
Beispiel #8
0
def test_plot_grid(num_metrics, fig_layout, outer_type, shape):
    """Ensure that axes have the right type/shape for a given grid spec"""
    metric_names = list(ascii_letters[:num_metrics])

    fig, ax = LivePlot(metric_names, **fig_layout).plot_objects

    assert isinstance(fig, Figure)
    assert isinstance(ax, outer_type)
    if shape:
        assert ax.shape == shape
Beispiel #9
0
    def check_from_dict_roundtrip(self):
        plotter_dict = self.plotter.to_dict()
        filename = str(uuid4())
        with open(filename, "wb") as f:
            pickle.dump(plotter_dict, f)

        with open(filename, "rb") as f:
            loaded_dict = pickle.load(f)

        new_plotter = LivePlot.from_dict(loaded_dict)

        for attr in [
                "_num_train_epoch",
                "_num_train_batch",
                "_num_test_epoch",
                "_num_test_batch",
                "_metrics",
                "_pltkwargs",
                "metric_colors",
        ]:
            desired = getattr(self.plotter, attr)
            actual = getattr(new_plotter, attr)
            assert_array_equal(
                actual,
                desired,
                err_msg=
                "LiveLogger.from_metrics did not round-trip successfully.\n"
                "logger.{} does not match.\nGot: {}\nExpected: {}"
                "".format(attr, actual, desired),
            )

        compare_all_metrics(self.plotter.train_metrics,
                            new_plotter.train_metrics)
        compare_all_metrics(self.plotter.test_metrics,
                            new_plotter.test_metrics)

        assert isinstance(new_plotter._test_colors,
                          type(self.plotter._test_colors))
        assert self.plotter._test_colors == new_plotter._test_colors
        assert self.plotter._test_colors[None] is new_plotter._test_colors[None]

        assert isinstance(self.plotter._train_colors,
                          type(new_plotter._train_colors))
        assert self.plotter._train_colors == new_plotter._train_colors
        assert self.plotter._train_colors[None] is new_plotter._train_colors[
            None]

        # check consistency for all public attributes
        for attr in (
                x for x in dir(self.plotter) if not x.startswith("_")
                and not callable(getattr(self.plotter, x)) and x not in
            {"plot_objects", "metrics", "test_metrics", "train_metrics"}):
            original_attr = getattr(self.plotter, attr)
            from_dict_attr = getattr(new_plotter, attr)
            assert original_attr == from_dict_attr, attr
Beispiel #10
0
def test_unregister_metric_warns():
    plotter = LivePlot(metrics=["a"])
    with pytest.warns(UserWarning):
        plotter.set_train_batch(dict(a=1, b=1), batch_size=1)

    with pytest.warns(UserWarning):
        plotter.set_test_batch(dict(a=1, c=1), batch_size=1)
Beispiel #11
0
def test_trivial_case():
    """ Perform a trivial sanity check on live plotter"""
    plotter = LivePlot("a")
    plotter.set_train_batch(dict(a=1.0), batch_size=1, plot=False)
    plotter.set_train_batch(dict(a=3.0), batch_size=1, plot=False)
    plotter.set_train_epoch()

    assert_array_equal(plotter.train_metrics["a"]["batch_data"],
                       np.array([1.0, 3.0]))
    assert_array_equal(plotter.train_metrics["a"]["epoch_domain"],
                       np.array([2]))
    assert_array_equal(plotter.train_metrics["a"]["epoch_data"],
                       np.array([1.0 / 2.0 + 3.0 / 2.0]))
Beispiel #12
0
def test_fuzz_plot_grid(num_metrics: int, nrows: int, ncols: int):
    plotter = LivePlot(list(ascii_letters[:num_metrics]),
                       nrows=nrows,
                       ncols=ncols)
    assert plotter._pltkwargs["nrows"] * plotter._pltkwargs[
        "ncols"] >= num_metrics
Beispiel #13
0
def test_adaptive_plot_grid():
    plotter = LivePlot(list(ascii_letters[:5]), ncols=2)
    assert plotter._pltkwargs["nrows"] == 3
    assert plotter._pltkwargs["ncols"] == 2
Beispiel #14
0
class LivePlotStateMachine(RuleBasedStateMachine):
    """Provides basic rules for exercising essential aspects of LivePlot"""
    def __init__(self):
        super().__init__()
        self.train_metric_names = []
        self.test_metric_names = []
        self.train_batch_set = False
        self.test_batch_set = False
        self.plotter = None  # type: LivePlot
        self.logger = None  # type: LiveLogger

    @initialize(
        num_train_metrics=st.integers(0, 3),
        num_test_metrics=st.integers(0, 3),
        data=st.data(),
    )
    def choose_metrics(self, num_train_metrics: int, num_test_metrics: int,
                       data: st.SearchStrategy):
        assume(num_train_metrics + num_test_metrics > 0)
        self.train_metric_names = ["metric-a", "metric-b",
                                   "metric-c"][:num_train_metrics]

        self.test_metric_names = ["metric-a", "metric-b",
                                  "metric-c"][:num_test_metrics]
        train_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_train_metrics,
                max_size=num_train_metrics,
            ),
            label="train_colors",
        )

        test_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_test_metrics,
                max_size=num_test_metrics,
            ),
            label="test_colors",
        )

        metrics = OrderedDict((n, dict()) for n in sorted(
            set(self.train_metric_names + self.test_metric_names)))

        for metric, color in zip(self.train_metric_names, train_colors):
            metrics[metric]["train"] = color

        for metric, color in zip(self.test_metric_names, test_colors):
            metrics[metric]["test"] = color

        self.plotter = LivePlot(
            metrics,
            max_fraction_spent_plotting=data.draw(
                st.floats(0, 1), label="max_fraction_spent_plotting"),
            last_n_batches=data.draw(st.none() | st.integers(1, 100),
                                     label="last_n_batches"),
        )
        self.logger = LiveLogger()

        note("Train metric names: {}".format(self.train_metric_names))
        note("Test metric names: {}".format(self.test_metric_names))

    @rule(batch_size=st.integers(0, 2), data=st.data(), plot=st.booleans())
    def set_train_batch(self, batch_size: int, data: SearchStrategy,
                        plot: bool):
        self.train_batch_set = True

        batch = {
            name: data.draw(st.floats(-1, 1), label=name)
            for name in self.train_metric_names
        }
        self.logger.set_train_batch(metrics=batch, batch_size=batch_size)
        self.plotter.set_train_batch(metrics=batch,
                                     batch_size=batch_size,
                                     plot=plot)

    @rule()
    def set_train_epoch(self):
        self.logger.set_train_epoch()
        self.plotter.set_train_epoch()

    @rule(batch_size=st.integers(0, 2), data=st.data())
    def set_test_batch(self, batch_size: int, data: SearchStrategy):
        self.test_batch_set = True

        batch = {
            name: data.draw(st.floats(-1, 1), label=name)
            for name in self.test_metric_names
        }
        self.logger.set_test_batch(metrics=batch, batch_size=batch_size)
        self.plotter.set_test_batch(metrics=batch, batch_size=batch_size)

    @rule()
    def set_test_epoch(self):
        self.logger.set_test_epoch()
        self.plotter.set_test_epoch()

    def teardown(self):
        plt.close("all")
        super().teardown()
Beispiel #15
0
def test_bad_figsize(plotter: LivePlot, bad_size):
    with pytest.raises(ValueError):
        plotter.figsize = bad_size
Beispiel #16
0
def test_plotters(plotter: LivePlot):
    """Ensure that loggers() can produce a Logger that can round-trip"""
    LivePlot.from_dict(plotter.to_dict())
Beispiel #17
0
def test_color_setter_validation(plotter: LivePlot, bad_colors):
    with pytest.raises(TypeError):
        plotter.metric_colors = bad_colors
Beispiel #18
0
def plot_logger(
    logger: LiveLogger,
    plot_batches: bool = True,
    last_n_batches: Optional[int] = None,
    colors: Optional[Dict[str, Union[ValidColor, Dict[str,
                                                      ValidColor]]]] = None,
    nrows: Optional[int] = None,
    ncols: int = 1,
    figsize: Optional[Tuple[int, int]] = None,
) -> Tuple[LivePlot, Figure, Union[Axes, np.ndarray]]:
    """Plots the data recorded by a :class:`~noggin.logger.LiveLogger` instance.

    Converts the logger to an instance of :class:`~noggin.plotter.LivePlot`.

    Parameters
    ----------
    logger : LiveLogger
        The logger whose train/test-split batch/epoch-level data will be plotted.

    plot_batches : bool, optional (default=True)
        If ``True`` include batch-level data in plot.

    last_n_batches : Optional[int]
        The maximum number of batches to be plotted at any given time.
        If ``None``, all of the data will be plotted.

    colors : Optional[Dict[str, Union[ValidColor, Dict[str, ValidColor]]]]
        ``colors`` can be a dictionary, specifying the colors used to plot
        the metrics. Two mappings are valid:
            - '<metric-name>' -> color-value  (specifies train-metric color only)
            - '<metric-name>' -> {'train'/'test' : color-value}
        If ``None``, default colors are used in the plot.

    nrows : Optional[int]
        Number of rows of the subplot grid. Metrics are added in
        row-major order to fill the grid.

    ncols : int, optional, default: 1
        Number of columns of the subplot grid. Metrics are added in
        row-major order to fill the grid.

    figsize : Optional[Sequence[float, float]]
        Specifies the width and height, respectively, of the figure.

    Returns
    -------
    Tuple[LivePlot, Figure, Union[Axes, np.ndarray]]
        The resulting plotter, matplotlib-figure, and axis (or array of axes)
    """

    if not isinstance(logger, LiveLogger):
        raise TypeError(
            "`logger` must be an instance of `noggin.LiveLogger`, got {}".
            format(logger))

    metrics = sorted(
        set(
            list(logger.train_metrics.keys()) +
            list(logger.test_metrics.keys())))

    plotter = LivePlot(
        metrics,
        max_fraction_spent_plotting=0.0,
        last_n_batches=last_n_batches,
        nrows=nrows,
        ncols=ncols,
        figsize=figsize,
    )

    plotter.last_n_batches = last_n_batches
    if colors is not None:
        plotter.metric_colors = colors

    plotter_dict = plotter.to_dict()

    plotter_dict.update(logger.to_dict())
    plotter = LivePlot.from_dict(plotter_dict)
    plotter.plot(plot_batches=plot_batches)
    fig, ax = plotter.plot_objects
    return plotter, fig, ax
Beispiel #19
0
def test_set_color(plotter: LivePlot, colors: dict, data: st.DataObject):
    metric = data.draw(st.sampled_from(plotter.metrics), label="metric")
    plotter.metric_colors = {metric: colors}
    assert plotter.metric_colors[metric] == colors