Exemple #1
0
def MaskRCNN_3D(backbone,
                num_classes,
                input_shape,
                norm_method='whole_image',
                crop_size=(14, 14, 14),
                weights=None,
                pooling=None,
                mask_dtype=K.floatx(),
                required_channels=3,
                **kwargs):
    """Constructs a mrcnn model using a backbone from keras-applications.

    Args:
        backbone: string, name of backbone to use.
        num_classes: Number of classes to classify.
        input_shape: The shape of the input data.
        weights: one of `None` (random initialization),
            'imagenet' (pre-training on ImageNet),
            or the path to the weights file to be loaded.
        pooling: optional pooling mode for feature extraction
            when `include_top` is `False`.
            - `None` means that the output of the model will be
                the 4D tensor output of the
                last convolutional layer.
            - `avg` means that global average pooling
                will be applied to the output of the
                last convolutional layer, and thus
                the output of the model will be a 2D tensor.
            - `max` means that global max pooling will
                be applied.
        required_channels: integer, the required number of channels of the
            backbone.  3 is the default for all current backbones.

    Returns:
        RetinaNet model with a backbone.
    """
    inputs = Input(shape=input_shape)
    # force the channel size for backbone input to be `required_channels`
    norm = ImageNormalization3D(norm_method=norm_method)(inputs)
    fixed_inputs = TensorProduct(required_channels)(norm)
    model_kwargs = {
        'include_top': False,
        'input_tensor': fixed_inputs,
        'weights': weights,
        'pooling': pooling
    }
    layer_outputs = get_pyramid_layer_outputs(backbone, inputs, **model_kwargs)

    kwargs['backbone_layers'] = layer_outputs

    # create the full model
    return retinanet_mask_3D(inputs=inputs,
                             num_classes=num_classes,
                             crop_size=crop_size,
                             name='{}_retinanet_mask_3D'.format(backbone),
                             mask_dtype=mask_dtype,
                             **kwargs)
Exemple #2
0
def bn_feature_net_skip_3D(receptive_field=61,
                           input_shape=(5, 256, 256, 1),
                           fgbg_model=None,
                           last_only=True,
                           n_skips=2,
                           norm_method='std',
                           padding_mode='reflect',
                           **kwargs):
    if K.image_data_format() == 'channels_first':
        channel_axis = 1
    else:
        channel_axis = -1

    inputs = Input(shape=input_shape)
    img = ImageNormalization3D(norm_method=norm_method,
                               filter_size=receptive_field)(inputs)

    models = []
    model_outputs = []

    if fgbg_model is not None:
        for layer in fgbg_model.layers:
            layer.trainable = False
        models.append(fgbg_model)
        fgbg_output = fgbg_model(inputs)
        if isinstance(fgbg_output, list):
            fgbg_output = fgbg_output[-1]
        model_outputs.append(fgbg_output)

    for _ in range(n_skips + 1):
        if model_outputs:
            model_input = Concatenate(axis=channel_axis)(
                [img, model_outputs[-1]])
        else:
            model_input = img
        new_input_shape = model_input.get_shape().as_list()[1:]
        models.append(
            bn_feature_net_3D(receptive_field=receptive_field,
                              input_shape=new_input_shape,
                              norm_method=None,
                              dilated=True,
                              padding=True,
                              padding_mode=padding_mode,
                              **kwargs))
        model_outputs.append(models[-1](model_input))

    if last_only:
        model = Model(inputs=inputs, outputs=model_outputs[-1])
    else:
        if fgbg_model is None:
            model = Model(inputs=inputs, outputs=model_outputs)
        else:
            model = Model(inputs=inputs, outputs=model_outputs[1:])

    return model
Exemple #3
0
def bn_feature_net_skip_3D(receptive_field=61,
                           input_shape=(5, 256, 256, 1),
                           fgbg_model=None,
                           last_only=True,
                           n_skips=2,
                           norm_method='std',
                           padding_mode='reflect',
                           **kwargs):
    """Creates a 3D featurenet with skip-connections.

    Args:
        receptive_field (int): the receptive field of the neural network.
        input_shape (tuple): Create input tensor with this shape.
        fgbg_model (tensorflow.keras.Model): Concatenate output of this model
            with the inputs as a skip-connection.
        last_only (bool): Model will only output the final prediction,
            and not return any of the underlying model predictions.
        n_skips (int): The number of skip-connections
        norm_method (str): The type of ImageNormalization to use
        padding_mode (str): Type of padding, one of 'reflect' or 'zero'
        kwargs (dict): Other model options defined in bn_feature_net_3D

    Returns:
        tensorflow.keras.Model: 3D FeatureNet with skip-connections
    """
    channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
    inputs = Input(shape=input_shape)
    img = ImageNormalization3D(norm_method=norm_method,
                               filter_size=receptive_field)(inputs)

    models = []
    model_outputs = []

    if fgbg_model is not None:
        for layer in fgbg_model.layers:
            layer.trainable = False
        models.append(fgbg_model)
        fgbg_output = fgbg_model(inputs)
        if isinstance(fgbg_output, list):
            fgbg_output = fgbg_output[-1]
        model_outputs.append(fgbg_output)

    for _ in range(n_skips + 1):
        if model_outputs:
            model_input = Concatenate(axis=channel_axis)(
                [img, model_outputs[-1]])
        else:
            model_input = img

        new_input_shape = model_input.get_shape().as_list()[1:]
        models.append(
            bn_feature_net_3D(receptive_field=receptive_field,
                              input_shape=new_input_shape,
                              norm_method=None,
                              dilated=True,
                              padding=True,
                              padding_mode=padding_mode,
                              **kwargs))
        model_outputs.append(models[-1](model_input))

    if last_only:
        model = Model(inputs=inputs, outputs=model_outputs[-1])
    elif fgbg_model is None:
        model = Model(inputs=inputs, outputs=model_outputs)
    else:
        model = Model(inputs=inputs, outputs=model_outputs[1:])

    return model
Exemple #4
0
def bn_feature_net_3D(receptive_field=61,
                      n_frames=5,
                      input_shape=(5, 256, 256, 1),
                      n_features=3,
                      n_channels=1,
                      reg=1e-5,
                      n_conv_filters=64,
                      n_dense_filters=200,
                      VGG_mode=False,
                      init='he_normal',
                      norm_method='std',
                      location=False,
                      dilated=False,
                      padding=False,
                      padding_mode='reflect',
                      multires=False,
                      include_top=True):
    """Creates a 3D featurenet.

    Args:
        receptive_field (int): the receptive field of the neural network.
        n_frames (int): Number of frames.
        input_shape (tuple): If no input tensor, create one with this shape.
        n_features (int): Number of output features
        n_channels (int): number of input channels
        reg (int): regularization value
        n_conv_filters (int): number of convolutional filters
        n_dense_filters (int): number of dense filters
        VGG_mode (bool): If multires, uses VGG_mode for multiresolution
        init (str): Method for initalizing weights.
        norm_method (str): ImageNormalization mode to use
        location (bool): Whether to include location data
        dilated (bool): Whether to use dilated pooling.
        padding (bool): Whether to use padding.
        padding_mode (str): Type of padding, one of 'reflect' or 'zero'
        multires (bool): Enables multi-resolution mode
        include_top (bool): Whether to include the final layer of the model

    Returns:
        tensorflow.keras.Model: 3D FeatureNet
    """
    # Create layers list (x) to store all of the layers.
    # We need to use the functional API to enable the multiresolution mode
    x = []

    win = (receptive_field - 1) // 2
    win_z = (n_frames - 1) // 2

    if dilated:
        padding = True

    if K.image_data_format() == 'channels_first':
        channel_axis = 1
        time_axis = 2
        row_axis = 3
        col_axis = 4
        if not dilated:
            input_shape = (n_channels, n_frames, receptive_field,
                           receptive_field)
    else:
        channel_axis = -1
        time_axis = 1
        row_axis = 2
        col_axis = 3
        if not dilated:
            input_shape = (n_frames, receptive_field, receptive_field,
                           n_channels)

    x.append(Input(shape=input_shape))
    x.append(
        ImageNormalization3D(norm_method=norm_method,
                             filter_size=receptive_field)(x[-1]))

    if padding:
        if padding_mode == 'reflect':
            x.append(ReflectionPadding3D(padding=(win_z, win, win))(x[-1]))
        elif padding_mode == 'zero':
            x.append(ZeroPadding3D(padding=(win_z, win, win))(x[-1]))

    if location:
        x.append(Location3D(in_shape=tuple(x[-1].shape.as_list()[1:]))(x[-1]))
        x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]]))

    layers_to_concat = []

    rf_counter = receptive_field
    block_counter = 0
    d = 1

    while rf_counter > 4:
        filter_size = 3 if rf_counter % 2 == 0 else 4
        x.append(
            Conv3D(n_conv_filters, (1, filter_size, filter_size),
                   dilation_rate=(1, d, d),
                   kernel_initializer=init,
                   padding='valid',
                   kernel_regularizer=l2(reg))(x[-1]))
        x.append(BatchNormalization(axis=channel_axis)(x[-1]))
        x.append(Activation('relu')(x[-1]))

        block_counter += 1
        rf_counter -= filter_size - 1

        if block_counter % 2 == 0:
            if dilated:
                x.append(
                    DilatedMaxPool3D(dilation_rate=(1, d, d),
                                     pool_size=(1, 2, 2))(x[-1]))
                d *= 2
            else:
                x.append(MaxPool3D(pool_size=(1, 2, 2))(x[-1]))

            if VGG_mode:
                n_conv_filters *= 2

            rf_counter = rf_counter // 2

            if multires:
                layers_to_concat.append(len(x) - 1)

    if multires:
        c = []
        for l in layers_to_concat:
            output_shape = x[l].get_shape().as_list()
            target_shape = x[-1].get_shape().as_list()
            time_crop = (0, 0)

            row_crop = int(output_shape[row_axis] - target_shape[row_axis])

            if row_crop % 2 == 0:
                row_crop = (row_crop // 2, row_crop // 2)
            else:
                row_crop = (row_crop // 2, row_crop // 2 + 1)

            col_crop = int(output_shape[col_axis] - target_shape[col_axis])

            if col_crop % 2 == 0:
                col_crop = (col_crop // 2, col_crop // 2)
            else:
                col_crop = (col_crop // 2, col_crop // 2 + 1)

            cropping = (time_crop, row_crop, col_crop)

            c.append(Cropping3D(cropping=cropping)(x[l]))
        x.append(Concatenate(axis=channel_axis)(c))

    x.append(
        Conv3D(n_dense_filters, (1, rf_counter, rf_counter),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        Conv3D(n_dense_filters, (n_frames, 1, 1),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        TensorProduct(n_dense_filters,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        TensorProduct(n_features,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))

    if not dilated:
        x.append(Flatten()(x[-1]))

    if include_top:
        x.append(Softmax(axis=channel_axis)(x[-1]))

    model = Model(inputs=x[0], outputs=x[-1])

    return model
Exemple #5
0
def bn_feature_net_3D(receptive_field=61,
                      n_frames=5,
                      input_shape=(5, 256, 256, 1),
                      n_features=3,
                      n_channels=1,
                      reg=1e-5,
                      n_conv_filters=64,
                      n_dense_filters=200,
                      VGG_mode=False,
                      init='he_normal',
                      norm_method='std',
                      location=False,
                      dilated=False,
                      padding=False,
                      padding_mode='reflect',
                      multires=False,
                      include_top=True,
                      temporal=None,
                      residual=False,
                      temporal_kernel_size=3):
    """Creates a 3D featurenet.

    Args:
        receptive_field (int): the receptive field of the neural network.
        n_frames (int): Number of frames.
        input_shape (tuple): If no input tensor, create one with this shape.
        n_features (int): Number of output features
        n_channels (int): number of input channels
        reg (int): regularization value
        n_conv_filters (int): number of convolutional filters
        n_dense_filters (int): number of dense filters
        VGG_mode (bool): If ``multires``, uses ``VGG_mode``
            for multiresolution
        init (str): Method for initalizing weights.
        norm_method (str): Normalization method to use with the
            :mod:`deepcell.layers.normalization.ImageNormalization3D` layer.
        location (bool): Whether to include a
            :mod:`deepcell.layers.location.Location3D` layer.
        dilated (bool): Whether to use dilated pooling.
        padding (bool): Whether to use padding.
        padding_mode (str): Type of padding, one of 'reflect' or 'zero'
        multires (bool): Enables multi-resolution mode
        include_top (bool): Whether to include the final layer of the model
        temporal (str): Type of temporal operation
        residual (bool): Whether to use temporal information as a residual
        temporal_kernel_size (int): size of 2D kernel used in temporal convolutions

    Returns:
        tensorflow.keras.Model: 3D FeatureNet
    """
    # Create layers list (x) to store all of the layers.
    # We need to use the functional API to enable the multiresolution mode
    x = []

    win = (receptive_field - 1) // 2
    win_z = (n_frames - 1) // 2

    if dilated:
        padding = True

    if K.image_data_format() == 'channels_first':
        channel_axis = 1
        time_axis = 2
        row_axis = 3
        col_axis = 4
        if not dilated:
            input_shape = (n_channels, n_frames, receptive_field,
                           receptive_field)
    else:
        channel_axis = -1
        time_axis = 1
        row_axis = 2
        col_axis = 3
        if not dilated:
            input_shape = (n_frames, receptive_field, receptive_field,
                           n_channels)

    x.append(Input(shape=input_shape))
    x.append(
        ImageNormalization3D(norm_method=norm_method,
                             filter_size=receptive_field)(x[-1]))

    if padding:
        if padding_mode == 'reflect':
            x.append(ReflectionPadding3D(padding=(win_z, win, win))(x[-1]))
        elif padding_mode == 'zero':
            x.append(ZeroPadding3D(padding=(win_z, win, win))(x[-1]))

    if location:
        x.append(Location3D()(x[-1]))
        x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]]))

    layers_to_concat = []

    rf_counter = receptive_field
    block_counter = 0
    d = 1

    while rf_counter > 4:
        filter_size = 3 if rf_counter % 2 == 0 else 4
        x.append(
            Conv3D(n_conv_filters, (1, filter_size, filter_size),
                   dilation_rate=(1, d, d),
                   kernel_initializer=init,
                   padding='valid',
                   kernel_regularizer=l2(reg))(x[-1]))
        x.append(BatchNormalization(axis=channel_axis)(x[-1]))
        x.append(Activation('relu')(x[-1]))

        block_counter += 1
        rf_counter -= filter_size - 1

        if block_counter % 2 == 0:
            if dilated:
                x.append(
                    DilatedMaxPool3D(dilation_rate=(1, d, d),
                                     pool_size=(1, 2, 2))(x[-1]))
                d *= 2
            else:
                x.append(MaxPool3D(pool_size=(1, 2, 2))(x[-1]))

            if VGG_mode:
                n_conv_filters *= 2

            rf_counter = rf_counter // 2

            if multires:
                layers_to_concat.append(len(x) - 1)

    if multires:
        c = []
        for l in layers_to_concat:
            output_shape = x[l].get_shape().as_list()
            target_shape = x[-1].get_shape().as_list()
            time_crop = (0, 0)

            row_crop = int(output_shape[row_axis] - target_shape[row_axis])

            if row_crop % 2 == 0:
                row_crop = (row_crop // 2, row_crop // 2)
            else:
                row_crop = (row_crop // 2, row_crop // 2 + 1)

            col_crop = int(output_shape[col_axis] - target_shape[col_axis])

            if col_crop % 2 == 0:
                col_crop = (col_crop // 2, col_crop // 2)
            else:
                col_crop = (col_crop // 2, col_crop // 2 + 1)

            cropping = (time_crop, row_crop, col_crop)

            c.append(Cropping3D(cropping=cropping)(x[l]))
        x.append(Concatenate(axis=channel_axis)(c))

    x.append(
        Conv3D(n_dense_filters, (1, rf_counter, rf_counter),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        Conv3D(n_dense_filters, (n_frames, 1, 1),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    feature = Activation('relu')(x[-1])

    def __merge_temporal_features(feature,
                                  mode='conv',
                                  residual=False,
                                  n_filters=256,
                                  n_frames=3,
                                  padding=True,
                                  temporal_kernel_size=3):
        if mode is None:
            return feature

        mode = str(mode).lower()
        if mode == 'conv':
            x = Conv3D(n_filters,
                       (n_frames, temporal_kernel_size, temporal_kernel_size),
                       kernel_initializer=init,
                       padding='same',
                       activation='relu',
                       kernel_regularizer=l2(reg))(feature)
        elif mode == 'lstm':
            x = ConvLSTM2D(filters=n_filters,
                           kernel_size=temporal_kernel_size,
                           padding='same',
                           kernel_initializer=init,
                           activation='relu',
                           kernel_regularizer=l2(reg),
                           return_sequences=True)(feature)
        elif mode == 'gru':
            x = ConvGRU2D(filters=n_filters,
                          kernel_size=temporal_kernel_size,
                          padding='same',
                          kernel_initializer=init,
                          activation='relu',
                          kernel_regularizer=l2(reg),
                          return_sequences=True)(feature)
        else:
            raise ValueError(
                '`temporal` must be one of "conv", "lstm", "gru" or None')

        if residual is True:
            temporal_feature = Add()([feature, x])
        else:
            temporal_feature = x
        temporal_feature_normed = BatchNormalization(
            axis=channel_axis)(temporal_feature)
        return temporal_feature_normed

    temporal_feature = __merge_temporal_features(
        feature,
        mode=temporal,
        residual=residual,
        n_filters=n_dense_filters,
        n_frames=n_frames,
        padding=padding,
        temporal_kernel_size=temporal_kernel_size)
    x.append(temporal_feature)

    x.append(
        TensorProduct(n_dense_filters,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        TensorProduct(n_features,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))

    if not dilated:
        x.append(Flatten()(x[-1]))

    if include_top:
        x.append(Softmax(axis=channel_axis, dtype=K.floatx())(x[-1]))

    model = Model(inputs=x[0], outputs=x[-1])

    return model
Exemple #6
0
def bn_feature_net_3D(receptive_field=61,
                      n_frames=5,
                      input_shape=(5, 256, 256, 1),
                      n_features=3,
                      n_channels=1,
                      reg=1e-5,
                      n_conv_filters=64,
                      n_dense_filters=200,
                      VGG_mode=False,
                      init='he_normal',
                      norm_method='std',
                      location=False,
                      dilated=False,
                      padding=False,
                      padding_mode='reflect',
                      multires=False,
                      include_top=True):
    # Create layers list (x) to store all of the layers.
    # We need to use the functional API to enable the multiresolution mode
    x = []

    win = (receptive_field - 1) // 2
    win_z = (n_frames - 1) // 2

    if dilated:
        padding = True

    if K.image_data_format() == 'channels_first':
        channel_axis = 1
        time_axis = 2
        row_axis = 3
        col_axis = 4
        if not dilated:
            input_shape = (n_channels, n_frames, receptive_field, receptive_field)
    else:
        channel_axis = -1
        time_axis = 1
        row_axis = 2
        col_axis = 3
        if not dilated:
            input_shape = (n_frames, receptive_field, receptive_field, n_channels)

    x.append(Input(shape=input_shape))
    x.append(ImageNormalization3D(norm_method=norm_method, filter_size=receptive_field)(x[-1]))

    if padding:
        if padding_mode == 'reflect':
            x.append(ReflectionPadding3D(padding=(win_z, win, win))(x[-1]))
        elif padding_mode == 'zero':
            x.append(ZeroPadding3D(padding=(win_z, win, win))([-1]))

    if location:
        x.append(Location3D(in_shape=tuple(x[-1].shape.as_list()[1:]))(x[-1]))
        x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]]))

    if multires:
        layers_to_concat = []

    rf_counter = receptive_field
    block_counter = 0
    d = 1

    while rf_counter > 4:
        filter_size = 3 if rf_counter % 2 == 0 else 4
        x.append(Conv3D(n_conv_filters, (1, filter_size, filter_size), dilation_rate=(1, d, d), kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1]))
        x.append(BatchNormalization(axis=channel_axis)(x[-1]))
        x.append(Activation('relu')(x[-1]))

        block_counter += 1
        rf_counter -= filter_size - 1

        if block_counter % 2 == 0:
            if dilated:
                x.append(DilatedMaxPool3D(dilation_rate=(1, d, d), pool_size=(1, 2, 2))(x[-1]))
                d *= 2
            else:
                x.append(MaxPool3D(pool_size=(1, 2, 2))(x[-1]))

            if VGG_mode:
                n_conv_filters *= 2

            rf_counter = rf_counter // 2

            if multires:
                layers_to_concat.append(len(x) - 1)

    if multires:
        c = []
        for l in layers_to_concat:
            output_shape = x[l].get_shape().as_list()
            target_shape = x[-1].get_shape().as_list()
            time_crop = (0, 0)

            row_crop = int(output_shape[row_axis] - target_shape[row_axis])

            if row_crop % 2 == 0:
                row_crop = (row_crop // 2, row_crop // 2)
            else:
                row_crop = (row_crop // 2, row_crop // 2 + 1)

            col_crop = int(output_shape[col_axis] - target_shape[col_axis])

            if col_crop % 2 == 0:
                col_crop = (col_crop // 2, col_crop // 2)
            else:
                col_crop = (col_crop // 2, col_crop // 2 + 1)

            cropping = (time_crop, row_crop, col_crop)

            c.append(Cropping3D(cropping=cropping)(x[l]))
        x.append(Concatenate(axis=channel_axis)(c))

    x.append(Conv3D(n_dense_filters, (1, rf_counter, rf_counter), dilation_rate=(1, d, d), kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(Conv3D(n_dense_filters, (n_frames, 1, 1), dilation_rate=(1, d, d), kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(TensorProd3D(n_dense_filters, n_dense_filters, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(TensorProd3D(n_dense_filters, n_features, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1]))

    if not dilated:
        x.append(Flatten()(x[-1]))

    if include_top:
        x.append(Softmax(axis=channel_axis)(x[-1]))

    model = Model(inputs=x[0], outputs=x[-1])

    return model