Example #1
0
def bn_feature_net_3D(receptive_field=61,
                      n_frames=5,
                      input_shape=(5, 256, 256, 1),
                      n_features=3,
                      n_channels=1,
                      reg=1e-5,
                      n_conv_filters=64,
                      n_dense_filters=200,
                      VGG_mode=False,
                      init='he_normal',
                      norm_method='std',
                      location=False,
                      dilated=False,
                      padding=False,
                      padding_mode='reflect',
                      multires=False,
                      include_top=True,
                      temporal=None,
                      residual=False,
                      temporal_kernel_size=3):
    """Creates a 3D featurenet.

    Args:
        receptive_field (int): the receptive field of the neural network.
        n_frames (int): Number of frames.
        input_shape (tuple): If no input tensor, create one with this shape.
        n_features (int): Number of output features
        n_channels (int): number of input channels
        reg (int): regularization value
        n_conv_filters (int): number of convolutional filters
        n_dense_filters (int): number of dense filters
        VGG_mode (bool): If ``multires``, uses ``VGG_mode``
            for multiresolution
        init (str): Method for initalizing weights.
        norm_method (str): Normalization method to use with the
            :mod:`deepcell.layers.normalization.ImageNormalization3D` layer.
        location (bool): Whether to include a
            :mod:`deepcell.layers.location.Location3D` layer.
        dilated (bool): Whether to use dilated pooling.
        padding (bool): Whether to use padding.
        padding_mode (str): Type of padding, one of 'reflect' or 'zero'
        multires (bool): Enables multi-resolution mode
        include_top (bool): Whether to include the final layer of the model
        temporal (str): Type of temporal operation
        residual (bool): Whether to use temporal information as a residual
        temporal_kernel_size (int): size of 2D kernel used in temporal convolutions

    Returns:
        tensorflow.keras.Model: 3D FeatureNet
    """
    # Create layers list (x) to store all of the layers.
    # We need to use the functional API to enable the multiresolution mode
    x = []

    win = (receptive_field - 1) // 2
    win_z = (n_frames - 1) // 2

    if dilated:
        padding = True

    if K.image_data_format() == 'channels_first':
        channel_axis = 1
        time_axis = 2
        row_axis = 3
        col_axis = 4
        if not dilated:
            input_shape = (n_channels, n_frames, receptive_field,
                           receptive_field)
    else:
        channel_axis = -1
        time_axis = 1
        row_axis = 2
        col_axis = 3
        if not dilated:
            input_shape = (n_frames, receptive_field, receptive_field,
                           n_channels)

    x.append(Input(shape=input_shape))
    x.append(
        ImageNormalization3D(norm_method=norm_method,
                             filter_size=receptive_field)(x[-1]))

    if padding:
        if padding_mode == 'reflect':
            x.append(ReflectionPadding3D(padding=(win_z, win, win))(x[-1]))
        elif padding_mode == 'zero':
            x.append(ZeroPadding3D(padding=(win_z, win, win))(x[-1]))

    if location:
        x.append(Location3D()(x[-1]))
        x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]]))

    layers_to_concat = []

    rf_counter = receptive_field
    block_counter = 0
    d = 1

    while rf_counter > 4:
        filter_size = 3 if rf_counter % 2 == 0 else 4
        x.append(
            Conv3D(n_conv_filters, (1, filter_size, filter_size),
                   dilation_rate=(1, d, d),
                   kernel_initializer=init,
                   padding='valid',
                   kernel_regularizer=l2(reg))(x[-1]))
        x.append(BatchNormalization(axis=channel_axis)(x[-1]))
        x.append(Activation('relu')(x[-1]))

        block_counter += 1
        rf_counter -= filter_size - 1

        if block_counter % 2 == 0:
            if dilated:
                x.append(
                    DilatedMaxPool3D(dilation_rate=(1, d, d),
                                     pool_size=(1, 2, 2))(x[-1]))
                d *= 2
            else:
                x.append(MaxPool3D(pool_size=(1, 2, 2))(x[-1]))

            if VGG_mode:
                n_conv_filters *= 2

            rf_counter = rf_counter // 2

            if multires:
                layers_to_concat.append(len(x) - 1)

    if multires:
        c = []
        for l in layers_to_concat:
            output_shape = x[l].get_shape().as_list()
            target_shape = x[-1].get_shape().as_list()
            time_crop = (0, 0)

            row_crop = int(output_shape[row_axis] - target_shape[row_axis])

            if row_crop % 2 == 0:
                row_crop = (row_crop // 2, row_crop // 2)
            else:
                row_crop = (row_crop // 2, row_crop // 2 + 1)

            col_crop = int(output_shape[col_axis] - target_shape[col_axis])

            if col_crop % 2 == 0:
                col_crop = (col_crop // 2, col_crop // 2)
            else:
                col_crop = (col_crop // 2, col_crop // 2 + 1)

            cropping = (time_crop, row_crop, col_crop)

            c.append(Cropping3D(cropping=cropping)(x[l]))
        x.append(Concatenate(axis=channel_axis)(c))

    x.append(
        Conv3D(n_dense_filters, (1, rf_counter, rf_counter),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        Conv3D(n_dense_filters, (n_frames, 1, 1),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    feature = Activation('relu')(x[-1])

    def __merge_temporal_features(feature,
                                  mode='conv',
                                  residual=False,
                                  n_filters=256,
                                  n_frames=3,
                                  padding=True,
                                  temporal_kernel_size=3):
        if mode is None:
            return feature

        mode = str(mode).lower()
        if mode == 'conv':
            x = Conv3D(n_filters,
                       (n_frames, temporal_kernel_size, temporal_kernel_size),
                       kernel_initializer=init,
                       padding='same',
                       activation='relu',
                       kernel_regularizer=l2(reg))(feature)
        elif mode == 'lstm':
            x = ConvLSTM2D(filters=n_filters,
                           kernel_size=temporal_kernel_size,
                           padding='same',
                           kernel_initializer=init,
                           activation='relu',
                           kernel_regularizer=l2(reg),
                           return_sequences=True)(feature)
        elif mode == 'gru':
            x = ConvGRU2D(filters=n_filters,
                          kernel_size=temporal_kernel_size,
                          padding='same',
                          kernel_initializer=init,
                          activation='relu',
                          kernel_regularizer=l2(reg),
                          return_sequences=True)(feature)
        else:
            raise ValueError(
                '`temporal` must be one of "conv", "lstm", "gru" or None')

        if residual is True:
            temporal_feature = Add()([feature, x])
        else:
            temporal_feature = x
        temporal_feature_normed = BatchNormalization(
            axis=channel_axis)(temporal_feature)
        return temporal_feature_normed

    temporal_feature = __merge_temporal_features(
        feature,
        mode=temporal,
        residual=residual,
        n_filters=n_dense_filters,
        n_frames=n_frames,
        padding=padding,
        temporal_kernel_size=temporal_kernel_size)
    x.append(temporal_feature)

    x.append(
        TensorProduct(n_dense_filters,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        TensorProduct(n_features,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))

    if not dilated:
        x.append(Flatten()(x[-1]))

    if include_top:
        x.append(Softmax(axis=channel_axis, dtype=K.floatx())(x[-1]))

    model = Model(inputs=x[0], outputs=x[-1])

    return model
Example #2
0
def bn_feature_net_3D(receptive_field=61,
                      n_frames=5,
                      input_shape=(5, 256, 256, 1),
                      n_features=3,
                      n_channels=1,
                      reg=1e-5,
                      n_conv_filters=64,
                      n_dense_filters=200,
                      VGG_mode=False,
                      init='he_normal',
                      norm_method='std',
                      location=False,
                      dilated=False,
                      padding=False,
                      padding_mode='reflect',
                      multires=False,
                      include_top=True):
    """Creates a 3D featurenet.

    Args:
        receptive_field (int): the receptive field of the neural network.
        n_frames (int): Number of frames.
        input_shape (tuple): If no input tensor, create one with this shape.
        n_features (int): Number of output features
        n_channels (int): number of input channels
        reg (int): regularization value
        n_conv_filters (int): number of convolutional filters
        n_dense_filters (int): number of dense filters
        VGG_mode (bool): If multires, uses VGG_mode for multiresolution
        init (str): Method for initalizing weights.
        norm_method (str): ImageNormalization mode to use
        location (bool): Whether to include location data
        dilated (bool): Whether to use dilated pooling.
        padding (bool): Whether to use padding.
        padding_mode (str): Type of padding, one of 'reflect' or 'zero'
        multires (bool): Enables multi-resolution mode
        include_top (bool): Whether to include the final layer of the model

    Returns:
        tensorflow.keras.Model: 3D FeatureNet
    """
    # Create layers list (x) to store all of the layers.
    # We need to use the functional API to enable the multiresolution mode
    x = []

    win = (receptive_field - 1) // 2
    win_z = (n_frames - 1) // 2

    if dilated:
        padding = True

    if K.image_data_format() == 'channels_first':
        channel_axis = 1
        time_axis = 2
        row_axis = 3
        col_axis = 4
        if not dilated:
            input_shape = (n_channels, n_frames, receptive_field,
                           receptive_field)
    else:
        channel_axis = -1
        time_axis = 1
        row_axis = 2
        col_axis = 3
        if not dilated:
            input_shape = (n_frames, receptive_field, receptive_field,
                           n_channels)

    x.append(Input(shape=input_shape))
    x.append(
        ImageNormalization3D(norm_method=norm_method,
                             filter_size=receptive_field)(x[-1]))

    if padding:
        if padding_mode == 'reflect':
            x.append(ReflectionPadding3D(padding=(win_z, win, win))(x[-1]))
        elif padding_mode == 'zero':
            x.append(ZeroPadding3D(padding=(win_z, win, win))(x[-1]))

    if location:
        x.append(Location3D(in_shape=tuple(x[-1].shape.as_list()[1:]))(x[-1]))
        x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]]))

    layers_to_concat = []

    rf_counter = receptive_field
    block_counter = 0
    d = 1

    while rf_counter > 4:
        filter_size = 3 if rf_counter % 2 == 0 else 4
        x.append(
            Conv3D(n_conv_filters, (1, filter_size, filter_size),
                   dilation_rate=(1, d, d),
                   kernel_initializer=init,
                   padding='valid',
                   kernel_regularizer=l2(reg))(x[-1]))
        x.append(BatchNormalization(axis=channel_axis)(x[-1]))
        x.append(Activation('relu')(x[-1]))

        block_counter += 1
        rf_counter -= filter_size - 1

        if block_counter % 2 == 0:
            if dilated:
                x.append(
                    DilatedMaxPool3D(dilation_rate=(1, d, d),
                                     pool_size=(1, 2, 2))(x[-1]))
                d *= 2
            else:
                x.append(MaxPool3D(pool_size=(1, 2, 2))(x[-1]))

            if VGG_mode:
                n_conv_filters *= 2

            rf_counter = rf_counter // 2

            if multires:
                layers_to_concat.append(len(x) - 1)

    if multires:
        c = []
        for l in layers_to_concat:
            output_shape = x[l].get_shape().as_list()
            target_shape = x[-1].get_shape().as_list()
            time_crop = (0, 0)

            row_crop = int(output_shape[row_axis] - target_shape[row_axis])

            if row_crop % 2 == 0:
                row_crop = (row_crop // 2, row_crop // 2)
            else:
                row_crop = (row_crop // 2, row_crop // 2 + 1)

            col_crop = int(output_shape[col_axis] - target_shape[col_axis])

            if col_crop % 2 == 0:
                col_crop = (col_crop // 2, col_crop // 2)
            else:
                col_crop = (col_crop // 2, col_crop // 2 + 1)

            cropping = (time_crop, row_crop, col_crop)

            c.append(Cropping3D(cropping=cropping)(x[l]))
        x.append(Concatenate(axis=channel_axis)(c))

    x.append(
        Conv3D(n_dense_filters, (1, rf_counter, rf_counter),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        Conv3D(n_dense_filters, (n_frames, 1, 1),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        TensorProduct(n_dense_filters,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        TensorProduct(n_features,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))

    if not dilated:
        x.append(Flatten()(x[-1]))

    if include_top:
        x.append(Softmax(axis=channel_axis)(x[-1]))

    model = Model(inputs=x[0], outputs=x[-1])

    return model
Example #3
0
def bn_feature_net_3D(receptive_field=61,
                      n_frames=5,
                      input_shape=(5, 256, 256, 1),
                      n_features=3,
                      n_channels=1,
                      reg=1e-5,
                      n_conv_filters=64,
                      n_dense_filters=200,
                      VGG_mode=False,
                      init='he_normal',
                      norm_method='std',
                      location=False,
                      dilated=False,
                      padding=False,
                      padding_mode='reflect',
                      multires=False,
                      include_top=True):
    # Create layers list (x) to store all of the layers.
    # We need to use the functional API to enable the multiresolution mode
    x = []

    win = (receptive_field - 1) // 2
    win_z = (n_frames - 1) // 2

    if dilated:
        padding = True

    if K.image_data_format() == 'channels_first':
        channel_axis = 1
        time_axis = 2
        row_axis = 3
        col_axis = 4
        if not dilated:
            input_shape = (n_channels, n_frames, receptive_field, receptive_field)
    else:
        channel_axis = -1
        time_axis = 1
        row_axis = 2
        col_axis = 3
        if not dilated:
            input_shape = (n_frames, receptive_field, receptive_field, n_channels)

    x.append(Input(shape=input_shape))
    x.append(ImageNormalization3D(norm_method=norm_method, filter_size=receptive_field)(x[-1]))

    if padding:
        if padding_mode == 'reflect':
            x.append(ReflectionPadding3D(padding=(win_z, win, win))(x[-1]))
        elif padding_mode == 'zero':
            x.append(ZeroPadding3D(padding=(win_z, win, win))([-1]))

    if location:
        x.append(Location3D(in_shape=tuple(x[-1].shape.as_list()[1:]))(x[-1]))
        x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]]))

    if multires:
        layers_to_concat = []

    rf_counter = receptive_field
    block_counter = 0
    d = 1

    while rf_counter > 4:
        filter_size = 3 if rf_counter % 2 == 0 else 4
        x.append(Conv3D(n_conv_filters, (1, filter_size, filter_size), dilation_rate=(1, d, d), kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1]))
        x.append(BatchNormalization(axis=channel_axis)(x[-1]))
        x.append(Activation('relu')(x[-1]))

        block_counter += 1
        rf_counter -= filter_size - 1

        if block_counter % 2 == 0:
            if dilated:
                x.append(DilatedMaxPool3D(dilation_rate=(1, d, d), pool_size=(1, 2, 2))(x[-1]))
                d *= 2
            else:
                x.append(MaxPool3D(pool_size=(1, 2, 2))(x[-1]))

            if VGG_mode:
                n_conv_filters *= 2

            rf_counter = rf_counter // 2

            if multires:
                layers_to_concat.append(len(x) - 1)

    if multires:
        c = []
        for l in layers_to_concat:
            output_shape = x[l].get_shape().as_list()
            target_shape = x[-1].get_shape().as_list()
            time_crop = (0, 0)

            row_crop = int(output_shape[row_axis] - target_shape[row_axis])

            if row_crop % 2 == 0:
                row_crop = (row_crop // 2, row_crop // 2)
            else:
                row_crop = (row_crop // 2, row_crop // 2 + 1)

            col_crop = int(output_shape[col_axis] - target_shape[col_axis])

            if col_crop % 2 == 0:
                col_crop = (col_crop // 2, col_crop // 2)
            else:
                col_crop = (col_crop // 2, col_crop // 2 + 1)

            cropping = (time_crop, row_crop, col_crop)

            c.append(Cropping3D(cropping=cropping)(x[l]))
        x.append(Concatenate(axis=channel_axis)(c))

    x.append(Conv3D(n_dense_filters, (1, rf_counter, rf_counter), dilation_rate=(1, d, d), kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(Conv3D(n_dense_filters, (n_frames, 1, 1), dilation_rate=(1, d, d), kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(TensorProd3D(n_dense_filters, n_dense_filters, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(TensorProd3D(n_dense_filters, n_features, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1]))

    if not dilated:
        x.append(Flatten()(x[-1]))

    if include_top:
        x.append(Softmax(axis=channel_axis)(x[-1]))

    model = Model(inputs=x[0], outputs=x[-1])

    return model