def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")): class _TestBatch(Dataset): def __getitem__(self, _unused_id): im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1) return im[None], seg[None].astype(np.float32) def __len__(self): return train_steps net = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, ).to(device) loss = DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-4) src = DataLoader(_TestBatch(), batch_size=batch_size) trainer = create_supervised_trainer(net, opt, loss, device, False) trainer.run(src, 1) loss = trainer.state.output return loss
def test_epistemic_scoring(self): input_size = (20, 20, 20) device = "cuda" if torch.cuda.is_available() else "cpu" keys = ["image", "label"] num_training_ims = 10 train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) transforms = Compose([ AddChanneld(keys), CropForegroundd(keys, source_key="image"), DivisiblePadd(keys, 4), ]) infer_transforms = Compose([ AddChannel(), CropForeground(), DivisiblePad(4), ]) train_ds = CacheDataset(train_data, transforms) # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) model = UNet(3, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) num_epochs = 10 for _ in trange(num_epochs): epoch_loss = 0 for batch_data in train_loader: inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= len(train_loader) entropy_score = EpistemicScoring(model=model, transforms=infer_transforms, roi_size=[20, 20, 20], num_samples=10) # Call Individual Infer from Epistemic Scoring ip_stack = [test_data["image"], test_data["image"], test_data["image"]] ip_stack = np.array(ip_stack) score_3d = entropy_score.entropy_3d_volume(ip_stack) score_3d_sum = np.sum(score_3d) # Call Entropy Metric from Epistemic Scoring self.assertEqual(score_3d.shape, input_size) self.assertIsInstance(score_3d_sum, np.float32) self.assertGreater(score_3d_sum, 3.0)
def run_test(batch_size=64, train_steps=200, device=torch.device("cuda:0")): class _TestBatch(Dataset): def __init__(self, transforms): self.transforms = transforms def __getitem__(self, _unused_id): im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1) seed = np.random.randint(2147483647) self.transforms.set_random_state(seed=seed) im = self.transforms(im) self.transforms.set_random_state(seed=seed) seg = self.transforms(seg) return im, seg def __len__(self): return train_steps net = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, ).to(device) loss = DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-2) train_transforms = Compose([ AddChannel(), ScaleIntensity(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(), ToTensor() ]) src = DataLoader(_TestBatch(train_transforms), batch_size=batch_size, shuffle=True) net.train() epoch_loss = 0 step = 0 for img, seg in src: step += 1 opt.zero_grad() output = net(img.to(device)) step_loss = loss(output, seg.to(device)) step_loss.backward() opt.step() epoch_loss += step_loss.item() epoch_loss /= step return epoch_loss, step
in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) #loss_function = DiceLoss(to_onehot_y=True, softmax=True) #optimizer = torch.optim.Adam(model.parameters(), 1e-4) loss_function = DiceCELoss(include_background=True, to_onehot_y=True, softmax=True, lambda_dice=0.5, lambda_ce=0.5) optimizer = torch.optim.Adam(model.parameters(), 1e-3) dice_metric = DiceMetric(include_background=False, reduction="mean") scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5) ## """## Execute a typical PyTorch training process""" epoch_num = 300 val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) post_label = AsDiscrete(to_onehot=True, n_classes=2)
def main_worker(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f if not os.path.exists(args.dir): raise FileNotFoundError(f"Missing directory {args.dir}") # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") total_start = time.time() train_transforms = Compose([ # load 4 Nifti images and stack them together LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), Orientationd(keys=["image", "label"], axcodes="RAS"), RandSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64], random_size=False), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=["image", "label"]), ]) # create a training data loader train_ds = BratsCacheDataset( root_dir=args.dir, transform=train_transforms, section="training", num_workers=4, cache_rate=args.cache_rate, shuffle=True, ) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) # validation transforms and dataset val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), Orientationd(keys=["image", "label"], axcodes="RAS"), CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), ToTensord(keys=["image", "label"]), ]) val_ds = BratsCacheDataset( root_dir=args.dir, transform=val_transforms, section="validation", num_workers=4, cache_rate=args.cache_rate, shuffle=False, ) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if dist.get_rank() == 0: # Logging for TensorBoard writer = SummaryWriter(log_dir=args.log_dir) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") if args.network == "UNet": model = UNet( dimensions=3, in_channels=4, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) else: model = SegResNet(in_channels=4, out_channels=3, init_filters=16, dropout_prob=0.2).to(device) loss_function = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5, amsgrad=True) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[args.local_rank]) # start a typical PyTorch training total_epoch = args.epochs best_metric = -1000000 best_metric_epoch = -1 epoch_time = AverageMeter("Time", ":6.3f") progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ") end = time.time() print(f"Time elapsed before training: {end-total_start}") for epoch in range(total_epoch): train_loss = train(train_loader, model, loss_function, optimizer, epoch, args, device) epoch_time.update(time.time() - end) if epoch % args.print_freq == 0: progress.display(epoch) if dist.get_rank() == 0: writer.add_scalar("Loss/train", train_loss, epoch) if (epoch + 1) % args.val_interval == 0: metric, metric_tc, metric_wt, metric_et = evaluate( model, val_loader, device) if dist.get_rank() == 0: writer.add_scalar("Mean Dice/val", metric, epoch) writer.add_scalar("Mean Dice TC/val", metric_tc, epoch) writer.add_scalar("Mean Dice WT/val", metric_wt, epoch) writer.add_scalar("Mean Dice ET/val", metric_et, epoch) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) end = time.time() print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}") if dist.get_rank() == 0: print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) # all processes should see same parameters as they all start from same # random parameters and gradients are synchronized in backward passes, # therefore, saving it in one process is sufficient torch.save(model.state_dict(), "final_model.pth") writer.flush() dist.destroy_process_group()
def test_test_time_augmentation(self): input_size = (20, 40) # test different input data shape to pad list collate keys = ["image", "label"] num_training_ims = 10 train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) device = "cuda" if torch.cuda.is_available() else "cpu" transforms = Compose( [ AddChanneld(keys), RandAffined( keys, prob=1.0, spatial_size=(30, 30), rotate_range=(np.pi / 3, np.pi / 3), translate_range=(3, 3), scale_range=((0.8, 1), (0.8, 1)), padding_mode="zeros", mode=("bilinear", "nearest"), as_tensor_output=False, ), CropForegroundd(keys, source_key="image"), DivisiblePadd(keys, 4), ] ) train_ds = CacheDataset(train_data, transforms) # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) num_epochs = 10 for _ in trange(num_epochs): epoch_loss = 0 for batch_data in train_loader: inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= len(train_loader) post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) tt_aug = TestTimeAugmentation( transform=transforms, batch_size=5, num_workers=0, inferrer_fn=model, device=device, to_tensor=True, output_device="cpu", post_func=post_trans, ) mode, mean, std, vvc = tt_aug(test_data) self.assertEqual(mode.shape, (1,) + input_size) self.assertEqual(mean.shape, (1,) + input_size) self.assertTrue(all(np.unique(mode) == (0, 1))) self.assertGreaterEqual(mean.min(), 0.0) self.assertLessEqual(mean.max(), 1.0) self.assertEqual(std.shape, (1,) + input_size) self.assertIsInstance(vvc, float)
class UNet2DSegmenter(AbstractBaseLearner): """Segmenter based on the U-Net architecture.""" def __init__( self, architecture: SegmentationArchitectures = SegmentationArchitectures.ResidualUNet2D, loss: SegmentationLosses = SegmentationLosses.GeneralizedDiceLoss, optimizer: Optimizers = Optimizers.Adam, mask_type: MaskType = MaskType.TIFF_LABELS, in_channels: int = 1, out_channels: int = 3, roi_size: Tuple[int, int] = (384, 384), num_filters_in_first_layer: int = 16, learning_rate: float = 0.001, weight_decay: float = 0.0001, momentum: float = 0.9, num_epochs: int = 400, batch_sizes: Tuple[int, int, int, int] = (8, 1, 1, 1), num_workers: Tuple[int, int, int, int] = (4, 4, 1, 1), validation_step: int = 2, sliding_window_batch_size: int = 4, class_names: Tuple[str, ...] = ("Background", "Object", "Border"), experiment_name: str = "Unet", model_name: str = "best_model", seed: int = 4294967295, working_dir: str = '.', stdout: TextIOWrapper = sys.stdout, stderr: TextIOWrapper = sys.stderr ): """Constructor. @param mask_type: MaskType Type of mask: defines file type, mask geometry and they way pixels are assigned to the various classes. @see qu.data.model.MaskType @param architecture: SegmentationArchitectures Core network architecture: one of (SegmentationArchitectures.ResidualUNet2D, SegmentationArchitectures.AttentionUNet2D) @param loss: SegmentationLosses Loss function: currently only SegmentationLosses.GeneralizedDiceLoss is supported @param optimizer: Optimizers Optimizer: one of (Optimizers.Adam, Optimizers.SGD) @param in_channels: int, optional: default = 1 Number of channels in the input (e.g. 1 for gray-value images). @param out_channels: int, optional: default = 3 Number of channels in the output (classes). @param roi_size: Tuple[int, int], optional: default = (384, 384) Crop area (and input size of the U-Net network) used for training and validation/prediction. @param num_filters_in_first_layer: int Number of filters in the first layer. Every subsequent layer doubles the number of filters. @param learning_rate: float, optional: default = 1e-3 Initial learning rate for the optimizer. @param weight_decay: float, optional: default = 1e-4 Weight decay of the learning rate for the optimizer. Used by the Adam optimizer. @param momentum: float, optional: default = 0.9 Momentum of the accelerated gradient for the optimizer. Used by the SGD optimizer. @param num_epochs: int, optional: default = 400 Number of epochs for training. @param batch_sizes: Tuple[int, int, int], optional: default = (8, 1, 1, 1) Batch sizes for training, validation, testing, and prediction, respectively. @param num_workers: Tuple[int, int, int], optional: default = (4, 4, 1, 1) Number of workers for training, validation, testing, and prediction, respectively. @param validation_step: int, optional: default = 2 Number of training steps before the next validation is performed. @param sliding_window_batch_size: int, optional: default = 4 Number of batches for sliding window inference during validation and prediction. @param class_names: Tuple[str, ...], optional: default = ("Background", "Object", "Border") Name of the classes for logging validation curves. @param experiment_name: str, optional: default = "" Name of the experiment that maps to the folder that contains training information (to be used by tensorboard). Please note, current datetime will be appended. @param model_name: str, optional: default = "best_model.ph" Name of the file that stores the best model. Please note, current datetime will be appended (before the extension). @param seed: int, optional; default = 4294967295 Set random seed for modules to enable or disable deterministic training. @param working_dir: str, optional, default = "." Working folder where to save the model weights and the logs for tensorboard. """ # Call base constructor super().__init__() # Standard pipe wrappers self._stdout = stdout self._stderr = stderr # Device (initialize as "cpu") self._device = "cpu" # Architecture, loss function and optimizer self._option_architecture = architecture self._option_loss = loss self._option_optimizer = optimizer self._learning_rate = learning_rate self._weight_decay = weight_decay self._momentum = momentum # Mask type self._mask_type = mask_type # Input and output channels self._in_channels = in_channels self._out_channels = out_channels # Define hyper parameters self._roi_size = roi_size self._num_filters_in_first_layer = num_filters_in_first_layer self._training_batch_size = batch_sizes[0] self._validation_batch_size = batch_sizes[1] self._test_batch_size = batch_sizes[2] self._prediction_batch_size = batch_sizes[3] self._training_num_workers = num_workers[0] self._validation_num_workers = num_workers[1] self._test_num_workers = num_workers[2] self._prediction_num_workers = num_workers[3] self._n_epochs = num_epochs self._validation_step = validation_step self._sliding_window_batch_size = sliding_window_batch_size # Other parameters self._class_names = out_channels * ["Unknown"] for i in range(min(out_channels, len(class_names))): self._class_names[i] = class_names[i] # Set monai seed set_determinism(seed=seed) # All file names self._train_image_names: list = [] self._train_mask_names: list = [] self._validation_image_names: list = [] self._validation_mask_names: list = [] self._test_image_names: list = [] self._test_mask_names: list = [] # Transforms self._train_image_transforms = None self._train_mask_transforms = None self._validation_image_transforms = None self._validation_mask_transforms = None self._test_image_transforms = None self._test_mask_transforms = None self._prediction_image_transforms = None self._validation_post_transforms = None self._test_post_transforms = None self._prediction_post_transforms = None # Datasets and data loaders self._train_dataset = None self._train_dataloader = None self._validation_dataset = None self._validation_dataloader = None self._test_dataset = None self._test_dataloader = None self._prediction_dataset = None self._prediction_dataloader = None # Set model architecture, loss function, metric and optimizer self._model = None self._training_loss_function = None self._optimizer = None self._validation_metric = None # Working directory, model file name and experiment name for Tensorboard logs. # The file names will be redefined at the beginning of the training. self._working_dir = Path(working_dir).resolve() self._raw_experiment_name = experiment_name self._raw_model_file_name = model_name # Keep track of the full path of the best model self._best_model = '' # Keep track of last error message self._message = "" def train(self) -> bool: """Run training in a separate thread (added to the global application ThreadPool).""" # Free memory on the GPU self._clear_session() # Check that the data is set properly if len(self._train_image_names) == 0 or \ len(self._train_mask_names) == 0 or \ len(self._validation_image_names) == 0 or \ len(self._validation_mask_names) == 0: self._message = "No training/validation data found." return False if len(self._train_image_names) != len(self._train_mask_names) == 0: self._message = "The number of training images does not match the number of training masks." return False if len(self._validation_image_names) != len(self._validation_mask_names) == 0: self._message = "The number of validation images does not match the number of validation masks." return False # Define the transforms self._define_training_transforms() # Define the datasets and data loaders self._define_training_data_loaders() # Instantiate the model self._define_model() # Define the loss function self._define_training_loss() # Define the optimizer (with default parameters) self._define_optimizer() # Define the validation metric self._define_validation_metric() # Define experiment name and model name experiment_name, model_file_name = self._prepare_experiment_and_model_names() # Keep track of the best model file name self._best_model = model_file_name # Enter the main training loop best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() # Initialize TensorBoard's SummaryWriter writer = SummaryWriter(experiment_name) for epoch in range(self._n_epochs): # Inform self._print_header(f"Epoch {epoch + 1}/{self._n_epochs}") # Switch to training mode self._model.train() epoch_loss = 0 step = 0 for batch_data in self._train_dataloader: # Update step step += 1 # Get the next batch and move it to device inputs, labels = batch_data[0].to(self._device), batch_data[1].to(self._device) # Zero the gradient buffers self._optimizer.zero_grad() # Forward pass outputs = self._model(inputs) # Calculate the loss loss = self._training_loss_function(outputs, labels) # Back-propagate loss.backward() # Update weights (optimize) self._optimizer.step() # Update and store metrics epoch_loss += loss.item() epoch_len = len(self._train_dataset) / self._train_dataloader.batch_size if epoch_len != int(epoch_len): epoch_len = int(epoch_len) + 1 print(f"Batch {step}/{epoch_len}: train_loss = {loss.item():.4f}", file=self._stdout) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"Average loss = {epoch_loss:.4f}", file=self._stdout) writer.add_scalar("average_train_loss", epoch_loss, epoch + 1) # Validation if (epoch + 1) % self._validation_step == 0: self._print_header("Validation") # Switch to evaluation mode self._model.eval() # Make sure not to update the gradients with torch.no_grad(): # Global metrics metric_sum = 0.0 metric_count = 0 metric = 0.0 # Keep track of the metrics for all classes metric_sum_classes = self._out_channels * [0.0] metric_count_classes = self._out_channels * [0] metric_classes = self._out_channels * [0.0] for val_data in self._validation_dataloader: # Get the next batch and move it to device val_images, val_labels = val_data[0].to(self._device), val_data[1].to(self._device) # Apply sliding inference over ROI size val_outputs = sliding_window_inference( val_images, self._roi_size, self._sliding_window_batch_size, self._model ) val_outputs = self._validation_post_transforms(val_outputs) # Compute overall metric value, not_nans = self._validation_metric( y_pred=val_outputs, y=val_labels ) not_nans = not_nans.item() metric_count += not_nans metric_sum += value.item() * not_nans # Compute metric for each class for c in range(self._out_channels): value_obj, not_nans = self._validation_metric( y_pred=val_outputs[:, c:c + 1], y=val_labels[:, c:c + 1] ) not_nans = not_nans.item() metric_count_classes[c] += not_nans metric_sum_classes[c] += value_obj.item() * not_nans # Global metric metric = metric_sum / metric_count metric_values.append(metric) # Metric per class for c in range(self._out_channels): metric_classes[c] = metric_sum_classes[c] / metric_count_classes[c] # Print summary print(f"Global metric = {metric:.4f} ", file=self._stdout) for c in range(self._out_channels): print(f"Class '{self._class_names[c]}' metric = {metric_classes[c]:.4f} ", file=self._stdout) # Do we have the best metric so far? if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save( self._model.state_dict(), model_file_name ) print(f"New best global metric = {best_metric:.4f} at epoch: {best_metric_epoch}", file=self._stdout) print(f"Saved best model '{Path(model_file_name).name}'", file=self._stdout) # Add validation loss and metrics to log writer.add_scalar("val_mean_dice_loss", metric, epoch + 1) for c in range(self._out_channels): metric_name = f"val_{self._class_names[c].lower()}_metric" writer.add_scalar(metric_name, metric_classes[c], epoch + 1) print(f"Training completed. Best_metric = {best_metric:.4f} at epoch: {best_metric_epoch}", file=self._stdout) writer.close() # Return success return True def test_predict( self, target_folder: Union[Path, str] = '', model_path: Union[Path, str] = '' ) -> bool: """Run prediction on predefined test data. @param target_folder: Path|str, optional: default = '' Path to the folder where to store the predicted images. If not specified, if defaults to '{working_dir}/predictions'. See constructor. @param model_path: Path|str, optional: default = '' Full path to the model to use. If omitted and a training was just run, the path to the model with the best metric is already stored and will be used. @see get_best_model_path() @return True if the prediction was successful, False otherwise. """ # Inform self._print_header("Test prediction") # Get the device self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # If the model is not in memory, instantiate it first if self._model is None: self._define_model() # If the path to the best model was not set, use current one (if set) if model_path == '': model_path = self.get_best_model_path() # Try loading the model weights: they must be compatible # with the model in memory try: checkpoint = torch.load( model_path, map_location=torch.device('cpu') ) self._model.load_state_dict(checkpoint) print(f"Loaded best metric model {model_path}.", file=self._stdout) except Exception as e: self._message = "Error: there was a problem loading the model! Aborting." return False # If the target folder is not specified, set it to the standard predictions out if target_folder == '': target_folder = Path(self._working_dir) / "tests" else: target_folder = Path(target_folder) target_folder.mkdir(parents=True, exist_ok=True) # Switch to evaluation mode self._model.eval() indx = 0 # Make sure not to update the gradients with torch.no_grad(): for test_data in self._test_dataloader: # Get the next batch and move it to device test_images, test_masks = test_data[0].to(self._device), test_data[1].to(self._device) # Apply sliding inference over ROI size test_outputs = sliding_window_inference( test_images, self._roi_size, self._sliding_window_batch_size, self._model ) test_outputs = self._test_post_transforms(test_outputs) # Retrieve the image from the GPU (if needed) pred = test_outputs.cpu().numpy().squeeze() # Prepare the output file name basename = os.path.splitext(os.path.basename(self._test_image_names[indx]))[0] basename = basename.replace('train_', 'pred_') # Convert to label image label_img = self._prediction_to_label_tiff_image(pred) # Save label image as tiff file label_file_name = os.path.join( str(target_folder), basename + '.tif') with TiffWriter(label_file_name) as tif: tif.save(label_img) # Inform print(f"Saved {str(target_folder)}/{basename}.tif", file=self._stdout) # Update the index indx += 1 # Inform print(f"Test prediction completed.", file=self._stdout) # Return success return True def predict(self, input_folder: Union[Path, str], target_folder: Union[Path, str], model_path: Union[Path, str] ): """Run prediction. @param input_folder: Path|str Path to the folder where to store the predicted images. @param target_folder: Path|str Path to the folder where to store the predicted images. @param model_path: Path|str Full path to the model to use. @return True if the prediction was successful, False otherwise. """ # Inform self._print_header("Prediction") # Get the device self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # If the model is not in memory, instantiate it first if self._model is None: self._define_model() # Try loading the model weights: they must be compatible # with the model in memory try: checkpoint = torch.load( model_path, map_location=torch.device('cpu') ) self._model.load_state_dict(checkpoint) print(f"Loaded best metric model {model_path}.", file=self._stdout) except Exception as e: self._message = "Error: there was a problem loading the model! Aborting." return False # Make sure the target folder exists if type(target_folder) == str and target_folder == '': self._message = "Error: please specify a valid target folder! Aborting." return False target_folder = Path(target_folder) target_folder.mkdir(parents=True, exist_ok=True) # Get prediction dataloader if not self._define_prediction_data_loaders(input_folder): self._message = "Error: could not instantiate prediction dataloader! Aborting." return False # Switch to evaluation mode self._model.eval() indx = 0 # Make sure not to update the gradients with torch.no_grad(): for prediction_data in self._prediction_dataloader: # Get the next batch and move it to device prediction_images = prediction_data.to(self._device) # Apply sliding inference over ROI size prediction_outputs = sliding_window_inference( prediction_images, self._roi_size, self._sliding_window_batch_size, self._model ) prediction_outputs = self._prediction_post_transforms(prediction_outputs) # Retrieve the image from the GPU (if needed) pred = prediction_outputs.cpu().numpy().squeeze() # Prepare the output file name basename = os.path.splitext(os.path.basename(self._prediction_image_names[indx]))[0] basename = "pred_" + basename # Convert to label image label_img = self._prediction_to_label_tiff_image(pred) # Save label image as tiff file label_file_name = os.path.join( str(target_folder), basename + '.tif') with TiffWriter(label_file_name) as tif: tif.save(label_img) # Inform print(f"Saved {str(target_folder)}/{basename}.tif", file=self._stdout) # Update the index indx += 1 # Inform print(f"Prediction completed.", file=self._stdout) # Return success return True def set_training_data(self, train_image_names, train_mask_names, val_image_names, val_mask_names, test_image_names, test_mask_names) -> None: """Set all training files names. @param train_image_names: list List of training image names. @param train_mask_names: list List of training mask names. @param val_image_names: list List of validation image names. @param val_mask_names: list List of validation image names. @param test_image_names: list List of test image names. @param test_mask_names: list List of test image names. """ # First validate all data if len(train_image_names) != len(train_mask_names): raise ValueError("The number of training images does not match the number of training masks.") if len(val_image_names) != len(val_mask_names): raise ValueError("The number of validation images does not match the number of validation masks.") if len(test_image_names) != len(test_mask_names): raise ValueError("The number of test images does not match the number of test masks.") # Training data self._train_image_names = train_image_names self._train_mask_names = train_mask_names # Validation data self._validation_image_names = val_image_names self._validation_mask_names = val_mask_names # Test data self._test_image_names = test_image_names self._test_mask_names = test_mask_names @staticmethod def _prediction_to_label_tiff_image(prediction): """Save the prediction to a label image (TIFF)""" # Convert to label image label_img = one_hot_stack_to_label_image( prediction, first_index_is_background=True, channels_first=True, dtype=np.uint16 ) return label_img def _define_training_transforms(self): """Define and initialize all training data transforms. * training set images transform * training set masks transform * validation set images transform * validation set masks transform * validation set images post-transform * test set images transform * test set masks transform * test set images post-transform * prediction set images transform * prediction set images post-transform @return True if data transforms could be instantiated, False otherwise. """ if self._mask_type == MaskType.UNKNOWN: raise Exception("The mask type is unknown. Cannot continue!") # Depending on the mask type, we will need to adapt the Mask Loader # and Transform. We start by initializing the most common types. MaskLoader = LoadMask(self._mask_type) MaskTransform = Identity # Adapt the transform for the LABEL types if self._mask_type == MaskType.TIFF_LABELS or self._mask_type == MaskType.NUMPY_LABELS: MaskTransform = ToOneHot(num_classes=self._out_channels) # The H5_ONE_HOT type requires a different loader if self._mask_type == MaskType.H5_ONE_HOT: # MaskLoader: still missing raise Exception("HDF5 one-hot masks are not supported yet!") # Define transforms for training self._train_image_transforms = Compose( [ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), RandSpatialCrop(self._roi_size, random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor() ] ) self._train_mask_transforms = Compose( [ MaskLoader, MaskTransform, RandSpatialCrop(self._roi_size, random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor() ] ) # Define transforms for validation self._validation_image_transforms = Compose( [ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor() ] ) self._validation_mask_transforms = Compose( [ MaskLoader, MaskTransform, ToTensor() ] ) # Define transforms for testing self._test_image_transforms = Compose( [ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor() ] ) self._test_mask_transforms = Compose( [ MaskLoader, MaskTransform, ToTensor() ] ) # Post transforms self._validation_post_transforms = Compose( [ Activations(softmax=True), AsDiscrete(threshold_values=True) ] ) self._test_post_transforms = Compose( [ Activations(softmax=True), AsDiscrete(threshold_values=True) ] ) def _define_training_data_loaders(self) -> bool: """Initialize training datasets and data loaders. @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders! @return True if datasets and data loaders could be instantiated, False otherwise. """ # Optimize arguments if sys.platform == 'win32': persistent_workers = True pin_memory = False else: persistent_workers = False pin_memory = torch.cuda.is_available() if len(self._train_image_names) == 0 or \ len(self._train_mask_names) == 0 or \ len(self._validation_image_names) == 0 or \ len(self._validation_mask_names) == 0 or \ len(self._test_image_names) == 0 or \ len(self._test_mask_names) == 0: self._train_dataset = None self._train_dataloader = None self._validation_dataset = None self._validation_dataloader = None self._test_dataset = None self._test_dataloader = None return False # Training self._train_dataset = ArrayDataset( self._train_image_names, self._train_image_transforms, self._train_mask_names, self._train_mask_transforms ) self._train_dataloader = DataLoader( self._train_dataset, batch_size=self._training_batch_size, shuffle=False, num_workers=self._training_num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory ) # Validation self._validation_dataset = ArrayDataset( self._validation_image_names, self._validation_image_transforms, self._validation_mask_names, self._validation_mask_transforms ) self._validation_dataloader = DataLoader( self._validation_dataset, batch_size=self._validation_batch_size, shuffle=False, num_workers=self._validation_num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory ) # Test self._test_dataset = ArrayDataset( self._test_image_names, self._test_image_transforms, self._test_mask_names, self._test_mask_transforms ) self._test_dataloader = DataLoader( self._test_dataset, batch_size=self._test_batch_size, shuffle=False, num_workers=self._test_num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory ) return True def _define_prediction_transforms(self): """Define and initialize all prediction data transforms. * prediction set images transform * prediction set images post-transform @return True if data transforms could be instantiated, False otherwise. """ # Define transforms for prediction self._prediction_image_transforms = Compose( [ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor(), ] ) self._prediction_post_transforms = Compose( [ Activations(softmax=True), AsDiscrete(threshold_values=True), ] ) def _define_prediction_data_loaders( self, prediction_folder_path: Union[Path, str] ) -> bool: """Initialize prediction datasets and data loaders. @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders! @return True if datasets and data loaders could be instantiated, False otherwise. """ # Check that the path exists prediction_folder_path = Path(prediction_folder_path) if not prediction_folder_path.is_dir(): return False # Scan for images self._prediction_image_names = natsorted( glob(str(Path(prediction_folder_path) / "*.tif")) ) # Optimize arguments if sys.platform == 'win32': persistent_workers = True pin_memory = False else: persistent_workers = False pin_memory = torch.cuda.is_available() if len(self._prediction_image_names) == 0: self._prediction_dataset = None self._prediction_dataloader = None return False # Define the transforms self._define_prediction_transforms() # Prediction self._prediction_dataset = Dataset( self._prediction_image_names, self._prediction_image_transforms ) self._prediction_dataloader = DataLoader( self._prediction_dataset, batch_size=self._test_batch_size, shuffle=False, num_workers=self._test_num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory ) return True def get_message(self): """Return last error message.""" return self._message def get_best_model_path(self): """Return the full path to the best model.""" return self._best_model def _clear_session(self) -> None: """Try clearing cache on the GPU.""" if self._device != "cpu": torch.cuda.empty_cache() def _define_model(self) -> None: """Instantiate the U-Net architecture.""" # Create U-Net self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device '{self._device}'.", file=self._stdout) # Try to free memory on the GPU if self._device != "cpu": torch.cuda.empty_cache() # Instantiate the requested model if self._option_architecture == SegmentationArchitectures.ResidualUNet2D: # Monai's UNet self._model = UNet( dimensions=2, in_channels=self._in_channels, out_channels=self._out_channels, channels=tuple((self._num_filters_in_first_layer * 2**i for i in range(0, 5))), strides=(2, 2, 2, 2), num_res_units=2 ).to(self._device) elif self._option_architecture == SegmentationArchitectures.AttentionUNet2D: # Attention U-Net self._model = AttentionUNet2D( img_ch=self._in_channels, output_ch=self._out_channels, n1=self._num_filters_in_first_layer ).to(self._device) else: raise ValueError(f"Unexpected architecture {self._option_architecture}! Aborting.") def _define_training_loss(self) -> None: """Define the loss function.""" if self._option_loss == SegmentationLosses.GeneralizedDiceLoss: self._training_loss_function = GeneralizedDiceLoss( include_background=True, to_onehot_y=False, softmax=True, batch=True, ) else: raise ValueError(f"Unknown loss option {self._option_loss}! Aborting.") def _define_optimizer(self) -> None: """Define the optimizer.""" if self._model is None: return if self._option_optimizer == Optimizers.Adam: self._optimizer = Adam( self._model.parameters(), self._learning_rate, weight_decay=self._weight_decay, amsgrad=True ) elif self._option_optimizer == Optimizers.SGD: self._optimizer = SGD( self._model.parameters(), lr=self._learning_rate, momentum=self._momentum ) else: raise ValueError(f"Unknown optimizer option {self._option_optimizer}! Aborting.") def _define_validation_metric(self): """Define the metric for validation function.""" self._validation_metric = DiceMetric( include_background=True, reduction="mean" ) def _prepare_experiment_and_model_names(self) -> Tuple[str, str]: """Prepare the experiment and model names. @return experiment_file_name, model_file_name Current date time is appended and the full path is returned. """ # Make sure the "runs" subfolder exists runs_dir = Path(self._working_dir) / "runs" runs_dir.mkdir(parents=True, exist_ok=True) now = datetime.now() # current date and time date_time = now.strftime("%Y%m%d_%H%M%S") # Experiment name experiment_name = f"{self._raw_experiment_name}_{str(self._option_architecture)}_{date_time}" \ if self._raw_experiment_name != "" \ else f"{str(self._option_architecture)}_{date_time}" experiment_name = runs_dir / experiment_name # Best model file name name = Path(self._raw_model_file_name).stem model_file_name = f"{name}_{date_time}.pth" model_file_name = runs_dir / model_file_name return str(experiment_name), str(model_file_name) def _print_header(self, header_text, line_length=80, file=None): """Print a section header.""" if file is None: file = self._stdout print(f"{line_length * '-'}", file=file) print(f"{header_text}", file=self._stdout) print(f"{line_length * '-'}", file=file)
def configure(self): self.set_device() network = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(self.device) if self.multi_gpu: network = DistributedDataParallel( module=network, device_ids=[self.device], find_unused_parameters=False, ) train_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=[1.0, 1.0, 1.0], mode=["bilinear", "nearest"]), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=("image", "label")), ]) train_datalist = load_decathlon_datalist(self.data_list_file_path, True, "training") if self.multi_gpu: train_datalist = partition_dataset( data=train_datalist, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True, )[dist.get_rank()] train_ds = CacheDataset( data=train_datalist, transform=train_transforms, cache_num=32, cache_rate=1.0, num_workers=4, ) train_data_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, ) val_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), ToTensord(keys=("image", "label")), ]) val_datalist = load_decathlon_datalist(self.data_list_file_path, True, "validation") val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4) val_data_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=4, ) post_transform = Compose([ Activationsd(keys="pred", softmax=True), AsDiscreted( keys=["pred", "label"], argmax=[True, False], to_onehot=True, n_classes=2, ), ]) # metric key_val_metric = { "val_mean_dice": MeanDice( include_background=False, output_transform=lambda x: (x["pred"], x["label"]), device=self.device, ) } val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver( save_dir=self.ckpt_dir, save_dict={"model": network}, save_key_metric=True, ), TensorBoardStatsHandler(log_dir=self.ckpt_dir, output_transform=lambda x: None), ] self.eval_engine = SupervisedEvaluator( device=self.device, val_data_loader=val_data_loader, network=network, inferer=SlidingWindowInferer( roi_size=[160, 160, 160], sw_batch_size=4, overlap=0.5, ), post_transform=post_transform, key_val_metric=key_val_metric, val_handlers=val_handlers, amp=self.amp, ) optimizer = torch.optim.Adam(network.parameters(), self.learning_rate) loss_function = DiceLoss(to_onehot_y=True, softmax=True) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=self.eval_engine, interval=self.val_interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler( log_dir=self.ckpt_dir, tag_name="train_loss", output_transform=lambda x: x["loss"], ), ] self.train_engine = SupervisedTrainer( device=self.device, max_epochs=self.max_epochs, train_data_loader=train_data_loader, network=network, optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), post_transform=post_transform, key_train_metric=None, train_handlers=train_handlers, amp=self.amp, ) if self.local_rank > 0: self.train_engine.logger.setLevel(logging.WARNING) self.eval_engine.logger.setLevel(logging.WARNING)
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
max_epochs = 6 learning_rate = 1e-4 val_interval = 2 model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) loss_function = DiceCELoss( to_onehot_y=True, softmax=True, squared_pred=True, batch=True ) optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() dice_metric = DiceMetric( include_background=True, reduction="mean", get_not_nans=False ) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)] ) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) best_metric = -1 best_metric_epoch = -1 best_metrics_epochs_and_time = [[], [], []] epoch_loss_values = [] metric_values = []
class UNet_DF(pl.LightningModule): def __init__(self, hparams): super().__init__() self.hparams = hparams self.unet = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(64, 128, 258, 512, 1024), strides=(2, 2, 2, 2), norm=monai.networks.layers.Norm.BATCH, dropout=0, ) self.sample_masks = [] # Data setup def setup(self, stage): data_df = pd.read_csv( '/data/shared/prostate/yale_prostate/input_lists/MR_yale.csv') train_imgs = data_df['IMAGE'][0:295].tolist() train_masks = data_df['SEGM'][0:295].tolist() train_dicts = [{ 'image': image, 'mask': mask } for (image, mask) in zip(train_imgs, train_masks)] train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.15) # Basic transforms data_keys = ["image", "mask"] data_transforms = Compose([ LoadNiftid(keys=data_keys), AddChanneld(keys=data_keys), NormalizeIntensityd(keys="image"), RandCropByPosNegLabeld(keys=data_keys, label_key="mask", spatial_size=self.hparams.patch_size, num_samples=4, image_key="image"), ]) self.train_dataset = monai.data.CacheDataset( data=train_dicts, transform=Compose([data_transforms, ToTensord(keys=data_keys)]), cache_rate=1.0) self.val_dataset = monai.data.CacheDataset( data=val_dicts, transform=Compose([data_transforms, ToTensord(keys=data_keys)]), cache_rate=1.0) def train_dataloader(self): return monai.data.DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers) def val_dataloader(self): return monai.data.DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) # Training setup def forward(self, image): return self.unet(image) def criterion(self, y_hat, y): dice_loss = monai.losses.DiceLoss(to_onehot_y=True, softmax=True) focal_loss = monai.losses.FocalLoss() return dice_loss(y_hat, y) + focal_loss(y_hat, y) def training_step(self, batch, batch_idx): inputs, labels = batch['image'], batch['mask'] outputs = self(inputs) loss = self.criterion(outputs, labels) self.logger.log_metrics({"loss/train": loss}, self.global_step) return {'loss': loss} def configure_optimizers(self): lr = self.hparams.lr optimizer = torch.optim.Adam(self.unet.parameters(), lr=lr) return optimizer def validation_step(self, batch, batch_idx): inputs, labels = ( batch["image"], batch["mask"], ) outputs = self(inputs) # Sample masks if self.current_epoch != 0: middle = int(outputs[0].argmax(0).shape[2] / 2) image = outputs[0].argmax(0)[:, :, middle].unsqueeze(0).detach() self.sample_masks.append(image) loss = self.criterion(outputs, labels) return {"val_loss": loss} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() self.logger.log_metrics({"loss/val": avg_loss}, self.current_epoch) if self.current_epoch != 0: grid = torchvision.utils.make_grid(self.sample_masks) self.logger.experiment.add_image('sample_masks', grid, self.current_epoch) self.sample_masks = [] return {"val_loss": avg_loss}
class MaskGAN(pl.LightningModule): def __init__(self, hparams): super().__init__() self.hparams = hparams self.generator = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(64, 128, 258, 512, 1024), strides=(2, 2, 2, 2), norm=monai.networks.layers.Norm.BATCH, dropout=0, ) self.discriminator = Discriminator( in_shape=self.hparams.patch_size, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), norm=monai.networks.layers.Norm.BATCH, ) self.generated_masks = None self.sample_masks = [] # Data setup def setup(self, stage): data_df = pd.read_csv( '/data/shared/prostate/yale_prostate/input_lists/MR_yale.csv') train_imgs = data_df['IMAGE'][0:295].tolist() train_masks = data_df['SEGM'][0:295].tolist() train_dicts = [{ 'image': image, 'mask': mask } for (image, mask) in zip(train_imgs, train_masks)] train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.15) # Basic transforms data_keys = ["image", "mask"] data_transforms = Compose([ LoadNiftid(keys=data_keys), AddChanneld(keys=data_keys), NormalizeIntensityd(keys="image"), RandCropByPosNegLabeld(keys=data_keys, label_key="mask", spatial_size=self.hparams.patch_size, num_samples=4, image_key="image"), ]) self.train_dataset = monai.data.CacheDataset( data=train_dicts, transform=Compose([data_transforms, ToTensord(keys=data_keys)]), cache_rate=1.0) self.val_dataset = monai.data.CacheDataset( data=val_dicts, transform=Compose([data_transforms, ToTensord(keys=data_keys)]), cache_rate=1.0) def train_dataloader(self): return monai.data.DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers) def val_dataloader(self): return monai.data.DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) # Training setup def forward(self, image): return self.generator(image) def generator_loss(self, y_hat, y): dice_loss = monai.losses.DiceLoss(to_onehot_y=True, softmax=True) return dice_loss(y_hat, y) def adversarial_loss(self, y_hat, y): return F.binary_cross_entropy(y_hat, y) def training_step(self, batch, batch_idx, optimizer_idx): inputs, labels = batch['image'], batch['mask'] batch_size = inputs.size(0) # Generator training if optimizer_idx == 0: self.generated_masks = self(inputs) # Loss from difference between real and generated masks g_loss = self.generator_loss(self.generated_masks, labels) # Loss from discriminator # The generator wants the discriminator to be wrong, # so the wrong labels are used fake_labels = torch.ones(batch_size, 1).cuda(inputs.device.index) d_loss = self.adversarial_loss( self.discriminator( self.generated_masks.argmax(1).type( torch.FloatTensor).cuda(inputs.device.index)), fake_labels) avg_loss = g_loss + 0.5 * d_loss self.logger.log_metrics({"g_train/g_loss": g_loss}, self.global_step) self.logger.log_metrics({"g_train/d_loss": d_loss}, self.global_step) self.logger.log_metrics({"g_train/tot_loss": avg_loss}, self.global_step) return {'loss': avg_loss} # Discriminator trainig else: # Learning real masks real_labels = torch.ones(batch_size, 1).cuda(inputs.device.index) real_loss = self.adversarial_loss( self.discriminator( labels.squeeze(1).type(torch.FloatTensor).cuda( inputs.device.index)), real_labels) # Learning "fake" masks fake_labels = torch.zeros(batch_size, 1).cuda(inputs.device.index) fake_loss = self.adversarial_loss( self.discriminator( self.generated_masks.argmax(1).detach().type( torch.FloatTensor).cuda(inputs.device.index)), fake_labels) avg_loss = real_loss + fake_loss self.logger.log_metrics({"d_train/real_loss": real_loss}, self.global_step) self.logger.log_metrics({"d_train/fake_loss": fake_loss}, self.global_step) self.logger.log_metrics({"d_train/tot_loss": avg_loss}, self.global_step) return {'loss': avg_loss} def configure_optimizers(self): lr = self.hparams.lr g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=lr) d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr) return [g_optimizer, d_optimizer], [] def validation_step(self, batch, batch_idx): inputs, labels = ( batch["image"], batch["mask"], ) outputs = self(inputs) # Sample masks if self.current_epoch != 0: middle = int(outputs[0].argmax(0).shape[2] / 2) image = outputs[0].argmax(0)[:, :, middle].unsqueeze(0).detach() self.sample_masks.append(image) loss = self.generator_loss(outputs, labels) return {"val_loss": loss} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() self.logger.log_metrics({"val/loss": avg_loss}, self.current_epoch) if self.current_epoch != 0: grid = torchvision.utils.make_grid(self.sample_masks) self.logger.experiment.add_image('sample_masks', grid, self.current_epoch) self.sample_masks = [] return {"val_loss": avg_loss}
def train_process(fast=False): epoch_num = 10 val_interval = 1 train_trans, val_trans = transformations() train_ds = Dataset(data=train_files, transform=train_trans) val_ds = Dataset(data=val_files, transform=val_trans) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True) val_loader = DataLoader(val_ds, batch_size=1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n1 = 16 model = UNet(dimensions=3, in_channels=1, out_channels=2, channels=(n1 * 1, n1 * 2, n1 * 4, n1 * 8, n1 * 16), strides=(2, 2, 2, 2)).to(device) loss_function = DiceLoss(to_onehot_y=True, softmax=True) post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) post_label = AsDiscrete(to_onehot=True, n_classes=2) optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5) best_metric = -1 best_metric_epoch = -1 best_metrics_epochs_and_time = [[], [], []] epoch_loss_values = list() metric_values = list() for epoch in range(epoch_num): print(f"epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['image'].to( device), batch_data['label'].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 for val_data in val_loader: val_inputs, val_labels = val_data['image'].to( device), val_data['label'].to(device) val_outputs = model(val_inputs) val_outputs = post_pred(val_outputs) val_labels = post_label(val_labels) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric epochs_no_improve = 0 best_metric_epoch = epoch + 1 best_metrics_epochs_and_time[0].append(best_metric) best_metrics_epochs_and_time[1].append(best_metric_epoch) torch.save(model.state_dict(), 'sLUMRTL644.pth') else: epochs_no_improve += 1 print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) return epoch_num, epoch_loss_values, metric_values, best_metrics_epochs_and_time
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") max_epochs = 6 learning_rate = 1e-4 val_interval = 2 model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) loss_function = DiceLoss(to_onehot_y=True, softmax=True) optimizer = Adam(model.parameters(), learning_rate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) best_metric = -1 best_metric_epoch = -1 best_metrics_epochs_and_time = [[], [], []] epoch_loss_values = [] metric_values = [] epoch_times = [] total_start = time.time() writer = SummaryWriter(log_dir=out_dir)
class UNet2DRestorer(AbstractBaseLearner): """Restorer based on the U-Net architecture.""" def __init__(self, in_channels: int = 1, out_channels: int = 1, roi_size: Tuple[int, int] = (384, 384), num_epochs: int = 400, batch_sizes: Tuple[int, int, int, int] = (8, 1, 1, 1), num_workers: Tuple[int, int, int, int] = (4, 4, 1, 1), validation_step: int = 2, sliding_window_batch_size: int = 4, experiment_name: str = "", model_name: str = "best_model", seed: int = 4294967295, working_dir: str = '.', stdout: TextIOWrapper = sys.stdout, stderr: TextIOWrapper = sys.stderr): """Constructor. @param in_channels: int, optional: default = 1 Number of channels in the input (e.g. 1 for gray-value images). @param out_channels: int, optional: default = 3 Number of channels in the output (classes). @param roi_size: Tuple[int, int], optional: default = (384, 384) Crop area (and input size of the U-Net network) used for training and validation/prediction. @param num_epochs: int, optional: default = 400 Number of epochs for training. @param batch_sizes: Tuple[int, int, int], optional: default = (8, 1, 1, 1) Batch sizes for training, validation, testing, and prediction, respectively. @param num_workers: Tuple[int, int, int], optional: default = (4, 4, 1, 1) Number of workers for training, validation, testing, and prediction, respectively. @param validation_step: int, optional: default = 2 Number of training steps before the next validation is performed. @param sliding_window_batch_size: int, optional: default = 4 Number of batches for sliding window inference during validation and prediction. @param experiment_name: str, optional: default = "" Name of the experiment that maps to the folder that contains training information (to be used by tensorboard). Please note, current datetime will be appended. @param model_name: str, optional: default = "best_model.ph" Name of the file that stores the best model. Please note, current datetime will be appended (before the extension). @param seed: int, optional; default = 4294967295 Set random seed for modules to enable or disable deterministic training. @param working_dir: str, optional, default = "." Working folder where to save the model weights and the logs for tensorboard. """ # Call base constructor super().__init__() # Standard pipe wrappers self._stdout = stdout self._stderr = stderr # Device (initialize as "cpu") self._device = "cpu" # Input and output channels self._in_channels = in_channels self._out_channels = out_channels # Define hyper parameters self._roi_size = roi_size self._training_batch_size = batch_sizes[0] self._validation_batch_size = batch_sizes[1] self._test_batch_size = batch_sizes[2] self._prediction_batch_size = batch_sizes[3] self._training_num_workers = num_workers[0] self._validation_num_workers = num_workers[1] self._test_num_workers = num_workers[2] self._prediction_num_workers = num_workers[3] self._n_epochs = num_epochs self._validation_step = validation_step self._sliding_window_batch_size = sliding_window_batch_size # Set monai seed set_determinism(seed=seed) # All file names self._train_image_names: list = [] self._train_target_names: list = [] self._validation_image_names: list = [] self._validation_target_names: list = [] self._test_image_names: list = [] self._test_target_names: list = [] # Transforms self._train_image_transforms = None self._train_target_transforms = None self._validation_image_transforms = None self._validation_target_transforms = None self._test_image_transforms = None self._test_target_transforms = None self._prediction_image_transforms = None self._validation_post_transforms = None self._test_post_transforms = None self._prediction_post_transforms = None # Datasets and data loaders self._train_dataset = None self._train_dataloader = None self._validation_dataset = None self._validation_dataloader = None self._test_dataset = None self._test_dataloader = None self._prediction_dataset = None self._prediction_dataloader = None # Set model architecture, loss function, metric and optimizer self._model = None self._training_loss_function = None self._optimizer = None self._validation_metric = None # Working directory, model file name and experiment name for Tensorboard logs. # The file names will be redefined at the beginning of the training. self._working_dir = Path(working_dir).resolve() self._raw_experiment_name = experiment_name self._raw_model_file_name = model_name # Keep track of the full path of the best model self._best_model = '' # Keep track of last error message self._message = "" def train(self) -> bool: """Run training in a separate thread (added to the global application ThreadPool).""" # Free memory on the GPU self._clear_session() # Check that the data is set properly if len(self._train_image_names) == 0 or \ len(self._train_target_names) == 0 or \ len(self._validation_image_names) == 0 or \ len(self._validation_target_names) == 0: self._message = "No training/validation data found." return False if len(self._train_image_names) != len(self._train_target_names) == 0: self._message = "The number of training images does not match the number of training targets." return False if len(self._validation_image_names) != len( self._validation_target_names) == 0: self._message = "The number of validation images does not match the number of validation targets." return False # Define the transforms self._define_transforms() # Define the datasets and data loaders self._define_training_data_loaders() # Instantiate the model self._define_model() # Define the loss function self._define_training_loss() # Define the optimizer (with default parameters) self._define_optimizer() # Define the validation metric self._define_validation_metric() # Define experiment name and model name experiment_name, model_file_name = self._prepare_experiment_and_model_names( ) # Keep track of the best model file name self._best_model = model_file_name # Enter the main training loop lowest_validation_loss = np.Inf lowest_validation_epoch = -1 epoch_loss_values = list() validation_loss_values = list() # Initialize TensorBoard's SummaryWriter writer = SummaryWriter(experiment_name) for epoch in range(self._n_epochs): # Inform self._print_header(f"Epoch {epoch + 1}/{self._n_epochs}") # Switch to training mode self._model.train() epoch_loss = 0 step = 0 for batch_data in self._train_dataloader: # Update step step += 1 # Get the next batch and move it to device inputs, labels = batch_data[0].to( self._device), batch_data[1].to(self._device) # Zero the gradient buffers self._optimizer.zero_grad() # Forward pass outputs = self._model(inputs) # Calculate the loss loss = self._training_loss_function(outputs, labels) # Back-propagate loss.backward() # Update weights (optimize) self._optimizer.step() # Update and store metrics epoch_loss += loss.item() epoch_len = len( self._train_dataset) / self._train_dataloader.batch_size if epoch_len != int(epoch_len): epoch_len = int(epoch_len) + 1 print( f"Batch {step}/{epoch_len}: train_loss = {loss.item():.4f}", file=self._stdout) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"Average loss = {epoch_loss:.4f}", file=self._stdout) writer.add_scalar("average_train_loss", epoch_loss, epoch + 1) # Validation if (epoch + 1) % self._validation_step == 0: self._print_header("Validation") # Switch to evaluation mode self._model.eval() # Make sure not to update the gradients with torch.no_grad(): # Global validation loss validation_loss_sum = 0.0 validation_loss_count = 0 for val_data in self._validation_dataloader: # Get the next batch and move it to device val_images, val_labels = val_data[0].to( self._device), val_data[1].to(self._device) # Apply sliding inference over ROI size val_outputs = sliding_window_inference( val_images, self._roi_size, self._sliding_window_batch_size, self._model) val_outputs = self._validation_post_transforms( val_outputs) # Calculate the validation loss val_loss = self._training_loss_function( val_outputs, val_labels) # Add to the current loss validation_loss_count += 1 validation_loss_sum += val_loss.item() # Global validation loss validation_loss = validation_loss_sum / validation_loss_count validation_loss_values.append(validation_loss) # Print summary print(f"Validation loss = {validation_loss:.4f} ", file=self._stdout) # Do we have the best metric so far? if validation_loss < lowest_validation_loss: lowest_validation_loss = validation_loss lowest_validation_epoch = epoch + 1 torch.save(self._model.state_dict(), model_file_name) print( f"New lowest validation loss = {lowest_validation_loss:.4f} at epoch: {lowest_validation_epoch}", file=self._stdout) print( f"Saved best model '{Path(model_file_name).name}'", file=self._stdout) # Add validation loss and metrics to log writer.add_scalar("val_mean_loss", validation_loss, epoch + 1) print( f"Training completed. Lowest validation loss = {lowest_validation_loss:.4f} at epoch: {lowest_validation_epoch}", file=self._stdout) writer.close() # Return success return True def test_predict(self, target_folder: Union[Path, str] = '', model_path: Union[Path, str] = '') -> bool: """Run prediction on predefined test data. @param target_folder: Path|str, optional: default = '' Path to the folder where to store the predicted images. If not specified, if defaults to '{working_dir}/predictions'. See constructor. @param model_path: Path|str, optional: default = '' Full path to the model to use. If omitted and a training was just run, the path to the model with the best metric is already stored and will be used. @see get_best_model_path() @return True if the prediction was successful, False otherwise. """ # Inform self._print_header("Test prediction") # Get the device self._device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # If the model is not in memory, instantiate it first if self._model is None: self._define_model() # If the path to the best model was not set, use current one (if set) if model_path == '': model_path = self.get_best_model_path() # Try loading the model weights: they must be compatible # with the model in memory try: checkpoint = torch.load(model_path, map_location=torch.device('cpu')) self._model.load_state_dict(checkpoint) print(f"Loaded best metric model {model_path}.", file=self._stdout) except Exception as e: self._message = "Error: there was a problem loading the model! Aborting." return False # If the target folder is not specified, set it to the standard predictions out if target_folder == '': target_folder = Path(self._working_dir) / "tests" else: target_folder = Path(target_folder) target_folder.mkdir(parents=True, exist_ok=True) # Switch to evaluation mode self._model.eval() indx = 0 # Make sure not to update the gradients with torch.no_grad(): for test_data in self._test_dataloader: # Get the next batch and move it to device test_images, test_masks = test_data[0].to( self._device), test_data[1].to(self._device) # Apply sliding inference over ROI size test_outputs = sliding_window_inference( test_images, self._roi_size, self._sliding_window_batch_size, self._model) test_outputs = self._test_post_transforms(test_outputs) # The ToNumpy() transform already causes the Tensor # to be gathered from the GPU to the CPU pred = test_outputs.squeeze() # Prepare the output file name basename = os.path.splitext( os.path.basename(self._test_image_names[indx]))[0] basename = basename.replace('train_', 'pred_') # Save label image as tiff file pred_file_name = os.path.join(str(target_folder), basename + '.tif') with TiffWriter(pred_file_name) as tif: tif.save(pred) # Inform print(f"Saved {str(target_folder)}/{basename}.tif", file=self._stdout) # Update the index indx += 1 # Inform print(f"Test prediction completed.", file=self._stdout) # Return success return True def predict(self, input_folder: Union[Path, str], target_folder: Union[Path, str], model_path: Union[Path, str]): """Run prediction. @param input_folder: Path|str Path to the folder where to store the predicted images. @param target_folder: Path|str Path to the folder where to store the predicted images. @param model_path: Path|str Full path to the model to use. @return True if the prediction was successful, False otherwise. """ # Inform self._print_header("Prediction") # Get the device self._device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # If the model is not in memory, instantiate it first if self._model is None: self._define_model() # Try loading the model weights: they must be compatible # with the model in memory try: checkpoint = torch.load(model_path, map_location=torch.device('cpu')) self._model.load_state_dict(checkpoint) print(f"Loaded best metric model {model_path}.", file=self._stdout) except Exception as e: self._message = "Error: there was a problem loading the model! Aborting." return False # Make sure the target folder exists if type(target_folder) == str and target_folder == '': self._message = "Error: please specify a valid target folder! Aborting." return False target_folder = Path(target_folder) target_folder.mkdir(parents=True, exist_ok=True) # Get prediction dataloader if not self._define_prediction_data_loaders(input_folder): self._message = "Error: could not instantiate prediction dataloader! Aborting." return False # Switch to evaluation mode self._model.eval() indx = 0 # Make sure not to update the gradients with torch.no_grad(): for prediction_data in self._prediction_dataloader: # Get the next batch and move it to device prediction_images = prediction_data.to(self._device) # Apply sliding inference over ROI size prediction_outputs = sliding_window_inference( prediction_images, self._roi_size, self._sliding_window_batch_size, self._model) prediction_outputs = self._prediction_post_transforms( prediction_outputs) # The ToNumpy() transform already causes the Tensor # to be gathered from the GPU to the CPU pred = prediction_outputs.squeeze() # Prepare the output file name basename = os.path.splitext( os.path.basename(self._prediction_image_names[indx]))[0] basename = "pred_" + basename # Save label image as tiff file pred_file_name = os.path.join(str(target_folder), basename + '.tif') with TiffWriter(pred_file_name) as tif: tif.save(pred) # Inform print(f"Saved {str(target_folder)}/{basename}.tif", file=self._stdout) # Update the index indx += 1 # Inform print(f"Prediction completed.", file=self._stdout) # Return success return True def set_training_data(self, train_image_names, train_mask_names, val_image_names, val_mask_names, test_image_names, test_mask_names) -> None: """Set all training files names. @param train_image_names: list List of training image names. @param train_mask_names: list List of training mask names. @param val_image_names: list List of validation image names. @param val_mask_names: list List of validation image names. @param test_image_names: list List of test image names. @param test_mask_names: list List of test image names. """ # First validate all data if len(train_image_names) != len(train_mask_names): raise ValueError( "The number of training images does not match the number of training masks." ) if len(val_image_names) != len(val_mask_names): raise ValueError( "The number of validation images does not match the number of validation masks." ) if len(test_image_names) != len(test_mask_names): raise ValueError( "The number of test images does not match the number of test masks." ) # Training data self._train_image_names = train_image_names self._train_target_names = train_mask_names # Validation data self._validation_image_names = val_image_names self._validation_target_names = val_mask_names # Test data self._test_image_names = test_image_names self._test_target_names = test_mask_names @staticmethod def _prediction_to_label_tiff_image(prediction): """Save the prediction to a label image (TIFF)""" # Convert to label image label_img = one_hot_stack_to_label_image( prediction, first_index_is_background=True, channels_first=True, dtype=np.uint16) return label_img def _define_transforms(self): """Define and initialize all data transforms. * training set images transform * training set targets transform * validation set images transform * validation set targets transform * validation set images post-transform * test set images transform * test set targets transform * test set images post-transform * prediction set images transform * prediction set images post-transform @return True if data transforms could be instantiated, False otherwise. """ # Define transforms for training self._train_image_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), RandSpatialCrop(self._roi_size, random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor() ]) self._train_target_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), RandSpatialCrop(self._roi_size, random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor() ]) # Define transforms for validation self._validation_image_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), ToTensor() ]) self._validation_target_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), ToTensor() ]) # Define transforms for testing self._test_image_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), ToTensor() ]) self._test_target_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), ToTensor() ]) # Define transforms for prediction self._prediction_image_transforms = Compose( [LoadImage(image_only=True), AddChannel(), ToTensor()]) # Post transforms self._validation_post_transforms = Compose([Identity()]) self._test_post_transforms = Compose( [ToNumpy(), ScaleIntensity(0, 65535)]) self._prediction_post_transforms = Compose( [ToNumpy(), ScaleIntensity(0, 65535)]) def _define_training_data_loaders(self) -> bool: """Initialize training datasets and data loaders. @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders! @return True if datasets and data loaders could be instantiated, False otherwise. """ # Optimize arguments if sys.platform == 'win32': persistent_workers = True pin_memory = False else: persistent_workers = False pin_memory = torch.cuda.is_available() if len(self._train_image_names) == 0 or \ len(self._train_target_names) == 0 or \ len(self._validation_image_names) == 0 or \ len(self._validation_target_names) == 0 or \ len(self._test_image_names) == 0 or \ len(self._test_target_names) == 0: self._train_dataset = None self._train_dataloader = None self._validation_dataset = None self._validation_dataloader = None self._test_dataset = None self._test_dataloader = None return False # Training self._train_dataset = ArrayDataset(self._train_image_names, self._train_image_transforms, self._train_target_names, self._train_target_transforms) self._train_dataloader = DataLoader( self._train_dataset, batch_size=self._training_batch_size, shuffle=False, num_workers=self._training_num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory) # Validation self._validation_dataset = ArrayDataset( self._validation_image_names, self._validation_image_transforms, self._validation_target_names, self._validation_target_transforms) self._validation_dataloader = DataLoader( self._validation_dataset, batch_size=self._validation_batch_size, shuffle=False, num_workers=self._validation_num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory) # Test self._test_dataset = ArrayDataset(self._test_image_names, self._test_image_transforms, self._test_target_names, self._test_target_transforms) self._test_dataloader = DataLoader( self._test_dataset, batch_size=self._test_batch_size, shuffle=False, num_workers=self._test_num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory) return True def _define_prediction_data_loaders( self, prediction_folder_path: Union[Path, str]) -> bool: """Initialize prediction datasets and data loaders. @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders! @return True if datasets and data loaders could be instantiated, False otherwise. """ # Check that the path exists prediction_folder_path = Path(prediction_folder_path) if not prediction_folder_path.is_dir(): return False # Scan for images self._prediction_image_names = natsorted( glob(str(Path(prediction_folder_path) / "*.tif"))) # Optimize arguments if sys.platform == 'win32': persistent_workers = True pin_memory = False else: persistent_workers = False pin_memory = torch.cuda.is_available() if len(self._prediction_image_names) == 0: self._prediction_dataset = None self._prediction_dataloader = None return False # Prediction self._prediction_dataset = Dataset(self._prediction_image_names, self._prediction_image_transforms) self._prediction_dataloader = DataLoader( self._prediction_dataset, batch_size=self._test_batch_size, shuffle=False, num_workers=self._test_num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory) return True def get_message(self): """Return last error message.""" return self._message def get_best_model_path(self): """Return the full path to the best model.""" return self._best_model def _clear_session(self) -> None: """Try clearing cache on the GPU.""" if self._device != "cpu": torch.cuda.empty_cache() def _define_model(self) -> None: """Instantiate the U-Net architecture.""" # Create U-Net self._device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") print(f"Using device '{self._device}'.", file=self._stdout) # Try to free memory on the GPU if self._device != "cpu": torch.cuda.empty_cache() # Monai's UNet self._model = UNet(dimensions=2, in_channels=self._in_channels, out_channels=self._out_channels, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2).to(self._device) def _define_training_loss(self) -> None: """Define the loss function.""" # Use the MAE loss self._training_loss_function = L1Loss() def _define_optimizer(self, learning_rate: float = 1e-3, weight_decay: float = 1e-4) -> None: """Define the optimizer. @param learning_rate: float, optional, default = 1e-3 Initial learning rate for the optimizer. @param weight_decay: float, optional, default = 1e-4 Weight decay of the learning rate for the optimizer. """ if self._model is None: return self._optimizer = Adam(self._model.parameters(), learning_rate, weight_decay=weight_decay, amsgrad=True) def _define_validation_metric(self): """Define the metric for validation function.""" self._validation_metric = DiceMetric(include_background=True, reduction="mean") def _prepare_experiment_and_model_names(self) -> Tuple[str, str]: """Prepare the experiment and model names. @return experiment_file_name, model_file_name Current date time is appended and the full path is returned. """ # Make sure the "runs" subfolder exists runs_dir = Path(self._working_dir) / "runs" runs_dir.mkdir(parents=True, exist_ok=True) now = datetime.now() # current date and time date_time = now.strftime("%Y%m%d_%H%M%S") # Experiment name experiment_name = f"{self._raw_experiment_name}_{date_time}" \ if self._raw_experiment_name != "" \ else f"{date_time}" experiment_name = runs_dir / experiment_name # Best model file name name = Path(self._raw_model_file_name).stem model_file_name = f"{name}_{date_time}.pth" model_file_name = runs_dir / model_file_name return str(experiment_name), str(model_file_name) def _print_header(self, header_text, line_length=80, file=None): """Print a section header.""" if file is None: file = self._stdout print(f"{line_length * '-'}", file=file) print(f"{header_text}", file=self._stdout) print(f"{line_length * '-'}", file=file)
def main(config): now = datetime.now().strftime("%Y%m%d-%H:%M:%S") # path csv_path = config['path']['csv_path'] trained_model_path = config['path'][ 'trained_model_path'] # if None, trained from scratch training_model_folder = os.path.join( config['path']['training_model_folder'], now) # '/path/to/folder' if not os.path.exists(training_model_folder): os.makedirs(training_model_folder) logdir = os.path.join(training_model_folder, 'logs') if not os.path.exists(logdir): os.makedirs(logdir) # PET CT scan params image_shape = tuple(config['preprocessing']['image_shape']) # (x, y, z) in_channels = config['preprocessing']['in_channels'] voxel_spacing = tuple( config['preprocessing'] ['voxel_spacing']) # (4.8, 4.8, 4.8) # in millimeter, (x, y, z) data_augment = config['preprocessing'][ 'data_augment'] # True # for training dataset only resize = config['preprocessing']['resize'] # True # not use yet origin = config['preprocessing']['origin'] # how to set the new origin normalize = config['preprocessing'][ 'normalize'] # True # whether or not to normalize the inputs number_class = config['preprocessing']['number_class'] # 2 # CNN params architecture = config['model']['architecture'] # 'unet' or 'vnet' cnn_params = config['model'][architecture]['cnn_params'] # transform list to tuple for key, value in cnn_params.items(): if isinstance(value, list): cnn_params[key] = tuple(value) # Training params epochs = config['training']['epochs'] batch_size = config['training']['batch_size'] shuffle = config['training']['shuffle'] opt_params = config['training']["optimizer"]["opt_params"] # Get Data DM = DataManager(csv_path=csv_path) train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test( wrap_with_dict=True) # Input preprocessing # use data augmentation for training train_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # user can also add other random transforms RandAffined(keys=("pet_img", "ct_img", "mask_img"), spatial_size=None, prob=0.4, rotate_range=(0, np.pi / 30, np.pi / 15), shear_range=None, translate_range=(10, 10, 10), scale_range=(0.1, 0.1, 0.1), mode=("bilinear", "bilinear", "nearest"), padding_mode="border"), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # without data augmentation for validation val_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_images_paths, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle, num_workers=2) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_images_paths, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=batch_size, num_workers=2) # Model # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, # 3D in_channels=in_channels, out_channels=1, kernel_size=5, channels=(8, 16, 32, 64, 128), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True) opt = torch.optim.Adam(net.parameters(), 1e-3) # training val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), # TensorBoardImageHandler( # log_dir="./runs/", # batch_transform=lambda x: (x["image"], x["label"]), # output_transform=lambda x: x["pred"], # ), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SimpleInferer(), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "val_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "val_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) train_handlers = [ # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, prepare_batch=lambda x: (x['image'], x['mask_img']), inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "train_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "train_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()