Example #1
0
def test_trivial_case():
    """ Perform a trivial sanity check on live metric"""
    metric = LiveMetric("a")
    metric.add_datapoint(1.0, weighting=1.0)
    metric.add_datapoint(3.0, weighting=1.0)
    metric.set_epoch_datapoint(99)
    assert_array_equal(metric.batch_domain, np.array([1, 2]))
    assert_array_equal(metric.batch_data, np.array([1.0, 3.0]))

    assert_array_equal(metric.epoch_domain, np.array([99]))
    assert_array_equal(metric.epoch_data, np.array([1.0 / 2.0 + 3.0 / 2.0]))

    dict_ = metric.to_dict()
    for name in ("batch_data", "epoch_data", "epoch_domain"):
        assert_array_equal(
            dict_[name],
            getattr(metric, name),
            err_msg=name + " does not map to the correct value in the metric-dict",
        )
Example #2
0
class LiveMetricChecker(RuleBasedStateMachine):
    """ Ensures that exercising the api of LiveMetric produces
    results that are consistent with a simplistic implementation"""

    def __init__(self):
        super().__init__()

        self.batch_data = []
        self._weights = []
        self.epoch_data = []
        self.epoch_domain = []
        self.livemetric = None  # type: LiveMetric
        self.name = None  # type: str

    @initialize(name=st.sampled_from(["a", "b", "c"]))
    def init_metric(self, name: str):
        self.livemetric = LiveMetric(name)
        self.name = name

    @rule(value=st.floats(-1e6, 1e6), weighting=st.one_of(st.none(), st.floats(0, 2)))
    def add_datapoint(self, value: float, weighting: Optional[float]):
        if weighting is not None:
            self.livemetric.add_datapoint(value=value, weighting=weighting)
        else:
            self.livemetric.add_datapoint(value=value)
        self.batch_data.append(value)
        self._weights.append(weighting if weighting is not None else 1.0)

    @rule()
    def set_epoch_datapoint(self):
        self.livemetric.set_epoch_datapoint()

        if self._weights:
            batch_dat = np.array(self.batch_data)[-len(self._weights) :]
            weights = np.array(self._weights) / sum(self._weights)
            weights = np.nan_to_num(weights)
            epoch_mean = batch_dat @ weights
            self.epoch_data.append(epoch_mean)
            self.epoch_domain.append(len(self.batch_data))
            self._weights = []

    @rule()
    def show_repr(self):
        """ Ensure no side effects of calling `repr()`"""
        repr(self.livemetric)

    @precondition(lambda self: self.livemetric is not None)
    @invariant()
    def dict_roundtrip(self):
        """Ensure `from_dict(to_dict())` round trip is successful"""
        metrics_dict = self.livemetric.to_dict()
        new_metrics = LiveMetric.from_dict(metrics_dict=metrics_dict)

        for attr in [
            "name",
            "batch_data",
            "epoch_data",
            "epoch_domain",
            "_batch_data",
            "_epoch_data",
            "_epoch_domain",
            "_running_weighted_sum",
            "_total_weighting",
            "_cnt_since_epoch",
        ]:
            desired = getattr(self.livemetric, attr)
            actual = getattr(new_metrics, attr)
            assert type(actual) == type(desired), attr
            assert_array_equal(
                actual,
                desired,
                err_msg="`LiveMetric.from_dict` did not round-trip successfully.\n"
                "livemetric.{} does not match.\nGot: {}\nExpected: {}"
                "".format(attr, actual, desired),
            )

    @rule()
    def check_batch_data_is_consistent(self):
        actual_batch_data1 = self.livemetric.batch_data
        actual_batch_data2 = self.livemetric.batch_data
        assert isinstance(actual_batch_data1, np.ndarray)
        assert isinstance(actual_batch_data2, np.ndarray)
        assert_array_equal(
            actual_batch_data1,
            actual_batch_data2,
            err_msg="calling `LiveMetric.batch_data` two"
            "consecutive times produces different "
            "results",
        )

    @rule()
    def check_epoch_data_is_consistent(self):
        actual_epoch_data1 = self.livemetric.epoch_data
        actual_epoch_data2 = self.livemetric.epoch_data
        assert isinstance(actual_epoch_data1, np.ndarray)
        assert isinstance(actual_epoch_data2, np.ndarray)
        assert_array_equal(
            actual_epoch_data1,
            actual_epoch_data2,
            err_msg="calling `LiveMetric.epoch_data` two"
            "consecutive times produces different "
            "results",
        )

    @rule()
    def check_epoch_domain_is_consistent(self):
        actual_epoch_domain1 = self.livemetric.epoch_domain
        actual_epoch_domain2 = self.livemetric.epoch_domain
        assert isinstance(actual_epoch_domain1, np.ndarray)
        assert isinstance(actual_epoch_domain2, np.ndarray)
        assert_array_equal(
            actual_epoch_domain1,
            actual_epoch_domain2,
            err_msg="calling `LiveMetric.epoch_domain` two"
            "consecutive times produces different "
            "results",
        )

    @precondition(lambda self: self.livemetric is not None)
    @invariant()
    def compare(self):
        expected_batch_domain = np.arange(1, len(self.batch_data) + 1)
        expected_batch_data = np.asarray(self.batch_data)
        expected_epoch_data = np.asarray(self.epoch_data)
        expected_epoch_domain = np.asarray(self.epoch_domain)

        actual_batch_domain = self.livemetric.batch_domain
        actual_batch_data = self.livemetric.batch_data
        actual_epoch_data = self.livemetric.epoch_data
        actual_epoch_domain = self.livemetric.epoch_domain

        assert isinstance(actual_batch_domain, np.ndarray)
        assert isinstance(actual_batch_data, np.ndarray)
        assert isinstance(actual_epoch_data, np.ndarray)
        assert isinstance(actual_epoch_domain, np.ndarray)

        assert_array_equal(
            expected_batch_data,
            actual_batch_data,
            err_msg=err_msg(
                desired=expected_batch_data, actual=actual_batch_data, name="Batch Data"
            ),
        )
        assert_array_equal(
            expected_epoch_domain,
            self.livemetric.epoch_domain,
            err_msg=err_msg(
                desired=expected_epoch_domain,
                actual=self.livemetric.epoch_domain,
                name="Epoch Domain",
            ),
        )
        assert_array_equal(
            self.livemetric.batch_domain,
            actual_batch_domain,
            err_msg=err_msg(
                desired=expected_batch_domain,
                actual=actual_batch_domain,
                name="Batch Domain",
            ),
        )
        assert_allclose(
            actual=actual_epoch_data,
            desired=expected_epoch_data,
            err_msg=err_msg(
                desired=expected_epoch_data, actual=actual_batch_data, name="Epoch Data"
            ),
        )

        assert self.livemetric.name == self.name