예제 #1
0
def test_trivial_case():
    """ Perform a trivial sanity check on live plotter"""
    plotter = LivePlot("a")
    plotter.set_train_batch(dict(a=1.0), batch_size=1, plot=False)
    plotter.set_train_batch(dict(a=3.0), batch_size=1, plot=False)
    plotter.set_train_epoch()

    assert_array_equal(plotter.train_metrics["a"]["batch_data"],
                       np.array([1.0, 3.0]))
    assert_array_equal(plotter.train_metrics["a"]["epoch_domain"],
                       np.array([2]))
    assert_array_equal(plotter.train_metrics["a"]["epoch_data"],
                       np.array([1.0 / 2.0 + 3.0 / 2.0]))
예제 #2
0
class LivePlotStateMachine(RuleBasedStateMachine):
    """Provides basic rules for exercising essential aspects of LivePlot"""
    def __init__(self):
        super().__init__()
        self.train_metric_names = []
        self.test_metric_names = []
        self.train_batch_set = False
        self.test_batch_set = False
        self.plotter = None  # type: LivePlot
        self.logger = None  # type: LiveLogger

    @initialize(
        num_train_metrics=st.integers(0, 3),
        num_test_metrics=st.integers(0, 3),
        data=st.data(),
    )
    def choose_metrics(self, num_train_metrics: int, num_test_metrics: int,
                       data: st.SearchStrategy):
        assume(num_train_metrics + num_test_metrics > 0)
        self.train_metric_names = ["metric-a", "metric-b",
                                   "metric-c"][:num_train_metrics]

        self.test_metric_names = ["metric-a", "metric-b",
                                  "metric-c"][:num_test_metrics]
        train_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_train_metrics,
                max_size=num_train_metrics,
            ),
            label="train_colors",
        )

        test_colors = data.draw(
            st.lists(
                cst.matplotlib_colors(),
                min_size=num_test_metrics,
                max_size=num_test_metrics,
            ),
            label="test_colors",
        )

        metrics = OrderedDict((n, dict()) for n in sorted(
            set(self.train_metric_names + self.test_metric_names)))

        for metric, color in zip(self.train_metric_names, train_colors):
            metrics[metric]["train"] = color

        for metric, color in zip(self.test_metric_names, test_colors):
            metrics[metric]["test"] = color

        self.plotter = LivePlot(
            metrics,
            max_fraction_spent_plotting=data.draw(
                st.floats(0, 1), label="max_fraction_spent_plotting"),
            last_n_batches=data.draw(st.none() | st.integers(1, 100),
                                     label="last_n_batches"),
        )
        self.logger = LiveLogger()

        note("Train metric names: {}".format(self.train_metric_names))
        note("Test metric names: {}".format(self.test_metric_names))

    @rule(batch_size=st.integers(0, 2), data=st.data(), plot=st.booleans())
    def set_train_batch(self, batch_size: int, data: SearchStrategy,
                        plot: bool):
        self.train_batch_set = True

        batch = {
            name: data.draw(st.floats(-1, 1), label=name)
            for name in self.train_metric_names
        }
        self.logger.set_train_batch(metrics=batch, batch_size=batch_size)
        self.plotter.set_train_batch(metrics=batch,
                                     batch_size=batch_size,
                                     plot=plot)

    @rule()
    def set_train_epoch(self):
        self.logger.set_train_epoch()
        self.plotter.set_train_epoch()

    @rule(batch_size=st.integers(0, 2), data=st.data())
    def set_test_batch(self, batch_size: int, data: SearchStrategy):
        self.test_batch_set = True

        batch = {
            name: data.draw(st.floats(-1, 1), label=name)
            for name in self.test_metric_names
        }
        self.logger.set_test_batch(metrics=batch, batch_size=batch_size)
        self.plotter.set_test_batch(metrics=batch, batch_size=batch_size)

    @rule()
    def set_test_epoch(self):
        self.logger.set_test_epoch()
        self.plotter.set_test_epoch()

    def teardown(self):
        plt.close("all")
        super().teardown()