def __init__(self): super().__init__() self.train_metrics = [] # type: List[LiveMetric] self.test_metrics = [] # type: List[LiveMetric] self.logger = LiveLogger() self.train_batch_set = False self.test_batch_set = False self.num_train_batch = 0
def test_set_batch_missing_metric(logger: LiveLogger, data: st.DataObject): keys = list(logger.train_metrics) + list(logger.test_metrics) missing_metrics = data.draw( st.text().filter(lambda x: x not in set(keys)).map(lambda x: {x: 2}) ) logger.set_train_batch(metrics=missing_metrics, batch_size=1) logger.set_test_batch(metrics=missing_metrics, batch_size=1)
def test_logger_xarray(logger: LiveLogger): tr_batch, tr_epoch = logger.to_xarray("train") te_batch, te_epoch = logger.to_xarray("test") check_batch_xarray(logger.train_metrics, tr_batch) check_epoch_xarray(logger.train_metrics, tr_epoch) check_batch_xarray(logger.test_metrics, te_batch) check_epoch_xarray(logger.test_metrics, te_epoch)
def choose_metrics(self, num_train_metrics: int, num_test_metrics: int, data: st.SearchStrategy): assume(num_train_metrics + num_test_metrics > 0) self.train_metric_names = ["metric-a", "metric-b", "metric-c"][:num_train_metrics] self.test_metric_names = ["metric-a", "metric-b", "metric-c"][:num_test_metrics] train_colors = data.draw( st.lists( cst.matplotlib_colors(), min_size=num_train_metrics, max_size=num_train_metrics, ), label="train_colors", ) test_colors = data.draw( st.lists( cst.matplotlib_colors(), min_size=num_test_metrics, max_size=num_test_metrics, ), label="test_colors", ) metrics = OrderedDict((n, dict()) for n in sorted( set(self.train_metric_names + self.test_metric_names))) for metric, color in zip(self.train_metric_names, train_colors): metrics[metric]["train"] = color for metric, color in zip(self.test_metric_names, test_colors): metrics[metric]["test"] = color self.plotter = LivePlot( metrics, max_fraction_spent_plotting=data.draw( st.floats(0, 1), label="max_fraction_spent_plotting"), last_n_batches=data.draw(st.none() | st.integers(1, 100), label="last_n_batches"), ) self.logger = LiveLogger() note("Train metric names: {}".format(self.train_metric_names)) note("Test metric names: {}".format(self.test_metric_names))
def test_concat_experiments(logger: LiveLogger, num_exps: int, data: st.DataObject): metrics = list(logger.train_metrics) assume(len(metrics) > 0) logger.set_train_batch( {k: data.draw(st.floats(-1e6, 1e6)) for k in metrics}, batch_size=1) batch_xarrays = [logger.to_xarray("train")[0]] for n in range(num_exps - 1): logger.set_train_batch( {k: data.draw(st.floats(-1e6, 1e6)) for k in metrics}, batch_size=1) batch_xarrays.append(logger.to_xarray("train")[0]) out = concat_experiments(*batch_xarrays) assert list(out.coords["experiment"]) == list(range(num_exps)) assert list(out.data_vars) == list(metrics) for n in range(num_exps): for metric in metrics: assert_equal( batch_xarrays[n].to_array(metric), out.isel(experiment=n).drop_vars( names=["experiment"]).to_array(metric).dropna( dim="iterations"), )
def test_trivial_case(): """ Perform a trivial sanity check on live logger""" logger = LiveLogger() logger.set_train_batch(dict(a=1.0), batch_size=1) logger.set_train_batch(dict(a=3.0), batch_size=1) logger.set_train_epoch() assert_array_equal(logger.train_metrics["a"]["batch_data"], np.array([1.0, 3.0])) assert_array_equal(logger.train_metrics["a"]["epoch_domain"], np.array([2])) assert_array_equal( logger.train_metrics["a"]["epoch_data"], np.array([1.0 / 2.0 + 3.0 / 2.0]) )
def check_from_dict_roundtrip(self): logger_dict = self.logger.to_dict() new_logger = LiveLogger.from_dict(logger_dict) for attr in [ "_num_train_epoch", "_num_train_batch", "_num_test_epoch", "_num_test_batch", ]: desired = getattr(self.logger, attr) actual = getattr(new_logger, attr) assert actual == desired, ( "`LiveLogger.from_dict` did not round-trip successfully.\n" "logger.{} does not match.\nGot: {}\nExpected: {}" "".format(attr, actual, desired) ) compare_all_metrics(self.logger.train_metrics, new_logger.train_metrics) compare_all_metrics(self.logger.test_metrics, new_logger.test_metrics)
def plot_logger( logger: LiveLogger, plot_batches: bool = True, last_n_batches: Optional[int] = None, colors: Optional[Dict[str, Union[ValidColor, Dict[str, ValidColor]]]] = None, nrows: Optional[int] = None, ncols: int = 1, figsize: Optional[Tuple[int, int]] = None, ) -> Tuple[LivePlot, Figure, Union[Axes, np.ndarray]]: """Plots the data recorded by a :class:`~noggin.logger.LiveLogger` instance. Converts the logger to an instance of :class:`~noggin.plotter.LivePlot`. Parameters ---------- logger : LiveLogger The logger whose train/test-split batch/epoch-level data will be plotted. plot_batches : bool, optional (default=True) If ``True`` include batch-level data in plot. last_n_batches : Optional[int] The maximum number of batches to be plotted at any given time. If ``None``, all of the data will be plotted. colors : Optional[Dict[str, Union[ValidColor, Dict[str, ValidColor]]]] ``colors`` can be a dictionary, specifying the colors used to plot the metrics. Two mappings are valid: - '<metric-name>' -> color-value (specifies train-metric color only) - '<metric-name>' -> {'train'/'test' : color-value} If ``None``, default colors are used in the plot. nrows : Optional[int] Number of rows of the subplot grid. Metrics are added in row-major order to fill the grid. ncols : int, optional, default: 1 Number of columns of the subplot grid. Metrics are added in row-major order to fill the grid. figsize : Optional[Sequence[float, float]] Specifies the width and height, respectively, of the figure. Returns ------- Tuple[LivePlot, Figure, Union[Axes, np.ndarray]] The resulting plotter, matplotlib-figure, and axis (or array of axes) """ if not isinstance(logger, LiveLogger): raise TypeError( "`logger` must be an instance of `noggin.LiveLogger`, got {}". format(logger)) metrics = sorted( set( list(logger.train_metrics.keys()) + list(logger.test_metrics.keys()))) plotter = LivePlot( metrics, max_fraction_spent_plotting=0.0, last_n_batches=last_n_batches, nrows=nrows, ncols=ncols, figsize=figsize, ) plotter.last_n_batches = last_n_batches if colors is not None: plotter.metric_colors = colors plotter_dict = plotter.to_dict() plotter_dict.update(logger.to_dict()) plotter = LivePlot.from_dict(plotter_dict) plotter.plot(plot_batches=plot_batches) fig, ax = plotter.plot_objects return plotter, fig, ax
def test_loggers(logger: LiveLogger): """Ensure that loggers() can produce a Logger that can round-trip""" LiveLogger.from_dict(logger.to_dict())
class LiveLoggerStateMachine(RuleBasedStateMachine): """ Ensures that exercising the api of LiveLogger produces results that are consistent with a simplistic implementation""" def __init__(self): super().__init__() self.train_metrics = [] # type: List[LiveMetric] self.test_metrics = [] # type: List[LiveMetric] self.logger = LiveLogger() self.train_batch_set = False self.test_batch_set = False self.num_train_batch = 0 @initialize(num_train_metrics=st.integers(0, 3), num_test_metrics=st.integers(0, 3)) def choose_metrics(self, num_train_metrics: int, num_test_metrics: int): train_metric_names = ["metric-a", "metric-b", "metric-c"][:num_train_metrics] for name in train_metric_names: self.train_metrics.append(LiveMetric(name=name)) test_metric_names = ["metric-a", "metric-b", "metric-c"][:num_test_metrics] for name in test_metric_names: self.test_metrics.append(LiveMetric(name=name)) note("Train metric names: {}".format(train_metric_names)) note("Test metric names: {}".format(test_metric_names)) @rule() def get_repr(self): """ Ensure no side effect """ repr(self.logger) @rule(batch_size=st.integers(0, 2), data=st.data()) def set_train_batch(self, batch_size: int, data: SearchStrategy): if self.train_metrics: self.num_train_batch += 1 self.train_batch_set = True batch = { metric.name: data.draw( st.floats(-1, 1) | st.floats(-1, 1).map(np.array), label=metric.name ) for metric in self.train_metrics } self.logger.set_train_batch(metrics=batch, batch_size=batch_size) for metric in self.train_metrics: metric.add_datapoint(batch[metric.name], weighting=batch_size) @rule() def set_train_epoch(self): self.logger.set_train_epoch() for metric in self.train_metrics: metric.set_epoch_datapoint() @rule(batch_size=st.integers(0, 2), data=st.data()) def set_test_batch(self, batch_size: int, data: SearchStrategy): self.test_batch_set = True batch = { metric.name: data.draw( st.floats(-1, 1) | st.floats(-1, 1).map(np.array), label=metric.name ) for metric in self.test_metrics } self.logger.set_test_batch(metrics=batch, batch_size=batch_size) for metric in self.test_metrics: metric.add_datapoint(batch[metric.name], weighting=batch_size) @rule() def set_test_epoch(self): self.logger.set_test_epoch() # align test-epoch with train domain for metric in self.test_metrics: if metric.name in {m.name for m in self.train_metrics}: x = self.num_train_batch if self.num_train_batch > 0 else None else: x = None metric.set_epoch_datapoint(x) @precondition(lambda self: self.train_batch_set) @invariant() def compare_train_metrics(self): logged_metrics = self.logger.train_metrics expected_metrics = dict( (metric.name, metric.to_dict()) for metric in self.train_metrics ) compare_all_metrics(logged_metrics, expected_metrics) @precondition(lambda self: self.test_batch_set) @invariant() def compare_test_metrics(self): logged_metrics = self.logger.test_metrics expected_metrics = dict( (metric.name, metric.to_dict()) for metric in self.test_metrics ) compare_all_metrics(logged_metrics, expected_metrics) @invariant() def check_from_dict_roundtrip(self): logger_dict = self.logger.to_dict() new_logger = LiveLogger.from_dict(logger_dict) for attr in [ "_num_train_epoch", "_num_train_batch", "_num_test_epoch", "_num_test_batch", ]: desired = getattr(self.logger, attr) actual = getattr(new_logger, attr) assert actual == desired, ( "`LiveLogger.from_dict` did not round-trip successfully.\n" "logger.{} does not match.\nGot: {}\nExpected: {}" "".format(attr, actual, desired) ) compare_all_metrics(self.logger.train_metrics, new_logger.train_metrics) compare_all_metrics(self.logger.test_metrics, new_logger.test_metrics) @rule(save_via_live_object=st.booleans()) def check_metric_io(self, save_via_live_object: bool): """Ensure the saving/loading metrics always produces self-consistent results with the logger""" from uuid import uuid4 filename = str(uuid4()) if save_via_live_object: save_metrics(filename, liveplot=self.logger) else: save_metrics( filename, train_metrics=self.logger.train_metrics, test_metrics=self.logger.test_metrics, ) io_train_metrics, io_test_metrics = load_metrics(filename) compare_all_metrics(io_train_metrics, self.logger.train_metrics) compare_all_metrics(io_test_metrics, self.logger.test_metrics)
def test_logger_init_compatible_with_plotter(args: tuple, kwargs: dict): LiveLogger(*args, **kwargs)
def test_fuzz_set_epoch(logger: LiveLogger): logger.set_train_epoch() logger.set_test_epoch()
class LivePlotStateMachine(RuleBasedStateMachine): """Provides basic rules for exercising essential aspects of LivePlot""" def __init__(self): super().__init__() self.train_metric_names = [] self.test_metric_names = [] self.train_batch_set = False self.test_batch_set = False self.plotter = None # type: LivePlot self.logger = None # type: LiveLogger @initialize( num_train_metrics=st.integers(0, 3), num_test_metrics=st.integers(0, 3), data=st.data(), ) def choose_metrics(self, num_train_metrics: int, num_test_metrics: int, data: st.SearchStrategy): assume(num_train_metrics + num_test_metrics > 0) self.train_metric_names = ["metric-a", "metric-b", "metric-c"][:num_train_metrics] self.test_metric_names = ["metric-a", "metric-b", "metric-c"][:num_test_metrics] train_colors = data.draw( st.lists( cst.matplotlib_colors(), min_size=num_train_metrics, max_size=num_train_metrics, ), label="train_colors", ) test_colors = data.draw( st.lists( cst.matplotlib_colors(), min_size=num_test_metrics, max_size=num_test_metrics, ), label="test_colors", ) metrics = OrderedDict((n, dict()) for n in sorted( set(self.train_metric_names + self.test_metric_names))) for metric, color in zip(self.train_metric_names, train_colors): metrics[metric]["train"] = color for metric, color in zip(self.test_metric_names, test_colors): metrics[metric]["test"] = color self.plotter = LivePlot( metrics, max_fraction_spent_plotting=data.draw( st.floats(0, 1), label="max_fraction_spent_plotting"), last_n_batches=data.draw(st.none() | st.integers(1, 100), label="last_n_batches"), ) self.logger = LiveLogger() note("Train metric names: {}".format(self.train_metric_names)) note("Test metric names: {}".format(self.test_metric_names)) @rule(batch_size=st.integers(0, 2), data=st.data(), plot=st.booleans()) def set_train_batch(self, batch_size: int, data: SearchStrategy, plot: bool): self.train_batch_set = True batch = { name: data.draw(st.floats(-1, 1), label=name) for name in self.train_metric_names } self.logger.set_train_batch(metrics=batch, batch_size=batch_size) self.plotter.set_train_batch(metrics=batch, batch_size=batch_size, plot=plot) @rule() def set_train_epoch(self): self.logger.set_train_epoch() self.plotter.set_train_epoch() @rule(batch_size=st.integers(0, 2), data=st.data()) def set_test_batch(self, batch_size: int, data: SearchStrategy): self.test_batch_set = True batch = { name: data.draw(st.floats(-1, 1), label=name) for name in self.test_metric_names } self.logger.set_test_batch(metrics=batch, batch_size=batch_size) self.plotter.set_test_batch(metrics=batch, batch_size=batch_size) @rule() def set_test_epoch(self): self.logger.set_test_epoch() self.plotter.set_test_epoch() def teardown(self): plt.close("all") super().teardown()
def test_logger_xarray_validate_inputs(logger: LiveLogger): with pytest.raises(ValueError): logger.to_xarray("traintest")