예제 #1
0
    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        ws = idist.get_world_size()
        if self.save_rank >= ws:
            raise ValueError("target rank is greater than the distributed group size.")

        # all gather file names across ranks
        _images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames

        # only save metrics to file in specified rank
        if idist.get_rank() == self.save_rank:
            _metrics = {}
            if self.metrics is not None and len(engine.state.metrics) > 0:
                _metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics}
            _metric_details = {}
            if self.metric_details is not None and len(engine.state.metric_details) > 0:
                for k, v in engine.state.metric_details.items():
                    if k in self.metric_details or "*" in self.metric_details:
                        _metric_details[k] = v

            write_metrics_reports(
                save_dir=self.save_dir,
                images=_images,
                metrics=_metrics,
                metric_details=_metric_details,
                summary_ops=self.summary_ops,
                deli=self.deli,
                output_type=self.output_type,
            )
    def test_content(self):
        with tempfile.TemporaryDirectory() as tempdir:
            write_metrics_reports(
                save_dir=tempdir,
                images=["filepath1", "filepath2"],
                metrics={
                    "metric1": 1,
                    "metric2": 2
                },
                metric_details={
                    "metric3": torch.tensor([[1, 2], [2, 3]]),
                    "metric4": torch.tensor([[5, 6], [7, 8]])
                },
                summary_ops=["mean", "median", "max", "90percentile"],
                deli="\t",
                output_type="csv",
            )

            # check the metrics.csv and content
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metrics.csv")))
            with open(os.path.join(tempdir, "metrics.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_raw.csv")))
            # check the metric_raw.csv and content
            with open(os.path.join(tempdir, "metric3_raw.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i > 0:
                        self.assertEqual(row, [
                            f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"
                        ])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_summary.csv")))
            # check the metric_summary.csv and content
            with open(os.path.join(tempdir, "metric3_summary.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i == 1:
                        self.assertEqual(
                            row, ["class0\t1.5000\t1.5000\t2.0000\t1.9000"])
                    elif i == 2:
                        self.assertEqual(
                            row, ["class1\t2.5000\t2.5000\t3.0000\t2.9000"])
                    elif i == 3:
                        self.assertEqual(
                            row, ["mean\t2.0000\t2.0000\t2.5000\t2.4000"])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_raw.csv")))
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_summary.csv")))
예제 #3
0
    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        ws = idist.get_world_size()
        if self.save_rank >= ws:
            raise ValueError(
                "target rank is greater than the distributed group size.")

        _images = self._filenames
        if ws > 1:
            _filenames = self.deli.join(_images)
            if get_torch_version_tuple() > (1, 6, 0):
                # all gather across all processes
                _filenames = self.deli.join(idist.all_gather(_filenames))
            else:
                raise RuntimeError(
                    "MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0."
                )
            _images = _filenames.split(self.deli)

        # only save metrics to file in specified rank
        if idist.get_rank() == self.save_rank:
            _metrics = {}
            if self.metrics is not None and len(engine.state.metrics) > 0:
                _metrics = {
                    k: v
                    for k, v in engine.state.metrics.items()
                    if k in self.metrics or "*" in self.metrics
                }
            _metric_details = {}
            if self.metric_details is not None and len(
                    engine.state.metric_details) > 0:
                for k, v in engine.state.metric_details.items():
                    if k in self.metric_details or "*" in self.metric_details:
                        _metric_details[k] = v

            write_metrics_reports(
                save_dir=self.save_dir,
                images=_images,
                metrics=_metrics,
                metric_details=_metric_details,
                summary_ops=self.summary_ops,
                deli=self.deli,
                output_type=self.output_type,
            )