def cmr_3d_sax_transform(self):

        train_transform = ts.Compose([
            ts.Pad(size=self.scale_size),
            ts.ToTensor(),
            ts.ChannelsFirst(),
            ts.TypeCast(['float', 'float']),
            ts.RandomFlip(h=True, v=True, p=self.random_flip_prob),
            ts.RandomAffine(rotation_range=self.rotate_val,
                            translation_range=self.shift_val,
                            zoom_range=self.scale_val,
                            interp=('bilinear', 'nearest')),
            #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            ts.NormalizeMedic(norm_flag=(True, False)),
            ts.ChannelsLast(),
            ts.AddChannel(axis=0),
            ts.RandomCrop(size=self.patch_size),
            ts.TypeCast(['float', 'long'])
        ])

        valid_transform = ts.Compose([
            ts.Pad(size=self.scale_size),
            ts.ToTensor(),
            ts.ChannelsFirst(),
            ts.TypeCast(['float', 'float']),
            #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            ts.NormalizeMedic(norm_flag=(True, False)),
            ts.ChannelsLast(),
            ts.AddChannel(axis=0),
            ts.SpecialCrop(size=self.patch_size, crop_type=0),
            ts.TypeCast(['float', 'long'])
        ])

        return {'train': train_transform, 'valid': valid_transform}
Exemplo n.º 2
0
    def gsd_pCT_train_transform(self, seed=None):
        if seed is None:
            seed = np.random.randint(0, 9999)  # seed must be an integer for torch

        train_transform = ts.Compose([
            ts.ToTensor(),
            ts.Pad(size=self.scale_size),
            ts.TypeCast(['float', 'float']),
            RandomFlipTransform(axes=self.flip_axis, p=self.random_flip_prob, seed=seed, max_output_channels=self.max_output_channels),
            RandomElasticTransform(seed=seed, p=self.random_elastic_prob, image_interpolation=Interpolation.BSPLINE, max_displacement=self.max_deform,
                                   num_control_points=self.elastic_control_points,  max_output_channels=self.max_output_channels),
            RandomAffineTransform(scales = self.scale_val, degrees = (self.rotate_val), isotropic = True, default_pad_value = 0,
                        image_interpolation = Interpolation.BSPLINE, seed=seed, p=self.random_affine_prob, max_output_channels=self.max_output_channels),
            RandomNoiseTransform(p=self.random_noise_prob, std=self.noise_std, seed=seed, max_output_channels=self.max_output_channels),
            ts.ChannelsFirst(),
            # ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            # Todo apply channel wise normalisation
            ts.NormalizeMedic(norm_flag=(True, False)),
            # Todo eventually add random crop augmentation (fork torchsample and fix the Random Crop bug)
            # ts.ChannelsLast(), # seems to be needed for crop
            # ts.RandomCrop(size=self.patch_size),
            ts.TypeCast(['float', 'long'])
        ])

        return train_transform
Exemplo n.º 3
0
    def test_3d_sax_transform(self):
        test_transform = ts.Compose([
            ts.PadFactorNumpy(factor=self.division_factor),
            ts.ToTensor(),
            ts.ChannelsFirst(),
            ts.TypeCast(['float']),
            #ts.NormalizeMedicPercentile(norm_flag=True),
            ts.NormalizeMedic(norm_flag=True),
            ts.ChannelsLast(),
            ts.AddChannel(axis=0),
        ])

        return {'test': test_transform}
Exemplo n.º 4
0
    def gsd_pCT_valid_transform(self, seed=None):
        valid_transform = ts.Compose([
            ts.ToTensor(),
            ts.Pad(size=self.scale_size),
            ts.ChannelsFirst(),
            ts.TypeCast(['float', 'float']),
            # ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            ts.NormalizeMedic(norm_flag=(True, False)),
            # ts.ChannelsLast(),
            # ts.SpecialCrop(size=self.patch_size, crop_type=0),
            ts.TypeCast(['float', 'long'])
        ])

        return valid_transform
Exemplo n.º 5
0
    def gsd_pCT_transform(self):
        '''
        Data augmentation transformations for the Geneva Stroke dataset (pCT maps)
        :return:
        '''

        train_transform = ts.Compose([
            ts.ToTensor(),
            ts.Pad(size=self.scale_size),
            ts.TypeCast(['float', 'float']),
            ts.RandomFlip(h=True, v=True, p=self.random_flip_prob),
            # Todo Random Affine doesn't support channels --> try newer version of torchsample or torchvision
            # ts.RandomAffine(rotation_range=self.rotate_val, translation_range=self.shift_val,
            #                 zoom_range=self.scale_val, interp=('bilinear', 'nearest')),
            ts.ChannelsFirst(),
            #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            # Todo apply channel wise normalisation
            ts.NormalizeMedic(norm_flag=(True, False)),
            # Todo fork torchsample and fix the Random Crop bug
            # ts.ChannelsLast(), # seems to be needed for crop
            # ts.RandomCrop(size=self.patch_size),
            ts.TypeCast(['float', 'long'])
        ])

        valid_transform = ts.Compose([
            ts.ToTensor(),
            ts.Pad(size=self.scale_size),
            ts.ChannelsFirst(),
            ts.TypeCast(['float', 'float']),
            #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
            ts.NormalizeMedic(norm_flag=(True, False)),
            # ts.ChannelsLast(),
            # ts.SpecialCrop(size=self.patch_size, crop_type=0),
            ts.TypeCast(['float', 'long'])
        ])

        # train_transform = ts.Compose([
        #     ts.ToTensor(),
        #     ts.Pad(size=self.scale_size),
        #                               ts.ChannelsFirst(),
        #                               ts.TypeCast(['float', 'long'])
        # ])
        # valid_transform = ts.Compose([
        #                               ts.ToTensor(),
        #     ts.Pad(size=self.scale_size),
        #     ts.ChannelsFirst(),
        #     ts.TypeCast(['float', 'long'])
        #
        # ])

        # train_transform = tf.Compose([
        #     tf.Pad(1),
        #     tf.Lambda(lambda a: a.permute(3, 0, 1, 2)),
        #     tf.Lambda(lambda a: a.float()),
        # ])
        # valid_transform = tf.Compose([
        #     tf.Pad(1),
        #     tf.Lambda(lambda a: a.permute(3, 0, 1, 2)),
        #     tf.Lambda(lambda a: a.float()),
        #
        # ])

        return {'train': train_transform, 'valid': valid_transform}