def test_mismatched_metrics(mismatched_category, mismatched_metric): x = dict( a=dict( batch_data=np.array([0]), epoch_data=np.array([0]), epoch_domain=np.array([0]), ), b=dict( batch_data=np.array([0]), epoch_data=np.array([0]), epoch_domain=np.array([0]), ), ) y = dict( a=dict( batch_data=np.array([0]), epoch_data=np.array([0]), epoch_domain=np.array([0]), ), b=dict( batch_data=np.array([0]), epoch_data=np.array([0]), epoch_domain=np.array([0]), ), ) if mismatched_category is not None: x[mismatched_metric][mismatched_category] += 1 with pytest.raises(AssertionError): compare_all_metrics(x, y) else: compare_all_metrics(x, y)
def check_metric_io(self, save_via_object: bool): """Ensure the saving/loading metrics always produces self-consistent results with the plotter""" from uuid import uuid4 filename = str(uuid4()) if save_via_object: save_metrics(filename, liveplot=self.plotter) else: save_metrics( filename, train_metrics=self.plotter.train_metrics, test_metrics=self.plotter.test_metrics, ) io_train_metrics, io_test_metrics = load_metrics(filename) plot_train_metrics = self.plotter.train_metrics plot_test_metrics = self.plotter.test_metrics assert tuple(io_test_metrics) == tuple(plot_test_metrics), ( "The io test metrics do not match those from the LivePlot " "instance. Order matters for reproducing the plot.") compare_all_metrics(plot_train_metrics, io_train_metrics) compare_all_metrics(plot_test_metrics, io_test_metrics)
def test_mismatched_number_of_metrics(): with pytest.raises(AssertionError): compare_all_metrics( dict(a=dict( batch_data=np.array([0]), epoch_data=np.array([0]), epoch_domain=np.array([0]), )), dict(), )
def check_from_dict_roundtrip(self): plotter_dict = self.plotter.to_dict() filename = str(uuid4()) with open(filename, "wb") as f: pickle.dump(plotter_dict, f) with open(filename, "rb") as f: loaded_dict = pickle.load(f) new_plotter = LivePlot.from_dict(loaded_dict) for attr in [ "_num_train_epoch", "_num_train_batch", "_num_test_epoch", "_num_test_batch", "_metrics", "_pltkwargs", "metric_colors", ]: desired = getattr(self.plotter, attr) actual = getattr(new_plotter, attr) assert_array_equal( actual, desired, err_msg= "LiveLogger.from_metrics did not round-trip successfully.\n" "logger.{} does not match.\nGot: {}\nExpected: {}" "".format(attr, actual, desired), ) compare_all_metrics(self.plotter.train_metrics, new_plotter.train_metrics) compare_all_metrics(self.plotter.test_metrics, new_plotter.test_metrics) assert isinstance(new_plotter._test_colors, type(self.plotter._test_colors)) assert self.plotter._test_colors == new_plotter._test_colors assert self.plotter._test_colors[None] is new_plotter._test_colors[None] assert isinstance(self.plotter._train_colors, type(new_plotter._train_colors)) assert self.plotter._train_colors == new_plotter._train_colors assert self.plotter._train_colors[None] is new_plotter._train_colors[ None] # check consistency for all public attributes for attr in ( x for x in dir(self.plotter) if not x.startswith("_") and not callable(getattr(self.plotter, x)) and x not in {"plot_objects", "metrics", "test_metrics", "train_metrics"}): original_attr = getattr(self.plotter, attr) from_dict_attr = getattr(new_plotter, attr) assert original_attr == from_dict_attr, attr
def check_metric_io(self, save_via_live_object: bool): """Ensure the saving/loading metrics always produces self-consistent results with the logger""" from uuid import uuid4 filename = str(uuid4()) if save_via_live_object: save_metrics(filename, liveplot=self.logger) else: save_metrics( filename, train_metrics=self.logger.train_metrics, test_metrics=self.logger.test_metrics, ) io_train_metrics, io_test_metrics = load_metrics(filename) compare_all_metrics(io_train_metrics, self.logger.train_metrics) compare_all_metrics(io_test_metrics, self.logger.test_metrics)
def check_from_dict_roundtrip(self): logger_dict = self.logger.to_dict() new_logger = LiveLogger.from_dict(logger_dict) for attr in [ "_num_train_epoch", "_num_train_batch", "_num_test_epoch", "_num_test_batch", ]: desired = getattr(self.logger, attr) actual = getattr(new_logger, attr) assert actual == desired, ( "`LiveLogger.from_dict` did not round-trip successfully.\n" "logger.{} does not match.\nGot: {}\nExpected: {}" "".format(attr, actual, desired) ) compare_all_metrics(self.logger.train_metrics, new_logger.train_metrics) compare_all_metrics(self.logger.test_metrics, new_logger.test_metrics)
def compare_test_metrics(self): log_metrics = self.logger.test_metrics plot_metrics = self.plotter.test_metrics compare_all_metrics(log_metrics, plot_metrics)
def compare_test_metrics(self): logged_metrics = self.logger.test_metrics expected_metrics = dict( (metric.name, metric.to_dict()) for metric in self.test_metrics ) compare_all_metrics(logged_metrics, expected_metrics)