glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) data_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(train_images, train_segs)] monai_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), Orientationd(keys=["image", "label"], axcodes="RAS"), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd( keys=['image', 'label'], source_key='image', start_coord_key='foreground_start_coord', end_coord_key='foreground_end_coord', ), # crop CropForeground NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), # Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method='end'), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=[ 'image', 'label', 'foreground_start_coord', 'foreground_end_coord' ], ) ]
TESTS.append(("BorderPadd 3d", "3D", 0, True, BorderPadd(KEYS, [4]))) TESTS.append(("DivisiblePadd 2d", "2D", 0, True, DivisiblePadd(KEYS, k=4))) TESTS.append( ("DivisiblePadd 3d", "3D", 0, True, DivisiblePadd(KEYS, k=[4, 8, 11]))) TESTS.append(("CenterSpatialCropd 2d", "2D", 0, True, CenterSpatialCropd(KEYS, roi_size=95))) TESTS.append(("CenterSpatialCropd 3d", "3D", 0, True, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]))) TESTS.append(("CropForegroundd 2d", "2D", 0, True, CropForegroundd(KEYS, source_key="label", margin=2))) TESTS.append(("CropForegroundd 3d", "3D", 0, True, CropForegroundd(KEYS, source_key="label", k_divisible=[5, 101, 2]))) TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, True, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) TESTS.append(("Flipd 3d", "3D", 0, True, Flipd(KEYS, [1, 2]))) TESTS.append(("Flipd 3d", "3D", 0, True, Flipd(KEYS, [1, 2]))) TESTS.append(("RandFlipd 3d", "3D", 0, True, RandFlipd(KEYS, 1, [1, 2]))) TESTS.append(("RandAxisFlipd 3d", "3D", 0, True, RandAxisFlipd(KEYS, 1)))
def image_mixing(data, seed=None): #random.seed(seed) file_list = [x for x in data if int(x['_label']) == 1] random.shuffle(file_list) crop_foreground = CropForegroundd(keys=["image"], source_key="image", margin=(0, 0, 0), select_fn=lambda x: x != 0) WW, WL = 1500, -600 ct_window = CTWindowd(keys=["image"], width=WW, level=WL) resize2 = Resized(keys=["image"], spatial_size=(int(512 * 0.75), int(512 * 0.75), -1), mode="area") resize1 = Resized(keys=["image"], spatial_size=(-1, -1, 40), mode="nearest") gauss = GaussianSmooth(sigma=(1., 1., 0)) gauss2 = GaussianSmooth(sigma=(2.0, 2.0, 0)) affine = Affined(keys=["image"], scale_params=(1.0, 2.0, 1.0), padding_mode='zeros') common_transform = Compose([ LoadImaged(keys=["image"]), ct_window, CTSegmentation(keys=["image"]), AddChanneld(keys=["image"]), affine, crop_foreground, resize1, resize2, SqueezeDimd(keys=["image"]), ]) dirs = setup_directories() data_dir = dirs['data'] mixed_images_dir = os.path.join(data_dir, 'mixed_images') if not os.path.exists(mixed_images_dir): os.mkdir(mixed_images_dir) for img1, img2 in itertools.combinations(file_list, 2): img1 = {'image': img1["image"], 'seg': img1['seg']} img2 = {'image': img2["image"], 'seg': img2['seg']} img1_data = common_transform(img1)["image"] img2_data = common_transform(img2)["image"] img1_mask, img2_mask = (img1_data > 0), (img2_data > 0) img_presek = np.logical_and(img1_mask, img2_mask) img = np.maximum(img_presek * img1_data, img_presek * img2_data) multi_slice_viewer(img, "img1") loop = True while loop: save = input("Save image [y/n/e]: ") if save.lower() == 'y': loop = False k = str(time.time()).encode('utf-8') h = blake2b(key=k, digest_size=16) name = h.hexdigest() + '.nii.gz' out_path = os.path.join(mixed_images_dir, name) write_nifti(img, out_path, resample=False) elif save.lower() == 'n': loop = False break elif save.lower() == 'e': print("exeting") exit() else: print("wrong input!")
def main(train_output): logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # Setup directories dirs = setup_directories() # Setup torch device device, using_gpu = create_device("cuda") # Load and randomize images # HACKATON image and segmentation data hackathon_dir = os.path.join(dirs["data"], 'HACKATHON') map_fn = lambda x: (x[0], int(x[1])) with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp: train_info_hackathon = [ map_fn(entry.strip().split(',')) for entry in fp.readlines() ] image_dir = os.path.join(hackathon_dir, 'images', 'train') seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train') _train_data_hackathon = get_data_from_info(image_dir, seg_dir, train_info_hackathon, dual_output=False) large_image_splitter(_train_data_hackathon, dirs["cache"]) balance_training_data(_train_data_hackathon, seed=72) # PSUF data """psuf_dir = os.path.join(dirs["data"], 'psuf') with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp: train_info = [entry.strip().split(',') for entry in fp.readlines()] image_dir = os.path.join(psuf_dir, 'images') train_data_psuf = get_data_from_info(image_dir, None, train_info)""" # Split data into train, validate and test train_split, test_data_hackathon = train_test_split(_train_data_hackathon, test_size=0.2, shuffle=True, random_state=42) #train_data_hackathon, valid_data_hackathon = train_test_split(train_split, test_size=0.2, shuffle=True, random_state=43) # Setup transforms # Crop foreground crop_foreground = CropForegroundd( keys=["image"], source_key="image", margin=(5, 5, 0), #select_fn = lambda x: x != 0 ) # Crop Z crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12)) # Window width and level (window center) WW, WL = 1500, -600 ct_window = CTWindowd(keys=["image"], width=WW, level=WL) spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30)) resize = Resized(keys=["image"], spatial_size=(int(512 * 0.50), int(512 * 0.50), -1), mode="trilinear") # Create transforms common_transform = Compose([ LoadImaged(keys=["image"]), ct_window, CTSegmentation(keys=["image"]), AddChanneld(keys=["image"]), resize, crop_foreground, crop_z, spatial_pad, ]) hackathon_train_transfrom = Compose([ common_transform, ToTensord(keys=["image"]), ]).flatten() psuf_transforms = Compose([ LoadImaged(keys=["image"]), AddChanneld(keys=["image"]), ToTensord(keys=["image"]), ]) # Setup data #set_determinism(seed=100) test_dataset = PersistentDataset(data=test_data_hackathon[:], transform=hackathon_train_transfrom, cache_dir=dirs["persistent"]) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True, pin_memory=using_gpu, num_workers=1, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Setup network, loss function, optimizer and scheduler network = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device) # Setup validator and trainer valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), ]) # Setup tester tester = Tester(device=device, test_data_loader=test_loader, load_dir=train_output, out_dir=dirs["out"], network=network, post_transform=valid_post_transforms, non_blocking=using_gpu, amp=using_gpu) # Run tester tester.run()
AsDiscrete, Compose, CropForegroundd, LoadImaged, RandCropByPosNegLabeld, Rand3DElasticd, ToTensord, ) from DataSet import * from VNet import * from Training import * train_transform = Compose([ LoadImaged(keys=[DataType.Image, DataType.Label]), AddChanneld(keys=[DataType.Image, DataType.Label]), CropForegroundd(keys=[DataType.Image, DataType.Label], source_key=DataType.Image), RandCropByPosNegLabeld( keys=[DataType.Image, DataType.Label], label_key=DataType.Label, spatial_size=( 110, 110, 25 ), # Crop to slightly larger than desired for elastic deformation - minimises errors created from padding pos=1, neg=0, num_samples=1, image_key=DataType.Image, image_threshold=0, ), # This function handles both affine and elastic deformation together # To minimise padding issues, we will elastic and affine transform first on original before any cropping Rand3DElasticd(
def main(): #TODO Defining file paths & output directory path json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json') data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19') logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467') if os.path.exists(logdir_path)==False: os.mkdir(logdir_path) # Load Json & Append Root Path with open(json_Path, 'r') as json_f: json_Data = json.load(json_f) train_Data = json_Data['training'] val_Data = json_Data['validation'] for idx, each_d in enumerate(train_Data): train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image']) for idx, each_d in enumerate(val_Data): val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image']) print('Total Number of Training Data Samples: {}'.format(len(train_Data))) print(train_Data) print('#' * 10) print('Total Number of Validation Data Samples: {}'.format(len(val_Data))) print(val_Data) print('#' * 10) # Set Determinism set_determinism(seed=123) # Define Training Transforms train_Transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Spacingd(keys=["image"], pixdim=( 2.0, 2.0, 2.0), mode=("bilinear")), ScaleIntensityRanged( keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image"], source_key="image"), SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)), RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2), CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False), OneOf(transforms=[ RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8), # Please note that that if image, image_2 are called via the same transform call because of the determinism # they will get augmented the exact same way which is not the required case here, hence two calls are made OneOf(transforms=[ RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8) ] ) check_ds = Dataset(data=train_Data, transform=train_Transforms) check_loader = DataLoader(check_ds, batch_size=1) check_data = first(check_loader) image = (check_data["image"][0][0]) print(f"image shape: {image.shape}") # Define Network ViT backbone & Loss & Optimizer device = torch.device("cuda:0") model = ViTAutoEnc( in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16), pos_embed='conv', hidden_size=768, mlp_dim=3072, ) model = model.to(device) # Define Hyper-paramters for training loop max_epochs = 500 val_interval = 2 batch_size = 4 lr = 1e-4 epoch_loss_values = [] step_loss_values = [] epoch_cl_loss_values = [] epoch_recon_loss_values = [] val_loss_values = [] best_val_loss = 1000.0 recon_loss = L1Loss() contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Define DataLoader using MONAI, CacheDataset needs to be used train_ds = Dataset(data=train_Data, transform=train_Transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4) val_ds = Dataset(data=val_Data, transform=train_Transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4) for epoch in range(max_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 epoch_cl_loss = 0 epoch_recon_loss = 0 step = 0 for batch_data in train_loader: step += 1 start_time = time.time() inputs, inputs_2, gt_input = ( batch_data["image"].to(device), batch_data["image_2"].to(device), batch_data["gt_image"].to(device), ) optimizer.zero_grad() outputs_v1, hidden_v1 = model(inputs) outputs_v2, hidden_v2 = model(inputs_2) flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4) flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4) r_loss = recon_loss(outputs_v1, gt_input) cl_loss = contrastive_loss(flat_out_v1, flat_out_v2) # Adjust the CL loss by Recon Loss total_loss = r_loss + cl_loss * r_loss total_loss.backward() optimizer.step() epoch_loss += total_loss.item() step_loss_values.append(total_loss.item()) # CL & Recon Loss Storage of Value epoch_cl_loss += cl_loss.item() epoch_recon_loss += r_loss.item() end_time = time.time() print( f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {total_loss.item():.4f}, " f"time taken: {end_time-start_time}s") epoch_loss /= step epoch_cl_loss /= step epoch_recon_loss /= step epoch_loss_values.append(epoch_loss) epoch_cl_loss_values.append(epoch_cl_loss) epoch_recon_loss_values.append(epoch_recon_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if epoch % val_interval == 0: print('Entering Validation for epoch: {}'.format(epoch+1)) total_val_loss = 0 val_step = 0 model.eval() for val_batch in val_loader: val_step += 1 start_time = time.time() inputs, gt_input = ( val_batch["image"].to(device), val_batch["gt_image"].to(device), ) print('Input shape: {}'.format(inputs.shape)) outputs, outputs_v2 = model(inputs) val_loss = recon_loss(outputs, gt_input) total_val_loss += val_loss.item() end_time = time.time() total_val_loss /= val_step val_loss_values.append(total_val_loss) print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s") if total_val_loss < best_val_loss: print(f"Saving new model based on validation loss {total_val_loss:.4f}") best_val_loss = total_val_loss checkpoint = {'epoch': max_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt')) plt.figure(1, figsize=(8, 8)) plt.subplot(2, 2, 1) plt.plot(epoch_loss_values) plt.grid() plt.title('Training Loss') plt.subplot(2, 2, 2) plt.plot(val_loss_values) plt.grid() plt.title('Validation Loss') plt.subplot(2, 2, 3) plt.plot(epoch_cl_loss_values) plt.grid() plt.title('Training Contrastive Loss') plt.subplot(2, 2, 4) plt.plot(epoch_recon_loss_values) plt.grid() plt.title('Training Recon Loss') plt.savefig(os.path.join(logdir_path, 'loss_plots.png')) plt.close(1) print('Done') return None
set_determinism(seed=0) val_transforms = Compose([ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), mode=('bilinear', 'nearest')), Orientationd(keys=['image', 'label'], axcodes='RAS'), ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=['image', 'label'], source_key='image'), ToTensord(keys=['image', 'label']) ]) check_ds = monai.data.Dataset(data=val_files, transform=val_transforms) check_loader = monai.data.DataLoader(check_ds, batch_size=1) check_data = monai.utils.misc.first(check_loader) # plt.imshow(check_data['image'][0, 0, :, :, 80]) # plt.imshow(check_data['label'][0, 0, :, :, 80]) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=params['num_workers']) # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds,
def test_value(self, argments, image, expected_data): result = CropForegroundd(**argments)(image) np.testing.assert_allclose(result["img"], expected_data)
TESTS.append(("BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7, 2, 5]))) TESTS.append(("BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7]))) TESTS.append(("BorderPadd 3d", "3D", 0, BorderPadd(KEYS, [4]))) TESTS.append(("DivisiblePadd 2d", "2D", 0, DivisiblePadd(KEYS, k=4))) TESTS.append(("DivisiblePadd 3d", "3D", 0, DivisiblePadd(KEYS, k=[4, 8, 11]))) TESTS.append(("CenterSpatialCropd 2d", "2D", 0, CenterSpatialCropd(KEYS, roi_size=95))) TESTS.append(("CenterSpatialCropd 3d", "3D", 0, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]))) TESTS.append(("CropForegroundd 2d", "2D", 0, CropForegroundd(KEYS, source_key="label", margin=2))) TESTS.append(("CropForegroundd 3d", "3D", 0, CropForegroundd(KEYS, source_key="label", k_divisible=[5, 101, 2]))) TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) TESTS.append(("Flipd 3d", "3D", 0, Flipd(KEYS, [1, 2]))) TESTS.append(("RandFlipd 3d", "3D", 0, RandFlipd(KEYS, 1, [1, 2]))) TESTS.append(("RandAxisFlipd 3d", "3D", 0, RandAxisFlipd(KEYS, 1))) for acc in [True, False]: TESTS.append(("Orientationd 3d", "3D", 0, Orientationd(KEYS, "RAS", as_closest_canonical=acc)))
def segment(image, label, result, weights, resolution, patch_size, network, gpu_ids): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if label is not None: uniform_img_dimensions_internal(image, label, True) files = [{"image": image, "label": label}] else: files = [{"image": image}] # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop( image, resolution) # ------------------------------- if label is not None: if resolution is not None: val_transforms = Compose([ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), ToTensord(keys=['image', 'label']) ]) else: val_transforms = Compose([ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd( keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ]) else: if resolution is not None: val_transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution SpatialPadd( keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch ToTensord(keys=['image']) ]) else: val_transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd( keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch ToTensord(keys=['image']) ]) val_ds = monai.data.Dataset(data=files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=False) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([ EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True) ]) if gpu_ids != '-1': # try to use all the available GPUs os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") # build the network if network == 'nnunet': net = build_net() # nn build_net elif network == 'unetr': net = build_UNETR() # UneTR net = net.to(device) if gpu_ids == '-1': net.load_state_dict(new_state_dict_cpu(weights)) else: net.load_state_dict(new_state_dict(weights)) # define sliding window size and batch size for windows inference roi_size = patch_size sw_batch_size = 4 net.eval() with torch.no_grad(): if label is None: for val_data in val_loader: val_images = val_data["image"].to(device) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] else: for val_data in val_loader: val_images, val_labels = val_data["image"].to( device), val_data["label"].to(device) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() print("Evaluation Metric (Dice):", metric) result_array = val_outputs[0].squeeze().data.cpu().numpy() # Remove the pad if the image was smaller than the patch in some directions result_array = result_array[0:resampled_size[0], 0:resampled_size[1], 0:resampled_size[2]] # resample back to the original resolution if resolution is not None: result_array_np = np.transpose(result_array, (2, 1, 0)) result_array_temp = sitk.GetImageFromArray(result_array_np) result_array_temp.SetSpacing(resolution) # save temporary label writer = sitk.ImageFileWriter() writer.SetFileName('temp_seg.nii') writer.Execute(result_array_temp) files = [{"image": 'temp_seg.nii'}] files_transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), Spacingd(keys=['image'], pixdim=original_resolution, mode=('nearest')), Resized(keys=['image'], spatial_size=crop_shape, mode=('nearest')), ]) files_ds = Dataset(data=files, transform=files_transforms) files_loader = DataLoader(files_ds, batch_size=1, num_workers=0) for files_data in files_loader: files_images = files_data["image"] res = files_images.squeeze().data.numpy() result_array = np.rint(res) os.remove('./temp_seg.nii') # recover the cropped background before saving the image empty_array = np.zeros(original_shape) empty_array[coord1[0]:coord2[0], coord1[1]:coord2[1], coord1[2]:coord2[2]] = result_array result_seg = from_numpy_to_itk(empty_array, image) # save label writer = sitk.ImageFileWriter() writer.SetFileName(result) writer.Execute(result_seg) print("Saved Result at:", str(result))
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # check gpus if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] val_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(test_images, test_segs)] # Transforms list if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # CT HU filter # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # build the network if opt.network is 'nnunet': net = build_net() # nn build_net elif opt.network is 'unetr': net = build_UNETR() # UneTR net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) loss_function = monai.losses.DiceCELoss(sigmoid=True) torch.backends.cudnn.benchmark = opt.benchmark if opt.network is 'nnunet': optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,) net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9) elif opt.network is 'unetr': optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if opt.network is 'nnunet': update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice(val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch ) ) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label # val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image") plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label") plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted( glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted( glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted( glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted( glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted( glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted( glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted( glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted( glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip( train_images_for_dice, train_segs_for_dice)] val_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(test_images, test_segs)] # Transforms list if opt.resolution is not None: train_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # try to use all the available GPUs devices = get_devices_spec(None) # build the network net = build_net() net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # loss_function = monai.losses.DiceLoss(sigmoid=True) loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7) optim = torch.optim.Adam(net.parameters(), lr=opt.lr) net_scheduler = get_scheduler(optim, opt) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda( ), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda( ), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, net) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice( val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice( train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice( test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}" .format(epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch)) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
"2D", 0, CenterSpatialCropd(KEYS, roi_size=95), ) ) TESTS.append( ( "CenterSpatialCropd 3d", "3D", 0, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), ) ) TESTS.append(("CropForegroundd 2d", "2D", 0, CropForegroundd(KEYS, source_key="label", margin=2))) TESTS.append(("CropForegroundd 3d", "3D", 0, CropForegroundd(KEYS, source_key="label"))) TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) TESTS.append( ( "Flipd 3d", "3D", 0, Flipd(KEYS, [1, 2]), ) )
def statistics_crop(image, resolution): files = [{"image": image}] reader = sitk.ImageFileReader() reader.SetFileName(image) image_itk = reader.Execute() original_resolution = image_itk.GetSpacing() # original size transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), ToTensord(keys=['image']) ]) data = monai.data.Dataset(data=files, transform=transforms) loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) loader = monai.utils.misc.first(loader) im, = (loader['image'][0]) vol = im.numpy() original_shape = vol.shape # cropped foreground size transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), CropForegroundd( keys=['image'], source_key='image', start_coord_key='foreground_start_coord', end_coord_key='foreground_end_coord', ), # crop CropForeground ToTensord( keys=['image', 'foreground_start_coord', 'foreground_end_coord']) ]) data = monai.data.Dataset(data=files, transform=transforms) loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) loader = monai.utils.misc.first(loader) im, coord1, coord2 = (loader['image'][0], loader['foreground_start_coord'][0], loader['foreground_end_coord'][0]) vol = im[0].numpy() coord1 = coord1.numpy() coord2 = coord2.numpy() crop_shape = vol.shape if resolution is not None: transforms = Compose([ LoadImaged(keys=['image']), AddChanneld(keys=['image']), CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution ToTensord(keys=['image']) ]) data = monai.data.Dataset(data=files, transform=transforms) loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) loader = monai.utils.misc.first(loader) im, = (loader['image'][0]) vol = im.numpy() resampled_size = vol.shape else: resampled_size = original_shape return original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution
pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ) ), Range()(Orientationd(keys=["image", "label"], axcodes="RAS")), Range()( ScaleIntensityRanged( keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ) ), Range()(CropForegroundd(keys=["image", "label"], source_key="image")), # pre-compute foreground and background indexes # and cache them to accelerate training Range("Indexing")( FgBgToIndicesd( keys="label", fg_postfix="_fg", bg_postfix="_bg", image_key="image", ) ), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device="cuda:0"), Range("RandCrop")( RandCropByPosNegLabeld( keys=["image", "label"],
def main(): print_config() # Define paths for running the script data_dir = os.path.normpath('/to/be/defined') json_path = os.path.normpath('/to/be/defined') logdir = os.path.normpath('/to/be/defined') # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used use_pretrained = 1 pretrained_path = os.path.normpath('/to/be/defined') # Training Hyper-parameters lr = 1e-4 max_iterations = 30000 eval_num = 100 if os.path.exists(logdir) == False: os.mkdir(logdir) # Training & Validation Transform chain train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandFlipd( keys=["image", "label"], spatial_axis=[0], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[1], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[2], prob=0.10, ), RandRotate90d( keys=["image", "label"], prob=0.10, max_k=3, ), RandShiftIntensityd( keys=["image"], offsets=0.10, prob=0.50, ), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=["image", "label"]), ]) datalist = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="training") val_files = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="validation") train_ds = CacheDataset( data=datalist, transform=train_transforms, cache_num=24, cache_rate=1.0, num_workers=4, ) train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4, pin_memory=True) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) case_num = 0 img = val_ds[case_num]["image"] label = val_ds[case_num]["label"] img_shape = img.shape label_shape = label.shape print(f"image shape: {img_shape}, label shape: {label_shape}") os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNETR( in_channels=1, out_channels=14, img_size=(96, 96, 96), feature_size=16, hidden_size=768, mlp_dim=3072, num_heads=12, pos_embed="conv", norm_name="instance", res_block=True, dropout_rate=0.0, ) # Load ViT backbone weights into UNETR if use_pretrained == 1: print('Loading Weights from the Path {}'.format(pretrained_path)) vit_dict = torch.load(pretrained_path) vit_weights = vit_dict['state_dict'] # Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias, # conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions # while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR) vit_weights.pop('conv3d_transpose_1.bias') vit_weights.pop('conv3d_transpose_1.weight') vit_weights.pop('conv3d_transpose.bias') vit_weights.pop('conv3d_transpose.weight') model.vit.load_state_dict(vit_weights) print('Pretrained Weights Succesfully Loaded !') elif use_pretrained == 0: print( 'No weights were loaded, all weights being used are randomly initialized!' ) model.to(device) loss_function = DiceCELoss(to_onehot_y=True, softmax=True) torch.backends.cudnn.benchmark = True optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) post_label = AsDiscrete(to_onehot=14) post_pred = AsDiscrete(argmax=True, to_onehot=14) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) global_step = 0 dice_val_best = 0.0 global_step_best = 0 epoch_loss_values = [] metric_values = [] def validation(epoch_iterator_val): model.eval() dice_vals = list() with torch.no_grad(): for step, batch in enumerate(epoch_iterator_val): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model) val_labels_list = decollate_batch(val_labels) val_labels_convert = [ post_label(val_label_tensor) for val_label_tensor in val_labels_list ] val_outputs_list = decollate_batch(val_outputs) val_output_convert = [ post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list ] dice_metric(y_pred=val_output_convert, y=val_labels_convert) dice = dice_metric.aggregate().item() dice_vals.append(dice) epoch_iterator_val.set_description( "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)) dice_metric.reset() mean_dice_val = np.mean(dice_vals) return mean_dice_val def train(global_step, train_loader, dice_val_best, global_step_best): model.train() epoch_loss = 0 step = 0 epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) for step, batch in enumerate(epoch_iterator): step += 1 x, y = (batch["image"].cuda(), batch["label"].cuda()) logit_map = model(x) loss = loss_function(logit_map, y) loss.backward() epoch_loss += loss.item() optimizer.step() optimizer.zero_grad() epoch_iterator.set_description( "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)) if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations: epoch_iterator_val = tqdm( val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) dice_val = validation(epoch_iterator_val) epoch_loss /= step epoch_loss_values.append(epoch_loss) metric_values.append(dice_val) if dice_val > dice_val_best: dice_val_best = dice_val global_step_best = global_step torch.save(model.state_dict(), os.path.join(logdir, "best_metric_model.pth")) print( "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) else: print( "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) plt.figure(1, (12, 6)) plt.subplot(1, 2, 1) plt.title("Iteration Average Loss") x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.subplot(1, 2, 2) plt.title("Val Mean Dice") x = [eval_num * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.savefig( os.path.join(logdir, 'btcv_finetune_quick_update.png')) plt.clf() plt.close(1) global_step += 1 return global_step, dice_val_best, global_step_best while global_step < max_iterations: global_step, dice_val_best, global_step_best = train( global_step, train_loader, dice_val_best, global_step_best) model.load_state_dict( torch.load(os.path.join(logdir, "best_metric_model.pth"))) print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")
TESTS.append(( "CenterSpatialCropd 2d", "2D", 0, CenterSpatialCropd(KEYS, roi_size=95), )) TESTS.append(( "CenterSpatialCropd 3d", "3D", 0, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), )) TESTS.append(("CropForegroundd 2d", "2D", 0, CropForegroundd(KEYS, source_key="label", margin=2))) TESTS.append( ("CropForegroundd 3d", "3D", 0, CropForegroundd(KEYS, source_key="label"))) TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) TESTS.append(( "Flipd 3d", "3D", 0, Flipd(KEYS, [1, 2]), )) TESTS.append((
def main(hparams): print('===== INITIAL PARAMETERS =====') print('Model name: ', hparams.name) print('Batch size: ', hparams.batch_size) print('Patch size: ', hparams.patch_size) print('Epochs: ', hparams.epochs) print('Learning rate: ', hparams.learning_rate) print('Loss function: ', hparams.loss) print() ### Data collection data_dir = 'data/' print('Available directories: ', os.listdir(data_dir)) # Get paths for images and masks, organize into dictionaries images = sorted(glob.glob(data_dir + '**/*CTImg*', recursive=True)) masks = sorted(glob.glob(data_dir + '**/*Mask*', recursive=True)) data_dicts = [{ 'image': image_file, 'mask': mask_file } for image_file, mask_file in zip(images, masks)] # Dataset selection train_dicts = select_animals(images, masks, [12, 13, 14, 18, 20]) val_dicts = select_animals(images, masks, [25]) test_dicts = select_animals(images, masks, [27]) data_keys = ['image', 'mask'] # Data transformation data_transforms = Compose([ LoadNiftid(keys=data_keys), AddChanneld(keys=data_keys), ScaleIntensityd(keys=data_keys), CropForegroundd(keys=data_keys, source_key='image'), RandSpatialCropd(keys=data_keys, roi_size=(hparams.patch_size, hparams.patch_size, 1), random_size=False), ]) train_transforms = Compose([data_transforms, ToTensord(keys=data_keys)]) val_transforms = Compose([data_transforms, ToTensord(keys=data_keys)]) test_transforms = Compose([data_transforms, ToTensord(keys=data_keys)]) # Data loaders data_loaders = { 'train': create_loader(train_dicts, batch_size=hparams.batch_size, transforms=train_transforms, shuffle=True), 'val': create_loader(val_dicts, transforms=val_transforms), 'test': create_loader(test_dicts, transforms=test_transforms) } for key in data_loaders: print(key, len(data_loaders[key])) ### Model training if hparams.loss == 'Dice': criterion = monai.losses.DiceLoss(to_onehot_y=True, do_softmax=True) elif hparams.loss == 'CrossEntropy': criterion = nn.CrossEntropyLoss() model = UNet( dimensions=2, in_channels=1, out_channels=2, channels=(64, 128, 258, 512, 1024), strides=(2, 2, 2, 2), norm=monai.networks.layers.Norm.BATCH, criterion=criterion, hparams=hparams, ) early_stopping = EarlyStopping('val_loss') checkpoint_callback = ModelCheckpoint logger = TensorBoardLogger('models/' + hparams.name + '/tb_logs', name=hparams.name) trainer = Trainer( check_val_every_n_epoch=5, default_save_path='models/' + hparams.name + '/checkpoints', # early_stop_callback=early_stopping, gpus=1, max_epochs=hparams.epochs, # min_epochs=10, logger=logger) trainer.fit(model, train_dataloader=data_loaders['train'], val_dataloaders=data_loaders['val'])
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
# Saturate label values greater than 1 to 1 label_values = d["label"] d["label"] = (label_values > 0).astype(label_values.dtype) return d train_transforms = Compose([ LoadImaged(keys=["image", "label"]), # FixLabelAffineAndReduceClassesToOne(keys=["image", "label"]), # AddChanneld(keys=["image", "label"]), # Spacingd(keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")), # Orientationd(keys=["image", "label"], axcodes="RAS"), # ScaleIntensityd(keys=["image"]), # CropForegroundd(keys=["image", "label"], source_key="image"), # RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", # spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, # image_key="image", image_threshold=0, ), # ToTensord(keys=["image", "label"]), # ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), #
def configure(self): self.set_device() network = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(self.device) if self.multi_gpu: network = DistributedDataParallel( module=network, device_ids=[self.device], find_unused_parameters=False, ) train_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=[1.0, 1.0, 1.0], mode=["bilinear", "nearest"]), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=("image", "label")), ]) train_datalist = load_decathlon_datalist(self.data_list_file_path, True, "training") if self.multi_gpu: train_datalist = partition_dataset( data=train_datalist, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True, )[dist.get_rank()] train_ds = CacheDataset( data=train_datalist, transform=train_transforms, cache_num=32, cache_rate=1.0, num_workers=4, ) train_data_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, ) val_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), ToTensord(keys=("image", "label")), ]) val_datalist = load_decathlon_datalist(self.data_list_file_path, True, "validation") val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4) val_data_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=4, ) post_transform = Compose([ Activationsd(keys="pred", softmax=True), AsDiscreted( keys=["pred", "label"], argmax=[True, False], to_onehot=True, n_classes=2, ), ]) # metric key_val_metric = { "val_mean_dice": MeanDice( include_background=False, output_transform=lambda x: (x["pred"], x["label"]), device=self.device, ) } val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver( save_dir=self.ckpt_dir, save_dict={"model": network}, save_key_metric=True, ), TensorBoardStatsHandler(log_dir=self.ckpt_dir, output_transform=lambda x: None), ] self.eval_engine = SupervisedEvaluator( device=self.device, val_data_loader=val_data_loader, network=network, inferer=SlidingWindowInferer( roi_size=[160, 160, 160], sw_batch_size=4, overlap=0.5, ), post_transform=post_transform, key_val_metric=key_val_metric, val_handlers=val_handlers, amp=self.amp, ) optimizer = torch.optim.Adam(network.parameters(), self.learning_rate) loss_function = DiceLoss(to_onehot_y=True, softmax=True) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=self.eval_engine, interval=self.val_interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler( log_dir=self.ckpt_dir, tag_name="train_loss", output_transform=lambda x: x["loss"], ), ] self.train_engine = SupervisedTrainer( device=self.device, max_epochs=self.max_epochs, train_data_loader=train_data_loader, network=network, optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), post_transform=post_transform, key_train_metric=None, train_handlers=train_handlers, amp=self.amp, ) if self.local_rank > 0: self.train_engine.logger.setLevel(logging.WARNING) self.eval_engine.logger.setLevel(logging.WARNING)
def test_test_time_augmentation(self): input_size = (20, 40) # test different input data shape to pad list collate keys = ["image", "label"] num_training_ims = 10 train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) device = "cuda" if torch.cuda.is_available() else "cpu" transforms = Compose( [ AddChanneld(keys), RandAffined( keys, prob=1.0, spatial_size=(30, 30), rotate_range=(np.pi / 3, np.pi / 3), translate_range=(3, 3), scale_range=((0.8, 1), (0.8, 1)), padding_mode="zeros", mode=("bilinear", "nearest"), as_tensor_output=False, ), CropForegroundd(keys, source_key="image"), DivisiblePadd(keys, 4), ] ) train_ds = CacheDataset(train_data, transforms) # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) num_epochs = 10 for _ in trange(num_epochs): epoch_loss = 0 for batch_data in train_loader: inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= len(train_loader) post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) tt_aug = TestTimeAugmentation( transform=transforms, batch_size=5, num_workers=0, inferrer_fn=model, device=device, to_tensor=True, output_device="cpu", post_func=post_trans, ) mode, mean, std, vvc = tt_aug(test_data) self.assertEqual(mode.shape, (1,) + input_size) self.assertEqual(mean.shape, (1,) + input_size) self.assertTrue(all(np.unique(mode) == (0, 1))) self.assertGreaterEqual(mean.min(), 0.0) self.assertLessEqual(mean.max(), 1.0) self.assertEqual(std.shape, (1,) + input_size) self.assertIsInstance(vvc, float)
def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # Setup directories dirs = setup_directories() # Setup torch device device, using_gpu = create_device("cuda") # Load and randomize images # HACKATON image and segmentation data hackathon_dir = os.path.join(dirs["data"], 'HACKATHON') map_fn = lambda x: (x[0], int(x[1])) with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp: train_info_hackathon = [ map_fn(entry.strip().split(',')) for entry in fp.readlines() ] image_dir = os.path.join(hackathon_dir, 'images', 'train') seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train') _train_data_hackathon = get_data_from_info(image_dir, seg_dir, train_info_hackathon, dual_output=False) _train_data_hackathon = large_image_splitter(_train_data_hackathon, dirs["cache"]) copy_list = transform_and_copy(_train_data_hackathon, dirs['cache']) balance_training_data2(_train_data_hackathon, copy_list, seed=72) # PSUF data """psuf_dir = os.path.join(dirs["data"], 'psuf') with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp: train_info = [entry.strip().split(',') for entry in fp.readlines()] image_dir = os.path.join(psuf_dir, 'images') train_data_psuf = get_data_from_info(image_dir, None, train_info)""" # Split data into train, validate and test train_split, test_data_hackathon = train_test_split(_train_data_hackathon, test_size=0.2, shuffle=True, random_state=42) train_data_hackathon, valid_data_hackathon = train_test_split( train_split, test_size=0.2, shuffle=True, random_state=43) #balance_training_data(train_data_hackathon, seed=72) #balance_training_data(valid_data_hackathon, seed=73) #balance_training_data(test_data_hackathon, seed=74) # Setup transforms # Crop foreground crop_foreground = CropForegroundd(keys=["image"], source_key="image", margin=(5, 5, 0), select_fn=lambda x: x != 0) # Crop Z crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12)) # Window width and level (window center) WW, WL = 1500, -600 ct_window = CTWindowd(keys=["image"], width=WW, level=WL) # Random axis flip rand_x_flip = RandFlipd(keys=["image"], spatial_axis=0, prob=0.50) rand_y_flip = RandFlipd(keys=["image"], spatial_axis=1, prob=0.50) rand_z_flip = RandFlipd(keys=["image"], spatial_axis=2, prob=0.50) # Rand affine transform rand_affine = RandAffined(keys=["image"], prob=0.5, rotate_range=(0, 0, np.pi / 12), shear_range=(0.07, 0.07, 0.0), translate_range=(0, 0, 0), scale_range=(0.07, 0.07, 0.0), padding_mode="zeros") # Pad image to have hight at least 30 spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30)) resize = Resized(keys=["image"], spatial_size=(int(512 * 0.50), int(512 * 0.50), -1), mode="trilinear") # Apply Gaussian noise rand_gaussian_noise = RandGaussianNoised(keys=["image"], prob=0.25, mean=0.0, std=0.1) # Create transforms common_transform = Compose([ LoadImaged(keys=["image"]), ct_window, CTSegmentation(keys=["image"]), AddChanneld(keys=["image"]), resize, crop_foreground, crop_z, spatial_pad, ]) hackathon_train_transform = Compose([ common_transform, rand_x_flip, rand_y_flip, rand_z_flip, rand_affine, rand_gaussian_noise, ToTensord(keys=["image"]), ]).flatten() hackathon_valid_transfrom = Compose([ common_transform, #rand_x_flip, #rand_y_flip, #rand_z_flip, #rand_affine, ToTensord(keys=["image"]), ]).flatten() hackathon_test_transfrom = Compose([ common_transform, ToTensord(keys=["image"]), ]).flatten() psuf_transforms = Compose([ LoadImaged(keys=["image"]), AddChanneld(keys=["image"]), ToTensord(keys=["image"]), ]) # Setup data #set_determinism(seed=100) train_dataset = PersistentDataset(data=train_data_hackathon[:], transform=hackathon_train_transform, cache_dir=dirs["persistent"]) valid_dataset = PersistentDataset(data=valid_data_hackathon[:], transform=hackathon_valid_transfrom, cache_dir=dirs["persistent"]) test_dataset = PersistentDataset(data=test_data_hackathon[:], transform=hackathon_test_transfrom, cache_dir=dirs["persistent"]) train_loader = DataLoader( train_dataset, batch_size=4, #shuffle=True, pin_memory=using_gpu, num_workers=2, sampler=ImbalancedDatasetSampler( train_data_hackathon, callback_get_label=lambda x, i: x[i]['_label']), collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT)) valid_loader = DataLoader( valid_dataset, batch_size=4, shuffle=False, pin_memory=using_gpu, num_workers=2, sampler=ImbalancedDatasetSampler( valid_data_hackathon, callback_get_label=lambda x, i: x[i]['_label']), collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT)) test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=using_gpu, num_workers=2, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Setup network, loss function, optimizer and scheduler network = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device) # pos_weight for class imbalance _, n, p = calculate_class_imbalance(train_data_hackathon) pos_weight = torch.Tensor([n, p]).to(device) loss_function = torch.nn.BCEWithLogitsLoss(pos_weight) optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=0) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95, last_epoch=-1) # Setup validator and trainer valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), #Activationsd(keys="pred", softmax=True), ]) validator = Validator(device=device, val_data_loader=valid_loader, network=network, post_transform=valid_post_transforms, amp=using_gpu, non_blocking=using_gpu) trainer = Trainer(device=device, out_dir=dirs["out"], out_name="DenseNet121", max_epochs=120, validation_epoch=1, validation_interval=1, train_data_loader=train_loader, network=network, optimizer=optimizer, loss_function=loss_function, lr_scheduler=None, validator=validator, amp=using_gpu, non_blocking=using_gpu) """x_max, y_max, z_max, size_max = 0, 0, 0, 0 for data in valid_loader: image = data["image"] label = data["label"] print() print(len(data['image_transforms'])) #print(data['image_transforms']) print(label) shape = image.shape x_max = max(x_max, shape[-3]) y_max = max(y_max, shape[-2]) z_max = max(z_max, shape[-1]) size = int(image.nelement()*image.element_size()/1024/1024) size_max = max(size_max, size) print("shape:", shape, "size:", str(size)+"MB") #multi_slice_viewer(image[0, 0, :, :, :], str(label)) print(x_max, y_max, z_max, str(size_max)+"MB") exit()""" # Run trainer train_output = trainer.run() # Setup tester tester = Tester(device=device, test_data_loader=test_loader, load_dir=train_output, out_dir=dirs["out"], network=network, post_transform=valid_post_transforms, non_blocking=using_gpu, amp=using_gpu) # Run tester tester.run()