Example #1
0
    def from_dict(cls, plotter_dict):
        """Records the state of the plotter in a dictionary.

        This is the inverse of :func:`~noggin.plotter.LivePlot.to_dict`

        Parameters
        ----------
        plotter_dict : Dict[str, Any]
            The dictionary storing the state of the logger to be
            restored.

        Returns
        -------
        noggin.LivePlot
            The restored plotter.

        Notes
        -----
        This is a class-method, the syntax for invoking it is:

        >>> loaded_plotter = LivePlot.from_dict(plotter_dict)

        To restore your plot from the loaded plotter, call:

        >>> loaded_plotter.plot()
        """
        new = cls(
            metrics=plotter_dict["metric_names"],
            max_fraction_spent_plotting=plotter_dict["max_fraction_spent_plotting"],
            last_n_batches=plotter_dict["last_n_batches"],
        )

        new._train_metrics.update(
            (key, LiveMetric.from_dict(metric))
            for key, metric in plotter_dict["train_metrics"].items()
        )

        new._test_metrics.update(
            (key, LiveMetric.from_dict(metric))
            for key, metric in plotter_dict["test_metrics"].items()
        )

        for train_mode, stat_mode in product(["train", "test"], ["batch", "epoch"]):
            item = "num_{}_{}".format(train_mode, stat_mode)
            setattr(new, "_" + item, plotter_dict[item])

        for attr in ("pltkwargs", "train_colors", "test_colors"):
            setattr(new, "_" + attr, plotter_dict[attr])

        train_colors = defaultdict(lambda: None)
        test_colors = defaultdict(lambda: None)
        train_colors.update(new._train_colors)
        test_colors.update(new._test_colors)
        new._train_colors = train_colors
        new._test_colors = test_colors
        return new
Example #2
0
    def choose_metrics(self, num_train_metrics: int, num_test_metrics: int):
        train_metric_names = ["metric-a", "metric-b", "metric-c"][:num_train_metrics]
        for name in train_metric_names:
            self.train_metrics.append(LiveMetric(name=name))

        test_metric_names = ["metric-a", "metric-b", "metric-c"][:num_test_metrics]
        for name in test_metric_names:
            self.test_metrics.append(LiveMetric(name=name))

        note("Train metric names: {}".format(train_metric_names))
        note("Test metric names: {}".format(test_metric_names))
Example #3
0
def test_from_dict_input_validation2(bad_input: dict, data: st.DataObject):
    input_dict = {}

    bad_input = {
        k: data.draw(v, label=k) if isinstance(v, st.SearchStrategy) else v
        for k, v in bad_input.items()
    }
    for name, metrics in static_logger_dict.items():
        input_dict = metrics.copy()
        input_dict.update(bad_input)
        break

    assert input_dict

    with pytest.raises((ValueError, TypeError)):
        LiveMetric.from_dict(input_dict)
Example #4
0
    def dict_roundtrip(self):
        """Ensure `from_dict(to_dict())` round trip is successful"""
        metrics_dict = self.livemetric.to_dict()
        new_metrics = LiveMetric.from_dict(metrics_dict=metrics_dict)

        for attr in [
            "name",
            "batch_data",
            "epoch_data",
            "epoch_domain",
            "_batch_data",
            "_epoch_data",
            "_epoch_domain",
            "_running_weighted_sum",
            "_total_weighting",
            "_cnt_since_epoch",
        ]:
            desired = getattr(self.livemetric, attr)
            actual = getattr(new_metrics, attr)
            assert type(actual) == type(desired), attr
            assert_array_equal(
                actual,
                desired,
                err_msg="`LiveMetric.from_dict` did not round-trip successfully.\n"
                "livemetric.{} does not match.\nGot: {}\nExpected: {}"
                "".format(attr, actual, desired),
            )
Example #5
0
def test_trivial_case():
    """ Perform a trivial sanity check on live metric"""
    metric = LiveMetric("a")
    metric.add_datapoint(1.0, weighting=1.0)
    metric.add_datapoint(3.0, weighting=1.0)
    metric.set_epoch_datapoint(99)
    assert_array_equal(metric.batch_domain, np.array([1, 2]))
    assert_array_equal(metric.batch_data, np.array([1.0, 3.0]))

    assert_array_equal(metric.epoch_domain, np.array([99]))
    assert_array_equal(metric.epoch_data, np.array([1.0 / 2.0 + 3.0 / 2.0]))

    dict_ = metric.to_dict()
    for name in ("batch_data", "epoch_data", "epoch_domain"):
        assert_array_equal(
            dict_[name],
            getattr(metric, name),
            err_msg=name + " does not map to the correct value in the metric-dict",
        )
Example #6
0
class LiveMetricChecker(RuleBasedStateMachine):
    """ Ensures that exercising the api of LiveMetric produces
    results that are consistent with a simplistic implementation"""

    def __init__(self):
        super().__init__()

        self.batch_data = []
        self._weights = []
        self.epoch_data = []
        self.epoch_domain = []
        self.livemetric = None  # type: LiveMetric
        self.name = None  # type: str

    @initialize(name=st.sampled_from(["a", "b", "c"]))
    def init_metric(self, name: str):
        self.livemetric = LiveMetric(name)
        self.name = name

    @rule(value=st.floats(-1e6, 1e6), weighting=st.one_of(st.none(), st.floats(0, 2)))
    def add_datapoint(self, value: float, weighting: Optional[float]):
        if weighting is not None:
            self.livemetric.add_datapoint(value=value, weighting=weighting)
        else:
            self.livemetric.add_datapoint(value=value)
        self.batch_data.append(value)
        self._weights.append(weighting if weighting is not None else 1.0)

    @rule()
    def set_epoch_datapoint(self):
        self.livemetric.set_epoch_datapoint()

        if self._weights:
            batch_dat = np.array(self.batch_data)[-len(self._weights) :]
            weights = np.array(self._weights) / sum(self._weights)
            weights = np.nan_to_num(weights)
            epoch_mean = batch_dat @ weights
            self.epoch_data.append(epoch_mean)
            self.epoch_domain.append(len(self.batch_data))
            self._weights = []

    @rule()
    def show_repr(self):
        """ Ensure no side effects of calling `repr()`"""
        repr(self.livemetric)

    @precondition(lambda self: self.livemetric is not None)
    @invariant()
    def dict_roundtrip(self):
        """Ensure `from_dict(to_dict())` round trip is successful"""
        metrics_dict = self.livemetric.to_dict()
        new_metrics = LiveMetric.from_dict(metrics_dict=metrics_dict)

        for attr in [
            "name",
            "batch_data",
            "epoch_data",
            "epoch_domain",
            "_batch_data",
            "_epoch_data",
            "_epoch_domain",
            "_running_weighted_sum",
            "_total_weighting",
            "_cnt_since_epoch",
        ]:
            desired = getattr(self.livemetric, attr)
            actual = getattr(new_metrics, attr)
            assert type(actual) == type(desired), attr
            assert_array_equal(
                actual,
                desired,
                err_msg="`LiveMetric.from_dict` did not round-trip successfully.\n"
                "livemetric.{} does not match.\nGot: {}\nExpected: {}"
                "".format(attr, actual, desired),
            )

    @rule()
    def check_batch_data_is_consistent(self):
        actual_batch_data1 = self.livemetric.batch_data
        actual_batch_data2 = self.livemetric.batch_data
        assert isinstance(actual_batch_data1, np.ndarray)
        assert isinstance(actual_batch_data2, np.ndarray)
        assert_array_equal(
            actual_batch_data1,
            actual_batch_data2,
            err_msg="calling `LiveMetric.batch_data` two"
            "consecutive times produces different "
            "results",
        )

    @rule()
    def check_epoch_data_is_consistent(self):
        actual_epoch_data1 = self.livemetric.epoch_data
        actual_epoch_data2 = self.livemetric.epoch_data
        assert isinstance(actual_epoch_data1, np.ndarray)
        assert isinstance(actual_epoch_data2, np.ndarray)
        assert_array_equal(
            actual_epoch_data1,
            actual_epoch_data2,
            err_msg="calling `LiveMetric.epoch_data` two"
            "consecutive times produces different "
            "results",
        )

    @rule()
    def check_epoch_domain_is_consistent(self):
        actual_epoch_domain1 = self.livemetric.epoch_domain
        actual_epoch_domain2 = self.livemetric.epoch_domain
        assert isinstance(actual_epoch_domain1, np.ndarray)
        assert isinstance(actual_epoch_domain2, np.ndarray)
        assert_array_equal(
            actual_epoch_domain1,
            actual_epoch_domain2,
            err_msg="calling `LiveMetric.epoch_domain` two"
            "consecutive times produces different "
            "results",
        )

    @precondition(lambda self: self.livemetric is not None)
    @invariant()
    def compare(self):
        expected_batch_domain = np.arange(1, len(self.batch_data) + 1)
        expected_batch_data = np.asarray(self.batch_data)
        expected_epoch_data = np.asarray(self.epoch_data)
        expected_epoch_domain = np.asarray(self.epoch_domain)

        actual_batch_domain = self.livemetric.batch_domain
        actual_batch_data = self.livemetric.batch_data
        actual_epoch_data = self.livemetric.epoch_data
        actual_epoch_domain = self.livemetric.epoch_domain

        assert isinstance(actual_batch_domain, np.ndarray)
        assert isinstance(actual_batch_data, np.ndarray)
        assert isinstance(actual_epoch_data, np.ndarray)
        assert isinstance(actual_epoch_domain, np.ndarray)

        assert_array_equal(
            expected_batch_data,
            actual_batch_data,
            err_msg=err_msg(
                desired=expected_batch_data, actual=actual_batch_data, name="Batch Data"
            ),
        )
        assert_array_equal(
            expected_epoch_domain,
            self.livemetric.epoch_domain,
            err_msg=err_msg(
                desired=expected_epoch_domain,
                actual=self.livemetric.epoch_domain,
                name="Epoch Domain",
            ),
        )
        assert_array_equal(
            self.livemetric.batch_domain,
            actual_batch_domain,
            err_msg=err_msg(
                desired=expected_batch_domain,
                actual=actual_batch_domain,
                name="Batch Domain",
            ),
        )
        assert_allclose(
            actual=actual_epoch_data,
            desired=expected_epoch_data,
            err_msg=err_msg(
                desired=expected_epoch_data, actual=actual_batch_data, name="Epoch Data"
            ),
        )

        assert self.livemetric.name == self.name
Example #7
0
def test_from_dict_input_validation(bad_input: st.SearchStrategy, data: st.DataObject):
    bad_input = data.draw(bad_input, label="bad_input")
    with pytest.raises((ValueError, TypeError)):
        LiveMetric.from_dict(bad_input)
Example #8
0
def test_badname(name: Any):
    with pytest.raises(TypeError):
        LiveMetric(name)
Example #9
0
 def init_metric(self, name: str):
     self.livemetric = LiveMetric(name)
     self.name = name
Example #10
0
def test_livemetrics(live_metrics: dict):
    """Ensure that each entry in live_metrics() can round-trip via LiveMetric"""
    for metric_name, metrics_dict in live_metrics.items():
        LiveMetric.from_dict(metrics_dict).to_dict()
Example #11
0
def test_metrics_dict(metrics_dict: dict):
    """Ensure that metrics_dict() can round-trip via LiveMetric"""
    LiveMetric.from_dict(metrics_dict).to_dict()