def test_exceptions(self): with self.assertRaises(ValueError): # no meta EnsureChannelFirstd("img")({ "img": np.zeros((1, 2, 3)), "img_meta_dict": None }) with self.assertRaises(ValueError): # no meta channel EnsureChannelFirstd("img")({ "img": np.zeros((1, 2, 3)), "img_meta_dict": { "original_channel_dim": None } }) EnsureChannelFirstd("img", strict_check=False)({ "img": np.zeros((1, 2, 3)), "img_meta_dict": None }) EnsureChannelFirstd("img", strict_check=False)({ "img": np.zeros((1, 2, 3)), "img_meta_dict": { "original_channel_dim": None } })
def test_exceptions(self): im = torch.zeros((1, 2, 3)) with self.assertRaises(ValueError): # no meta EnsureChannelFirstd("img")({"img": im}) with self.assertRaises(ValueError): # no meta channel EnsureChannelFirstd("img")({ "img": MetaTensor(im, meta={"original_channel_dim": None}) }) EnsureChannelFirstd("img", strict_check=False)({"img": im}) EnsureChannelFirstd("img", strict_check=False)({ "img": MetaTensor(im, meta={"original_channel_dim": None}) })
def pre_transforms(self, data): return [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), AddBackgroundScribblesFromROId( scribbles="label", scribbles_bg_label=self.scribbles_bg_label, scribbles_fg_label=self.scribbles_fg_label, ), # at the moment optimisers are bottleneck taking a long time, # therefore scaling non-isotropic with big spacing Spacingd(keys=["image", "label"], pixdim=self.pix_dim, mode=["bilinear", "nearest"]), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys="image", a_min=self.intensity_range[0], a_max=self.intensity_range[1], b_min=self.intensity_range[2], b_max=self.intensity_range[3], clip=self.intensity_range[4], ), MakeLikelihoodFromScribblesHistogramd( image="image", scribbles="label", post_proc_label="prob", scribbles_bg_label=self.scribbles_bg_label, scribbles_fg_label=self.scribbles_fg_label, normalise=True, ), ]
def pre_transforms(self, data=None): t = [ LoadImaged(keys="image", reader="ITKReader"), EnsureChannelFirstd(keys="image"), Orientationd(keys="image", axcodes="RAS"), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), ] if self.type == InferType.DEEPEDIT: t.extend( [ AddGuidanceFromPointsCustomd(ref_image="image", guidance="guidance", label_names=self.labels), Resized(keys="image", spatial_size=self.spatial_size, mode="area"), ResizeGuidanceMultipleLabelCustomd(guidance="guidance", ref_image="image"), AddGuidanceSignalCustomd( keys="image", guidance="guidance", number_intensity_ch=self.number_intensity_ch ), ] ) else: t.extend( [ Resized(keys="image", spatial_size=self.spatial_size, mode="area"), DiscardAddGuidanced( keys="image", label_names=self.labels, number_intensity_ch=self.number_intensity_ch ), ] ) t.append(EnsureTyped(keys="image", device=data.get("device") if data else None)) return t
def test_correct(self): with tempfile.TemporaryDirectory() as temp_dir: transforms = Compose([ LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2", "im2_meta_dict"), names=("im3", "im3_meta_dict")), ResampleToMatchd("im3", "im1_meta_dict"), Lambda(update_fname), SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False), ]) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) # check that output sizes match assert_allclose(data["im1"].shape, data["im3"].shape) # and that the meta data has been updated accordingly assert_allclose(data["im3"].shape[1:], data["im3_meta_dict"]["spatial_shape"], type_test=False) assert_allclose(data["im3_meta_dict"]["affine"], data["im1_meta_dict"]["affine"]) # check we're different from the original self.assertTrue( any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) self.assertTrue( any(i != j for i, j in zip(data["im3_meta_dict"]["affine"].flatten( ), data["im2_meta_dict"]["affine"].flatten()))) # test the inverse data = Invertd("im3", transforms, "im3")(data) assert_allclose(data["im2"].shape, data["im3"].shape)
def test_correct(self): transforms = Compose([ LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2"), names=("im3")), ResampleToMatchd("im3", "im1"), Lambda(update_fname), SaveImaged("im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, resample=False), ]) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) # check that output sizes match assert_allclose(data["im1"].shape, data["im3"].shape) # and that the meta data has been updated accordingly assert_allclose(data["im3"].affine, data["im1"].affine) # check we're different from the original self.assertTrue( any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) self.assertTrue( any(i != j for i, j in zip(data["im3"].affine.flatten(), data["im2"].affine.flatten()))) # test the inverse data = Invertd("im3", transforms)(data) assert_allclose(data["im2"].shape, data["im3"].shape)
def test_linear_consistent_dict(self, xform_cls, input_dict, atol): """xform cls testing itk consistency""" img = LoadImaged(keys, image_only=True, simple_keys=True)({ keys[0]: FILE_PATH, keys[1]: FILE_PATH_1 }) img = EnsureChannelFirstd(keys)(img) ref_1 = {k: _create_itk_obj(img[k][0], img[k].affine) for k in keys} output = self.run_transform(img, xform_cls, input_dict) ref_2 = { k: _create_itk_obj(output[k][0], output[k].affine) for k in keys } expected = {k: _resample_to_affine(ref_1[k], ref_2[k]) for k in keys} # compare ref_2 and expected results from itk diff = { k: np.abs( itk.GetArrayFromImage(ref_2[k]) - itk.GetArrayFromImage(expected[k])) for k in keys } avg_diff = {k: np.mean(diff[k]) for k in keys} for k in keys: self.assertTrue(avg_diff[k] < atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
def val_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), reader="ITKReader"), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), EnsureChannelFirstd(keys=("image", "label")), Orientationd(keys=["image", "label"], axcodes="RAS"), # This transform may not work well for MR images ScaleIntensityRanged(keys=("image"), a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), Resized(keys=("image", "label"), spatial_size=self.spatial_size, mode=("area", "nearest")), # Transforms for click simulation FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"), AddInitialSeedPointMissingLabelsd(keys="label", guidance="guidance", sids="sids"), AddGuidanceSignalCustomd( keys="image", guidance="guidance", number_intensity_ch=self.number_intensity_ch), # ToTensord(keys=("image", "label")), SelectItemsd(keys=("image", "label", "guidance", "label_names")), ]
def run_inference_test(root_dir, device="cuda:0"): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 saver = SaveImage( output_dir=os.path.join(root_dir, "output"), dtype=np.float32, output_ext=".nii.gz", output_postfix="seg", mode="bilinear", ) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) # decollate prediction into a list val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)] val_meta = decollate_batch(val_data[PostFix.meta("img")]) # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) for img, meta in zip(val_outputs, val_meta): # save a decollated batch of files saver(img, meta) return dice_metric.aggregate().item()
def test_load_png(self): spatial_size = (256, 256, 3) test_image = np.random.randint(0, 256, size=spatial_size) with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) result = LoadImaged(keys="img")({"img": filename}) result = EnsureChannelFirstd(keys="img")(result) self.assertEqual(result["img"].shape[0], 3)
def get_image_transforms(): itk_reader = monai.data.ITKReader() # Define transforms for image image_transforms = Compose([ LoadImaged(keys=['img'], reader=itk_reader), EnsureChannelFirstd(keys=['img']), ScaleIntensityd(keys=['img']), ToTensord(keys=['img']), ]) return image_transforms
def pre_transforms(self, data=None) -> Sequence[Callable]: return [ LoadImaged(keys="image", reader="ITKReader"), EnsureChannelFirstd(keys="image"), Spacingd(keys="image", pixdim=self.target_spacing), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), EnsureTyped(keys="image"), ]
def test_load_nifti(self, input_param, filenames, original_channel_dim): if original_channel_dim is None: test_image = np.random.rand(128, 128, 128) elif original_channel_dim == -1: test_image = np.random.rand(128, 128, 128, 1) with tempfile.TemporaryDirectory() as tempdir: for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) result = LoadImaged(**input_param)({"img": filenames}) result = EnsureChannelFirstd(**input_param)(result) self.assertEqual(result["img"].shape[0], len(filenames))
def val_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), reader="ITKReader"), EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=self.target_spacing, mode=("bilinear", "nearest")), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), EnsureTyped(keys=("image", "label")), SelectItemsd(keys=("image", "label")), ]
def test_inverse(self): loader = Compose( [LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2"))]) data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) tr = ResampleToMatch() im_mod = tr(data["im2"], data["im1"]) self.assertNotEqual(im_mod.shape, data["im2"].shape) self.assertGreater( ((im_mod.affine - data["im2"].affine)**2).sum()**0.5, 1e-2) # inverse im_mod2 = tr.inverse(im_mod) self.assertEqual(im_mod2.shape, data["im2"].shape) self.assertLess(((im_mod2.affine - data["im2"].affine)**2).sum()**0.5, 1e-2) self.assertEqual(im_mod2.applied_operations, [])
def test_correct(self, reader, writer): loader = Compose([ LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2")) ]) data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) with self.assertRaises(ValueError): ResampleToMatch(mode=None)(img=data["im2"], img_dst=data["im1"]) im_mod = ResampleToMatch()(data["im2"], data["im1"]) saver = SaveImaged("im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, writer=writer, resample=False) im_mod.meta["filename_or_obj"] = get_rand_fname() saver({"im3": im_mod}) saved = nib.load( os.path.join(self.tmpdir, im_mod.meta["filename_or_obj"])) assert_allclose(data["im1"].shape[1:], saved.shape) assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19]))
def train_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), reader="ITKReader"), NormalizeLabelsInDatasetd( keys="label", label_names=self._labels), # Specially for missing labels EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=self.target_spacing, mode=("bilinear", "nearest")), CropForegroundd(keys=("image", "label"), source_key="image"), SpatialPadd(keys=("image", "label"), spatial_size=self.spatial_size), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.spatial_size, pos=1, neg=1, num_samples=self.num_samples, image_key="image", image_threshold=0, ), EnsureTyped(keys=("image", "label"), device=context.device), RandFlipd(keys=("image", "label"), spatial_axis=[0], prob=0.10), RandFlipd(keys=("image", "label"), spatial_axis=[1], prob=0.10), RandFlipd(keys=("image", "label"), spatial_axis=[2], prob=0.10), RandRotate90d(keys=("image", "label"), prob=0.10, max_k=3), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), SelectItemsd(keys=("image", "label")), ]
def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num, num_samples): if mode != "test": keys = ["image", "label"] else: keys = ["image"] load_transforms = [ LoadImaged(keys=keys), EnsureChannelFirstd(keys=keys), ] # 2. sampling sample_transforms = [ PreprocessAnisotropic( keys=keys, clip_values=clip_values[task_id], pixdim=spacing[task_id], normalize_values=normalize_values[task_id], model_mode=mode, ), ] # 3. spatial transforms if mode == "train": other_transforms = [ SpatialPadd(keys=["image", "label"], spatial_size=patch_size[task_id]), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=patch_size[task_id], pos=pos_sample_num, neg=neg_sample_num, num_samples=num_samples, image_key="image", image_threshold=0, ), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("trilinear", "nearest"), align_corners=(True, None), prob=0.15, ), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ] elif mode == "validation": other_transforms = [ CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ] else: other_transforms = [ CastToTyped(keys=["image"], dtype=(np.float32)), EnsureTyped(keys=["image"]), ] all_transforms = load_transforms + sample_transforms + other_transforms return Compose(all_transforms)
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = ( make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose([ LoadImaged(KEYS, image_only=True), EnsureChannelFirstd(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, prob=0, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available( ) else 2 dataset = Dataset(data, transform=transform) transform.inverse(dataset[0]) loader = DataLoader(dataset, num_workers=num_workers, batch_size=1) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted", "label_inverted"], transform=transform, orig_keys=["label", "label"], nearest_interp=True, device="cpu", ) inverter_1 = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted1", "label_inverted1"], transform=transform, orig_keys=["image", "image"], nearest_interp=[True, False], device="cpu", ) expected_keys = [ "image", "image_inverted", "image_inverted1", "label", "label_inverted", "label_inverted1" ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) item = inverter_1(item) self.assertListEqual(sorted(item), expected_keys) self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) # check the nearest interpolation mode i = item["image_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # check the case that different items use different interpolation mode to invert transforms d = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) d = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype( np.int32) original = LoadImaged(KEYS, image_only=True)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = item["label_inverted"].meta["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 self.assertTrue((reverted.size - n_good) < 40000, f"diff. {reverted.size - n_good}") set_determinism(seed=None)
def configure(self): self.set_device() network = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(self.device) if self.multi_gpu: network = DistributedDataParallel( module=network, device_ids=[self.device], find_unused_parameters=False, ) train_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=[1.0, 1.0, 1.0], mode=["bilinear", "nearest"]), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=("image", "label")), ]) train_datalist = load_decathlon_datalist(self.data_list_file_path, True, "training") if self.multi_gpu: train_datalist = partition_dataset( data=train_datalist, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True, )[dist.get_rank()] train_ds = CacheDataset( data=train_datalist, transform=train_transforms, cache_num=32, cache_rate=1.0, num_workers=4, ) train_data_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, ) val_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), ToTensord(keys=("image", "label")), ]) val_datalist = load_decathlon_datalist(self.data_list_file_path, True, "validation") val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4) val_data_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=4, ) post_transform = Compose([ Activationsd(keys="pred", softmax=True), AsDiscreted( keys=["pred", "label"], argmax=[True, False], to_onehot=True, n_classes=2, ), ]) # metric key_val_metric = { "val_mean_dice": MeanDice( include_background=False, output_transform=lambda x: (x["pred"], x["label"]), device=self.device, ) } val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver( save_dir=self.ckpt_dir, save_dict={"model": network}, save_key_metric=True, ), TensorBoardStatsHandler(log_dir=self.ckpt_dir, output_transform=lambda x: None), ] self.eval_engine = SupervisedEvaluator( device=self.device, val_data_loader=val_data_loader, network=network, inferer=SlidingWindowInferer( roi_size=[160, 160, 160], sw_batch_size=4, overlap=0.5, ), post_transform=post_transform, key_val_metric=key_val_metric, val_handlers=val_handlers, amp=self.amp, ) optimizer = torch.optim.Adam(network.parameters(), self.learning_rate) loss_function = DiceLoss(to_onehot_y=True, softmax=True) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=self.eval_engine, interval=self.val_interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler( log_dir=self.ckpt_dir, tag_name="train_loss", output_transform=lambda x: x["loss"], ), ] self.train_engine = SupervisedTrainer( device=self.device, max_epochs=self.max_epochs, train_data_loader=train_data_loader, network=network, optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), post_transform=post_transform, key_train_metric=None, train_handlers=train_handlers, amp=self.amp, ) if self.local_rank > 0: self.train_engine.logger.setLevel(logging.WARNING) self.eval_engine.logger.setLevel(logging.WARNING)
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, None)): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"], reader=readers[0]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ]) train_transforms.set_random_state(1234) val_transforms = Compose([ LoadImaged(keys=["img", "seg"], reader=readers[1]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ]) # create a training data loader if cachedataset == 2: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) elif cachedataset == 3: train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([ ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True) ]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = [] metric_values = [] writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs")) model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(6): print("-" * 10) print(f"Epoch {epoch + 1}/{6}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: with eval_mode(model): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) # decollate prediction into a list and execute post processing for every item val_outputs = [ val_post_tran(i) for i in decollate_batch(val_outputs) ] # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current mean dice: {metric:0.4f} " f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}" ) writer.close() return epoch_loss_values, best_metric, best_metric_epoch
out_dir = "./outputs_fast" train_images = sorted(glob.glob(os.path.join(data_root, "imagesTr", "*.nii.gz"))) train_labels = sorted(glob.glob(os.path.join(data_root, "labelsTr", "*.nii.gz"))) data_dicts = [ {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] train_files, val_files = data_dicts[:-9], data_dicts[-9:] set_determinism(seed=0) train_transforms = Compose( [ Range("LoadImage")(LoadImaged(keys=["image", "label"])), Range()(EnsureChannelFirstd(keys=["image", "label"])), Range("Spacing")( Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ) ), Range()(Orientationd(keys=["image", "label"], axcodes="RAS")), Range()( ScaleIntensityRanged( keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0,
def main(tempdir): print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, _ = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) files = [{"img": img} for img in images] # define pre transforms pre_transforms = Compose([ LoadImaged(keys="img"), EnsureChannelFirstd(keys="img"), Orientationd(keys="img", axcodes="RAS"), Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True), ScaleIntensityd(keys="img"), ToTensord(keys="img"), ]) # define dataset and dataloader dataset = Dataset(data=files, transform=pre_transforms) dataloader = DataLoader(dataset, batch_size=2, num_workers=4) # define post transforms post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), Invertd(keys="pred", transform=pre_transforms, loader=dataloader, orig_keys="img", nearest_interp=True), SaveImaged(keys="pred_inverted", output_dir="./output", output_postfix="seg", resample=False), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict( torch.load("best_metric_model_segmentation3d_dict.pth")) net.eval() with torch.no_grad(): for d in dataloader: images = d["img"].to(device) # define sliding window size and batch size for windows inference d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, predictor=net) # execute post transforms to invert spatial transforms and save to NIfTI files post_transforms(d)
def compute(args): # generate synthetic data for the example if args.local_rank == 0 and not os.path.exists(args.dir): # create 16 random pred, label paris for evaluation print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # if have multiple nodes, set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): pred, label = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1, noise_max=0.5) n = nib.Nifti1Image(pred, np.eye(4)) nib.save(n, os.path.join(args.dir, f"pred{i:d}.nii.gz")) n = nib.Nifti1Image(label, np.eye(4)) nib.save(n, os.path.join(args.dir, f"label{i:d}.nii.gz")) # initialize the distributed evaluation process, change to NCCL backend if computing on GPU dist.init_process_group(backend="gloo", init_method="env://") preds = sorted(glob(os.path.join(args.dir, "pred*.nii.gz"))) labels = sorted(glob(os.path.join(args.dir, "label*.nii.gz"))) datalist = [{ "pred": pred, "label": label } for pred, label in zip(preds, labels)] # split data for every subprocess, for example, 16 processes compute in parallel data_part = partition_dataset( data=datalist, num_partitions=dist.get_world_size(), shuffle=False, even_divisible=False, )[dist.get_rank()] # define transforms for predictions and labels transforms = Compose([ LoadImaged(keys=["pred", "label"]), EnsureChannelFirstd(keys=["pred", "label"]), ScaleIntensityd(keys="pred"), EnsureTyped(keys=["pred", "label"]), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) data_part = [transforms(item) for item in data_part] # compute metrics for current process metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) metric(y_pred=[i["pred"] for i in data_part], y=[i["label"] for i in data_part]) filenames = [ item["pred_meta_dict"]["filename_or_obj"] for item in data_part ] # all-gather results from all the processes and reduce for final result result = metric.aggregate().item() filenames = string_list_all_gather(strings=filenames) if args.local_rank == 0: print("mean dice: ", result) # generate metrics reports at: output/mean_dice_raw.csv, output/mean_dice_summary.csv, output/metrics.csv write_metrics_reports( save_dir="./output", images=filenames, metrics={"mean_dice": result}, metric_details={"mean_dice": metric.get_buffer()}, summary_ops="*", ) metric.reset() dist.destroy_process_group()
def main(): parser = argparse.ArgumentParser(description="training") parser.add_argument( "--checkpoint", type=str, default=None, help="checkpoint full path", ) parser.add_argument( "--factor_ram_cost", default=0.0, type=float, help="factor to determine RAM cost in the searched architecture", ) parser.add_argument( "--fold", action="store", required=True, help="fold index in N-fold cross-validation", ) parser.add_argument( "--json", action="store", required=True, help="full path of .json file", ) parser.add_argument( "--json_key", action="store", required=True, help="selected key in .json data list", ) parser.add_argument( "--local_rank", required=int, help="local process rank", ) parser.add_argument( "--num_folds", action="store", required=True, help="number of folds in cross-validation", ) parser.add_argument( "--output_root", action="store", required=True, help="output root", ) parser.add_argument( "--root", action="store", required=True, help="data root", ) args = parser.parse_args() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not os.path.exists(args.output_root): os.makedirs(args.output_root, exist_ok=True) amp = True determ = True factor_ram_cost = args.factor_ram_cost fold = int(args.fold) input_channels = 1 learning_rate = 0.025 learning_rate_arch = 0.001 learning_rate_milestones = np.array([0.4, 0.8]) num_images_per_batch = 1 num_epochs = 1430 # around 20k iteration num_epochs_per_validation = 100 num_epochs_warmup = 715 num_folds = int(args.num_folds) num_patches_per_image = 1 num_sw_batch_size = 6 output_classes = 3 overlap_ratio = 0.625 patch_size = (96, 96, 96) patch_size_valid = (96, 96, 96) spacing = [1.0, 1.0, 1.0] print("factor_ram_cost", factor_ram_cost) # deterministic training if determ: set_determinism(seed=0) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") # dist.barrier() world_size = dist.get_world_size() with open(args.json, "r") as f: json_data = json.load(f) split = len(json_data[args.json_key]) // num_folds list_train = json_data[args.json_key][:( split * fold)] + json_data[args.json_key][(split * (fold + 1)):] list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))] # training data files = [] for _i in range(len(list_train)): str_img = os.path.join(args.root, list_train[_i]["image"]) str_seg = os.path.join(args.root, list_train[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) train_files = files random.shuffle(train_files) train_files_w = train_files[:len(train_files) // 2] train_files_w = partition_dataset(data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_w:", len(train_files_w)) train_files_a = train_files[len(train_files) // 2:] train_files_a = partition_dataset(data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_a:", len(train_files_a)) # validation data files = [] for _i in range(len(list_valid)): str_img = os.path.join(args.root, list_valid[_i]["image"]) str_seg = os.path.join(args.root, list_valid[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) val_files = files val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[dist.get_rank()] print("val_files:", len(val_files)) # network architecture device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)), CopyItemsd(keys=["label"], times=1, names=["label4crop"]), Lambdad( keys=["label4crop"], func=lambda x: np.concatenate(tuple([ ndimage.binary_dilation( (x == _k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(output_classes) ]), axis=0), overwrite=True, ), EnsureTyped(keys=["image", "label"]), CastToTyped(keys=["image"], dtype=(torch.float32)), SpatialPadd(keys=["image", "label", "label4crop"], spatial_size=patch_size, mode=["reflect", "constant", "constant"]), RandCropByLabelClassesd(keys=["image", "label"], label_key="label4crop", num_classes=output_classes, ratios=[ 1, ] * output_classes, spatial_size=patch_size, num_samples=num_patches_per_image), Lambdad(keys=["label4crop"], func=lambda x: 0), RandRotated(keys=["image", "label"], range_x=0.3, range_y=0.3, range_z=0.3, mode=["bilinear", "nearest"], prob=0.2), RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, mode=["trilinear", "nearest"], align_corners=[True, None], prob=0.16), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5), RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5), CastToTyped(keys=["image", "label"], dtype=(torch.float32, torch.uint8)), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ToTensord(keys=["image", "label"]) ]) train_ds_a = monai.data.CacheDataset(data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8) train_ds_w = monai.data.CacheDataset(data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited. # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms) # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms) # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) train_loader_a = ThreadDataLoader(train_ds_a, num_workers=0, batch_size=num_images_per_batch, shuffle=True) train_loader_w = ThreadDataLoader(train_ds_w, num_workers=0, batch_size=num_images_per_batch, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) # DataLoader can be used as alternatives when ThreadDataLoader is less efficient. # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available()) dints_space = monai.networks.nets.TopologySearch( channel_mul=0.5, num_blocks=12, num_depths=4, use_downsample=True, device=device, ) model = monai.networks.nets.DiNTS( dints_space=dints_space, in_channels=input_channels, num_classes=output_classes, use_downsample=True, ) model = model.to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=output_classes)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)]) # loss function loss_func = monai.losses.DiceCELoss( include_background=False, to_onehot_y=True, softmax=True, squared_pred=True, batch=True, smooth_nr=0.00001, smooth_dr=0.00001, ) # optimizer optimizer = torch.optim.SGD(model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004) arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) print() if torch.cuda.device_count() > 1: if dist.get_rank() == 0: print("Let's use", torch.cuda.device_count(), "GPUs!") model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) if args.checkpoint != None and os.path.isfile(args.checkpoint): print("[info] fine-tuning pre-trained checkpoint {0:s}".format( args.checkpoint)) model.load_state_dict(torch.load(args.checkpoint, map_location=device)) torch.cuda.empty_cache() else: print("[info] training from scratch") # amp if amp: from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() if dist.get_rank() == 0: print("[info] amp enabled") # start a typical PyTorch training val_interval = num_epochs_per_validation best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() idx_iter = 0 metric_values = list() if dist.get_rank() == 0: writer = SummaryWriter( log_dir=os.path.join(args.output_root, "Events")) with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") dataloader_a_iterator = iter(train_loader_a) start_time = time.time() for epoch in range(num_epochs): decay = 0.5**np.sum([ (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones ]) lr = learning_rate * decay for param_group in optimizer.param_groups: param_group["lr"] = lr if dist.get_rank() == 0: print("-" * 10) print(f"epoch {epoch + 1}/{num_epochs}") print("learning rate is set to {}".format(lr)) model.train() epoch_loss = 0 loss_torch = torch.zeros(2, dtype=torch.float, device=device) epoch_loss_arch = 0 loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device) step = 0 for batch_data in train_loader_w: step += 1 inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = True else: for _ in model.module.weight_parameters(): _.requires_grad = True dints_space.log_alpha_a.requires_grad = False dints_space.log_alpha_c.requires_grad = False optimizer.zero_grad() if amp: with autocast(): outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() loss_torch[0] += loss.item() loss_torch[1] += 1.0 epoch_len = len(train_loader_w) idx_iter += 1 if dist.get_rank() == 0: print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) if epoch < num_epochs_warmup: continue try: sample_a = next(dataloader_a_iterator) except StopIteration: dataloader_a_iterator = iter(train_loader_a) sample_a = next(dataloader_a_iterator) inputs_search, labels_search = sample_a["image"].to( device), sample_a["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = False else: for _ in model.module.weight_parameters(): _.requires_grad = False dints_space.log_alpha_a.requires_grad = True dints_space.log_alpha_c.requires_grad = True # linear increase topology and RAM loss entropy_alpha_c = torch.tensor(0.).to(device) entropy_alpha_a = torch.tensor(0.).to(device) ram_cost_full = torch.tensor(0.).to(device) ram_cost_usage = torch.tensor(0.).to(device) ram_cost_loss = torch.tensor(0.).to(device) topology_loss = torch.tensor(0.).to(device) probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean() entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \ F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean() topology_loss = dints_space.get_topology_entropy(probs_a) ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True) ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape) ram_cost_loss = torch.abs(factor_ram_cost - ram_cost_usage / ram_cost_full) arch_optimizer_a.zero_grad() arch_optimizer_c.zero_grad() combination_weights = (epoch - num_epochs_warmup) / ( num_epochs - num_epochs_warmup) if amp: with autocast(): outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) scaler.scale(loss).backward() scaler.step(arch_optimizer_a) scaler.step(arch_optimizer_c) scaler.update() else: outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) loss.backward() arch_optimizer_a.step() arch_optimizer_c.step() epoch_loss_arch += loss.item() loss_torch_arch[0] += loss.item() loss_torch_arch[1] += 1.0 if dist.get_rank() == 0: print( "[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}") writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step) # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) loss_torch = loss_torch.tolist() loss_torch_arch = loss_torch_arch.tolist() if dist.get_rank() == 0: loss_torch_epoch = loss_torch[0] / loss_torch[1] print( f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if epoch >= num_epochs_warmup: loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] print( f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if (epoch + 1) % val_interval == 0: torch.cuda.empty_cache() model.eval() with torch.no_grad(): metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) metric_sum = 0.0 metric_count = 0 metric_mat = [] val_images = None val_labels = None val_outputs = None _index = 0 for val_data in val_loader: val_images = val_data["image"].to(device) val_labels = val_data["label"].to(device) roi_size = patch_size_valid sw_batch_size = num_sw_batch_size if amp: with torch.cuda.amp.autocast(): pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) else: pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) val_outputs = pred val_outputs = post_pred(val_outputs[0, ...]) val_outputs = val_outputs[None, ...] val_labels = post_label(val_labels[0, ...]) val_labels = val_labels[None, ...] value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) print(_index + 1, "/", len(val_loader), value) metric_count += len(value) metric_sum += value.sum().item() metric_vals = value.cpu().numpy() if len(metric_mat) == 0: metric_mat = metric_vals else: metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) for _c in range(output_classes - 1): val0 = torch.nan_to_num(value[0, _c], nan=0.0) val1 = 1.0 - torch.isnan(value[0, 0]).float() metric[2 * _c] += val0 * val1 metric[2 * _c + 1] += val1 _index += 1 # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) metric = metric.tolist() if dist.get_rank() == 0: for _c in range(output_classes - 1): print( "evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) avg_metric = 0 for _c in range(output_classes - 1): avg_metric += metric[2 * _c] / metric[2 * _c + 1] avg_metric = avg_metric / float(output_classes - 1) print("avg_metric", avg_metric) if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 best_metric_iterations = idx_iter node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode( ) torch.save( { "node_a": node_a_d, "arch_code_a": arch_code_a_d, "arch_code_a_max": arch_code_a_max_d, "arch_code_c": arch_code_c_d, "iter_num": idx_iter, "epochs": epoch + 1, "best_dsc": best_metric, "best_path": best_metric_iterations, }, os.path.join(args.output_root, "search_code_" + str(idx_iter) + ".pth"), ) print("saved new best metric model") dict_file = {} dict_file["best_avg_dice_score"] = float(best_metric) dict_file["best_avg_dice_score_epoch"] = int( best_metric_epoch) dict_file["best_avg_dice_score_iteration"] = int(idx_iter) with open(os.path.join(args.output_root, "progress.yaml"), "w") as out_file: documents = yaml.dump(dict_file, stream=out_file) print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, avg_metric, best_metric, best_metric_epoch)) current_time = time.time() elapsed_time = (current_time - start_time) / 60.0 with open( os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write( "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n" .format(epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter)) dist.barrier() torch.cuda.empty_cache() print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) if dist.get_rank() == 0: writer.close() dist.destroy_process_group() return
def main_worker(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f if not os.path.exists(args.dir): raise FileNotFoundError(f"missing directory {args.dir}") # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) # use amp to accelerate training scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True total_start = time.time() train_transforms = Compose([ # load 4 Nifti images and stack them together LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ]) # create a training data loader train_ds = BratsCacheDataset( root_dir=args.dir, transform=train_transforms, section="training", num_workers=4, cache_rate=args.cache_rate, shuffle=True, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.batch_size, shuffle=True) # validation transforms and dataset val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), ]) val_ds = BratsCacheDataset( root_dir=args.dir, transform=val_transforms, section="validation", num_workers=4, cache_rate=args.cache_rate, shuffle=False, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=args.batch_size, shuffle=False) # create network, loss function and optimizer if args.network == "SegResNet": model = SegResNet( blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], init_filters=16, in_channels=4, out_channels=3, dropout_prob=0.0, ).to(device) else: model = UNet( spatial_dims=3, in_channels=4, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = DiceFocalLoss( smooth_nr=1e-5, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True, batch=True, ) optimizer = Novograd(model.parameters(), lr=args.lr) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) dice_metric = DiceMetric(include_background=True, reduction="mean") dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch") post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # start a typical PyTorch training best_metric = -1 best_metric_epoch = -1 print(f"time elapsed before training: {time.time() - total_start}") train_start = time.time() for epoch in range(args.epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{args.epochs}") epoch_loss = train(train_loader, model, loss_function, optimizer, lr_scheduler, scaler) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % args.val_interval == 0: metric, metric_tc, metric_wt, metric_et = evaluate( model, val_loader, dice_metric, dice_metric_batch, post_trans) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 if dist.get_rank() == 0: torch.save(model.state_dict(), "best_metric_model.pth") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}," f" total train time: {(time.time() - train_start):.4f}") dist.destroy_process_group()
def main(): #TODO Defining file paths & output directory path json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json') data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19') logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467') if os.path.exists(logdir_path)==False: os.mkdir(logdir_path) # Load Json & Append Root Path with open(json_Path, 'r') as json_f: json_Data = json.load(json_f) train_Data = json_Data['training'] val_Data = json_Data['validation'] for idx, each_d in enumerate(train_Data): train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image']) for idx, each_d in enumerate(val_Data): val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image']) print('Total Number of Training Data Samples: {}'.format(len(train_Data))) print(train_Data) print('#' * 10) print('Total Number of Validation Data Samples: {}'.format(len(val_Data))) print(val_Data) print('#' * 10) # Set Determinism set_determinism(seed=123) # Define Training Transforms train_Transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Spacingd(keys=["image"], pixdim=( 2.0, 2.0, 2.0), mode=("bilinear")), ScaleIntensityRanged( keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image"], source_key="image"), SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)), RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2), CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False), OneOf(transforms=[ RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8), # Please note that that if image, image_2 are called via the same transform call because of the determinism # they will get augmented the exact same way which is not the required case here, hence two calls are made OneOf(transforms=[ RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8) ] ) check_ds = Dataset(data=train_Data, transform=train_Transforms) check_loader = DataLoader(check_ds, batch_size=1) check_data = first(check_loader) image = (check_data["image"][0][0]) print(f"image shape: {image.shape}") # Define Network ViT backbone & Loss & Optimizer device = torch.device("cuda:0") model = ViTAutoEnc( in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16), pos_embed='conv', hidden_size=768, mlp_dim=3072, ) model = model.to(device) # Define Hyper-paramters for training loop max_epochs = 500 val_interval = 2 batch_size = 4 lr = 1e-4 epoch_loss_values = [] step_loss_values = [] epoch_cl_loss_values = [] epoch_recon_loss_values = [] val_loss_values = [] best_val_loss = 1000.0 recon_loss = L1Loss() contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Define DataLoader using MONAI, CacheDataset needs to be used train_ds = Dataset(data=train_Data, transform=train_Transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4) val_ds = Dataset(data=val_Data, transform=train_Transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4) for epoch in range(max_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 epoch_cl_loss = 0 epoch_recon_loss = 0 step = 0 for batch_data in train_loader: step += 1 start_time = time.time() inputs, inputs_2, gt_input = ( batch_data["image"].to(device), batch_data["image_2"].to(device), batch_data["gt_image"].to(device), ) optimizer.zero_grad() outputs_v1, hidden_v1 = model(inputs) outputs_v2, hidden_v2 = model(inputs_2) flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4) flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4) r_loss = recon_loss(outputs_v1, gt_input) cl_loss = contrastive_loss(flat_out_v1, flat_out_v2) # Adjust the CL loss by Recon Loss total_loss = r_loss + cl_loss * r_loss total_loss.backward() optimizer.step() epoch_loss += total_loss.item() step_loss_values.append(total_loss.item()) # CL & Recon Loss Storage of Value epoch_cl_loss += cl_loss.item() epoch_recon_loss += r_loss.item() end_time = time.time() print( f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {total_loss.item():.4f}, " f"time taken: {end_time-start_time}s") epoch_loss /= step epoch_cl_loss /= step epoch_recon_loss /= step epoch_loss_values.append(epoch_loss) epoch_cl_loss_values.append(epoch_cl_loss) epoch_recon_loss_values.append(epoch_recon_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if epoch % val_interval == 0: print('Entering Validation for epoch: {}'.format(epoch+1)) total_val_loss = 0 val_step = 0 model.eval() for val_batch in val_loader: val_step += 1 start_time = time.time() inputs, gt_input = ( val_batch["image"].to(device), val_batch["gt_image"].to(device), ) print('Input shape: {}'.format(inputs.shape)) outputs, outputs_v2 = model(inputs) val_loss = recon_loss(outputs, gt_input) total_val_loss += val_loss.item() end_time = time.time() total_val_loss /= val_step val_loss_values.append(total_val_loss) print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s") if total_val_loss < best_val_loss: print(f"Saving new model based on validation loss {total_val_loss:.4f}") best_val_loss = total_val_loss checkpoint = {'epoch': max_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt')) plt.figure(1, figsize=(8, 8)) plt.subplot(2, 2, 1) plt.plot(epoch_loss_values) plt.grid() plt.title('Training Loss') plt.subplot(2, 2, 2) plt.plot(val_loss_values) plt.grid() plt.title('Validation Loss') plt.subplot(2, 2, 3) plt.plot(epoch_cl_loss_values) plt.grid() plt.title('Training Contrastive Loss') plt.subplot(2, 2, 4) plt.plot(epoch_recon_loss_values) plt.grid() plt.title('Training Recon Loss') plt.savefig(os.path.join(logdir_path, 'loss_plots.png')) plt.close(1) print('Done') return None
def main(tempdir): print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, _ = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) files = [{"img": img} for img in images] # define pre transforms pre_transforms = Compose([ LoadImaged(keys="img"), EnsureChannelFirstd(keys="img"), Orientationd(keys="img", axcodes="RAS"), Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True), ScaleIntensityd(keys="img"), EnsureTyped(keys="img"), ]) # define dataset and dataloader dataset = Dataset(data=files, transform=pre_transforms) dataloader = DataLoader(dataset, batch_size=2, num_workers=4) # define post transforms post_transforms = Compose([ EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), Invertd( keys= "pred", # invert the `pred` data field, also support multiple fields transform=pre_transforms, orig_keys= "img", # get the previously applied pre_transforms information on the `img` data field, # then invert `pred` based on this information. we can use same info # for multiple fields, also support different orig_keys for different fields meta_keys= "pred_meta_dict", # key field to save inverted meta data, every item maps to `keys` orig_meta_keys= "img_meta_dict", # get the meta data from `img_meta_dict` field when inverting, # for example, may need the `affine` to invert `Spacingd` transform, # multiple fields can use the same meta data to invert meta_key_postfix= "meta_dict", # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key, # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}", # otherwise, no need this arg during inverting nearest_interp= False, # don't change the interpolation mode to "nearest" when inverting transforms # to ensure a smooth output, then execute `AsDiscreted` transform to_tensor=True, # convert to PyTorch Tensor after inverting ), AsDiscreted(keys="pred", threshold=0.5), SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict( torch.load("best_metric_model_segmentation3d_dict.pth")) net.eval() with torch.no_grad(): for d in dataloader: images = d["img"].to(device) # define sliding window size and batch size for windows inference d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, predictor=net) # decollate the batch data into a list of dictionaries, then execute postprocessing transforms d = [post_transforms(i) for i in decollate_batch(d)]