def test_correct_results(self, degrees, spatial_axes, reshape, order, mode, cval, prefilter): rotate_fn = RandRotated('img', degrees, prob=1.0, spatial_axes=spatial_axes, reshape=reshape, order=order, mode=mode, cval=cval, prefilter=prefilter) rotate_fn.set_random_state(243) rotated = rotate_fn({'img': self.imt[0]}) angle = rotate_fn.angle expected = list() for channel in self.imt[0]: expected.append( scipy.ndimage.rotate(channel, angle, spatial_axes, reshape, order=order, mode=mode, cval=cval, prefilter=prefilter)) expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, rotated['img']))
def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotated( "img", range_x=degrees, prob=1.0, keep_size=keep_size, mode=mode, padding_mode=padding_mode, align_corners=align_corners, ) rotate_fn.set_random_state(243) rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": _mode = "nearest" elif padding_mode == "reflection": _mode = "reflect" else: _mode = "constant" angle = rotate_fn.x expected = scipy.ndimage.rotate(self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False) expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, rotated["img"][0]))
def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotated( "img", range_x=degrees, prob=1.0, keep_size=keep_size, mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=np.float64, ) im = im_type(self.imt[0]) rotate_fn.set_random_state(243) rotated = rotate_fn({"img": im, "seg": im_type(self.segn[0])}) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": _mode = "nearest" elif padding_mode == "reflection": _mode = "reflect" else: _mode = "constant" angle = rotate_fn.rand_rotate.x expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) test_local_inversion(rotate_fn, rotated, {"img": im}, "img") for k, v in rotated.items(): rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v expected = np.stack(expected).astype(np.float32) good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels")
def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotated( "img", range_x=x, range_y=y, range_z=z, prob=1.0, keep_size=keep_size, mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=np.float64, ) rotate_fn.set_random_state(243) rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) np.testing.assert_allclose(rotated["img"].shape, expected)
def pre_transforms(self): t = [ LoadImaged(keys="image", reader="nibabelreader"), AddChanneld(keys="image"), # Spacing might not be needed as resize transform is used later. # Spacingd(keys="image", pixdim=self.spacing), RandAffined( keys="image", prob=1, rotate_range=(np.pi / 4, np.pi / 4, np.pi / 4), padding_mode="zeros", as_tensor_output=False, ), RandFlipd(keys="image", prob=0.5, spatial_axis=0), RandRotated(keys="image", range_x=(-5, 5), range_y=(-5, 5), range_z=(-5, 5)), Resized(keys="image", spatial_size=self.spatial_size), ] # If using TTA for deepedit if self.deepedit: t.append(DiscardAddGuidanced(keys="image")) t.append(ToTensord(keys="image")) return Compose(t)
TESTS.append( ("Zoomd 1d", "1D odd", 0, True, Zoomd(KEYS, zoom=2, keep_size=False))) TESTS.append(("Zoomd 2d", "2D", 2e-1, True, Zoomd(KEYS, zoom=0.9))) TESTS.append(("Zoomd 3d", "3D", 3e-2, True, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False))) TESTS.append(("RandZoom 3d", "3D", 9e-2, True, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) TESTS.append(("RandRotated, prob 0", "2D", 0, True, RandRotated(KEYS, prob=0, dtype=np.float64))) TESTS.append(( "Rotated 2d", "2D", 8e-2, True, Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False, dtype=np.float64), )) TESTS.append(( "Rotated 3d",
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( keys=["image", "label"], transform=transform, loader=loader, orig_keys="label", nearest_interp=True, postfix="inverted", to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ) # execute 1 epoch for d in loader: d = inverter(d) # this unit test only covers basic function, test_handler_transform_inverter covers more self.assertTupleEqual(d["image"].shape[1:], (1, 100, 100, 100)) self.assertTupleEqual(d["label"].shape[1:], (1, 100, 100, 100)) # check the nearest inerpolation mode for i in d["image_inverted"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in d["label_inverted"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) set_determinism(seed=None)
from monai.utils import set_determinism TESTS: List[Tuple] = [] for pad_collate in [ lambda x: pad_list_data_collate(batch=x, method="end", mode="constant" ), PadListDataCollate(method="end", mode="constant"), ]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False, dtype=np.float64))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((dict, pad_collate, Compose([ RandRotate90d("image", prob=1, max_k=3), RandRotate90d("image", prob=1, max_k=4) ]))) TESTS.append(
RandRotate90d, RandRotated, RandSpatialCrop, RandSpatialCropd, RandZoom, RandZoomd, ) from monai.utils import set_determinism TESTS: List[Tuple] = [] TESTS.append((dict, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) TESTS.append((dict, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2))) TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True))) TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False))) TESTS.append( (list, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((list, RandRotate90(prob=1, max_k=2)))
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, postfix="inverted1", to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) # test different nearest interpolation values TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="image", nearest_interp=[True, False], post_func=[lambda x: x + 10, lambda x: x], postfix="inverted2", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) # check the nearest inerpolation mode for i in engine.state.output["image_inverted1"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted1"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted1"][-1].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1824: torch 1.5.1 self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") # check the case that different items use different interpolation mode to invert transforms for i in engine.state.output["image_inverted2"]: # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted2"]: # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(i.shape, (1, 100, 101, 107))
_, has_nib = optional_import("nibabel") KEYS = ["image", "label"] TESTS_3D = [( t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 3 ) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS, prob=0.5), Compose( [RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), RandAffined(keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device( "cuda" if torch.cuda.is_available() else "cpu")), ]] TESTS_2D = [ (t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 2) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]), RandAxisFlipd(keys=KEYS, prob=0.5), Compose([ RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)),
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
"RandLambdad 3d", "3D", 5e-2, RandLambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True, prob=0.5), ) ) TESTS.append(("Zoomd 1d", "1D odd", 0, Zoomd(KEYS, zoom=2, keep_size=False))) TESTS.append(("Zoomd 2d", "2D", 2e-1, Zoomd(KEYS, zoom=0.9))) TESTS.append(("Zoomd 3d", "3D", 3e-2, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False))) TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) TESTS.append(("RandRotated, prob 0", "2D", 0, RandRotated(KEYS, prob=0))) TESTS.append( ("Rotated 2d", "2D", 8e-2, Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False)) ) TESTS.append( ( "Rotated 3d", "3D", 1e-1, Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore ) ) TESTS.append(
def main(): """ Read input and configuration parameters """ parser = argparse.ArgumentParser(description='Run basic UNet with MONAI.') parser.add_argument('--config', dest='config', metavar='config', type=str, help='config file') args = parser.parse_args() with open(args.config) as f: config_info = yaml.load(f, Loader=yaml.FullLoader) # print to log the parameter setups print(yaml.dump(config_info)) # GPU params cuda_device = config_info['device']['cuda_device'] num_workers = config_info['device']['num_workers'] # training and validation params loss_type = config_info['training']['loss_type'] batch_size_train = config_info['training']['batch_size_train'] batch_size_valid = config_info['training']['batch_size_valid'] lr = float(config_info['training']['lr']) nr_train_epochs = config_info['training']['nr_train_epochs'] validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] sliding_window_validation = config_info['training']['sliding_window_validation'] # data params data_root = config_info['data']['data_root'] training_list = config_info['data']['training_list'] validation_list = config_info['data']['validation_list'] # model saving # model saving out_model_dir = os.path.join(config_info['output']['out_model_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['output_subfix']) print("Saving to directory ", out_model_dir) max_nr_models_saved = config_info['output']['max_nr_models_saved'] monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) torch.cuda.set_device(cuda_device) """ Data Preparation """ # create training and validation data lists train_files = create_data_list(data_folder_list=data_root, subject_list=training_list, img_postfix='_Image', label_postfix='_Label') print(len(train_files)) print(train_files[0]) print(train_files[-1]) val_files = create_data_list(data_folder_list=data_root, subject_list=validation_list, img_postfix='_Image', label_postfix='_Label') print(len(val_files)) print(val_files[0]) print(val_files[-1]) # data preprocessing for training: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) # - define 2D patches to be extracted # - add data augmentation (random rotation and random flip) # - squeeze to 2D train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False), RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) check_train_data = monai.utils.misc.first(train_loader) print("Training data tensor shapes") print(check_train_data['img'].shape, check_train_data['seg'].shape) # data preprocessing for validation: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) if sliding_window_validation: val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), ToTensord(keys=['img', 'seg']) ]) do_shuffle = False collate_fn_to_use = None else: # - add extraction of 2D slices from validation set to emulate how loss is computed at training val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) do_shuffle = True collate_fn_to_use = list_data_collate # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size_valid, shuffle=do_shuffle, collate_fn=collate_fn_to_use, num_workers=num_workers) check_valid_data = monai.utils.misc.first(val_loader) print("Validation data tensor shapes") print(check_valid_data['img'].shape, check_valid_data['seg'].shape) """ Network preparation """ # Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss_function = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) device = torch.cuda.current_device() """ Training loop """ # start a typical PyTorch training best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid")) net.to(device) for epoch in range(nr_train_epochs): print('-' * 10) print('Epoch {}/{}'.format(epoch + 1, nr_train_epochs)) net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['img'].to(device), batch_data['seg'].to(device) opt.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() opt.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item())) writer_train.add_scalar('loss', loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss)) if (epoch + 1) % validation_every_n_epochs == 0: net.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None check_tot_validation = 0 for val_data in val_loader: check_tot_validation += 1 val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device) if sliding_window_validation: print('Running sliding window validation') roi_size = (96, 96, 1) val_outputs = sliding_window_inference(val_images, roi_size, batch_size_valid, net) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() else: print('Running 2D validation') # compute validation val_outputs = net(val_images) value = 1.0 - loss_function(val_outputs, val_labels) metric_count += 1 metric_sum += value.item() print("Total number of data in validation: %d" % check_tot_validation) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), os.path.join(out_model_dir, 'best_metric_model.pth')) print('saved new best metric model') print("current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d" % (epoch + 1, metric, best_metric, best_metric_epoch)) epoch_len = len(train_ds) // train_loader.batch_size writer_valid.add_scalar('loss', 1.0 - metric, epoch_len * epoch + step) writer_valid.add_scalar('val_mean_dice', metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer_valid, index=0, tag='image') plot_2d_or_3d_image(val_labels, epoch + 1, writer_valid, index=0, tag='label') plot_2d_or_3d_image(val_outputs, epoch + 1, writer_valid, index=0, tag='output') print('train completed, best_metric: %0.4f at epoch: %d' % (best_metric, best_metric_epoch)) writer_train.close() writer_valid.close()
def main(): """ Basic UNet as implemented in MONAI for Fetal Brain Segmentation, but using ignite to manage training and validation loop and checkpointing :return: """ """ Read input and configuration parameters """ parser = argparse.ArgumentParser( description='Run basic UNet with MONAI - Ignite version.') parser.add_argument('--config', dest='config', metavar='config', type=str, help='config file') args = parser.parse_args() with open(args.config) as f: config_info = yaml.load(f, Loader=yaml.FullLoader) # print to log the parameter setups print(yaml.dump(config_info)) # GPU params cuda_device = config_info['device']['cuda_device'] num_workers = config_info['device']['num_workers'] # training and validation params loss_type = config_info['training']['loss_type'] batch_size_train = config_info['training']['batch_size_train'] batch_size_valid = config_info['training']['batch_size_valid'] lr = float(config_info['training']['lr']) lr_decay = config_info['training']['lr_decay'] if lr_decay is not None: lr_decay = float(lr_decay) nr_train_epochs = config_info['training']['nr_train_epochs'] validation_every_n_epochs = config_info['training'][ 'validation_every_n_epochs'] sliding_window_validation = config_info['training'][ 'sliding_window_validation'] if 'model_to_load' in config_info['training'].keys(): model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise BlockingIOError( "cannot find model: {}".format(model_to_load)) else: model_to_load = None if 'manual_seed' in config_info['training'].keys(): seed = config_info['training']['manual_seed'] else: seed = None # data params data_root = config_info['data']['data_root'] training_list = config_info['data']['training_list'] validation_list = config_info['data']['validation_list'] # model saving out_model_dir = os.path.join( config_info['output']['out_model_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['output_subfix']) print("Saving to directory ", out_model_dir) if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') max_nr_models_saved = config_info['output']['max_nr_models_saved'] val_image_to_tensorboad = config_info['output']['val_image_to_tensorboad'] monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) torch.cuda.set_device(cuda_device) if seed is not None: # set manual seed if required (both numpy and torch) set_determinism(seed=seed) # # set torch only seed # torch.manual_seed(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False """ Data Preparation """ # create cache directory to store results for Persistent Dataset persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) # create training and validation data lists train_files = create_data_list(data_folder_list=data_root, subject_list=training_list, img_postfix='_Image', label_postfix='_Label') print(len(train_files)) print(train_files[0]) print(train_files[-1]) val_files = create_data_list(data_folder_list=data_root, subject_list=validation_list, img_postfix='_Image', label_postfix='_Label') print(len(val_files)) print(val_files[0]) print(val_files[-1]) # data preprocessing for training: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) # - define 2D patches to be extracted # - add data augmentation (random rotation and random flip) # - squeeze to 2D train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False), RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) # create a training data loader # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, # num_workers=num_workers) train_ds = monai.data.PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # check_train_data = monai.utils.misc.first(train_loader) # print("Training data tensor shapes") # print(check_train_data['img'].shape, check_train_data['seg'].shape) # data preprocessing for validation: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) if sliding_window_validation: val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), ToTensord(keys=['img', 'seg']) ]) do_shuffle = False collate_fn_to_use = None else: # - add extraction of 2D slices from validation set to emulate how loss is computed at training val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) do_shuffle = True collate_fn_to_use = list_data_collate # create a validation data loader # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, # num_workers=num_workers) val_ds = monai.data.PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=batch_size_valid, shuffle=do_shuffle, collate_fn=collate_fn_to_use, num_workers=num_workers) # check_valid_data = monai.utils.misc.first(val_loader) # print("Validation data tensor shapes") # print(check_valid_data['img'].shape, check_valid_data['seg'].shape) """ Network preparation """ # Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss_function = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) device = torch.cuda.current_device() if lr_decay is not None: lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=lr_decay, last_epoch=-1) """ Set ignite trainer """ # function to manage batch at training def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch['img'], batch['seg']), device, non_blocking) trainer = create_supervised_trainer(model=net, optimizer=opt, loss_fn=loss_function, device=device, non_blocking=False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training if model_to_load is not None: checkpoint_handler = CheckpointLoader(load_path=model_to_load, load_dict={ 'net': net, 'opt': opt, }) checkpoint_handler.attach(trainer) state = trainer.state_dict() else: checkpoint_handler = ModelCheckpoint(out_model_dir, 'net', n_saved=max_nr_models_saved, require_empty=False) # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=save_params) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_tensorboard_stats_handler = TensorBoardStatsHandler( summary_writer=writer_train) train_tensorboard_stats_handler.attach(trainer) if lr_decay is not None: print("Using Exponential LR decay") lr_schedule_handler = LrScheduleHandler(lr_scheduler, print_lr=True, name="lr_scheduler", writer=writer_train) lr_schedule_handler.attach(trainer) """ Set ignite evaluator to perform validation at training """ # set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = { "Loss": 1.0 - MeanDice(add_sigmoid=True, to_onehot_y=False), "Mean_Dice": MeanDice(add_sigmoid=True, to_onehot_y=False) } def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch['img'].to(device), batch['seg'].to( device) roi_size = (96, 96, 1) seg_probs = sliding_window_inference(val_images, roi_size, batch_size_valid, net) return seg_probs, val_labels if sliding_window_validation: # use sliding window inference at validation print("3D evaluator is used") net.to(device) evaluator = Engine(_sliding_window_processor) for name, metric in val_metrics.items(): metric.attach(evaluator, name) else: # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values print("2D evaluator is used") evaluator = create_supervised_evaluator(model=net, metrics=val_metrics, device=device, non_blocking=True, prepare_batch=prepare_batch) epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator # early_stopper = EarlyStopping(patience=4, # score_function=stopping_fn_from_metric(metric_name), # trainer=trainer) # evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid")) val_tensorboard_stats_handler = TensorBoardStatsHandler( summary_writer=writer_valid, output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.iteration ) # fetch global iteration number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. if val_image_to_tensorboad: val_tensorboard_image_handler = TensorBoardImageHandler( summary_writer=writer_valid, batch_transform=lambda batch: (batch['img'], batch['seg']), output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch) evaluator.add_event_handler( event_name=Events.ITERATION_COMPLETED(every=1), handler=val_tensorboard_image_handler) """ Run training """ state = trainer.run(train_loader, nr_train_epochs) print("Done!")
def main(): parser = argparse.ArgumentParser(description="training") parser.add_argument( "--checkpoint", type=str, default=None, help="checkpoint full path", ) parser.add_argument( "--factor_ram_cost", default=0.0, type=float, help="factor to determine RAM cost in the searched architecture", ) parser.add_argument( "--fold", action="store", required=True, help="fold index in N-fold cross-validation", ) parser.add_argument( "--json", action="store", required=True, help="full path of .json file", ) parser.add_argument( "--json_key", action="store", required=True, help="selected key in .json data list", ) parser.add_argument( "--local_rank", required=int, help="local process rank", ) parser.add_argument( "--num_folds", action="store", required=True, help="number of folds in cross-validation", ) parser.add_argument( "--output_root", action="store", required=True, help="output root", ) parser.add_argument( "--root", action="store", required=True, help="data root", ) args = parser.parse_args() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not os.path.exists(args.output_root): os.makedirs(args.output_root, exist_ok=True) amp = True determ = True factor_ram_cost = args.factor_ram_cost fold = int(args.fold) input_channels = 1 learning_rate = 0.025 learning_rate_arch = 0.001 learning_rate_milestones = np.array([0.4, 0.8]) num_images_per_batch = 1 num_epochs = 1430 # around 20k iteration num_epochs_per_validation = 100 num_epochs_warmup = 715 num_folds = int(args.num_folds) num_patches_per_image = 1 num_sw_batch_size = 6 output_classes = 3 overlap_ratio = 0.625 patch_size = (96, 96, 96) patch_size_valid = (96, 96, 96) spacing = [1.0, 1.0, 1.0] print("factor_ram_cost", factor_ram_cost) # deterministic training if determ: set_determinism(seed=0) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") # dist.barrier() world_size = dist.get_world_size() with open(args.json, "r") as f: json_data = json.load(f) split = len(json_data[args.json_key]) // num_folds list_train = json_data[args.json_key][:( split * fold)] + json_data[args.json_key][(split * (fold + 1)):] list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))] # training data files = [] for _i in range(len(list_train)): str_img = os.path.join(args.root, list_train[_i]["image"]) str_seg = os.path.join(args.root, list_train[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) train_files = files random.shuffle(train_files) train_files_w = train_files[:len(train_files) // 2] train_files_w = partition_dataset(data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_w:", len(train_files_w)) train_files_a = train_files[len(train_files) // 2:] train_files_a = partition_dataset(data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_a:", len(train_files_a)) # validation data files = [] for _i in range(len(list_valid)): str_img = os.path.join(args.root, list_valid[_i]["image"]) str_seg = os.path.join(args.root, list_valid[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) val_files = files val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[dist.get_rank()] print("val_files:", len(val_files)) # network architecture device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)), CopyItemsd(keys=["label"], times=1, names=["label4crop"]), Lambdad( keys=["label4crop"], func=lambda x: np.concatenate(tuple([ ndimage.binary_dilation( (x == _k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(output_classes) ]), axis=0), overwrite=True, ), EnsureTyped(keys=["image", "label"]), CastToTyped(keys=["image"], dtype=(torch.float32)), SpatialPadd(keys=["image", "label", "label4crop"], spatial_size=patch_size, mode=["reflect", "constant", "constant"]), RandCropByLabelClassesd(keys=["image", "label"], label_key="label4crop", num_classes=output_classes, ratios=[ 1, ] * output_classes, spatial_size=patch_size, num_samples=num_patches_per_image), Lambdad(keys=["label4crop"], func=lambda x: 0), RandRotated(keys=["image", "label"], range_x=0.3, range_y=0.3, range_z=0.3, mode=["bilinear", "nearest"], prob=0.2), RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, mode=["trilinear", "nearest"], align_corners=[True, None], prob=0.16), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5), RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5), CastToTyped(keys=["image", "label"], dtype=(torch.float32, torch.uint8)), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ToTensord(keys=["image", "label"]) ]) train_ds_a = monai.data.CacheDataset(data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8) train_ds_w = monai.data.CacheDataset(data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited. # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms) # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms) # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) train_loader_a = ThreadDataLoader(train_ds_a, num_workers=0, batch_size=num_images_per_batch, shuffle=True) train_loader_w = ThreadDataLoader(train_ds_w, num_workers=0, batch_size=num_images_per_batch, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) # DataLoader can be used as alternatives when ThreadDataLoader is less efficient. # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available()) dints_space = monai.networks.nets.TopologySearch( channel_mul=0.5, num_blocks=12, num_depths=4, use_downsample=True, device=device, ) model = monai.networks.nets.DiNTS( dints_space=dints_space, in_channels=input_channels, num_classes=output_classes, use_downsample=True, ) model = model.to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=output_classes)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)]) # loss function loss_func = monai.losses.DiceCELoss( include_background=False, to_onehot_y=True, softmax=True, squared_pred=True, batch=True, smooth_nr=0.00001, smooth_dr=0.00001, ) # optimizer optimizer = torch.optim.SGD(model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004) arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) print() if torch.cuda.device_count() > 1: if dist.get_rank() == 0: print("Let's use", torch.cuda.device_count(), "GPUs!") model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) if args.checkpoint != None and os.path.isfile(args.checkpoint): print("[info] fine-tuning pre-trained checkpoint {0:s}".format( args.checkpoint)) model.load_state_dict(torch.load(args.checkpoint, map_location=device)) torch.cuda.empty_cache() else: print("[info] training from scratch") # amp if amp: from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() if dist.get_rank() == 0: print("[info] amp enabled") # start a typical PyTorch training val_interval = num_epochs_per_validation best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() idx_iter = 0 metric_values = list() if dist.get_rank() == 0: writer = SummaryWriter( log_dir=os.path.join(args.output_root, "Events")) with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") dataloader_a_iterator = iter(train_loader_a) start_time = time.time() for epoch in range(num_epochs): decay = 0.5**np.sum([ (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones ]) lr = learning_rate * decay for param_group in optimizer.param_groups: param_group["lr"] = lr if dist.get_rank() == 0: print("-" * 10) print(f"epoch {epoch + 1}/{num_epochs}") print("learning rate is set to {}".format(lr)) model.train() epoch_loss = 0 loss_torch = torch.zeros(2, dtype=torch.float, device=device) epoch_loss_arch = 0 loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device) step = 0 for batch_data in train_loader_w: step += 1 inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = True else: for _ in model.module.weight_parameters(): _.requires_grad = True dints_space.log_alpha_a.requires_grad = False dints_space.log_alpha_c.requires_grad = False optimizer.zero_grad() if amp: with autocast(): outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() loss_torch[0] += loss.item() loss_torch[1] += 1.0 epoch_len = len(train_loader_w) idx_iter += 1 if dist.get_rank() == 0: print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) if epoch < num_epochs_warmup: continue try: sample_a = next(dataloader_a_iterator) except StopIteration: dataloader_a_iterator = iter(train_loader_a) sample_a = next(dataloader_a_iterator) inputs_search, labels_search = sample_a["image"].to( device), sample_a["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = False else: for _ in model.module.weight_parameters(): _.requires_grad = False dints_space.log_alpha_a.requires_grad = True dints_space.log_alpha_c.requires_grad = True # linear increase topology and RAM loss entropy_alpha_c = torch.tensor(0.).to(device) entropy_alpha_a = torch.tensor(0.).to(device) ram_cost_full = torch.tensor(0.).to(device) ram_cost_usage = torch.tensor(0.).to(device) ram_cost_loss = torch.tensor(0.).to(device) topology_loss = torch.tensor(0.).to(device) probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean() entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \ F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean() topology_loss = dints_space.get_topology_entropy(probs_a) ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True) ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape) ram_cost_loss = torch.abs(factor_ram_cost - ram_cost_usage / ram_cost_full) arch_optimizer_a.zero_grad() arch_optimizer_c.zero_grad() combination_weights = (epoch - num_epochs_warmup) / ( num_epochs - num_epochs_warmup) if amp: with autocast(): outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) scaler.scale(loss).backward() scaler.step(arch_optimizer_a) scaler.step(arch_optimizer_c) scaler.update() else: outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) loss.backward() arch_optimizer_a.step() arch_optimizer_c.step() epoch_loss_arch += loss.item() loss_torch_arch[0] += loss.item() loss_torch_arch[1] += 1.0 if dist.get_rank() == 0: print( "[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}") writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step) # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) loss_torch = loss_torch.tolist() loss_torch_arch = loss_torch_arch.tolist() if dist.get_rank() == 0: loss_torch_epoch = loss_torch[0] / loss_torch[1] print( f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if epoch >= num_epochs_warmup: loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] print( f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if (epoch + 1) % val_interval == 0: torch.cuda.empty_cache() model.eval() with torch.no_grad(): metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) metric_sum = 0.0 metric_count = 0 metric_mat = [] val_images = None val_labels = None val_outputs = None _index = 0 for val_data in val_loader: val_images = val_data["image"].to(device) val_labels = val_data["label"].to(device) roi_size = patch_size_valid sw_batch_size = num_sw_batch_size if amp: with torch.cuda.amp.autocast(): pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) else: pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) val_outputs = pred val_outputs = post_pred(val_outputs[0, ...]) val_outputs = val_outputs[None, ...] val_labels = post_label(val_labels[0, ...]) val_labels = val_labels[None, ...] value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) print(_index + 1, "/", len(val_loader), value) metric_count += len(value) metric_sum += value.sum().item() metric_vals = value.cpu().numpy() if len(metric_mat) == 0: metric_mat = metric_vals else: metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) for _c in range(output_classes - 1): val0 = torch.nan_to_num(value[0, _c], nan=0.0) val1 = 1.0 - torch.isnan(value[0, 0]).float() metric[2 * _c] += val0 * val1 metric[2 * _c + 1] += val1 _index += 1 # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) metric = metric.tolist() if dist.get_rank() == 0: for _c in range(output_classes - 1): print( "evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) avg_metric = 0 for _c in range(output_classes - 1): avg_metric += metric[2 * _c] / metric[2 * _c + 1] avg_metric = avg_metric / float(output_classes - 1) print("avg_metric", avg_metric) if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 best_metric_iterations = idx_iter node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode( ) torch.save( { "node_a": node_a_d, "arch_code_a": arch_code_a_d, "arch_code_a_max": arch_code_a_max_d, "arch_code_c": arch_code_c_d, "iter_num": idx_iter, "epochs": epoch + 1, "best_dsc": best_metric, "best_path": best_metric_iterations, }, os.path.join(args.output_root, "search_code_" + str(idx_iter) + ".pth"), ) print("saved new best metric model") dict_file = {} dict_file["best_avg_dice_score"] = float(best_metric) dict_file["best_avg_dice_score_epoch"] = int( best_metric_epoch) dict_file["best_avg_dice_score_iteration"] = int(idx_iter) with open(os.path.join(args.output_root, "progress.yaml"), "w") as out_file: documents = yaml.dump(dict_file, stream=out_file) print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, avg_metric, best_metric, best_metric_epoch)) current_time = time.time() elapsed_time = (current_time - start_time) / 60.0 with open( os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write( "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n" .format(epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter)) dist.barrier() torch.cuda.empty_cache() print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) if dist.get_rank() == 0: writer.close() dist.destroy_process_group() return
RandRotated, RandSpatialCrop, RandSpatialCropd, RandZoom, RandZoomd, ) from monai.utils import set_determinism TESTS: List[Tuple] = [] for pad_collate in [pad_list_data_collate, PadListDataCollate()]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) TESTS.append( (list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) TESTS.append( (list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) TESTS.append((list, pad_collate,
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = ( make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), # test EnsureTensor for complicated dict data and invert it CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"), # test to support Tensor, Numpy array and dictionary when inverting EnsureTyped(keys=["image", "test_dict"]), ToTensord("image"), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted", "label_inverted", "test_dict"], transform=transform, orig_keys=["label", "label", "test_dict"], meta_keys=[ PostFix.meta("image_inverted"), PostFix.meta("label_inverted"), None ], orig_meta_keys=[ PostFix.meta("label"), PostFix.meta("label"), None ], nearest_interp=True, to_tensor=[True, False, False], device="cpu", ) inverter_1 = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted1", "label_inverted1"], transform=transform, orig_keys=["image", "image"], meta_keys=[ PostFix.meta("image_inverted1"), PostFix.meta("label_inverted1") ], orig_meta_keys=[PostFix.meta("image"), PostFix.meta("image")], nearest_interp=[True, False], to_tensor=[True, True], device="cpu", ) expected_keys = [ "image", "image_inverted", "image_inverted1", PostFix.meta("image_inverted1"), PostFix.meta("image_inverted"), PostFix.meta("image"), "image_transforms", "label", "label_inverted", "label_inverted1", PostFix.meta("label_inverted1"), PostFix.meta("label_inverted"), PostFix.meta("label"), "label_transforms", "test_dict", "test_dict_transforms", ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) item = inverter_1(item) self.assertListEqual(sorted(item), expected_keys) self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) # check the nearest interpolation mode i = item["image_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # test inverted test_dict self.assertTrue( isinstance(item["test_dict"]["affine"], np.ndarray)) self.assertTrue( isinstance(item["test_dict"]["filename_or_obj"], str)) # check the case that different items use different interpolation mode to invert transforms d = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) d = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype( np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821), f"diff. {reverted.size - n_good}") set_determinism(seed=None)
has_nib = True else: _, has_nib = optional_import("nibabel") KEYS = ["image", "label"] TESTS = [ (t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS), RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, range_x=np.pi), RandAffined(keys=KEYS, rotate_range=np.pi), ] ] class TestInverseCollation(unittest.TestCase): """Test collation for of random transformations with prob == 0 and 1.""" def setUp(self): if not has_nib: self.skipTest("nibabel required for test_inverse") set_determinism(seed=0) im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)]
"Zoomd 3d", "3D", 3e-2, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), )) TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) TESTS.append(( "RandRotated, prob 0", "2D", 0, RandRotated(KEYS, prob=0), )) TESTS.append(( "Rotated 2d", "2D", 8e-2, Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), )) TESTS.append(( "Rotated 3d", "3D",
_, has_nib = optional_import("nibabel") KEYS = ["image", "label"] TESTS_3D = [( t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 3 ) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS, prob=0.5), Compose( [RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64), RandAffined(keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device( "cuda" if torch.cuda.is_available() else "cpu")), ]] TESTS_2D = [ (t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 2) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]), RandAxisFlipd(keys=KEYS, prob=0.5), Compose([ RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)),
def run_training(train_file_list, valid_file_list, config_info): """ Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks: * Data Preparation: Extract the filenames and prepare the training/validation processing transforms * Load Data: Load training and validation data to PyTorch DataLoader * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric. * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop. * Run training: The MONAI trainer is run, performing training and validation during training. Args: train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format. valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format. config_info: dict, contains configuration parameters for sampling, network and training. See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. """ """ Read input and configuration parameters """ # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # print to log the parameter setups print(yaml.dump(config_info)) # extract network parameters, perform checks/set defaults if not present and print them to log if 'seg_labels' in config_info['training'].keys(): seg_labels = config_info['training']['seg_labels'] else: seg_labels = [1] nr_out_channels = len(seg_labels) print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels)) patch_size = config_info["training"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) spacing = config_info["training"]["spacing"] print("Bringing all images to spacing = {}".format(spacing)) if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None: model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError("Cannot find model: {}".format(model_to_load)) else: print("Loading model from {}".format(model_to_load)) else: model_to_load = None # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) # set determinism if required if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None: seed = config_info['training']['manual_seed'] else: seed = None if seed is not None: print("Using determinism with seed = {}\n".format(seed)) set_determinism(seed=seed) """ Setup data output directory """ out_model_dir = os.path.join(config_info['output']['out_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['out_postfix']) print("Saving to directory {}\n".format(out_model_dir)) # create cache directory to store results for Persistent Dataset if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) """ Data preparation """ # Read the input files for training and validation print("*** Loading input data for training...") train_files = create_data_list_of_dictionaries(train_file_list) print("Number of inputs for training = {}".format(len(train_files))) val_files = create_data_list_of_dictionaries(valid_file_list) print("Number of inputs for validation = {}".format(len(val_files))) # Define MONAI processing transforms for the training data. This includes: # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1] # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M]) # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, # RandFlipd) # - ToTensor: convert to pytorch tensor train_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False), SqueezeDimd(keys=["image", "label"], dim=-1), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("bilinear", "nearest"), align_corners=(True, None), prob=0.16, ), RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2, keep_size=True, mode=["bilinear", "nearest"], padding_mode=["zeros", "border"]), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5), ToTensord(keys=["image", "label"]), ] ) # Define MONAI processing transforms for the validation data # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - ToTensor: convert to pytorch tensor # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image", "label"]), ] ) """ Load data """ # create training data loader train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=config_info['training']['batch_size_train'], shuffle=True, num_workers=config_info['device']['num_workers']) check_train_data = misc.first(train_loader) print("Training data tensor shapes:") print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape)) # create validation data loader if config_info['training']['batch_size_valid'] != 1: raise Exception("Batch size different from 1 at validation ar currently not supported") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=config_info['device']['num_workers']) check_valid_data = misc.first(val_loader) print("Validation data tensor shapes (Example):") print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape)) """ Network preparation """ print("*** Preparing the network ...") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) # initialise the network net = DynUNet( spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False, ).to(current_device) print(net) # define the loss function loss_function = choose_loss_function(nr_out_channels, config_info) # define the optimiser and the learning rate scheduler opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95) scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9 ) """ MONAI evaluator """ print("*** Preparing the dynUNet evaluator engine...\n") # val_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"), output_transform=lambda x: None, global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True, file_prefix='best_valid'), ] if config_info['output']['val_image_to_tensorboad']: val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"), batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], interval=2)) # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2,)) flip_inputs_2 = torch.flip(inputs, dims=(3,)) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=None, key_val_metric={ "Mean_dice": MeanDice( include_background=False, to_onehot_y=True, mutually_exclusive=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=False, ) """ MONAI trainer """ print("*** Preparing the dynUNet trainer engine...\n") # train_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len # define event handlers for the trainer writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer_train, log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss", output_transform=lambda x: x["loss"], global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_final=True, save_interval=2, epoch_level=True, n_saved=config_info['output']['max_nr_models_saved']), ] if model_to_load is not None: train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt})) # define customized trainer class DynUNetTrainer(SupervisedTrainer): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) def _compute_loss(preds, label): labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]] return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))]) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets).mean() loss.backward() self.optimizer.step() return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()} trainer = DynUNetTrainer( device=current_device, max_epochs=config_info['training']['nr_train_epochs'], train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss_function, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=False, ) """ Run training """ print("*** Run training...") trainer.run() print("Done!")
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS), CastToTyped(KEYS, dtype=torch.uint8), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) for i in engine.state.output["image_inverted"] + engine.state.output[ "label_inverted"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted"][-1].detach().cpu( ).numpy()[0].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) self.assertTrue((reverted.size - n_good) in (25300, 1812), "diff. in two possible values")