def run_interaction(self, train, compose): data = [{ "image": np.ones((1, 2, 2, 2)).astype(np.float32), "label": np.ones((1, 2, 2, 2)) } for _ in range(5)] network = torch.nn.Linear(2, 2) lr = 1e-3 opt = torch.optim.SGD(network.parameters(), lr) loss = torch.nn.L1Loss() train_transforms = Compose([ FindAllValidSlicesd(label="label", sids="sids"), AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ]) dataset = Dataset(data, transform=train_transforms) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [ Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys=["image", "label", "pred"]), FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ] iteration_transforms = Compose( iteration_transforms) if compose else iteration_transforms i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) self.assertEqual(len(i.transforms.transforms), 6, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing") self.assertEqual(engine.state.best_metric, 9)
def train_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), dtype=np.uint8), FilterImaged(keys="image", min_size=5), AsChannelFirstd(keys="image"), AddChanneld(keys="label"), ToTensord(keys="image"), TorchVisiond(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyd(keys="image"), RandRotate90d(keys=("image", "label"), prob=0.5, spatial_axes=(0, 1)), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), AddInitialSeedPointExd(label="label", guidance="guidance"), AddGuidanceSignald(image="image", guidance="guidance", number_intensity_ch=3), EnsureTyped(keys=("image", "label")), ]
def run_interaction(self, train, compose): data = [] for i in range(5): data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])}) network = torch.nn.Linear(1, 1) lr = 1e-3 opt = torch.optim.SGD(network.parameters(), lr) loss = torch.nn.L1Loss() dataset = Dataset(data, transform=None) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")] iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing") self.assertEqual(engine.state.best_metric, 9)
def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) np.testing.assert_allclose(result, test_data)
def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) np.testing.assert_allclose(result, test_data.numpy())
def post_transforms(self, data=None): return [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5), ToNumpyd(keys="pred"), RestoreLabeld(keys="pred", ref_image="image", mode="nearest"), AsChannelLastd(keys="pred"), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True), SqueezeDimd(keys="pred", dim=0), ToNumpyd(keys="pred"), Restored(keys="pred", ref_image="image"), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True), ToNumpyd(keys="pred"), Restored(keys="pred", ref_image="image"), BoundingBoxd(keys="pred", result="result", bbox="bbox"), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5), ToNumpyd(keys="pred"), RestoreLabeld(keys="pred", ref_image="image", mode="nearest"), AsChannelLastd(keys="pred"), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1), AsDiscreted(keys="pred", argmax=len(self.labels) > 1, threshold=0.5 if len(self.labels) == 1 else None), SqueezeDimd(keys="pred", dim=0), ToNumpyd(keys=("image", "pred")), PostFilterLabeld(keys="pred", image="image"), FindContoursd(keys="pred", labels=self.labels), ]
def _default_transforms(image_key, label_key, pixdim): keys = [image_key] if label_key is None else [image_key, label_key] mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST ] if len(keys) == 2 else [GridSampleMode.BILINEAR] return Compose([ LoadImaged(keys=keys), AsChannelFirstd(keys=keys), Orientationd(keys=keys, axcodes="RAS"), Spacingd(keys=keys, pixdim=pixdim, mode=mode), FromMetaTensord(keys=keys), ToNumpyd(keys=keys + [PostFix.meta(k) for k in keys]), ])
def get_click_transforms(self, context: Context): return [ Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys=("image", "label", "pred")), FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ]
def get_click_transforms(): return Compose([ Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys=("image", "label", "pred")), FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanced( guidance="guidance", discrepancy="discrepancy", probability="probability", ), AddGuidanceSignald(image="image", guidance="guidance"), EnsureTyped(keys=("image", "label")), ])
def get_click_transforms(): return Compose([ Activationsd(keys='pred', sigmoid=True), ToNumpyd(keys=('image', 'label', 'pred', 'probability', 'guidance')), FindDiscrepancyRegionsd(label='label', pred='pred', discrepancy='discrepancy', batched=True), AddRandomGuidanced(guidance='guidance', discrepancy='discrepancy', probability='probability', batched=True), AddGuidanceSignald(image='image', guidance='guidance', batched=True), ToTensord(keys=('image', 'label')) ])
def get_click_transforms(self, context: Context): return [ Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True), ToNumpyd(keys=("image", "label", "pred")), # Transforms for click simulation FindDiscrepancyRegionsCustomd(keys="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanceCustomd( keys="NA", guidance="guidance", discrepancy="discrepancy", probability="probability", ), AddGuidanceSignalCustomd( keys="image", guidance="guidance", number_intensity_ch=self.number_intensity_ch), # ToTensord(keys=("image", "label")), ]
def post_transforms(self, data=None) -> Sequence[Callable]: largest_cc = False if not data else data.get("largest_cc", False) applied_labels = list(self.labels.values()) if isinstance( self.labels, dict) else self.labels t = [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1), AsDiscreted(keys="pred", argmax=len(self.labels) > 1, threshold=0.5 if len(self.labels) == 1 else None), ] if largest_cc: t.append( KeepLargestConnectedComponentd(keys="pred", applied_labels=applied_labels)) t.extend([ ToNumpyd(keys="pred"), Restored(keys="pred", ref_image="image"), ]) return t
def test_values(self): with tempfile.TemporaryDirectory() as tempdir: test_data1 = [ ["subject_id", "label", "image", "ehr_0", "ehr_1", "ehr_2"], [ "s000000", 5, "./imgs/s000000.png", 2.007843256, 2.29019618, 2.054902077 ], [ "s000001", 0, "./imgs/s000001.png", 6.839215755, 6.474509716, 5.862744808 ], [ "s000002", 4, "./imgs/s000002.png", 3.772548914, 4.211764812, 4.635294437 ], [ "s000003", 1, "./imgs/s000003.png", 3.333333254, 3.235294342, 3.400000095 ], [ "s000004", 9, "./imgs/s000004.png", 6.427451134, 6.254901886, 5.976470947 ], ] test_data2 = [ [ "subject_id", "ehr_3", "ehr_4", "ehr_5", "ehr_6", "ehr_7", "ehr_8" ], [ "s000000", 3.019608021, 3.807843208, 3.584313869, 3.141176462, 3.1960783, 4.211764812 ], [ "s000001", 5.192157269, 5.274509907, 5.250980377, 4.647058964, 4.886274338, 4.392156601 ], [ "s000002", 5.298039436, 9.545097351, 12.57254887, 6.799999714, 2.1960783, 1.882352948 ], [ "s000003", 3.164705753, 3.086274624, 3.725490093, 3.698039293, 3.698039055, 3.701960802 ], [ "s000004", 6.26274538, 7.717647076, 9.584313393, 6.082352638, 2.662744999, 2.34117651 ], ] test_data3 = [ [ "subject_id", "ehr_9", "ehr_10", "meta_0", "meta_1", "meta_2" ], ["s000000", 6.301961422, 6.470588684, "TRUE", "TRUE", "TRUE"], [ "s000001", 5.219608307, 7.827450752, "FALSE", "TRUE", "FALSE" ], ["s000002", 1.882352948, 2.031372547, "TRUE", "FALSE", "TRUE"], [ "s000003", 3.309803963, 3.729412079, "FALSE", "FALSE", "TRUE" ], ["s000004", 2.062745094, 2.34117651, "FALSE", "TRUE", "TRUE"], ] def prepare_csv_file(data, filepath): with open(filepath, "a") as f: for d in data: f.write((",".join([str(i) for i in d])) + "\n") filepath1 = os.path.join(tempdir, "test_data1.csv") filepath2 = os.path.join(tempdir, "test_data2.csv") filepath3 = os.path.join(tempdir, "test_data3.csv") prepare_csv_file(test_data1, filepath1) prepare_csv_file(test_data2, filepath2) prepare_csv_file(test_data3, filepath3) # test single CSV file dataset = CSVIterableDataset(filepath1) for i, item in enumerate(dataset): if i == 2: self.assertDictEqual( { k: round(v, 4) if not isinstance(v, str) else v for k, v in item.items() }, { "subject_id": "s000002", "label": 4, "image": "./imgs/s000002.png", "ehr_0": 3.7725, "ehr_1": 4.2118, "ehr_2": 4.6353, }, ) break # test reset iterables dataset.reset(filename=filepath3) for i, item in enumerate(dataset): if i == 3: self.assertEqual(item["meta_0"], False) # test multiple CSV files, join tables with kwargs dataset = CSVIterableDataset([filepath1, filepath2, filepath3], on="subject_id") for i, item in enumerate(dataset): if i == 3: self.assertDictEqual( { k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items() }, { "subject_id": "s000003", "label": 1, "image": "./imgs/s000003.png", "ehr_0": 3.3333, "ehr_1": 3.2353, "ehr_2": 3.4000, "ehr_3": 3.1647, "ehr_4": 3.0863, "ehr_5": 3.7255, "ehr_6": 3.6980, "ehr_7": 3.6980, "ehr_8": 3.7020, "ehr_9": 3.3098, "ehr_10": 3.7294, "meta_0": False, "meta_1": False, "meta_2": True, }, ) # test selected columns and chunk size dataset = CSVIterableDataset( filename=[filepath1, filepath2, filepath3], chunksize=2, col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], ) for i, item in enumerate(dataset): if i == 3: self.assertDictEqual( { k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items() }, { "subject_id": "s000003", "image": "./imgs/s000003.png", "ehr_1": 3.2353, "ehr_7": 3.6980, "meta_1": False, }, ) # test group columns dataset = CSVIterableDataset( filename=[filepath1, filepath2, filepath3], col_names=[ "subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2" ], col_groups={ "ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"] }, ) for i, item in enumerate(dataset): if i == 3: np.testing.assert_allclose( [round(i, 4) for i in item["ehr"]], [ 3.3333, 3.2353, 3.4000, 3.1647, 3.0863, 3.7255, 3.6980, 3.6980, 3.7020, 3.3098, 3.7294 ], ) np.testing.assert_allclose(item["meta12"], [False, True]) # test transform dataset = CSVIterableDataset( chunksize=2, buffer_size=4, filename=[filepath1, filepath2, filepath3], col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr"), shuffle=True, ) dataset.set_random_state(123) expected = [ [3.7725, 4.2118, 4.6353, 5.298, 9.5451], [2.0078, 2.2902, 2.0549, 3.0196, 3.8078], [6.4275, 6.2549, 5.9765, 6.2627, 7.7176], [6.8392, 6.4745, 5.8627, 5.1922, 5.2745], ] for item, exp in zip(dataset, expected): self.assertTrue(isinstance(item["ehr"], np.ndarray)) np.testing.assert_allclose(np.around(item["ehr"], 4), exp) # test multiple processes loading dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys="label")) # set num workers = 0 for mac / win num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=2) for item in dataloader: # test the last item which only has 1 data if len(item) == 1: self.assertListEqual(item["subject_id"], ["s000002"]) np.testing.assert_allclose(item["label"], [4]) self.assertListEqual(item["image"], ["./imgs/s000002.png"])
def test_values(self): with tempfile.TemporaryDirectory() as tempdir: test_data1 = [ ["subject_id", "label", "image", "ehr_0", "ehr_1", "ehr_2"], [ "s000000", 5, "./imgs/s000000.png", 2.007843256, 2.29019618, 2.054902077 ], [ "s000001", 0, "./imgs/s000001.png", 6.839215755, 6.474509716, 5.862744808 ], [ "s000002", 4, "./imgs/s000002.png", 3.772548914, 4.211764812, 4.635294437 ], [ "s000003", 1, "./imgs/s000003.png", 3.333333254, 3.235294342, 3.400000095 ], [ "s000004", 9, "./imgs/s000004.png", 6.427451134, 6.254901886, 5.976470947 ], ] test_data2 = [ [ "subject_id", "ehr_3", "ehr_4", "ehr_5", "ehr_6", "ehr_7", "ehr_8" ], [ "s000000", 3.019608021, 3.807843208, 3.584313869, 3.141176462, 3.1960783, 4.211764812 ], [ "s000001", 5.192157269, 5.274509907, 5.250980377, 4.647058964, 4.886274338, 4.392156601 ], [ "s000002", 5.298039436, 9.545097351, 12.57254887, 6.799999714, 2.1960783, 1.882352948 ], [ "s000003", 3.164705753, 3.086274624, 3.725490093, 3.698039293, 3.698039055, 3.701960802 ], [ "s000004", 6.26274538, 7.717647076, 9.584313393, 6.082352638, 2.662744999, 2.34117651 ], ] test_data3 = [ [ "subject_id", "ehr_9", "ehr_10", "meta_0", "meta_1", "meta_2" ], ["s000000", 6.301961422, 6.470588684, "TRUE", "TRUE", "TRUE"], [ "s000001", 5.219608307, 7.827450752, "FALSE", "TRUE", "FALSE" ], ["s000002", 1.882352948, 2.031372547, "TRUE", "FALSE", "TRUE"], [ "s000003", 3.309803963, 3.729412079, "FALSE", "FALSE", "TRUE" ], ["s000004", 2.062745094, 2.34117651, "FALSE", "TRUE", "TRUE"], # generate NaN values in the row ["s000005", 3.353655643, 1.675674543, "TRUE", "TRUE", "FALSE"], ] def prepare_csv_file(data, filepath): with open(filepath, "a") as f: for d in data: f.write((",".join([str(i) for i in d])) + "\n") filepath1 = os.path.join(tempdir, "test_data1.csv") filepath2 = os.path.join(tempdir, "test_data2.csv") filepath3 = os.path.join(tempdir, "test_data3.csv") filepaths = [filepath1, filepath2, filepath3] prepare_csv_file(test_data1, filepath1) prepare_csv_file(test_data2, filepath2) prepare_csv_file(test_data3, filepath3) # test single CSV file dataset = CSVDataset(filepath1) self.assertDictEqual( { k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items() }, { "subject_id": "s000002", "label": 4, "image": "./imgs/s000002.png", "ehr_0": 3.7725, "ehr_1": 4.2118, "ehr_2": 4.6353, }, ) # test multiple CSV files, join tables with kwargs dataset = CSVDataset(filepaths, on="subject_id") self.assertDictEqual( { k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in dataset[3].items() }, { "subject_id": "s000003", "label": 1, "image": "./imgs/s000003.png", "ehr_0": 3.3333, "ehr_1": 3.2353, "ehr_2": 3.4000, "ehr_3": 3.1647, "ehr_4": 3.0863, "ehr_5": 3.7255, "ehr_6": 3.6980, "ehr_7": 3.6980, "ehr_8": 3.7020, "ehr_9": 3.3098, "ehr_10": 3.7294, "meta_0": False, "meta_1": False, "meta_2": True, }, ) # test selected rows and columns dataset = CSVDataset( src=filepaths, row_indices=[[0, 2], 3], # load row: 0, 1, 3 col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], ) self.assertEqual(len(dataset), 3) self.assertDictEqual( { k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in dataset[-1].items() }, { "subject_id": "s000003", "image": "./imgs/s000003.png", "ehr_1": 3.2353, "ehr_7": 3.6980, "meta_1": False, }, ) # test group columns dataset = CSVDataset( src=filepaths, row_indices=[1, 3], # load row: 1, 3 col_names=[ "subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2" ], col_groups={ "ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"] }, ) np.testing.assert_allclose( [round(i, 4) for i in dataset[-1]["ehr"]], [ 3.3333, 3.2353, 3.4000, 3.1647, 3.0863, 3.7255, 3.6980, 3.6980, 3.7020, 3.3098, 3.7294 ], ) np.testing.assert_allclose(dataset[-1]["meta12"], [False, True]) # test transform dataset = CSVDataset( src=filepaths, col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr")) self.assertEqual(len(dataset), 5) expected = [ [2.0078, 2.2902, 2.0549, 3.0196, 3.8078], [6.8392, 6.4745, 5.8627, 5.1922, 5.2745], [3.7725, 4.2118, 4.6353, 5.2980, 9.5451], [3.3333, 3.2353, 3.4000, 3.1647, 3.0863], [6.4275, 6.2549, 5.9765, 6.2627, 7.7176], ] for item, exp in zip(dataset, expected): self.assertTrue(isinstance(item["ehr"], np.ndarray)) np.testing.assert_allclose(np.around(item["ehr"], 4), exp) # test default values and dtype dataset = CSVDataset( src=filepaths, col_names=["subject_id", "image", "ehr_1", "ehr_9", "meta_1"], col_types={ "image": { "type": str, "default": "No image" }, "ehr_1": { "type": int, "default": 0 } }, how="outer", # generate NaN values in this merge mode ) self.assertEqual(len(dataset), 6) self.assertEqual(dataset[-1]["image"], "No image") self.assertEqual(type(dataset[-1]["ehr_1"]), int) np.testing.assert_allclose(dataset[-1]["ehr_9"], 3.3537, rtol=1e-2) # test pre-loaded DataFrame df = pd.read_csv(filepath1) dataset = CSVDataset(src=df) self.assertDictEqual( { k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items() }, { "subject_id": "s000002", "label": 4, "image": "./imgs/s000002.png", "ehr_0": 3.7725, "ehr_1": 4.2118, "ehr_2": 4.6353, }, ) # test pre-loaded multiple DataFrames, join tables with kwargs dfs = [pd.read_csv(i) for i in filepaths] dataset = CSVDataset(src=dfs, on="subject_id") self.assertEqual(dataset[3]["subject_id"], "s000003") self.assertEqual(dataset[3]["label"], 1) self.assertEqual(round(dataset[3]["ehr_0"], 4), 3.3333) self.assertEqual(dataset[3]["meta_0"], False)