def test_saved_content(self):
        with tempfile.TemporaryDirectory() as tempdir:

            # set up engine
            def _train_func(engine, batch):
                return torch.zeros(8)

            engine = Engine(_train_func)

            # set up testing handler
            saver = CSVSaver(output_dir=tempdir, filename="predictions2.csv")
            ClassificationSaver(output_dir=tempdir,
                                filename="predictions1.csv").attach(engine)
            ClassificationSaver(saver=saver).attach(engine)

            data = [{
                "filename_or_obj": ["testfile" + str(i) for i in range(8)]
            }]
            engine.run(data, max_epochs=1)

            def _test_file(filename):
                filepath = os.path.join(tempdir, filename)
                self.assertTrue(os.path.exists(filepath))
                with open(filepath, "r") as f:
                    reader = csv.reader(f)
                    i = 0
                    for row in reader:
                        self.assertEqual(row[0], "testfile" + str(i))
                        self.assertEqual(
                            np.array(row[1:]).astype(np.float32), 0.0)
                        i += 1
                    self.assertEqual(i, 8)

            _test_file("predictions1.csv")
            _test_file("predictions2.csv")
Esempio n. 2
0
    def __init__(
        self,
        keys: KeysCollection,
        meta_keys: Optional[KeysCollection] = None,
        meta_key_postfix: str = DEFAULT_POST_FIX,
        saver: Optional[CSVSaver] = None,
        output_dir: PathLike = "./",
        filename: str = "predictions.csv",
        delimiter: str = ",",
        overwrite: bool = True,
        flush: bool = True,
        allow_missing_keys: bool = False,
    ) -> None:
        """
        Args:
            keys: keys of the corresponding items to model output, this transform only supports 1 key.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            meta_keys: explicitly indicate the key of the corresponding metadata dictionary.
                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
                the metadata is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
                will extract the filename of input image to save classification results.
            meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`.
                so need the key to extract the metadata of input image, like filename, etc. default is `meta_dict`.
                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
                the metadata is a dictionary object which contains: filename, original_shape, etc.
                this arg only works when `meta_keys=None`. if no corresponding metadata, set to `None`.
            saver: the saver instance to save classification results, if None, create a CSVSaver internally.
                the saver must provide `save(data, meta_data)` and `finalize()` APIs.
            output_dir: if `saver=None`, specify the directory to save the CSV file.
            filename: if `saver=None`, specify the name of the saved CSV file.
            delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
                to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
            overwrite: if `saver=None`, indicate whether to overwriting existing CSV file content, if True,
                will clear the file before saving. otherwise, will append new content to the CSV file.
            flush: if `saver=None`, indicate whether to write the cache data to CSV file immediately
                in this transform and clear the cache. default to True.
                If False, may need user to call `saver.finalize()` manually or use `ClassificationSaver` handler.
            allow_missing_keys: don't raise exception if key is missing.

        """
        super().__init__(keys, allow_missing_keys)
        if len(self.keys) != 1:
            raise ValueError(
                "only 1 key is allowed when saving the classification result.")
        self.saver = saver or CSVSaver(output_dir=output_dir,
                                       filename=filename,
                                       overwrite=overwrite,
                                       flush=flush,
                                       delimiter=delimiter)
        self.flush = flush
        self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))
        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix,
                                                 len(self.keys))