コード例 #1
0
ファイル: master.py プロジェクト: vzinche/neurofire
 def get_transforms(self):
     transforms = Compose(
         RandomFlip3D(),
         RandomRotate(),
         ElasticTransform(alpha=2000., sigma=50.),  # Hard coded for now
         AsTorchBatch(2))
     return transforms
コード例 #2
0
ファイル: cremi.py プロジェクト: jamiegrieser/segmfriends
    def get_additional_transforms(self, master_config):
        transforms = self.transforms if self.transforms is not None else Compose(
        )

        master_config = {} if master_config is None else master_config

        # Replicate and downscale batch:
        if master_config.get("downscale_and_crop") is not None:
            ds_config = master_config.get("downscale_and_crop")
            apply_to = [conf.pop('apply_to') for conf in ds_config]
            transforms.add(ReplicateTensorsInBatch(apply_to))
            for indx, conf in enumerate(ds_config):
                transforms.add(
                    DownSampleAndCropTensorsInBatch(apply_to=[indx],
                                                    order=None,
                                                    **conf))

        # crop invalid affinity labels and elastic augment reflection padding assymetrically
        crop_config = master_config.get('crop_after_target', {})
        if crop_config:
            # One might need to crop after elastic transform to avoid edge artefacts of affinity
            # computation being warped into the FOV.
            transforms.add(VolumeAsymmetricCrop(**crop_config))

        transforms.add(AsTorchBatch(3, add_channel_axis_if_necessary=True))

        return transforms
コード例 #3
0
    def get_additional_transforms(self, transform_config):
        transforms = self.transforms if self.transforms is not None else Compose(
        )

        stack_scaling_factors = transform_config["stack_scaling_factors"]

        # Replicate and downscale batch:
        num_inputs = len(stack_scaling_factors)
        input_indices = list(range(num_inputs))

        transforms.add(ReplicateBatch(num_inputs))
        inv_scaling_facts = deepcopy(stack_scaling_factors)
        inv_scaling_facts.reverse()
        for in_idx, dws_fact, crop_fact in zip(input_indices,
                                               stack_scaling_factors,
                                               inv_scaling_facts):
            transforms.add(
                DownsampleAndCrop3D(apply_to=[in_idx],
                                    order=2,
                                    zoom_factor=dws_fact,
                                    crop_factor=crop_fact))

        transforms.add(AsTorchBatch(3))

        return transforms
コード例 #4
0
    def get_transforms(self):
        transforms = Compose()

        if self.transform_config.get('random_flip', False):
            transforms.add(RandomFlip3D())
            transforms.add(RandomRotate())

        # Elastic transforms can be skipped by
        # setting elastic_transform to false in the
        # yaml config file.
        if self.transform_config.get('elastic_transform'):
            elastic_transform_config = self.transform_config.get(
                'elastic_transform')
            if elastic_transform_config.get('apply', False):
                transforms.add(
                    ElasticTransform(
                        alpha=elastic_transform_config.get('alpha', 2000.),
                        sigma=elastic_transform_config.get('sigma', 50.),
                        order=elastic_transform_config.get('order', 0)))

        # Replicate and downscale batch:
        nb_inputs = 1
        if self.transform_config.get("downscale_and_crop") is not None:
            ds_config = self.transform_config.get("downscale_and_crop")
            apply_to = [conf.pop('apply_to') for conf in ds_config]
            nb_inputs = (np.array(apply_to) == 0).sum()
            transforms.add(ReplicateTensorsInBatch(apply_to))
            for indx, conf in enumerate(ds_config):
                transforms.add(
                    DownSampleAndCropTensorsInBatch(apply_to=[indx],
                                                    order=None,
                                                    **conf))

        # Check if to compute binary-affinity-targets from GT labels:
        if self.transform_config.get("affinity_config") is not None:
            affs_config = deepcopy(
                self.transform_config.get("affinity_config"))
            global_kwargs = affs_config.pop("global", {})

            aff_transform = Segmentation2AffinitiesDynamicOffsets if affs_config.pop("use_dynamic_offsets", False) \
                else affinity_config_to_transform

            for input_index in affs_config:
                affs_kwargs = deepcopy(global_kwargs)
                affs_kwargs.update(affs_config[input_index])
                transforms.add(
                    aff_transform(apply_to=[input_index + nb_inputs],
                                  **affs_kwargs))

        # crop invalid affinity labels and elastic augment reflection padding assymetrically
        crop_config = self.transform_config.get('crop_after_target', {})
        if crop_config:
            # One might need to crop after elastic transform to avoid edge artefacts of affinity
            # computation being warped into the FOV.
            transforms.add(VolumeAsymmetricCrop(**crop_config))

        transforms.add(AsTorchBatch(3))
        return transforms
コード例 #5
0
ファイル: artifact_source.py プロジェクト: vzinche/neurofire
 def get_transforms(self):
     all_transforms = [RandomRotate()]
     if 'elastic_transform' in self.master_config:
         all_transforms.append(ElasticTransform(**self.master_config.get('elastic_transform',
                                                                         {})))
     if self.master_config.get('crop_after_elastic_transform', False):
         all_transforms\
             .append(CenterCrop(**self.master_config.get('crop_after_elastic_transform')))
     all_transforms.append(AsTorchBatch(2))
     transforms = Compose(*all_transforms)
     return transforms
コード例 #6
0
    def get_additional_transforms(self, master_config):
        transforms = self.transforms if self.transforms is not None else Compose(
        )

        master_config = {} if master_config is None else master_config
        # TODO: somehow merge with the trainer loader...

        # Replicate and downscale batch:
        if master_config.get("downscale_and_crop") is not None:
            ds_config = master_config.get("downscale_and_crop")
            apply_to = [conf.pop('apply_to') for conf in ds_config]
            transforms.add(ReplicateBatchGeneralized(apply_to))
            for indx, conf in enumerate(ds_config):
                transforms.add(
                    DownsampleAndCrop3D(apply_to=[indx], order=None, **conf))

        # # # affinity transforms for affinity targets
        # # # we apply the affinity target calculation only to the segmentation (1)
        # if master_config.get("affinity_config") is not None:
        #     affs_config = master_config.get("affinity_config")
        #     global_kwargs = affs_config.pop("global", {})
        #     # TODO: define computed affs not in this way, but with a variable in config...
        #     nb_affs = len(affs_config)
        #     assert nb_affs == num_inputs
        #     # all_affs_kwargs = [deepcopy(global_kwargs) for _ in range(nb_affs)]
        #     for input_index in affs_config:
        #         affs_kwargs = deepcopy(global_kwargs)
        #         affs_kwargs.update(affs_config[input_index])
        #         transforms.add(affinity_config_to_transform(apply_to=[input_index+num_inputs], **affs_kwargs))

        # crop invalid affinity labels and elastic augment reflection padding assymetrically
        crop_config = master_config.get('crop_after_target', {})
        if crop_config:
            # One might need to crop after elastic transform to avoid edge artefacts of affinity
            # computation being warped into the FOV.
            transforms.add(VolumeAsymmetricCrop(**crop_config))

        transforms.add(AsTorchBatch(3, add_channel_axis_if_necessary=True))

        # transforms.add(CheckBatchAndChannelDim(3))

        return transforms
コード例 #7
0
ファイル: cremi.py プロジェクト: c-laun/quantized_vector_DT
 def get_transforms(self):
     transforms = AsTorchBatch(3)
     return transforms
コード例 #8
0
 def get_additional_transforms(self):
     transforms = self.transforms if self.transforms is not None else Compose(
     )
     transforms.add(AsTorchBatch(3))
     return transforms