def test_random_shape(self, input_param, input_shape, expected_shape): for im_type in TEST_NDARRAYS_ALL: with self.subTest(im_type=im_type): cropper = RandSpatialCrop(**input_param) cropper.set_random_state(seed=123) input_data = im_type(np.random.randint(0, 2, input_shape)) result = cropper(input_data) self.assertTupleEqual(result.shape, expected_shape)
def run_test(batch_size=64, train_steps=200, device=torch.device("cuda:0")): class _TestBatch(Dataset): def __init__(self, transforms): self.transforms = transforms def __getitem__(self, _unused_id): im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1) seed = np.random.randint(2147483647) self.transforms.set_random_state(seed=seed) im = self.transforms(im) self.transforms.set_random_state(seed=seed) seg = self.transforms(seg) return im, seg def __len__(self): return train_steps net = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, ).to(device) loss = DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-2) train_transforms = Compose([ AddChannel(), ScaleIntensity(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(), ToTensor() ]) src = DataLoader(_TestBatch(train_transforms), batch_size=batch_size, shuffle=True) net.train() epoch_loss = 0 step = 0 for img, seg in src: step += 1 opt.zero_grad() output = net(img.to(device)) step_loss = loss(output, seg.to(device)) step_loss.backward() opt.step() epoch_loss += step_loss.item() epoch_loss /= step return epoch_loss, step
def test_value(self, input_param, input_data): for p in TEST_NDARRAYS: cropper = RandSpatialCrop(**input_param) result = cropper(p(input_data)) roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] assert_allclose(result, input_data[:, roi[0][0]:roi[0][1], roi[1][0]:roi[1][1]], type_test=False)
def test_value(self, input_param, input_data): for im_type in TEST_NDARRAYS_ALL: with self.subTest(im_type=im_type): cropper = RandSpatialCrop(**input_param) result = cropper(im_type(input_data)) roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] assert_allclose(result, input_data[:, roi[0][0]:roi[0][1], roi[1][0]:roi[1][1]], type_test="tensor")
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray(im.astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray(seg.astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_imtrans = Compose([ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor(), ]) train_segtrans = Compose([ LoadImage(image_only=True), AddChannel(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor(), ]) val_imtrans = Compose([ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor() ]) val_segtrans = Compose( [LoadImage(image_only=True), AddChannel(), ToTensor()]) # define array dataset, data loader check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose( [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(10): print("-" * 10) print(f"epoch {epoch + 1}/{10}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to( device), val_data[1].to(device) roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) TESTS.append( (list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) TESTS.append( (list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2))) class _Dataset(torch.utils.data.Dataset): def __init__(self, images, labels, transforms): self.images = images self.labels = labels self.transforms = transforms
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation train_imtrans = Compose([ ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor() ]) train_segtrans = Compose([ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor() ]) val_imtrans = Compose([ ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor() ]) val_segtrans = Compose([ AddChannel(), Resize((96, 96, 96)), ToTensor() ]) # define nifti dataset, data loader check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss = monai.losses.DiceLoss(do_sigmoid=True) lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device('cuda:0') # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not a loss value train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) validation_every_n_epochs = 1 # Set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along Depth axis, at every validation epoch val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch[0], batch[1]), output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch ) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) train_epochs = 30 state = trainer.run(train_loader, train_epochs) shutil.rmtree(tempdir)
def test_random_shape(self, input_param, input_data, expected_shape): cropper = RandSpatialCrop(**input_param) cropper.set_random_state(seed=123) result = cropper(input_data) self.assertTupleEqual(result.shape, expected_shape)
def test_shape(self, input_param, input_data, expected_shape): result = RandSpatialCrop(**input_param)(input_data) self.assertTupleEqual(result.shape, expected_shape)
def _define_training_transforms(self): """Define and initialize all training data transforms. * training set images transform * training set masks transform * validation set images transform * validation set masks transform * validation set images post-transform * test set images transform * test set masks transform * test set images post-transform * prediction set images transform * prediction set images post-transform @return True if data transforms could be instantiated, False otherwise. """ if self._mask_type == MaskType.UNKNOWN: raise Exception("The mask type is unknown. Cannot continue!") # Depending on the mask type, we will need to adapt the Mask Loader # and Transform. We start by initializing the most common types. MaskLoader = LoadMask(self._mask_type) MaskTransform = Identity # Adapt the transform for the LABEL types if self._mask_type == MaskType.TIFF_LABELS or self._mask_type == MaskType.NUMPY_LABELS: MaskTransform = ToOneHot(num_classes=self._out_channels) # The H5_ONE_HOT type requires a different loader if self._mask_type == MaskType.H5_ONE_HOT: # MaskLoader: still missing raise Exception("HDF5 one-hot masks are not supported yet!") # Define transforms for training self._train_image_transforms = Compose( [ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), RandSpatialCrop(self._roi_size, random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor() ] ) self._train_mask_transforms = Compose( [ MaskLoader, MaskTransform, RandSpatialCrop(self._roi_size, random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor() ] ) # Define transforms for validation self._validation_image_transforms = Compose( [ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor() ] ) self._validation_mask_transforms = Compose( [ MaskLoader, MaskTransform, ToTensor() ] ) # Define transforms for testing self._test_image_transforms = Compose( [ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor() ] ) self._test_mask_transforms = Compose( [ MaskLoader, MaskTransform, ToTensor() ] ) # Post transforms self._validation_post_transforms = Compose( [ Activations(softmax=True), AsDiscrete(threshold_values=True) ] ) self._test_post_transforms = Compose( [ Activations(softmax=True), AsDiscrete(threshold_values=True) ] )
def test_value(self, input_param, input_data): cropper = RandSpatialCrop(**input_param) result = cropper(input_data) roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] np.testing.assert_allclose( result, input_data[:, roi[0][0]:roi[0][1], roi[1][0]:roi[1][1]])
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation train_imtrans = Compose([ ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor() ]) train_segtrans = Compose([ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor() ]) val_imtrans = Compose([ ScaleIntensity(), AddChannel(), ToTensor() ]) val_segtrans = Compose([ AddChannel(), ToTensor() ]) # define nifti dataset, data loader check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer device = torch.device('cuda:0') model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print('-' * 10) print('epoch {}/{}'.format(epoch + 1, 5)) model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), 'best_metric_model.pth') print('saved new best metric model') print('current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format( epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar('val_mean_dice', metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') shutil.rmtree(tempdir) print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) writer.close()
import monai from monai.data import ArrayDataset, create_test_image_2d from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, LoadImage, RandRotate90, RandSpatialCrop, ScaleIntensity, ToTensor, LoadNumpy, LoadNifti from monai.visualize import plot_2d_or_3d_image from torch.utils.data import Dataset, DataLoader from tutils import * tconfig.set_print_info(True) train_imtrans = Compose([ ToTensor(), AddChannel(), RandSpatialCrop((96, 96), random_size=False), ]) # RandRotate90(prob=0.5, spatial_axes=(0, 1)), # AddChannel(), # ToTensor(), # ScaleIntensity(), # AddChannel(), # RandSpatialCrop((96, 96), random_size=False), # LoadNifti(), train_segtrans = Compose([ LoadNifti(), AddChannel(), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor(), ]) # For testing
TESTS.append((dict, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) TESTS.append((dict, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2))) TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True))) TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False))) TESTS.append( (list, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((list, RandRotate90(prob=1, max_k=2))) class _Dataset(torch.utils.data.Dataset): def __init__(self, images, labels, transforms): self.images = images self.labels = labels self.transforms = transforms def __len__(self): return len(self.images)
def _define_transforms(self): """Define and initialize all data transforms. * training set images transform * training set targets transform * validation set images transform * validation set targets transform * validation set images post-transform * test set images transform * test set targets transform * test set images post-transform * prediction set images transform * prediction set images post-transform @return True if data transforms could be instantiated, False otherwise. """ # Define transforms for training self._train_image_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), RandSpatialCrop(self._roi_size, random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor() ]) self._train_target_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), RandSpatialCrop(self._roi_size, random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor() ]) # Define transforms for validation self._validation_image_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), ToTensor() ]) self._validation_target_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), ToTensor() ]) # Define transforms for testing self._test_image_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), ToTensor() ]) self._test_target_transforms = Compose([ LoadImage(image_only=True), ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False), AddChannel(), ToTensor() ]) # Define transforms for prediction self._prediction_image_transforms = Compose( [LoadImage(image_only=True), AddChannel(), ToTensor()]) # Post transforms self._validation_post_transforms = Compose([Identity()]) self._test_post_transforms = Compose( [ToNumpy(), ScaleIntensity(0, 65535)]) self._prediction_post_transforms = Compose( [ToNumpy(), ScaleIntensity(0, 65535)])
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation train_imtrans = Compose([ ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), EnsureType(), ]) train_segtrans = Compose([ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), EnsureType() ]) val_imtrans = Compose( [ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()]) val_segtrans = Compose([AddChannel(), Resize((96, 96, 96)), EnsureType()]) # define image dataset, data loader check_ds = ImageDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = ImageDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader( train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = ImageDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs_array/", "net", n_saved=10, require_empty=False) trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "net": net, "opt": opt }, ) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not a loss value train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x) train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: x) train_tensorboard_stats_handler.attach(trainer) validation_every_n_epochs = 1 # Set parameters for validation metric_name = "Mean_Dice" # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice()} post_pred = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) post_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)]) # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator( net, val_metrics, device, True, output_transform=lambda x, y, y_pred: ([post_pred(i) for i in decollate_batch(y_pred)], [post_label(i) for i in decollate_batch(y)]), ) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along Depth axis, at every validation epoch val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch[0], batch[1]), output_transform=lambda output: output[0], global_iter_transform=lambda x: trainer.state.epoch, ) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) train_epochs = 30 state = trainer.run(train_loader, train_epochs) print(state)