def get_transforms(self): transforms = Compose( RandomFlip3D(), RandomRotate(), ElasticTransform(alpha=2000., sigma=50.), # Hard coded for now AsTorchBatch(2)) return transforms
def get_additional_transforms(self, master_config): transforms = self.transforms if self.transforms is not None else Compose( ) master_config = {} if master_config is None else master_config # Replicate and downscale batch: if master_config.get("downscale_and_crop") is not None: ds_config = master_config.get("downscale_and_crop") apply_to = [conf.pop('apply_to') for conf in ds_config] transforms.add(ReplicateTensorsInBatch(apply_to)) for indx, conf in enumerate(ds_config): transforms.add( DownSampleAndCropTensorsInBatch(apply_to=[indx], order=None, **conf)) # crop invalid affinity labels and elastic augment reflection padding assymetrically crop_config = master_config.get('crop_after_target', {}) if crop_config: # One might need to crop after elastic transform to avoid edge artefacts of affinity # computation being warped into the FOV. transforms.add(VolumeAsymmetricCrop(**crop_config)) transforms.add(AsTorchBatch(3, add_channel_axis_if_necessary=True)) return transforms
def get_additional_transforms(self, transform_config): transforms = self.transforms if self.transforms is not None else Compose( ) stack_scaling_factors = transform_config["stack_scaling_factors"] # Replicate and downscale batch: num_inputs = len(stack_scaling_factors) input_indices = list(range(num_inputs)) transforms.add(ReplicateBatch(num_inputs)) inv_scaling_facts = deepcopy(stack_scaling_factors) inv_scaling_facts.reverse() for in_idx, dws_fact, crop_fact in zip(input_indices, stack_scaling_factors, inv_scaling_facts): transforms.add( DownsampleAndCrop3D(apply_to=[in_idx], order=2, zoom_factor=dws_fact, crop_factor=crop_fact)) transforms.add(AsTorchBatch(3)) return transforms
def get_transforms(self): transforms = Compose() if self.transform_config.get('random_flip', False): transforms.add(RandomFlip3D()) transforms.add(RandomRotate()) # Elastic transforms can be skipped by # setting elastic_transform to false in the # yaml config file. if self.transform_config.get('elastic_transform'): elastic_transform_config = self.transform_config.get( 'elastic_transform') if elastic_transform_config.get('apply', False): transforms.add( ElasticTransform( alpha=elastic_transform_config.get('alpha', 2000.), sigma=elastic_transform_config.get('sigma', 50.), order=elastic_transform_config.get('order', 0))) # Replicate and downscale batch: nb_inputs = 1 if self.transform_config.get("downscale_and_crop") is not None: ds_config = self.transform_config.get("downscale_and_crop") apply_to = [conf.pop('apply_to') for conf in ds_config] nb_inputs = (np.array(apply_to) == 0).sum() transforms.add(ReplicateTensorsInBatch(apply_to)) for indx, conf in enumerate(ds_config): transforms.add( DownSampleAndCropTensorsInBatch(apply_to=[indx], order=None, **conf)) # Check if to compute binary-affinity-targets from GT labels: if self.transform_config.get("affinity_config") is not None: affs_config = deepcopy( self.transform_config.get("affinity_config")) global_kwargs = affs_config.pop("global", {}) aff_transform = Segmentation2AffinitiesDynamicOffsets if affs_config.pop("use_dynamic_offsets", False) \ else affinity_config_to_transform for input_index in affs_config: affs_kwargs = deepcopy(global_kwargs) affs_kwargs.update(affs_config[input_index]) transforms.add( aff_transform(apply_to=[input_index + nb_inputs], **affs_kwargs)) # crop invalid affinity labels and elastic augment reflection padding assymetrically crop_config = self.transform_config.get('crop_after_target', {}) if crop_config: # One might need to crop after elastic transform to avoid edge artefacts of affinity # computation being warped into the FOV. transforms.add(VolumeAsymmetricCrop(**crop_config)) transforms.add(AsTorchBatch(3)) return transforms
def get_transforms(self): all_transforms = [RandomRotate()] if 'elastic_transform' in self.master_config: all_transforms.append(ElasticTransform(**self.master_config.get('elastic_transform', {}))) if self.master_config.get('crop_after_elastic_transform', False): all_transforms\ .append(CenterCrop(**self.master_config.get('crop_after_elastic_transform'))) all_transforms.append(AsTorchBatch(2)) transforms = Compose(*all_transforms) return transforms
def get_additional_transforms(self, master_config): transforms = self.transforms if self.transforms is not None else Compose( ) master_config = {} if master_config is None else master_config # TODO: somehow merge with the trainer loader... # Replicate and downscale batch: if master_config.get("downscale_and_crop") is not None: ds_config = master_config.get("downscale_and_crop") apply_to = [conf.pop('apply_to') for conf in ds_config] transforms.add(ReplicateBatchGeneralized(apply_to)) for indx, conf in enumerate(ds_config): transforms.add( DownsampleAndCrop3D(apply_to=[indx], order=None, **conf)) # # # affinity transforms for affinity targets # # # we apply the affinity target calculation only to the segmentation (1) # if master_config.get("affinity_config") is not None: # affs_config = master_config.get("affinity_config") # global_kwargs = affs_config.pop("global", {}) # # TODO: define computed affs not in this way, but with a variable in config... # nb_affs = len(affs_config) # assert nb_affs == num_inputs # # all_affs_kwargs = [deepcopy(global_kwargs) for _ in range(nb_affs)] # for input_index in affs_config: # affs_kwargs = deepcopy(global_kwargs) # affs_kwargs.update(affs_config[input_index]) # transforms.add(affinity_config_to_transform(apply_to=[input_index+num_inputs], **affs_kwargs)) # crop invalid affinity labels and elastic augment reflection padding assymetrically crop_config = master_config.get('crop_after_target', {}) if crop_config: # One might need to crop after elastic transform to avoid edge artefacts of affinity # computation being warped into the FOV. transforms.add(VolumeAsymmetricCrop(**crop_config)) transforms.add(AsTorchBatch(3, add_channel_axis_if_necessary=True)) # transforms.add(CheckBatchAndChannelDim(3)) return transforms
def get_transforms(self): transforms = AsTorchBatch(3) return transforms
def get_additional_transforms(self): transforms = self.transforms if self.transforms is not None else Compose( ) transforms.add(AsTorchBatch(3)) return transforms