Ejemplo n.º 1
0
def custom_unet(input_shape,
                last_activation,
                n_depth=2,
                n_filter_base=16,
                kernel_size=(3, 3, 3),
                n_conv_per_depth=2,
                activation="relu",
                batch_norm=False,
                dropout=0.0,
                pool_size=(2, 2, 2),
                n_channel_out=1,
                residual=False,
                prob_out=False,
                eps_scale=1e-3):
    """ TODO """

    if last_activation is None:
        raise ValueError(
            "last activation has to be given (e.g. 'sigmoid', 'relu')!")

    all((s % 2 == 1 for s in kernel_size)) or _raise(
        ValueError('kernel size should be odd in all dimensions.'))

    channel_axis = -1 if backend_channels_last() else 1

    n_dim = len(kernel_size)
    conv = Conv2D if n_dim == 2 else Conv3D

    input = Input(input_shape, name="input")
    unet = unet_block(n_depth,
                      n_filter_base,
                      kernel_size,
                      activation=activation,
                      dropout=dropout,
                      batch_norm=batch_norm,
                      n_conv_per_depth=n_conv_per_depth,
                      pool=pool_size)(input)

    final = conv(n_channel_out, (1, ) * n_dim, activation='linear')(unet)
    if residual:
        if not (n_channel_out == input_shape[-1] if backend_channels_last()
                else n_channel_out == input_shape[0]):
            raise ValueError(
                "number of input and output channels must be the same for a residual net."
            )
        final = Add()([final, input])
    final = Activation(activation=last_activation)(final)

    if prob_out:
        scale = conv(n_channel_out, (1, ) * n_dim, activation='softplus')(unet)
        scale = Lambda(lambda x: x + np.float32(eps_scale))(scale)
        final = Concatenate(axis=channel_axis)([final, scale])

    return Model(inputs=input, outputs=final)
Ejemplo n.º 2
0
    def predict(self, img, resizer=PadAndCropResizer(), **predict_kwargs):
        """Predict.

        Parameters
        ----------
        img : :class:`numpy.ndarray`
            Input image
        resizer : :class:`csbdeep.data.Resizer` or None
            If necessary, input image is resized to enable neural network prediction and result is (possibly)
            resized to yield original image size.

        Returns
        -------
        (:class:`numpy.ndarray`,:class:`numpy.ndarray`)
            Returns the tuple (`prob`, `dist`) of per-pixel object probabilities and star-convex polygon distances.

        """
        if resizer is None:
            resizer = NoResizer()
        isinstance(resizer, Resizer) or _raise(ValueError())

        img.ndim in (2, 3) or _raise(ValueError())

        x = img
        if x.ndim == 2:
            x = np.expand_dims(x, (-1 if backend_channels_last() else 0))

        channel = x.ndim - 1 if backend_channels_last() else 0
        axes = 'YXC' if backend_channels_last() else 'CYX'
        self.config.n_channel_in == x.shape[channel] or _raise(ValueError())

        # resize: make divisible by power of 2 to allow downsampling steps in unet
        axes_div_by = tuple(2**self.config.unet_n_depth if a != 'C' else 1
                            for a in axes)
        x = resizer.before(x, axes, axes_div_by)

        if backend_channels_last():
            sh = x.shape[:-1] + (1, )
        else:
            sh = (1, ) + x.shape[1:]
        dummy = np.empty((1, ) + sh, np.float32)

        prob, dist = self.keras_model.predict([np.expand_dims(x, 0), dummy],
                                              **predict_kwargs)
        prob, dist = prob[0], dist[0]

        prob = resizer.after(prob, axes)
        dist = resizer.after(dist, axes)

        prob = np.take(prob, 0, axis=channel)
        dist = np.moveaxis(dist, channel, -1)

        return prob, dist
Ejemplo n.º 3
0
    def __init__(self, n_rays=32, n_channel_in=1, **kwargs):
        """See class docstring."""

        # directly set by parameters
        self.n_rays = n_rays
        self.n_channel_in = int(n_channel_in)

        # default config (can be overwritten by kwargs below)
        self.unet_n_depth = 3
        self.unet_kernel_size = (3, 3)
        self.unet_n_filter_base = 32
        self.net_conv_after_unet = 128
        if backend_channels_last():
            self.net_input_shape = (None, None, self.n_channel_in)
            self.net_mask_shape = (None, None, 1)
        else:
            self.net_input_shape = (self.n_channel_in, None, None)
            self.net_mask_shape = (1, None, None)

        self.train_shape_completion = False
        self.train_completion_crop = 32
        self.train_patch_size = (256, 256)

        self.train_dist_loss = 'mae'
        self.train_epochs = 100
        self.train_steps_per_epoch = 400
        self.train_learning_rate = 0.0003
        self.train_batch_size = 4
        self.train_tensorboard = True
        self.train_checkpoint = 'weights_best.h5'
        self.train_reduce_lr = {'factor': 0.5, 'patience': 10}

        for k in kwargs:
            setattr(self, k, kwargs[k])
Ejemplo n.º 4
0
    def _build_unet(self):
        assert self.config.backbone == 'unet'

        input_img = Input(self.config.net_input_shape, name='input')
        if backend_channels_last():
            grid_shape = tuple(n//g if n is not None else None for g,n in zip(self.config.grid, self.config.net_mask_shape[:-1])) + (1,)
        else:
            grid_shape = (1,) + tuple(n//g if n is not None else None for g,n in zip(self.config.grid, self.config.net_mask_shape[1:]))
        input_mask = Input(grid_shape, name='dist_mask')

        unet_kwargs = {k[len('unet_'):]:v for (k,v) in vars(self.config).items() if k.startswith('unet_')}

        # maxpool input image to grid size
        pooled = np.array([1,1,1])
        pooled_img = input_img
        while tuple(pooled) != tuple(self.config.grid):
            pool = 1 + (np.asarray(self.config.grid) > pooled)
            pooled *= pool
            for _ in range(self.config.unet_n_conv_per_depth):
                pooled_img = Conv3D(self.config.unet_n_filter_base, self.config.unet_kernel_size,
                                    padding="same", activation=self.config.unet_activation)(pooled_img)
            pooled_img = MaxPooling3D(pool)(pooled_img)

        unet     = unet_block(**unet_kwargs)(pooled_img)
        if self.config.net_conv_after_unet > 0:
            unet = Conv3D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
                          name='features', padding='same', activation=self.config.unet_activation)(unet)

        output_prob = Conv3D(1,                  (1,1,1), name='prob', padding='same', activation='sigmoid')(unet)
        output_dist = Conv3D(self.config.n_rays, (1,1,1), name='dist', padding='same', activation='linear')(unet)
        return Model([input_img,input_mask], [output_prob,output_dist])
Ejemplo n.º 5
0
    def __init__(self, axes='YX', n_rays=32, n_channel_in=1, grid=(1,1), backbone='unet', **kwargs):
        """See class docstring."""

        super().__init__(axes=axes, n_channel_in=n_channel_in, n_channel_out=1+n_rays)

        # directly set by parameters
        self.n_rays                    = int(n_rays)
        self.grid                      = _normalize_grid(grid,2)
        self.backbone                  = str(backbone).lower()

        # default config (can be overwritten by kwargs below)
        if self.backbone == 'unet':
            self.unet_n_depth          = 3
            self.unet_kernel_size      = 3,3
            self.unet_n_filter_base    = 32
            self.unet_n_conv_per_depth = 2
            self.unet_pool             = 2,2
            self.unet_activation       = 'relu'
            self.unet_last_activation  = 'relu'
            self.unet_batch_norm       = False
            self.unet_dropout          = 0.0
            self.unet_prefix           = ''
            self.net_conv_after_unet   = 128
        else:
            # TODO: resnet backbone for 2D model?
            raise ValueError("backbone '%s' not supported." % self.backbone)

        if backend_channels_last():
            self.net_input_shape       = None,None,self.n_channel_in
            self.net_mask_shape        = None,None,1
        else:
            self.net_input_shape       = self.n_channel_in,None,None
            self.net_mask_shape        = 1,None,None

        self.train_shape_completion    = False
        self.train_completion_crop     = 32
        self.train_patch_size          = 256,256
        self.train_background_reg      = 1e-4

        self.train_dist_loss           = 'mae'
        self.train_loss_weights        = 1,0.2
        self.train_epochs              = 400
        self.train_steps_per_epoch     = 100
        self.train_learning_rate       = 0.0003
        self.train_batch_size          = 4
        self.train_n_val_patches       = None
        self.train_tensorboard         = True
        # the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
        min_delta_key = 'epsilon' if LooseVersion(keras.__version__)<=LooseVersion('2.1.5') else 'min_delta'
        self.train_reduce_lr           = {'factor': 0.5, 'patience': 40, min_delta_key: 0}

        self.use_gpu                   = False

        # remove derived attributes that shouldn't be overwritten
        for k in ('n_dim', 'n_channel_out'):
            try: del kwargs[k]
            except KeyError: pass

        self.update_parameters(False, **kwargs)
Ejemplo n.º 6
0
    def _build_resnet(self):
        assert self.config.backbone == 'resnet'

        input_img = Input(self.config.net_input_shape, name='input')
        if backend_channels_last():
            grid_shape = tuple(n // g if n is not None else None
                               for g, n in zip(self.config.grid, self.config.
                                               net_mask_shape[:-1])) + (1, )
        else:
            grid_shape = (1, ) + tuple(
                n // g if n is not None else None for g, n in zip(
                    self.config.grid, self.config.net_mask_shape[1:]))
        input_mask = Input(grid_shape, name='dist_mask')

        n_filter = self.config.resnet_n_filter_base
        resnet_kwargs = dict(
            kernel_size=self.config.resnet_kernel_size,
            n_conv_per_block=self.config.resnet_n_conv_per_block,
            batch_norm=self.config.resnet_batch_norm,
            kernel_initializer=self.config.resnet_kernel_init,
            activation=self.config.resnet_activation,
        )

        layer = input_img
        layer = Conv3D(
            n_filter, (7, 7, 7),
            padding="same",
            kernel_initializer=self.config.resnet_kernel_init)(layer)
        layer = Conv3D(
            n_filter, (3, 3, 3),
            padding="same",
            kernel_initializer=self.config.resnet_kernel_init)(layer)

        pooled = np.array([1, 1, 1])
        for n in range(self.config.resnet_n_blocks):
            pool = 1 + (np.asarray(self.config.grid) > pooled)
            pooled *= pool
            if any(p > 1 for p in pool):
                n_filter *= 2
            layer = resnet_block(n_filter, pool=tuple(pool),
                                 **resnet_kwargs)(layer)

        if self.config.net_conv_after_resnet > 0:
            layer = Conv3D(self.config.net_conv_after_resnet,
                           self.config.resnet_kernel_size,
                           name='features',
                           padding='same',
                           activation=self.config.resnet_activation)(layer)

        output_prob = Conv3D(1, (1, 1, 1),
                             name='prob',
                             padding='same',
                             activation='sigmoid')(layer)
        output_dist = Conv3D(self.config.n_rays, (1, 1, 1),
                             name='dist',
                             padding='same',
                             activation='linear')(layer)
        return Model([input_img, input_mask], [output_prob, output_dist])
Ejemplo n.º 7
0
def resnet_block(n_filter,
                 kernel_size=(3, 3),
                 pool=(1, 1),
                 n_conv_per_block=2,
                 batch_norm=False,
                 kernel_initializer='he_normal',
                 activation='relu'):
    n_conv_per_block >= 2 or _raise(
        ValueError('required: n_conv_per_block >= 2'))
    len(pool) == len(kernel_size) or _raise(
        ValueError('kernel and pool sizes must match.'))
    n_dim = len(kernel_size)
    n_dim in (2, 3) or _raise(ValueError('resnet_block only 2d or 3d.'))

    conv_layer = Conv2D if n_dim == 2 else Conv3D
    conv_kwargs = dict(
        padding='same',
        use_bias=not batch_norm,
        kernel_initializer=kernel_initializer,
    )
    channel_axis = -1 if backend_channels_last() else 1

    def f(inp):
        x = conv_layer(n_filter, kernel_size, strides=pool, **conv_kwargs)(inp)
        if batch_norm:
            x = BatchNormalization(axis=channel_axis)(x)
        x = Activation(activation)(x)

        for _ in range(n_conv_per_block - 2):
            x = conv_layer(n_filter, kernel_size, **conv_kwargs)(x)
            if batch_norm:
                x = BatchNormalization(axis=channel_axis)(x)
            x = Activation(activation)(x)

        x = conv_layer(n_filter, kernel_size, **conv_kwargs)(x)
        if batch_norm:
            x = BatchNormalization(axis=channel_axis)(x)

        if any(p != 1 for p in pool) or n_filter != K.int_shape(inp)[-1]:
            inp = conv_layer(n_filter, (1, ) * n_dim,
                             strides=pool,
                             **conv_kwargs)(inp)

        x = Add()([inp, x])
        x = Activation(activation)(x)
        return x

    return f
Ejemplo n.º 8
0
    def _build(self):
        #self.config.backbone == 'unet'|'unet2' or _raise(NotImplementedError())
        self.config.backbone in ('unet','unet2') or _raise(NotImplementedError())
        
        input_img  = Input(self.config.net_input_shape, name='input')
        if backend_channels_last():
            grid_shape = tuple(n//g if n is not None else None for g,n in zip(self.config.grid, self.config.net_mask_shape[:-1])) + (1,)
        else:
            grid_shape = (1,) + tuple(n//g if n is not None else None for g,n in zip(self.config.grid, self.config.net_mask_shape[1:]))
        input_mask = Input(grid_shape, name='dist_mask')

        unet_kwargs = {k[len('unet_'):]:v for (k,v) in vars(self.config).items() if k.startswith('unet_')}

        # maxpool input image to grid size
        pooled = np.array([1,1])
        pooled_img = input_img
        while tuple(pooled) != tuple(self.config.grid):
            pool = 1 + (np.asarray(self.config.grid) > pooled)
            pooled *= pool
            for _ in range(self.config.unet_n_conv_per_depth):
                pooled_img = Conv2D(self.config.unet_n_filter_base, self.config.unet_kernel_size,
                                    padding='same', activation=self.config.unet_activation)(pooled_img)
            pooled_img = MaxPooling2D(pool)(pooled_img)

        if(self.config.backbone == 'unet2'):
            unet       = unet_block2(**unet_kwargs)(pooled_img)
            if self.config.net_conv_after_unet > 0:
                unet    = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
                                 name='features', padding='same', activation=self.config.unet_activation)(unet)
                ## extra dropout layer after the feature layer
                unet = Dropout(rate = self.config.feature_dropout)(unet)

            output_prob  = Conv2D(1, (1,1), name='prob', padding='same', activation='sigmoid',kernel_initializer='glorot_uniform' )(unet)
            if self.config.y_range is None:
                output_dist  = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', activation='linear')(unet)
            else: 
                output_dist  = Conv2D(self.config.n_rays, (1,1), name='linear', padding='same', activation='linear')(unet)
                output_dist  = RangedSig(y_min=self.config.y_range[0], y_max=self.config.y_range[1], name='dist')(output_dist)
           
#           if self.config.y_range is not None: 
#                y_min = self.config.y_range[0]
#                y_max = self.config.y_range[1]
#                output_dist = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', 
#               ==                       activation=Activation(lambda x: output_to_y_range(x, y_min, y_max)))(unet)
#            else:
#                output_dist = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', activation='linear')(unet)
        
        else:
            unet       = unet_block(**unet_kwargs)(pooled_img)

            if self.config.net_conv_after_unet > 0:
                unet    = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
                                 name='features', padding='same', activation=self.config.unet_activation)(unet)

            output_prob  = Conv2D(1,                  (1,1), name='prob', padding='same', activation='sigmoid')(unet)
            if self.config.y_range is not None: 
                y_min = self.config.y_range[0]
                y_max = self.config.y_range[1]
                output_dist = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', 
                                     activation=Activation(lambda x: output_to_y_range(x, y_min, y_max)))(unet)
            else:
                output_dist = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', activation='linear')(unet)
#           output_dist = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', activation='linear')(unet)


        return Model([input_img,input_mask], [output_prob,output_dist])
Ejemplo n.º 9
0
    def __init__(self, axes='YX', n_rays=32, n_channel_in=1, grid=(1,1), backbone='unet', **kwargs):
        """See class docstring."""

        super().__init__(axes=axes, n_channel_in=n_channel_in, n_channel_out=1+n_rays)

        # directly set by parameters
        self.n_rays                    = int(n_rays)
        self.grid                      = _normalize_grid(grid,2)
        self.backbone                  = str(backbone).lower()

        # default config (can be overwritten by kwargs below)
        if self.backbone in ('unet', 'unet2'):
            self.unet_n_depth          = 3
            self.unet_kernel_size      = 3,3
            self.unet_n_filter_base    = 32
            self.unet_n_conv_per_depth = 2
            self.unet_pool             = 2,2
            self.unet_activation       = 'relu'
            self.unet_last_activation  = 'relu'
            self.unet_batch_norm       = False
            self.unet_dropout          = 0.0
            self.unet_prefix           = ''
            self.net_conv_after_unet   = 128
            ## add unet kernel initialization params
            if self.backbone == 'unet2':
                self.unet_kernel_init  = 'he_uniform'

        else:
            # TODO: resnet backbone for 2D model?
            raise ValueError("backbone '%s' not supported." % self.backbone)

        if backend_channels_last():
            self.net_input_shape       = None,None,self.n_channel_in
            self.net_mask_shape        = None,None,1
        else:
            self.net_input_shape       = self.n_channel_in,None,None
            self.net_mask_shape        = 1,None,None

        self.train_shape_completion    = False
        self.train_completion_crop     = 32
        self.train_patch_size          = 256,256
        ## whether the dist loss will be normalized by mean(mask)
        self.norm_by_mask              = True           
        self.train_background_reg      = 1e-4
        self.train_foreground_only     = 0.9
   
        self.train_dist_loss           = 'mae'
        self.train_loss_weights        = 1,0.2
        self.train_epochs              = 400
        self.train_steps_per_epoch     = 100
        self.train_learning_rate       = 0.0003
        self.train_batch_size          = 4
        self.train_n_val_patches       = None
        self.train_tensorboard         = True
        # the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
        min_delta_key = 'epsilon' if LooseVersion(keras.__version__)<=LooseVersion('2.1.5') else 'min_delta'
        self.train_reduce_lr           = {'factor': 0.5, 'patience': 40, min_delta_key: 0}
        ## implement one cycle learning rate training policy 
        self.train_one_cycle_lr_max          = None
        ## implement constrained distance output range. If we have a know range of object(nucleus) radius to predict
        #self.y_range = [0.0,self.train_patch_size[0]/(2*self.grid[0])]
        self.y_range = None
        ## implement EDT probability threshold, prob value below threshold will be set to zero
        self.EDT_prob_threshold = 0
        ## implement dropout layer and droprate after the feature layer
        self.feature_dropout = 0
        ## implement EDT border constraint
        self.EDT_border_R = 9
        self.EDT_black_border = False

        self.use_gpu                   = False

        # remove derived attributes that shouldn't be overwritten
        for k in ('n_dim', 'n_channel_out'):
            try: del kwargs[k]
            except KeyError: pass

        self.update_parameters(False, **kwargs)
Ejemplo n.º 10
0
    def __init__(self, X, **kwargs):
        """See class docstring"""

        # X is empty if config is None
        if  X.size != 0:
            assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

            n_dim = len(X.shape) - 2

            if n_dim == 2:
                axes = 'SYXC'
            elif n_dim == 3:
                axes = 'SZYXC'

            # parse and check axes
            axes = axes_check_and_normalize(axes)
            ax = axes_dict(axes)
            ax = {a: (ax[a] is not None) for a in ax}

            (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
            not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

            axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
            axes = axes.replace('S', '')  # remove sample axis if it exists

            if backend_channels_last():
                if ax['C']:
                    axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
                else:
                    axes += 'C'
            else:
                if ax['C']:
                    axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
                else:
                    axes = 'C' + axes

            means, stds = [], []
            for i in range(X.shape[-1]):
                means.append(np.mean(X[..., i]))
                stds.append(np.std(X[..., i]))

            # normalization parameters
            self.means = [str(el) for el in means]
            self.stds = [str(el) for el in stds]
            # directly set by parameters
            self.n_dim = n_dim
            self.axes = axes
            # fixed parameters
            self.n_channel_in = 1
            self.n_channel_out = 4
            self.train_loss = 'denoiseg'

            # default config (can be overwritten by kwargs below)

            self.unet_n_depth = 4
            self.relative_weights = [1.0, 1.0, 5.0]
            self.unet_kern_size = 3
            self.unet_n_first = 32
            self.unet_last_activation = 'linear'
            self.probabilistic = False
            self.unet_residual = False
            if backend_channels_last():
                self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,)
            else:
                self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,)

            self.train_epochs = 200
            self.train_steps_per_epoch = 400
            self.train_learning_rate = 0.0004
            self.train_batch_size = 128
            self.train_tensorboard = False
            self.train_checkpoint = 'weights_best.h5'
            self.train_checkpoint_last  = 'weights_last.h5'
            self.train_checkpoint_epoch = 'weights_now.h5'
            self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10}
            self.batch_norm = True
            self.n2v_perc_pix = 1.5
            self.n2v_patch_shape = (64, 64) if self.n_dim == 2 else (64, 64, 64)
            self.n2v_manipulator = 'uniform_withCP'
            self.n2v_neighborhood_radius = 5
            self.denoiseg_alpha = 0.5

        # disallow setting 'probabilistic' manually
        try:
            kwargs['probabilistic'] = False
        except:
            pass
        # disallow setting 'unet_residual' manually
        try:
            kwargs['unet_residual'] = False
        except:
            pass

        for k in kwargs:
            setattr(self, k, kwargs[k])
Ejemplo n.º 11
0
    def __init__(self, axes='ZYX', rays=None, n_channel_in=1, grid=(1,1,1), anisotropy=None, backbone='resnet', **kwargs):

        if rays is None:
            if 'rays_json' in kwargs:
                rays = rays_from_json(kwargs['rays_json'])
            elif 'n_rays' in kwargs:
                rays = Rays_GoldenSpiral(kwargs['n_rays'])
            else:
                rays = Rays_GoldenSpiral(96)
        elif np.isscalar(rays):
            rays = Rays_GoldenSpiral(rays)

        super().__init__(axes=axes, n_channel_in=n_channel_in, n_channel_out=1+len(rays))

        # directly set by parameters
        self.n_rays                    = len(rays)
        self.grid                      = _normalize_grid(grid,3)
        self.anisotropy                = anisotropy if anisotropy is None else tuple(anisotropy)
        self.backbone                  = str(backbone).lower()
        self.rays_json                 = rays.to_json()

        if 'anisotropy' in self.rays_json['kwargs']:
            if self.rays_json['kwargs']['anisotropy'] is None and self.anisotropy is not None:
                self.rays_json['kwargs']['anisotropy'] = self.anisotropy
                print("Changing 'anisotropy' of rays to %s" % str(anisotropy))
            elif self.rays_json['kwargs']['anisotropy'] != self.anisotropy:
                warnings.warn("Mismatch of 'anisotropy' of rays and 'anisotropy'.")

        # default config (can be overwritten by kwargs below)
        if self.backbone == 'unet':
            self.unet_n_depth            = 2
            self.unet_kernel_size        = 3,3,3
            self.unet_n_filter_base      = 32
            self.unet_n_conv_per_depth   = 2
            self.unet_pool               = 2,2,2
            self.unet_activation         = 'relu'
            self.unet_last_activation    = 'relu'
            self.unet_batch_norm         = False
            self.unet_dropout            = 0.0
            self.unet_prefix             = ''
            self.net_conv_after_unet     = 128
        elif self.backbone == 'resnet':
            self.resnet_n_blocks         = 4
            self.resnet_kernel_size      = 3,3,3
            self.resnet_kernel_init      = 'he_normal'
            self.resnet_n_filter_base    = 32
            self.resnet_n_conv_per_block = 3
            self.resnet_activation       = 'relu'
            self.resnet_batch_norm       = False
            self.net_conv_after_resnet   = 128
        else:
            raise ValueError("backbone '%s' not supported." % self.backbone)

        if backend_channels_last():
            self.net_input_shape       = None,None,None,self.n_channel_in
            self.net_mask_shape        = None,None,None,1
        else:
            self.net_input_shape       = self.n_channel_in,None,None,None
            self.net_mask_shape        = 1,None,None,None

        # self.train_shape_completion    = False
        # self.train_completion_crop     = 32
        self.train_patch_size          = 128,128,128
        self.train_background_reg      = 1e-4

        self.train_dist_loss           = 'mae'
        self.train_loss_weights        = 1,0.2
        self.train_epochs              = 400
        self.train_steps_per_epoch     = 100
        self.train_learning_rate       = 0.0003
        self.train_batch_size          = 1
        self.train_n_val_patches       = None
        self.train_tensorboard         = True
        # the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
        min_delta_key = 'epsilon' if LooseVersion(keras.__version__)<=LooseVersion('2.1.5') else 'min_delta'
        self.train_reduce_lr           = {'factor': 0.5, 'patience': 40, min_delta_key: 0}

        self.use_gpu                   = False

        # remove derived attributes that shouldn't be overwritten
        for k in ('n_dim', 'n_channel_out', 'n_rays', 'rays_json'):
            try: del kwargs[k]
            except KeyError: pass

        self.update_parameters(False, **kwargs)
Ejemplo n.º 12
0
    def __init__(self, X, **kwargs):
        """See class docstring."""

        assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

        n_dim = len(X.shape) - 2
        n_channel_in = X.shape[-1]
        n_channel_out = n_channel_in
        mean = np.mean(X)
        std = np.std(X)

        if n_dim == 2:
            axes = 'SYXC'
        elif n_dim == 3:
            axes = 'SZYXC'

        # parse and check axes
        axes = axes_check_and_normalize(axes)
        ax = axes_dict(axes)
        ax = {a: (ax[a] is not None) for a in ax}

        (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
        not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

        axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
        axes = axes.replace('S','') # remove sample axis if it exists

        if backend_channels_last():
            if ax['C']:
                axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
            else:
                axes += 'C'
        else:
            if ax['C']:
                axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
            else:
                axes = 'C'+axes

        # normalization parameters
        self.mean                  = str(mean)
        self.std                   = str(std)
        # directly set by parameters
        self.n_dim                 = n_dim
        self.axes                  = axes
        self.n_channel_in          = int(n_channel_in)
        self.n_channel_out         = int(n_channel_out)

        # default config (can be overwritten by kwargs below)
        self.unet_residual         = False
        self.unet_n_depth          = 2
        self.unet_kern_size        = 5 if self.n_dim==2 else 3
        self.unet_n_first          = 32
        self.unet_last_activation  = 'linear'
        if backend_channels_last():
            self.unet_input_shape  = self.n_dim*(None,) + (self.n_channel_in,)
        else:
            self.unet_input_shape  = (self.n_channel_in,) + self.n_dim*(None,)

        self.train_loss            = 'mae'
        self.train_epochs          = 100
        self.train_steps_per_epoch = 400
        self.train_learning_rate   = 0.0004
        self.train_batch_size      = 16
        self.train_tensorboard     = True
        self.train_checkpoint      = 'weights_best.h5'
        self.train_reduce_lr       = {'factor': 0.5, 'patience': 10}
        self.batch_norm            = True
        self.n2v_perc_pix           = 1.5
        self.n2v_patch_shape       = (64, 64) if self.n_dim==2 else (64, 64, 64)
        self.n2v_manipulator       = 'uniform_withCP'
        self.n2v_neighborhood_radius = 5

        # disallow setting 'n_dim' manually
        try:
            del kwargs['n_dim']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass

        for k in kwargs:
            setattr(self, k, kwargs[k])
Ejemplo n.º 13
0
def from_tensor(x, channel=0, single_sample=True):
    return np.moveaxis((x[0] if single_sample else x),
                       (-1 if backend_channels_last() else 1), channel)
Ejemplo n.º 14
0
def load_training_data_direct(X, Y, validation_split=0, axes=None, n_images=None, verbose=False):
    """Load training data from file in ``.npz`` format.

        The data file is expected to have the keys:

        - ``X``    : Array of training input images.
        - ``Y``    : Array of corresponding target images.
        - ``axes`` : Axes of the training images.


        Parameters
        ----------
        file : str
            File name
        validation_split : float
            Fraction of images to use as validation set during training.
        axes: str, optional
            Must be provided in case the loaded data does not contain ``axes`` information.
        n_images : int, optional
            Can be used to limit the number of images loaded from data.
        verbose : bool, optional
            Can be used to display information about the loaded images.

        Returns
        -------
        tuple( tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), str )
            Returns two tuples (`X_train`, `Y_train`), (`X_val`, `Y_val`) of training and validation sets
            and the axes of the input images.
            The tuple of validation data will be ``None`` if ``validation_split = 0``.

        """

    # f = np.load(file)
    # X, Y = f['X'], f['Y']
    # if axes is None:
    #    axes = f['axes']
    axes = axes_check_and_normalize(axes)

    assert X.shape == Y.shape
    assert len(axes) == X.ndim
    assert 'C' in axes
    if n_images is None:
        n_images = X.shape[0]
    assert X.shape[0] == Y.shape[0]
    assert 0 < n_images <= X.shape[0]
    assert 0 <= validation_split < 1

    X, Y = X[:n_images], Y[:n_images]
    channel = axes_dict(axes)['C']

    if validation_split > 0:
        n_val = int(round(n_images * validation_split))
        n_train = n_images - n_val
        assert 0 < n_val and 0 < n_train
        X_t, Y_t = X[-n_val:], Y[-n_val:]
        X, Y = X[:n_train], Y[:n_train]
        assert X.shape[0] == n_train and X_t.shape[0] == n_val
        X_t = move_channel_for_backend(X_t, channel=channel)
        Y_t = move_channel_for_backend(Y_t, channel=channel)

    X = move_channel_for_backend(X, channel=channel)
    Y = move_channel_for_backend(Y, channel=channel)

    axes = axes.replace('C', '')  # remove channel
    if backend_channels_last():
        axes = axes + 'C'
    else:
        axes = axes[:1] + 'C' + axes[1:]

    data_val = (X_t, Y_t) if validation_split > 0 else None

    if verbose:
        ax = axes_dict(axes)
        n_train, n_val = len(X), len(X_t) if validation_split > 0 else 0
        image_size = tuple(X.shape[ax[a]] for a in 'TZYX' if a in axes)
        n_dim = len(image_size)
        n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]

        print('number of training images:\t', n_train)
        print('number of validation images:\t', n_val)
        print('image size (%dD):\t\t' % n_dim, image_size)
        print('axes:\t\t\t\t', axes)
        print('channels in / out:\t\t', n_channel_in, '/', n_channel_out)

    return (X, Y), data_val, axes
Ejemplo n.º 15
0
    def __init__(self, X,**kwargs):

        if  X.size != 0:

            assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

            n_dim = len(X.shape) - 2

            if n_dim == 2:
                axes = 'SYXC'
            elif n_dim == 3:
                axes = 'SZYXC'

            # parse and check axes
            axes = axes_check_and_normalize(axes)
            ax = axes_dict(axes)
            ax = {a: (ax[a] is not None) for a in ax}

            (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
            not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

            axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
            axes = axes.replace('S', '')  # remove sample axis if it exists

            if backend_channels_last():
                if ax['C']:
                    axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
                else:
                    axes += 'C'
            else:
                if ax['C']:
                    axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
                else:
                    axes = 'C' + axes

            means, stds = [], []
            for i in range(X.shape[-1]):
                means.append(np.mean(X[..., i]))
                stds.append(np.std(X[..., i]))

            # normalization parameters
            self.means = [str(el) for el in means]
            self.stds = [str(el) for el in stds]
            # directly set by parameters
            self.n_dim = n_dim
            self.axes = axes
            # fixed parameters
            if 'C' in axes:
                self.n_channel_in = X.shape[-1]
            else:
                self.n_channel_in = 1
            self.train_loss = 'demix'

            # default config (can be overwritten by kwargs below)

            self.unet_n_depth = 2
            self.unet_kern_size = 3
            self.unet_n_first = 64
            self.unet_last_activation = 'linear'
            self.probabilistic = False
            self.unet_residual = False
            if backend_channels_last():
                self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,)
            else:
                self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,)

            # fixed parameters
            self.train_epochs = 200
            self.train_steps_per_epoch = 50
            self.train_learning_rate = 0.0004
            self.train_batch_size = 64
            self.train_tensorboard = False
            self.train_checkpoint = 'weights_best.h5'
            self.train_checkpoint_last  = 'weights_last.h5'
            self.train_checkpoint_epoch = 'weights_now.h5'
            self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10}
            self.batch_norm = False
            self.n2v_perc_pix = 1.5
            self.n2v_patch_shape = (128, 128) if self.n_dim == 2 else (64, 64, 64)
            self.n2v_manipulator = 'uniform_withCP'
            self.n2v_neighborhood_radius = 5

            self.single_net_per_channel = False


            self.channel_denoised = False
            self.multi_objective = True
            self.normalizer = 'none'
            self.weights_objectives = [0.1, 0, 0.45, 0.45]
            self.distributions = 'gauss'
            self.n2v_leave_center = False
            self.scale_aug = False
            self.structN2Vmask = None

            self.n_back_modes = 2
            self.n_fore_modes = 2
            self.n_instance_seg = 0
            self.n_back_i_modes = 1
            self.n_fore_i_modes = 1

            self.fit_std = False
            self.fit_mean = True



            # self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components +\
            #                      self.n_instance_seg * (self.n_back_i_modes+self.n_fore_i_modes)

        try:
            kwargs['probabilistic'] = False
        except:
            pass
        # disallow setting 'unet_residual' manually
        try:
            kwargs['unet_residual'] = False
        except:
            pass

        # print('KWARGS')
        for k in kwargs:

            # print(k,  kwargs[k])
            setattr(self, k, kwargs[k])
        self.n_components = 3 if (self.fit_std & self.fit_mean) else 2
        self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components + \
                             self.n_instance_seg * (self.n_back_i_modes + self.n_fore_i_modes)
Ejemplo n.º 16
0
    def __init__(self, X, **kwargs):
        """See class docstring"""
        assert X.size != 0
        assert len(X.shape) == 4 or len(
            X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

        n_dim = len(X.shape) - 2

        if n_dim == 2:
            axes = 'SYXC'
        elif n_dim == 3:
            axes = 'SZYXC'

        # parse and check axes
        axes = axes_check_and_normalize(axes)
        ax = axes_dict(axes)
        ax = {a: (ax[a] is not None) for a in ax}

        (ax['X'] and ax['Y']) or _raise(
            ValueError('lateral axes X and Y must be present.'))
        not (ax['Z'] and ax['T']) or _raise(
            ValueError('using Z and T axes together not supported.'))

        axes.startswith('S') or (not ax['S']) or _raise(
            ValueError('sample axis S must be first.'))
        axes = axes.replace('S', '')  # remove sample axis if it exists

        if backend_channels_last():
            if ax['C']:
                axes[-1] == 'C' or _raise(
                    ValueError('channel axis must be last for backend (%s).' %
                               K.backend()))
            else:
                axes += 'C'
        else:
            if ax['C']:
                axes[0] == 'C' or _raise(
                    ValueError('channel axis must be first for backend (%s).' %
                               K.backend()))
            else:
                axes = 'C' + axes

        assert X.shape[-1] == 1
        n_channel_in = 1
        n_channel_out = 3

        # directly set by parameters
        self.n_dim = n_dim
        self.axes = axes
        # fixed parameters
        self.n_channel_in = n_channel_in
        self.n_channel_out = n_channel_out
        self.train_loss = 'seg'

        # default config (can be overwritten by kwargs below)

        self.unet_n_depth = 4
        self.relative_weights = [1.0, 1.0, 5.0]
        self.unet_kern_size = 3
        self.unet_n_first = 32
        self.unet_last_activation = 'linear'
        self.probabilistic = False
        self.unet_residual = False
        if backend_channels_last():
            self.unet_input_shape = self.n_dim * (None, ) + (
                self.n_channel_in, )
        else:
            self.unet_input_shape = (
                self.n_channel_in, ) + self.n_dim * (None, )

        self.train_epochs = 200
        self.train_steps_per_epoch = 400
        self.train_learning_rate = 0.0004
        self.train_batch_size = 128
        self.train_tensorboard = False
        self.train_checkpoint = 'weights_best.h5'
        self.train_checkpoint_last = 'weights_last.h5'
        self.train_checkpoint_epoch = 'weights_now.h5'
        self.train_reduce_lr = {'factor': 0.5, 'patience': 10}
        self.batch_norm = True

        # disallow setting 'n_dim' manually
        try:
            del kwargs['n_dim']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'n_channel_in' manually
        try:
            del kwargs['n_channel_in']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'n_channel_out' manually
        try:
            del kwargs['n_channel_out']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'train_loss' manually
        try:
            del kwargs['train_loss']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'probabilistic' manually
        try:
            del kwargs['probabilistic']
        except:
            pass
        # disallow setting 'unet_residual' manually
        try:
            del kwargs['unet_residual']
        except:
            pass

        for k in kwargs:
            setattr(self, k, kwargs[k])
Ejemplo n.º 17
0
def tensor_num_channels(x):
    return x.shape[-1 if backend_channels_last() else 1]
Ejemplo n.º 18
0
import keras.backend as K
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, TensorBoard
from keras.layers import Input, Conv2D
from keras.models import Model
from keras.utils import Sequence
from keras.optimizers import Adam

from csbdeep.internals.blocks import unet_block
from csbdeep.utils import _raise, Path, load_json, save_json, backend_channels_last
from csbdeep.data import Resizer, NoResizer, PadAndCropResizer

from .utils import star_dist, edt_prob
from skimage.segmentation import clear_border

if not backend_channels_last():
    raise NotImplementedError(
        "Keras is configured to use the '%s' image data format, which is currently not supported. "
        "Please change it to use 'channels_last' instead: "
        "https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored"
        % K.image_data_format())


def masked_loss(mask, penalty):
    def _loss(d_true, d_pred):
        return K.mean(mask * penalty(d_pred - d_true), axis=-1)

    return _loss


def masked_loss_mae(mask):
Ejemplo n.º 19
0
    def _build(self):

        #~~
        #with tf.device('gpu:0'):
        #~~Indentation

        self.config.backbone == 'unet' or _raise(NotImplementedError())

        input_img = Input(self.config.net_input_shape, name='input')
        if backend_channels_last():
            grid_shape = tuple(n // g if n is not None else None
                               for g, n in zip(self.config.grid, self.config.
                                               net_mask_shape[:-1])) + (1, )
        else:
            grid_shape = (1, ) + tuple(
                n // g if n is not None else None for g, n in zip(
                    self.config.grid, self.config.net_mask_shape[1:]))
        input_mask = Input(grid_shape, name='dist_mask')

        unet_kwargs = {
            k[len('unet_'):]: v
            for (k, v) in vars(self.config).items() if k.startswith('unet_')
        }

        # maxpool input image to grid size
        pooled = np.array([1, 1])
        pooled_img = input_img
        while tuple(pooled) != tuple(self.config.grid):
            pool = 1 + (np.asarray(self.config.grid) > pooled)
            pooled *= pool
            for _ in range(self.config.unet_n_conv_per_depth):
                pooled_img = Conv2D(
                    self.config.unet_n_filter_base,
                    self.config.unet_kernel_size,
                    padding='same',
                    activation=self.config.unet_activation)(pooled_img)
            pooled_img = MaxPooling2D(pool)(pooled_img)

        unet = unet_block(**unet_kwargs)(pooled_img)
        if self.config.net_conv_after_unet > 0:
            unet = Conv2D(self.config.net_conv_after_unet,
                          self.config.unet_kernel_size,
                          name='features',
                          padding='same',
                          activation=self.config.unet_activation)(unet)

        output_prob = Conv2D(1, (1, 1),
                             name='prob',
                             padding='same',
                             activation='sigmoid')(unet)
        output_dist = Conv2D(self.config.n_rays, (1, 1),
                             name='dist',
                             padding='same',
                             activation='linear')(unet)

        #~~
        #with tf. device("gpu:0"):
        compModel = Model([input_img, input_mask], [output_prob, output_dist])
        compModel.compile(optimizer=keras.optimizers.Adagrad(
            lr=self.config.lr))

        return compModel
Ejemplo n.º 20
0
def unet_block(n_depth=2,
               n_filter_base=16,
               kernel_size=(3, 3),
               n_conv_per_depth=2,
               activation="relu",
               batch_norm=False,
               dropout=0.0,
               last_activation=None,
               pool=(2, 2),
               prefix=''):
    if len(pool) != len(kernel_size):
        raise ValueError('kernel and pool sizes must match.')
    n_dim = len(kernel_size)
    if n_dim not in (2, 3):
        raise ValueError('unet_block only 2d or 3d.')

    conv_block = conv_block2 if n_dim == 2 else conv_block3
    pooling = MaxPooling2D if n_dim == 2 else MaxPooling3D
    upsampling = UpSampling2D if n_dim == 2 else UpSampling3D

    if last_activation is None:
        last_activation = activation

    channel_axis = -1 if backend_channels_last() else 1

    def _name(s):
        return prefix + s

    def _func(input):
        skip_layers = []
        layer = input

        # down ...
        for n in range(n_depth):
            for i in range(n_conv_per_depth):
                layer = conv_block(n_filter_base * 2**n,
                                   *kernel_size,
                                   dropout=dropout,
                                   activation=activation,
                                   batch_norm=batch_norm,
                                   name=_name("down_level_%s_no_%s" %
                                              (n, i)))(layer)

            skip_layers.append(layer)
            layer = pooling(pool, name=_name("max_%s" % n))(layer)

        # middle
        for i in range(n_conv_per_depth - 1):
            layer = conv_block(n_filter_base * 2**n_depth,
                               *kernel_size,
                               dropout=dropout,
                               activation=activation,
                               batch_norm=batch_norm,
                               name=_name("middle_%s" % i))(layer)

        layer = conv_block(n_filter_base * 2**max(0, n_depth - 1),
                           *kernel_size,
                           dropout=dropout,
                           activation=activation,
                           batch_norm=batch_norm,
                           name=_name("middle_%s" % n_conv_per_depth))(layer)

        # ...and up with skip layers
        for n in reversed(range(n_depth)):
            layer = Concatenate(axis=channel_axis)(
                [upsampling(pool)(layer), skip_layers[n]])
            for i in range(n_conv_per_depth - 1):
                layer = conv_block(n_filter_base * 2**n,
                                   *kernel_size,
                                   dropout=dropout,
                                   activation=activation,
                                   batch_norm=batch_norm,
                                   name=_name("up_level_%s_no_%s" %
                                              (n, i)))(layer)

            layer = conv_block(
                n_filter_base * 2**max(0, n - 1),
                *kernel_size,
                dropout=dropout,
                activation=activation if n > 0 else last_activation,
                batch_norm=batch_norm,
                name=_name("up_level_%s_no_%s" % (n, n_conv_per_depth)))(layer)

        return layer

    return _func