def test_metrics_writer(self): default_dir = os.path.join('.', 'runs') shutil.rmtree(default_dir, ignore_errors=True) with tempfile.TemporaryDirectory() as temp_dir: # set up engine def _train_func(engine, batch): return batch + 1.0 engine = Engine(_train_func) # set up dummy metric @engine.on(Events.EPOCH_COMPLETED) def _update_metric(engine): current_metric = engine.state.metrics.get('acc', 0.1) engine.state.metrics['acc'] = current_metric + 0.1 # set up testing handler writer = SummaryWriter(log_dir=temp_dir) stats_handler = TensorBoardStatsHandler( writer, output_transform=lambda x: {'loss': x * 2.0}, global_epoch_transform=lambda x: x * 3.0) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) # check logging output self.assertTrue(len(glob.glob(temp_dir)) > 0) self.assertTrue(not os.path.exists(default_dir))
def test_metrics_writer(self): with tempfile.TemporaryDirectory() as tempdir: # set up engine def _train_func(engine, batch): return [batch + 1.0] engine = Engine(_train_func) # set up dummy metric @engine.on(Events.EPOCH_COMPLETED) def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 engine.state.test = current_metric # set up testing handler writer = SummaryWriter(log_dir=tempdir) stats_handler = TensorBoardStatsHandler( summary_writer=writer, iteration_log=True, epoch_log=False, output_transform=lambda x: {"loss": x[0] * 2.0}, global_epoch_transform=lambda x: x * 3.0, state_attributes=["test"], ) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) writer.close() # check logging output self.assertTrue(len(glob.glob(tempdir)) > 0)
def test_metrics_writer(self): tempdir = tempfile.mkdtemp() shutil.rmtree(tempdir, ignore_errors=True) # set up engine def _train_func(engine, batch): return batch + 1.0 engine = Engine(_train_func) # set up dummy metric @engine.on(Events.EPOCH_COMPLETED) def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 # set up testing handler writer = SummaryWriter(log_dir=tempdir) stats_handler = TensorBoardStatsHandler( writer, output_transform=lambda x: {"loss": x * 2.0}, global_epoch_transform=lambda x: x * 3.0 ) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) # check logging output self.assertTrue(os.path.exists(tempdir)) self.assertTrue(len(glob.glob(tempdir)) > 0) shutil.rmtree(tempdir)
def train_handlers(self, context: Context): handlers: List[Any] = [] # LR Scheduler lr_scheduler = self.lr_scheduler_handler(context) if lr_scheduler: handlers.append(lr_scheduler) if context.local_rank == 0: handlers.extend([ StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), TensorBoardStatsHandler( log_dir=context.events_dir, tag_name="train_loss", output_transform=from_engine(["loss"], first=True), ), ]) if context.evaluator: logger.info( f"{context.local_rank} - Adding Validation to run every '{self._val_interval}' interval" ) handlers.append( ValidationHandler(self._val_interval, validator=context.evaluator, epoch_level=True)) return handlers
def val_handlers(self, context: Context): val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=context.events_dir, output_transform=lambda x: None), ] return val_handlers if context.local_rank == 0 else None
def test_metrics_print(self): with tempfile.TemporaryDirectory() as tempdir: # set up engine def _train_func(engine, batch): return batch + 1.0 engine = Engine(_train_func) # set up dummy metric @engine.on(Events.EPOCH_COMPLETED) def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 # set up testing handler stats_handler = TensorBoardStatsHandler(log_dir=tempdir) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) # check logging output self.assertTrue(len(glob.glob(tempdir)) > 0)
def run(self, date=None) -> str: if date is not None: now = date else: now = datetime.datetime.now() datetime_string = now.strftime('%d/%m/%Y %H:%M:%S') print(f'Training started: {datetime_string}') now = datetime.datetime.now() timedate_info = str(now).split(' ')[0] + '_' + str(now.strftime("%H:%M:%S")).replace(':', '-') training_dir = os.path.join(self.out_dir, 'training') if not os.path.exists(training_dir): os.mkdir(training_dir) self.output_dir = os.path.join(training_dir, self.out_name + '_' + timedate_info) os.mkdir(self.output_dir) self.validator.output_dir = self.output_dir if self.summary_writer is None: self.summary_writer = SummaryWriter(log_dir=self.output_dir) if self.validator.summary_writer is None: self.validator.summary_writer = self.summary_writer handlers = [ MetricLogger(self.output_dir, validator=self.validator), ValidationHandler( validator=self.validator, start=self.validation_epoch, interval=self.validation_interval ), StatsHandler(tag_name="loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler( summary_writer=self.summary_writer, tag_name="Loss", output_transform=lambda x: x["loss"] ), ] save_dict = { 'network': self.network, 'optimizer': self.optimizer } if self.lr_scheduler is not None: handlers.insert(0, LrScheduleHandler(lr_scheduler=self.lr_scheduler, print_lr=True)) save_dict['lr_scheduler'] = self.lr_scheduler handlers.append( CheckpointSaver(save_dir=self.output_dir, save_dict=save_dict, save_interval=1, n_saved=1) ) self._register_handlers(handlers) super().run() return self.output_dir
def test_metrics_print(self): default_dir = os.path.join('.', 'runs') shutil.rmtree(default_dir, ignore_errors=True) # set up engine def _train_func(engine, batch): return batch + 1.0 engine = Engine(_train_func) # set up dummy metric @engine.on(Events.EPOCH_COMPLETED) def _update_metric(engine): current_metric = engine.state.metrics.get('acc', 0.1) engine.state.metrics['acc'] = current_metric + 0.1 # set up testing handler stats_handler = TensorBoardStatsHandler() stats_handler.attach(engine) engine.run(range(3), max_epochs=2) # check logging output self.assertTrue(os.path.exists(default_dir)) shutil.rmtree(default_dir)
def run(self, global_epoch: int) -> None: if global_epoch == 1: handlers = [ StatsHandler(), TensorBoardStatsHandler( summary_writer=self.summary_writer ), #, output_transform=lambda x: None), CheckpointSaver(save_dir=self.output_dir, save_dict={"network": self.network}, save_key_metric=True), MetricsSaver(save_dir=self.output_dir, metrics=['Valid_AUC', 'Valid_ACC']), self.early_stop_handler, ] self._register_handlers(handlers) return super().run(global_epoch=global_epoch)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation train_imtrans = Compose([ ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor() ]) train_segtrans = Compose([ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor() ]) val_imtrans = Compose([ ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor() ]) val_segtrans = Compose([ AddChannel(), Resize((96, 96, 96)), ToTensor() ]) # define nifti dataset, data loader check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss = monai.losses.DiceLoss(do_sigmoid=True) lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device('cuda:0') # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not a loss value train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) validation_every_n_epochs = 1 # Set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along Depth axis, at every validation epoch val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch[0], batch[1]), output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch ) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) train_epochs = 30 state = trainer.run(train_loader, train_epochs) shutil.rmtree(tempdir)
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) ################################ DATASET ################################ # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])] val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), ] ) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) ################################ DATASET ################################ ################################ NETWORK ################################ # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) ################################ NETWORK ################################ ################################ LOSS ################################ loss = monai.losses.DiceLoss(sigmoid=True) ################################ LOSS ################################ ################################ OPT ################################ opt = torch.optim.Adam(net.parameters(), 1e-3) ################################ OPT ################################ ################################ LR ################################ lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) ################################ LR ################################ val_post_transforms = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), TensorBoardImageHandler( log_dir="./runs/", batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], ), CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) train_post_transforms = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()
def train(self, train_info, valid_info, hyperparameters, run_data_check=False): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not run_data_check: start_dt = datetime.datetime.now() start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S') print(f'Training started: {start_dt_string}') # 1. Create folders to save the model timedate_info = str( datetime.datetime.now()).split(' ')[0] + '_' + str( datetime.datetime.now().strftime("%H:%M:%S")).replace( ':', '-') path_to_model = os.path.join( self.out_dir, 'trained_models', self.unique_name + '_' + timedate_info) os.mkdir(path_to_model) # 2. Load hyperparameters learning_rate = hyperparameters['learning_rate'] weight_decay = hyperparameters['weight_decay'] total_epoch = hyperparameters['total_epoch'] multiplicator = hyperparameters['multiplicator'] batch_size = hyperparameters['batch_size'] validation_epoch = hyperparameters['validation_epoch'] validation_interval = hyperparameters['validation_interval'] H = hyperparameters['H'] L = hyperparameters['L'] # 3. Consider class imbalance negative, positive = 0, 0 for _, label in train_info: if int(label) == 0: negative += 1 elif int(label) == 1: positive += 1 pos_weight = torch.Tensor([(negative / positive)]).to(self.device) # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices) train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir, train_info) valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir, valid_info) large_image_splitter(train_data, self.cache_dir) set_determinism(seed=100) train_trans, valid_trans = self.transformations(H, L) train_dataset = PersistentDataset( data=train_data[:], transform=train_trans, cache_dir=self.persistent_dataset_dir) valid_dataset = PersistentDataset( data=valid_data[:], transform=valid_trans, cache_dir=self.persistent_dataset_dir) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory, num_workers=self.num_workers, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory, num_workers=self.num_workers, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Perform data checks if run_data_check: check_data = monai.utils.misc.first(train_loader) print(check_data["image"].shape, check_data["label"]) for i in range(batch_size): multi_slice_viewer( check_data["image"][i, 0, :, :, :], check_data["image_meta_dict"]["filename_or_obj"][i]) exit() """c = 1 for d in train_loader: img = d["image"] seg = d["seg"][0] seg, _ = nrrd.read(seg) img_name = d["image_meta_dict"]["filename_or_obj"][0] print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape) multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0]) #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0]) c += 1 exit()""" # 5. Prepare model model = ModelCT().to(self.device) # 6. Define loss function, optimizer and scheduler loss_function = torch.nn.BCEWithLogitsLoss( pos_weight) # pos_weight for class imbalance optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, multiplicator, last_epoch=-1) # 7. Create post validation transforms and handlers path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard') writer = SummaryWriter(log_dir=path_to_tensorboard) valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), ]) valid_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(summary_writer=writer, output_transform=lambda x: None), CheckpointSaver(save_dir=path_to_model, save_dict={"model": model}, save_key_metric=True), MetricsSaver(save_dir=path_to_model, metrics=['Valid_AUC', 'Valid_ACC']), ] # 8. Create validatior discrete = AsDiscrete(threshold_values=True) evaluator = SupervisedEvaluator( device=self.device, val_data_loader=valid_loader, network=model, post_transform=valid_post_transforms, key_val_metric={ "Valid_AUC": ROCAUC(output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "Valid_Accuracy": Accuracy(output_transform=lambda x: (discrete(x["pred"]), x["label"])) }, val_handlers=valid_handlers, amp=self.amp, ) # 9. Create trainer # Loss function does the last sigmoid, so we dont need it here. train_post_transforms = Compose([ # Empty ]) logger = MetricLogger(evaluator=evaluator) train_handlers = [ logger, LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandlerCT(validator=evaluator, start=validation_epoch, interval=validation_interval, epoch_level=True), StatsHandler(tag_name="loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer, tag_name="Train_Loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=path_to_model, save_dict={ "model": model, "opt": optimizer }, save_interval=1, n_saved=1), ] trainer = SupervisedTrainer( device=self.device, max_epochs=total_epoch, train_data_loader=train_loader, network=model, optimizer=optimizer, loss_function=loss_function, post_transform=train_post_transforms, train_handlers=train_handlers, amp=self.amp, ) # 10. Run trainer trainer.run() # 11. Save results np.save(path_to_model + '/AUCS.npy', np.array(logger.metrics['Valid_AUC'])) np.save(path_to_model + '/ACCS.npy', np.array(logger.metrics['Valid_ACC'])) np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss)) np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters)) return path_to_model
def run_training_test(root_dir, device="cuda:0", amp=False): images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None), TensorBoardImageHandler(log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"]), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, amp=True if amp else False, ) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=root_dir, save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, amp=True if amp else False, ) trainer.run() return evaluator.state.best_metric
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ]) val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["seg"].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss = monai.losses.DiceLoss(sigmoid=True) lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device("cuda:0") # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch["img"], batch["seg"]), device, non_blocking) trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "net": net, "opt": opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name="trainer") train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) validation_every_n_iters = 5 # set parameters for validation metric_name = "Mean_Dice" # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(sigmoid=True, to_onehot_y=False)} # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.iteration, ) # fetch global iteration number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch["img"], batch["seg"]), output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch, ) evaluator.add_event_handler(event_name=Events.ITERATION_COMPLETED(every=2), handler=val_tensorboard_image_handler) train_epochs = 5 state = trainer.run(train_loader, train_epochs) print(state) shutil.rmtree(tempdir)
trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) # set parameters for validation validation_every_n_epochs = 1 metric_name = 'Accuracy' # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator',
def train(cfg): log_dir = create_log_dir(cfg) device = set_device(cfg) # -------------------------------------------------------------------------- # Data Loading and Preprocessing # -------------------------------------------------------------------------- # __________________________________________________________________________ # Build MONAI preprocessing train_preprocess = Compose([ ToTensorD(keys="image"), TorchVisionD(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyD(keys="image"), RandFlipD(keys="image", prob=0.5), RandRotate90D(keys="image", prob=0.5), CastToTypeD(keys="image", dtype=np.float32), RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) valid_preprocess = Compose([ CastToTypeD(keys="image", dtype=np.float32), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) # __________________________________________________________________________ # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( train_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], train_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( valid_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], valid_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # __________________________________________________________________________ # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) # __________________________________________________________________________ # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("Fist sample is None!") print("image: ") print(" shape", first_sample["image"].shape) print(" type: ", type(first_sample["image"])) print(" dtype: ", first_sample["image"].dtype) print("labels: ") print(" shape", first_sample["label"].shape) print(" type: ", type(first_sample["label"])) print(" dtype: ", first_sample["label"].dtype) print(f"batch size: {cfg['batch_size']}") print(f"train number of batches: {len(train_dataloader)}") print(f"valid number of batches: {len(valid_dataloader)}") # -------------------------------------------------------------------------- # Deep Learning Classification Model # -------------------------------------------------------------------------- # __________________________________________________________________________ # initialize model model = TorchVisionFCModel("resnet18", num_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # optimizer if cfg["novograd"]: optimizer = Novograd(model.parameters(), cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler if cfg["amp"]: cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= ( 1, 6) else False else: cfg["amp"] = False scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) # -------------------------------------------- # Ignite Trainer/Evaluator # -------------------------------------------- # Evaluator val_handlers = [ CheckpointSaver(save_dir=log_dir, save_dict={"net": model}, save_key_metric=True), StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None), ] val_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) evaluator = SupervisedEvaluator( device=device, val_data_loader=valid_dataloader, network=model, postprocessing=val_postprocessing, key_val_metric={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=cfg["amp"], ) # Trainer train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), CheckpointSaver(save_dir=cfg["logdir"], save_dict={ "net": model, "opt": optimizer }, save_interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), TensorBoardStatsHandler(log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] train_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) trainer = SupervisedTrainer( device=device, max_epochs=cfg["n_epochs"], train_data_loader=train_dataloader, network=model, optimizer=optimizer, loss_function=loss_func, postprocessing=train_postprocessing, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, amp=cfg["amp"], ) trainer.run()
def run_training_test(root_dir, device="cuda:0"): real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) train_files = [{"reals": img} for img in zip(real_images)] # prepare real data train_transforms = Compose([ LoadNiftid(keys=["reals"]), AsChannelFirstd(keys=["reals"]), ScaleIntensityd(keys=["reals"]), RandFlipd(keys=["reals"], prob=0.5), ToTensord(keys=["reals"]), ]) train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) learning_rate = 2e-4 betas = (0.5, 0.999) real_label = 1 fake_label = 0 # create discriminator disc_net = Discriminator(in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5).to(device) disc_net.apply(normal_init) disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas) disc_loss_criterion = torch.nn.BCELoss() def discriminator_loss(gen_images, real_images): real = real_images.new_full((real_images.shape[0], 1), real_label) gen = gen_images.new_full((gen_images.shape[0], 1), fake_label) realloss = disc_loss_criterion(disc_net(real_images), real) genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen) return torch.div(torch.add(realloss, genloss), 2) # create generator latent_size = 64 gen_net = Generator(latent_shape=latent_size, start_shape=(latent_size, 8, 8), channels=[32, 16, 8, 1], strides=[2, 2, 2, 1]) gen_net.apply(normal_init) gen_net.conv.add_module("activation", torch.nn.Sigmoid()) gen_net = gen_net.to(device) gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas) gen_loss_criterion = torch.nn.BCELoss() def generator_loss(gen_images): output = disc_net(gen_images) cats = output.new_full(output.shape, real_label) return gen_loss_criterion(output, cats) key_train_metric = None train_handlers = [ StatsHandler( name="training_loss", output_transform=lambda x: { Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS] }, ), TensorBoardStatsHandler( log_dir=root_dir, tag_name="training_loss", output_transform=lambda x: { Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS] }, ), CheckpointSaver(save_dir=root_dir, save_dict={ "g_net": gen_net, "d_net": disc_net }, save_interval=2, epoch_level=True), ] disc_train_steps = 2 num_epochs = 5 trainer = GanTrainer( device, num_epochs, train_loader, gen_net, gen_opt, generator_loss, disc_net, disc_opt, discriminator_loss, d_train_steps=disc_train_steps, latent_shape=latent_size, key_train_metric=key_train_metric, train_handlers=train_handlers, ) trainer.run() return trainer.state
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation train_imtrans = Compose([ ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), EnsureType(), ]) train_segtrans = Compose([ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), EnsureType() ]) val_imtrans = Compose( [ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()]) val_segtrans = Compose([AddChannel(), Resize((96, 96, 96)), EnsureType()]) # define image dataset, data loader check_ds = ImageDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = ImageDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader( train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = ImageDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs_array/", "net", n_saved=10, require_empty=False) trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "net": net, "opt": opt }, ) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not a loss value train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x) train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: x) train_tensorboard_stats_handler.attach(trainer) validation_every_n_epochs = 1 # Set parameters for validation metric_name = "Mean_Dice" # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice()} post_pred = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) post_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)]) # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator( net, val_metrics, device, True, output_transform=lambda x, y, y_pred: ([post_pred(i) for i in decollate_batch(y_pred)], [post_label(i) for i in decollate_batch(y)]), ) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along Depth axis, at every validation epoch val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch[0], batch[1]), output_transform=lambda output: output[0], global_iter_transform=lambda x: trainer.state.epoch, ) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) train_epochs = 30 state = trainer.run(train_loader, train_epochs) print(state)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ # the path of ixi IXI-T1 dataset data_path = os.sep.join([".", "workspace", "data", "medical", "ixi", "IXI-T1"]) images = [ "IXI314-IOP-0889-T1.nii.gz", "IXI249-Guys-1072-T1.nii.gz", "IXI609-HH-2600-T1.nii.gz", "IXI173-HH-1590-T1.nii.gz", "IXI020-Guys-0700-T1.nii.gz", "IXI342-Guys-0909-T1.nii.gz", "IXI134-Guys-0780-T1.nii.gz", "IXI577-HH-2661-T1.nii.gz", "IXI066-Guys-0731-T1.nii.gz", "IXI130-HH-1528-T1.nii.gz", "IXI607-Guys-1097-T1.nii.gz", "IXI175-HH-1570-T1.nii.gz", "IXI385-HH-2078-T1.nii.gz", "IXI344-Guys-0905-T1.nii.gz", "IXI409-Guys-0960-T1.nii.gz", "IXI584-Guys-1129-T1.nii.gz", "IXI253-HH-1694-T1.nii.gz", "IXI092-HH-1436-T1.nii.gz", "IXI574-IOP-1156-T1.nii.gz", "IXI585-Guys-1130-T1.nii.gz", ] images = [os.sep.join([data_path, f]) for f in images] # 2 binary labels for gender classification: man and woman labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) # define transforms train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), EnsureType()]) val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()]) # define image dataset, data loader check_ds = ImageDataset(image_files=images, labels=labels, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) im, label = monai.utils.misc.first(check_loader) print(type(im), im.shape, label) # create DenseNet121, CrossEntropyLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) loss = torch.nn.CrossEntropyLoss() lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs_array/", "net", n_saved=10, require_empty=False) trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt} ) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x) train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler(output_transform=lambda x: x) train_tensorboard_stats_handler.attach(trainer) # set parameters for validation validation_every_n_epochs = 1 metric_name = "Accuracy" # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader val_ds = ImageDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # create a training data loader train_ds = ImageDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs) print(state)
def configure(self): self.set_device() network = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(self.device) if self.multi_gpu: network = DistributedDataParallel( module=network, device_ids=[self.device], find_unused_parameters=False, ) train_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=[1.0, 1.0, 1.0], mode=["bilinear", "nearest"]), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=("image", "label")), ]) train_datalist = load_decathlon_datalist(self.data_list_file_path, True, "training") if self.multi_gpu: train_datalist = partition_dataset( data=train_datalist, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True, )[dist.get_rank()] train_ds = CacheDataset( data=train_datalist, transform=train_transforms, cache_num=32, cache_rate=1.0, num_workers=4, ) train_data_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, ) val_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), ToTensord(keys=("image", "label")), ]) val_datalist = load_decathlon_datalist(self.data_list_file_path, True, "validation") val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4) val_data_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=4, ) post_transform = Compose([ Activationsd(keys="pred", softmax=True), AsDiscreted( keys=["pred", "label"], argmax=[True, False], to_onehot=True, n_classes=2, ), ]) # metric key_val_metric = { "val_mean_dice": MeanDice( include_background=False, output_transform=lambda x: (x["pred"], x["label"]), device=self.device, ) } val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver( save_dir=self.ckpt_dir, save_dict={"model": network}, save_key_metric=True, ), TensorBoardStatsHandler(log_dir=self.ckpt_dir, output_transform=lambda x: None), ] self.eval_engine = SupervisedEvaluator( device=self.device, val_data_loader=val_data_loader, network=network, inferer=SlidingWindowInferer( roi_size=[160, 160, 160], sw_batch_size=4, overlap=0.5, ), post_transform=post_transform, key_val_metric=key_val_metric, val_handlers=val_handlers, amp=self.amp, ) optimizer = torch.optim.Adam(network.parameters(), self.learning_rate) loss_function = DiceLoss(to_onehot_y=True, softmax=True) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=self.eval_engine, interval=self.val_interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler( log_dir=self.ckpt_dir, tag_name="train_loss", output_transform=lambda x: x["loss"], ), ] self.train_engine = SupervisedTrainer( device=self.device, max_epochs=self.max_epochs, train_data_loader=train_data_loader, network=network, optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), post_transform=post_transform, key_train_metric=None, train_handlers=train_handlers, amp=self.amp, ) if self.local_rank > 0: self.train_engine.logger.setLevel(logging.WARNING) self.eval_engine.logger.setLevel(logging.WARNING)
def create_trainer(args): set_determinism(seed=args.seed) multi_gpu = args.multi_gpu local_rank = args.local_rank if multi_gpu: dist.init_process_group(backend="nccl", init_method="env://") device = torch.device("cuda:{}".format(local_rank)) torch.cuda.set_device(device) else: device = torch.device("cuda" if args.use_gpu else "cpu") pre_transforms = get_pre_transforms(args.roi_size, args.model_size, args.dimensions) click_transforms = get_click_transforms() post_transform = get_post_transforms() train_loader, val_loader = get_loaders(args, pre_transforms) # define training components network = get_network(args.network, args.channels, args.dimensions).to(device) if multi_gpu: network = torch.nn.parallel.DistributedDataParallel( network, device_ids=[local_rank], output_device=local_rank) if args.resume: logging.info('{}:: Loading Network...'.format(local_rank)) map_location = {"cuda:0": "cuda:{}".format(local_rank)} network.load_state_dict( torch.load(args.model_filepath, map_location=map_location)) # define event-handlers for engine val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=args.output, output_transform=lambda x: None), DeepgrowStatsHandler(log_dir=args.output, tag_name='val_dice', image_interval=args.image_interval), CheckpointSaver(save_dir=args.output, save_dict={"net": network}, save_key_metric=True, save_final=True, save_interval=args.save_interval, final_filename='model.pt') ] val_handlers = val_handlers if local_rank == 0 else None evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=network, iteration_update=Interaction( transforms=click_transforms, max_interactions=args.max_val_interactions, key_probability='probability', train=False), inferer=SimpleInferer(), post_transform=post_transform, key_val_metric={ "val_dice": MeanDice(include_background=False, output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers) loss_function = DiceLoss(sigmoid=True, squared_pred=True) optimizer = torch.optim.Adam(network.parameters(), args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=args.val_freq, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir=args.output, tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=args.output, save_dict={ "net": network, "opt": optimizer, "lr": lr_scheduler }, save_interval=args.save_interval * 2, save_final=True, final_filename='checkpoint.pt'), ] train_handlers = train_handlers if local_rank == 0 else train_handlers[:2] trainer = SupervisedTrainer( device=device, max_epochs=args.epochs, train_data_loader=train_loader, network=network, iteration_update=Interaction( transforms=click_transforms, max_interactions=args.max_train_interactions, key_probability='probability', train=True), optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), post_transform=post_transform, amp=args.amp, key_train_metric={ "train_dice": MeanDice(include_background=False, output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, ) return trainer
def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])] val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ] ) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=num_workers) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) summary_writer = SummaryWriter(log_dir=root_dir) val_postprocessing = Compose( [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) class _TestEvalIterEvents: def attach(self, engine): engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) def _forward_completed(self, engine): pass val_handlers = [ StatsHandler(iteration_log=False), TensorBoardStatsHandler(summary_writer=summary_writer, iteration_log=False), TensorBoardImageHandler( log_dir=root_dir, batch_transform=from_engine(["image", "label"]), output_transform=from_engine("pred") ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), _TestEvalIterEvents(), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), postprocessing=val_postprocessing, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, metric_cmp_fn=lambda cur, prev: cur >= prev, # if greater or equal, treat as new best metric val_handlers=val_handlers, amp=bool(amp), to_kwargs={"memory_format": torch.preserve_format}, amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, ) train_postprocessing = Compose( [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) class _TestTrainIterEvents: def attach(self, engine): engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed) engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed) engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self._model_completed) def _forward_completed(self, engine): pass def _loss_completed(self, engine): pass def _backward_completed(self, engine): pass def _model_completed(self, engine): pass train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)), TensorBoardStatsHandler( summary_writer=summary_writer, tag_name="train_loss", output_transform=from_engine("loss", first=True) ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), _TestTrainIterEvents(), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), postprocessing=train_postprocessing, key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, train_handlers=train_handlers, amp=bool(amp), optim_set_to_none=True, to_kwargs={"memory_format": torch.preserve_format}, amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, ) trainer.run() return evaluator.state.best_metric
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) ################################ DATASET ################################ # get dataset train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4) ################################ DATASET ################################ ################################ NETWORK ################################ # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) ################################ NETWORK ################################ ################################ LOSS ################################ loss = monai.losses.DiceLoss(sigmoid=True) ################################ LOSS ################################ ################################ OPT ################################ opt = torch.optim.Adam(net.parameters(), 1e-3) ################################ OPT ################################ ################################ LR ################################ lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) ################################ LR ################################ ################################ Evalutaion ################################ val_post_transforms = ... val_handlers = ... evaluator = ... train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz", ] # 2 binary labels for gender classification: man and woman labels = np.array( [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]) train_files = [{ "img": img, "label": label } for img, label in zip(images[:10], labels[:10])] val_files = [{ "img": img, "label": label } for img, label in zip(images[-10:], labels[-10:])] # define transforms for image train_transforms = Compose([ LoadNiftid(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img"]), ]) val_transforms = Compose([ LoadNiftid(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), ToTensord(keys=["img"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["label"]) # create DenseNet121, CrossEntropyLoss and Adam optimizer net = monai.networks.nets.densenet.densenet121( spatial_dims=3, in_channels=1, out_channels=2, ) loss = torch.nn.CrossEntropyLoss() lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device("cuda:0") # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch["img"], batch["label"]), device, non_blocking) trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "net": net, "opt": opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name="trainer") train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) # set parameters for validation validation_every_n_epochs = 1 metric_name = "Accuracy" # add evaluation metric to the evaluator engine val_metrics = { metric_name: Accuracy(), "AUC": ROCAUC(to_onehot_y=True, add_softmax=True) } # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add early stopping handler to evaluator early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs)
def main(): """ Basic UNet as implemented in MONAI for Fetal Brain Segmentation, but using ignite to manage training and validation loop and checkpointing :return: """ """ Read input and configuration parameters """ parser = argparse.ArgumentParser( description='Run basic UNet with MONAI - Ignite version.') parser.add_argument('--config', dest='config', metavar='config', type=str, help='config file') args = parser.parse_args() with open(args.config) as f: config_info = yaml.load(f, Loader=yaml.FullLoader) # print to log the parameter setups print(yaml.dump(config_info)) # GPU params cuda_device = config_info['device']['cuda_device'] num_workers = config_info['device']['num_workers'] # training and validation params loss_type = config_info['training']['loss_type'] batch_size_train = config_info['training']['batch_size_train'] batch_size_valid = config_info['training']['batch_size_valid'] lr = float(config_info['training']['lr']) lr_decay = config_info['training']['lr_decay'] if lr_decay is not None: lr_decay = float(lr_decay) nr_train_epochs = config_info['training']['nr_train_epochs'] validation_every_n_epochs = config_info['training'][ 'validation_every_n_epochs'] sliding_window_validation = config_info['training'][ 'sliding_window_validation'] if 'model_to_load' in config_info['training'].keys(): model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise BlockingIOError( "cannot find model: {}".format(model_to_load)) else: model_to_load = None if 'manual_seed' in config_info['training'].keys(): seed = config_info['training']['manual_seed'] else: seed = None # data params data_root = config_info['data']['data_root'] training_list = config_info['data']['training_list'] validation_list = config_info['data']['validation_list'] # model saving out_model_dir = os.path.join( config_info['output']['out_model_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['output_subfix']) print("Saving to directory ", out_model_dir) if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') max_nr_models_saved = config_info['output']['max_nr_models_saved'] val_image_to_tensorboad = config_info['output']['val_image_to_tensorboad'] monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) torch.cuda.set_device(cuda_device) if seed is not None: # set manual seed if required (both numpy and torch) set_determinism(seed=seed) # # set torch only seed # torch.manual_seed(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False """ Data Preparation """ # create cache directory to store results for Persistent Dataset persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) # create training and validation data lists train_files = create_data_list(data_folder_list=data_root, subject_list=training_list, img_postfix='_Image', label_postfix='_Label') print(len(train_files)) print(train_files[0]) print(train_files[-1]) val_files = create_data_list(data_folder_list=data_root, subject_list=validation_list, img_postfix='_Image', label_postfix='_Label') print(len(val_files)) print(val_files[0]) print(val_files[-1]) # data preprocessing for training: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) # - define 2D patches to be extracted # - add data augmentation (random rotation and random flip) # - squeeze to 2D train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False), RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) # create a training data loader # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, # num_workers=num_workers) train_ds = monai.data.PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # check_train_data = monai.utils.misc.first(train_loader) # print("Training data tensor shapes") # print(check_train_data['img'].shape, check_train_data['seg'].shape) # data preprocessing for validation: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) if sliding_window_validation: val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), ToTensord(keys=['img', 'seg']) ]) do_shuffle = False collate_fn_to_use = None else: # - add extraction of 2D slices from validation set to emulate how loss is computed at training val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) do_shuffle = True collate_fn_to_use = list_data_collate # create a validation data loader # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, # num_workers=num_workers) val_ds = monai.data.PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=batch_size_valid, shuffle=do_shuffle, collate_fn=collate_fn_to_use, num_workers=num_workers) # check_valid_data = monai.utils.misc.first(val_loader) # print("Validation data tensor shapes") # print(check_valid_data['img'].shape, check_valid_data['seg'].shape) """ Network preparation """ # Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss_function = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) device = torch.cuda.current_device() if lr_decay is not None: lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=lr_decay, last_epoch=-1) """ Set ignite trainer """ # function to manage batch at training def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch['img'], batch['seg']), device, non_blocking) trainer = create_supervised_trainer(model=net, optimizer=opt, loss_fn=loss_function, device=device, non_blocking=False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training if model_to_load is not None: checkpoint_handler = CheckpointLoader(load_path=model_to_load, load_dict={ 'net': net, 'opt': opt, }) checkpoint_handler.attach(trainer) state = trainer.state_dict() else: checkpoint_handler = ModelCheckpoint(out_model_dir, 'net', n_saved=max_nr_models_saved, require_empty=False) # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=save_params) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_tensorboard_stats_handler = TensorBoardStatsHandler( summary_writer=writer_train) train_tensorboard_stats_handler.attach(trainer) if lr_decay is not None: print("Using Exponential LR decay") lr_schedule_handler = LrScheduleHandler(lr_scheduler, print_lr=True, name="lr_scheduler", writer=writer_train) lr_schedule_handler.attach(trainer) """ Set ignite evaluator to perform validation at training """ # set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = { "Loss": 1.0 - MeanDice(add_sigmoid=True, to_onehot_y=False), "Mean_Dice": MeanDice(add_sigmoid=True, to_onehot_y=False) } def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch['img'].to(device), batch['seg'].to( device) roi_size = (96, 96, 1) seg_probs = sliding_window_inference(val_images, roi_size, batch_size_valid, net) return seg_probs, val_labels if sliding_window_validation: # use sliding window inference at validation print("3D evaluator is used") net.to(device) evaluator = Engine(_sliding_window_processor) for name, metric in val_metrics.items(): metric.attach(evaluator, name) else: # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values print("2D evaluator is used") evaluator = create_supervised_evaluator(model=net, metrics=val_metrics, device=device, non_blocking=True, prepare_batch=prepare_batch) epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator # early_stopper = EarlyStopping(patience=4, # score_function=stopping_fn_from_metric(metric_name), # trainer=trainer) # evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid")) val_tensorboard_stats_handler = TensorBoardStatsHandler( summary_writer=writer_valid, output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.iteration ) # fetch global iteration number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. if val_image_to_tensorboad: val_tensorboard_image_handler = TensorBoardImageHandler( summary_writer=writer_valid, batch_transform=lambda batch: (batch['img'], batch['seg']), output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch) evaluator.add_event_handler( event_name=Events.ITERATION_COMPLETED(every=1), handler=val_tensorboard_image_handler) """ Run training """ state = trainer.run(train_loader, nr_train_epochs) print("Done!")
def run_training(train_file_list, valid_file_list, config_info): """ Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks: * Data Preparation: Extract the filenames and prepare the training/validation processing transforms * Load Data: Load training and validation data to PyTorch DataLoader * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric. * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop. * Run training: The MONAI trainer is run, performing training and validation during training. Args: train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format. valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format. config_info: dict, contains configuration parameters for sampling, network and training. See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. """ """ Read input and configuration parameters """ # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # print to log the parameter setups print(yaml.dump(config_info)) # extract network parameters, perform checks/set defaults if not present and print them to log if 'seg_labels' in config_info['training'].keys(): seg_labels = config_info['training']['seg_labels'] else: seg_labels = [1] nr_out_channels = len(seg_labels) print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels)) patch_size = config_info["training"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) spacing = config_info["training"]["spacing"] print("Bringing all images to spacing = {}".format(spacing)) if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None: model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError("Cannot find model: {}".format(model_to_load)) else: print("Loading model from {}".format(model_to_load)) else: model_to_load = None # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) # set determinism if required if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None: seed = config_info['training']['manual_seed'] else: seed = None if seed is not None: print("Using determinism with seed = {}\n".format(seed)) set_determinism(seed=seed) """ Setup data output directory """ out_model_dir = os.path.join(config_info['output']['out_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['out_postfix']) print("Saving to directory {}\n".format(out_model_dir)) # create cache directory to store results for Persistent Dataset if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) """ Data preparation """ # Read the input files for training and validation print("*** Loading input data for training...") train_files = create_data_list_of_dictionaries(train_file_list) print("Number of inputs for training = {}".format(len(train_files))) val_files = create_data_list_of_dictionaries(valid_file_list) print("Number of inputs for validation = {}".format(len(val_files))) # Define MONAI processing transforms for the training data. This includes: # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1] # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M]) # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, # RandFlipd) # - ToTensor: convert to pytorch tensor train_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False), SqueezeDimd(keys=["image", "label"], dim=-1), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("bilinear", "nearest"), align_corners=(True, None), prob=0.16, ), RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2, keep_size=True, mode=["bilinear", "nearest"], padding_mode=["zeros", "border"]), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5), ToTensord(keys=["image", "label"]), ] ) # Define MONAI processing transforms for the validation data # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - ToTensor: convert to pytorch tensor # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image", "label"]), ] ) """ Load data """ # create training data loader train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=config_info['training']['batch_size_train'], shuffle=True, num_workers=config_info['device']['num_workers']) check_train_data = misc.first(train_loader) print("Training data tensor shapes:") print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape)) # create validation data loader if config_info['training']['batch_size_valid'] != 1: raise Exception("Batch size different from 1 at validation ar currently not supported") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=config_info['device']['num_workers']) check_valid_data = misc.first(val_loader) print("Validation data tensor shapes (Example):") print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape)) """ Network preparation """ print("*** Preparing the network ...") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) # initialise the network net = DynUNet( spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False, ).to(current_device) print(net) # define the loss function loss_function = choose_loss_function(nr_out_channels, config_info) # define the optimiser and the learning rate scheduler opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95) scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9 ) """ MONAI evaluator """ print("*** Preparing the dynUNet evaluator engine...\n") # val_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"), output_transform=lambda x: None, global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True, file_prefix='best_valid'), ] if config_info['output']['val_image_to_tensorboad']: val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"), batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], interval=2)) # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2,)) flip_inputs_2 = torch.flip(inputs, dims=(3,)) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=None, key_val_metric={ "Mean_dice": MeanDice( include_background=False, to_onehot_y=True, mutually_exclusive=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=False, ) """ MONAI trainer """ print("*** Preparing the dynUNet trainer engine...\n") # train_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len # define event handlers for the trainer writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer_train, log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss", output_transform=lambda x: x["loss"], global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_final=True, save_interval=2, epoch_level=True, n_saved=config_info['output']['max_nr_models_saved']), ] if model_to_load is not None: train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt})) # define customized trainer class DynUNetTrainer(SupervisedTrainer): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) def _compute_loss(preds, label): labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]] return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))]) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets).mean() loss.backward() self.optimizer.step() return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()} trainer = DynUNetTrainer( device=current_device, max_epochs=config_info['training']['nr_train_epochs'], train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss_function, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=False, ) """ Run training """ print("*** Run training...") trainer.run() print("Done!")
def main(config): now = datetime.now().strftime("%Y%m%d-%H:%M:%S") # path csv_path = config['path']['csv_path'] trained_model_path = config['path'][ 'trained_model_path'] # if None, trained from scratch training_model_folder = os.path.join( config['path']['training_model_folder'], now) # '/path/to/folder' if not os.path.exists(training_model_folder): os.makedirs(training_model_folder) logdir = os.path.join(training_model_folder, 'logs') if not os.path.exists(logdir): os.makedirs(logdir) # PET CT scan params image_shape = tuple(config['preprocessing']['image_shape']) # (x, y, z) in_channels = config['preprocessing']['in_channels'] voxel_spacing = tuple( config['preprocessing'] ['voxel_spacing']) # (4.8, 4.8, 4.8) # in millimeter, (x, y, z) data_augment = config['preprocessing'][ 'data_augment'] # True # for training dataset only resize = config['preprocessing']['resize'] # True # not use yet origin = config['preprocessing']['origin'] # how to set the new origin normalize = config['preprocessing'][ 'normalize'] # True # whether or not to normalize the inputs number_class = config['preprocessing']['number_class'] # 2 # CNN params architecture = config['model']['architecture'] # 'unet' or 'vnet' cnn_params = config['model'][architecture]['cnn_params'] # transform list to tuple for key, value in cnn_params.items(): if isinstance(value, list): cnn_params[key] = tuple(value) # Training params epochs = config['training']['epochs'] batch_size = config['training']['batch_size'] shuffle = config['training']['shuffle'] opt_params = config['training']["optimizer"]["opt_params"] # Get Data DM = DataManager(csv_path=csv_path) train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test( wrap_with_dict=True) # Input preprocessing # use data augmentation for training train_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # user can also add other random transforms RandAffined(keys=("pet_img", "ct_img", "mask_img"), spatial_size=None, prob=0.4, rotate_range=(0, np.pi / 30, np.pi / 15), shear_range=None, translate_range=(10, 10, 10), scale_range=(0.1, 0.1, 0.1), mode=("bilinear", "bilinear", "nearest"), padding_mode="border"), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # without data augmentation for validation val_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_images_paths, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle, num_workers=2) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_images_paths, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=batch_size, num_workers=2) # Model # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, # 3D in_channels=in_channels, out_channels=1, kernel_size=5, channels=(8, 16, 32, 64, 128), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True) opt = torch.optim.Adam(net.parameters(), 1e-3) # training val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), # TensorBoardImageHandler( # log_dir="./runs/", # batch_transform=lambda x: (x["image"], x["label"]), # output_transform=lambda x: x["pred"], # ), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SimpleInferer(), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "val_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "val_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) train_handlers = [ # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, prepare_batch=lambda x: (x['image'], x['mask_img']), inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "train_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "train_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()