def get_motion_transform(type='motion1'): if 'motion1' in type: dico_params_mot = { "maxDisp": (1, 6), "maxRot": (1, 6), "noiseBasePars": (5, 20, 0.8), "swallowFrequency": (2, 6, 0.5), "swallowMagnitude": (3, 6), "suddenFrequency": (2, 6, 0.5), "suddenMagnitude": (3, 6), "verbose": False, "keep_original": True, "proba_to_augment": 1, "preserve_center_pct": 0.1, "keep_original": True, "compare_to_original": True, "oversampling_pct": 0, "correct_motion": False } dico_params_mot = { "maxDisp": (1, 4), "maxRot": (1, 4), "noiseBasePars": (5, 20, 0.8), "swallowFrequency": (2, 6, 0.5), "swallowMagnitude": (3, 4), "suddenFrequency": (2, 6, 0.5), "suddenMagnitude": (3, 4), "verbose": False, "keep_original": True, "proba_to_augment": 1, "preserve_center_pct": 0.1, "keep_original": True, "compare_to_original": True, "oversampling_pct": 0, "correct_motion": False } if 'elastic1' in type: dico_elast = { 'num_control_points': 6, 'max_displacement': (30, 30, 30), 'proportion_to_augment': 1, 'image_interpolation': Interpolation.LINEAR } if type == 'motion1': transforms = Compose([ RandomMotionFromTimeCourse(**dico_params_mot), ]) elif type == 'elastic1_and_motion1': transforms = Compose([ RandomElasticDeformation(**dico_elast), RandomMotionFromTimeCourse(**dico_params_mot) ]) if type == 'random_noise_1': transforms = Compose([RandomNoise(std=(0.020, 0.2))]) return transforms
MapMetricWrapper("L1_map", lambda x, y: torch.abs(x - y), average_method="mean", mask_keys=['mask2']), "L2": MetricWrapper("L2", MSELoss()), #"SSIM": SSIM3D(average_method="mean"), "SSIM_mask": SSIM3D(average_method="mean", mask_keys=["mask", "mask2"]), "SSIM_Wrapped": MetricWrapper("SSIM_wrapped", lambda x, y: functional_ssim(x, y, return_map=False), use_mask=True, mask_key="mask"), "ssim_base": MetricWrapper('SSIM_base', ssim3D) } motion_trsfm = RandomMotionFromTimeCourse(verbose=True, compare_to_original=True, metrics=metrics, oversampling_pct=0.0) dataset.set_transform(motion_trsfm) tf = dataset[0] computed_metrics = tf["T1"]["metrics"] print("Computed metrics: {}".format(computed_metrics)) ov(tf['T1']['data'].squeeze().numpy())
"suddenMagnitude": (3, 4), "verbose": False, "keep_original": True, "proba_to_augment": 1, "preserve_center_pct": 0.1, "keep_original": True, "compare_to_original": True, "oversampling_pct": 0, "correct_motion": True } fipar = pd.read_csv( '/home/romain/QCcnn/mask_mvt_val_cati_T1/ssim_0.6956839561462402_sample00220_suj_cat12_s_S07_3DT1_mvt.csv', header=None) dico_params['fitpars'] = fipar.values t = RandomMotionFromTimeCourse(**dico_params) dataset = ImagesDataset(suj, transform=Compose( (CenterCropOrPad(target_shape=(182, 218, 182)), t))) dataset = ImagesDataset(suj, transform=Compose( (CenterCropOrPad(target_shape=(176, 240, 256)), t))) dataset = ImagesDataset(suj, transform=Compose( (CenterCropOrPad(target_shape=(182, 218, 256)), t))) dataset = ImagesDataset(suj, transform=Compose((t, ))) s = dataset[0]
def get_motion_transform(type='motion1'): if 'motion1' in type: from torchio.metrics import SSIM3D, MetricWrapper, MapMetricWrapper from torchio.metrics.ssim import functional_ssim from torchio.metrics.old_metrics import th_pearsonr, NCC from torch.nn import MSELoss, L1Loss #from torch_similarity.modules import NormalizedCrossCorrelation metrics = { # "L1": MetricWrapper("L1", L1Loss()), #same as L1_map #"NCC_c": MetricWrapper("L1", NCC()), "L1_map": MapMetricWrapper("L1_map", lambda x, y: torch.abs(x - y), average_method="mean", mask_keys=['brain']), # "L2": MapMetricWrapper("L2", MSELoss(), mask_keys=['brain']), # "SSIM": SSIM3D(average_method="mean"), "SSIM_mask": SSIM3D(average_method="mean", mask_keys=["brain"]), "NCC": MetricWrapper("NCC_th_brain", lambda x, y: th_pearsonr(x, y), use_mask=True, mask_key='brain'), "NCC2": MetricWrapper("NCC_th", lambda x, y: th_pearsonr(x, y), use_mask=False), # "SSIM": MetricWrapper("SSIM", lambda x, y: functional_ssim(x, y, return_map=False), # use_mask=True, mask_key="brain"), "ssim_base": MapMetricWrapper('SSIM_base', lambda x, y: ssim3D(x, y, size_average=True), average_method="mean", mask_keys=['brain']) } # metrics = {"L1_map": MapMetricWrapper("L1_map", lambda x, y: torch.abs(x - y), average_method="mean", # mask_keys=['brain'])} dico_params_mot = { "maxDisp": (1, 6), "maxRot": (1, 6), "noiseBasePars": (5, 20, 0.8), "swallowFrequency": (2, 6, 0.5), "swallowMagnitude": (3, 6), "suddenFrequency": (2, 6, 0.5), "suddenMagnitude": (3, 6), "verbose": False, "proba_to_augment": 1, "preserve_center_pct": 0.1, "compare_to_original": True, "oversampling_pct": 0, "correct_motion": False } dico_params_mot = { "maxDisp": (1, 4), "maxRot": (1, 4), "noiseBasePars": (5, 20, 0.8), "swallowFrequency": (2, 6, 0.5), "swallowMagnitude": (3, 4), "suddenFrequency": (2, 6, 0.5), "suddenMagnitude": (3, 4), "verbose": False, "proba_to_augment": 1, "preserve_center_pct": 0.1, "compare_to_original": True, "metrics": metrics, "oversampling_pct": 0, "correct_motion": False } if 'elastic1' in type: dico_elast = { 'num_control_points': 6, 'max_displacement': (30, 30, 30), 'p': 1, 'image_interpolation': Interpolation.LINEAR } if type == 'motion1': transforms = Compose([ RandomMotionFromTimeCourse(**dico_params_mot), ]) if type == 'elastic1': transforms = Compose([ RandomElasticDeformation(**dico_elast), ]) elif type == 'elastic1_and_motion1': transforms = Compose([ RandomElasticDeformation(**dico_elast), RandomMotionFromTimeCourse(**dico_params_mot) ]) if type == 'random_noise_1': transforms = Compose([RandomNoise(std=(0.020, 0.2))]) if type == 'AffFFT_random_noise': transforms = Compose([ RandomAffineFFT(scales=(0.8, 1.2), degrees=10, oversampling_pct=0.2, p=0.75), RandomNoise(std=(0.020, 0.2)) ]) if type == 'AffFFT_random_noise': transforms = Compose([ RandomAffine(scales=(0.8, 1.2), degrees=10, p=0.75, image_interpolation=Interpolation.NEAREST), RandomNoise(std=(0.020, 0.2)) ]) return transforms
out_path = '/data/romain/data_exemple/test2/' if not os.path.exists(out_path): os.mkdir(out_path) #plt.ioff() data_ref, aff = read_image('/data/romain/data_exemple/suj_150423/mT1w_1mm.nii') res, res_fitpar, extra_info = pd.DataFrame(), pd.DataFrame(), dict() disp_str = disp_str_list[0]; s = 2; xx = 100 for disp_str in disp_str_list: for s in [2, 20]: #[1, 2, 3, 5, 7, 10, 12 , 15, 20 ] : # [2,4,6] : #[1, 3 , 5 , 8, 10 , 12, 15, 20 , 25 ]: for xx in x0: dico_params['displacement_shift_strategy'] = disp_str fp = corrupt_data(xx, sigma=s, method=mvt_type, amplitude=10, mvt_axes=mvt_axes) dico_params['fitpars'] = fp dico_params['nT'] = fp.shape[1] t = RandomMotionFromTimeCourse(**dico_params) if 'synth' in suj_type: dataset = SubjectsDataset(suj, transform= torchio.Compose([tlab, t ])) else: dataset = SubjectsDataset(suj, transform= t ) sample = dataset[0] fout = out_path + '/{}_{}_{}_s{}_freq{}_{}'.format(suj_type, mvt_axe_str, mvt_type, s, xx, disp_str) fit_pars = t.fitpars - np.tile(t.to_substract[..., np.newaxis],(1,200)) # fig = plt.figure();plt.plot(fit_pars.T);plt.savefig(fout+'.png');plt.close(fig) #sample['image'].save(fout+'.nii') extra_info['x0'], extra_info['mvt_type'], extra_info['mvt_axe']= xx, mvt_type, mvt_axe_str extra_info['shift_type'], extra_info['sigma'], extra_info['amp'] = disp_str, s, 10 extra_info['disp'] = np.sum(t.to_substract)