def test_value_class(self, input_data, expected_value): # same test as for compute_meandice vals = dict() vals["y_pred"] = input_data.pop("y_pred") vals["y"] = input_data.pop("y") dice_metric = DiceMetric(**input_data, reduction="none") result, _ = dice_metric(**vals) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count print("evaluation metric:", metric)
def run_inference_test(root_dir, device=torch.device("cuda:0")): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inferene need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=np.float32) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels) not_nans = dice_metric.not_nans.item() metric_count += not_nans metric_sum += value.item() * not_nans val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data["img_meta_dict"]) metric = metric_sum / metric_count return metric
def evaluate(model, data_loader, device): metric = torch.zeros(8, dtype=torch.float, device=device) model.eval() with torch.no_grad(): dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose( [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) for val_data in data_loader: val_inputs, val_labels = ( val_data["image"].to(device, non_blocking=True), val_data["label"].to(device, non_blocking=True), ) val_outputs = model(val_inputs) val_outputs = post_trans(val_outputs) # compute overall mean dice value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels) value = value.squeeze() metric[0] += value * not_nans metric[1] += not_nans # compute mean dice for TC value_tc, not_nans = dice_metric(y_pred=val_outputs[:, 0:1], y=val_labels[:, 0:1]) value_tc = value_tc.squeeze() metric[2] += value_tc * not_nans metric[3] += not_nans # compute mean dice for WT value_wt, not_nans = dice_metric(y_pred=val_outputs[:, 1:2], y=val_labels[:, 1:2]) value_wt = value_wt.squeeze() metric[4] += value_wt * not_nans metric[5] += not_nans # compute mean dice for ET value_et, not_nans = dice_metric(y_pred=val_outputs[:, 2:3], y=val_labels[:, 2:3]) value_et = value_et.squeeze() metric[6] += value_et * not_nans metric[7] += not_nans # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) metric = metric.tolist() return metric[0] / metric[1], metric[2] / metric[3], metric[4] / metric[ 5], metric[6] / metric[7]
def evaluate(model, data_loader, device): metric_sum = metric_sum_tc = metric_sum_wt = metric_sum_et = 0.0 metric_count = metric_count_tc = metric_count_wt = metric_count_et = 0 model.eval() with torch.no_grad(): dice_metric = DiceMetric(include_background=True, sigmoid=True, reduction="mean") for val_data in data_loader: val_inputs, val_labels = ( val_data["image"].to(device, non_blocking=True), val_data["label"].to(device, non_blocking=True), ) val_outputs = model(val_inputs) # compute overall mean dice value = dice_metric(y_pred=val_outputs, y=val_labels) not_nans = dice_metric.not_nans.item() metric_count += not_nans metric_sum += value.item() * not_nans # compute mean dice for TC value_tc = dice_metric(y_pred=val_outputs[:, 0:1], y=val_labels[:, 0:1]) not_nans = dice_metric.not_nans.item() metric_count_tc += not_nans metric_sum_tc += value_tc.item() * not_nans # compute mean dice for WT value_wt = dice_metric(y_pred=val_outputs[:, 1:2], y=val_labels[:, 1:2]) not_nans = dice_metric.not_nans.item() metric_count_wt += not_nans metric_sum_wt += value_wt.item() * not_nans # compute mean dice for ET value_et = dice_metric(y_pred=val_outputs[:, 2:3], y=val_labels[:, 2:3]) not_nans = dice_metric.not_nans.item() metric_count_et += not_nans metric_sum_et += value_et.item() * not_nans metric = metric_sum / metric_count metric_tc = metric_sum_tc / metric_count_tc metric_wt = metric_sum_wt / metric_count_wt metric_et = metric_sum_et / metric_count_et return metric, metric_tc, metric_wt, metric_et
class MonaiMetrics(tools.Metrics): dice_metric = DiceMetric(include_background=True, reduction="mean") def score(self, y_true, y_pred): metric_sum = 0.0 metric_count = 0 with torch.no_grad(): for (val_true, val_pred) in zip(y_true, y_pred): val_true, _ = val_true val_pred, _ = val_pred value = self.dice_metric( y_pred=val_pred, y=val_true, ) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count return metric
def __init__( self, include_background: bool = True, to_onehot_y: bool = False, mutually_exclusive: bool = False, sigmoid: bool = False, other_act: Optional[Callable] = None, logit_thresh: float = 0.5, output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, ) -> None: """ Args: include_background: whether to include dice computation on the first channel of the predicted output. Defaults to True. to_onehot_y: whether to convert the output prediction into the one-hot format. Defaults to False. mutually_exclusive: if True, the output prediction will be converted into a binary matrix using a combination of argmax and to_onehot. Defaults to False. sigmoid: whether to add sigmoid function to the output prediction before computing Dice. Defaults to False. other_act: callable function to replace `sigmoid` as activation layer if needed, Defaults to ``None``. for example: `other_act = torch.tanh`. logit_thresh: the threshold value to round value to 0.0 and 1.0. Defaults to None (no thresholding). output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. See also: :py:meth:`monai.metrics.meandice.compute_meandice` """ super().__init__(output_transform, device=device) self.dice = DiceMetric( include_background=include_background, to_onehot_y=to_onehot_y, mutually_exclusive=mutually_exclusive, sigmoid=sigmoid, other_act=other_act, logit_thresh=logit_thresh, reduction=MetricReduction.MEAN, ) self._sum = 0.0 self._num_examples = 0
def __init__( self, include_background: bool = True, output_transform: Callable = lambda x: x, device: Optional[torch.device] = None, ) -> None: """ Args: include_background: whether to include dice computation on the first channel of the predicted output. Defaults to True. output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. device: device specification in case of distributed computation usage. See also: :py:meth:`monai.metrics.meandice.compute_meandice` """ super().__init__(output_transform, device=device) self.dice = DiceMetric( include_background=include_background, reduction=MetricReduction.MEAN, ) self._sum = 0.0 self._num_examples = 0
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] val_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(test_images, test_segs)] # Transforms list # Need to concatenate multiple channels here if you want multichannel segmentation # Check other examples on Monai webpage. if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # # try to use all the available GPUs # devices = get_devices_spec(None) # build the network net = build_net() net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) # loss_function = monai.losses.DiceLoss(sigmoid=True) # loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7) loss_function = monai.losses.DiceCELoss(sigmoid=True) optim = torch.optim.Adam(net.parameters(), lr=opt.lr) net_scheduler = get_scheduler(optim, opt) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice(val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch ) ) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def _define_validation_metric(self): """Define the metric for validation function.""" self._validation_metric = DiceMetric(include_background=True, reduction="mean")
model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) loss_function = DiceCELoss( to_onehot_y=True, softmax=True, squared_pred=True, batch=True ) optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() dice_metric = DiceMetric( include_background=True, reduction="mean", get_not_nans=False ) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)] ) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) best_metric = -1 best_metric_epoch = -1 best_metrics_epochs_and_time = [[], [], []] epoch_loss_values = [] metric_values = [] epoch_times = [] total_start = time.time() writer = SummaryWriter(log_dir=out_dir)
def evaluate(args): # initialize Horovod library hvd.init() # Horovod limits CPU threads to be used per worker torch.set_num_threads(1) if hvd.local_rank() == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds, num_replicas=hvd.size(), rank=hvd.rank()) # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent # issues with Infiniband implementations that are not fork-safe multiprocessing_context = None if hasattr( mp, "_supports_context" ) and mp._supports_context and "forkserver" in mp.get_all_start_methods(): multiprocessing_context = "forkserver" # sliding window inference need to input 1 image in every iteration val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, sampler=val_sampler, multiprocessing_context=multiprocessing_context, ) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{hvd.local_rank()}") model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) if hvd.rank() == 0: # load model parameters for evaluation model.load_state_dict(torch.load("final_model.pth")) # Horovod broadcasts parameters hvd.broadcast_parameters(model.state_dict(), root_rank=0) model.eval() with torch.no_grad(): # define PyTorch Tensor to record metrics result at each GPU # the first value is `sum` of all dice metric, the second value is `count` of not_nan items metric = torch.zeros(2, dtype=torch.float, device=device) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels).squeeze() metric[0] += value * dice_metric.not_nans metric[1] += dice_metric.not_nans # synchronizes all processes and reduce results print( f"metric in rank {hvd.rank()}: sum={metric[0].item()}, count={metric[1].item()}" ) avg_metric = hvd.allreduce(metric, name="mean_dice") if hvd.rank() == 0: print( f"average metric: sum={avg_metric[0].item()}, count={avg_metric[1].item()}" ) print("evaluation metric:", (avg_metric[0] / avg_metric[1]).item())
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) # try to use all the available GPUs devices = get_devices_spec(None) model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(devices[0]) model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth")) # if we have multiple GPUs, set data parallel to execute sliding window inference if len(devices) > 1: model = torch.nn.DataParallel(model, device_ids=devices) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0]) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) saver.save_batch(val_outputs, val_data["img_meta_dict"]) metric = metric_sum / metric_count print("evaluation metric:", metric)
# create UNet, DiceLoss and Adam optimizer # device = torch.device("cuda:0") # you don't need device, because Catalyst uses autoscaling model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") class MonaiSupervisedRunner(dl.SupervisedRunner): def forward(self, batch): if self.is_train_loader: output = {self.output_key: self.model(batch[self.input_key])} elif self.is_valid_loader: roi_size = (96, 96, 96) sw_batch_size = 4 output = { self.output_key: sliding_window_inference(batch[self.input_key], roi_size, sw_batch_size, self.model) }
def segment(image, label, result, weights, resolution, patch_size, network, gpu_ids): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if label is not None: uniform_img_dimensions_internal(image, label, True) files = [{"image": image, "label": label}] else: files = [{"image": image}] # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop( image, resolution) # ------------------------------- if label is not None: if resolution is not None: val_transforms = Compose([ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), ToTensord(keys=['image', 'label']) ]) else: val_transforms = Compose([ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd( keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ]) else: if resolution is not None: val_transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution SpatialPadd( keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch ToTensord(keys=['image']) ]) else: val_transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd( keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch ToTensord(keys=['image']) ]) val_ds = monai.data.Dataset(data=files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=False) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([ EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True) ]) if gpu_ids != '-1': # try to use all the available GPUs os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") # build the network if network == 'nnunet': net = build_net() # nn build_net elif network == 'unetr': net = build_UNETR() # UneTR net = net.to(device) if gpu_ids == '-1': net.load_state_dict(new_state_dict_cpu(weights)) else: net.load_state_dict(new_state_dict(weights)) # define sliding window size and batch size for windows inference roi_size = patch_size sw_batch_size = 4 net.eval() with torch.no_grad(): if label is None: for val_data in val_loader: val_images = val_data["image"].to(device) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] else: for val_data in val_loader: val_images, val_labels = val_data["image"].to( device), val_data["label"].to(device) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() print("Evaluation Metric (Dice):", metric) result_array = val_outputs[0].squeeze().data.cpu().numpy() # Remove the pad if the image was smaller than the patch in some directions result_array = result_array[0:resampled_size[0], 0:resampled_size[1], 0:resampled_size[2]] # resample back to the original resolution if resolution is not None: result_array_np = np.transpose(result_array, (2, 1, 0)) result_array_temp = sitk.GetImageFromArray(result_array_np) result_array_temp.SetSpacing(resolution) # save temporary label writer = sitk.ImageFileWriter() writer.SetFileName('temp_seg.nii') writer.Execute(result_array_temp) files = [{"image": 'temp_seg.nii'}] files_transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), Spacingd(keys=['image'], pixdim=original_resolution, mode=('nearest')), Resized(keys=['image'], spatial_size=crop_shape, mode=('nearest')), ]) files_ds = Dataset(data=files, transform=files_transforms) files_loader = DataLoader(files_ds, batch_size=1, num_workers=0) for files_data in files_loader: files_images = files_data["image"] res = files_images.squeeze().data.numpy() result_array = np.rint(res) os.remove('./temp_seg.nii') # recover the cropped background before saving the image empty_array = np.zeros(original_shape) empty_array[coord1[0]:coord2[0], coord1[1]:coord2[1], coord1[2]:coord2[2]] = result_array result_seg = from_numpy_to_itk(empty_array, image) # save label writer = sitk.ImageFileWriter() writer.SetFileName(result) writer.Execute(result_seg) print("Saved Result at:", str(result))
def run_training_test(root_dir, device="cuda:0", cachedataset=0): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])] val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld( keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ] ) train_transforms.set_random_state(1234) val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) # create a training data loader if cachedataset == 2: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) elif cachedataset == 3: train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) dice_metric = DiceMetric(include_background=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs")) model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(6): print("-" * 10) print(f"Epoch {epoch + 1}/{6}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model)) value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += not_nans.item() metric_sum += value.item() * not_nans.item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current mean dice: {metric:0.4f} " f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print(f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}") writer.close() return epoch_loss_values, best_metric, best_metric_epoch
def segment(image, label, result, weights, resolution, patch_size): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if label is not None: files = [{"image": image, "label": label}] else: files = [{"image": image}] if label is not None: if resolution is not None: val_transforms = Compose([ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ]) else: val_transforms = Compose([ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), ToTensord(keys=['image', 'label']) ]) else: if resolution is not None: val_transforms = Compose([ LoadNiftid(keys=['image']), AddChanneld(keys=['image']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image'], pixdim=resolution, mode='bilinear'), ToTensord(keys=['image']) ]) else: val_transforms = Compose([ LoadNiftid(keys=['image']), AddChanneld(keys=['image']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), ToTensord(keys=['image']) ]) val_ds = monai.data.Dataset(data=files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # try to use all the available GPUs devices = get_devices_spec(None) os.environ[ 'CUDA_VISIBLE_DEVICES'] = args.gpu_ids # Multi-gpu selector for training if len(devices) > 1: # build the network net = build_net() net = net.to(devices[0]) net = torch.nn.DataParallel(net, device_ids=devices) net.load_state_dict(torch.load(weights)) else: net = build_net().cuda() net.load_state_dict(new_state_dict(weights)) # define sliding window size and batch size for windows inference roi_size = patch_size sw_batch_size = 4 net.eval() with torch.no_grad(): if label is None: for val_data in val_loader: val_images = val_data["image"].cuda() # define sliding window size and batch size for windows inference val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = (val_outputs.sigmoid() >= 0.5).float() else: metric_sum = 0.0 metric_count = 0 for val_data in val_loader: val_images, val_labels = val_data["image"].cuda( ), val_data["label"].cuda() # define sliding window size and batch size for windows inference val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) val_outputs = (val_outputs.sigmoid() >= 0.5).float() metric = metric_sum / metric_count print("Evaluation Metric (Dice):", metric) result_array = val_outputs.squeeze().data.cpu().numpy() if resolution is not None: reader = sitk.ImageFileReader() reader.SetFileName(image) image_itk = reader.Execute() res = sitk.GetImageFromArray(np.transpose(result_array, (2, 1, 0))) res = resize(res, (sitk.GetArrayFromImage(image_itk)).shape[::-1], sitk.sitkNearestNeighbor) res = (np.rint(sitk.GetArrayFromImage(res))) res = sitk.GetImageFromArray(res) res.SetDirection(image_itk.GetDirection()) res.SetOrigin(image_itk.GetOrigin()) res.SetSpacing(image_itk.GetSpacing()) else: res = from_numpy_to_itk(result_array, image) # save label writer = sitk.ImageFileWriter() writer.SetFileName(result) writer.Execute(res) print("Saved Result at:", str(result)) if resolution is not None: def dice_coeff(seg, gt): """ function to calculate the dice score """ seg = seg.flatten() gt = gt.flatten() dice = float(2 * (gt * seg).sum()) / float(gt.sum() + seg.sum()) return dice y_pred = sitk.GetArrayFromImage(sitk.ReadImage(result)) y_true = sitk.GetArrayFromImage(sitk.ReadImage(label)) print("Dice after resampling to original resolution:", dice_coeff(y_pred, y_true))
def main(): print_config() # Define paths for running the script data_dir = os.path.normpath('/to/be/defined') json_path = os.path.normpath('/to/be/defined') logdir = os.path.normpath('/to/be/defined') # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used use_pretrained = 1 pretrained_path = os.path.normpath('/to/be/defined') # Training Hyper-parameters lr = 1e-4 max_iterations = 30000 eval_num = 100 if os.path.exists(logdir) == False: os.mkdir(logdir) # Training & Validation Transform chain train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandFlipd( keys=["image", "label"], spatial_axis=[0], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[1], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[2], prob=0.10, ), RandRotate90d( keys=["image", "label"], prob=0.10, max_k=3, ), RandShiftIntensityd( keys=["image"], offsets=0.10, prob=0.50, ), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=["image", "label"]), ]) datalist = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="training") val_files = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="validation") train_ds = CacheDataset( data=datalist, transform=train_transforms, cache_num=24, cache_rate=1.0, num_workers=4, ) train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4, pin_memory=True) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) case_num = 0 img = val_ds[case_num]["image"] label = val_ds[case_num]["label"] img_shape = img.shape label_shape = label.shape print(f"image shape: {img_shape}, label shape: {label_shape}") os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNETR( in_channels=1, out_channels=14, img_size=(96, 96, 96), feature_size=16, hidden_size=768, mlp_dim=3072, num_heads=12, pos_embed="conv", norm_name="instance", res_block=True, dropout_rate=0.0, ) # Load ViT backbone weights into UNETR if use_pretrained == 1: print('Loading Weights from the Path {}'.format(pretrained_path)) vit_dict = torch.load(pretrained_path) vit_weights = vit_dict['state_dict'] # Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias, # conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions # while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR) vit_weights.pop('conv3d_transpose_1.bias') vit_weights.pop('conv3d_transpose_1.weight') vit_weights.pop('conv3d_transpose.bias') vit_weights.pop('conv3d_transpose.weight') model.vit.load_state_dict(vit_weights) print('Pretrained Weights Succesfully Loaded !') elif use_pretrained == 0: print( 'No weights were loaded, all weights being used are randomly initialized!' ) model.to(device) loss_function = DiceCELoss(to_onehot_y=True, softmax=True) torch.backends.cudnn.benchmark = True optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) post_label = AsDiscrete(to_onehot=14) post_pred = AsDiscrete(argmax=True, to_onehot=14) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) global_step = 0 dice_val_best = 0.0 global_step_best = 0 epoch_loss_values = [] metric_values = [] def validation(epoch_iterator_val): model.eval() dice_vals = list() with torch.no_grad(): for step, batch in enumerate(epoch_iterator_val): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model) val_labels_list = decollate_batch(val_labels) val_labels_convert = [ post_label(val_label_tensor) for val_label_tensor in val_labels_list ] val_outputs_list = decollate_batch(val_outputs) val_output_convert = [ post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list ] dice_metric(y_pred=val_output_convert, y=val_labels_convert) dice = dice_metric.aggregate().item() dice_vals.append(dice) epoch_iterator_val.set_description( "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)) dice_metric.reset() mean_dice_val = np.mean(dice_vals) return mean_dice_val def train(global_step, train_loader, dice_val_best, global_step_best): model.train() epoch_loss = 0 step = 0 epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) for step, batch in enumerate(epoch_iterator): step += 1 x, y = (batch["image"].cuda(), batch["label"].cuda()) logit_map = model(x) loss = loss_function(logit_map, y) loss.backward() epoch_loss += loss.item() optimizer.step() optimizer.zero_grad() epoch_iterator.set_description( "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)) if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations: epoch_iterator_val = tqdm( val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) dice_val = validation(epoch_iterator_val) epoch_loss /= step epoch_loss_values.append(epoch_loss) metric_values.append(dice_val) if dice_val > dice_val_best: dice_val_best = dice_val global_step_best = global_step torch.save(model.state_dict(), os.path.join(logdir, "best_metric_model.pth")) print( "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) else: print( "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) plt.figure(1, (12, 6)) plt.subplot(1, 2, 1) plt.title("Iteration Average Loss") x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.subplot(1, 2, 2) plt.title("Val Mean Dice") x = [eval_num * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.savefig( os.path.join(logdir, 'btcv_finetune_quick_update.png')) plt.clf() plt.close(1) global_step += 1 return global_step, dice_val_best, global_step_best while global_step < max_iterations: global_step, dice_val_best, global_step_best = train( global_step, train_loader, dice_val_best, global_step_best) model.load_state_dict( torch.load(os.path.join(logdir, "best_metric_model.pth"))) print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")
def main(cfg): """Runs main training procedure.""" print('Starting training...') print('Current working directory is:', os.getcwd()) # fix random seeds for reproducibility seed_everything(seed=cfg['seed']) # neptune logging neptune.init(project_qualified_name=cfg['neptune_project_name'], api_token=cfg['neptune_api_token']) neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg) num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1) activation = 'sigmoid' if num_classes == 1 else 'softmax2d' background = False if cfg['ignore_channels'] else True aux_params = dict( pooling=cfg['pooling'], # one of 'avg', 'max' dropout=cfg['dropout'], # dropout ratio, default is None activation='sigmoid', # activation function, default is None classes=num_classes) # define number of output labels # configure model models = { 'unet': Unet(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], decoder_use_batchnorm=cfg['use_batchnorm'], classes=num_classes, activation=activation, aux_params=aux_params), 'unetplusplus': UnetPlusPlus(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], decoder_use_batchnorm=cfg['use_batchnorm'], classes=num_classes, activation=activation, aux_params=aux_params), 'deeplabv3plus': DeepLabV3Plus(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], classes=num_classes, activation=activation, aux_params=aux_params) } assert cfg['architecture'] in models.keys() model = models[cfg['architecture']] # configure loss losses = { 'dice_loss': DiceLoss(include_background=background, softmax=False, batch=cfg['combine']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=False, batch=cfg['combine']) } assert cfg['loss'] in losses.keys() loss = losses[cfg['loss']] # configure optimizer optimizers = { 'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]), 'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]), 'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])]) } assert cfg['optimizer'] in optimizers.keys() optimizer = optimizers[cfg['optimizer']] # configure metrics metrics = { 'dice_score': DiceMetric(include_background=background, reduction='mean'), 'dice_smp': Fscore(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'iou_smp': IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=False, batch=cfg['combine']), 'dice_loss': DiceLoss(include_background=background, softmax=False, batch=cfg['combine']), 'cross_entropy': BCELoss(reduction='mean'), 'accuracy': Accuracy(ignore_channels=cfg['ignore_channels']) } assert all(m['name'] in metrics.keys() for m in cfg['metrics']) metrics = [(metrics[m['name']], m['name'], m['type']) for m in cfg['metrics']] # tuple of (metric, name, type) # configure scheduler schedulers = { 'steplr': StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5), 'cosine': CosineAnnealingLR(optimizer, cfg['epochs'], eta_min=cfg['eta_min'], last_epoch=-1) } assert cfg['scheduler'] in schedulers.keys() scheduler = schedulers[cfg['scheduler']] # configure augmentations train_transform = load_train_transform(transform_type=cfg['transform'], patch_size=cfg['patch_size']) valid_transform = load_valid_transform(patch_size=cfg['patch_size']) train_dataset = SegmentationDataset(df_path=cfg['train_data'], transform=train_transform, normalize=cfg['normalize'], tissuemix=cfg['tissuemix'], probability=cfg['probability'], blending=cfg['blending'], warping=cfg['warping'], color=cfg['color']) valid_dataset = SegmentationDataset(df_path=cfg['valid_data'], transform=valid_transform, normalize=cfg['normalize']) # save intermediate augmentations if cfg['eval_dir']: default_dataset = SegmentationDataset(df_path=cfg['train_data'], transform=None, normalize=None) transform_dataset = SegmentationDataset(df_path=cfg['train_data'], transform=None, normalize=None, tissuemix=cfg['tissuemix'], probability=cfg['probability'], blending=cfg['blending'], warping=cfg['warping'], color=cfg['color']) for idx in range(0, min(500, len(default_dataset)), 10): image_input, image_mask = default_dataset[idx] image_input = image_input.transpose((1, 2, 0)) image_input = image_input.astype(np.uint8) image_mask = image_mask.transpose( 1, 2, 0) # Why do we need transpose here? image_mask = image_mask.astype(np.uint8) image_mask = image_mask.squeeze() image_mask = image_mask * 255 image_transform, _ = transform_dataset[idx] image_transform = image_transform.transpose( (1, 2, 0)).astype(np.uint8) idx_str = str(idx).zfill(3) skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}a_image_input.png'), image_input, check_contrast=False) plt.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}b_image_mask.png'), image_mask, vmin=0, vmax=1) skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}c_image_transform.png'), image_transform, check_contrast=False) del transform_dataset train_loader = DataLoader(train_dataset, batch_size=cfg['batch_size'], num_workers=cfg['workers'], shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=cfg['batch_size'], num_workers=cfg['workers'], shuffle=False) trainer = Trainer(model=model, device=cfg['device'], save_checkpoints=cfg['save_checkpoints'], checkpoint_dir=cfg['checkpoint_dir'], checkpoint_name=cfg['checkpoint_name']) trainer.compile(optimizer=optimizer, loss=loss, metrics=metrics, num_classes=num_classes) trainer.fit(train_loader, valid_loader, epochs=cfg['epochs'], scheduler=scheduler, verbose=cfg['verbose'], loss_weight=cfg['loss_weight']) # validation inference best_model = model best_model.load_state_dict( torch.load(os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name']))) best_model.to(cfg['device']) best_model.eval() # setup directory to save plots if os.path.isdir(cfg['plot_dir']): # remove existing dir and content shutil.rmtree(cfg['plot_dir']) # create absolute destination os.makedirs(cfg['plot_dir']) # valid dataset without transformations and normalization for image visualization valid_dataset_vis = SegmentationDataset(df_path=cfg['valid_data'], transform=valid_transform, normalize=None) if cfg['save_checkpoints']: for n in range(len(valid_dataset)): image_vis = valid_dataset_vis[n][0].astype('uint8') image_vis = image_vis.transpose((1, 2, 0)) image, gt_mask = valid_dataset[n] gt_mask = gt_mask.transpose((1, 2, 0)) gt_mask = gt_mask.squeeze() x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0) pr_mask, _ = best_model.predict(x_tensor) pr_mask = pr_mask.cpu().numpy().round() pr_mask = pr_mask.squeeze() save_predictions(out_path=cfg['plot_dir'], index=n, image=image_vis, ground_truth_mask=gt_mask, predicted_mask=pr_mask, average='macro')
def test_nans_class(self, params, input_data, expected_value): dice_metric = DiceMetric(**params) dice_metric(**input_data) result, _ = dice_metric.aggregate() np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
def main_worker(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f if not os.path.exists(args.dir): raise FileNotFoundError(f"missing directory {args.dir}") # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) # use amp to accelerate training scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True total_start = time.time() train_transforms = Compose([ # load 4 Nifti images and stack them together LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ]) # create a training data loader train_ds = BratsCacheDataset( root_dir=args.dir, transform=train_transforms, section="training", num_workers=4, cache_rate=args.cache_rate, shuffle=True, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.batch_size, shuffle=True) # validation transforms and dataset val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), ]) val_ds = BratsCacheDataset( root_dir=args.dir, transform=val_transforms, section="validation", num_workers=4, cache_rate=args.cache_rate, shuffle=False, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=args.batch_size, shuffle=False) # create network, loss function and optimizer if args.network == "SegResNet": model = SegResNet( blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], init_filters=16, in_channels=4, out_channels=3, dropout_prob=0.0, ).to(device) else: model = UNet( spatial_dims=3, in_channels=4, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = DiceFocalLoss( smooth_nr=1e-5, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True, batch=True, ) optimizer = Novograd(model.parameters(), lr=args.lr) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) dice_metric = DiceMetric(include_background=True, reduction="mean") dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch") post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # start a typical PyTorch training best_metric = -1 best_metric_epoch = -1 print(f"time elapsed before training: {time.time() - total_start}") train_start = time.time() for epoch in range(args.epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{args.epochs}") epoch_loss = train(train_loader, model, loss_function, optimizer, lr_scheduler, scaler) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % args.val_interval == 0: metric, metric_tc, metric_wt, metric_et = evaluate( model, val_loader, dice_metric, dice_metric_batch, post_trans) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 if dist.get_rank() == 0: torch.save(model.state_dict(), "best_metric_model.pth") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}," f" total train time: {(time.time() - train_start):.4f}") dist.destroy_process_group()
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()]) segtrans = Compose([AddChannel(), EnsureType()]) val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict( torch.load("best_metric_model_segmentation3d_array.pth")) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to( device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) meta_data = decollate_batch(val_data[2]) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output, data in zip(val_outputs, meta_data): saver(val_output, data) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
def evaluate(args): if args.local_rank == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed evaluation process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, sampler=val_sampler) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[args.local_rank]) # config mapping to expected GPU device map_location = {"cuda:0": f"cuda:{args.local_rank}"} # load model parameters to GPU device model.load_state_dict( torch.load("final_model.pth", map_location=map_location)) model.eval() with torch.no_grad(): # define PyTorch Tensor to record metrics result at each GPU # the first value is `sum` of all dice metric, the second value is `count` of not_nan items metric = torch.zeros(2, dtype=torch.float, device=device) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels).squeeze() metric[0] += value * dice_metric.not_nans metric[1] += dice_metric.not_nans # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) metric = metric.tolist() if dist.get_rank() == 0: print("evaluation metric:", metric[0] / metric[1]) dist.destroy_process_group()
def evaluate(args): if args.local_rank == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed evaluation process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ]) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(dataset=val_ds, even_divisible=False, shuffle=False) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) # config mapping to expected GPU device map_location = {"cuda:0": f"cuda:{args.local_rank}"} # load model parameters to GPU device model.load_state_dict( torch.load("final_model.pth", map_location=map_location)) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if dist.get_rank() == 0: print("evaluation metric:", metric) dist.destroy_process_group()
def evaluate(gpu, args): # initialize the distributed evaluation process, every GPU runs in a process, # so the process rank is (node index x GPU count of 1 node + GPU index) rank = args.node * args.gpus + gpu dist.init_process_group(backend="nccl", init_method="env://", world_size=args.world_size, rank=rank) images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds, num_replicas=args.world_size, rank=rank) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available(), sampler=val_sampler, ) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{gpu}") model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[gpu]) # config mapping to expected GPU device map_location = {"cuda:0": f"cuda:{gpu}"} # load model parameters to GPU device model.load_state_dict( torch.load("final_model.pth", map_location=map_location)) model.eval() with torch.no_grad(): # define PyTorch Tensor to record metrics result at each GPU # the first value is `sum` of all dice metric, the second value is `count` of not_nan items metric = torch.zeros(2, dtype=torch.float, device=device) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels).squeeze() metric[0] += value * dice_metric.not_nans metric[1] += dice_metric.not_nans # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) metric = metric.tolist() if rank == 0: print("evaluation metric:", metric[0] / metric[1]) dist.destroy_process_group()
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) #loss_function = DiceLoss(to_onehot_y=True, softmax=True) #optimizer = torch.optim.Adam(model.parameters(), 1e-4) loss_function = DiceCELoss(include_background=True, to_onehot_y=True, softmax=True, lambda_dice=0.5, lambda_ce=0.5) optimizer = torch.optim.Adam(model.parameters(), 1e-3) dice_metric = DiceMetric(include_background=False, reduction="mean") scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5) ## """## Execute a typical PyTorch training process""" epoch_num = 300 val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) post_label = AsDiscrete(to_onehot=True, n_classes=2) for epoch in range(epoch_num):
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # # create a temporary directory and 40 random image, mask pairs if (not os.path.exists(tempdir)): print(f"generating synthetic data to {tempdir}") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=3, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) else: print(f"Found data in {tempdir}") images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[64, 64, 64], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.35, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ]) val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["seg"].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_path = "best_metric_model_segmentation3d_dict.pth" if (os.path.exists(model_path)): model.load_state_dict(torch.load(model_path)) print(f"Loaded model from file '{model_path}'") loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() num_epochs = 10 for epoch in range(num_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{num_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) roi_size = (64, 64, 64) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_path) print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray(im.astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray(seg.astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_imtrans = Compose([ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor(), ]) train_segtrans = Compose([ LoadImage(image_only=True), AddChannel(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor(), ]) val_imtrans = Compose([ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor() ]) val_segtrans = Compose( [LoadImage(image_only=True), AddChannel(), ToTensor()]) # define array dataset, data loader check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose( [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(10): print("-" * 10) print(f"epoch {epoch + 1}/{10}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to( device), val_data[1].to(device) roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def evaluate(args): # initialize Horovod library hvd.init() # Horovod limits CPU threads to be used per worker torch.set_num_threads(1) if hvd.local_rank() == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print(f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ] ) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds, shuffle=False, num_replicas=hvd.size(), rank=hvd.rank()) # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent # issues with Infiniband implementations that are not fork-safe multiprocessing_context = None if hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods(): multiprocessing_context = "forkserver" # sliding window inference need to input 1 image in every iteration val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler, multiprocessing_context=multiprocessing_context, ) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{hvd.local_rank()}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) if hvd.rank() == 0: # load model parameters for evaluation model.load_state_dict(torch.load("final_model.pth")) # Horovod broadcasts parameters hvd.broadcast_parameters(model.state_dict(), root_rank=0) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if hvd.rank() == 0: print("evaluation metric:", metric)