コード例 #1
0
def test_mismatched_metrics(mismatched_category, mismatched_metric):
    x = dict(
        a=dict(
            batch_data=np.array([0]),
            epoch_data=np.array([0]),
            epoch_domain=np.array([0]),
        ),
        b=dict(
            batch_data=np.array([0]),
            epoch_data=np.array([0]),
            epoch_domain=np.array([0]),
        ),
    )

    y = dict(
        a=dict(
            batch_data=np.array([0]),
            epoch_data=np.array([0]),
            epoch_domain=np.array([0]),
        ),
        b=dict(
            batch_data=np.array([0]),
            epoch_data=np.array([0]),
            epoch_domain=np.array([0]),
        ),
    )
    if mismatched_category is not None:
        x[mismatched_metric][mismatched_category] += 1
        with pytest.raises(AssertionError):
            compare_all_metrics(x, y)
    else:
        compare_all_metrics(x, y)
コード例 #2
0
    def check_metric_io(self, save_via_object: bool):
        """Ensure the saving/loading metrics always produces self-consistent
        results with the plotter"""
        from uuid import uuid4

        filename = str(uuid4())
        if save_via_object:
            save_metrics(filename, liveplot=self.plotter)
        else:
            save_metrics(
                filename,
                train_metrics=self.plotter.train_metrics,
                test_metrics=self.plotter.test_metrics,
            )
        io_train_metrics, io_test_metrics = load_metrics(filename)

        plot_train_metrics = self.plotter.train_metrics
        plot_test_metrics = self.plotter.test_metrics

        assert tuple(io_test_metrics) == tuple(plot_test_metrics), (
            "The io test metrics do not match those from the LivePlot "
            "instance. Order matters for reproducing the plot.")

        compare_all_metrics(plot_train_metrics, io_train_metrics)
        compare_all_metrics(plot_test_metrics, io_test_metrics)
コード例 #3
0
def test_mismatched_number_of_metrics():
    with pytest.raises(AssertionError):
        compare_all_metrics(
            dict(a=dict(
                batch_data=np.array([0]),
                epoch_data=np.array([0]),
                epoch_domain=np.array([0]),
            )),
            dict(),
        )
コード例 #4
0
    def check_from_dict_roundtrip(self):
        plotter_dict = self.plotter.to_dict()
        filename = str(uuid4())
        with open(filename, "wb") as f:
            pickle.dump(plotter_dict, f)

        with open(filename, "rb") as f:
            loaded_dict = pickle.load(f)

        new_plotter = LivePlot.from_dict(loaded_dict)

        for attr in [
                "_num_train_epoch",
                "_num_train_batch",
                "_num_test_epoch",
                "_num_test_batch",
                "_metrics",
                "_pltkwargs",
                "metric_colors",
        ]:
            desired = getattr(self.plotter, attr)
            actual = getattr(new_plotter, attr)
            assert_array_equal(
                actual,
                desired,
                err_msg=
                "LiveLogger.from_metrics did not round-trip successfully.\n"
                "logger.{} does not match.\nGot: {}\nExpected: {}"
                "".format(attr, actual, desired),
            )

        compare_all_metrics(self.plotter.train_metrics,
                            new_plotter.train_metrics)
        compare_all_metrics(self.plotter.test_metrics,
                            new_plotter.test_metrics)

        assert isinstance(new_plotter._test_colors,
                          type(self.plotter._test_colors))
        assert self.plotter._test_colors == new_plotter._test_colors
        assert self.plotter._test_colors[None] is new_plotter._test_colors[None]

        assert isinstance(self.plotter._train_colors,
                          type(new_plotter._train_colors))
        assert self.plotter._train_colors == new_plotter._train_colors
        assert self.plotter._train_colors[None] is new_plotter._train_colors[
            None]

        # check consistency for all public attributes
        for attr in (
                x for x in dir(self.plotter) if not x.startswith("_")
                and not callable(getattr(self.plotter, x)) and x not in
            {"plot_objects", "metrics", "test_metrics", "train_metrics"}):
            original_attr = getattr(self.plotter, attr)
            from_dict_attr = getattr(new_plotter, attr)
            assert original_attr == from_dict_attr, attr
コード例 #5
0
    def check_metric_io(self, save_via_live_object: bool):
        """Ensure the saving/loading metrics always produces self-consistent
        results with the logger"""
        from uuid import uuid4

        filename = str(uuid4())
        if save_via_live_object:
            save_metrics(filename, liveplot=self.logger)
        else:
            save_metrics(
                filename,
                train_metrics=self.logger.train_metrics,
                test_metrics=self.logger.test_metrics,
            )
        io_train_metrics, io_test_metrics = load_metrics(filename)

        compare_all_metrics(io_train_metrics, self.logger.train_metrics)
        compare_all_metrics(io_test_metrics, self.logger.test_metrics)
コード例 #6
0
    def check_from_dict_roundtrip(self):
        logger_dict = self.logger.to_dict()
        new_logger = LiveLogger.from_dict(logger_dict)

        for attr in [
            "_num_train_epoch",
            "_num_train_batch",
            "_num_test_epoch",
            "_num_test_batch",
        ]:
            desired = getattr(self.logger, attr)
            actual = getattr(new_logger, attr)
            assert actual == desired, (
                "`LiveLogger.from_dict` did not round-trip successfully.\n"
                "logger.{} does not match.\nGot: {}\nExpected: {}"
                "".format(attr, actual, desired)
            )

        compare_all_metrics(self.logger.train_metrics, new_logger.train_metrics)
        compare_all_metrics(self.logger.test_metrics, new_logger.test_metrics)
コード例 #7
0
 def compare_test_metrics(self):
     log_metrics = self.logger.test_metrics
     plot_metrics = self.plotter.test_metrics
     compare_all_metrics(log_metrics, plot_metrics)
コード例 #8
0
 def compare_test_metrics(self):
     logged_metrics = self.logger.test_metrics
     expected_metrics = dict(
         (metric.name, metric.to_dict()) for metric in self.test_metrics
     )
     compare_all_metrics(logged_metrics, expected_metrics)