def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"): ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False) loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available()) net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2 ).to(device) roi_size = (16, 32, 48) sw_batch_size = batch_size def _sliding_window_processor(_engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device) return predict_segmentation(seg_probs) infer_engine = Engine(_sliding_window_processor) SegmentationSaver( output_dir=output_dir, output_ext=".nii.gz", output_postfix="seg", batch_transform=lambda x: x[2] ).attach(infer_engine) infer_engine.run(loader) basename = os.path.basename(img_name)[: -len(".nii.gz")] saved_name = os.path.join(output_dir, basename, f"{basename}_seg.nii.gz") return saved_name
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' ] # 2 binary labels for gender classification: man and woman labels = np.array([ 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) # Define transforms for image val_transforms = Compose([ ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor() ]) # Define nifti dataset val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) # create a validation data loader val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) # Create DenseNet121 device = torch.device('cuda:0') model = monai.networks.nets.densenet.densenet121( spatial_dims=3, in_channels=1, out_channels=2, ).to(device) model.load_state_dict(torch.load('best_metric_model.pth')) model.eval() with torch.no_grad(): num_correct = 0. metric_count = 0 saver = CSVSaver(output_dir='./output') for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) val_outputs = model(val_images).argmax(dim=1) value = torch.eq(val_outputs, val_labels) metric_count += len(value) num_correct += value.sum().item() saver.save_batch(val_outputs, val_data[2]) metric = num_correct / metric_count print('evaluation metric:', metric) saver.finalize()
def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) device = torch.device('cuda:0') model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load('best_metric_model.pth')) model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 saver = NiftiSaver(output_dir='./output') for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count print('evaluation metric:', metric) shutil.rmtree(tempdir)
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count print("evaluation metric:", metric)
train_transforms = Compose([ ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), ToTensor() ]) val_transforms = Compose( [ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()]) # Define nifti dataset, data loader check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) im, label = monai.utils.misc.first(check_loader) print(type(im), im.shape, label) # create a training data loader train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True,
def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device('cuda:0') net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) net.to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch[0].to(device), batch[1].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) return seg_probs, val_labels evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0])) file_saver.attach(evaluator) # the model was trained by "unet_training_array" example ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}) ckpt_saver.attach(evaluator) # sliding window inference for one image at every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = evaluator.run(loader) shutil.rmtree(tempdir)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation train_imtrans = Compose( [ ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor(), ] ) train_segtrans = Compose( [ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor(), ] ) val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) val_segtrans = Compose([AddChannel(), ToTensor()]) # define nifti dataset, data loader check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice( y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True ) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( epoch + 1, metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") shutil.rmtree(tempdir) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz", ] # 2 binary labels for gender classification: man and woman labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]) # define transforms train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), ToTensor()]) val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()]) # define nifti dataset, data loader check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) im, label = monai.utils.misc.first(check_loader) print(type(im), im.shape, label) # create DenseNet121, CrossEntropyLoss and Adam optimizer net = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2,) loss = torch.nn.CrossEntropyLoss() lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device("cuda:0") # ignite trainer expects batch=(img, label) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt} ) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name="trainer") train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) # set parameters for validation validation_every_n_epochs = 1 metric_name = "Accuracy" # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # create a training data loader train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz", ] # 2 binary labels for gender classification: man and woman labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0]) # define transforms for image val_transforms = Compose( [ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()]) # define nifti dataset val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) # create DenseNet121 net = monai.networks.nets.densenet.densenet121( spatial_dims=3, in_channels=1, out_channels=2, ) device = torch.device("cuda:0") metric_name = "Accuracy" # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch[0], batch[1]), device, non_blocking) # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data prediction_saver = ClassificationSaver( output_dir="tempdir", batch_transform=lambda batch: batch[2], output_transform=lambda output: output[0].argmax(1), ) prediction_saver.attach(evaluator) # the model was trained by "densenet_training_array" example CheckpointLoader(load_path="./runs/net_checkpoint_20.pth", load_dict={ "net": net }).attach(evaluator) # create a validation data loader val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) state = evaluator.run(val_loader)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) data_dir = '/home/marafath/scratch/eu_data' labels = np.load('eu_labels.npy') train_images = [] train_labels = [] val_images = [] val_labels = [] n_count = 0 p_count = 0 idx = 0 for case in os.listdir(data_dir): if p_count < 13 and labels[idx] == 1: val_images.append( os.path.join(data_dir, case, 'image_masked.nii.gz')) val_labels.append(labels[idx]) p_count += 1 idx += 1 elif n_count < 11 and labels[idx] == 0: val_images.append( os.path.join(data_dir, case, 'image_masked.nii.gz')) val_labels.append(labels[idx]) n_count += 1 idx += 1 else: train_images.append( os.path.join(data_dir, case, 'image_masked.nii.gz')) train_labels.append(labels[idx]) idx += 1 # Define transforms train_transforms = Compose([ ScaleIntensity(), AddChannel(), RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), SpatialPad((256, 256, 92), mode='constant'), Resize((256, 256, 92)), ToTensor() ]) val_transforms = Compose([ ScaleIntensity(), AddChannel(), SpatialPad((256, 256, 92), mode='constant'), Resize((256, 256, 92)), ToTensor() ]) # create a training data loader train_ds = NiftiDataset(image_files=train_images, labels=train_labels, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = NiftiDataset(image_files=val_images, labels=val_labels, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) # Create DenseNet121, CrossEntropyLoss and Adam optimizer device = torch.device('cuda:0') model = monai.networks.nets.densenet.densenet121( spatial_dims=3, in_channels=1, out_channels=2, ).to(device) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), 1e-3) # finetuning #model.load_state_dict(torch.load('best_metric_model_d121.pth')) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() epc = 100 # Number of epoch for epoch in range(epc): print('-' * 10) print('epoch {}/{}'.format(epoch + 1, epc)) model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to( device=device, dtype=torch.int64) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): num_correct = 0. metric_count = 0 for val_data in val_loader: val_images, val_labels = val_data[0].to( device), val_data[1].to(device) val_outputs = model(val_images) value = torch.eq(val_outputs.argmax(dim=1), val_labels) metric_count += len(value) num_correct += value.sum().item() metric = num_correct / metric_count metric_values.append(metric) #torch.save(model.state_dict(), 'model_d121_epoch_{}.pth'.format(epoch + 1)) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save( model.state_dict(), '/home/marafath/scratch/saved_models/best_metric_model_d121.pth' ) print('saved new best metric model') print( 'current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}' .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar('val_accuracy', metric, epoch + 1) print('train completed, best_metric: {:.4f} at epoch: {}'.format( best_metric, best_metric_epoch)) writer.close()
n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device('cuda:0') net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) net.to(device) # define sliding window size and batch size for windows inference
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI314-IOP-0889-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI249-Guys-1072-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI609-HH-2600-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI173-HH-1590-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI020-Guys-0700-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI342-Guys-0909-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI134-Guys-0780-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI577-HH-2661-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI066-Guys-0731-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI130-HH-1528-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz" ]), ] # 2 binary labels for gender classification: man and woman labels = np.array( [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) # Define transforms train_transforms = Compose([ ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), ToTensor() ]) val_transforms = Compose( [ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()]) # Define nifti dataset, data loader check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) im, label = monai.utils.misc.first(check_loader) print(type(im), im.shape, label) # create a training data loader train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) # Create DenseNet121, CrossEntropyLoss and Adam optimizer device = torch.device("cuda:0") model = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), 1e-5) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): num_correct = 0.0 metric_count = 0 for val_data in val_loader: val_images, val_labels = val_data[0].to( device), val_data[1].to(device) val_outputs = model(val_images) value = torch.eq(val_outputs.argmax(dim=1), val_labels) metric_count += len(value) num_correct += value.sum().item() metric = num_correct / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model.pth") print("saved new best metric model") print( "current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_accuracy", metric, epoch + 1) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor() ]) train_segtrans = Compose([ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor() ]) val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) val_segtrans = Compose([AddChannel(), ToTensor()]) # define nifti dataset, data loader check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader(train_ds, batch_size=4,
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation train_imtrans = Compose([ ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor() ]) train_segtrans = Compose([ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor() ]) val_imtrans = Compose([ ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor() ]) val_segtrans = Compose([ AddChannel(), Resize((96, 96, 96)), ToTensor() ]) # define nifti dataset, data loader check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss = monai.losses.DiceLoss(do_sigmoid=True) lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device('cuda:0') # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not a loss value train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) validation_every_n_epochs = 1 # Set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along Depth axis, at every validation epoch val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch[0], batch[1]), output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch ) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) train_epochs = 30 state = trainer.run(train_loader, train_epochs) shutil.rmtree(tempdir)
'/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' ] # 2 binary labels for gender classification: man and woman labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0]) # define transforms for image val_transforms = Compose( [ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()]) # define nifti dataset val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) # create DenseNet121 net = monai.networks.nets.densenet.densenet121( spatial_dims=3, in_channels=1, out_channels=2, ) device = torch.device('cuda:0') metric_name = 'Accuracy' # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} def prepare_batch(batch, device=None, non_blocking=False):
def test_dataset(self): tempdir = tempfile.mkdtemp() full_names, ref_data = [], [] for filename in FILENAMES: test_image = np.random.randint(0, 2, size=(4, 4, 4)) ref_data.append(test_image) save_path = os.path.join(tempdir, filename) full_names.append(save_path) nib.save(nib.Nifti1Image(test_image, np.eye(4)), save_path) # default loading no meta dataset = NiftiDataset(full_names) for d, ref in zip(dataset, ref_data): np.testing.assert_allclose(d, ref, atol=1e-3) # loading no meta, int dataset = NiftiDataset(full_names, dtype=np.float16) for d, _ in zip(dataset, ref_data): self.assertEqual(d.dtype, np.float16) # loading with meta, no transform dataset = NiftiDataset(full_names, image_only=False) for d_tuple, ref in zip(dataset, ref_data): d, meta = d_tuple np.testing.assert_allclose(d, ref, atol=1e-3) np.testing.assert_allclose(meta["original_affine"], np.eye(4)) # loading image/label, no meta dataset = NiftiDataset(full_names, seg_files=full_names, image_only=True) for d_tuple, ref in zip(dataset, ref_data): img, seg = d_tuple np.testing.assert_allclose(img, ref, atol=1e-3) np.testing.assert_allclose(seg, ref, atol=1e-3) # loading image/label, no meta dataset = NiftiDataset(full_names, transform=lambda x: x + 1, image_only=True) for d, ref in zip(dataset, ref_data): np.testing.assert_allclose(d, ref + 1, atol=1e-3) # set seg transform, but no seg_files with self.assertRaises(TypeError): dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) _ = dataset[0] # set seg transform, but no seg_files with self.assertRaises(TypeError): dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) _ = dataset[0] # loading image/label, with meta dataset = NiftiDataset( full_names, transform=lambda x: x + 1, seg_files=full_names, seg_transform=lambda x: x + 2, image_only=False ) for d_tuple, ref in zip(dataset, ref_data): img, seg, meta = d_tuple np.testing.assert_allclose(img, ref + 1, atol=1e-3) np.testing.assert_allclose(seg, ref + 2, atol=1e-3) np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) # loading image/label, with meta dataset = NiftiDataset( full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False ) for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)): img, seg, label, meta = d_tuple np.testing.assert_allclose(img, ref + 1, atol=1e-3) np.testing.assert_allclose(seg, ref, atol=1e-3) np.testing.assert_allclose(idx + 1, label) np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) # loading image/label, with sync. transform dataset = NiftiDataset( full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False ) for d_tuple, ref in zip(dataset, ref_data): img, seg, meta = d_tuple np.testing.assert_allclose(img, seg, atol=1e-3) self.assertTrue(not np.allclose(img, ref)) np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) shutil.rmtree(tempdir)
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 post_trans = Compose( [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch[0].to(device), batch[1].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) seg_probs = post_trans(seg_probs) return seg_probs, val_labels evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice().attach(evaluator, "Mean_Dice") # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg", name="evaluator", batch_transform=lambda x: x[2], output_transform=lambda output: output[0], ) file_saver.attach(evaluator) # the model was trained by "unet_training_array" example ckpt_saver = CheckpointLoader( load_path="./runs_array/net_checkpoint_100.pt", load_dict={"net": net}) ckpt_saver.attach(evaluator) # sliding window inference for one image at every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = evaluator.run(loader) print(state)