def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) # define transforms for image and segmentation imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()]) segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()]) val_ds = ArrayDataset(images, imtrans, segs, segtrans) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth")) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output in val_outputs: saver(val_output) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
post_label(i) for i in decollate_batch(val_labels) ] largest = KeepLargestConnectedComponent(applied_labels=[1]) #value = compute_meandice( # y_pred=val_outputs, # y=val_labels, # include_background=False, #) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value[0]) metric_sum += value[0].sum().item() # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() #metric = metric_sum / metric_count metric_values.append(metric) scheduler.step(metric) ## writer.add_scalar("val_mean_dice", metric, epoch + 1) ## writer.add_scalar("Learning rate", optimizer.param_groups[0]['lr'], epoch + 1) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), os.path.join(out_dir, "best_metric_model.pth")) print("saved new best metric model") print(f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f"\nbest mean dice: {best_metric:.4f}"
def evaluate(args): # initialize Horovod library hvd.init() # Horovod limits CPU threads to be used per worker torch.set_num_threads(1) if hvd.local_rank() == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print(f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ] ) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds, shuffle=False, num_replicas=hvd.size(), rank=hvd.rank()) # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent # issues with Infiniband implementations that are not fork-safe multiprocessing_context = None if hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods(): multiprocessing_context = "forkserver" # sliding window inference need to input 1 image in every iteration val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler, multiprocessing_context=multiprocessing_context, ) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{hvd.local_rank()}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) if hvd.rank() == 0: # load model parameters for evaluation model.load_state_dict(torch.load("final_model.pth")) # Horovod broadcasts parameters hvd.broadcast_parameters(model.state_dict(), root_rank=0) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if hvd.rank() == 0: print("evaluation metric:", metric)
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) # define transforms for image and segmentation train_imtrans = Compose([ LoadImage(image_only=True), AddChannel(), ScaleIntensity(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), EnsureType(), ]) train_segtrans = Compose([ LoadImage(image_only=True), AddChannel(), ScaleIntensity(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), EnsureType(), ]) val_imtrans = Compose([ LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType() ]) val_segtrans = Compose([ LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType() ]) # define array dataset, data loader check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(10): print("-" * 10) print(f"epoch {epoch + 1}/{10}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to( device), val_data[1].to(device) roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
def evaluate(args): if args.local_rank == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed evaluation process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ]) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(dataset=val_ds, even_divisible=False, shuffle=False) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) # config mapping to expected GPU device map_location = {"cuda:0": f"cuda:{args.local_rank}"} # load model parameters to GPU device model.load_state_dict( torch.load("final_model.pth", map_location=map_location)) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if dist.get_rank() == 0: print("evaluation metric:", metric) dist.destroy_process_group()
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), EnsureTyped(keys=["img", "seg"]), ]) val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["seg"].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation3d_dict.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def compute(args): # generate synthetic data for the example if args.local_rank == 0 and not os.path.exists(args.dir): # create 16 random pred, label paris for evaluation print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # if have multiple nodes, set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): pred, label = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1, noise_max=0.5) n = nib.Nifti1Image(pred, np.eye(4)) nib.save(n, os.path.join(args.dir, f"pred{i:d}.nii.gz")) n = nib.Nifti1Image(label, np.eye(4)) nib.save(n, os.path.join(args.dir, f"label{i:d}.nii.gz")) # initialize the distributed evaluation process, change to NCCL backend if computing on GPU dist.init_process_group(backend="gloo", init_method="env://") preds = sorted(glob(os.path.join(args.dir, "pred*.nii.gz"))) labels = sorted(glob(os.path.join(args.dir, "label*.nii.gz"))) datalist = [{ "pred": pred, "label": label } for pred, label in zip(preds, labels)] # split data for every subprocess, for example, 16 processes compute in parallel data_part = partition_dataset( data=datalist, num_partitions=dist.get_world_size(), shuffle=False, even_divisible=False, )[dist.get_rank()] # define transforms for predictions and labels transforms = Compose([ LoadImaged(keys=["pred", "label"]), EnsureChannelFirstd(keys=["pred", "label"]), ScaleIntensityd(keys="pred"), EnsureTyped(keys=["pred", "label"]), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) data_part = [transforms(item) for item in data_part] # compute metrics for current process metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) metric(y_pred=[i["pred"] for i in data_part], y=[i["label"] for i in data_part]) filenames = [ item["pred_meta_dict"]["filename_or_obj"] for item in data_part ] # all-gather results from all the processes and reduce for final result result = metric.aggregate().item() filenames = string_list_all_gather(strings=filenames) if args.local_rank == 0: print("mean dice: ", result) # generate metrics reports at: output/mean_dice_raw.csv, output/mean_dice_summary.csv, output/metrics.csv write_metrics_reports( save_dir="./output", images=filenames, metrics={"mean_dice": result}, metric_details={"mean_dice": metric.get_buffer()}, summary_ops="*", ) metric.reset() dist.destroy_process_group()
def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, None)): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"], reader=readers[0]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ]) train_transforms.set_random_state(1234) val_transforms = Compose([ LoadImaged(keys=["img", "seg"], reader=readers[1]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ]) # create a training data loader if cachedataset == 2: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) elif cachedataset == 3: train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([ ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True) ]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = [] metric_values = [] writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs")) model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(6): print("-" * 10) print(f"Epoch {epoch + 1}/{6}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: with eval_mode(model): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) # decollate prediction into a list and execute post processing for every item val_outputs = [ val_post_tran(i) for i in decollate_batch(val_outputs) ] # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current mean dice: {metric:0.4f} " f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}" ) writer.close() return epoch_loss_values, best_metric, best_metric_epoch
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg") # try to use all the available GPUs devices = [torch.device("cuda" if torch.cuda.is_available() else "cpu")] #devices = get_devices_spec(None) model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(devices[0]) model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth")) # if we have multiple GPUs, set data parallel to execute sliding window inference if len(devices) > 1: model = torch.nn.DataParallel(model, device_ids=devices) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0]) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) meta_data = decollate_batch(val_data["img_meta_dict"]) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output, data in zip(val_outputs, meta_data): saver(val_output, data) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()