コード例 #1
0
 def test_save_sample(self):
     dataset = ImagesDataset(self.subjects_list, transform=lambda x: x)
     _ = len(dataset)  # for coverage
     sample = dataset[0]
     output_path = self.dir / 'test.nii.gz'
     paths_dict = {'t1': output_path}
     dataset.save_sample(sample, paths_dict)
     nii = nib.load(str(output_path))
     ndims_output = len(nii.shape)
     ndims_sample = len(sample['t1'][DATA].shape)
     assert ndims_sample == ndims_output + 1
コード例 #2
0
 def test_no_load_transform(self):
     with self.assertRaises(ValueError):
         dataset = ImagesDataset(
             self.subjects_list,
             load_image_data=False,
             transform=lambda x: x,
         )
コード例 #3
0
    def set_data_loader_from_file_list(self,
                                       fin,
                                       transforms=None,
                                       mask_key=None,
                                       mask_regex=None,
                                       batch_size=1,
                                       num_workers=0,
                                       shuffel=True):

        suj_list = get_subject_list_from_file_list(fin,
                                                   mask_regex=mask_regex,
                                                   mask_key=mask_key)

        if not isinstance(transforms, torchvision.transforms.transforms.Compose
                          ) and transforms is not None:
            transforms = Compose(transforms)

        train_dataset = ImagesDataset(suj_list, transform=transforms)

        self.train_dataset = train_dataset
        self.train_dataloader = DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=shuffel,
                                           num_workers=num_workers)
        self.val_dataloader = self.train_dataloader
コード例 #4
0
ファイル: utils.py プロジェクト: animesh/torchio
    def setUp(self):
        """Set up test fixtures, if any."""
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
        self.dir.mkdir(exist_ok=True)
        random.seed(42)
        np.random.seed(42)

        registration_matrix = np.array([[1, 0, 0, 10], [0, 1, 0, 0],
                                        [0, 0, 1.2, 0], [0, 0, 0, 1]])

        subject_a = Subject(t1=Image(self.get_image_path('t1_a'), INTENSITY), )
        subject_b = Subject(
            t1=Image(self.get_image_path('t1_b'), INTENSITY),
            label=Image(self.get_image_path('label_b', binary=True), LABEL),
        )
        subject_c = Subject(label=Image(
            self.get_image_path('label_c', binary=True), LABEL), )
        subject_d = Subject(
            t1=Image(
                self.get_image_path('t1_d'),
                INTENSITY,
                pre_affine=registration_matrix,
            ),
            t2=Image(self.get_image_path('t2_d'), INTENSITY),
            label=Image(self.get_image_path('label_d', binary=True), LABEL),
        )
        self.subjects_list = [
            subject_a,
            subject_b,
            subject_c,
            subject_d,
        ]
        self.dataset = ImagesDataset(self.subjects_list)
        self.sample = self.dataset[-1]
コード例 #5
0
ファイル: utils.py プロジェクト: zhangjh705/torchio
    def setUp(self):
        """Set up test fixtures, if any."""
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
        self.dir.mkdir(exist_ok=True)
        random.seed(42)
        np.random.seed(42)

        subject_a = Subject(
            Image('t1', self.get_image_path('t1_a'), INTENSITY), )
        subject_b = Subject(
            Image('t1', self.get_image_path('t1_b'), INTENSITY),
            Image('label', self.get_image_path('label_b', binary=True), LABEL),
        )
        subject_c = Subject(
            Image('label', self.get_image_path('label_c', binary=True),
                  LABEL), )
        subject_d = Subject(
            Image('t1', self.get_image_path('t1_d'), INTENSITY),
            Image('t2', self.get_image_path('t2_d'), INTENSITY),
            Image('label', self.get_image_path('label_d', binary=True), LABEL),
        )
        self.subjects_list = [
            subject_a,
            subject_b,
            subject_c,
            subject_d,
        ]
        self.dataset = ImagesDataset(self.subjects_list)
        self.sample = self.dataset[-1]
コード例 #6
0
def inference_padding(paths_dict, model, transformation, device, pred_path,
                      cp_path, opt):
    print("[INFO] Loading model.")
    model.load_state_dict(torch.load(cp_path))
    model.to(device)
    model.eval()

    subjects_dataset_inf = ImagesDataset(paths_dict, transform=transformation)

    nb_subject = len(subjects_dataset_inf)

    border = (0, 0, 0)

    print("[INFO] Starting Inference.")
    print('Number of subjects to infer: {}'.format(nb_subject))

    for param in model.parameters():
        print(2, param.data.reshape(-1)[0])
        break

    for index, batch in enumerate(subjects_dataset_inf):
        name = batch['T1']['stem'].split('T1')[0]
        affine = batch['T1']['affine']
        reference = torchio.utils.nib_to_sitk(batch['T1'][DATA].numpy(),
                                              affine)

        batch_pad = CenterCropOrPad((144, 192, 48))(batch)

        affine_pad = batch_pad['T1']['affine']
        nb_modalities = len(
            [k for k in batch_pad.keys() if k in opt.modalities])
        assert nb_modalities in [
            1, 2
        ], 'Incorrect number of modalities for {}'.format(name.split('T1')[0])

        data = {'T1': batch_pad['T1'][DATA].unsqueeze(0).to(device)}
        if nb_modalities == 2:
            data['FLAIR'] = torch.cat(
                [batch_pad['T1'][DATA], batch_pad['FLAIR'][DATA]],
                0).unsqueeze(0).to(device)

        with torch.no_grad():
            logits, _ = model(data)
            labels = logits.argmax(dim=1, keepdim=True)
            labels = labels[0, 0, ...].cpu().numpy()
        # output = np.zeros(shape[2:])
        # output[:x_shape,:y_shape,:z_shape] = labels
        output = labels
        #output = nib.Nifti1Image(output.astype(float), affine_pad)
        output = torchio.utils.nib_to_sitk(output.astype(float), affine_pad)
        output = sitk.Resample(
            output,
            reference,
            sitk.Transform(),
            sitk.sitkNearestNeighbor,
        )
        sitk.WriteImage(output, pred_path.format(name))
        print('{}/{} - Inference done for {} with {} modalities'.format(
            index, nb_subject, name, nb_modalities))
コード例 #7
0
 def test_data_loader(self):
     from torch.utils.data import DataLoader
     subj_list = [torchio.datasets.Colin27()]
     dataset = ImagesDataset(subj_list)
     loader = DataLoader(dataset, batch_size=1, shuffle=True)
     for batch in loader:
         batch['t1'][DATA]
         batch['brain'][DATA]
コード例 #8
0
def main():
    # Define training and patches sampling parameters
    num_epochs = 20
    patch_size = 128
    queue_length = 100
    samples_per_volume = 5
    batch_size = 2

    # Populate a list with images
    one_subject = Subject(
        T1=Image('../BRATS2018_crop_renamed/LGG75_T1.nii.gz',
                 torchio.INTENSITY),
        T2=Image('../BRATS2018_crop_renamed/LGG75_T2.nii.gz',
                 torchio.INTENSITY),
        label=Image('../BRATS2018_crop_renamed/LGG75_Label.nii.gz',
                    torchio.LABEL),
    )

    # This subject doesn't have a T2 MRI!
    another_subject = Subject(
        T1=Image('../BRATS2018_crop_renamed/LGG74_T1.nii.gz',
                 torchio.INTENSITY),
        label=Image('../BRATS2018_crop_renamed/LGG74_Label.nii.gz',
                    torchio.LABEL),
    )

    subjects = [
        one_subject,
        another_subject,
    ]

    subjects_dataset = ImagesDataset(subjects)
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        samples_per_volume,
        patch_size,
        ImageSampler,
    )

    # This collate_fn is needed in the case of missing modalities
    # In this case, the batch will be composed by a *list* of samples instead of
    # the typical Python dictionary that is collated by default in Pytorch
    batch_loader = DataLoader(
        queue_dataset,
        batch_size=batch_size,
        collate_fn=lambda x: x,
    )

    # Mock PyTorch model
    model = nn.Identity()

    for epoch_index in range(num_epochs):
        for batch in batch_loader:  # batch is a *list* here, not a dictionary
            logits = model(batch)
            print([batch[idx].keys() for idx in range(batch_size)])
    print()
コード例 #9
0
ファイル: utils.py プロジェクト: zhangjh705/torchio
 def get_inconsistent_sample(self):
     """Return a sample containing images of different shape."""
     subject = Subject(
         Image('t1', self.get_image_path('t1_d'), INTENSITY),
         Image('t2', self.get_image_path('t2_d', shape=(10, 20, 31)),
               INTENSITY),
         Image('label', self.get_image_path('label_d', binary=True), LABEL),
     )
     subjects_list = [subject]
     dataset = ImagesDataset([subject])
     return dataset[0]
コード例 #10
0
def TorchIODataLoader3DTraining(config: SemSegConfig) -> torch.utils.data.DataLoader:
    print('Building TorchIO Training Set Loader...')
    subject_list = list()
    for idx, (image_path, label_path) in enumerate(zip(config.train_images, config.train_labels)):
        s1 = torchio.Subject(
            t1=Image(type=torchio.INTENSITY, path=image_path),
            label=Image(type=torchio.LABEL, path=label_path),
        )

        subject_list.append(s1)

    subjects_dataset = ImagesDataset(subject_list, transform=config.transform_train)
    train_data = torch.utils.data.DataLoader(subjects_dataset, batch_size=config.batch_size,
                                             shuffle=True, num_workers=config.num_workers)
    print('TorchIO Training Loader built!')
    return train_data
コード例 #11
0
ファイル: test_queue.py プロジェクト: xm-cmic/torchio
 def test_queue(self):
     subjects_dataset = ImagesDataset(self.subjects_list)
     queue_dataset = Queue(
         subjects_dataset,
         max_length=6,
         samples_per_volume=2,
         patch_size=10,
         sampler_class=ImageSampler,
         num_workers=2,
         verbose=True,
     )
     _ = str(queue_dataset)
     batch_loader = DataLoader(queue_dataset, batch_size=4)
     for batch in batch_loader:
         _ = batch['one_modality'][DATA]
         _ = batch['segmentation'][DATA]
コード例 #12
0
ファイル: utils.py プロジェクト: animesh/torchio
 def get_inconsistent_sample(self):
     """Return a sample containing images of different shape."""
     subject = Subject(
         t1=Image(self.get_image_path('t1_d'), INTENSITY),
         t2=Image(self.get_image_path('t2_d', shape=(10, 20, 31)),
                  INTENSITY),
         label=Image(
             self.get_image_path(
                 'label_d',
                 shape=(8, 17, 25),
                 binary=True,
             ),
             LABEL,
         ),
     )
     subjects_list = [subject]
     dataset = ImagesDataset(subjects_list)
     return dataset[0]
コード例 #13
0
ファイル: inference.py プロジェクト: KCL-BMEIS/ScribbleDA
def inference_padding(paths_dict, model, transformation, device, pred_path,
                      cp_path, opt):

    model.load_state_dict(torch.load(cp_path))
    model.to(device)
    model.eval()

    subjects_dataset_inf = ImagesDataset(paths_dict, transform=transformation)

    # batch_loader_inf = DataLoader(subjects_dataset_inf, batch_size=1)
    #window_size = (256,256,256)
    window_size = patch_size
    border = (0, 0, 0)
    for batch in tqdm(subjects_dataset_inf):
        batch_pad = CenterCropOrPad((288, 128, 48))(batch)
        mod_used = MODALITIES[-1]

        data = batch_pad[mod_used][DATA].cuda().unsqueeze(0)

        reference = torchio.utils.nib_to_sitk(batch[mod_used][DATA].numpy(),
                                              batch[mod_used]['affine'])

        affine_pad = batch_pad[mod_used]['affine']
        name = batch[mod_used]['stem']

        with torch.no_grad():
            logits, _ = model(data, 'source')
            labels = logits.argmax(dim=1, keepdim=True)
            labels = labels[0, 0, ...].cpu().numpy()
        output = labels

        output = torchio.utils.nib_to_sitk(output.astype(float), affine_pad)
        output = sitk.Resample(
            output,
            reference,
            sitk.Transform(),
            sitk.sitkNearestNeighbor,
        )
        sitk.WriteImage(output, pred_path.format(name))
コード例 #14
0
        #dir_cat = gdir(dir_img,'cat12')
        #fm = gfile(dir_cat, '^mask_brain', {"items": 1})
        #fp1 = gfile(dir_cat, '^p1', {"items": 1})
        #fp2 = gfile(dir_cat, '^p2', {"items": 1})

    one_suj = {'image': Image(fin, INTENSITY), 'brain': Image(fm[0], LABEL)}
    if len(fp1) == 1:
        one_suj['p1'] = Image(fp1[0], LABEL)
    if len(fp2) == 1:
        one_suj['p2'] = Image(fp2[0], LABEL)

    subject = [Subject(one_suj) for i in range(0, nb_sample)]
    #subject = [ one_suj for i in range(0,nb_sample) ]
    print('input list is duplicated {} '.format(len(subject)))
    #subject = Subject(subject)
    dataset = ImagesDataset(subject, transform=transfo)

    for i in range(0, nb_sample):

        sample = dataset[i]  #in n time sample[0] it is cumulativ

        image_dict = sample['image']
        volume_path = image_dict['path']
        dd = volume_path.split('/')
        volume_name = dd[len(dd) - 2] + '_' + image_dict['stem']
        #nb_saved = image_dict['index'] #

        fname = resdir_mvt + 'ssim_{}_sample{:05d}_suj_{}_mvt.csv'.format(
            image_dict['metrics']['ssim'], index, volume_name)

        t = dataset.get_transform()
コード例 #15
0
    Image('resection_gray_matter_left', images_dir / gp('gray_matter_left_seg'), None),
    Image('resection_resectable_left', images_dir / gp('resectable_left_seg'), None),
    Image('resection_gray_matter_right', images_dir / gp('gray_matter_right_seg'), None),
    Image('resection_resectable_right', images_dir / gp('resectable_right_seg'), None),
)

df_volumes = pd.read_csv(Path('~/episurg/volumes.csv').expanduser())
volumes = df_volumes.Volume.values

transform = RandomResection(
    volumes=volumes,
    # sigmas_range=(0.75, 0.75),
    keep_original=True,
    verbose=True,
    # seed=42,
)

dataset = ImagesDataset([subject])
transformed = dataset[0]

for i in range(10):
    transformed = transform(dataset[0])
    dataset.save_sample(
        transformed,
        dict(
            image=f'/tmp/resected_{i}.nii.gz',
            # image_original='/tmp/resected_original.nii.gz',
            label=f'/tmp/resected_label_{i}.nii.gz',
        ),
    )
コード例 #16
0
    def set_data_loader(self,
                        train_csv_file='',
                        val_csv_file='',
                        transforms=None,
                        batch_size=1,
                        num_workers=0,
                        par_queue=None,
                        save_to_dir=None,
                        load_from_dir=None,
                        replicate_suj=0,
                        shuffel_train=True,
                        get_condition_csv=None,
                        get_condition_field='',
                        get_condition_nb_wanted=1 / 4,
                        collate_fn=None,
                        add_to_load=None,
                        add_to_load_regexp=None):

        if not isinstance(transforms, torchvision.transforms.transforms.Compose
                          ) and transforms is not None:
            transforms = Compose(transforms)

        if load_from_dir is not None:
            if type(load_from_dir) == str:
                load_from_dir = [load_from_dir, load_from_dir]
            fsample_train, fsample_val = gfile(load_from_dir[0],
                                               'sample.*pt'), gfile(
                                                   load_from_dir[1],
                                                   'sample.*pt')
            #random.shuffle(fsample_train)
            #fsample_train = fsample_train[0:10000]

            if get_condition_csv is not None:
                res = pd.read_csv(load_from_dir[0] + '/' + get_condition_csv)
                cond_val = res[get_condition_field].values

                y = np.linspace(np.min(cond_val), np.max(cond_val), 101)
                nb_wanted_per_interval = int(
                    np.round(len(cond_val) * get_condition_nb_wanted / 100))
                y_select = []
                for i in range(len(y) - 1):
                    indsel = np.where((cond_val > y[i])
                                      & (cond_val < y[i + 1]))[0]
                    nb_select = len(indsel)
                    if nb_select < nb_wanted_per_interval:
                        print(
                            ' only {} / {} for interval {} {:,.3f} |  {:,.3f} '
                            .format(nb_select, nb_wanted_per_interval, i, y[i],
                                    y[i + 1]))
                        y_select.append(indsel)
                    else:
                        pind = np.random.permutation(range(0, nb_select))
                        y_select.append(indsel[pind[0:nb_wanted_per_interval]])
                        #print('{} selecting {}'.format(i, len(y_select[-1])))
                ind_select = np.hstack(y_select)
                y = cond_val[ind_select]
                fsample_train = [fsample_train[ii] for ii in ind_select]
                self.log_string += '\nfinal selection {} soit {:,.3f} % instead of {:,.3f} %'.format(
                    len(y),
                    len(y) / len(cond_val) * 100,
                    get_condition_nb_wanted * 100)

                #conditions = [("MSE", ">", 0.0028),]
                #select_ind = apply_conditions_on_dataset(res,conditions)
                #fsel = [fsample_train[ii] for ii,jj in enumerate(select_ind) if jj]

            self.log_string += '\nloading {} train sample from {}'.format(
                len(fsample_train), load_from_dir[0])
            self.log_string += '\nloading {} val   sample from {}'.format(
                len(fsample_val), load_from_dir[1])
            train_dataset = ImagesDataset(
                fsample_train,
                load_from_dir=load_from_dir[0],
                transform=transforms,
                add_to_load=add_to_load,
                add_to_load_regexp=add_to_load_regexp)
            self.train_csv_load_file_train = fsample_train

            val_dataset = ImagesDataset(fsample_val,
                                        load_from_dir=load_from_dir[1],
                                        transform=transforms,
                                        add_to_load=add_to_load,
                                        add_to_load_regexp=add_to_load_regexp)
            self.train_csv_load_file_train = fsample_val

        else:
            data_parameters = {
                'image': {
                    'csv_file': train_csv_file
                },
            }
            data_parameters_val = {
                'image': {
                    'csv_file': val_csv_file
                },
            }

            paths_dict, info = get_subject_list_and_csv_info_from_data_prameters(
                data_parameters, fpath_idx='filename')
            paths_dict_val, info_val = get_subject_list_and_csv_info_from_data_prameters(
                data_parameters_val, fpath_idx='filename', shuffle_order=False)

            if replicate_suj:
                lll = []
                for i in range(0, replicate_suj):
                    lll.extend(paths_dict)
                paths_dict = lll
                self.log_string += 'Replicating train dataSet {} times, new length is {}'.format(
                    replicate_suj, len(lll))

            train_dataset = ImagesDataset(paths_dict,
                                          transform=transforms,
                                          save_to_dir=save_to_dir)
            val_dataset = ImagesDataset(paths_dict_val,
                                        transform=transforms,
                                        save_to_dir=save_to_dir)

        self.res_name += '_B{}_nw{}'.format(batch_size, num_workers)

        if par_queue is not None:
            self.patch = True
            windows_size = par_queue['windows_size']
            if len(windows_size) == 1:
                windows_size = [
                    windows_size[0], windows_size[0], windows_size[0]
                ]

            train_queue = Queue(train_dataset,
                                par_queue['queue_length'],
                                par_queue['samples_per_volume'],
                                windows_size,
                                ImageSampler,
                                num_workers=num_workers,
                                verbose=self.verbose)

            val_queue = Queue(val_dataset,
                              par_queue['queue_length'],
                              1,
                              windows_size,
                              ImageSampler,
                              num_workers=num_workers,
                              shuffle_subjects=False,
                              shuffle_patches=False,
                              verbose=self.verbose)
            self.res_name += '_spv{}'.format(par_queue['samples_per_volume'])

            self.train_dataloader = DataLoader(train_queue,
                                               batch_size=batch_size,
                                               shuffle=shuffel_train,
                                               collate_fn=collate_fn)
            self.val_dataloader = DataLoader(val_queue,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             collate_fn=collate_fn)

        else:
            self.train_dataset = train_dataset
            self.train_dataloader = DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffel_train,
                                               num_workers=num_workers,
                                               collate_fn=collate_fn)
            self.val_dataloader = DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers,
                                             collate_fn=collate_fn)
コード例 #17
0
ファイル: test_predic_cnn.py プロジェクト: GFabien/torchQC
write_image(tensor, affine, '/tmp/toto.nii')
mvt = pd.read_csv(
    '/home/romain/QCcnn/motion_cati_brain_ms/ssim_0.03557806462049484_sample00010_suj_cat12_brain_s_S02_Sag_MPRAGE_mvt.csv',
    header=None)
fpars = np.asarray(mvt)

transforms = get_motion_transform()

suj = [[
    Image(
        'T1',
        '/home/romain/QCcnn/motion_cati_brain_ms/brain_s_S02_Sag_MPRAGE.nii.gz',
        'intensity'),
]]

dataset = ImagesDataset(suj, transform=transforms)
s = dataset[0]

ov(s['T1']['data'][0])
tt = dataset.get_transform()
plt.figure()
plt.plot(tt.fitpars.T)
dataset.save_sample(
    s, dict(T1='/home/romain/QCcnn/motion_cati_brain_ms/toto10.nii'))

#look at distribution of metric on simulate motion
dir_cache = '/network/lustre/dtlake01/opendata/data/ds000030/rrr/CNN_cache'
dd = gdir(dir_cache, 'mask_mv')
fr = gfile(dd, 'resul')
name_res = get_parent_path(dd)[1]
res = [pd.read_csv(ff) for ff in fr]
コード例 #18
0
data_parameters = {'image': {'csv_file': '/data/romain/data_exemple/file_ms.csv', 'type': torchio.INTENSITY},
                   'label1': {'csv_file': '/data/romain/data_exemple/file_p1.csv', 'type': torchio.LABEL},
                   'label2': {'csv_file': '/data/romain/data_exemple/file_p2.csv', 'type': torchio.LABEL},
                   'label3': {'csv_file': '/data/romain/data_exemple/file_p3.csv', 'type': torchio.LABEL},
                   'sampler': {'csv_file': '/data/romain/data_exemple/file_mask.csv', 'type': torchio.SAMPLING_MAP}}
paths_dict, info = get_subject_list_and_csv_info_from_data_prameters(data_parameters) #,shuffle_order=False)

landmarks_file = '/data/romain/data_exemple/landmarks_hcp100.npy'
transforms = (HistogramStandardization(landmarks_file, mask_field_name='sampler'),)

transforms = (RandomElasticDeformation(num_control_points=8, proportion_to_augment=1, deformation_std=25, image_interpolation=Interpolation.BSPLINE),)
transforms = (RandomMotion(seed=42, degrees=0, translation=15, num_transforms=2, verbose=True,proportion_to_augment=1),)
transforms = (RandomBiasField(coefficients_range=(-0.5, 0.5),order=3), )

transform = Compose(transforms) #should be done in ImagesDataset
dataset = ImagesDataset(paths_dict, transform=transform)
dataset_not = ImagesDataset(paths_dict, transform=None)
dataload = torch.utils.data.DataLoader(dataset, num_workers=0, batch_size=1)
dataloadnot = torch.utils.data.DataLoader(dataset_not, num_workers=0, batch_size=1)

ddd = dataset[0] #next(iter(dataset))
ii = np.squeeze( ddd['image']['data'][0], axis=1)

ddno = dataset_not[0] #

dd= next(iter(dataload))
ddno = next(iter(dataloadnot))

ii = np.squeeze( dd['image']['data'][0,0,:],axis=1)
iio = np.squeeze( ddno['image']['data'][0,0,:],axis=1)
コード例 #19
0
ファイル: example_motion.py プロジェクト: zhangjh705/torchio
Another way of getting this result is by running the command-line tool:

$ torchio-transform ~/Dropbox/MRI/t1.nii.gz RandomMotion /tmp/t1_motion.nii.gz --seed 42 --kwargs "degrees=10 translation=10 num_transforms=3 proportion_to_augment=1"

"""

from pprint import pprint
from torchio import Image, ImagesDataset, transforms, INTENSITY, LABEL, Subject

subject = Subject(
    Image('label', '~/Dropbox/MRI/t1_brain_seg.nii.gz', LABEL),
    Image('t1', '~/Dropbox/MRI/t1.nii.gz', INTENSITY),
)
subjects_list = [subject]

dataset = ImagesDataset(subjects_list)
sample = dataset[0]
transform = transforms.RandomMotion(
    seed=42,
    degrees=10,
    translation=10,
    num_transforms=3,
)
transformed = transform(sample)

pprint(transformed['t1']['random_motion_times'])
pprint(transformed['t1']['random_motion_degrees'])
pprint(transformed['t1']['random_motion_translation'])

dataset.save_sample(transformed, dict(t1='/tmp/t1_motion.nii.gz'))
dataset.save_sample(transformed, dict(label='/tmp/t1_brain_seg_motion.nii.gz'))
コード例 #20
0
def inference_padding(paths_dict, model, transformation, device, pred_path,
                      cp_path, opt):
    print("[INFO] Loading model.")
    model.load_state_dict(torch.load(cp_path))
    model.to(device)
    model.eval()

    subjects_dataset_inf = ImagesDataset(paths_dict, transform=transformation)

    nb_subject = len(subjects_dataset_inf)

    border = (0, 0, 0)

    print("[INFO] Starting Inference.")
    print('Number of subjects to infer: {}'.format(nb_subject))

    for index, batch in enumerate(subjects_dataset_inf):
        original_shape = batch['T1'][DATA].shape[1:]
        name = batch['T1']['stem'].split('T1')[0]
        affine = batch['T1']['affine']
        reference = torchio.utils.nib_to_sitk(batch['T1'][DATA].numpy(),
                                              affine)

        new_shape = []
        for i, dim in enumerate(original_shape):
            new_dim = dim if dim > patch_size[i] else patch_size[i]
            new_shape.append(new_dim)

        batch_pad = CenterCropOrPad(tuple(new_shape))(batch)
        affine_pad = batch_pad['T1']['affine']

        data = torch.cat([
            data_mod[DATA]
            for mod, data_mod in batch_pad.items() if mod in opt.modalities
        ], 0)

        nb_modalities = data.shape[0]
        assert nb_modalities in [
            1, 4
        ], 'Incorrect number of modalities for {}'.format(name.split('T1')[0])

        sampler = GridSampler(data, opt.window_size, border)
        aggregator = GridAggregator(data, border)
        loader = DataLoader(sampler, batch_size=1)

        with torch.no_grad():
            for batch_elemt in loader:
                locations = batch_elemt['location']
                input_tensor = {
                    'T1': batch_elemt['image'][:, :1, ...].to(device)
                }
                if nb_modalities == 4:
                    input_tensor['all'] = batch_elemt['image'].to(device)
                logits, _ = model(input_tensor)
                labels = logits.argmax(dim=1, keepdim=True)
                outputs = labels
                aggregator.add_batch(outputs, locations)

        output = aggregator.output_array.astype(float)
        output = torchio.utils.nib_to_sitk(output, affine_pad)
        output = sitk.Resample(
            output,
            reference,
            sitk.Transform(),
            sitk.sitkNearestNeighbor,
        )
        sitk.WriteImage(output, pred_path.format(name))
        print('{}/{} - Inference done for {} with {} modalities'.format(
            index, nb_subject, name, nb_modalities))
コード例 #21
0
 def test_wrong_transform_init(self):
     with self.assertRaises(ValueError):
         dataset = ImagesDataset(
             self.subjects_list,
             transform=dict(),
         )
コード例 #22
0
 def iterate_dataset(subjects_list):
     dataset = ImagesDataset(subjects_list)
     for _ in dataset:
         pass
コード例 #23
0
    Image('T2', '../BRATS2018_crop_renamed/LGG75_T2.nii.gz', torchio.INTENSITY),
    Image('label', '../BRATS2018_crop_renamed/LGG75_Label.nii.gz', torchio.LABEL),
)

# This subject doesn't have a T2 MRI!
another_subject = Subject(
    Image('T1', '../BRATS2018_crop_renamed/LGG74_T1.nii.gz', torchio.INTENSITY),
    Image('label', '../BRATS2018_crop_renamed/LGG74_Label.nii.gz', torchio.LABEL),
)

subjects = [
    one_subject,
    another_subject,
]

subjects_dataset = ImagesDataset(subjects)
queue_dataset = Queue(
    subjects_dataset,
    queue_length,
    samples_per_volume,
    patch_size,
    ImageSampler,
)

# This collate_fn is needed in the case of missing modalities
# In this case, the batch will be composed by a *list* of samples instead of
# the typical Python dictionary that is collated by default in Pytorch
batch_loader = DataLoader(
    queue_dataset,
    batch_size=batch_size,
    collate_fn=lambda x: x,
コード例 #24
0
subject_list = list()

idx = 0
for idx,(train_image, train_label) in enumerate(zip(train_images, train_labels)):
    image_path = os.path.join(train_images_folder, train_image)
    label_path = os.path.join(train_labels_folder, train_label)

    s1 = torchio.Subject(
        t1    = Image(type=torchio.INTENSITY, path=image_path),
        label = Image(type=torchio.LABEL, path=label_path),
    )

    subject_list.append(s1)

subjects_dataset = ImagesDataset(subject_list, transform=transform)
subject_sample = subjects_dataset[0]

for idx in range(0,len(train_images[:10])):
    subject_sample = subjects_dataset[idx]
    print("Iter {} on {}".format(idx+1,len(train_images)))
    print("t1.shape       = {}".format(subject_sample.t1.shape))
    print("label.shape    = {}".format(subject_sample.label.shape))
    print("t1 [min - max] = [{:.1f} : {:.1f}]".format(subject_sample.t1.data.min(),subject_sample.t1.data.max()))
    print("label.unique   = {}".format(subject_sample.label.data.unique()))


config = SemSegMRIConfig()
train_data = torch.utils.data.DataLoader(subjects_dataset, batch_size=config.batch_size,
                                         shuffle=False, num_workers=config.num_workers)
コード例 #25
0
 def test_images_dataset_deprecated(self):
     with self.assertWarns(DeprecationWarning):
         ImagesDataset(self.subjects_list)
コード例 #26
0
from torchio import Subject, Image, ImagesDataset
from torchio.transforms import RandomMotionFromTimeCourse
from torchio.metrics import SSIM3D, MetricWrapper, MapMetricWrapper
from torchio.metrics.ssim import functional_ssim
from torch.nn import MSELoss, L1Loss
import torch
from nibabel.viewers import OrthoSlicer3D as ov

t1_path = "/data/romain/HCPdata/suj_100307/T1w_acpc_dc_restore.nii"
mask_path = "/data/romain/HCPdata/suj_100307/cat12/fill_mask_head.nii.gz"
dataset = ImagesDataset([
    Subject({
        "T1": Image(t1_path),
        "mask": Image(mask_path, type="mask"),
        "mask2": Image(mask_path, type="mask")
    })
])

metrics = {
    "L1":
    MetricWrapper("L1", L1Loss()),
    "L1_map":
    MapMetricWrapper("L1_map",
                     lambda x, y: torch.abs(x - y),
                     average_method="mean",
                     mask_keys=['mask2']),
    "L2":
    MetricWrapper("L2", MSELoss()),
    #"SSIM": SSIM3D(average_method="mean"),
    "SSIM_mask":
    SSIM3D(average_method="mean", mask_keys=["mask", "mask2"]),
コード例 #27
0
another_subject_dict = {
    'T1':
    dict(path='../BRATS2018_crop_renamed/LGG74_T1.nii.gz',
         type=torchio.INTENSITY),
    'label':
    dict(path='../BRATS2018_crop_renamed/LGG74_Label.nii.gz',
         type=torchio.LABEL),
}

subjects_paths = [
    one_subject_dict,
    another_subject_dict,
]

subjects_dataset = ImagesDataset(subjects_paths, transform=transform)

# Run a benchmark for different numbers of workers
workers = range(mp.cpu_count() + 1)
for num_workers in workers:
    print('Number of workers:', num_workers)

    # Define the dataset as a queue of patches
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        samples_per_volume,
        patch_size,
        ImageSampler,
        num_workers=num_workers,
        shuffle_subjects=False,
コード例 #28
0
 def test_no_load(self):
     dataset = ImagesDataset(self.subjects_list, load_image_data=False)
     for sample in dataset:
         pass
コード例 #29
0
from torchio import Image, ImagesDataset, transforms, INTENSITY, LABEL
from torchvision.transforms import Compose
import numpy as np
from nibabel.viewers import OrthoSlicer3D as ov
from copy import deepcopy

np.random.seed(12)

out_dir = '/data/ghiles/'

subject = [[
    Image('T1', '/data/romain/HCPdata/suj_100307/T1w_1mm.nii.gz', INTENSITY),
    Image('mask', '/data/romain/HCPdata/suj_100307/brain_mT1w_1mm.nii', LABEL)
]]
subjects_list = [subject]
dataset = ImagesDataset(subject)
sample = dataset[0]
#sample = deepcopy(sample_orig)

nT = 100
time_points = [.55, 1.0]
fitpars = np.zeros((6, nT))

fitpars[1, 55:] = -15
#fitpars[dim_modif, :45] = -7.5
#fitpars[dim_modif, 45:] = 7.5

#ov(sample["T1"]["data"][0], sample["T1"]["affine"])

transform = RandomMotionTimeCourseAffines(fitpars=fitpars,
                                          time_points=time_points,
コード例 #30
0
from pprint import pprint
from torchio import ImagesDataset, transforms, INTENSITY

paths = [{
    't1': dict(path='~/Dropbox/MRI/t1.nii.gz', type=INTENSITY),
    'colin': dict(path='/tmp/colin27_t1_tal_lin.nii.gz', type=INTENSITY),
}]

dataset = ImagesDataset(paths)
sample = dataset[0]
transform = transforms.RandomMotion(
    seed=42,
    degrees=20,
    translation=15,
    num_transforms=3,
    verbose=True,
)
transformed = transform(sample)

pprint(transformed['t1']['random_motion_times'])

dataset.save_sample(transformed, dict(t1='/tmp/t1_motion.nii.gz'))
dataset.save_sample(transformed, dict(colin='/tmp/colin_motion.nii.gz'))