def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ # the path of ixi IXI-T1 dataset data_path = os.sep.join([".", "workspace", "data", "medical", "ixi", "IXI-T1"]) images = [ "IXI314-IOP-0889-T1.nii.gz", "IXI249-Guys-1072-T1.nii.gz", "IXI609-HH-2600-T1.nii.gz", "IXI173-HH-1590-T1.nii.gz", "IXI020-Guys-0700-T1.nii.gz", "IXI342-Guys-0909-T1.nii.gz", "IXI134-Guys-0780-T1.nii.gz", "IXI577-HH-2661-T1.nii.gz", "IXI066-Guys-0731-T1.nii.gz", "IXI130-HH-1528-T1.nii.gz", "IXI607-Guys-1097-T1.nii.gz", "IXI175-HH-1570-T1.nii.gz", "IXI385-HH-2078-T1.nii.gz", "IXI344-Guys-0905-T1.nii.gz", "IXI409-Guys-0960-T1.nii.gz", "IXI584-Guys-1129-T1.nii.gz", "IXI253-HH-1694-T1.nii.gz", "IXI092-HH-1436-T1.nii.gz", "IXI574-IOP-1156-T1.nii.gz", "IXI585-Guys-1130-T1.nii.gz", ] images = [os.sep.join([data_path, f]) for f in images] # 2 binary labels for gender classification: man and woman labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) train_files = [{"img": img, "label": label} for img, label in zip(images[:10], labels[:10])] val_files = [{"img": img, "label": label} for img, label in zip(images[-10:], labels[-10:])] # Define transforms for image train_transforms = Compose( [ LoadImaged(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]), EnsureTyped(keys=["img"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), EnsureTyped(keys=["img"]), ] ) post_pred = Compose([EnsureType(), Activations(softmax=True)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) # Define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["label"]) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) # Create DenseNet121, CrossEntropyLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), 1e-5) auc_metric = ROCAUCMetric() # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to(device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["label"].to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) acc_value = torch.eq(y_pred.argmax(dim=1), y) acc_metric = acc_value.sum().item() / len(acc_value) y_onehot = [post_label(i) for i in decollate_batch(y)] y_pred_act = [post_pred(i) for i in decollate_batch(y_pred)] auc_metric(y_pred_act, y_onehot) auc_result = auc_metric.aggregate() auc_metric.reset() del y_pred_act, y_onehot if acc_metric > best_metric: best_metric = acc_metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_classification3d_dict.pth") print("saved new best metric model") print( "current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}".format( epoch + 1, acc_metric, auc_result, best_metric, best_metric_epoch ) ) writer.add_scalar("val_accuracy", acc_metric, epoch + 1) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def segment(image, label, result, weights, resolution, patch_size, network, gpu_ids): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if label is not None: uniform_img_dimensions_internal(image, label, True) files = [{"image": image, "label": label}] else: files = [{"image": image}] # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop( image, resolution) # ------------------------------- if label is not None: if resolution is not None: val_transforms = Compose([ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), ToTensord(keys=['image', 'label']) ]) else: val_transforms = Compose([ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd( keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ]) else: if resolution is not None: val_transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution SpatialPadd( keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch ToTensord(keys=['image']) ]) else: val_transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd( keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch ToTensord(keys=['image']) ]) val_ds = monai.data.Dataset(data=files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=False) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([ EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True) ]) if gpu_ids != '-1': # try to use all the available GPUs os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") # build the network if network == 'nnunet': net = build_net() # nn build_net elif network == 'unetr': net = build_UNETR() # UneTR net = net.to(device) if gpu_ids == '-1': net.load_state_dict(new_state_dict_cpu(weights)) else: net.load_state_dict(new_state_dict(weights)) # define sliding window size and batch size for windows inference roi_size = patch_size sw_batch_size = 4 net.eval() with torch.no_grad(): if label is None: for val_data in val_loader: val_images = val_data["image"].to(device) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] else: for val_data in val_loader: val_images, val_labels = val_data["image"].to( device), val_data["label"].to(device) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() print("Evaluation Metric (Dice):", metric) result_array = val_outputs[0].squeeze().data.cpu().numpy() # Remove the pad if the image was smaller than the patch in some directions result_array = result_array[0:resampled_size[0], 0:resampled_size[1], 0:resampled_size[2]] # resample back to the original resolution if resolution is not None: result_array_np = np.transpose(result_array, (2, 1, 0)) result_array_temp = sitk.GetImageFromArray(result_array_np) result_array_temp.SetSpacing(resolution) # save temporary label writer = sitk.ImageFileWriter() writer.SetFileName('temp_seg.nii') writer.Execute(result_array_temp) files = [{"image": 'temp_seg.nii'}] files_transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), Spacingd(keys=['image'], pixdim=original_resolution, mode=('nearest')), Resized(keys=['image'], spatial_size=crop_shape, mode=('nearest')), ]) files_ds = Dataset(data=files, transform=files_transforms) files_loader = DataLoader(files_ds, batch_size=1, num_workers=0) for files_data in files_loader: files_images = files_data["image"] res = files_images.squeeze().data.numpy() result_array = np.rint(res) os.remove('./temp_seg.nii') # recover the cropped background before saving the image empty_array = np.zeros(original_shape) empty_array[coord1[0]:coord2[0], coord1[1]:coord2[1], coord1[2]:coord2[2]] = result_array result_seg = from_numpy_to_itk(empty_array, image) # save label writer = sitk.ImageFileWriter() writer.SetFileName(result) writer.Execute(result_seg) print("Saved Result at:", str(result))
def _iteration( self, engine: Engine, batchdata: Dict[str, Any] ) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below item in a dictionary: - PRED: prediction result of model. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, _ = batch args: Tuple = () kwargs: Dict = {} else: inputs, _, args, kwargs = batch def _compute_pred(): ct = 1.0 pred = self.inferer(inputs, self.network, *args, **kwargs).cpu() pred = nn.functional.softmax(pred, dim=1) if not self.tta_val: return pred else: for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]: flip_inputs = torch.flip(inputs, dims=dims) flip_pred = torch.flip( self.inferer(flip_inputs, self.network).cpu(), dims=dims ) flip_pred = nn.functional.softmax(flip_pred, dim=1) del flip_inputs pred += flip_pred del flip_pred ct += 1 return pred / ct # execute forward computation with eval_mode(self.network): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() inputs = inputs.cpu() predictions = self.post_pred(decollate_batch(predictions)[0]) affine = batchdata["image_meta_dict"]["affine"].numpy()[0] resample_flag = batchdata["resample_flag"] anisotrophy_flag = batchdata["anisotrophy_flag"] crop_shape = batchdata["crop_shape"][0].tolist() original_shape = batchdata["original_shape"][0].tolist() if resample_flag: # convert the prediction back to the original (after cropped) shape predictions = recovery_prediction( predictions.numpy(), [self.num_classes, *crop_shape], anisotrophy_flag ) else: predictions = predictions.numpy() predictions = np.argmax(predictions, axis=0) # pad the prediction back to the original shape predictions_org = np.zeros([*original_shape]) box_start, box_end = batchdata["bbox"][0] h_start, w_start, d_start = box_start h_end, w_end, d_end = box_end predictions_org[h_start:h_end, w_start:w_end, d_start:d_end] = predictions del predictions filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split("/")[-1] print( "save {} with shape: {}, mean values: {}".format( filename, predictions_org.shape, predictions_org.mean() ) ) write_nifti( data=predictions_org, file_name=os.path.join(self.output_dir, filename), affine=affine, resample=False, output_dtype=np.uint8, ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) return {"pred": predictions_org}
dice_metric_val = np.zeros(number_class) for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) roi_size = (96, 96, 96) sw_batch_size = 4 #print('val_labels: ', val_labels.size()) val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model) #print('val_outputs_pre_proc: ', val_outputs.size()) #val_outputs = post_pred(val_outputs) #val_labels = post_label(val_labels) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_labels) ] #largest = KeepLargestConnectedComponent(applied_labels=[1]) # print('val_outputs_post_proc: ', val_outputs.size()) # print('val_labels_post_proc: ', val_labels.size()) # # value = compute_meandice( # # y_pred=val_outputs, # # y=val_labels, # # #include_background=True, # ) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value[0]) metric_sum += value[0].sum().item()
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) # define transforms for image and segmentation train_imtrans = Compose([ LoadImage(image_only=True), AddChannel(), ScaleIntensity(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), EnsureType(), ]) train_segtrans = Compose([ LoadImage(image_only=True), AddChannel(), ScaleIntensity(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), EnsureType(), ]) val_imtrans = Compose([ LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType() ]) val_segtrans = Compose([ LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType() ]) # define array dataset, data loader check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(10): print("-" * 10) print(f"epoch {epoch + 1}/{10}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to( device), val_data[1].to(device) roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = ( make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose([ LoadImaged(KEYS, image_only=True), EnsureChannelFirstd(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, prob=0, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available( ) else 2 dataset = Dataset(data, transform=transform) transform.inverse(dataset[0]) loader = DataLoader(dataset, num_workers=num_workers, batch_size=1) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted", "label_inverted"], transform=transform, orig_keys=["label", "label"], nearest_interp=True, device="cpu", ) inverter_1 = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted1", "label_inverted1"], transform=transform, orig_keys=["image", "image"], nearest_interp=[True, False], device="cpu", ) expected_keys = [ "image", "image_inverted", "image_inverted1", "label", "label_inverted", "label_inverted1" ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) item = inverter_1(item) self.assertListEqual(sorted(item), expected_keys) self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) # check the nearest interpolation mode i = item["image_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # check the case that different items use different interpolation mode to invert transforms d = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) d = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype( np.int32) original = LoadImaged(KEYS, image_only=True)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = item["label_inverted"].meta["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 self.assertTrue((reverted.size - n_good) < 40000, f"diff. {reverted.size - n_good}") set_determinism(seed=None)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ # the path of ixi IXI-T1 dataset data_path = os.sep.join( [".", "workspace", "data", "medical", "ixi", "IXI-T1"]) images = [ "IXI314-IOP-0889-T1.nii.gz", "IXI249-Guys-1072-T1.nii.gz", "IXI609-HH-2600-T1.nii.gz", "IXI173-HH-1590-T1.nii.gz", "IXI020-Guys-0700-T1.nii.gz", "IXI342-Guys-0909-T1.nii.gz", "IXI134-Guys-0780-T1.nii.gz", "IXI577-HH-2661-T1.nii.gz", "IXI066-Guys-0731-T1.nii.gz", "IXI130-HH-1528-T1.nii.gz", "IXI607-Guys-1097-T1.nii.gz", "IXI175-HH-1570-T1.nii.gz", "IXI385-HH-2078-T1.nii.gz", "IXI344-Guys-0905-T1.nii.gz", "IXI409-Guys-0960-T1.nii.gz", "IXI584-Guys-1129-T1.nii.gz", "IXI253-HH-1694-T1.nii.gz", "IXI092-HH-1436-T1.nii.gz", "IXI574-IOP-1156-T1.nii.gz", "IXI585-Guys-1130-T1.nii.gz", ] images = [os.sep.join([data_path, f]) for f in images] # 2 binary labels for gender classification: man and woman labels = np.array( [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) train_files = [{ "img": img, "label": label } for img, label in zip(images[:10], labels[:10])] val_files = [{ "img": img, "label": label } for img, label in zip(images[-10:], labels[-10:])] # define transforms for image train_transforms = Compose([ LoadImaged(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]), EnsureTyped(keys=["img"]), ]) val_transforms = Compose([ LoadImaged(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), EnsureTyped(keys=["img"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["label"]) # create DenseNet121, CrossEntropyLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) loss = torch.nn.CrossEntropyLoss() lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch["img"], batch["label"]), device, non_blocking) trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs_dict/", "net", n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "net": net, "opt": opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x) train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: x) train_tensorboard_stats_handler.attach(trainer) # set parameters for validation validation_every_n_epochs = 1 metric_name = "AUC" # add evaluation metric to the evaluator engine val_metrics = {metric_name: ROCAUC()} post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) post_pred = Compose([EnsureType(), Activations(softmax=True)]) # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator( net, val_metrics, device, True, prepare_batch=prepare_batch, output_transform=lambda x, y, y_pred: ([post_pred(i) for i in decollate_batch(y_pred)], [post_label(i) for i in decollate_batch(y)])) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add early stopping handler to evaluator early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs) print(state)
def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", num_workers=10): monai.config.print_config() # define transforms for image and classification train_transforms = Compose( [ LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True, dtype=np.float64), RandFlip(spatial_axis=0, prob=0.5), RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), ToTensor(), ] ) train_transforms.set_random_state(1234) val_transforms = Compose( [LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), ToTensor()] ) y_pred_trans = Compose([ToTensor(), Activations(softmax=True)]) y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=len(np.unique(train_y)))]) auc_metric = ROCAUCMetric() # create train, val data loaders train_ds = MedNISTDataset(train_x, train_y, train_transforms) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=num_workers) val_ds = MedNISTDataset(val_x, val_y, val_transforms) val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers) model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(train_y))).to(device) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), 1e-5) epoch_num = 4 val_interval = 1 # start training validation best_metric = -1 best_metric_epoch = -1 epoch_loss_values = [] metric_values = [] model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(epoch_num): print("-" * 10) print(f"Epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: with eval_mode(model): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) # compute accuracy acc_value = torch.eq(y_pred.argmax(dim=1), y) acc_metric = acc_value.sum().item() / len(acc_value) # decollate prediction and label and execute post processing y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] y = [y_trans(i) for i in decollate_batch(y)] # compute AUC auc_metric(y_pred, y) auc_value = auc_metric.aggregate() auc_metric.reset() metric_values.append(auc_value) if auc_value > best_metric: best_metric = auc_value best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current AUC: {auc_value:0.4f} " f"current accuracy: {acc_metric:0.4f} best AUC: {best_metric:0.4f} at epoch {best_metric_epoch}" ) print(f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}") return epoch_loss_values, best_metric, best_metric_epoch
def run_inference_test(root_dir, device="cuda:0"): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([ ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True) ]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=np.float32) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) # decollate prediction into a list and execute post processing for every item val_outputs = [ val_post_tran(i) for i in decollate_batch(val_outputs) ] # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) saver.save_batch(val_outputs, val_data["img_meta_dict"]) return dice_metric.aggregate().item()
def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, None)): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"], reader=readers[0]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ]) train_transforms.set_random_state(1234) val_transforms = Compose([ LoadImaged(keys=["img", "seg"], reader=readers[1]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ]) # create a training data loader if cachedataset == 2: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) elif cachedataset == 3: train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([ ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True) ]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = [] metric_values = [] writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs")) model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(6): print("-" * 10) print(f"Epoch {epoch + 1}/{6}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: with eval_mode(model): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) # decollate prediction into a list and execute post processing for every item val_outputs = [ val_post_tran(i) for i in decollate_batch(val_outputs) ] # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current mean dice: {metric:0.4f} " f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}" ) writer.close() return epoch_loss_values, best_metric, best_metric_epoch
val_data = next(val_loader_iterator) val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) roi_size = (160, 160, 160) sw_batch_size = 4 with nvtx.annotate("sliding window", color="green"): val_outputs = sliding_window_inference( val_inputs, roi_size, sw_batch_size, model) with nvtx.annotate("decollate batch", color="blue"): val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_labels) ] with nvtx.annotate("compute metric", color="yellow"): # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg") # try to use all the available GPUs devices = [torch.device("cuda" if torch.cuda.is_available() else "cpu")] #devices = get_devices_spec(None) model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(devices[0]) model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth")) # if we have multiple GPUs, set data parallel to execute sliding window inference if len(devices) > 1: model = torch.nn.DataParallel(model, device_ids=devices) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0]) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) meta_data = decollate_batch(val_data["img_meta_dict"]) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output, data in zip(val_outputs, meta_data): saver(val_output, data) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AddChanneld(keys=["img", "seg"]), ScaleIntensityd(keys=["img", "seg"]), EnsureTyped(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict( torch.load("best_metric_model_segmentation2d_dict.pth")) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output in val_outputs: saver(val_output) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: data = [ { "pred": torch.zeros(8), PostFix.meta("image"): { "filename_or_obj": ["testfile" + str(i) for i in range(8)] }, }, { "pred": torch.zeros(8), PostFix.meta("image"): { "filename_or_obj": ["testfile" + str(i) for i in range(8, 16)] }, }, { "pred": torch.zeros(8), PostFix.meta("image"): { "filename_or_obj": ["testfile" + str(i) for i in range(16, 24)] }, }, ] saver = CSVSaver(output_dir=Path(tempdir), filename="predictions2.csv", overwrite=False, flush=False, delimiter="\t") # set up test transforms post_trans = Compose([ CopyItemsd(keys=PostFix.meta("image"), times=1, names=PostFix.meta("pred")), # 1st saver saves data into CSV file SaveClassificationd( keys="pred", saver=None, meta_keys=None, output_dir=Path(tempdir), filename="predictions1.csv", delimiter="\t", overwrite=True, ), # 2rd saver only saves data into the cache, manually finalize later SaveClassificationd(keys="pred", saver=saver, meta_key_postfix=PostFix.meta()), ]) # simulate inference 2 iterations d = decollate_batch(data[0]) for i in d: post_trans(i) d = decollate_batch(data[1]) for i in d: post_trans(i) # write into CSV file saver.finalize() # 3rd saver will not delete previous data due to `overwrite=False` trans2 = SaveClassificationd( keys="pred", saver=None, meta_keys=PostFix.meta( "image"), # specify meta key, so no need to copy anymore output_dir=tempdir, filename="predictions1.csv", delimiter="\t", overwrite=False, ) d = decollate_batch(data[2]) for i in d: trans2(i) def _test_file(filename, count): filepath = os.path.join(tempdir, filename) self.assertTrue(os.path.exists(filepath)) with open(filepath) as f: reader = csv.reader(f, delimiter="\t") i = 0 for row in reader: self.assertEqual(row[0], "testfile" + str(i)) self.assertEqual( np.array(row[1:]).astype(np.float32), 0.0) i += 1 self.assertEqual(i, count) _test_file("predictions1.csv", 24) _test_file("predictions2.csv", 16)
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), EnsureTyped(keys=["img", "seg"]), ]) val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["seg"].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation3d_dict.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def evaluate(args): if args.local_rank == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed evaluation process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ]) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(dataset=val_ds, even_divisible=False, shuffle=False) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) # config mapping to expected GPU device map_location = {"cuda:0": f"cuda:{args.local_rank}"} # load model parameters to GPU device model.load_state_dict( torch.load("final_model.pth", map_location=map_location)) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if dist.get_rank() == 0: print("evaluation metric:", metric) dist.destroy_process_group()
def _train_func(engine, batch): engine.state.batch = decollate_batch(batch) return [torch.zeros(1) for _ in range(8 + rank * 2)]
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
def main(tempdir): print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, _ = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) files = [{"img": img} for img in images] # define pre transforms pre_transforms = Compose([ LoadImaged(keys="img"), EnsureChannelFirstd(keys="img"), Orientationd(keys="img", axcodes="RAS"), Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True), ScaleIntensityd(keys="img"), EnsureTyped(keys="img"), ]) # define dataset and dataloader dataset = Dataset(data=files, transform=pre_transforms) dataloader = DataLoader(dataset, batch_size=2, num_workers=4) # define post transforms post_transforms = Compose([ EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), Invertd( keys= "pred", # invert the `pred` data field, also support multiple fields transform=pre_transforms, orig_keys= "img", # get the previously applied pre_transforms information on the `img` data field, # then invert `pred` based on this information. we can use same info # for multiple fields, also support different orig_keys for different fields meta_keys= "pred_meta_dict", # key field to save inverted meta data, every item maps to `keys` orig_meta_keys= "img_meta_dict", # get the meta data from `img_meta_dict` field when inverting, # for example, may need the `affine` to invert `Spacingd` transform, # multiple fields can use the same meta data to invert meta_key_postfix= "meta_dict", # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key, # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}", # otherwise, no need this arg during inverting nearest_interp= False, # don't change the interpolation mode to "nearest" when inverting transforms # to ensure a smooth output, then execute `AsDiscreted` transform to_tensor=True, # convert to PyTorch Tensor after inverting ), AsDiscreted(keys="pred", threshold=0.5), SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict( torch.load("best_metric_model_segmentation3d_dict.pth")) net.eval() with torch.no_grad(): for d in dataloader: images = d["img"].to(device) # define sliding window size and batch size for windows inference d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, predictor=net) # decollate the batch data into a list of dictionaries, then execute postprocessing transforms d = [post_transforms(i) for i in decollate_batch(d)]
def _iteration(self, engine: Engine, batchdata: Dict[str, Any]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: - IMAGE: image Tensor data for model input, already moved to device. - LABEL: label Tensor data corresponding to the image, already moved to device. - PRED: prediction result of model. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch targets = targets.cpu() def _compute_pred(): ct = 1.0 pred = self.inferer(inputs, self.network, *args, **kwargs).cpu() pred = nn.functional.softmax(pred, dim=1) if not self.tta_val: return pred else: for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]: flip_inputs = torch.flip(inputs, dims=dims) flip_pred = torch.flip(self.inferer( flip_inputs, self.network).cpu(), dims=dims) flip_pred = nn.functional.softmax(flip_pred, dim=1) del flip_inputs pred += flip_pred del flip_pred ct += 1 return pred / ct # execute forward computation with eval_mode(self.network): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() inputs = inputs.cpu() predictions = self.post_pred(decollate_batch(predictions)[0]) targets = self.post_label(decollate_batch(targets)[0]) resample_flag = batchdata["resample_flag"] anisotrophy_flag = batchdata["anisotrophy_flag"] crop_shape = batchdata["crop_shape"][0].tolist() original_shape = batchdata["original_shape"][0].tolist() if resample_flag: # convert the prediction back to the original (after cropped) shape predictions = recovery_prediction(predictions.numpy(), [self.num_classes, *crop_shape], anisotrophy_flag) predictions = torch.tensor(predictions) # put iteration outputs into engine.state engine.state.output = { Keys.IMAGE: inputs, Keys.LABEL: targets.unsqueeze(0) } engine.state.output[Keys.PRED] = torch.zeros( [1, self.num_classes, *original_shape]) # pad the prediction back to the original shape box_start, box_end = batchdata["bbox"][0] h_start, w_start, d_start = box_start h_end, w_end, d_end = box_end engine.state.output[Keys.PRED][0, :, h_start:h_end, w_start:w_end, d_start:d_end] = predictions del predictions engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output
def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]): if batchdata is None: raise ValueError("Must provide batch data for current iteration.") pos_click_sum = 0 neg_click_sum = 0 if np.random.choice( [True, False], p=[self.deepgrow_probability, 1 - self.deepgrow_probability]): pos_click_sum += 1 # increase pos_click_sum by 1-click for AddInitialSeedPointd pre_transform for j in range(self.max_interactions): # print("Inner iteration (click simulations running): ", str(j)) inputs, _ = engine.prepare_batch(batchdata) inputs = inputs.to(engine.state.device) engine.fire_event(IterationEvents.INNER_ITERATION_STARTED) engine.network.eval() with torch.no_grad(): if engine.amp: with torch.cuda.amp.autocast(): predictions = engine.inferer( inputs, engine.network) else: predictions = engine.inferer(inputs, engine.network) batchdata.update({CommonKeys.PRED: predictions}) # decollate/collate batchdata to execute click transforms batchdata_list = decollate_batch(batchdata, detach=True) for i in range(len(batchdata_list)): batchdata_list[i][self.click_probability_key] = ( (1.0 - ((1.0 / self.max_interactions) * j)) if self.train else 1.0) batchdata_list[i] = self.transforms(batchdata_list[i]) batchdata = list_data_collate(batchdata_list) # first item in batch only pos_click_sum += (batchdata_list[0].get("is_pos", 0)) * 1 neg_click_sum += (batchdata_list[0].get("is_neg", 0)) * 1 engine.fire_event(IterationEvents.INNER_ITERATION_COMPLETED) else: # zero out input guidance channels batchdata_list = decollate_batch(batchdata, detach=True) for i in range(len(batchdata_list)): batchdata_list[i][CommonKeys.IMAGE][-1] *= 0 batchdata_list[i][CommonKeys.IMAGE][-2] *= 0 batchdata = list_data_collate(batchdata_list) # first item in batch only engine.state.batch = batchdata engine.state.batch.update( {"pos_click_sum": torch.tensor(pos_click_sum)}) engine.state.batch.update( {"neg_click_sum": torch.tensor(neg_click_sum)}) return engine._iteration(engine, batchdata)
def evaluate(args): # initialize Horovod library hvd.init() # Horovod limits CPU threads to be used per worker torch.set_num_threads(1) if hvd.local_rank() == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print(f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ] ) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds, shuffle=False, num_replicas=hvd.size(), rank=hvd.rank()) # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent # issues with Infiniband implementations that are not fork-safe multiprocessing_context = None if hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods(): multiprocessing_context = "forkserver" # sliding window inference need to input 1 image in every iteration val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler, multiprocessing_context=multiprocessing_context, ) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{hvd.local_rank()}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) if hvd.rank() == 0: # load model parameters for evaluation model.load_state_dict(torch.load("final_model.pth")) # Horovod broadcasts parameters hvd.broadcast_parameters(model.state_dict(), root_rank=0) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if hvd.rank() == 0: print("evaluation metric:", metric)
def _train_func(engine, batch): engine.state.batch = decollate_batch(list(batch)) return [torch.zeros((1, 10, 10))]
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()]) segtrans = Compose([AddChannel(), EnsureType()]) val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict( torch.load("best_metric_model_segmentation3d_array.pth")) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to( device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) meta_data = decollate_batch(val_data[2]) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output, data in zip(val_outputs, meta_data): saver(val_output, data) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld( keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4, ), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), EnsureTyped(keys=["img", "seg"]), ]) val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader( check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["seg"].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader( val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch["img"], batch["seg"]), device, non_blocking) trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs_dict/", "net", n_saved=10, require_empty=False) trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "net": net, "opt": opt }, ) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x) train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: x) train_tensorboard_stats_handler.attach(trainer) validation_every_n_iters = 5 # set parameters for validation metric_name = "Mean_Dice" # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice()} post_pred = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) post_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)]) # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator( net, val_metrics, device, True, output_transform=lambda x, y, y_pred: ([post_pred(i) for i in decollate_batch(y_pred)], [post_label(i) for i in decollate_batch(y)]), prepare_batch=prepare_batch, ) @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.iteration, ) # fetch global iteration number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch["img"], batch["seg"]), output_transform=lambda output: output[0], global_iter_transform=lambda x: trainer.state.epoch, ) evaluator.add_event_handler( event_name=Events.ITERATION_COMPLETED(every=2), handler=val_tensorboard_image_handler, ) train_epochs = 5 state = trainer.run(train_loader, train_epochs) print(state)
def _train_func(engine, batch): engine.state.batch = decollate_batch(batch) return [ torch.randint(0, 255, (1, 2, 2)).float() for _ in range(8) ]
def test_transforms(self, case_id): set_determinism(2022) config = ConfigParser() config.read_config(TEST_CASES) config["input_keys"] = keys test_case = config.get_parsed_content(id=case_id, instantiate=True, lazy=False) # transform instance dataset = CacheDataset(self.files, transform=test_case) loader = DataLoader(dataset, batch_size=3, shuffle=True) for x in loader: self.assertIsInstance(x[keys[0]], MetaTensor) self.assertIsInstance(x[keys[1]], MetaTensor) out = decollate_batch(x) # decollate every batch should work # test forward patches loaded = out[0] if not monai_config.USE_META_DICT: self.assertEqual(len(loaded), len(keys)) else: self.assertNotEqual(len(loaded), len(keys)) img, seg = loaded[keys[0]], loaded[keys[1]] expected = config.get_parsed_content( id=f"{case_id}_answer", instantiate=True) # expected results self.assertEqual(expected["load_shape"], list(x[keys[0]].shape)) assert_allclose(expected["affine"], img.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF) assert_allclose(expected["affine"], seg.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF) test_cls = [type(x).__name__ for x in test_case.transforms] tracked_cls = [x[TraceKeys.CLASS_NAME] for x in img.applied_operations] self.assertTrue( len(tracked_cls) <= len(test_cls) ) # tracked items should be no more than the compose items. with tempfile.TemporaryDirectory() as tempdir: # test writer SaveImageD(keys, resample=False, output_dir=tempdir, output_postfix=case_id)(loaded) # test inverse inv = InvertD(keys, orig_keys=keys, transform=test_case, nearest_interp=True) out = inv(loaded) img, seg = out[keys[0]], out[keys[1]] assert_allclose(expected["inv_affine"], img.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF) assert_allclose(expected["inv_affine"], seg.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF) self.assertFalse(img.applied_operations) self.assertFalse(seg.applied_operations) assert_allclose(expected["inv_shape"], img.shape, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF) assert_allclose(expected["inv_shape"], seg.shape, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF) with tempfile.TemporaryDirectory() as tempdir: # test writer SaveImageD(keys, resample=False, output_dir=tempdir, output_postfix=case_id)(out) seg_file = os.path.join(tempdir, key_1, f"{key_1}_{case_id}.nii.gz") segout = nib.load(seg_file).get_fdata() segin = nib.load(FILE_PATH_1).get_fdata() ndiff = np.sum(np.abs(segout - segin) > 0) total = np.prod(segout.shape) self.assertTrue(ndiff / total < 0.4, f"{ndiff / total}")