def test_values(self): datalist = [ { "image": "spleen_19.nii.gz", "label": "spleen_label_19.nii.gz" }, { "image": "spleen_31.nii.gz", "label": "spleen_label_31.nii.gz" }, ] transform = Compose([ DataStatsd(keys=["image", "label"], data_shape=False, value_range=False, data_value=True), SimulateDelayd(keys=["image", "label"], delay_time=0.1), ]) dataset = CacheDataset(data=datalist, transform=transform, cache_rate=0.5, cache_num=1) dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=2) for d in dataloader: self.assertEqual(d["image"][0], "spleen_19.nii.gz") self.assertEqual(d["image"][1], "spleen_31.nii.gz") self.assertEqual(d["label"][0], "spleen_label_19.nii.gz") self.assertEqual(d["label"][1], "spleen_label_31.nii.gz")
def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_stats.log") handler = logging.FileHandler(filename, mode="w") handler.setLevel(logging.INFO) name = "DataStats" logger = logging.getLogger(name) logger.addHandler(handler) input_param = { "keys": "img", "prefix": "test data", "data_shape": True, "value_range": True, "data_value": True, "additional_info": np.mean, "name": name, } transform = DataStatsd(**input_param) _ = transform(input_data) for h in logger.handlers[:]: h.close() logger.removeHandler(h) del handler with open(filename) as f: content = f.read() if sys.platform != "win32": self.assertEqual(content, expected_print)
def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_stats.log") handler = logging.FileHandler(filename, mode="w") input_param = { "keys": "img", "prefix": "test data", "data_shape": True, "value_range": True, "data_value": True, "additional_info": lambda x: np.mean(x), "logger_handler": handler, } transform = DataStatsd(**input_param) _ = transform(input_data) handler.stream.close() transform.printer._logger.removeHandler(handler) with open(filename, "r") as f: content = f.read() self.assertEqual(content, expected_print)
def test_value(self, input_param, input_data, expected_print): transform = DataStatsd(**input_param) _ = transform(input_data)
def test_value(self, input_param, input_data, expected_print): transform = DataStatsd(**input_param) _ = transform(input_data) self.assertEqual(transform.printer.output, expected_print)