def transform(self): if hp.mode == '3d': if hp.aug: training_transform = Compose([ # ToCanonical(), CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'), # RandomMotion(), RandomBiasField(), ZNormalization(), RandomNoise(), RandomFlip(axes=(0, )), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }), ]) else: training_transform = Compose([ CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size, hp.crop_or_pad_size), padding_mode='reflect'), ZNormalization(), ]) elif hp.mode == '2d': if hp.aug: training_transform = Compose([ CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'), # RandomMotion(), RandomBiasField(), ZNormalization(), RandomNoise(), RandomFlip(axes=(0, )), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }), ]) else: training_transform = Compose([ CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size, hp.crop_or_pad_size), padding_mode='reflect'), ZNormalization(), ]) else: raise Exception('no such kind of mode!') return training_transform
def __init__(self, h, w, nb_of_dims, latent_dim, use_coronal, use_sagital, p, experiment_name, parallel, model_weights='best_model.pth', thr=.5): self.model = PatchModel(h, w, nb_of_dims, latent_dim, p).cuda() if parallel: self.model = nn.DataParallel(self.model) self.model.load_state_dict(torch.load(model_weights)) self.model.eval() self.h = h self.w = w self.best_t = thr self.nb_of_dims = nb_of_dims self.use_coronal = use_coronal self.use_sagital = use_sagital self.experiment_name = experiment_name gray_matter_template = nib.load( './data/MNI152_T1_1mm_brain_gray.nii.gz') self.gmpm = gray_matter_template.get_fdata() > 0 t1_landmarks = Path('./data/t1_landmarks.npy') landmarks_dict = {'mri': t1_landmarks} histogram_transform = HistogramStandardization(landmarks_dict) znorm_transform = ZNormalization(masking_method=ZNormalization.mean) self.transform = torchio.transforms.Compose( [histogram_transform, znorm_transform])
def test_transforms(self): landmarks_dict = dict( t1=np.linspace(0, 100, 13), t2=np.linspace(0, 100, 13), ) random_transforms = ( RandomFlip(axes=(0, 1, 2), flip_probability=1), RandomNoise(), RandomBiasField(), RandomElasticDeformation(proportion_to_augment=1), RandomAffine(), RandomMotion(proportion_to_augment=1), ) intensity_transforms = ( Rescale(), ZNormalization(), HistogramStandardization(landmarks_dict=landmarks_dict), ) for transform in random_transforms: sample = self.get_sample() transformed = transform(sample) for transform in intensity_transforms: sample = self.get_sample() transformed = transform(sample)
def transform(self): training_transform = Compose([ ZNormalization(), ]) return training_transform
def test_transforms(self): landmarks_dict = dict( t1=np.linspace(0, 100, 13), t2=np.linspace(0, 100, 13), ) transforms = ( CenterCropOrPad((9, 21, 30)), ToCanonical(), Resample((1, 1.1, 1.25)), RandomFlip(axes=(0, 1, 2), flip_probability=1), RandomMotion(proportion_to_augment=1), RandomGhosting(proportion_to_augment=1, axes=(0, 1, 2)), RandomSpike(), RandomNoise(), RandomBlur(), RandomSwap(patch_size=2, num_iterations=5), Lambda(lambda x: 1.5 * x, types_to_apply=INTENSITY), RandomBiasField(), Rescale((0, 1)), ZNormalization(masking_method='label'), HistogramStandardization(landmarks_dict=landmarks_dict), RandomElasticDeformation(proportion_to_augment=1), RandomAffine(), Pad((1, 2, 3, 0, 5, 6)), Crop((3, 2, 8, 0, 1, 4)), ) transformed = self.get_sample() for transform in transforms: transformed = transform(transformed)
def training_network(landmarks, dataset, subjects): training_transform = Compose([ ToCanonical(), Resample(4), CropOrPad((48, 60, 48), padding_mode='reflect'), RandomMotion(), HistogramStandardization({'mri': landmarks}), RandomBiasField(), ZNormalization(masking_method=ZNormalization.mean), RandomNoise(), RandomFlip(axes=(0, )), OneOf({ RandomAffine(): 0.8, RandomElasticDeformation(): 0.2, }), ]) validation_transform = Compose([ ToCanonical(), Resample(4), CropOrPad((48, 60, 48), padding_mode='reflect'), HistogramStandardization({'mri': landmarks}), ZNormalization(masking_method=ZNormalization.mean), ]) training_split_ratio = 0.9 num_subjects = len(dataset) num_training_subjects = int(training_split_ratio * num_subjects) training_subjects = subjects[:num_training_subjects] validation_subjects = subjects[num_training_subjects:] training_set = tio.SubjectsDataset(training_subjects, transform=training_transform) validation_set = tio.SubjectsDataset(validation_subjects, transform=validation_transform) print('Training set:', len(training_set), 'subjects') print('Validation set:', len(validation_set), 'subjects') return training_set, validation_set
def __init__(self, root_dir, img_range=(0,0)): self.root_dir = root_dir self.img_range = img_range subject_lists = [] #check if there is a labels if self.root_dir[-1] != '/': self.root_dir += '/' self.is_labeled = os.path.isdir(self.root_dir + LABEL_DIR) self.files = [re.findall('[0-9]{4}', filename)[0] for filename in os.listdir(self.root_dir + TRAIN_DIR)] self.files = sorted(self.files, key = lambda f : int(f)) # store all subjects in the list for img_num in range(img_range[0], img_range[1]+1): img_file = os.path.join(self.root_dir, TRAIN_DIR, IMG_PREFIX + self.files[img_num] + EXT) label_file = os.path.join(self.root_dir, LABEL_DIR, LABEL_PREFIX + self.files[img_num] + EXT) subject = torchio.Subject( torchio.Image('t1', img_file, torchio.INTENSITY), torchio.Image('label', label_file, torchio.LABEL) ) subject_lists.append(subject) print(img_file) print(label_file) # Define transforms for data normalization and augmentation mtransforms = ( ZNormalization(), #transforms.RandomNoise(std_range=(0, 0.25)), #transforms.RandomFlip(axes=(0,)), ) self.subjects = torchio.ImagesDataset(subject_lists, transform=transforms.Compose(mtransforms)) self.dataset = torchio.Queue( subjects_dataset=self.subjects, max_length=2, samples_per_volume=675, sampler_class=torchio.sampler.ImageSampler, patch_size=(240, 240, 3), num_workers=4, shuffle_subjects=False, shuffle_patches=True ) print("Dataset details\n Images: {}".format(self.img_range[1] - self.img_range[0] + 1))
def get_image_patches(input_img_name, mod_nb, gmpm=None, use_coronal=False, use_sagital=False, input_mask_name=None, augment=True, h=16, w=32, coef=.2, record_results=False, pred_labels=None): subject_dict = { 'mri': torchio.Image(input_img_name, torchio.INTENSITY), } # torchio normalization t1_landmarks = Path(f'./data/t1_landmarks_{mod_nb}.npy') landmarks_dict = {'mri': t1_landmarks} histogram_transform = HistogramStandardization(landmarks_dict) znorm_transform = ZNormalization(masking_method=ZNormalization.mean) transform = torchio.transforms.Compose( [histogram_transform, znorm_transform]) subject = torchio.Subject(subject_dict) zimage = transform(subject) target_np = zimage['mri'].data[0].numpy() if input_mask_name is not None: mask = nib.load(input_mask_name) mask_np = (mask.get_fdata() > 0).astype('float') else: mask_np = np.zeros_like(target_np) all_patches, all_labels, side_mask_np, mid_mask_np = get_patches_and_labels( target_np, gmpm, mask_np, use_coronal=use_coronal, use_sagital=use_sagital, h=h, w=w, coef=coef, augment=augment, record_results=record_results, pred_labels=pred_labels) if not record_results: return all_patches, all_labels else: return side_mask_np, mid_mask_np
# defining dict for pre-processing - key is the string and the value is the transform object global_preprocessing_dict = { "to_canonical": to_canonical_transform, "threshold": threshold_transform, "clip": clip_transform, "clamp": clip_transform, "crop_external_zero_planes": CropExternalZeroplanes, "crop": crop_transform, "centercrop": centercrop_transform, "normalize_by_val": normalize_by_val_transform, "normalize_imagenet": normalize_imagenet_transform(), "normalize_standardize": normalize_standardize_transform(), "normalize_div_by_255": normalize_div_by_255_transform(), "normalize": ZNormalization(), "normalize_positive": ZNormalization(masking_method=positive_voxel_mask), "normalize_nonZero": ZNormalization(masking_method=nonzero_voxel_mask), "normalize_nonzero": ZNormalization(masking_method=nonzero_voxel_mask), "normalize_nonZero_masked": NonZeroNormalizeOnMaskedRegion(), "normalize_nonzero_masked": NonZeroNormalizeOnMaskedRegion(), "rgba2rgb": rgba2rgb_transform, "rgbatorgb": rgba2rgb_transform, "rgba_to_rgb": rgba2rgb_transform, "rgb2rgba": rgb2rgba_transform, "rgbtorgba": rgb2rgba_transform, "rgb_to_rgba": rgb2rgba_transform, "histogram_matching": histogram_matching, "stain_normalizer": stain_normalizer, }
return Lambda(function=partial(clip_intensities, min=min, max=max), p=p) def rotate_90(axis, p=1): return Lambda(function=partial(tensor_rotate_90, axis=axis), p=p) def rotate_180(axis, p=1): return Lambda(function=partial(tensor_rotate_180, axis=axis), p=p) # defining dict for pre-processing - key is the string and the value is the transform object global_preprocessing_dict = { 'threshold': threshold_transform, 'clip': clip_transform, 'normalize': ZNormalization(), 'normalize_nonZero': ZNormalization(masking_method=positive_voxel_mask), 'normalize_nonZero_masked': NonZeroNormalizeOnMaskedRegion(), 'crop_external_zero_planes': crop_external_zero_planes, 'normalize_imagenet': normalize_imagenet, 'normalize_standardize': normalize_standardize, 'normalize_div_by_255': normalize_div_by_255 } # Defining a dictionary for augmentations - key is the string and the value is the augmentation object global_augs_dict = { 'affine': affine, 'elastic': elastic, 'kspace': mri_artifact, 'bias': bias, 'blur': blur,
def test(): parser = argparse.ArgumentParser( description='PyTorch Medical Segmentation Testing') parser = parse_training_args(parser) args, _ = parser.parse_known_args() args = parser.parse_args() torch.backends.cudnn.deterministic = True torch.backends.cudnn.enabled = args.cudnn_enabled torch.backends.cudnn.benchmark = args.cudnn_benchmark from data_function import MedData_test os.makedirs(output_dir_test, exist_ok=True) if hp.mode == '2d': from models.two_d.unet import Unet model = Unet(in_channels=hp.in_class, classes=hp.out_class) # from models.two_d.miniseg import MiniSeg # model = MiniSeg(in_input=hp.in_class, classes=hp.out_class) # from models.two_d.fcn import FCN32s as fcn # model = fcn(in_class =hp.in_class,n_class=hp.out_class) # from models.two_d.segnet import SegNet # model = SegNet(input_nbr=hp.in_class,label_nbr=hp.out_class) # from models.two_d.deeplab import DeepLabV3 # model = DeepLabV3(in_class=hp.in_class,class_num=hp.out_class) # from models.two_d.unetpp import ResNet34UnetPlus # model = ResNet34UnetPlus(num_channels=hp.in_class,num_class=hp.out_class) # from models.two_d.pspnet import PSPNet # model = PSPNet(in_class=hp.in_class,n_classes=hp.out_class) elif hp.mode == '3d': from models.three_d.unet3d import UNet model = UNet(in_channels=hp.in_class, n_classes=hp.out_class, base_n_filter=2) #from models.three_d.fcn3d import FCN_Net #model = FCN_Net(in_channels =hp.in_class,n_class =hp.out_class) #from models.three_d.highresnet import HighRes3DNet #model = HighRes3DNet(in_channels=hp.in_class,out_channels=hp.out_class) #from models.three_d.densenet3d import SkipDenseNet3D #model = SkipDenseNet3D(in_channels=hp.in_class, classes=hp.out_class) # from models.three_d.densevoxelnet3d import DenseVoxelNet # model = DenseVoxelNet(in_channels=hp.in_class, classes=hp.out_class) #from models.three_d.vnet3d import VNet #model = VNet(in_channels=hp.in_class, classes=hp.out_class) model = torch.nn.DataParallel(model, device_ids=devicess, output_device=[1]) print("load model:", args.ckpt) print(os.path.join(args.output_dir, args.latest_checkpoint_file)) ckpt = torch.load(os.path.join(args.output_dir, args.latest_checkpoint_file), map_location=lambda storage, loc: storage) model.load_state_dict(ckpt["model"]) model.cuda() test_dataset = MedData_test(source_test_dir, label_test_dir) znorm = ZNormalization() if hp.mode == '3d': patch_overlap = hp.patch_overlap patch_size = hp.patch_size elif hp.mode == '2d': patch_overlap = hp.patch_overlap patch_size = hp.patch_size for i, subj in enumerate(test_dataset.subjects): subj = znorm(subj) grid_sampler = torchio.inference.GridSampler( subj, patch_size, patch_overlap, ) patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=16) aggregator = torchio.inference.GridAggregator(grid_sampler) aggregator_1 = torchio.inference.GridAggregator(grid_sampler) model.eval() with torch.no_grad(): for patches_batch in tqdm(patch_loader): input_tensor = patches_batch['source'][torchio.DATA].to(device) locations = patches_batch[torchio.LOCATION] if hp.mode == '2d': input_tensor = input_tensor.squeeze(4) outputs = model(input_tensor) if hp.mode == '2d': outputs = outputs.unsqueeze(4) logits = torch.sigmoid(outputs) labels = logits.clone() labels[labels > 0.5] = 1 labels[labels <= 0.5] = 0 aggregator.add_batch(logits, locations) aggregator_1.add_batch(labels, locations) output_tensor = aggregator.get_output_tensor() output_tensor_1 = aggregator_1.get_output_tensor() affine = subj['source']['affine'] if (hp.in_class == 1) and (hp.out_class == 1): label_image = torchio.ScalarImage(tensor=output_tensor.numpy(), affine=affine) label_image.save( os.path.join(output_dir_test, f"{str(i):04d}-result_float" + hp.save_arch)) # f"{str(i):04d}-result_float.mhd" output_image = torchio.ScalarImage(tensor=output_tensor_1.numpy(), affine=affine) output_image.save( os.path.join(output_dir_test, f"{str(i):04d}-result_int" + hp.save_arch)) else: output_tensor = output_tensor.unsqueeze(1) output_tensor_1 = output_tensor_1.unsqueeze(1) output_image_artery_float = torchio.ScalarImage( tensor=output_tensor[0].numpy(), affine=affine) output_image_artery_float.save( os.path.join( output_dir_test, f"{str(i):04d}-result_float_artery" + hp.save_arch)) # f"{str(i):04d}-result_float_artery.mhd" output_image_artery_int = torchio.ScalarImage( tensor=output_tensor_1[0].numpy(), affine=affine) output_image_artery_int.save( os.path.join(output_dir_test, f"{str(i):04d}-result_int_artery" + hp.save_arch)) output_image_lung_float = torchio.ScalarImage( tensor=output_tensor[1].numpy(), affine=affine) output_image_lung_float.save( os.path.join(output_dir_test, f"{str(i):04d}-result_float_lung" + hp.save_arch)) output_image_lung_int = torchio.ScalarImage( tensor=output_tensor_1[1].numpy(), affine=affine) output_image_lung_int.save( os.path.join(output_dir_test, f"{str(i):04d}-result_int_lung" + hp.save_arch)) output_image_trachea_float = torchio.ScalarImage( tensor=output_tensor[2].numpy(), affine=affine) output_image_trachea_float.save( os.path.join( output_dir_test, f"{str(i):04d}-result_float_trachea" + hp.save_arch)) output_image_trachea_int = torchio.ScalarImage( tensor=output_tensor_1[2].numpy(), affine=affine) output_image_trachea_int.save( os.path.join(output_dir_test, f"{str(i):04d}-result_int_trachea" + hp.save_arch)) output_image_vein_float = torchio.ScalarImage( tensor=output_tensor[3].numpy(), affine=affine) output_image_vein_float.save( os.path.join(output_dir_test, f"{str(i):04d}-result_float_vein" + hp.save_arch)) output_image_vein_int = torchio.ScalarImage( tensor=output_tensor_1[3].numpy(), affine=affine) output_image_vein_int.save( os.path.join(output_dir_test, f"{str(i):04d}-result_int_vein" + hp.save_arch))
def main(): opt = parsing_data() print("[INFO]Reading data") # Dictionary with data parameters for NiftyNet Reader if torch.cuda.is_available(): print('[INFO] GPU available.') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: raise Exception( "[INFO] No GPU found or Wrong gpu id, please run without --cuda") # FOLDERS fold_dir = opt.model_dir fold_dir_model = os.path.join(fold_dir, 'models') if not os.path.exists(fold_dir_model): os.makedirs(fold_dir_model) save_path = os.path.join(fold_dir_model, './CP_{}.pth') output_path = os.path.join(fold_dir, 'output') if not os.path.exists(output_path): os.makedirs(output_path) output_path = os.path.join(output_path, 'output_{}.nii.gz') # LOGGING orig_stdout = sys.stdout if os.path.exists(os.path.join(fold_dir, 'out.txt')): compt = 0 while os.path.exists( os.path.join(fold_dir, 'out_' + str(compt) + '.txt')): compt += 1 f = open(os.path.join(fold_dir, 'out_' + str(compt) + '.txt'), 'w') else: f = open(os.path.join(fold_dir, 'out.txt'), 'w') sys.stdout = f # SPLITS split_path_source = opt.dataset_split_source assert os.path.isfile(split_path_source), 'source file not found' split_path_target = opt.dataset_split_target assert os.path.isfile(split_path_target), 'target file not found' split_path = dict() split_path['source'] = split_path_source split_path['target'] = split_path_target path_file = dict() path_file['source'] = opt.path_source path_file['target'] = opt.path_target list_split = [ 'training', 'validation', ] paths_dict = dict() for domain in ['source', 'target']: df_split = pd.read_csv(split_path[domain], header=None) list_file = dict() for split in list_split: list_file[split] = df_split[df_split[1].isin([split])][0].tolist() paths_dict_domain = {split: [] for split in list_split} for split in list_split: for subject in list_file[split]: subject_data = [] for modality in MODALITIES[domain]: subject_data.append( Image( modality, path_file[domain] + subject + modality + '.nii.gz', torchio.INTENSITY)) if split in ['training', 'validation']: subject_data.append( Image('label', path_file[domain] + subject + 'Label.nii.gz', torchio.LABEL)) #subject_data[] = paths_dict_domain[split].append(Subject(*subject_data)) print(domain, split, len(paths_dict_domain[split])) paths_dict[domain] = paths_dict_domain # PREPROCESSING transform_training = dict() transform_validation = dict() for domain in ['source', 'target']: transform_training[domain] = ( ToCanonical(), ZNormalization(), CenterCropOrPad((144, 192, 48)), RandomAffine(scales=(0.9, 1.1), degrees=10), RandomNoise(std_range=(0, 0.10)), RandomFlip(axes=(0, )), ) transform_training[domain] = Compose(transform_training[domain]) transform_validation[domain] = ( ToCanonical(), ZNormalization(), CenterCropOrPad((144, 192, 48)), ) transform_validation[domain] = Compose(transform_validation[domain]) transform = { 'training': transform_training, 'validation': transform_validation } # MODEL norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} print("[INFO] Building model") model = Generic_UNet(input_modalities=MODALITIES_TARGET, base_num_features=32, num_classes=nb_classes, num_pool=4, num_conv_per_stage=2, feat_map_mul_on_downscale=2, conv_op=torch.nn.Conv3d, norm_op=torch.nn.InstanceNorm3d, norm_op_kwargs=norm_op_kwargs, nonlin=net_nonlin, nonlin_kwargs=net_nonlin_kwargs, convolutional_pooling=False, convolutional_upsampling=False, final_nonlin=torch.nn.Softmax(1)) print("[INFO] Training") train(paths_dict, model, transform, device, save_path, opt) sys.stdout = orig_stdout f.close()
def transform(self): test_transform = Compose([ ZNormalization(), ]) return test_transform
if len(label.shape) <= 3: label = np.expand_dims(label, axis=0) self.data_list.append(data) self.label_list.append(label) if __name__ == "__main__": slice_n = 99 spatial = RandomAffine(scales=1, degrees=3, isotropic=False, default_pad_value='otsu', image_interpolation='bspline') tmp_transform = Compose([spatial, ZNormalization()]) dataset = HaNOarsDataset(f'./data/{"HaN_OAR"}_shrink{2}x_padded160', 10) dataset.filter_labels([ OARS_LABELS.EYE_L, OARS_LABELS.EYE_R, OARS_LABELS.LENS_L, OARS_LABELS.LENS_R ], False) dataset.to_numpy() tmp_data, tmp_label = dataset[0] # tmp_label2 = dataset.label_list[0] # unique, counts = np.unique(tmp_label[tmp_label > 0], return_counts=True) # print(np.asarray((unique, counts)).T) # unique, counts = np.unique(tmp_label2[tmp_label2 > 0], return_counts=True) # print(np.asarray((unique, counts)).T)
def compose_transforms() -> Compose: print(f"{ctime()}: Setting up transformations...") """ # Our Preprocessing Options available in TorchIO are: * Intensity - NormalizationTransform - RescaleIntensity - ZNormalization - HistogramStandardization * Spatial - CropOrPad - Crop - Pad - Resample - ToCanonical We should read and experiment with these, but for now will just use a bunch with the default values. """ preprocessors = [ ToCanonical(p=1), ZNormalization(masking_method=None, p=1), # alternately, use RescaleIntensity ] """ # Our Augmentation Options available in TorchIO are: * Spatial - RandomFlip - RandomAffine - RandomElasticDeformation * Intensity - RandomMotion - RandomGhosting - RandomSpike - RandomBiasField - RandomBlur - RandomNoise - RandomSwap We should read and experiment with these, but for now will just use a bunch with the default values. """ augments = [ RandomFlip(axes=(0, 1, 2), flip_probability=0.5), RandomAffine(image_interpolation="linear", p=0.8), # default, compromise on speed + quality # this will be most processing intensive, leave out for now, see results # RandomElasticDeformation(p=1), RandomMotion(), RandomSpike(), RandomBiasField(), RandomBlur(), RandomNoise(), ] transform = Compose(preprocessors + augments) print(f"{ctime()}: Transformations registered.") return transform
fig, ax = plt.subplots(dpi=100) plot_histogram(ax, znormed.mri.data, label='Z-normed', alpha=1) ax.set_title('Intensity values of one sample after z-normalization') ax.set_xlabel('Intensity') ax.grid() training_transform = Compose([ ToCanonical(), # Resample(4), CropOrPad((112, 112, 48), padding_mode=0), #reflect , original 112,112,48 RandomMotion(num_transforms=6, image_interpolation='nearest', p=0.2), HistogramStandardization({'mri': landmarks}), RandomBiasField(p=0.2), RandomBlur(p=0.2), ZNormalization(masking_method=ZNormalization.mean), RandomFlip(axes=['inferior-superior'], flip_probability=0.2), # RandomNoise(std=0.5, p=0.2), RandomGhosting(intensity=1.8, p=0.2), # RandomNoise(), # RandomFlip(axes=(0,)), # OneOf({ # RandomAffine(): 0.8, # RandomElasticDeformation(): 0.2, # }), ]) validation_transform = Compose([ ToCanonical(), # Resample(4), CropOrPad((112, 112, 48), padding_mode=0), #original 112,112,48
def test_z_normalization(self): transform = ZNormalization() transformed = transform(self.sample_subject) self.assertAlmostEqual(float(transformed.t1.data.mean()), 0., places=6) self.assertAlmostEqual(float(transformed.t1.data.std()), 1.)
def main(): opt = parsing_data() print("[INFO] Reading data.") # Dictionary with data parameters for NiftyNet Reader if torch.cuda.is_available(): print('[INFO] GPU available.') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: raise Exception( "[INFO] No GPU found or Wrong gpu id, please run without --cuda") # FOLDERS fold_dir = opt.model_dir checkpoint_path = os.path.join(fold_dir, 'models', './CP_{}.pth') checkpoint_path = checkpoint_path.format(opt.epoch_infe) assert os.path.isfile(checkpoint_path), 'no checkpoint found' output_path = opt.output_dir if not os.path.exists(output_path): os.makedirs(output_path) output_path = os.path.join(output_path, 'output_{}.nii.gz') # SPLITS split_path = opt.dataset_split assert os.path.isfile(split_path), 'split file not found' print('Split file found: {}'.format(split_path)) # Reading csv file df_split = pd.read_csv(split_path, header=None) list_file = dict() list_split = ['inference', 'validation'] for split in list_split: list_file[split] = df_split[df_split[1].isin([split.lower() ])][0].tolist() # filing paths add_name = '_sym' if opt.add_sym else '' paths_dict = {split: [] for split in list_split} for split in list_split: for subject in list_file[split]: subject_data = [] for modality in MODALITIES: subject_modality = opt.path_file + subject + modality + add_name + '.nii.gz' if os.path.isfile(subject_modality): subject_data.append( Image(modality, subject_modality, torchio.INTENSITY)) if len(subject_data) > 0: paths_dict[split].append(Subject(*subject_data)) transform_inference = ( ToCanonical(), ZNormalization(), ) transform_inference = Compose(transform_inference) # MODEL norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} print("[INFO] Building model.") model = Generic_UNet(input_modalities=['T1', 'all'], base_num_features=32, num_classes=opt.nb_classes, num_pool=4, num_conv_per_stage=2, feat_map_mul_on_downscale=2, conv_op=torch.nn.Conv3d, norm_op=torch.nn.InstanceNorm3d, norm_op_kwargs=norm_op_kwargs, nonlin=net_nonlin, nonlin_kwargs=net_nonlin_kwargs, convolutional_pooling=False, convolutional_upsampling=False, final_nonlin=lambda x: x, input_features={ 'T1': 1, 'all': 4 }) paths_inf = paths_dict['inference'] + paths_dict['validation'] inference_padding(paths_inf, model, transform_inference, device, output_path, checkpoint_path, opt)
RandomAffine, ) # Mock PyTorch model model = lambda x: x # Define training and patches sampling parameters num_epochs = 4 patch_size = 128 queue_length = 100 samples_per_volume = 1 batch_size = 2 # Define transforms for data normalization and augmentation transforms = ( ZNormalization(), RandomAffine(scales=(0.9, 1.1), degrees=10), RandomNoise(std_range=(0, 0.25)), RandomFlip(axes=(0, )), ) transform = Compose(transforms) # Populate a list with dictionaries of paths one_subject_dict = { 'T1': dict(path='../BRATS2018_crop_renamed/LGG75_T1.nii.gz', type=torchio.INTENSITY), 'T2': dict(path='../BRATS2018_crop_renamed/LGG75_T2.nii.gz', type=torchio.INTENSITY), 'label':
def main(): opt = parsing_data() print("[INFO] Reading data") # Dictionary with data parameters for NiftyNet Reader if torch.cuda.is_available(): print('[INFO] GPU available.') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: raise Exception( "[INFO] No GPU found or Wrong gpu id, please run without --cuda") # FOLDERS fold_dir = opt.model_dir fold_dir_model = os.path.join(fold_dir, 'models') if not os.path.exists(fold_dir_model): os.makedirs(fold_dir_model) save_path = os.path.join(fold_dir_model, './CP_{}.pth') output_path = os.path.join(fold_dir, 'output') if not os.path.exists(output_path): os.makedirs(output_path) output_path = os.path.join(output_path, 'output_{}.nii.gz') # LOGGING orig_stdout = sys.stdout if os.path.exists(os.path.join(fold_dir, 'out.txt')): compt = 0 while os.path.exists( os.path.join(fold_dir, 'out_' + str(compt) + '.txt')): compt += 1 f = open(os.path.join(fold_dir, 'out_' + str(compt) + '.txt'), 'w') else: f = open(os.path.join(fold_dir, 'out.txt'), 'w') #sys.stdout = f print("[INFO] Hyperparameters") print('Alpha: {}'.format(opt.alpha)) print('Beta: {}'.format(opt.beta)) print('Beta_DA: {}'.format(opt.beta_da)) print('Weight Reg: {}'.format(opt.weight_crf)) # SPLITS split_path_source = opt.dataset_split_source assert os.path.isfile(split_path_source), 'source file not found' split_path_target = opt.dataset_split_target assert os.path.isfile(split_path_target), 'target file not found' split_path = dict() split_path['source'] = split_path_source split_path['target'] = split_path_target path_file = dict() path_file['source'] = opt.path_source path_file['target'] = opt.path_target list_split = ['training', 'validation', 'inference'] paths_dict = dict() for domain in ['source', 'target']: df_split = pd.read_csv(split_path[domain], header=None) list_file = dict() for split in list_split: list_file[split] = df_split[df_split[1].isin([split])][0].tolist() list_file['inference'] += list_file['validation'] paths_dict_domain = {split: [] for split in list_split} for split in list_split: for subject in list_file[split]: subject_data = [] for modality in MODALITIES[domain]: subject_data.append( Image( modality, path_file[domain] + subject + modality + '.nii.gz', torchio.INTENSITY)) if split in ['training', 'validation']: if domain == 'source': subject_data.append( Image( 'label', path_file[domain] + subject + 't1_seg.nii.gz', torchio.LABEL)) else: subject_data.append( Image( 'scribble', path_file[domain] + subject + 't2scribble_cor.nii.gz', torchio.LABEL)) #subject_data[] = paths_dict_domain[split].append(Subject(*subject_data)) print(domain, split, len(paths_dict_domain[split])) paths_dict[domain] = paths_dict_domain # PREPROCESSING transform_training = dict() transform_validation = dict() for domain in ['source', 'target']: transformations = ( ToCanonical(), ZNormalization(), CenterCropOrPad((288, 128, 48)), RandomAffine(scales=(0.9, 1.1), degrees=10), RandomNoise(std_range=(0, 0.10)), RandomFlip(axes=(0, )), ) transform_training[domain] = Compose(transformations) for domain in ['source', 'target']: transformations = (ToCanonical(), ZNormalization(), CenterCropOrPad((288, 128, 48))) transform_validation[domain] = Compose(transformations) transform = { 'training': transform_training, 'validation': transform_validation } # MODEL norm_op_kwargs = {'eps': 1e-5, 'affine': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} print("[INFO] Building model") model = UNet2D5(input_channels=1, base_num_features=16, num_classes=NB_CLASSES, num_pool=4, conv_op=nn.Conv3d, norm_op=nn.InstanceNorm3d, norm_op_kwargs=norm_op_kwargs, nonlin=net_nonlin, nonlin_kwargs=net_nonlin_kwargs) print("[INFO] Training") #criterion = DC_and_CE_loss({}, {}) criterion = DC_CE(NB_CLASSES) train(paths_dict, model, transform, criterion, device, save_path, opt)