def run_interaction(self, train, compose): data = [] for i in range(5): data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])}) network = torch.nn.Linear(1, 1) lr = 1e-3 opt = torch.optim.SGD(network.parameters(), lr) loss = torch.nn.L1Loss() dataset = Dataset(data, transform=None) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")] iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing") self.assertEqual(engine.state.best_metric, 9)
def run_interaction(self, train, compose): data = [{ "image": np.ones((1, 2, 2, 2)).astype(np.float32), "label": np.ones((1, 2, 2, 2)) } for _ in range(5)] network = torch.nn.Linear(2, 2) lr = 1e-3 opt = torch.optim.SGD(network.parameters(), lr) loss = torch.nn.L1Loss() train_transforms = Compose([ FindAllValidSlicesd(label="label", sids="sids"), AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ]) dataset = Dataset(data, transform=train_transforms) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [ Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys=["image", "label", "pred"]), FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ] iteration_transforms = Compose( iteration_transforms) if compose else iteration_transforms i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) self.assertEqual(len(i.transforms.transforms), 6, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing") self.assertEqual(engine.state.best_metric, 9)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{ Keys.IMAGE: img, Keys.LABEL: seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ Keys.IMAGE: img, Keys.LABEL: seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=[Keys.IMAGE, Keys.LABEL]), AsChannelFirstd(keys=[Keys.IMAGE, Keys.LABEL], channel_dim=-1), ScaleIntensityd(keys=[Keys.IMAGE, Keys.LABEL]), RandCropByPosNegLabeld(keys=[Keys.IMAGE, Keys.LABEL], label_key=Keys.LABEL, size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=[Keys.IMAGE, Keys.LABEL], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=[Keys.IMAGE, Keys.LABEL]), ]) val_transforms = Compose([ LoadNiftid(keys=[Keys.IMAGE, Keys.LABEL]), AsChannelFirstd(keys=[Keys.IMAGE, Keys.LABEL], channel_dim=-1), ScaleIntensityd(keys=[Keys.IMAGE, Keys.LABEL]), ToTensord(keys=[Keys.IMAGE, Keys.LABEL]), ]) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) val_handlers = [StatsHandler(output_transform=lambda x: None)] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), val_handlers=val_handlers, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, add_sigmoid=True, output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL])) }, additional_metrics=None, ) train_handlers = [ ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x[Keys.INFO][Keys.LOSS]), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), train_handlers=train_handlers, amp=False, key_train_metric=None, ) trainer.run() shutil.rmtree(tempdir)
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) ################################ DATASET ################################ # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])] val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), ] ) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) ################################ DATASET ################################ ################################ NETWORK ################################ # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) ################################ NETWORK ################################ ################################ LOSS ################################ loss = monai.losses.DiceLoss(sigmoid=True) ################################ LOSS ################################ ################################ OPT ################################ opt = torch.optim.Adam(net.parameters(), 1e-3) ################################ OPT ################################ ################################ LR ################################ lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) ################################ LR ################################ val_post_transforms = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), TensorBoardImageHandler( log_dir="./runs/", batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], ), CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) train_post_transforms = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()
def train(self, train_info, valid_info, hyperparameters, run_data_check=False): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not run_data_check: start_dt = datetime.datetime.now() start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S') print(f'Training started: {start_dt_string}') # 1. Create folders to save the model timedate_info = str( datetime.datetime.now()).split(' ')[0] + '_' + str( datetime.datetime.now().strftime("%H:%M:%S")).replace( ':', '-') path_to_model = os.path.join( self.out_dir, 'trained_models', self.unique_name + '_' + timedate_info) os.mkdir(path_to_model) # 2. Load hyperparameters learning_rate = hyperparameters['learning_rate'] weight_decay = hyperparameters['weight_decay'] total_epoch = hyperparameters['total_epoch'] multiplicator = hyperparameters['multiplicator'] batch_size = hyperparameters['batch_size'] validation_epoch = hyperparameters['validation_epoch'] validation_interval = hyperparameters['validation_interval'] H = hyperparameters['H'] L = hyperparameters['L'] # 3. Consider class imbalance negative, positive = 0, 0 for _, label in train_info: if int(label) == 0: negative += 1 elif int(label) == 1: positive += 1 pos_weight = torch.Tensor([(negative / positive)]).to(self.device) # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices) train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir, train_info) valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir, valid_info) large_image_splitter(train_data, self.cache_dir) set_determinism(seed=100) train_trans, valid_trans = self.transformations(H, L) train_dataset = PersistentDataset( data=train_data[:], transform=train_trans, cache_dir=self.persistent_dataset_dir) valid_dataset = PersistentDataset( data=valid_data[:], transform=valid_trans, cache_dir=self.persistent_dataset_dir) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory, num_workers=self.num_workers, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory, num_workers=self.num_workers, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Perform data checks if run_data_check: check_data = monai.utils.misc.first(train_loader) print(check_data["image"].shape, check_data["label"]) for i in range(batch_size): multi_slice_viewer( check_data["image"][i, 0, :, :, :], check_data["image_meta_dict"]["filename_or_obj"][i]) exit() """c = 1 for d in train_loader: img = d["image"] seg = d["seg"][0] seg, _ = nrrd.read(seg) img_name = d["image_meta_dict"]["filename_or_obj"][0] print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape) multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0]) #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0]) c += 1 exit()""" # 5. Prepare model model = ModelCT().to(self.device) # 6. Define loss function, optimizer and scheduler loss_function = torch.nn.BCEWithLogitsLoss( pos_weight) # pos_weight for class imbalance optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, multiplicator, last_epoch=-1) # 7. Create post validation transforms and handlers path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard') writer = SummaryWriter(log_dir=path_to_tensorboard) valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), ]) valid_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(summary_writer=writer, output_transform=lambda x: None), CheckpointSaver(save_dir=path_to_model, save_dict={"model": model}, save_key_metric=True), MetricsSaver(save_dir=path_to_model, metrics=['Valid_AUC', 'Valid_ACC']), ] # 8. Create validatior discrete = AsDiscrete(threshold_values=True) evaluator = SupervisedEvaluator( device=self.device, val_data_loader=valid_loader, network=model, post_transform=valid_post_transforms, key_val_metric={ "Valid_AUC": ROCAUC(output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "Valid_Accuracy": Accuracy(output_transform=lambda x: (discrete(x["pred"]), x["label"])) }, val_handlers=valid_handlers, amp=self.amp, ) # 9. Create trainer # Loss function does the last sigmoid, so we dont need it here. train_post_transforms = Compose([ # Empty ]) logger = MetricLogger(evaluator=evaluator) train_handlers = [ logger, LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandlerCT(validator=evaluator, start=validation_epoch, interval=validation_interval, epoch_level=True), StatsHandler(tag_name="loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer, tag_name="Train_Loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=path_to_model, save_dict={ "model": model, "opt": optimizer }, save_interval=1, n_saved=1), ] trainer = SupervisedTrainer( device=self.device, max_epochs=total_epoch, train_data_loader=train_loader, network=model, optimizer=optimizer, loss_function=loss_function, post_transform=train_post_transforms, train_handlers=train_handlers, amp=self.amp, ) # 10. Run trainer trainer.run() # 11. Save results np.save(path_to_model + '/AUCS.npy', np.array(logger.metrics['Valid_AUC'])) np.save(path_to_model + '/ACCS.npy', np.array(logger.metrics['Valid_ACC'])) np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss)) np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters)) return path_to_model
def run_training_test(root_dir, device="cuda:0", amp=False): images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None), TensorBoardImageHandler(log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"]), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, amp=True if amp else False, ) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=root_dir, save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, amp=True if amp else False, ) trainer.run() return evaluator.state.best_metric
def train(args): if args.local_rank == 0 and not os.path.exists(args.dir): # create 40 random image, mask paris for training print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images, segs)] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ]) # create a training data loader train_ds = Dataset(data=train_files, transform=train_transforms) # create a training data sampler train_sampler = DistributedSampler(train_ds) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=False, num_workers=2, pin_memory=True, sampler=train_sampler, ) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) # wrap the model with DistributedDataParallel module net = DistributedDataParallel(net, device_ids=[device]) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ] if dist.get_rank() == 0: logging.basicConfig(stream=sys.stdout, level=logging.INFO) train_handlers.extend([ StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2), ]) trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, post_transform=train_post_transforms, key_train_metric={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device) }, train_handlers=train_handlers, ) trainer.run() dist.destroy_process_group()
def train(index): # ---------- Build the nn-Unet network ------------ if opt.resolution is None: sizes, spacings = opt.patch_size, opt.spacing else: sizes, spacings = opt.patch_size, opt.resolution strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [ 2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes) ] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) net = monai.networks.nets.DynUNet( spatial_dims=3, in_channels=opt.in_channels, out_channels=opt.out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], res_block=True, # act=act_type, # norm=Norm.BATCH, ).to(device) from torch.autograd import Variable from torchsummaryX import summary data = Variable( torch.randn(int(opt.batch_size), int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda() out = net(data) summary(net, data) print("out size: {}".format(out.size())) # if opt.preload is not None: # net.load_state_dict(torch.load(opt.preload)) # ---------- ------------------------ ------------ optim = torch.optim.Adam(net.parameters(), lr=opt.lr) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs)**0.9) loss_function = monai.losses.DiceCELoss(sigmoid=True) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]) ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loaders[index], network=net, inferer=SlidingWindowInferer(roi_size=opt.patch_size, sw_batch_size=opt.batch_size, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ ValidationHandler(validator=evaluator, interval=5, epoch_level=True), LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": optim }, save_final=True, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=opt.epochs, train_data_loader=train_loaders[index], network=net, optimizer=optim, loss_function=loss_function, inferer=SimpleInferer(), post_transform=train_post_transforms, amp=False, train_handlers=train_handlers, ) trainer.run() return net
def train(cfg): log_dir = create_log_dir(cfg) device = set_device(cfg) # -------------------------------------------------------------------------- # Data Loading and Preprocessing # -------------------------------------------------------------------------- # __________________________________________________________________________ # Build MONAI preprocessing train_preprocess = Compose([ ToTensorD(keys="image"), TorchVisionD(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyD(keys="image"), RandFlipD(keys="image", prob=0.5), RandRotate90D(keys="image", prob=0.5), CastToTypeD(keys="image", dtype=np.float32), RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) valid_preprocess = Compose([ CastToTypeD(keys="image", dtype=np.float32), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) # __________________________________________________________________________ # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( train_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], train_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( valid_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], valid_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # __________________________________________________________________________ # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) # __________________________________________________________________________ # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("Fist sample is None!") print("image: ") print(" shape", first_sample["image"].shape) print(" type: ", type(first_sample["image"])) print(" dtype: ", first_sample["image"].dtype) print("labels: ") print(" shape", first_sample["label"].shape) print(" type: ", type(first_sample["label"])) print(" dtype: ", first_sample["label"].dtype) print(f"batch size: {cfg['batch_size']}") print(f"train number of batches: {len(train_dataloader)}") print(f"valid number of batches: {len(valid_dataloader)}") # -------------------------------------------------------------------------- # Deep Learning Classification Model # -------------------------------------------------------------------------- # __________________________________________________________________________ # initialize model model = TorchVisionFCModel("resnet18", num_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # optimizer if cfg["novograd"]: optimizer = Novograd(model.parameters(), cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler if cfg["amp"]: cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= ( 1, 6) else False else: cfg["amp"] = False scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) # -------------------------------------------- # Ignite Trainer/Evaluator # -------------------------------------------- # Evaluator val_handlers = [ CheckpointSaver(save_dir=log_dir, save_dict={"net": model}, save_key_metric=True), StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None), ] val_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) evaluator = SupervisedEvaluator( device=device, val_data_loader=valid_dataloader, network=model, postprocessing=val_postprocessing, key_val_metric={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=cfg["amp"], ) # Trainer train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), CheckpointSaver(save_dir=cfg["logdir"], save_dict={ "net": model, "opt": optimizer }, save_interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), TensorBoardStatsHandler(log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] train_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) trainer = SupervisedTrainer( device=device, max_epochs=cfg["n_epochs"], train_data_loader=train_dataloader, network=model, optimizer=optimizer, loss_function=loss_func, postprocessing=train_postprocessing, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, amp=cfg["amp"], ) trainer.run()
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) ################################ DATASET ################################ # get dataset train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4) ################################ DATASET ################################ ################################ NETWORK ################################ # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) ################################ NETWORK ################################ ################################ LOSS ################################ loss = monai.losses.DiceLoss(sigmoid=True) ################################ LOSS ################################ ################################ OPT ################################ opt = torch.optim.Adam(net.parameters(), 1e-3) ################################ OPT ################################ ################################ LR ################################ lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) ################################ LR ################################ ################################ Evalutaion ################################ val_post_transforms = ... val_handlers = ... evaluator = ... train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()
def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])] val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ] ) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=num_workers) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) summary_writer = SummaryWriter(log_dir=root_dir) val_postprocessing = Compose( [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) class _TestEvalIterEvents: def attach(self, engine): engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) def _forward_completed(self, engine): pass val_handlers = [ StatsHandler(iteration_log=False), TensorBoardStatsHandler(summary_writer=summary_writer, iteration_log=False), TensorBoardImageHandler( log_dir=root_dir, batch_transform=from_engine(["image", "label"]), output_transform=from_engine("pred") ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), _TestEvalIterEvents(), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), postprocessing=val_postprocessing, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, metric_cmp_fn=lambda cur, prev: cur >= prev, # if greater or equal, treat as new best metric val_handlers=val_handlers, amp=bool(amp), to_kwargs={"memory_format": torch.preserve_format}, amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, ) train_postprocessing = Compose( [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) class _TestTrainIterEvents: def attach(self, engine): engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed) engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed) engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self._model_completed) def _forward_completed(self, engine): pass def _loss_completed(self, engine): pass def _backward_completed(self, engine): pass def _model_completed(self, engine): pass train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)), TensorBoardStatsHandler( summary_writer=summary_writer, tag_name="train_loss", output_transform=from_engine("loss", first=True) ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), _TestTrainIterEvents(), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), postprocessing=train_postprocessing, key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, train_handlers=train_handlers, amp=bool(amp), optim_set_to_none=True, to_kwargs={"memory_format": torch.preserve_format}, amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, ) trainer.run() return evaluator.state.best_metric
def run_interaction(self, train): label_names = {"spleen": 1, "background": 0} np.random.seed(0) data = [ { "image": np.random.randint(0, 256, size=(1, 15, 15, 15)).astype(np.float32), "label": np.random.randint(0, 2, size=(1, 15, 15, 15)), "label_names": label_names, } for _ in range(5) ] network = torch.nn.Conv3d(3, len(label_names), 1) lr = 1e-3 opt = torch.optim.Adam(network.parameters(), lr) loss = DiceCELoss(to_onehot_y=True, softmax=True) pre_transforms = Compose( [ FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"), AddInitialSeedPointMissingLabelsd(keys="label", guidance="guidance", sids="sids"), AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1), ToTensord(keys=("image", "label")), ] ) dataset = Dataset(data, transform=pre_transforms) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [ FindDiscrepancyRegionsDeepEditd(keys="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanceDeepEditd( keys="NA", guidance="guidance", discrepancy="discrepancy", probability="probability" ), AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1), ToTensord(keys=("image", "label")), ] post_transforms = [ Activationsd(keys="pred", softmax=True), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=len(label_names)), SplitPredsLabeld(keys="pred"), ToTensord(keys=("image", "label")), ] iteration_transforms = Compose(iteration_transforms) post_transforms = Compose(post_transforms) i = Interaction( deepgrow_probability=1.0, transforms=iteration_transforms, click_probability_key="probability", train=train, label_names=label_names, ) self.assertEqual(len(i.transforms.transforms), 4, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, postprocessing=post_transforms, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing") self.assertEqual(engine.state.best_metric, 1)
def main(config): now = datetime.now().strftime("%Y%m%d-%H:%M:%S") # path csv_path = config['path']['csv_path'] trained_model_path = config['path'][ 'trained_model_path'] # if None, trained from scratch training_model_folder = os.path.join( config['path']['training_model_folder'], now) # '/path/to/folder' if not os.path.exists(training_model_folder): os.makedirs(training_model_folder) logdir = os.path.join(training_model_folder, 'logs') if not os.path.exists(logdir): os.makedirs(logdir) # PET CT scan params image_shape = tuple(config['preprocessing']['image_shape']) # (x, y, z) in_channels = config['preprocessing']['in_channels'] voxel_spacing = tuple( config['preprocessing'] ['voxel_spacing']) # (4.8, 4.8, 4.8) # in millimeter, (x, y, z) data_augment = config['preprocessing'][ 'data_augment'] # True # for training dataset only resize = config['preprocessing']['resize'] # True # not use yet origin = config['preprocessing']['origin'] # how to set the new origin normalize = config['preprocessing'][ 'normalize'] # True # whether or not to normalize the inputs number_class = config['preprocessing']['number_class'] # 2 # CNN params architecture = config['model']['architecture'] # 'unet' or 'vnet' cnn_params = config['model'][architecture]['cnn_params'] # transform list to tuple for key, value in cnn_params.items(): if isinstance(value, list): cnn_params[key] = tuple(value) # Training params epochs = config['training']['epochs'] batch_size = config['training']['batch_size'] shuffle = config['training']['shuffle'] opt_params = config['training']["optimizer"]["opt_params"] # Get Data DM = DataManager(csv_path=csv_path) train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test( wrap_with_dict=True) # Input preprocessing # use data augmentation for training train_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # user can also add other random transforms RandAffined(keys=("pet_img", "ct_img", "mask_img"), spatial_size=None, prob=0.4, rotate_range=(0, np.pi / 30, np.pi / 15), shear_range=None, translate_range=(10, 10, 10), scale_range=(0.1, 0.1, 0.1), mode=("bilinear", "bilinear", "nearest"), padding_mode="border"), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # without data augmentation for validation val_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_images_paths, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle, num_workers=2) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_images_paths, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=batch_size, num_workers=2) # Model # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, # 3D in_channels=in_channels, out_channels=1, kernel_size=5, channels=(8, 16, 32, 64, 128), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True) opt = torch.optim.Adam(net.parameters(), 1e-3) # training val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), # TensorBoardImageHandler( # log_dir="./runs/", # batch_transform=lambda x: (x["image"], x["label"]), # output_transform=lambda x: x["pred"], # ), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SimpleInferer(), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "val_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "val_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) train_handlers = [ # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, prepare_batch=lambda x: (x['image'], x['mask_img']), inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "train_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "train_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()