def test_plotter_xarray(plotter: LivePlot): tr_batch, tr_epoch = plotter.to_xarray("train") te_batch, te_epoch = plotter.to_xarray("test") check_batch_xarray(plotter.train_metrics, tr_batch) check_epoch_xarray(plotter.train_metrics, tr_epoch) check_batch_xarray(plotter.test_metrics, te_batch) check_epoch_xarray(plotter.test_metrics, te_epoch)
def create_plot( metrics: Metrics, max_fraction_spent_plotting: float = 0.05, last_n_batches: Optional[int] = None, nrows: Optional[int] = None, ncols: int = 1, figsize: Optional[Tuple[int, int]] = None, ) -> Tuple[LivePlot, Figure, ndarray]: """ Create matplotlib figure/axes, and a live-plotter, which publishes "live" training/testing metric data, at a batch and epoch level, to the figure. Returns ------- Tuple[liveplot.LivePlot, matplotlib.figure.Figure, numpy.ndarray(matplotlib.axes.Axes)] (LivePlot-instance, figure, array-of-axes) Examples -------- Creating a live plot in a Jupyter notebook >>> %matplotlib notebook >>> import numpy as np >>> from noggin import create_plot, save_metrics >>> metrics = ["accuracy", "loss"] >>> plotter, fig, ax = create_plot(metrics) >>> for i, x in enumerate(np.linspace(0, 10, 100)): ... # training ... x += np.random.rand(1)*5 ... batch_metrics = {"accuracy": x**2, "loss": 1/x**.5} ... plotter.set_train_batch(batch_metrics, batch_size=1, plot=True) ... ... # cue training epoch ... if i%10 == 0 and i > 0: ... plotter.plot_train_epoch() ... ... # cue test-time computations ... for x in np.linspace(0, 10, 5): ... x += (np.random.rand(1) - 0.5)*5 ... test_metrics = {"accuracy": x**2} ... plotter.set_test_batch(test_metrics, batch_size=1) ... plotter.plot_test_epoch() ... ... plotter.plot() # ensures final data gets plotted Saving the logged metrics >>> save_metrics("./metrics.npz", plotter) # save metrics to numpy-archive """ live_plotter = LivePlot( metrics, max_fraction_spent_plotting=max_fraction_spent_plotting, last_n_batches=last_n_batches, figsize=figsize, ncols=ncols, nrows=nrows, ) fig, ax = live_plotter.plot_objects return live_plotter, fig, ax
def test_flat_color_syntax(colors: list): metric_names = ascii_letters[:len(colors)] p = LivePlot({n: c for n, c in zip(metric_names, colors)}) assert p.metric_colors == { n: dict(train=c) for n, c in zip(metric_names, colors) }
def plotters(draw) -> st.SearchStrategy[LivePlot]: train_metrics = draw(live_metrics()) min_num_test = 1 if not train_metrics else 0 test_metrics = draw(live_metrics(min_num_metrics=min_num_test)) metric_names = sorted(set(train_metrics).union(set(test_metrics))) train_colors = {k: draw(matplotlib_colors()) for k in train_metrics} test_colors = {k: draw(matplotlib_colors()) for k in test_metrics} return LivePlot.from_dict( dict( train_metrics=train_metrics, test_metrics=test_metrics, num_train_epoch=max( (len(v["epoch_data"]) for v in train_metrics.values()), default=0), num_train_batch=max( (len(v["batch_data"]) for v in train_metrics.values()), default=0), num_test_epoch=max( (len(v["epoch_data"]) for v in test_metrics.values()), default=0), num_test_batch=max( (len(v["batch_data"]) for v in test_metrics.values()), default=0), max_fraction_spent_plotting=draw(st.floats(0, 1)), last_n_batches=draw(st.none() | st.integers(1, 100)), pltkwargs=dict(figsize=(3, 2), nrows=len(metric_names), ncols=1), train_colors=train_colors, test_colors=test_colors, metric_names=metric_names, ))
def test_input_validation(bad_input: dict, data: st.DataObject): defaults = dict(metrics=["a"]) defaults.update({ k: cst.draw_if_strategy(data, v, label=k) for k, v in bad_input.items() }) with pytest.raises((ValueError, TypeError)): LivePlot(**defaults)
def choose_metrics(self, num_train_metrics: int, num_test_metrics: int, data: st.SearchStrategy): assume(num_train_metrics + num_test_metrics > 0) self.train_metric_names = ["metric-a", "metric-b", "metric-c"][:num_train_metrics] self.test_metric_names = ["metric-a", "metric-b", "metric-c"][:num_test_metrics] train_colors = data.draw( st.lists( cst.matplotlib_colors(), min_size=num_train_metrics, max_size=num_train_metrics, ), label="train_colors", ) test_colors = data.draw( st.lists( cst.matplotlib_colors(), min_size=num_test_metrics, max_size=num_test_metrics, ), label="test_colors", ) metrics = OrderedDict((n, dict()) for n in sorted( set(self.train_metric_names + self.test_metric_names))) for metric, color in zip(self.train_metric_names, train_colors): metrics[metric]["train"] = color for metric, color in zip(self.test_metric_names, test_colors): metrics[metric]["test"] = color self.plotter = LivePlot( metrics, max_fraction_spent_plotting=data.draw( st.floats(0, 1), label="max_fraction_spent_plotting"), last_n_batches=data.draw(st.none() | st.integers(1, 100), label="last_n_batches"), ) self.logger = LiveLogger() note("Train metric names: {}".format(self.train_metric_names)) note("Test metric names: {}".format(self.test_metric_names))
def test_setting_color_for_non_metric_is_silent(plotter: LivePlot, data: st.DataObject): color = { data.draw(st.text(), label="non_metric"): data.draw(cst.matplotlib_colors(), label="color") } original_colors = plotter.metric_colors plotter.metric_colors = color assert plotter.metric_colors == original_colors
def test_plot_grid(num_metrics, fig_layout, outer_type, shape): """Ensure that axes have the right type/shape for a given grid spec""" metric_names = list(ascii_letters[:num_metrics]) fig, ax = LivePlot(metric_names, **fig_layout).plot_objects assert isinstance(fig, Figure) assert isinstance(ax, outer_type) if shape: assert ax.shape == shape
def check_from_dict_roundtrip(self): plotter_dict = self.plotter.to_dict() filename = str(uuid4()) with open(filename, "wb") as f: pickle.dump(plotter_dict, f) with open(filename, "rb") as f: loaded_dict = pickle.load(f) new_plotter = LivePlot.from_dict(loaded_dict) for attr in [ "_num_train_epoch", "_num_train_batch", "_num_test_epoch", "_num_test_batch", "_metrics", "_pltkwargs", "metric_colors", ]: desired = getattr(self.plotter, attr) actual = getattr(new_plotter, attr) assert_array_equal( actual, desired, err_msg= "LiveLogger.from_metrics did not round-trip successfully.\n" "logger.{} does not match.\nGot: {}\nExpected: {}" "".format(attr, actual, desired), ) compare_all_metrics(self.plotter.train_metrics, new_plotter.train_metrics) compare_all_metrics(self.plotter.test_metrics, new_plotter.test_metrics) assert isinstance(new_plotter._test_colors, type(self.plotter._test_colors)) assert self.plotter._test_colors == new_plotter._test_colors assert self.plotter._test_colors[None] is new_plotter._test_colors[None] assert isinstance(self.plotter._train_colors, type(new_plotter._train_colors)) assert self.plotter._train_colors == new_plotter._train_colors assert self.plotter._train_colors[None] is new_plotter._train_colors[ None] # check consistency for all public attributes for attr in ( x for x in dir(self.plotter) if not x.startswith("_") and not callable(getattr(self.plotter, x)) and x not in {"plot_objects", "metrics", "test_metrics", "train_metrics"}): original_attr = getattr(self.plotter, attr) from_dict_attr = getattr(new_plotter, attr) assert original_attr == from_dict_attr, attr
def test_unregister_metric_warns(): plotter = LivePlot(metrics=["a"]) with pytest.warns(UserWarning): plotter.set_train_batch(dict(a=1, b=1), batch_size=1) with pytest.warns(UserWarning): plotter.set_test_batch(dict(a=1, c=1), batch_size=1)
def test_trivial_case(): """ Perform a trivial sanity check on live plotter""" plotter = LivePlot("a") plotter.set_train_batch(dict(a=1.0), batch_size=1, plot=False) plotter.set_train_batch(dict(a=3.0), batch_size=1, plot=False) plotter.set_train_epoch() assert_array_equal(plotter.train_metrics["a"]["batch_data"], np.array([1.0, 3.0])) assert_array_equal(plotter.train_metrics["a"]["epoch_domain"], np.array([2])) assert_array_equal(plotter.train_metrics["a"]["epoch_data"], np.array([1.0 / 2.0 + 3.0 / 2.0]))
def test_fuzz_plot_grid(num_metrics: int, nrows: int, ncols: int): plotter = LivePlot(list(ascii_letters[:num_metrics]), nrows=nrows, ncols=ncols) assert plotter._pltkwargs["nrows"] * plotter._pltkwargs[ "ncols"] >= num_metrics
def test_adaptive_plot_grid(): plotter = LivePlot(list(ascii_letters[:5]), ncols=2) assert plotter._pltkwargs["nrows"] == 3 assert plotter._pltkwargs["ncols"] == 2
class LivePlotStateMachine(RuleBasedStateMachine): """Provides basic rules for exercising essential aspects of LivePlot""" def __init__(self): super().__init__() self.train_metric_names = [] self.test_metric_names = [] self.train_batch_set = False self.test_batch_set = False self.plotter = None # type: LivePlot self.logger = None # type: LiveLogger @initialize( num_train_metrics=st.integers(0, 3), num_test_metrics=st.integers(0, 3), data=st.data(), ) def choose_metrics(self, num_train_metrics: int, num_test_metrics: int, data: st.SearchStrategy): assume(num_train_metrics + num_test_metrics > 0) self.train_metric_names = ["metric-a", "metric-b", "metric-c"][:num_train_metrics] self.test_metric_names = ["metric-a", "metric-b", "metric-c"][:num_test_metrics] train_colors = data.draw( st.lists( cst.matplotlib_colors(), min_size=num_train_metrics, max_size=num_train_metrics, ), label="train_colors", ) test_colors = data.draw( st.lists( cst.matplotlib_colors(), min_size=num_test_metrics, max_size=num_test_metrics, ), label="test_colors", ) metrics = OrderedDict((n, dict()) for n in sorted( set(self.train_metric_names + self.test_metric_names))) for metric, color in zip(self.train_metric_names, train_colors): metrics[metric]["train"] = color for metric, color in zip(self.test_metric_names, test_colors): metrics[metric]["test"] = color self.plotter = LivePlot( metrics, max_fraction_spent_plotting=data.draw( st.floats(0, 1), label="max_fraction_spent_plotting"), last_n_batches=data.draw(st.none() | st.integers(1, 100), label="last_n_batches"), ) self.logger = LiveLogger() note("Train metric names: {}".format(self.train_metric_names)) note("Test metric names: {}".format(self.test_metric_names)) @rule(batch_size=st.integers(0, 2), data=st.data(), plot=st.booleans()) def set_train_batch(self, batch_size: int, data: SearchStrategy, plot: bool): self.train_batch_set = True batch = { name: data.draw(st.floats(-1, 1), label=name) for name in self.train_metric_names } self.logger.set_train_batch(metrics=batch, batch_size=batch_size) self.plotter.set_train_batch(metrics=batch, batch_size=batch_size, plot=plot) @rule() def set_train_epoch(self): self.logger.set_train_epoch() self.plotter.set_train_epoch() @rule(batch_size=st.integers(0, 2), data=st.data()) def set_test_batch(self, batch_size: int, data: SearchStrategy): self.test_batch_set = True batch = { name: data.draw(st.floats(-1, 1), label=name) for name in self.test_metric_names } self.logger.set_test_batch(metrics=batch, batch_size=batch_size) self.plotter.set_test_batch(metrics=batch, batch_size=batch_size) @rule() def set_test_epoch(self): self.logger.set_test_epoch() self.plotter.set_test_epoch() def teardown(self): plt.close("all") super().teardown()
def test_bad_figsize(plotter: LivePlot, bad_size): with pytest.raises(ValueError): plotter.figsize = bad_size
def test_plotters(plotter: LivePlot): """Ensure that loggers() can produce a Logger that can round-trip""" LivePlot.from_dict(plotter.to_dict())
def test_color_setter_validation(plotter: LivePlot, bad_colors): with pytest.raises(TypeError): plotter.metric_colors = bad_colors
def plot_logger( logger: LiveLogger, plot_batches: bool = True, last_n_batches: Optional[int] = None, colors: Optional[Dict[str, Union[ValidColor, Dict[str, ValidColor]]]] = None, nrows: Optional[int] = None, ncols: int = 1, figsize: Optional[Tuple[int, int]] = None, ) -> Tuple[LivePlot, Figure, Union[Axes, np.ndarray]]: """Plots the data recorded by a :class:`~noggin.logger.LiveLogger` instance. Converts the logger to an instance of :class:`~noggin.plotter.LivePlot`. Parameters ---------- logger : LiveLogger The logger whose train/test-split batch/epoch-level data will be plotted. plot_batches : bool, optional (default=True) If ``True`` include batch-level data in plot. last_n_batches : Optional[int] The maximum number of batches to be plotted at any given time. If ``None``, all of the data will be plotted. colors : Optional[Dict[str, Union[ValidColor, Dict[str, ValidColor]]]] ``colors`` can be a dictionary, specifying the colors used to plot the metrics. Two mappings are valid: - '<metric-name>' -> color-value (specifies train-metric color only) - '<metric-name>' -> {'train'/'test' : color-value} If ``None``, default colors are used in the plot. nrows : Optional[int] Number of rows of the subplot grid. Metrics are added in row-major order to fill the grid. ncols : int, optional, default: 1 Number of columns of the subplot grid. Metrics are added in row-major order to fill the grid. figsize : Optional[Sequence[float, float]] Specifies the width and height, respectively, of the figure. Returns ------- Tuple[LivePlot, Figure, Union[Axes, np.ndarray]] The resulting plotter, matplotlib-figure, and axis (or array of axes) """ if not isinstance(logger, LiveLogger): raise TypeError( "`logger` must be an instance of `noggin.LiveLogger`, got {}". format(logger)) metrics = sorted( set( list(logger.train_metrics.keys()) + list(logger.test_metrics.keys()))) plotter = LivePlot( metrics, max_fraction_spent_plotting=0.0, last_n_batches=last_n_batches, nrows=nrows, ncols=ncols, figsize=figsize, ) plotter.last_n_batches = last_n_batches if colors is not None: plotter.metric_colors = colors plotter_dict = plotter.to_dict() plotter_dict.update(logger.to_dict()) plotter = LivePlot.from_dict(plotter_dict) plotter.plot(plot_batches=plot_batches) fig, ax = plotter.plot_objects return plotter, fig, ax
def test_set_color(plotter: LivePlot, colors: dict, data: st.DataObject): metric = data.draw(st.sampled_from(plotter.metrics), label="metric") plotter.metric_colors = {metric: colors} assert plotter.metric_colors[metric] == colors