def test_array_values(self): input_data = {"img": [[0, 1], [1, 2]], "seg": [[0, 1], [1, 2]]} result = CopyItemsd(keys="img", times=1, names="img_1")(input_data) self.assertTrue("img_1" in result) result["img_1"][0][0] += 1 np.testing.assert_allclose(result["img"], [[0, 1], [1, 2]]) np.testing.assert_allclose(result["img_1"], [[1, 1], [1, 2]])
def test_correct(self): with tempfile.TemporaryDirectory() as temp_dir: transforms = Compose([ LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2", "im2_meta_dict"), names=("im3", "im3_meta_dict")), ResampleToMatchd("im3", "im1_meta_dict"), Lambda(update_fname), SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False), ]) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) # check that output sizes match assert_allclose(data["im1"].shape, data["im3"].shape) # and that the meta data has been updated accordingly assert_allclose(data["im3"].shape[1:], data["im3_meta_dict"]["spatial_shape"], type_test=False) assert_allclose(data["im3_meta_dict"]["affine"], data["im1_meta_dict"]["affine"]) # check we're different from the original self.assertTrue( any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) self.assertTrue( any(i != j for i, j in zip(data["im3_meta_dict"]["affine"].flatten( ), data["im2_meta_dict"]["affine"].flatten()))) # test the inverse data = Invertd("im3", transforms, "im3")(data) assert_allclose(data["im2"].shape, data["im3"].shape)
def test_correct(self): transforms = Compose([ LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2"), names=("im3")), ResampleToMatchd("im3", "im1"), Lambda(update_fname), SaveImaged("im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, resample=False), ]) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) # check that output sizes match assert_allclose(data["im1"].shape, data["im3"].shape) # and that the meta data has been updated accordingly assert_allclose(data["im3"].affine, data["im1"].affine) # check we're different from the original self.assertTrue( any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) self.assertTrue( any(i != j for i, j in zip(data["im3"].affine.flatten(), data["im2"].affine.flatten()))) # test the inverse data = Invertd("im3", transforms)(data) assert_allclose(data["im2"].shape, data["im3"].shape)
def test_numpy_values(self, keys, times, names): input_data = {"img": np.array([[0, 1], [1, 2]]), "seg": np.array([[3, 4], [4, 5]])} result = CopyItemsd(keys=keys, times=times, names=names)(input_data) for name in ensure_tuple(names): self.assertTrue(name in result) result["img_1"] += 1 np.testing.assert_allclose(result["img_1"], np.array([[1, 2], [2, 3]])) np.testing.assert_allclose(result["img"], np.array([[0, 1], [1, 2]]))
def test_default_names(self): input_data = { "img": np.array([[0, 1], [1, 2]]), "seg": np.array([[3, 4], [4, 5]]) } result = CopyItemsd(keys=["img", "seg"], times=2, names=None)(input_data) for name in ["img_0", "seg_0", "img_1", "seg_1"]: self.assertTrue(name in result)
def test_graph_tensor_values(self): device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") net = torch.nn.PReLU().to(device) with eval_mode(net): pred = net(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)) input_data = {"pred": pred, "seg": torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)} result = CopyItemsd(keys="pred", times=1, names="pred_1")(input_data) self.assertTrue("pred_1" in result) result["pred_1"] += 1.0 torch.testing.assert_allclose(result["pred"], torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)) torch.testing.assert_allclose(result["pred_1"], torch.tensor([[1.0, 2.0], [2.0, 3.0]], device=device))
def test_tensor_values(self): device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") input_data = { "img": torch.tensor([[0, 1], [1, 2]], device=device), "seg": torch.tensor([[0, 1], [1, 2]], device=device), } result = CopyItemsd(keys="img", times=1, names="img_1")(input_data) self.assertTrue("img_1" in result) result["img_1"] += 1 torch.testing.assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]], device=device)) torch.testing.assert_allclose(result["img_1"], torch.tensor([[1, 2], [2, 3]], device=device))
def test_compute(self): data = [ { "image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"] }, { "image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"] }, ] handlers = [ DecollateBatch(event="MODEL_COMPLETED"), PostProcessing(transform=Compose([ Activationsd(keys="pred", sigmoid=True), CopyItemsd(keys="filename", times=1, names="filename_bak"), AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), ])), ] # set up engine, PostProcessing handler works together with postprocessing transforms of engine engine = SupervisedEvaluator( device=torch.device("cpu:0"), val_data_loader=data, epoch_length=2, network=torch.nn.PReLU(), # set decollate=False and execute some postprocessing first, then decollate in handlers postprocessing=lambda x: dict(pred=x["pred"] + 1.0), decollate=False, val_handlers=handlers, ) engine.run() expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]) for o, e in zip(engine.state.output, expected): torch.testing.assert_allclose(o["pred"], e) filename = o.get("filename_bak") if filename is not None: self.assertEqual(filename, "test2")
def get_xforms_with_synthesis(mode="synthesis", keys=("image", "label"), keys2=("image", "label", "synthetic_lesion")): """returns a composed transform for train/val/infer.""" xforms = [ LoadImaged(keys), AddChanneld(keys), Orientationd(keys, axcodes="LPS"), Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]), ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True), CopyItemsd(keys,1, names=['image_1', 'label_1']), ] if mode == "synthesis": xforms.extend([ SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"), # ensure at least 192x192 RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=3), TransCustom(keys, path_synthesis, read_cea_aug_slice2, pseudo_healthy_with_texture, scans_syns, decreasing_sequence, GEN=15, POST_PROCESS=True, mask_outer_ring=True, new_value=.5), RandAffined( # keys, keys2, prob=0.15, rotate_range=(0.05, 0.05, None), # 3 parameters control the transform on 3 dimensions scale_range=(0.1, 0.1, None), mode=("bilinear", "nearest", "bilinear"), # mode=("bilinear", "nearest"), as_tensor_output=False ), RandGaussianNoised((keys2[0],keys2[2]), prob=0.15, std=0.01), # RandGaussianNoised(keys[0], prob=0.15, std=0.01), RandFlipd(keys, spatial_axis=0, prob=0.5), RandFlipd(keys, spatial_axis=1, prob=0.5), RandFlipd(keys, spatial_axis=2, prob=0.5), TransCustom2(0.333) ]) dtype = (np.float32, np.uint8) # dtype = (np.float32, np.uint8, np.float32) xforms.extend([CastToTyped(keys, dtype=dtype)]) return monai.transforms.Compose(xforms)
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = ( make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), # test EnsureTensor for complicated dict data and invert it CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"), # test to support Tensor, Numpy array and dictionary when inverting EnsureTyped(keys=["image", "test_dict"]), ToTensord("image"), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted", "label_inverted", "test_dict"], transform=transform, orig_keys=["label", "label", "test_dict"], meta_keys=[ PostFix.meta("image_inverted"), PostFix.meta("label_inverted"), None ], orig_meta_keys=[ PostFix.meta("label"), PostFix.meta("label"), None ], nearest_interp=True, to_tensor=[True, False, False], device="cpu", ) inverter_1 = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted1", "label_inverted1"], transform=transform, orig_keys=["image", "image"], meta_keys=[ PostFix.meta("image_inverted1"), PostFix.meta("label_inverted1") ], orig_meta_keys=[PostFix.meta("image"), PostFix.meta("image")], nearest_interp=[True, False], to_tensor=[True, True], device="cpu", ) expected_keys = [ "image", "image_inverted", "image_inverted1", PostFix.meta("image_inverted1"), PostFix.meta("image_inverted"), PostFix.meta("image"), "image_transforms", "label", "label_inverted", "label_inverted1", PostFix.meta("label_inverted1"), PostFix.meta("label_inverted"), PostFix.meta("label"), "label_transforms", "test_dict", "test_dict_transforms", ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) item = inverter_1(item) self.assertListEqual(sorted(item), expected_keys) self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) # check the nearest interpolation mode i = item["image_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # test inverted test_dict self.assertTrue( isinstance(item["test_dict"]["affine"], np.ndarray)) self.assertTrue( isinstance(item["test_dict"]["filename_or_obj"], str)) # check the case that different items use different interpolation mode to invert transforms d = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) d = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype( np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821), f"diff. {reverted.size - n_good}") set_determinism(seed=None)
def main(): #TODO Defining file paths & output directory path json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json') data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19') logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467') if os.path.exists(logdir_path)==False: os.mkdir(logdir_path) # Load Json & Append Root Path with open(json_Path, 'r') as json_f: json_Data = json.load(json_f) train_Data = json_Data['training'] val_Data = json_Data['validation'] for idx, each_d in enumerate(train_Data): train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image']) for idx, each_d in enumerate(val_Data): val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image']) print('Total Number of Training Data Samples: {}'.format(len(train_Data))) print(train_Data) print('#' * 10) print('Total Number of Validation Data Samples: {}'.format(len(val_Data))) print(val_Data) print('#' * 10) # Set Determinism set_determinism(seed=123) # Define Training Transforms train_Transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Spacingd(keys=["image"], pixdim=( 2.0, 2.0, 2.0), mode=("bilinear")), ScaleIntensityRanged( keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image"], source_key="image"), SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)), RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2), CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False), OneOf(transforms=[ RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8), # Please note that that if image, image_2 are called via the same transform call because of the determinism # they will get augmented the exact same way which is not the required case here, hence two calls are made OneOf(transforms=[ RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8) ] ) check_ds = Dataset(data=train_Data, transform=train_Transforms) check_loader = DataLoader(check_ds, batch_size=1) check_data = first(check_loader) image = (check_data["image"][0][0]) print(f"image shape: {image.shape}") # Define Network ViT backbone & Loss & Optimizer device = torch.device("cuda:0") model = ViTAutoEnc( in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16), pos_embed='conv', hidden_size=768, mlp_dim=3072, ) model = model.to(device) # Define Hyper-paramters for training loop max_epochs = 500 val_interval = 2 batch_size = 4 lr = 1e-4 epoch_loss_values = [] step_loss_values = [] epoch_cl_loss_values = [] epoch_recon_loss_values = [] val_loss_values = [] best_val_loss = 1000.0 recon_loss = L1Loss() contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Define DataLoader using MONAI, CacheDataset needs to be used train_ds = Dataset(data=train_Data, transform=train_Transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4) val_ds = Dataset(data=val_Data, transform=train_Transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4) for epoch in range(max_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 epoch_cl_loss = 0 epoch_recon_loss = 0 step = 0 for batch_data in train_loader: step += 1 start_time = time.time() inputs, inputs_2, gt_input = ( batch_data["image"].to(device), batch_data["image_2"].to(device), batch_data["gt_image"].to(device), ) optimizer.zero_grad() outputs_v1, hidden_v1 = model(inputs) outputs_v2, hidden_v2 = model(inputs_2) flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4) flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4) r_loss = recon_loss(outputs_v1, gt_input) cl_loss = contrastive_loss(flat_out_v1, flat_out_v2) # Adjust the CL loss by Recon Loss total_loss = r_loss + cl_loss * r_loss total_loss.backward() optimizer.step() epoch_loss += total_loss.item() step_loss_values.append(total_loss.item()) # CL & Recon Loss Storage of Value epoch_cl_loss += cl_loss.item() epoch_recon_loss += r_loss.item() end_time = time.time() print( f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {total_loss.item():.4f}, " f"time taken: {end_time-start_time}s") epoch_loss /= step epoch_cl_loss /= step epoch_recon_loss /= step epoch_loss_values.append(epoch_loss) epoch_cl_loss_values.append(epoch_cl_loss) epoch_recon_loss_values.append(epoch_recon_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if epoch % val_interval == 0: print('Entering Validation for epoch: {}'.format(epoch+1)) total_val_loss = 0 val_step = 0 model.eval() for val_batch in val_loader: val_step += 1 start_time = time.time() inputs, gt_input = ( val_batch["image"].to(device), val_batch["gt_image"].to(device), ) print('Input shape: {}'.format(inputs.shape)) outputs, outputs_v2 = model(inputs) val_loss = recon_loss(outputs, gt_input) total_val_loss += val_loss.item() end_time = time.time() total_val_loss /= val_step val_loss_values.append(total_val_loss) print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s") if total_val_loss < best_val_loss: print(f"Saving new model based on validation loss {total_val_loss:.4f}") best_val_loss = total_val_loss checkpoint = {'epoch': max_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt')) plt.figure(1, figsize=(8, 8)) plt.subplot(2, 2, 1) plt.plot(epoch_loss_values) plt.grid() plt.title('Training Loss') plt.subplot(2, 2, 2) plt.plot(val_loss_values) plt.grid() plt.title('Validation Loss') plt.subplot(2, 2, 3) plt.plot(epoch_cl_loss_values) plt.grid() plt.title('Training Contrastive Loss') plt.subplot(2, 2, 4) plt.plot(epoch_recon_loss_values) plt.grid() plt.title('Training Recon Loss') plt.savefig(os.path.join(logdir_path, 'loss_plots.png')) plt.close(1) print('Done') return None
from monai.engines import SupervisedEvaluator from monai.handlers import PostProcessing from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd # test lambda function as `transform` TEST_CASE_1 = [{ "transform": lambda x: dict(pred=x["pred"] + 1.0) }, torch.tensor([[[[1.9975], [1.9997]]]])] # test composed post transforms as `transform` TEST_CASE_2 = [ { "transform": Compose([ CopyItemsd(keys="filename", times=1, names="filename_bak"), AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), ]) }, torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]), ] class TestHandlerPostProcessing(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_compute(self, input_params, expected): data = [ {
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=1, names="label_inverted"), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image", "label_inverted"], transform=transform, loader=loader, orig_keys="label", meta_keys=["image_meta_dict", "label_inverted_meta_dict"], orig_meta_keys="label_meta_dict", nearest_interp=True, to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ) # execute 1 epoch for d in loader: d = inverter(d) # this unit test only covers basic function, test_handler_transform_inverter covers more self.assertTupleEqual(d["label"].shape[1:], (1, 100, 100, 100)) # check the nearest inerpolation mode for i in d["image"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in d["label_inverted"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) set_determinism(seed=None)
def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: data = [ { "pred": torch.zeros(8), PostFix.meta("image"): { "filename_or_obj": ["testfile" + str(i) for i in range(8)] }, }, { "pred": torch.zeros(8), PostFix.meta("image"): { "filename_or_obj": ["testfile" + str(i) for i in range(8, 16)] }, }, { "pred": torch.zeros(8), PostFix.meta("image"): { "filename_or_obj": ["testfile" + str(i) for i in range(16, 24)] }, }, ] saver = CSVSaver(output_dir=Path(tempdir), filename="predictions2.csv", overwrite=False, flush=False, delimiter="\t") # set up test transforms post_trans = Compose([ CopyItemsd(keys=PostFix.meta("image"), times=1, names=PostFix.meta("pred")), # 1st saver saves data into CSV file SaveClassificationd( keys="pred", saver=None, meta_keys=None, output_dir=Path(tempdir), filename="predictions1.csv", delimiter="\t", overwrite=True, ), # 2rd saver only saves data into the cache, manually finalize later SaveClassificationd(keys="pred", saver=saver, meta_key_postfix=PostFix.meta()), ]) # simulate inference 2 iterations d = decollate_batch(data[0]) for i in d: post_trans(i) d = decollate_batch(data[1]) for i in d: post_trans(i) # write into CSV file saver.finalize() # 3rd saver will not delete previous data due to `overwrite=False` trans2 = SaveClassificationd( keys="pred", saver=None, meta_keys=PostFix.meta( "image"), # specify meta key, so no need to copy anymore output_dir=tempdir, filename="predictions1.csv", delimiter="\t", overwrite=False, ) d = decollate_batch(data[2]) for i in d: trans2(i) def _test_file(filename, count): filepath = os.path.join(tempdir, filename) self.assertTrue(os.path.exists(filepath)) with open(filepath) as f: reader = csv.reader(f, delimiter="\t") i = 0 for row in reader: self.assertEqual(row[0], "testfile" + str(i)) self.assertEqual( np.array(row[1:]).astype(np.float32), 0.0) i += 1 self.assertEqual(i, count) _test_file("predictions1.csv", 24) _test_file("predictions2.csv", 16)
def main(): parser = argparse.ArgumentParser(description="training") parser.add_argument( "--checkpoint", type=str, default=None, help="checkpoint full path", ) parser.add_argument( "--factor_ram_cost", default=0.0, type=float, help="factor to determine RAM cost in the searched architecture", ) parser.add_argument( "--fold", action="store", required=True, help="fold index in N-fold cross-validation", ) parser.add_argument( "--json", action="store", required=True, help="full path of .json file", ) parser.add_argument( "--json_key", action="store", required=True, help="selected key in .json data list", ) parser.add_argument( "--local_rank", required=int, help="local process rank", ) parser.add_argument( "--num_folds", action="store", required=True, help="number of folds in cross-validation", ) parser.add_argument( "--output_root", action="store", required=True, help="output root", ) parser.add_argument( "--root", action="store", required=True, help="data root", ) args = parser.parse_args() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not os.path.exists(args.output_root): os.makedirs(args.output_root, exist_ok=True) amp = True determ = True factor_ram_cost = args.factor_ram_cost fold = int(args.fold) input_channels = 1 learning_rate = 0.025 learning_rate_arch = 0.001 learning_rate_milestones = np.array([0.4, 0.8]) num_images_per_batch = 1 num_epochs = 1430 # around 20k iteration num_epochs_per_validation = 100 num_epochs_warmup = 715 num_folds = int(args.num_folds) num_patches_per_image = 1 num_sw_batch_size = 6 output_classes = 3 overlap_ratio = 0.625 patch_size = (96, 96, 96) patch_size_valid = (96, 96, 96) spacing = [1.0, 1.0, 1.0] print("factor_ram_cost", factor_ram_cost) # deterministic training if determ: set_determinism(seed=0) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") # dist.barrier() world_size = dist.get_world_size() with open(args.json, "r") as f: json_data = json.load(f) split = len(json_data[args.json_key]) // num_folds list_train = json_data[args.json_key][:( split * fold)] + json_data[args.json_key][(split * (fold + 1)):] list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))] # training data files = [] for _i in range(len(list_train)): str_img = os.path.join(args.root, list_train[_i]["image"]) str_seg = os.path.join(args.root, list_train[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) train_files = files random.shuffle(train_files) train_files_w = train_files[:len(train_files) // 2] train_files_w = partition_dataset(data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_w:", len(train_files_w)) train_files_a = train_files[len(train_files) // 2:] train_files_a = partition_dataset(data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_a:", len(train_files_a)) # validation data files = [] for _i in range(len(list_valid)): str_img = os.path.join(args.root, list_valid[_i]["image"]) str_seg = os.path.join(args.root, list_valid[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) val_files = files val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[dist.get_rank()] print("val_files:", len(val_files)) # network architecture device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)), CopyItemsd(keys=["label"], times=1, names=["label4crop"]), Lambdad( keys=["label4crop"], func=lambda x: np.concatenate(tuple([ ndimage.binary_dilation( (x == _k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(output_classes) ]), axis=0), overwrite=True, ), EnsureTyped(keys=["image", "label"]), CastToTyped(keys=["image"], dtype=(torch.float32)), SpatialPadd(keys=["image", "label", "label4crop"], spatial_size=patch_size, mode=["reflect", "constant", "constant"]), RandCropByLabelClassesd(keys=["image", "label"], label_key="label4crop", num_classes=output_classes, ratios=[ 1, ] * output_classes, spatial_size=patch_size, num_samples=num_patches_per_image), Lambdad(keys=["label4crop"], func=lambda x: 0), RandRotated(keys=["image", "label"], range_x=0.3, range_y=0.3, range_z=0.3, mode=["bilinear", "nearest"], prob=0.2), RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, mode=["trilinear", "nearest"], align_corners=[True, None], prob=0.16), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5), RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5), CastToTyped(keys=["image", "label"], dtype=(torch.float32, torch.uint8)), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ToTensord(keys=["image", "label"]) ]) train_ds_a = monai.data.CacheDataset(data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8) train_ds_w = monai.data.CacheDataset(data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited. # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms) # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms) # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) train_loader_a = ThreadDataLoader(train_ds_a, num_workers=0, batch_size=num_images_per_batch, shuffle=True) train_loader_w = ThreadDataLoader(train_ds_w, num_workers=0, batch_size=num_images_per_batch, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) # DataLoader can be used as alternatives when ThreadDataLoader is less efficient. # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available()) dints_space = monai.networks.nets.TopologySearch( channel_mul=0.5, num_blocks=12, num_depths=4, use_downsample=True, device=device, ) model = monai.networks.nets.DiNTS( dints_space=dints_space, in_channels=input_channels, num_classes=output_classes, use_downsample=True, ) model = model.to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=output_classes)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)]) # loss function loss_func = monai.losses.DiceCELoss( include_background=False, to_onehot_y=True, softmax=True, squared_pred=True, batch=True, smooth_nr=0.00001, smooth_dr=0.00001, ) # optimizer optimizer = torch.optim.SGD(model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004) arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) print() if torch.cuda.device_count() > 1: if dist.get_rank() == 0: print("Let's use", torch.cuda.device_count(), "GPUs!") model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) if args.checkpoint != None and os.path.isfile(args.checkpoint): print("[info] fine-tuning pre-trained checkpoint {0:s}".format( args.checkpoint)) model.load_state_dict(torch.load(args.checkpoint, map_location=device)) torch.cuda.empty_cache() else: print("[info] training from scratch") # amp if amp: from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() if dist.get_rank() == 0: print("[info] amp enabled") # start a typical PyTorch training val_interval = num_epochs_per_validation best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() idx_iter = 0 metric_values = list() if dist.get_rank() == 0: writer = SummaryWriter( log_dir=os.path.join(args.output_root, "Events")) with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") dataloader_a_iterator = iter(train_loader_a) start_time = time.time() for epoch in range(num_epochs): decay = 0.5**np.sum([ (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones ]) lr = learning_rate * decay for param_group in optimizer.param_groups: param_group["lr"] = lr if dist.get_rank() == 0: print("-" * 10) print(f"epoch {epoch + 1}/{num_epochs}") print("learning rate is set to {}".format(lr)) model.train() epoch_loss = 0 loss_torch = torch.zeros(2, dtype=torch.float, device=device) epoch_loss_arch = 0 loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device) step = 0 for batch_data in train_loader_w: step += 1 inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = True else: for _ in model.module.weight_parameters(): _.requires_grad = True dints_space.log_alpha_a.requires_grad = False dints_space.log_alpha_c.requires_grad = False optimizer.zero_grad() if amp: with autocast(): outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() loss_torch[0] += loss.item() loss_torch[1] += 1.0 epoch_len = len(train_loader_w) idx_iter += 1 if dist.get_rank() == 0: print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) if epoch < num_epochs_warmup: continue try: sample_a = next(dataloader_a_iterator) except StopIteration: dataloader_a_iterator = iter(train_loader_a) sample_a = next(dataloader_a_iterator) inputs_search, labels_search = sample_a["image"].to( device), sample_a["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = False else: for _ in model.module.weight_parameters(): _.requires_grad = False dints_space.log_alpha_a.requires_grad = True dints_space.log_alpha_c.requires_grad = True # linear increase topology and RAM loss entropy_alpha_c = torch.tensor(0.).to(device) entropy_alpha_a = torch.tensor(0.).to(device) ram_cost_full = torch.tensor(0.).to(device) ram_cost_usage = torch.tensor(0.).to(device) ram_cost_loss = torch.tensor(0.).to(device) topology_loss = torch.tensor(0.).to(device) probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean() entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \ F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean() topology_loss = dints_space.get_topology_entropy(probs_a) ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True) ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape) ram_cost_loss = torch.abs(factor_ram_cost - ram_cost_usage / ram_cost_full) arch_optimizer_a.zero_grad() arch_optimizer_c.zero_grad() combination_weights = (epoch - num_epochs_warmup) / ( num_epochs - num_epochs_warmup) if amp: with autocast(): outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) scaler.scale(loss).backward() scaler.step(arch_optimizer_a) scaler.step(arch_optimizer_c) scaler.update() else: outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) loss.backward() arch_optimizer_a.step() arch_optimizer_c.step() epoch_loss_arch += loss.item() loss_torch_arch[0] += loss.item() loss_torch_arch[1] += 1.0 if dist.get_rank() == 0: print( "[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}") writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step) # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) loss_torch = loss_torch.tolist() loss_torch_arch = loss_torch_arch.tolist() if dist.get_rank() == 0: loss_torch_epoch = loss_torch[0] / loss_torch[1] print( f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if epoch >= num_epochs_warmup: loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] print( f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if (epoch + 1) % val_interval == 0: torch.cuda.empty_cache() model.eval() with torch.no_grad(): metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) metric_sum = 0.0 metric_count = 0 metric_mat = [] val_images = None val_labels = None val_outputs = None _index = 0 for val_data in val_loader: val_images = val_data["image"].to(device) val_labels = val_data["label"].to(device) roi_size = patch_size_valid sw_batch_size = num_sw_batch_size if amp: with torch.cuda.amp.autocast(): pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) else: pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) val_outputs = pred val_outputs = post_pred(val_outputs[0, ...]) val_outputs = val_outputs[None, ...] val_labels = post_label(val_labels[0, ...]) val_labels = val_labels[None, ...] value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) print(_index + 1, "/", len(val_loader), value) metric_count += len(value) metric_sum += value.sum().item() metric_vals = value.cpu().numpy() if len(metric_mat) == 0: metric_mat = metric_vals else: metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) for _c in range(output_classes - 1): val0 = torch.nan_to_num(value[0, _c], nan=0.0) val1 = 1.0 - torch.isnan(value[0, 0]).float() metric[2 * _c] += val0 * val1 metric[2 * _c + 1] += val1 _index += 1 # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) metric = metric.tolist() if dist.get_rank() == 0: for _c in range(output_classes - 1): print( "evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) avg_metric = 0 for _c in range(output_classes - 1): avg_metric += metric[2 * _c] / metric[2 * _c + 1] avg_metric = avg_metric / float(output_classes - 1) print("avg_metric", avg_metric) if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 best_metric_iterations = idx_iter node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode( ) torch.save( { "node_a": node_a_d, "arch_code_a": arch_code_a_d, "arch_code_a_max": arch_code_a_max_d, "arch_code_c": arch_code_c_d, "iter_num": idx_iter, "epochs": epoch + 1, "best_dsc": best_metric, "best_path": best_metric_iterations, }, os.path.join(args.output_root, "search_code_" + str(idx_iter) + ".pth"), ) print("saved new best metric model") dict_file = {} dict_file["best_avg_dice_score"] = float(best_metric) dict_file["best_avg_dice_score_epoch"] = int( best_metric_epoch) dict_file["best_avg_dice_score_iteration"] = int(idx_iter) with open(os.path.join(args.output_root, "progress.yaml"), "w") as out_file: documents = yaml.dump(dict_file, stream=out_file) print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, avg_metric, best_metric, best_metric_epoch)) current_time = time.time() elapsed_time = (current_time - start_time) / 60.0 with open( os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write( "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n" .format(epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter)) dist.barrier() torch.cuda.empty_cache() print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) if dist.get_rank() == 0: writer.close() dist.destroy_process_group() return
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted1", "label_inverted2"]), CopyItemsd("image", times=2, names=["image_inverted1", "image_inverted2"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image_inverted1", "label_inverted1"], batch_keys="label", meta_keys=[ "image_inverted1_meta_dict", "label_inverted1_meta_dict" ], batch_meta_keys="label_meta_dict", nearest_interp=True, to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) # test different nearest interpolation values TransformInverter( transform=transform, loader=loader, output_keys=["image_inverted2", "label_inverted2"], batch_keys="image", meta_keys=None, batch_meta_keys="image_meta_dict", meta_key_postfix="meta_dict", nearest_interp=[True, False], post_func=[lambda x: x + 10, lambda x: x], collate_fn=pad_list_data_collate, num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) # check the nearest inerpolation mode for i in engine.state.output["image_inverted1"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted1"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted1"][-1].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.batch["label_inverted1_meta_dict"][-1][ "filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1824: torch 1.5.1 self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") # check the case that different items use different interpolation mode to invert transforms d = engine.state.output["image_inverted2"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107)) d = engine.state.output["label_inverted2"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107))