def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # check gpus if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] val_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(test_images, test_segs)] # Transforms list if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # CT HU filter # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # build the network if opt.network is 'nnunet': net = build_net() # nn build_net elif opt.network is 'unetr': net = build_UNETR() # UneTR net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) loss_function = monai.losses.DiceCELoss(sigmoid=True) torch.backends.cudnn.benchmark = opt.benchmark if opt.network is 'nnunet': optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,) net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9) elif opt.network is 'unetr': optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if opt.network is 'nnunet': update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice(val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch ) ) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label # val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image") plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label") plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def test_test_time_augmentation(self): input_size = (20, 20) device = "cuda" if torch.cuda.is_available() else "cpu" keys = ["image", "label"] num_training_ims = 10 train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) transforms = Compose([ AddChanneld(keys), RandAffined( keys, prob=1.0, spatial_size=(30, 30), rotate_range=(np.pi / 3, np.pi / 3), translate_range=(3, 3), scale_range=((0.8, 1), (0.8, 1)), padding_mode="zeros", mode=("bilinear", "nearest"), as_tensor_output=False, ), CropForegroundd(keys, source_key="image"), DivisiblePadd(keys, 4), ]) train_ds = CacheDataset(train_data, transforms) # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) num_epochs = 10 for _ in trange(num_epochs): epoch_loss = 0 for batch_data in train_loader: inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= len(train_loader) post_trans = Compose([ Activations(sigmoid=True), AsDiscrete(threshold_values=True), ]) def inferrer_fn(x): return post_trans(model(x)) tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) mode, mean, std, vvc = tt_aug(test_data) self.assertEqual(mode.shape, (1, ) + input_size) self.assertEqual(mean.shape, (1, ) + input_size) self.assertTrue(all(np.unique(mode) == (0, 1))) self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) self.assertEqual(std.shape, (1, ) + input_size) self.assertIsInstance(vvc, float)
cfg.thickness = [32, 96] cfg.train_aug = Compose([ Resized( keys=("input", "mask"), spatial_size=1120, size_mode="longest", mode="bilinear", align_corners=False, ), SpatialPadd(keys=("input", "mask"), spatial_size=(1120, 1120)), RandFlipd(keys=("input", "mask"), prob=0.5, spatial_axis=1), RandAffined( keys=("input", "mask"), prob=0.5, rotate_range=np.pi / 14.4, translate_range=(70, 70), scale_range=(0.1, 0.1), as_tensor_output=False, ), RandSpatialCropd( keys=("input", "mask"), roi_size=(cfg.img_size[0], cfg.img_size[1]), random_size=False, ), RandScaleIntensityd(keys="input", factors=(-0.2, 0.2), prob=0.5), RandShiftIntensityd(keys="input", offsets=(-51, 51), prob=0.5), RandLambdad(keys="input", func=lambda x: 255 - x, prob=0.5), RandCoarseDropoutd( keys=("input", "mask"), holes=8, spatial_size=(1, 1),
TESTS_3D = [( t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 3 ) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS, prob=0.5), Compose( [RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64), RandAffined(keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device( "cuda" if torch.cuda.is_available() else "cpu")), ]] TESTS_2D = [ (t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 2) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]), RandAxisFlipd(keys=KEYS, prob=0.5), Compose([ RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), ToTensord(keys=KEYS) ]), RandZoomd(
has_nib = True else: _, has_nib = optional_import("nibabel") KEYS = ["image", "label"] TESTS = [ (t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS), RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, range_x=np.pi), RandAffined(keys=KEYS, rotate_range=np.pi), ] ] class TestInverseCollation(unittest.TestCase): """Test collation for of random transformations with prob == 0 and 1.""" def setUp(self): if not has_nib: self.skipTest("nibabel required for test_inverse") set_determinism(seed=0) im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)] load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)])
TESTS_3D = [( t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 3 ) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS, prob=0.5), RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), RandAffined( keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), as_tensor_output=False, ), ]] TESTS_2D = [( t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 2 ) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]), RandAxisFlipd(keys=KEYS, prob=0.5), RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, postfix="inverted1", to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) # test different nearest interpolation values TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="image", nearest_interp=[True, False], post_func=[lambda x: x + 10, lambda x: x], postfix="inverted2", collate_fn=pad_list_data_collate, num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) # check the nearest inerpolation mode for i in engine.state.output["image_inverted1"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted1"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted1"][-1].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1824: torch 1.5.1 self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") # check the case that different items use different interpolation mode to invert transforms d = engine.state.output["image_inverted2"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107)) d = engine.state.output["label_inverted2"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107))
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted( glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted( glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted( glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted( glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted( glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted( glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted( glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted( glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip( train_images_for_dice, train_segs_for_dice)] val_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(test_images, test_segs)] # Transforms list if opt.resolution is not None: train_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # try to use all the available GPUs devices = get_devices_spec(None) # build the network net = build_net() net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # loss_function = monai.losses.DiceLoss(sigmoid=True) loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7) optim = torch.optim.Adam(net.parameters(), lr=opt.lr) net_scheduler = get_scheduler(optim, opt) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda( ), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda( ), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, net) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice( val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice( train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice( test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}" .format(epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch)) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose([ EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2) ]) post_label = Compose( [EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
translate_params=[10, 5, -4], scale_params=[0.8, 1.3], ), ) ) TESTS.append( ( "RandAffine 3d", "3D", 1e-1, RandAffined( KEYS, [155, 179, 192], prob=1, padding_mode="zeros", rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], shear_range=[(0.5, 0.5)], translate_range=[10, 5, -4], scale_range=[(0.8, 1.2), (0.9, 1.3)], ), ) ) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore def no_collation(x): return x
def main(config): now = datetime.now().strftime("%Y%m%d-%H:%M:%S") # path csv_path = config['path']['csv_path'] trained_model_path = config['path'][ 'trained_model_path'] # if None, trained from scratch training_model_folder = os.path.join( config['path']['training_model_folder'], now) # '/path/to/folder' if not os.path.exists(training_model_folder): os.makedirs(training_model_folder) logdir = os.path.join(training_model_folder, 'logs') if not os.path.exists(logdir): os.makedirs(logdir) # PET CT scan params image_shape = tuple(config['preprocessing']['image_shape']) # (x, y, z) in_channels = config['preprocessing']['in_channels'] voxel_spacing = tuple( config['preprocessing'] ['voxel_spacing']) # (4.8, 4.8, 4.8) # in millimeter, (x, y, z) data_augment = config['preprocessing'][ 'data_augment'] # True # for training dataset only resize = config['preprocessing']['resize'] # True # not use yet origin = config['preprocessing']['origin'] # how to set the new origin normalize = config['preprocessing'][ 'normalize'] # True # whether or not to normalize the inputs number_class = config['preprocessing']['number_class'] # 2 # CNN params architecture = config['model']['architecture'] # 'unet' or 'vnet' cnn_params = config['model'][architecture]['cnn_params'] # transform list to tuple for key, value in cnn_params.items(): if isinstance(value, list): cnn_params[key] = tuple(value) # Training params epochs = config['training']['epochs'] batch_size = config['training']['batch_size'] shuffle = config['training']['shuffle'] opt_params = config['training']["optimizer"]["opt_params"] # Get Data DM = DataManager(csv_path=csv_path) train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test( wrap_with_dict=True) # Input preprocessing # use data augmentation for training train_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # user can also add other random transforms RandAffined(keys=("pet_img", "ct_img", "mask_img"), spatial_size=None, prob=0.4, rotate_range=(0, np.pi / 30, np.pi / 15), shear_range=None, translate_range=(10, 10, 10), scale_range=(0.1, 0.1, 0.1), mode=("bilinear", "bilinear", "nearest"), padding_mode="border"), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # without data augmentation for validation val_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_images_paths, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle, num_workers=2) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_images_paths, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=batch_size, num_workers=2) # Model # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, # 3D in_channels=in_channels, out_channels=1, kernel_size=5, channels=(8, 16, 32, 64, 128), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True) opt = torch.optim.Adam(net.parameters(), 1e-3) # training val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), # TensorBoardImageHandler( # log_dir="./runs/", # batch_transform=lambda x: (x["image"], x["label"]), # output_transform=lambda x: x["pred"], # ), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SimpleInferer(), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "val_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "val_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) train_handlers = [ # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, prepare_batch=lambda x: (x['image'], x['mask_img']), inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "train_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "train_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()
keys = ("img", "seg") # use these when interpolating binary segmentations to ensure values are 0 or 1 only zoom_mode = monai.utils.enums.InterpolateMode.NEAREST elast_mode = monai.utils.enums.GridSampleMode.BILINEAR, monai.utils.enums.GridSampleMode.NEAREST trans = Compose( [ ScaleIntensityd(keys=("img",)), # rescale image data to range [0,1] AddChanneld(keys=keys), # add 1-size channel dimension RandRotate90d(keys=keys, prob=aug_prob), RandFlipd(keys=keys, prob=aug_prob), RandZoomd(keys=keys, prob=aug_prob, mode=zoom_mode), Rand2DElasticd(keys=keys, prob=aug_prob, spacing=10, magnitude_range=(-2, 2), mode=elast_mode), RandAffined(keys=keys, prob=aug_prob, rotate_range=1, translate_range=16, mode=elast_mode), ToTensord(keys=keys), # convert to tensor ] ) data = [ {"img": train_images[i], "seg": train_segs[i]} for i in range(len(train_images)) ] ds = CacheDataset(data, trans) loader = DataLoader( dataset=ds, batch_size=batch_size, num_workers=num_workers, pin_memory=torch.cuda.is_available(),
def training(train_files, val_files, log_dir): # Define transforms for image print(log_dir) train_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ConcatItemsd(keys=modalDataKey, name="inputs"), RandRotate90d(keys=["inputs"], prob=0.8, spatial_axes=[0, 1]), RandAffined(keys=["inputs"], prob=0.8, scale_range=[0.1, 0.5]), RandZoomd(keys=["inputs"], prob=0.8, max_zoom=1.5, min_zoom=0.5), # RandFlipd(keys=["inputs"], prob=0.5, spatial_axis=1), ToTensord(keys=["inputs"]), ] ) val_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ConcatItemsd(keys=modalDataKey, name="inputs"), ToTensord(keys=["inputs"]), ] ) # data_size = len(full_files) # split = data_size // 2 # indices = list(range(data_size)) # train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) # valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) # full_loader = DataLoader(full_files, batch_size=64, sampler=sampler(full_files), pin_memory=True) # train_loader = DataLoader(full_files, batch_size=128, sampler=train_sampler, collate_fn=collate_fn) # val_loader = DataLoader(full_files, batch_size=split, sampler=valid_sampler, collate_fn=collate_fn) # DL = DataLoader(train_files, batch_size=64, shuffle=True, num_workers=0, drop_last=True, collate_fn=collate_fn) # randomBatch_sizeList = [8, 16, 32, 64, 128] # randomLRList = [1e-4, 1e-5, 5e-5, 5e-4, 1e-3] # batch_size = random.choice(randomBatch_sizeList) # lr = random.choice(randomLRList) lr = 0.01 batch_size = 256 # print(batch_size) # print(lr) # Define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device) check_data = monai.utils.misc.first(check_loader) # print(check_data) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.device) # train_data = monai.utils.misc.first(train_loader) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device) # Create Net, CrossEntropyLoss and Adam optimizer # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device) # model = densenet121(spatial_dims=2, in_channels=3, out_channels=5).to(device) # im_size = (2,) + tuple(train_ds[0]["inputs"].shape) model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device) classes = np.array([0, 1, 2, 3, 4]) # print(check_data["label"].numpy()) class_weights = class_weight.compute_class_weight('balanced', classes, check_data["label"].numpy()) class_weights_tensor = torch.Tensor(class_weights).to(device) # print(class_weights_tensor) # loss_function = nn.BCEWithLogitsLoss() loss_function = torch.nn.CrossEntropyLoss(weight=class_weights_tensor) # loss_function = torch.nn.MSELoss() # m = torch.nn.LogSoftmax(dim=1) optimizer = torch.optim.Adam(model.parameters(), lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.5, last_epoch=-1) # 如果有保存的模型,则加载模型,并在其基础上继续训练 if os.path.exists(log_dir): checkpoint = torch.load(log_dir) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] print('加载 epoch {} 成功!'.format(start_epoch)) else: start_epoch = 0 print('无保存模型,将从头开始训练!') # start a typical PyTorch training epoch_num = 300 val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() # checkpoint_interval = 100 for epoch in range(start_epoch + 1, epoch_num): print("-" * 10) print(f"epoch {epoch + 1}/{epoch_num}") # print(scheduler.get_last_lr()) model.train() epoch_loss = 0 step = 0 # for i, (inputs, labels, imgName) in enumerate(train_loader): for batch_data in train_loader: step += 1 inputs, labels = batch_data["inputs"].to(device), batch_data["label"].to(device) # batch_arr = [] # for j in range(len(inputs)): # batch_arr.append(inputs[i]) # batch_img = Variable(torch.from_numpy(np.array(batch_arr)).to(device)) # labels = Variable(torch.from_numpy(np.array(labels)).to(device)) # batch_img = batch_img.type(torch.FloatTensor).to(device) outputs = model(inputs) # y_ordinal_encoding = transformOrdinalEncoding(labels, labels.shape[0], 5) # loss = loss_function(outputs, torch.from_numpy(y_ordinal_encoding).to(device)) loss = loss_function(outputs, labels.long()) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print(f"{step}/{len(train_loader)}, train_loss: {loss.item():.4f}") epoch_len = len(train_loader) // train_loader.batch_size writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step scheduler.step() print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0])) epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") # if (epoch + 1) % checkpoint_interval == 0: # 每隔checkpoint_interval保存一次 # checkpoint = {'model': model.state_dict(), # 'optimizer': optimizer.state_dict(), # 'epoch': epoch # } # path_checkpoint = './model/checkpoint_{}_epoch.pth'.format(epoch) # torch.save(checkpoint, path_checkpoint) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) # for i, (inputs, labels, imgName) in enumerate(val_loader): for val_data in val_loader: val_images, val_labels = val_data["inputs"].to(device), val_data["label"].to(device) # val_batch_arr = [] # for j in range(len(inputs)): # val_batch_arr.append(inputs[i]) # val_img = Variable(torch.from_numpy(np.array(val_batch_arr)).to(device)) # labels = Variable(torch.from_numpy(np.array(labels)).to(device)) # val_img = val_img.type(torch.FloatTensor).to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) # y_ordinal_encoding = transformOrdinalEncoding(y, y.shape[0], 5) # y_pred = torch.sigmoid(y_pred) # y = (y / 0.25).long() # print(y) # auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True) # zero = torch.zeros_like(y_pred) # one = torch.ones_like(y_pred) # y_pred_label = torch.where(y_pred > 0.5, one, zero) # print((y_pred_label.sum(1)).to(torch.long)) # y_pred_acc = (y_pred_label.sum(1)).to(torch.long) # print(y_pred.argmax(dim=1)) # kappa_value = kappa(cm) kappa_value = cohen_kappa_score(y.to("cpu"), y_pred.argmax(dim=1).to("cpu"), weights='quadratic') # kappa_value = cohen_kappa_score(y.to("cpu"), y_pred_acc.to("cpu"), weights='quadratic') metric_values.append(kappa_value) acc_value = torch.eq(y_pred.argmax(dim=1), y) # print(acc_value) acc_metric = acc_value.sum().item() / len(acc_value) if kappa_value > best_metric: best_metric = kappa_value best_metric_epoch = epoch + 1 checkpoint = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(checkpoint, log_dir) print("saved new best metric model") print( "current epoch: {} current Kappa: {:.4f} current accuracy: {:.4f} best Kappa: {:.4f} at epoch {}".format( epoch + 1, kappa_value, acc_metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_accuracy", acc_metric, epoch + 1) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close() plt.figure('train', (12, 6)) plt.subplot(1, 2, 1) plt.title("Epoch Average Loss") x = [i + 1 for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel('epoch') plt.plot(x, y) plt.subplot(1, 2, 2) plt.title("Validation: Area under the ROC curve") x = [val_interval * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel('epoch') plt.plot(x, y) plt.show() evaluta_model(val_files, log_dir)
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS), CastToTyped(KEYS, dtype=torch.uint8), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) for i in engine.state.output["image_inverted"] + engine.state.output[ "label_inverted"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted"][-1].detach().cpu( ).numpy()[0].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) self.assertTrue((reverted.size - n_good) in (25300, 1812), "diff. in two possible values")
def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # Setup directories dirs = setup_directories() # Setup torch device device, using_gpu = create_device("cuda") # Load and randomize images # HACKATON image and segmentation data hackathon_dir = os.path.join(dirs["data"], 'HACKATHON') map_fn = lambda x: (x[0], int(x[1])) with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp: train_info_hackathon = [ map_fn(entry.strip().split(',')) for entry in fp.readlines() ] image_dir = os.path.join(hackathon_dir, 'images', 'train') seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train') _train_data_hackathon = get_data_from_info(image_dir, seg_dir, train_info_hackathon, dual_output=False) _train_data_hackathon = large_image_splitter(_train_data_hackathon, dirs["cache"]) copy_list = transform_and_copy(_train_data_hackathon, dirs['cache']) balance_training_data2(_train_data_hackathon, copy_list, seed=72) # PSUF data """psuf_dir = os.path.join(dirs["data"], 'psuf') with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp: train_info = [entry.strip().split(',') for entry in fp.readlines()] image_dir = os.path.join(psuf_dir, 'images') train_data_psuf = get_data_from_info(image_dir, None, train_info)""" # Split data into train, validate and test train_split, test_data_hackathon = train_test_split(_train_data_hackathon, test_size=0.2, shuffle=True, random_state=42) train_data_hackathon, valid_data_hackathon = train_test_split( train_split, test_size=0.2, shuffle=True, random_state=43) #balance_training_data(train_data_hackathon, seed=72) #balance_training_data(valid_data_hackathon, seed=73) #balance_training_data(test_data_hackathon, seed=74) # Setup transforms # Crop foreground crop_foreground = CropForegroundd(keys=["image"], source_key="image", margin=(5, 5, 0), select_fn=lambda x: x != 0) # Crop Z crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12)) # Window width and level (window center) WW, WL = 1500, -600 ct_window = CTWindowd(keys=["image"], width=WW, level=WL) # Random axis flip rand_x_flip = RandFlipd(keys=["image"], spatial_axis=0, prob=0.50) rand_y_flip = RandFlipd(keys=["image"], spatial_axis=1, prob=0.50) rand_z_flip = RandFlipd(keys=["image"], spatial_axis=2, prob=0.50) # Rand affine transform rand_affine = RandAffined(keys=["image"], prob=0.5, rotate_range=(0, 0, np.pi / 12), shear_range=(0.07, 0.07, 0.0), translate_range=(0, 0, 0), scale_range=(0.07, 0.07, 0.0), padding_mode="zeros") # Pad image to have hight at least 30 spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30)) resize = Resized(keys=["image"], spatial_size=(int(512 * 0.50), int(512 * 0.50), -1), mode="trilinear") # Apply Gaussian noise rand_gaussian_noise = RandGaussianNoised(keys=["image"], prob=0.25, mean=0.0, std=0.1) # Create transforms common_transform = Compose([ LoadImaged(keys=["image"]), ct_window, CTSegmentation(keys=["image"]), AddChanneld(keys=["image"]), resize, crop_foreground, crop_z, spatial_pad, ]) hackathon_train_transform = Compose([ common_transform, rand_x_flip, rand_y_flip, rand_z_flip, rand_affine, rand_gaussian_noise, ToTensord(keys=["image"]), ]).flatten() hackathon_valid_transfrom = Compose([ common_transform, #rand_x_flip, #rand_y_flip, #rand_z_flip, #rand_affine, ToTensord(keys=["image"]), ]).flatten() hackathon_test_transfrom = Compose([ common_transform, ToTensord(keys=["image"]), ]).flatten() psuf_transforms = Compose([ LoadImaged(keys=["image"]), AddChanneld(keys=["image"]), ToTensord(keys=["image"]), ]) # Setup data #set_determinism(seed=100) train_dataset = PersistentDataset(data=train_data_hackathon[:], transform=hackathon_train_transform, cache_dir=dirs["persistent"]) valid_dataset = PersistentDataset(data=valid_data_hackathon[:], transform=hackathon_valid_transfrom, cache_dir=dirs["persistent"]) test_dataset = PersistentDataset(data=test_data_hackathon[:], transform=hackathon_test_transfrom, cache_dir=dirs["persistent"]) train_loader = DataLoader( train_dataset, batch_size=4, #shuffle=True, pin_memory=using_gpu, num_workers=2, sampler=ImbalancedDatasetSampler( train_data_hackathon, callback_get_label=lambda x, i: x[i]['_label']), collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT)) valid_loader = DataLoader( valid_dataset, batch_size=4, shuffle=False, pin_memory=using_gpu, num_workers=2, sampler=ImbalancedDatasetSampler( valid_data_hackathon, callback_get_label=lambda x, i: x[i]['_label']), collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT)) test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=using_gpu, num_workers=2, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Setup network, loss function, optimizer and scheduler network = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device) # pos_weight for class imbalance _, n, p = calculate_class_imbalance(train_data_hackathon) pos_weight = torch.Tensor([n, p]).to(device) loss_function = torch.nn.BCEWithLogitsLoss(pos_weight) optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=0) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95, last_epoch=-1) # Setup validator and trainer valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), #Activationsd(keys="pred", softmax=True), ]) validator = Validator(device=device, val_data_loader=valid_loader, network=network, post_transform=valid_post_transforms, amp=using_gpu, non_blocking=using_gpu) trainer = Trainer(device=device, out_dir=dirs["out"], out_name="DenseNet121", max_epochs=120, validation_epoch=1, validation_interval=1, train_data_loader=train_loader, network=network, optimizer=optimizer, loss_function=loss_function, lr_scheduler=None, validator=validator, amp=using_gpu, non_blocking=using_gpu) """x_max, y_max, z_max, size_max = 0, 0, 0, 0 for data in valid_loader: image = data["image"] label = data["label"] print() print(len(data['image_transforms'])) #print(data['image_transforms']) print(label) shape = image.shape x_max = max(x_max, shape[-3]) y_max = max(y_max, shape[-2]) z_max = max(z_max, shape[-1]) size = int(image.nelement()*image.element_size()/1024/1024) size_max = max(size_max, size) print("shape:", shape, "size:", str(size)+"MB") #multi_slice_viewer(image[0, 0, :, :, :], str(label)) print(x_max, y_max, z_max, str(size_max)+"MB") exit()""" # Run trainer train_output = trainer.run() # Setup tester tester = Tester(device=device, test_data_loader=test_loader, load_dir=train_output, out_dir=dirs["out"], network=network, post_transform=valid_post_transforms, non_blocking=using_gpu, amp=using_gpu) # Run tester tester.run()