def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS), CastToTyped(KEYS, dtype=torch.uint8), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) for i in engine.state.output["image_inverted"] + engine.state.output[ "label_inverted"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted"][-1].detach().cpu( ).numpy()[0].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) self.assertTrue((reverted.size - n_good) in (25300, 1812), "diff. in two possible values")
def test_fail_random_but_not_invertible(self): transforms = Compose( [AddChanneld("im"), Rand2DElasticd("im", None, None)]) with self.assertRaises(RuntimeError): TestTimeAugmentation(transforms, None, None, None)
def test_values(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") transform = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), ]) def _test_dataset(dataset): self.assertEqual(len(dataset), 52) self.assertTrue("image" in dataset[0]) self.assertTrue("label" in dataset[0]) self.assertTrue("image_meta_dict" in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44)) try: # will start downloading if testing_dir doesn't have the Decathlon files data = DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=True, ) except (ContentTooShortError, HTTPError, RuntimeError) as e: print(str(e)) if isinstance(e, RuntimeError): # FIXME: skip MD5 check as current downloading method may fail self.assertTrue(str(e).startswith("md5 check")) return # skipping this test due the network connection errors _test_dataset(data) data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=False) _test_dataset(data) # test validation without transforms data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="validation", download=False) self.assertTupleEqual(data[0]["image"].shape, (36, 47, 44)) self.assertEqual(len(data), 52) data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="training", download=False) self.assertTupleEqual(data[0]["image"].shape, (34, 56, 31)) self.assertEqual(len(data), 208) # test dataset properties data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="validation", download=False) properties = data.get_properties(keys="labels") self.assertDictEqual(properties["labels"], { "0": "background", "1": "Anterior", "2": "Posterior" }) shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus")) try: data = DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=False, ) except RuntimeError as e: print(str(e)) self.assertTrue(str(e).startswith("Cannot find dataset directory"))
def main(): print_config() # Define paths for running the script data_dir = os.path.normpath('/to/be/defined') json_path = os.path.normpath('/to/be/defined') logdir = os.path.normpath('/to/be/defined') # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used use_pretrained = 1 pretrained_path = os.path.normpath('/to/be/defined') # Training Hyper-parameters lr = 1e-4 max_iterations = 30000 eval_num = 100 if os.path.exists(logdir) == False: os.mkdir(logdir) # Training & Validation Transform chain train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), ScaleIntensityRanged( keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandFlipd( keys=["image", "label"], spatial_axis=[0], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[1], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[2], prob=0.10, ), RandRotate90d( keys=["image", "label"], prob=0.10, max_k=3, ), RandShiftIntensityd( keys=["image"], offsets=0.10, prob=0.50, ), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=["image", "label"]), ]) datalist = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="training") val_files = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="validation") train_ds = CacheDataset( data=datalist, transform=train_transforms, cache_num=24, cache_rate=1.0, num_workers=4, ) train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4, pin_memory=True) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) case_num = 0 img = val_ds[case_num]["image"] label = val_ds[case_num]["label"] img_shape = img.shape label_shape = label.shape print(f"image shape: {img_shape}, label shape: {label_shape}") os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNETR( in_channels=1, out_channels=14, img_size=(96, 96, 96), feature_size=16, hidden_size=768, mlp_dim=3072, num_heads=12, pos_embed="conv", norm_name="instance", res_block=True, dropout_rate=0.0, ) # Load ViT backbone weights into UNETR if use_pretrained == 1: print('Loading Weights from the Path {}'.format(pretrained_path)) vit_dict = torch.load(pretrained_path) vit_weights = vit_dict['state_dict'] # Remove items of vit_weights if they are not in the ViT backbone (this is used in UNETR). # For example, some variables names like conv3d_transpose.weight, conv3d_transpose.bias, # conv3d_transpose_1.weight and conv3d_transpose_1.bias are used to match dimensions # while pretraining with ViTAutoEnc and are not a part of ViT backbone. model_dict = model.vit.state_dict() vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict} model_dict.update(vit_weights) model.vit.load_state_dict(model_dict) del model_dict, vit_weights, vit_dict print('Pretrained Weights Succesfully Loaded !') elif use_pretrained == 0: print( 'No weights were loaded, all weights being used are randomly initialized!' ) model.to(device) loss_function = DiceCELoss(to_onehot_y=True, softmax=True) torch.backends.cudnn.benchmark = True optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) post_label = AsDiscrete(to_onehot=14) post_pred = AsDiscrete(argmax=True, to_onehot=14) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) global_step = 0 dice_val_best = 0.0 global_step_best = 0 epoch_loss_values = [] metric_values = [] def validation(epoch_iterator_val): model.eval() dice_vals = list() with torch.no_grad(): for step, batch in enumerate(epoch_iterator_val): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model) val_labels_list = decollate_batch(val_labels) val_labels_convert = [ post_label(val_label_tensor) for val_label_tensor in val_labels_list ] val_outputs_list = decollate_batch(val_outputs) val_output_convert = [ post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list ] dice_metric(y_pred=val_output_convert, y=val_labels_convert) dice = dice_metric.aggregate().item() dice_vals.append(dice) epoch_iterator_val.set_description( "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)) dice_metric.reset() mean_dice_val = np.mean(dice_vals) return mean_dice_val def train(global_step, train_loader, dice_val_best, global_step_best): model.train() epoch_loss = 0 step = 0 epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) for step, batch in enumerate(epoch_iterator): step += 1 x, y = (batch["image"].cuda(), batch["label"].cuda()) logit_map = model(x) loss = loss_function(logit_map, y) loss.backward() epoch_loss += loss.item() optimizer.step() optimizer.zero_grad() epoch_iterator.set_description( "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)) if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations: epoch_iterator_val = tqdm( val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) dice_val = validation(epoch_iterator_val) epoch_loss /= step epoch_loss_values.append(epoch_loss) metric_values.append(dice_val) if dice_val > dice_val_best: dice_val_best = dice_val global_step_best = global_step torch.save(model.state_dict(), os.path.join(logdir, "best_metric_model.pth")) print( "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) else: print( "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) plt.figure(1, (12, 6)) plt.subplot(1, 2, 1) plt.title("Iteration Average Loss") x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.subplot(1, 2, 2) plt.title("Val Mean Dice") x = [eval_num * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.savefig( os.path.join(logdir, 'btcv_finetune_quick_update.png')) plt.clf() plt.close(1) global_step += 1 return global_step, dice_val_best, global_step_best while global_step < max_iterations: global_step, dice_val_best, global_step_best = train( global_step, train_loader, dice_val_best, global_step_best) model.load_state_dict( torch.load(os.path.join(logdir, "best_metric_model.pth"))) print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")
def test_values(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") download_len = 1 val_frac = 1.0 collection = "QIN-PROSTATE-Repeatability" transform = Compose([ LoadImaged(keys=["image", "seg"], reader="PydicomReader", label_dict=TCIA_LABEL_DICT[collection]), AddChanneld(keys="image"), ScaleIntensityd(keys="image"), ]) def _test_dataset(dataset): self.assertEqual(len(dataset), int(download_len * val_frac)) self.assertTrue("image" in dataset[0]) self.assertTrue("seg" in dataset[0]) self.assertTrue(isinstance(dataset[0]["image"], MetaTensor)) self.assertTupleEqual(dataset[0]["image"].shape, (1, 256, 256, 24)) self.assertTupleEqual(dataset[0]["seg"].shape, (256, 256, 24, 4)) with skip_if_downloading_fails(): data = TciaDataset( root_dir=testing_dir, collection=collection, transform=transform, section="validation", download=True, download_len=download_len, copy_cache=False, val_frac=val_frac, ) _test_dataset(data) data = TciaDataset( root_dir=testing_dir, collection=collection, transform=transform, section="validation", download=False, val_frac=val_frac, ) _test_dataset(data) self.assertTrue(data[0]["image"].meta["filename_or_obj"].endswith( "QIN-PROSTATE-Repeatability/PCAMPMRI-00015/1901/image")) self.assertTrue(data[0]["seg"].meta["filename_or_obj"].endswith( "QIN-PROSTATE-Repeatability/PCAMPMRI-00015/1901/seg")) # test validation without transforms data = TciaDataset(root_dir=testing_dir, collection=collection, section="validation", download=False, val_frac=val_frac) self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24)) self.assertEqual(len(data), int(download_len * val_frac)) data = TciaDataset(root_dir=testing_dir, collection=collection, section="validation", download=False, val_frac=val_frac) self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24)) self.assertEqual(len(data), download_len) shutil.rmtree(os.path.join(testing_dir, collection)) try: TciaDataset( root_dir=testing_dir, collection=collection, transform=transform, section="validation", download=False, val_frac=val_frac, ) except RuntimeError as e: self.assertTrue(str(e).startswith("Cannot find dataset directory"))
opt = Options().parse() train_images = sorted( glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted( glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) data_dicts = [{ 'image': image_name, 'label': label_name, 'mask': mask_name } for image_name, label_name, mask_name in zip(train_images, train_segs)] monai_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), # Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # RandFlipd(keys=['image', 'label'], prob=1, spatial_axis=2), # RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1, # rotate_range=(np.pi / 36, np.pi / 4, np.pi / 36)), # Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1, # sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.20, 0.20, 0.20)), # RandAdjustContrastd(keys=['image'], gamma=(0.5, 3), prob=1), # RandGaussianNoised(keys=['image'], prob=1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), # RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=1), # BorderPadd(keys=['image', 'label'],spatial_border=(16,16,0)), # RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), # Orientationd(keys=["image", "label"], axcodes="PLI"), ToTensord(keys=['image', 'label'])
def test_shape(self, input_param, input_data, expected_shape): result = AddChanneld(**input_param)(input_data) self.assertEqual(result['img'].shape, expected_shape) self.assertEqual(result['seg'].shape, expected_shape)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz" ]), ] # 2 binary labels for gender classification: man and woman labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) val_files = [{ "img": img, "label": label } for img, label in zip(images, labels)] # Define transforms for image val_transforms = Compose([ LoadImaged(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), ToTensord(keys=["img"]), ]) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) # Create DenseNet121 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) model.load_state_dict( torch.load("best_metric_model_classification3d_dict.pth")) model.eval() with torch.no_grad(): num_correct = 0.0 metric_count = 0 saver = CSVSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["label"].to(device) val_outputs = model(val_images).argmax(dim=1) value = torch.eq(val_outputs, val_labels) metric_count += len(value) num_correct += value.sum().item() saver.save_batch(val_outputs, val_data["img_meta_dict"]) metric = num_correct / metric_count print("evaluation metric:", metric) saver.finalize()
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI314-IOP-0889-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI249-Guys-1072-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI609-HH-2600-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI173-HH-1590-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI020-Guys-0700-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI342-Guys-0909-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI134-Guys-0780-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI577-HH-2661-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI066-Guys-0731-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI130-HH-1528-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz" ]), ] # 2 binary labels for gender classification: man and woman labels = np.array( [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) train_files = [{ "img": img, "label": label } for img, label in zip(images[:10], labels[:10])] val_files = [{ "img": img, "label": label } for img, label in zip(images[-10:], labels[-10:])] # define transforms for image train_transforms = Compose([ LoadNiftid(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img"]), ]) val_transforms = Compose([ LoadNiftid(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), ToTensord(keys=["img"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["label"]) # create DenseNet121, CrossEntropyLoss and Adam optimizer net = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2) loss = torch.nn.CrossEntropyLoss() lr = 1e-5 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device("cuda:0") # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch["img"], batch["label"]), device, non_blocking) trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "net": net, "opt": opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name="trainer") train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) # set parameters for validation validation_every_n_epochs = 1 metric_name = "Accuracy" # add evaluation metric to the evaluator engine val_metrics = { metric_name: Accuracy(), "AUC": ROCAUC(to_onehot_y=True, softmax=True) } # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every epoch val_tensorboard_stats_handler = TensorBoardStatsHandler( output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch, ) # fetch global epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add early stopping handler to evaluator early_stopper = EarlyStopping( patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) def run_validation(engine): evaluator.run(val_loader) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs) print(state)
Rand2DElasticd, RandAffined, ) aug_prob = 0.5 keys = ("img", "seg") # use these when interpolating binary segmentations to ensure values are 0 or 1 only zoom_mode = monai.utils.enums.InterpolateMode.NEAREST elast_mode = monai.utils.enums.GridSampleMode.BILINEAR, monai.utils.enums.GridSampleMode.NEAREST trans = Compose( [ ScaleIntensityd(keys=("img",)), # rescale image data to range [0,1] AddChanneld(keys=keys), # add 1-size channel dimension RandRotate90d(keys=keys, prob=aug_prob), RandFlipd(keys=keys, prob=aug_prob), RandZoomd(keys=keys, prob=aug_prob, mode=zoom_mode), Rand2DElasticd(keys=keys, prob=aug_prob, spacing=10, magnitude_range=(-2, 2), mode=elast_mode), RandAffined(keys=keys, prob=aug_prob, rotate_range=1, translate_range=16, mode=elast_mode), ToTensord(keys=keys), # convert to tensor ] ) data = [ {"img": train_images[i], "seg": train_segs[i]} for i in range(len(train_images)) ] ds = CacheDataset(data, trans)
def main(config): now = datetime.now().strftime("%Y%m%d-%H:%M:%S") # path csv_path = config['path']['csv_path'] trained_model_path = config['path'][ 'trained_model_path'] # if None, trained from scratch training_model_folder = os.path.join( config['path']['training_model_folder'], now) # '/path/to/folder' if not os.path.exists(training_model_folder): os.makedirs(training_model_folder) logdir = os.path.join(training_model_folder, 'logs') if not os.path.exists(logdir): os.makedirs(logdir) # PET CT scan params image_shape = tuple(config['preprocessing']['image_shape']) # (x, y, z) in_channels = config['preprocessing']['in_channels'] voxel_spacing = tuple( config['preprocessing'] ['voxel_spacing']) # (4.8, 4.8, 4.8) # in millimeter, (x, y, z) data_augment = config['preprocessing'][ 'data_augment'] # True # for training dataset only resize = config['preprocessing']['resize'] # True # not use yet origin = config['preprocessing']['origin'] # how to set the new origin normalize = config['preprocessing'][ 'normalize'] # True # whether or not to normalize the inputs number_class = config['preprocessing']['number_class'] # 2 # CNN params architecture = config['model']['architecture'] # 'unet' or 'vnet' cnn_params = config['model'][architecture]['cnn_params'] # transform list to tuple for key, value in cnn_params.items(): if isinstance(value, list): cnn_params[key] = tuple(value) # Training params epochs = config['training']['epochs'] batch_size = config['training']['batch_size'] shuffle = config['training']['shuffle'] opt_params = config['training']["optimizer"]["opt_params"] # Get Data DM = DataManager(csv_path=csv_path) train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test( wrap_with_dict=True) # Input preprocessing # use data augmentation for training train_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # user can also add other random transforms RandAffined(keys=("pet_img", "ct_img", "mask_img"), spatial_size=None, prob=0.4, rotate_range=(0, np.pi / 30, np.pi / 15), shear_range=None, translate_range=(10, 10, 10), scale_range=(0.1, 0.1, 0.1), mode=("bilinear", "bilinear", "nearest"), padding_mode="border"), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # without data augmentation for validation val_transforms = Compose([ # read img + meta info LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), Roi2Mask(keys=['pet_img', 'mask_img'], method='otsu', tval=0.0, idx_channel=0), ResampleReshapeAlign(target_shape=image_shape, target_voxel_spacing=voxel_spacing, keys=['pet_img', "ct_img", 'mask_img'], origin='head', origin_key='pet_img'), Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), # normalize input ScaleIntensityRanged( keys=["pet_img"], a_min=0.0, a_max=25.0, b_min=0.0, b_max=1.0, clip=True, ), ScaleIntensityRanged( keys=["ct_img"], a_min=-1000.0, a_max=1000.0, b_min=0.0, b_max=1.0, clip=True, ), # Prepare for neural network ConcatModality(keys=['pet_img', 'ct_img']), AddChanneld(keys=["mask_img"]), # Add channel to the first axis ToTensord(keys=["image", "mask_img"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_images_paths, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle, num_workers=2) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_images_paths, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=batch_size, num_workers=2) # Model # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, # 3D in_channels=in_channels, out_channels=1, kernel_size=5, channels=(8, 16, 32, 64, 128), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True) opt = torch.optim.Adam(net.parameters(), 1e-3) # training val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), # TensorBoardImageHandler( # log_dir="./runs/", # batch_transform=lambda x: (x["image"], x["label"]), # output_transform=lambda x: x["pred"], # ), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SimpleInferer(), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "val_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "val_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), ]) train_handlers = [ # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, prepare_batch=lambda x: (x['image'], x['mask_img']), inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])), "train_precision": Precision(output_transform=lambda x: (x["pred"], x["label"])), "train_recall": Recall(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()
def evaluta_model(test_files, model_name): test_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), ConcatItemsd(keys=modalDataKey, name="inputs"), ToTensord(keys=["inputs"]), ] ) # create a validation data loader device = torch.device("cpu") print(len(test_files)) test_ds = monai.data.Dataset(data=test_files, transform=test_transforms) test_loader = DataLoader(test_ds, batch_size=len(test_files), num_workers=2, pin_memory=torch.device) # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device) model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device) # Evaluate the model on test dataset # # print(os.path.basename(model_name).split('.')[0]) checkpoint = torch.load(model_name) model.load_state_dict(checkpoint['model']) # optimizer.load_state_dict(checkpoint['optimizer']) # epochs = checkpoint['epoch'] # model.load_state_dict(torch.load(log_dir)) model.eval() with torch.no_grad(): saver = CSVSaver(output_dir="../result/GLeason/2d_output/", filename=os.path.basename(model_name).split('.')[0] + '.csv') for test_data in test_loader: test_images, test_labels = test_data["inputs"].to(device), test_data["label"].to(device) pred = model(test_images) # Gleason Classification # y_soft_label = (test_labels / 0.25).long() # y_soft_pred = (pred / 0.25).round().squeeze_().long() # print(test_data) probabilities = torch.sigmoid(pred) # pred2 = model(test_images).argmax(dim=1) # print(test_data) # saver.save_batch(probabilities.argmax(dim=1), test_data["t2Img_meta_dict"]) # zero = torch.zeros_like(probabilities) # one = torch.ones_like(probabilities) # y_pred_ordinal = torch.where(probabilities > 0.5, one, zero) # y_pred_acc = (y_pred_ordinal.sum(1)).to(torch.long) saver.save_batch(probabilities.argmax(dim=1), test_data["dwiImg_meta_dict"]) # print(test_labels) # print(probabilities[:, 1]) # for x in np.nditer(probabilities[:, 1]): # print(x) # prob_list.append(x) # falseList = [] # trueList = [] # for pre, label in zip(pred2.tolist(), test_labels.tolist() ): # if pre == 0 and label == 0: # falseList.append(0) # elif pre == 1 and label == 1: # trueList.append(1) # specificity = (falseList.count(0) / test_labels.tolist().count(0)) # sensitivity = (trueList.count(1) / test_labels.tolist().count(1)) # print('specificity:' + '%.4f' % specificity + ',', # 'sensitivity:' + '%.4f' % sensitivity + ',', # 'accuracy:' + '%.4f' % ((specificity + sensitivity) / 2)) # print(type(test_labels), type(pred)) # fpr, tpr, thresholds = roc_curve(test_labels, probabilities[:, 1]) # roc_auc = auc(fpr, tpr) # print('AUC = ' + str(roc_auc)) # AUC_list.append(roc_auc) # # print(accuracy_score(test_labels, pred2)) # accuracy_list.append(accuracy_score(test_labels, pred2)) # plt.plot(fpr, tpr, linewidth=2, label="ROC") # plt.xlabel("false presitive rate") # plt.ylabel("true presitive rate") # # plt.ylim(0, 1.05) # # plt.xlim(0, 1.05) # plt.legend(loc=4) # 图例的位置 # plt.show() saver.finalize() # cm = confusion_matrix(test_labels, y_pred_acc) cm = confusion_matrix(test_labels, probabilities.argmax(dim=1)) # cm = confusion_matrix(y_soft_label, y_soft_pred) # kappa_value = cohen_kappa_score(test_labels, y_pred_acc, weights='quadratic') kappa_value = cohen_kappa_score(test_labels, probabilities.argmax(dim=1), weights='quadratic') print('quadratic weighted kappa=' + str(kappa_value)) kappa_list.append(kappa_value) plot_confusion_matrix(cm, 'confusion_matrix.png', title='confusion matrix') from sklearn.metrics import classification_report print(classification_report(test_labels, probabilities.argmax(dim=1), digits=4)) accuracy_list.append( classification_report(test_labels, probabilities.argmax(dim=1), digits=4, output_dict=True)["accuracy"])
def training(train_files, val_files, log_dir): # Define transforms for image print(log_dir) train_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ConcatItemsd(keys=modalDataKey, name="inputs"), RandRotate90d(keys=["inputs"], prob=0.8, spatial_axes=[0, 1]), RandAffined(keys=["inputs"], prob=0.8, scale_range=[0.1, 0.5]), RandZoomd(keys=["inputs"], prob=0.8, max_zoom=1.5, min_zoom=0.5), # RandFlipd(keys=["inputs"], prob=0.5, spatial_axis=1), ToTensord(keys=["inputs"]), ] ) val_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ConcatItemsd(keys=modalDataKey, name="inputs"), ToTensord(keys=["inputs"]), ] ) # data_size = len(full_files) # split = data_size // 2 # indices = list(range(data_size)) # train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) # valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) # full_loader = DataLoader(full_files, batch_size=64, sampler=sampler(full_files), pin_memory=True) # train_loader = DataLoader(full_files, batch_size=128, sampler=train_sampler, collate_fn=collate_fn) # val_loader = DataLoader(full_files, batch_size=split, sampler=valid_sampler, collate_fn=collate_fn) # DL = DataLoader(train_files, batch_size=64, shuffle=True, num_workers=0, drop_last=True, collate_fn=collate_fn) # randomBatch_sizeList = [8, 16, 32, 64, 128] # randomLRList = [1e-4, 1e-5, 5e-5, 5e-4, 1e-3] # batch_size = random.choice(randomBatch_sizeList) # lr = random.choice(randomLRList) lr = 0.01 batch_size = 256 # print(batch_size) # print(lr) # Define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device) check_data = monai.utils.misc.first(check_loader) # print(check_data) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.device) # train_data = monai.utils.misc.first(train_loader) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device) # Create Net, CrossEntropyLoss and Adam optimizer # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device) # model = densenet121(spatial_dims=2, in_channels=3, out_channels=5).to(device) # im_size = (2,) + tuple(train_ds[0]["inputs"].shape) model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device) classes = np.array([0, 1, 2, 3, 4]) # print(check_data["label"].numpy()) class_weights = class_weight.compute_class_weight('balanced', classes, check_data["label"].numpy()) class_weights_tensor = torch.Tensor(class_weights).to(device) # print(class_weights_tensor) # loss_function = nn.BCEWithLogitsLoss() loss_function = torch.nn.CrossEntropyLoss(weight=class_weights_tensor) # loss_function = torch.nn.MSELoss() # m = torch.nn.LogSoftmax(dim=1) optimizer = torch.optim.Adam(model.parameters(), lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.5, last_epoch=-1) # 如果有保存的模型,则加载模型,并在其基础上继续训练 if os.path.exists(log_dir): checkpoint = torch.load(log_dir) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] print('加载 epoch {} 成功!'.format(start_epoch)) else: start_epoch = 0 print('无保存模型,将从头开始训练!') # start a typical PyTorch training epoch_num = 300 val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() # checkpoint_interval = 100 for epoch in range(start_epoch + 1, epoch_num): print("-" * 10) print(f"epoch {epoch + 1}/{epoch_num}") # print(scheduler.get_last_lr()) model.train() epoch_loss = 0 step = 0 # for i, (inputs, labels, imgName) in enumerate(train_loader): for batch_data in train_loader: step += 1 inputs, labels = batch_data["inputs"].to(device), batch_data["label"].to(device) # batch_arr = [] # for j in range(len(inputs)): # batch_arr.append(inputs[i]) # batch_img = Variable(torch.from_numpy(np.array(batch_arr)).to(device)) # labels = Variable(torch.from_numpy(np.array(labels)).to(device)) # batch_img = batch_img.type(torch.FloatTensor).to(device) outputs = model(inputs) # y_ordinal_encoding = transformOrdinalEncoding(labels, labels.shape[0], 5) # loss = loss_function(outputs, torch.from_numpy(y_ordinal_encoding).to(device)) loss = loss_function(outputs, labels.long()) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print(f"{step}/{len(train_loader)}, train_loss: {loss.item():.4f}") epoch_len = len(train_loader) // train_loader.batch_size writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step scheduler.step() print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0])) epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") # if (epoch + 1) % checkpoint_interval == 0: # 每隔checkpoint_interval保存一次 # checkpoint = {'model': model.state_dict(), # 'optimizer': optimizer.state_dict(), # 'epoch': epoch # } # path_checkpoint = './model/checkpoint_{}_epoch.pth'.format(epoch) # torch.save(checkpoint, path_checkpoint) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) # for i, (inputs, labels, imgName) in enumerate(val_loader): for val_data in val_loader: val_images, val_labels = val_data["inputs"].to(device), val_data["label"].to(device) # val_batch_arr = [] # for j in range(len(inputs)): # val_batch_arr.append(inputs[i]) # val_img = Variable(torch.from_numpy(np.array(val_batch_arr)).to(device)) # labels = Variable(torch.from_numpy(np.array(labels)).to(device)) # val_img = val_img.type(torch.FloatTensor).to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) # y_ordinal_encoding = transformOrdinalEncoding(y, y.shape[0], 5) # y_pred = torch.sigmoid(y_pred) # y = (y / 0.25).long() # print(y) # auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True) # zero = torch.zeros_like(y_pred) # one = torch.ones_like(y_pred) # y_pred_label = torch.where(y_pred > 0.5, one, zero) # print((y_pred_label.sum(1)).to(torch.long)) # y_pred_acc = (y_pred_label.sum(1)).to(torch.long) # print(y_pred.argmax(dim=1)) # kappa_value = kappa(cm) kappa_value = cohen_kappa_score(y.to("cpu"), y_pred.argmax(dim=1).to("cpu"), weights='quadratic') # kappa_value = cohen_kappa_score(y.to("cpu"), y_pred_acc.to("cpu"), weights='quadratic') metric_values.append(kappa_value) acc_value = torch.eq(y_pred.argmax(dim=1), y) # print(acc_value) acc_metric = acc_value.sum().item() / len(acc_value) if kappa_value > best_metric: best_metric = kappa_value best_metric_epoch = epoch + 1 checkpoint = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(checkpoint, log_dir) print("saved new best metric model") print( "current epoch: {} current Kappa: {:.4f} current accuracy: {:.4f} best Kappa: {:.4f} at epoch {}".format( epoch + 1, kappa_value, acc_metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_accuracy", acc_metric, epoch + 1) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close() plt.figure('train', (12, 6)) plt.subplot(1, 2, 1) plt.title("Epoch Average Loss") x = [i + 1 for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel('epoch') plt.plot(x, y) plt.subplot(1, 2, 2) plt.title("Validation: Area under the ROC curve") x = [val_interval * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel('epoch') plt.plot(x, y) plt.show() evaluta_model(val_files, log_dir)
def image_mixing(data, seed=None): #random.seed(seed) file_list = [x for x in data if int(x['_label']) == 1] random.shuffle(file_list) crop_foreground = CropForegroundd(keys=["image"], source_key="image", margin=(0, 0, 0), select_fn=lambda x: x != 0) WW, WL = 1500, -600 ct_window = CTWindowd(keys=["image"], width=WW, level=WL) resize2 = Resized(keys=["image"], spatial_size=(int(512 * 0.75), int(512 * 0.75), -1), mode="area") resize1 = Resized(keys=["image"], spatial_size=(-1, -1, 40), mode="nearest") gauss = GaussianSmooth(sigma=(1., 1., 0)) gauss2 = GaussianSmooth(sigma=(2.0, 2.0, 0)) affine = Affined(keys=["image"], scale_params=(1.0, 2.0, 1.0), padding_mode='zeros') common_transform = Compose([ LoadImaged(keys=["image"]), ct_window, CTSegmentation(keys=["image"]), AddChanneld(keys=["image"]), affine, crop_foreground, resize1, resize2, SqueezeDimd(keys=["image"]), ]) dirs = setup_directories() data_dir = dirs['data'] mixed_images_dir = os.path.join(data_dir, 'mixed_images') if not os.path.exists(mixed_images_dir): os.mkdir(mixed_images_dir) for img1, img2 in itertools.combinations(file_list, 2): img1 = {'image': img1["image"], 'seg': img1['seg']} img2 = {'image': img2["image"], 'seg': img2['seg']} img1_data = common_transform(img1)["image"] img2_data = common_transform(img2)["image"] img1_mask, img2_mask = (img1_data > 0), (img2_data > 0) img_presek = np.logical_and(img1_mask, img2_mask) img = np.maximum(img_presek * img1_data, img_presek * img2_data) multi_slice_viewer(img, "img1") loop = True while loop: save = input("Save image [y/n/e]: ") if save.lower() == 'y': loop = False k = str(time.time()).encode('utf-8') h = blake2b(key=k, digest_size=16) name = h.hexdigest() + '.nii.gz' out_path = os.path.join(mixed_images_dir, name) write_nifti(img, out_path, resample=False) elif save.lower() == 'n': loop = False break elif save.lower() == 'e': print("exeting") exit() else: print("wrong input!")
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, postfix="inverted1", to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) # test different nearest interpolation values TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="image", nearest_interp=[True, False], post_func=[lambda x: x + 10, lambda x: x], postfix="inverted2", collate_fn=pad_list_data_collate, num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) # check the nearest inerpolation mode for i in engine.state.output["image_inverted1"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted1"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted1"][-1].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1824: torch 1.5.1 self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") # check the case that different items use different interpolation mode to invert transforms d = engine.state.output["image_inverted2"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107)) d = engine.state.output["label_inverted2"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107))
def load_image(filename): data = {"image": filename} t = Compose([LoadImaged(keys="image"), AddChanneld(keys="image")]) return t(data)
def test_values(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") transform = Compose( [ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), ] ) def _test_dataset(dataset): self.assertEqual(len(dataset), 52) self.assertTrue("image" in dataset[0]) self.assertTrue("label" in dataset[0]) self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44)) with skip_if_downloading_fails(): data = DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=True, copy_cache=False, ) _test_dataset(data) data = DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=False ) _test_dataset(data) self.assertTrue(data[0][PostFix.meta("image")]["filename_or_obj"].endswith("hippocampus_163.nii.gz")) self.assertTrue(data[0][PostFix.meta("label")]["filename_or_obj"].endswith("hippocampus_163.nii.gz")) # test validation without transforms data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="validation", download=False) self.assertTupleEqual(data[0]["image"].shape, (36, 47, 44)) self.assertEqual(len(data), 52) data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="training", download=False) self.assertTupleEqual(data[0]["image"].shape, (34, 56, 31)) self.assertEqual(len(data), 208) # test dataset properties data = DecathlonDataset( root_dir=Path(testing_dir), task="Task04_Hippocampus", section="validation", download=False ) properties = data.get_properties(keys="labels") self.assertDictEqual(properties["labels"], {"0": "background", "1": "Anterior", "2": "Posterior"}) shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus")) try: DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=False, ) except RuntimeError as e: print(str(e)) self.assertTrue(str(e).startswith("Cannot find dataset directory"))
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray(im.astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray(seg.astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AddChanneld(keys=["img", "seg"]), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose( [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict( torch.load("best_metric_model_segmentation2d_dict.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = PNGSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) saver.save_batch(val_outputs) metric = metric_sum / metric_count print("evaluation metric:", metric)
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AddChanneld(keys=["img", "seg"]), ScaleIntensityd(keys=["img", "seg"]), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 1]), EnsureTyped(keys=["img", "seg"]), ]) val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AddChanneld(keys=["img", "seg"]), ScaleIntensityd(keys=["img", "seg"]), EnsureTyped(keys=["img", "seg"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["seg"].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(10): print("-" * 10) print(f"epoch {epoch + 1}/{10}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation2d_dict.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz" ]), ] # 2 binary labels for gender classification: man and woman labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) val_files = [{ "img": img, "label": label } for img, label in zip(images, labels)] # define transforms for image val_transforms = Compose([ LoadNiftid(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), ToTensord(keys=["img"]), ]) # create DenseNet121 net = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch["img"], batch["label"]), device, non_blocking) metric_name = "Accuracy" # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data prediction_saver = ClassificationSaver( output_dir="tempdir", name="evaluator", batch_transform=lambda batch: batch["img_meta_dict"], output_transform=lambda output: output[0].argmax(1), ) prediction_saver.attach(evaluator) # the model was trained by "densenet_training_dict" example CheckpointLoader(load_path="./runs_dict/net_checkpoint_20.pth", load_dict={ "net": net }).attach(evaluator) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) state = evaluator.run(val_loader) print(state)
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted( glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted( glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted( glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted( glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted( glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted( glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted( glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted( glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip( train_images_for_dice, train_segs_for_dice)] val_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(test_images, test_segs)] # Transforms list if opt.resolution is not None: train_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # try to use all the available GPUs devices = get_devices_spec(None) # build the network net = build_net() net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # loss_function = monai.losses.DiceLoss(sigmoid=True) loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7) optim = torch.optim.Adam(net.parameters(), lr=opt.lr) net_scheduler = get_scheduler(optim, opt) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda( ), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda( ), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, net) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice( val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice( train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice( test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}" .format(epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch)) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI314-IOP-0889-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI249-Guys-1072-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI609-HH-2600-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI173-HH-1590-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI020-Guys-0700-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI342-Guys-0909-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI134-Guys-0780-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI577-HH-2661-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI066-Guys-0731-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI130-HH-1528-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz" ]), ] # 2 binary labels for gender classification: man and woman labels = np.array( [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) train_files = [{ "img": img, "label": label } for img, label in zip(images[:10], labels[:10])] val_files = [{ "img": img, "label": label } for img, label in zip(images[-10:], labels[-10:])] # Define transforms for image train_transforms = Compose([ LoadImaged(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img"]), ]) val_transforms = Compose([ LoadImaged(keys=["img"]), AddChanneld(keys=["img"]), ScaleIntensityd(keys=["img"]), Resized(keys=["img"], spatial_size=(96, 96, 96)), ToTensord(keys=["img"]), ]) # Define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["label"]) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) # Create DenseNet121, CrossEntropyLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), 1e-5) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["label"].to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) acc_value = torch.eq(y_pred.argmax(dim=1), y) acc_metric = acc_value.sum().item() / len(acc_value) auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True) if acc_metric > best_metric: best_metric = acc_metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_classification3d_dict.pth") print("saved new best metric model") print( "current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}" .format(epoch + 1, acc_metric, auc_metric, best_metric, best_metric_epoch)) writer.add_scalar("val_accuracy", acc_metric, epoch + 1) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def main(hparams): print('===== INITIAL PARAMETERS =====') print('Model name: ', hparams.name) print('Batch size: ', hparams.batch_size) print('Patch size: ', hparams.patch_size) print('Epochs: ', hparams.epochs) print('Learning rate: ', hparams.learning_rate) print('Loss function: ', hparams.loss) print() ### Data collection data_dir = 'data/' print('Available directories: ', os.listdir(data_dir)) # Get paths for images and masks, organize into dictionaries images = sorted(glob.glob(data_dir + '**/*CTImg*', recursive=True)) masks = sorted(glob.glob(data_dir + '**/*Mask*', recursive=True)) data_dicts = [{'image': image_file, 'mask': mask_file} for image_file, mask_file in zip(images, masks)] # Dataset selection train_dicts = select_animals(images, masks, [12, 13, 14, 18, 20]) val_dicts = select_animals(images, masks, [25]) test_dicts = select_animals(images, masks, [27]) data_keys = ['image', 'mask'] # Data transformation data_transforms = Compose([ LoadNiftid(keys=data_keys), AddChanneld(keys=data_keys), ScaleIntensityd(keys=data_keys), CropForegroundd(keys=data_keys, source_key='image'), RandSpatialCropd( keys=data_keys, roi_size=(hparams.patch_size, hparams.patch_size, 1), random_size=False ), ]) train_transforms = Compose([ data_transforms, ToTensord(keys=data_keys) ]) val_transforms = Compose([ data_transforms, ToTensord(keys=data_keys) ]) test_transforms = Compose([ data_transforms, ToTensord(keys=data_keys) ]) # Data loaders data_loaders = { 'train': create_loader(train_dicts, batch_size=hparams.batch_size, transforms=train_transforms, shuffle=True), 'val': create_loader(val_dicts, transforms=val_transforms), 'test': create_loader(test_dicts, transforms=test_transforms) } for key in data_loaders: print(key, len(data_loaders[key])) ### Model training if hparams.loss == 'Dice': criterion = monai.losses.DiceLoss(to_onehot_y=True, do_softmax=True) elif hparams.loss == 'CrossEntropy': criterion = nn.CrossEntropyLoss() model = UNet( dimensions=2, in_channels=1, out_channels=2, channels=(64, 128, 258, 512, 1024), strides=(2, 2, 2, 2), norm=monai.networks.layers.Norm.BATCH, criterion=criterion, hparams=hparams, ) early_stopping = EarlyStopping('val_loss') checkpoint_callback = ModelCheckpoint logger = TensorBoardLogger('models/' + hparams.name + '/tb_logs', name=hparams.name) trainer = Trainer( check_val_every_n_epoch=5, default_save_path='models/' + hparams.name + '/checkpoints', # early_stop_callback=early_stopping, gpus=1, max_epochs=hparams.epochs, # min_epochs=10, logger=logger ) trainer.fit( model, train_dataloader=data_loaders['train'], val_dataloaders=data_loaders['val'] )
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AddChanneld(keys=["img", "seg"]), ScaleIntensityd(keys=["img", "seg"]), EnsureTyped(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict( torch.load("best_metric_model_segmentation2d_dict.pth")) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output in val_outputs: saver(val_output) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
def test_fail_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) with self.assertRaises(RuntimeError): TestTimeAugmentation(transforms, None, None, None)
def test_inverse_inferred_seg(self): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({ "image": image, "label": label.astype(np.float32) }) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform != "darwin" else 0 transforms = Compose([ AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99)) ]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2, ), ).to(device) data = first(loader) labels = data["label"].to(device) segs = model(labels).detach().cpu() label_transform_key = "label" + InverseKeys.KEY_SUFFIX segs_dict = { "label": segs, label_transform_key: data[label_transform_key] } segs_dict_decollated = decollate_batch(segs_dict) # inverse of individual segmentation seg_dict = first(segs_dict_decollated) # test to convert interpolation mode for 1 data of model output batch convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) # Inverse of batch batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation) with allow_missing_keys_mode(transforms): inv_batch = batch_inverter(segs_dict) self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
def test_test_time_augmentation(self): input_size = (20, 20) device = "cuda" if torch.cuda.is_available() else "cpu" keys = ["image", "label"] num_training_ims = 10 train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) transforms = Compose([ AddChanneld(keys), RandAffined( keys, prob=1.0, spatial_size=(30, 30), rotate_range=(np.pi / 3, np.pi / 3), translate_range=(3, 3), scale_range=((0.8, 1), (0.8, 1)), padding_mode="zeros", mode=("bilinear", "nearest"), as_tensor_output=False, ), CropForegroundd(keys, source_key="image"), DivisiblePadd(keys, 4), ]) train_ds = CacheDataset(train_data, transforms) # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) num_epochs = 10 for _ in trange(num_epochs): epoch_loss = 0 for batch_data in train_loader: inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= len(train_loader) post_trans = Compose([ Activations(sigmoid=True), AsDiscrete(threshold_values=True), ]) def inferrer_fn(x): return post_trans(model(x)) tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) mode, mean, std, vvc = tt_aug(test_data) self.assertEqual(mode.shape, (1, ) + input_size) self.assertEqual(mean.shape, (1, ) + input_size) self.assertTrue(all(np.unique(mode) == (0, 1))) self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) self.assertEqual(std.shape, (1, ) + input_size) self.assertIsInstance(vvc, float)
"/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" ] # 2 binary labels for gender classification: man and woman labels = np.array([ 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)] # define transforms for image val_transforms = Compose([ LoadNiftid(keys=['img']), AddChanneld(keys=['img']), ScaleIntensityd(keys=['img']), Resized(keys=['img'], spatial_size=(96, 96, 96)), ToTensord(keys=['img']) ]) # create DenseNet121 net = monai.networks.nets.densenet.densenet121( spatial_dims=3, in_channels=1, out_channels=2, ) device = torch.device("cuda:0") def prepare_batch(batch, device=None, non_blocking=False):
Here we use several transforms to augment the dataset: 1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files. 1. `AddChanneld` as the original data doesn't have channel dim, add 1 dim to construct "channel first" shape. 1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix. 1. `Orientationd` unifies the data orientation based on the affine matrix. 1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1]. 1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels. 1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio. The image centers of negative samples must be in valid body area. 1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform. 1. `ToTensord` converts the numpy array to PyTorch Tensor for further steps. """ train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys=["image"], a_min=-1000, a_max=300, b_min=0.0, b_max=1.0, clip=True, ), #CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"],
test_files = data_dicts[0:52] """## Set deterministic training for reproducibility""" set_determinism(seed=0) ''' Label 1: Bladder Label 2: Liver Label 3: Lungs Label 4: Heart Label 5: Pancreas ''' test_transforms = Compose([ LoadImaged(keys="image"), AddChanneld(keys="image"), Spacingd(keys="image", pixdim=(1.5, 1.5, 2.0), mode="bilinear"), Orientationd(keys="image", axcodes="RAS"), ScaleIntensityRanged( keys="image", a_min=-1000, a_max=300, b_min=0.0, b_max=1.0, clip=True, ), #CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys="image"), ]) test_ds = CacheDataset(data=test_files,