Ejemplo n.º 1
0
    def __init__(self,
                 path,
                 defect_augmentation_config,
                 name=None,
                 path_in_file=None,
                 data_slice=None,
                 dtype='float32',
                 ignore_slice_list=None,
                 mean=None,
                 std=None,
                 sigma=None,
                 zero_mean_unit_variance=True,
                 p_augment_ws=0.,
                 **slicing_config):

        super().__init__(path=path,
                         path_in_file=path_in_file,
                         data_slice=data_slice,
                         name=name,
                         dtype=dtype,
                         mean=mean,
                         std=std,
                         sigma=sigma,
                         p_augment_ws=p_augment_ws,
                         zero_mean_unit_variance=zero_mean_unit_variance,
                         **slicing_config)

        defect_augmentation_config = yaml2dict(defect_augmentation_config)
        defect_augmentation_config.update(
            {'ignore_slice_list': ignore_slice_list})
        self.defect_augmentation = DefectAugmentation.from_config(
            defect_augmentation_config)
        self.cast = Cast(self.dtype)
Ejemplo n.º 2
0
 def get_transforms(self):
     # no NEDT inversion for ISBI since labels give neuron rather than boundary probabilities
     transforms = Compose(
         NegativeExponentialDistanceTransform(gain=self.nedt_gain, invert=False),
         Cast(self.dtype)
     )
     return transforms
Ejemplo n.º 3
0
 def get_transforms(self):
     transforms = Compose(
         Normalize(),
         # after normalize since raw data comes in uint8
         AdditiveGaussianNoise(sigma=.025),
         Cast(self.dtype))
     return transforms
Ejemplo n.º 4
0
def test_full_pipeline():
    import h5py
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    from inferno.io.transform import Compose
    from inferno.io.transform.generic import Normalize, Cast, AsTorchBatch
    #tiktorch = TikTorch('/export/home/jhugger/sfb1129/test_configs_tiktorch/config/')
    tiktorch = TikTorch('/home/jo/config/')

    #with h5py.File('/export/home/jhugger/sfb1129/sample_C_20160501.hdf') as f:
    with h5py.File('/home/jo/sfb1129/sample_C_20160501.hdf') as f:
        cremi_raw = f['volumes']['raw'][:, 0:512, 0:512]

    transform = Compose(Normalize(), Cast('float32'))
    inputs = [transform(cremi_raw[i:i + 1]) for i in range(1)]

    halo = tiktorch.halo
    max_shape = tiktorch.dry_run([512, 512])

    print(f'Halo: {halo}')
    print(f'max_shape: {max_shape}')

    out = tiktorch.forward(inputs)

    return 0
Ejemplo n.º 5
0
 def get_transforms(self):
     transforms = []
     if self.label_volume:
         transforms.append(ConnectedComponents3D())
     if self.binarize:
         transforms.append(BinarizeSegmentation())
     transforms.append(Cast(self.dtype))
     return Compose(*transforms)
Ejemplo n.º 6
0
    def make_transforms(self):
        transforms = Compose(PadTo(self.window_size), RandomFlip3D(), RandomRotate())
        if self.master_config.get('elastic_transform'):
            elastic_transform_config = self.master_config.get('elastic_transform')
            transforms.add(ElasticTransform(alpha=elastic_transform_config.get('alpha', 2000.),
                                            sigma=elastic_transform_config.get('sigma', 50.),
                                            order=elastic_transform_config.get('order', 0)))

        # affinity transforms for affinity targets
        # we apply the affinity target calculation only to the segmentation (1)
        affinity_config = self.master_config.get('affinity_config', None)

        # Do we also train with semantic labels ?
        train_semantic = self.master_config.get('train_semantic', False)

        if affinity_config is None:
            if train_semantic:
                transforms.add(Semantics(apply_to=[1]))
                self.label_transforms = None
            else:
                self.label_transforms = Cast('float32')
        elif affinity_config == 'distances':
            # TODO read the bandwidths from the config
            self.label_transforms = Compose(Cast('int64'), ConnectedComponents3D())
            from ..transforms.distance_transform import SignedDistanceTransform
            transforms.add(SignedDistanceTransform(fg_bandwidth=8,
                                                   bg_bandwidth=32,
                                                   apply_to=[1]))
        else:
            if train_semantic:
                # we can't apply connected components yet if we train semantics and affinities
                self.label_transforms = Cast('int64')
                transforms.add(SemanticsAndAffinities(affinity_config, apply_to=[1]))
            else:
                self.label_transforms = Compose(Cast('int64'), ConnectedComponents3D())
                transforms.add(affinity_config_to_transform(apply_to=[1], **affinity_config))

        self.transforms = transforms
        sigma = 0.025
        self.raw_transforms = Compose(Cast('float32'), Normalize(), AdditiveNoise(sigma=sigma))
Ejemplo n.º 7
0
    def test_model(self):
        self.setUp()
        shape = self.handler.binary_dry_run([1250, 1250])
        transform = Compose(Normalize(), Cast('float32'))

        with h5py.File(
                '/export/home/jhugger/sfb1129/sample_C_20160501.hdf') as f:
            #with h5py.File('/home/jo/sfb1129/sample_C_20160501.hdf') as f:
            cremi_raw = f['volumes']['raw'][0:1, 0:shape[0], 0:shape[1]]

        input_tensor = torch.from_numpy(transform(cremi_raw[0:1]))
        out = self.handler.forward(torch.unsqueeze(input_tensor, 0))
        import scipy
        scipy.misc.imsave('/export/home/jhugger/sfb1129/tiktorch/out.jpg',
                          out[0, 0].data.cpu().numpy())
Ejemplo n.º 8
0
 def get_transforms(self, mean, std, sigma, p_augment_ws,
                    zero_mean_unit_variance):
     transforms = Compose(Cast(self.dtype))
     # add normalization (zero mean / unit variance)
     if zero_mean_unit_variance:
         transforms.add(Normalize(mean=mean, std=std))
     else:
         transforms.add(Normalize01())
     # add noise transform if specified
     if sigma is not None:
         transforms.add(AdditiveNoise(sigma=sigma))
     # add watershed super-pixel augmentation is specified
     if p_augment_ws > 0.:
         assert WatershedAugmentation is not None
         transforms.add(WatershedAugmentation(p_augment_ws, invert=True))
     return transforms
Ejemplo n.º 9
0
 def get_transforms(self):
     # The Segmentation2Affinities adds a channel dimension. Now depending on how many
     # orders were requested, we dispatch Segmentation2Affinities or
     # Segmentation2MultiOrderAffinities.
     transforms = Compose()
     # Cast to the right dtype
     transforms.add(Cast(self.dtype))
     # Run connected components to shuffle the labels
     transforms.add(ConnectedComponents3D(label_segmentation=True))
     # Make affinity maps
     transforms.add(
         Segmentation2MultiOrderAffinities(
             dim=self.affinity_dim,
             orders=pyu.to_iterable(self.affinity_order),
             add_singleton_channel_dimension=True,
             retain_segmentation=self.retain_segmentation))
     return transforms
Ejemplo n.º 10
0
def test_dunet():
    import h5py
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    from inferno.io.transform import Compose
    from inferno.io.transform.generic import Normalize, Cast, AsTorchBatch
    #tiktorch = TikTorch('/export/home/jhugger/sfb1129/test_configs_tiktorch/config/')
    tiktorch = TikTorch('/home/jo/config/')

    #with h5py.File('/export/home/jhugger/sfb1129/sample_C_20160501.hdf') as f:
    with h5py.File('/home/jo/sfb1129/sample_C_20160501.hdf') as f:
        cremi_raw = f['volumes']['raw'][:, 0:1024, 0:1024]

    transform = Compose(Normalize(), Cast('float32'))
    tikin_list = [TikIn([transform(cremi_raw[i:i + 1]) for i in range(1)])]
    inputs = [transform(cremi_raw[i:i + 1]) for i in range(2)]

    out = tiktorch.forward(inputs)
    return 0
Ejemplo n.º 11
0
    def test_model(self):
        self.setUp()
        # shape = self.handler.binary_dry_run([2000, 2000])
        transform = Compose(Normalize(), Cast("float32"))

        # with h5py.File('/export/home/jhugger/sfb1129/sample_C_20160501.hdf') as f:
        with h5py.File(
                "/export/home/jhugger/sfb1129/sample_C_20160501.hdf") as f:
            cremi_raw = f["volumes"]["raw"][0:1, 0:1248, 0:1248]

        input_tensor = torch.from_numpy(transform(cremi_raw[0:1]))
        input_tensor = torch.rand(1, 572, 572)
        print(torch.unsqueeze(input_tensor, 0).shape)
        out = self.handler.forward(torch.unsqueeze(input_tensor, 0))
        import scipy

        scipy.misc.imsave("/export/home/jhugger/sfb1129/tiktorch/out.jpg",
                          out[0, 0].data.cpu().numpy())
        scipy.misc.imsave("/home/jo/server/tiktorch/out.jpg",
                          out[0, 0].data.cpu().numpy())
Ejemplo n.º 12
0
 def get_transforms(self, mean, std):
     transforms = Compose(Cast(self.dtype), Normalize(mean=mean, std=std))
     return transforms
Ejemplo n.º 13
0
 def get_transforms(self):
     transforms = Compose(
         NegativeExponentialDistanceTransform(gain=self.nedt_gain, invert=False),
         Cast(self.dtype)
     )
     return transforms
Ejemplo n.º 14
0
 def get_transforms(self):
     transforms = Compose(
         Segmentation2Membranes(dtype=self.dtype),
         NegativeExponentialDistanceTransform(gain=self.nedt_gain),
         Cast(self.dtype))
     return transforms
Ejemplo n.º 15
0
 def get_transforms(self):
     if self.apply_on_image:
         transforms = Compose(ConnectedComponents2D(), Cast(self.dtype))
     else:
         transforms = Compose(ConnectedComponents3D(), Cast(self.dtype))
     return transforms
Ejemplo n.º 16
0
 def get_transforms(self):
     transforms = Compose(ConnectedComponents3D(label_segmentation=True),
                          Cast(self.dtype))
     return transforms
Ejemplo n.º 17
0
 def get_transforms(self):
     if self.label_volume:
         transforms = Compose(ConnectedComponents3D(), Cast(self.dtype))
     else:
         transforms = Cast(self.dtype)
     return transforms
Ejemplo n.º 18
0
 def get_transforms(self):
     transforms = Compose(Cast(self.dtype))
     return transforms
Ejemplo n.º 19
0
 def get_transforms(self):
     transforms = Compose(Normalize(), Cast(self.dtype))
     return transforms