def from_dict(cls, plotter_dict): """Records the state of the plotter in a dictionary. This is the inverse of :func:`~noggin.plotter.LivePlot.to_dict` Parameters ---------- plotter_dict : Dict[str, Any] The dictionary storing the state of the logger to be restored. Returns ------- noggin.LivePlot The restored plotter. Notes ----- This is a class-method, the syntax for invoking it is: >>> loaded_plotter = LivePlot.from_dict(plotter_dict) To restore your plot from the loaded plotter, call: >>> loaded_plotter.plot() """ new = cls( metrics=plotter_dict["metric_names"], max_fraction_spent_plotting=plotter_dict["max_fraction_spent_plotting"], last_n_batches=plotter_dict["last_n_batches"], ) new._train_metrics.update( (key, LiveMetric.from_dict(metric)) for key, metric in plotter_dict["train_metrics"].items() ) new._test_metrics.update( (key, LiveMetric.from_dict(metric)) for key, metric in plotter_dict["test_metrics"].items() ) for train_mode, stat_mode in product(["train", "test"], ["batch", "epoch"]): item = "num_{}_{}".format(train_mode, stat_mode) setattr(new, "_" + item, plotter_dict[item]) for attr in ("pltkwargs", "train_colors", "test_colors"): setattr(new, "_" + attr, plotter_dict[attr]) train_colors = defaultdict(lambda: None) test_colors = defaultdict(lambda: None) train_colors.update(new._train_colors) test_colors.update(new._test_colors) new._train_colors = train_colors new._test_colors = test_colors return new
def choose_metrics(self, num_train_metrics: int, num_test_metrics: int): train_metric_names = ["metric-a", "metric-b", "metric-c"][:num_train_metrics] for name in train_metric_names: self.train_metrics.append(LiveMetric(name=name)) test_metric_names = ["metric-a", "metric-b", "metric-c"][:num_test_metrics] for name in test_metric_names: self.test_metrics.append(LiveMetric(name=name)) note("Train metric names: {}".format(train_metric_names)) note("Test metric names: {}".format(test_metric_names))
def test_from_dict_input_validation2(bad_input: dict, data: st.DataObject): input_dict = {} bad_input = { k: data.draw(v, label=k) if isinstance(v, st.SearchStrategy) else v for k, v in bad_input.items() } for name, metrics in static_logger_dict.items(): input_dict = metrics.copy() input_dict.update(bad_input) break assert input_dict with pytest.raises((ValueError, TypeError)): LiveMetric.from_dict(input_dict)
def dict_roundtrip(self): """Ensure `from_dict(to_dict())` round trip is successful""" metrics_dict = self.livemetric.to_dict() new_metrics = LiveMetric.from_dict(metrics_dict=metrics_dict) for attr in [ "name", "batch_data", "epoch_data", "epoch_domain", "_batch_data", "_epoch_data", "_epoch_domain", "_running_weighted_sum", "_total_weighting", "_cnt_since_epoch", ]: desired = getattr(self.livemetric, attr) actual = getattr(new_metrics, attr) assert type(actual) == type(desired), attr assert_array_equal( actual, desired, err_msg="`LiveMetric.from_dict` did not round-trip successfully.\n" "livemetric.{} does not match.\nGot: {}\nExpected: {}" "".format(attr, actual, desired), )
def test_trivial_case(): """ Perform a trivial sanity check on live metric""" metric = LiveMetric("a") metric.add_datapoint(1.0, weighting=1.0) metric.add_datapoint(3.0, weighting=1.0) metric.set_epoch_datapoint(99) assert_array_equal(metric.batch_domain, np.array([1, 2])) assert_array_equal(metric.batch_data, np.array([1.0, 3.0])) assert_array_equal(metric.epoch_domain, np.array([99])) assert_array_equal(metric.epoch_data, np.array([1.0 / 2.0 + 3.0 / 2.0])) dict_ = metric.to_dict() for name in ("batch_data", "epoch_data", "epoch_domain"): assert_array_equal( dict_[name], getattr(metric, name), err_msg=name + " does not map to the correct value in the metric-dict", )
class LiveMetricChecker(RuleBasedStateMachine): """ Ensures that exercising the api of LiveMetric produces results that are consistent with a simplistic implementation""" def __init__(self): super().__init__() self.batch_data = [] self._weights = [] self.epoch_data = [] self.epoch_domain = [] self.livemetric = None # type: LiveMetric self.name = None # type: str @initialize(name=st.sampled_from(["a", "b", "c"])) def init_metric(self, name: str): self.livemetric = LiveMetric(name) self.name = name @rule(value=st.floats(-1e6, 1e6), weighting=st.one_of(st.none(), st.floats(0, 2))) def add_datapoint(self, value: float, weighting: Optional[float]): if weighting is not None: self.livemetric.add_datapoint(value=value, weighting=weighting) else: self.livemetric.add_datapoint(value=value) self.batch_data.append(value) self._weights.append(weighting if weighting is not None else 1.0) @rule() def set_epoch_datapoint(self): self.livemetric.set_epoch_datapoint() if self._weights: batch_dat = np.array(self.batch_data)[-len(self._weights) :] weights = np.array(self._weights) / sum(self._weights) weights = np.nan_to_num(weights) epoch_mean = batch_dat @ weights self.epoch_data.append(epoch_mean) self.epoch_domain.append(len(self.batch_data)) self._weights = [] @rule() def show_repr(self): """ Ensure no side effects of calling `repr()`""" repr(self.livemetric) @precondition(lambda self: self.livemetric is not None) @invariant() def dict_roundtrip(self): """Ensure `from_dict(to_dict())` round trip is successful""" metrics_dict = self.livemetric.to_dict() new_metrics = LiveMetric.from_dict(metrics_dict=metrics_dict) for attr in [ "name", "batch_data", "epoch_data", "epoch_domain", "_batch_data", "_epoch_data", "_epoch_domain", "_running_weighted_sum", "_total_weighting", "_cnt_since_epoch", ]: desired = getattr(self.livemetric, attr) actual = getattr(new_metrics, attr) assert type(actual) == type(desired), attr assert_array_equal( actual, desired, err_msg="`LiveMetric.from_dict` did not round-trip successfully.\n" "livemetric.{} does not match.\nGot: {}\nExpected: {}" "".format(attr, actual, desired), ) @rule() def check_batch_data_is_consistent(self): actual_batch_data1 = self.livemetric.batch_data actual_batch_data2 = self.livemetric.batch_data assert isinstance(actual_batch_data1, np.ndarray) assert isinstance(actual_batch_data2, np.ndarray) assert_array_equal( actual_batch_data1, actual_batch_data2, err_msg="calling `LiveMetric.batch_data` two" "consecutive times produces different " "results", ) @rule() def check_epoch_data_is_consistent(self): actual_epoch_data1 = self.livemetric.epoch_data actual_epoch_data2 = self.livemetric.epoch_data assert isinstance(actual_epoch_data1, np.ndarray) assert isinstance(actual_epoch_data2, np.ndarray) assert_array_equal( actual_epoch_data1, actual_epoch_data2, err_msg="calling `LiveMetric.epoch_data` two" "consecutive times produces different " "results", ) @rule() def check_epoch_domain_is_consistent(self): actual_epoch_domain1 = self.livemetric.epoch_domain actual_epoch_domain2 = self.livemetric.epoch_domain assert isinstance(actual_epoch_domain1, np.ndarray) assert isinstance(actual_epoch_domain2, np.ndarray) assert_array_equal( actual_epoch_domain1, actual_epoch_domain2, err_msg="calling `LiveMetric.epoch_domain` two" "consecutive times produces different " "results", ) @precondition(lambda self: self.livemetric is not None) @invariant() def compare(self): expected_batch_domain = np.arange(1, len(self.batch_data) + 1) expected_batch_data = np.asarray(self.batch_data) expected_epoch_data = np.asarray(self.epoch_data) expected_epoch_domain = np.asarray(self.epoch_domain) actual_batch_domain = self.livemetric.batch_domain actual_batch_data = self.livemetric.batch_data actual_epoch_data = self.livemetric.epoch_data actual_epoch_domain = self.livemetric.epoch_domain assert isinstance(actual_batch_domain, np.ndarray) assert isinstance(actual_batch_data, np.ndarray) assert isinstance(actual_epoch_data, np.ndarray) assert isinstance(actual_epoch_domain, np.ndarray) assert_array_equal( expected_batch_data, actual_batch_data, err_msg=err_msg( desired=expected_batch_data, actual=actual_batch_data, name="Batch Data" ), ) assert_array_equal( expected_epoch_domain, self.livemetric.epoch_domain, err_msg=err_msg( desired=expected_epoch_domain, actual=self.livemetric.epoch_domain, name="Epoch Domain", ), ) assert_array_equal( self.livemetric.batch_domain, actual_batch_domain, err_msg=err_msg( desired=expected_batch_domain, actual=actual_batch_domain, name="Batch Domain", ), ) assert_allclose( actual=actual_epoch_data, desired=expected_epoch_data, err_msg=err_msg( desired=expected_epoch_data, actual=actual_batch_data, name="Epoch Data" ), ) assert self.livemetric.name == self.name
def test_from_dict_input_validation(bad_input: st.SearchStrategy, data: st.DataObject): bad_input = data.draw(bad_input, label="bad_input") with pytest.raises((ValueError, TypeError)): LiveMetric.from_dict(bad_input)
def test_badname(name: Any): with pytest.raises(TypeError): LiveMetric(name)
def init_metric(self, name: str): self.livemetric = LiveMetric(name) self.name = name
def test_livemetrics(live_metrics: dict): """Ensure that each entry in live_metrics() can round-trip via LiveMetric""" for metric_name, metrics_dict in live_metrics.items(): LiveMetric.from_dict(metrics_dict).to_dict()
def test_metrics_dict(metrics_dict: dict): """Ensure that metrics_dict() can round-trip via LiveMetric""" LiveMetric.from_dict(metrics_dict).to_dict()