def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: # set up engine def _train_func(engine, batch): return torch.zeros(8) engine = Engine(_train_func) # set up testing handler saver = CSVSaver(output_dir=tempdir, filename="predictions2.csv") ClassificationSaver(output_dir=tempdir, filename="predictions1.csv").attach(engine) ClassificationSaver(saver=saver).attach(engine) data = [{ "filename_or_obj": ["testfile" + str(i) for i in range(8)] }] engine.run(data, max_epochs=1) def _test_file(filename): filepath = os.path.join(tempdir, filename) self.assertTrue(os.path.exists(filepath)) with open(filepath, "r") as f: reader = csv.reader(f) i = 0 for row in reader: self.assertEqual(row[0], "testfile" + str(i)) self.assertEqual( np.array(row[1:]).astype(np.float32), 0.0) i += 1 self.assertEqual(i, 8) _test_file("predictions1.csv") _test_file("predictions2.csv")
def __init__( self, keys: KeysCollection, meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = DEFAULT_POST_FIX, saver: Optional[CSVSaver] = None, output_dir: PathLike = "./", filename: str = "predictions.csv", delimiter: str = ",", overwrite: bool = True, flush: bool = True, allow_missing_keys: bool = False, ) -> None: """ Args: keys: keys of the corresponding items to model output, this transform only supports 1 key. See also: :py:class:`monai.transforms.compose.MapTransform` meta_keys: explicitly indicate the key of the corresponding metadata dictionary. for example, for data with key `image`, the metadata by default is in `image_meta_dict`. the metadata is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. will extract the filename of input image to save classification results. meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`. so need the key to extract the metadata of input image, like filename, etc. default is `meta_dict`. for example, for data with key `image`, the metadata by default is in `image_meta_dict`. the metadata is a dictionary object which contains: filename, original_shape, etc. this arg only works when `meta_keys=None`. if no corresponding metadata, set to `None`. saver: the saver instance to save classification results, if None, create a CSVSaver internally. the saver must provide `save(data, meta_data)` and `finalize()` APIs. output_dir: if `saver=None`, specify the directory to save the CSV file. filename: if `saver=None`, specify the name of the saved CSV file. delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`. to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter. overwrite: if `saver=None`, indicate whether to overwriting existing CSV file content, if True, will clear the file before saving. otherwise, will append new content to the CSV file. flush: if `saver=None`, indicate whether to write the cache data to CSV file immediately in this transform and clear the cache. default to True. If False, may need user to call `saver.finalize()` manually or use `ClassificationSaver` handler. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) if len(self.keys) != 1: raise ValueError( "only 1 key is allowed when saving the classification result.") self.saver = saver or CSVSaver(output_dir=output_dir, filename=filename, overwrite=overwrite, flush=flush, delimiter=delimiter) self.flush = flush self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))