예제 #1
0
 def __init__(self):
     super().__init__()
     self.train_metrics = []  # type: List[LiveMetric]
     self.test_metrics = []  # type: List[LiveMetric]
     self.logger = LiveLogger()
     self.train_batch_set = False
     self.test_batch_set = False
     self.num_train_batch = 0
예제 #2
0
def test_set_batch_missing_metric(logger: LiveLogger, data: st.DataObject):
    keys = list(logger.train_metrics) + list(logger.test_metrics)

    missing_metrics = data.draw(
        st.text().filter(lambda x: x not in set(keys)).map(lambda x: {x: 2})
    )
    logger.set_train_batch(metrics=missing_metrics, batch_size=1)
    logger.set_test_batch(metrics=missing_metrics, batch_size=1)
예제 #3
0
def test_logger_xarray(logger: LiveLogger):
    tr_batch, tr_epoch = logger.to_xarray("train")
    te_batch, te_epoch = logger.to_xarray("test")

    check_batch_xarray(logger.train_metrics, tr_batch)
    check_epoch_xarray(logger.train_metrics, tr_epoch)
    check_batch_xarray(logger.test_metrics, te_batch)
    check_epoch_xarray(logger.test_metrics, te_epoch)
예제 #4
0
    def choose_metrics(self, num_train_metrics: int, num_test_metrics: int,
                       data: st.SearchStrategy):
        assume(num_train_metrics + num_test_metrics > 0)
        self.train_metric_names = ["metric-a", "metric-b",
                                   "metric-c"][:num_train_metrics]

        self.test_metric_names = ["metric-a", "metric-b",
                                  "metric-c"][:num_test_metrics]
        train_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_train_metrics,
                max_size=num_train_metrics,
            ),
            label="train_colors",
        )

        test_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_test_metrics,
                max_size=num_test_metrics,
            ),
            label="test_colors",
        )

        metrics = OrderedDict((n, dict()) for n in sorted(
            set(self.train_metric_names + self.test_metric_names)))

        for metric, color in zip(self.train_metric_names, train_colors):
            metrics[metric]["train"] = color

        for metric, color in zip(self.test_metric_names, test_colors):
            metrics[metric]["test"] = color

        self.plotter = LivePlot(
            metrics,
            max_fraction_spent_plotting=data.draw(
                st.floats(0, 1), label="max_fraction_spent_plotting"),
            last_n_batches=data.draw(st.none() | st.integers(1, 100),
                                     label="last_n_batches"),
        )
        self.logger = LiveLogger()

        note("Train metric names: {}".format(self.train_metric_names))
        note("Test metric names: {}".format(self.test_metric_names))
예제 #5
0
def test_concat_experiments(logger: LiveLogger, num_exps: int,
                            data: st.DataObject):
    metrics = list(logger.train_metrics)
    assume(len(metrics) > 0)

    logger.set_train_batch(
        {k: data.draw(st.floats(-1e6, 1e6))
         for k in metrics}, batch_size=1)
    batch_xarrays = [logger.to_xarray("train")[0]]

    for n in range(num_exps - 1):
        logger.set_train_batch(
            {k: data.draw(st.floats(-1e6, 1e6))
             for k in metrics},
            batch_size=1)
        batch_xarrays.append(logger.to_xarray("train")[0])

    out = concat_experiments(*batch_xarrays)
    assert list(out.coords["experiment"]) == list(range(num_exps))
    assert list(out.data_vars) == list(metrics)

    for n in range(num_exps):
        for metric in metrics:
            assert_equal(
                batch_xarrays[n].to_array(metric),
                out.isel(experiment=n).drop_vars(
                    names=["experiment"]).to_array(metric).dropna(
                        dim="iterations"),
            )
예제 #6
0
def test_trivial_case():
    """ Perform a trivial sanity check on live logger"""
    logger = LiveLogger()
    logger.set_train_batch(dict(a=1.0), batch_size=1)
    logger.set_train_batch(dict(a=3.0), batch_size=1)
    logger.set_train_epoch()

    assert_array_equal(logger.train_metrics["a"]["batch_data"], np.array([1.0, 3.0]))
    assert_array_equal(logger.train_metrics["a"]["epoch_domain"], np.array([2]))
    assert_array_equal(
        logger.train_metrics["a"]["epoch_data"], np.array([1.0 / 2.0 + 3.0 / 2.0])
    )
예제 #7
0
    def check_from_dict_roundtrip(self):
        logger_dict = self.logger.to_dict()
        new_logger = LiveLogger.from_dict(logger_dict)

        for attr in [
            "_num_train_epoch",
            "_num_train_batch",
            "_num_test_epoch",
            "_num_test_batch",
        ]:
            desired = getattr(self.logger, attr)
            actual = getattr(new_logger, attr)
            assert actual == desired, (
                "`LiveLogger.from_dict` did not round-trip successfully.\n"
                "logger.{} does not match.\nGot: {}\nExpected: {}"
                "".format(attr, actual, desired)
            )

        compare_all_metrics(self.logger.train_metrics, new_logger.train_metrics)
        compare_all_metrics(self.logger.test_metrics, new_logger.test_metrics)
예제 #8
0
파일: utils.py 프로젝트: rsokl/noggin
def plot_logger(
    logger: LiveLogger,
    plot_batches: bool = True,
    last_n_batches: Optional[int] = None,
    colors: Optional[Dict[str, Union[ValidColor, Dict[str,
                                                      ValidColor]]]] = None,
    nrows: Optional[int] = None,
    ncols: int = 1,
    figsize: Optional[Tuple[int, int]] = None,
) -> Tuple[LivePlot, Figure, Union[Axes, np.ndarray]]:
    """Plots the data recorded by a :class:`~noggin.logger.LiveLogger` instance.

    Converts the logger to an instance of :class:`~noggin.plotter.LivePlot`.

    Parameters
    ----------
    logger : LiveLogger
        The logger whose train/test-split batch/epoch-level data will be plotted.

    plot_batches : bool, optional (default=True)
        If ``True`` include batch-level data in plot.

    last_n_batches : Optional[int]
        The maximum number of batches to be plotted at any given time.
        If ``None``, all of the data will be plotted.

    colors : Optional[Dict[str, Union[ValidColor, Dict[str, ValidColor]]]]
        ``colors`` can be a dictionary, specifying the colors used to plot
        the metrics. Two mappings are valid:
            - '<metric-name>' -> color-value  (specifies train-metric color only)
            - '<metric-name>' -> {'train'/'test' : color-value}
        If ``None``, default colors are used in the plot.

    nrows : Optional[int]
        Number of rows of the subplot grid. Metrics are added in
        row-major order to fill the grid.

    ncols : int, optional, default: 1
        Number of columns of the subplot grid. Metrics are added in
        row-major order to fill the grid.

    figsize : Optional[Sequence[float, float]]
        Specifies the width and height, respectively, of the figure.

    Returns
    -------
    Tuple[LivePlot, Figure, Union[Axes, np.ndarray]]
        The resulting plotter, matplotlib-figure, and axis (or array of axes)
    """

    if not isinstance(logger, LiveLogger):
        raise TypeError(
            "`logger` must be an instance of `noggin.LiveLogger`, got {}".
            format(logger))

    metrics = sorted(
        set(
            list(logger.train_metrics.keys()) +
            list(logger.test_metrics.keys())))

    plotter = LivePlot(
        metrics,
        max_fraction_spent_plotting=0.0,
        last_n_batches=last_n_batches,
        nrows=nrows,
        ncols=ncols,
        figsize=figsize,
    )

    plotter.last_n_batches = last_n_batches
    if colors is not None:
        plotter.metric_colors = colors

    plotter_dict = plotter.to_dict()

    plotter_dict.update(logger.to_dict())
    plotter = LivePlot.from_dict(plotter_dict)
    plotter.plot(plot_batches=plot_batches)
    fig, ax = plotter.plot_objects
    return plotter, fig, ax
예제 #9
0
def test_loggers(logger: LiveLogger):
    """Ensure that loggers() can produce a Logger that can round-trip"""
    LiveLogger.from_dict(logger.to_dict())
예제 #10
0
class LiveLoggerStateMachine(RuleBasedStateMachine):
    """ Ensures that exercising the api of LiveLogger produces
    results that are consistent with a simplistic implementation"""

    def __init__(self):
        super().__init__()
        self.train_metrics = []  # type: List[LiveMetric]
        self.test_metrics = []  # type: List[LiveMetric]
        self.logger = LiveLogger()
        self.train_batch_set = False
        self.test_batch_set = False
        self.num_train_batch = 0

    @initialize(num_train_metrics=st.integers(0, 3), num_test_metrics=st.integers(0, 3))
    def choose_metrics(self, num_train_metrics: int, num_test_metrics: int):
        train_metric_names = ["metric-a", "metric-b", "metric-c"][:num_train_metrics]
        for name in train_metric_names:
            self.train_metrics.append(LiveMetric(name=name))

        test_metric_names = ["metric-a", "metric-b", "metric-c"][:num_test_metrics]
        for name in test_metric_names:
            self.test_metrics.append(LiveMetric(name=name))

        note("Train metric names: {}".format(train_metric_names))
        note("Test metric names: {}".format(test_metric_names))

    @rule()
    def get_repr(self):
        """ Ensure no side effect """
        repr(self.logger)

    @rule(batch_size=st.integers(0, 2), data=st.data())
    def set_train_batch(self, batch_size: int, data: SearchStrategy):
        if self.train_metrics:
            self.num_train_batch += 1

        self.train_batch_set = True
        batch = {
            metric.name: data.draw(
                st.floats(-1, 1) | st.floats(-1, 1).map(np.array), label=metric.name
            )
            for metric in self.train_metrics
        }
        self.logger.set_train_batch(metrics=batch, batch_size=batch_size)

        for metric in self.train_metrics:
            metric.add_datapoint(batch[metric.name], weighting=batch_size)

    @rule()
    def set_train_epoch(self):
        self.logger.set_train_epoch()
        for metric in self.train_metrics:
            metric.set_epoch_datapoint()

    @rule(batch_size=st.integers(0, 2), data=st.data())
    def set_test_batch(self, batch_size: int, data: SearchStrategy):
        self.test_batch_set = True
        batch = {
            metric.name: data.draw(
                st.floats(-1, 1) | st.floats(-1, 1).map(np.array), label=metric.name
            )
            for metric in self.test_metrics
        }
        self.logger.set_test_batch(metrics=batch, batch_size=batch_size)

        for metric in self.test_metrics:
            metric.add_datapoint(batch[metric.name], weighting=batch_size)

    @rule()
    def set_test_epoch(self):
        self.logger.set_test_epoch()

        # align test-epoch with train domain
        for metric in self.test_metrics:
            if metric.name in {m.name for m in self.train_metrics}:
                x = self.num_train_batch if self.num_train_batch > 0 else None
            else:
                x = None
            metric.set_epoch_datapoint(x)

    @precondition(lambda self: self.train_batch_set)
    @invariant()
    def compare_train_metrics(self):
        logged_metrics = self.logger.train_metrics
        expected_metrics = dict(
            (metric.name, metric.to_dict()) for metric in self.train_metrics
        )
        compare_all_metrics(logged_metrics, expected_metrics)

    @precondition(lambda self: self.test_batch_set)
    @invariant()
    def compare_test_metrics(self):
        logged_metrics = self.logger.test_metrics
        expected_metrics = dict(
            (metric.name, metric.to_dict()) for metric in self.test_metrics
        )
        compare_all_metrics(logged_metrics, expected_metrics)

    @invariant()
    def check_from_dict_roundtrip(self):
        logger_dict = self.logger.to_dict()
        new_logger = LiveLogger.from_dict(logger_dict)

        for attr in [
            "_num_train_epoch",
            "_num_train_batch",
            "_num_test_epoch",
            "_num_test_batch",
        ]:
            desired = getattr(self.logger, attr)
            actual = getattr(new_logger, attr)
            assert actual == desired, (
                "`LiveLogger.from_dict` did not round-trip successfully.\n"
                "logger.{} does not match.\nGot: {}\nExpected: {}"
                "".format(attr, actual, desired)
            )

        compare_all_metrics(self.logger.train_metrics, new_logger.train_metrics)
        compare_all_metrics(self.logger.test_metrics, new_logger.test_metrics)

    @rule(save_via_live_object=st.booleans())
    def check_metric_io(self, save_via_live_object: bool):
        """Ensure the saving/loading metrics always produces self-consistent
        results with the logger"""
        from uuid import uuid4

        filename = str(uuid4())
        if save_via_live_object:
            save_metrics(filename, liveplot=self.logger)
        else:
            save_metrics(
                filename,
                train_metrics=self.logger.train_metrics,
                test_metrics=self.logger.test_metrics,
            )
        io_train_metrics, io_test_metrics = load_metrics(filename)

        compare_all_metrics(io_train_metrics, self.logger.train_metrics)
        compare_all_metrics(io_test_metrics, self.logger.test_metrics)
예제 #11
0
def test_logger_init_compatible_with_plotter(args: tuple, kwargs: dict):
    LiveLogger(*args, **kwargs)
예제 #12
0
def test_fuzz_set_epoch(logger: LiveLogger):
    logger.set_train_epoch()
    logger.set_test_epoch()
예제 #13
0
class LivePlotStateMachine(RuleBasedStateMachine):
    """Provides basic rules for exercising essential aspects of LivePlot"""
    def __init__(self):
        super().__init__()
        self.train_metric_names = []
        self.test_metric_names = []
        self.train_batch_set = False
        self.test_batch_set = False
        self.plotter = None  # type: LivePlot
        self.logger = None  # type: LiveLogger

    @initialize(
        num_train_metrics=st.integers(0, 3),
        num_test_metrics=st.integers(0, 3),
        data=st.data(),
    )
    def choose_metrics(self, num_train_metrics: int, num_test_metrics: int,
                       data: st.SearchStrategy):
        assume(num_train_metrics + num_test_metrics > 0)
        self.train_metric_names = ["metric-a", "metric-b",
                                   "metric-c"][:num_train_metrics]

        self.test_metric_names = ["metric-a", "metric-b",
                                  "metric-c"][:num_test_metrics]
        train_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_train_metrics,
                max_size=num_train_metrics,
            ),
            label="train_colors",
        )

        test_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_test_metrics,
                max_size=num_test_metrics,
            ),
            label="test_colors",
        )

        metrics = OrderedDict((n, dict()) for n in sorted(
            set(self.train_metric_names + self.test_metric_names)))

        for metric, color in zip(self.train_metric_names, train_colors):
            metrics[metric]["train"] = color

        for metric, color in zip(self.test_metric_names, test_colors):
            metrics[metric]["test"] = color

        self.plotter = LivePlot(
            metrics,
            max_fraction_spent_plotting=data.draw(
                st.floats(0, 1), label="max_fraction_spent_plotting"),
            last_n_batches=data.draw(st.none() | st.integers(1, 100),
                                     label="last_n_batches"),
        )
        self.logger = LiveLogger()

        note("Train metric names: {}".format(self.train_metric_names))
        note("Test metric names: {}".format(self.test_metric_names))

    @rule(batch_size=st.integers(0, 2), data=st.data(), plot=st.booleans())
    def set_train_batch(self, batch_size: int, data: SearchStrategy,
                        plot: bool):
        self.train_batch_set = True

        batch = {
            name: data.draw(st.floats(-1, 1), label=name)
            for name in self.train_metric_names
        }
        self.logger.set_train_batch(metrics=batch, batch_size=batch_size)
        self.plotter.set_train_batch(metrics=batch,
                                     batch_size=batch_size,
                                     plot=plot)

    @rule()
    def set_train_epoch(self):
        self.logger.set_train_epoch()
        self.plotter.set_train_epoch()

    @rule(batch_size=st.integers(0, 2), data=st.data())
    def set_test_batch(self, batch_size: int, data: SearchStrategy):
        self.test_batch_set = True

        batch = {
            name: data.draw(st.floats(-1, 1), label=name)
            for name in self.test_metric_names
        }
        self.logger.set_test_batch(metrics=batch, batch_size=batch_size)
        self.plotter.set_test_batch(metrics=batch, batch_size=batch_size)

    @rule()
    def set_test_epoch(self):
        self.logger.set_test_epoch()
        self.plotter.set_test_epoch()

    def teardown(self):
        plt.close("all")
        super().teardown()
예제 #14
0
def test_logger_xarray_validate_inputs(logger: LiveLogger):
    with pytest.raises(ValueError):
        logger.to_xarray("traintest")