示例#1
0
def skip_merge(skip_layers, upsampled_layers, skip_merge_type, data_format,
               num_dims, padding):
    """Skip connection concatenate/add to upsampled layer
    :param keras.layer skip_layers: as named
    :param keras.layer upsampled_layers: as named
    :param str skip_merge_type: [add, concat]
    :param str data_format: [channels_first, channels_last]
    :param int num_dims: as named
    :param str padding: same or valid
    :return: keras.layer skip merged layer
    """

    channel_axis = get_channel_axis(data_format)
    # crop input if padding='valid'
    if padding == 'valid':
        skip_layers = Lambda(_crop_layer,
                             arguments={
                                 'final_layer': upsampled_layers,
                                 'data_format': data_format,
                                 'num_dims': num_dims
                             })(skip_layers)

    if skip_merge_type == 'concat':
        layer = Concatenate(axis=channel_axis)([upsampled_layers, skip_layers])
    else:
        skip_layers = Lambda(pad_channels,
                             arguments={
                                 'final_layer': upsampled_layers,
                                 'channel_axis': channel_axis
                             })(skip_layers)
        layer = Add()([upsampled_layers, skip_layers])
    return layer
示例#2
0
def kl_divergence_loss(y_true, y_pred):
    """KL divergence loss
    D(y||y') = sum(p(y)*log(p(y)/p(y'))

    :param y_true: Ground truth
    :param y_pred: Prediction
    :return float: KL divergence loss
    """
    y_true = K.clip(y_true, K.epsilon(), 1)
    y_pred = K.clip(y_pred, K.epsilon(), 1)
    channel_axis = get_channel_axis(K.image_data_format())
    return K.sum(y_true * K.log(y_true / y_pred), axis=channel_axis)
示例#3
0
def mse_loss(y_true, y_pred, mean_loss=True):
    """Mean squared loss

    :param y_true: Ground truth
    :param y_pred: Prediction
    :return float: Mean squared error loss
    """
    if not mean_loss:
        return K.square(y_pred - y_true)

    channel_axis = get_channel_axis(K.image_data_format())
    return K.mean(K.square(y_pred - y_true), axis=channel_axis)
示例#4
0
def mae_loss(y_true, y_pred, mean_loss=True):
    """Mean absolute error

    Keras losses by default calculate metrics along axis=-1, which works with
    image_format='channels_last'. The arrays do not seem to batch flattened,
    change axis if using 'channels_first
    """
    if not mean_loss:
        return K.abs(y_pred - y_true)

    channel_axis = get_channel_axis(K.image_data_format())
    return K.mean(K.abs(y_pred - y_true), axis=channel_axis)
示例#5
0
def binary_crossentropy_loss(y_true, y_pred, mean_loss=True):
    """Binary cross entropy loss
    :param y_true: Ground truth
    :param y_pred: Prediction
    :return float: Binary cross entropy loss
    """
    assert len(np.unique(y_true).tolist()) <= 2
    assert len(np.unique(y_pred).tolist()) <= 2

    if not mean_loss:
        return K.binary_crossentropy(y_true, y_pred)

    channel_axis = get_channel_axis(K.image_data_format())
    return K.mean(K.binary_crossentropy(y_true, y_pred), axis=channel_axis)
示例#6
0
def _crop_layer(input_layer, final_layer, data_format, num_dims):
    """Crop input layer to match shape of final layer

    ONLY SYMMETRIC CROPPING IS HANDLED HERE!

    :param keras.layers final_layer: last layer of conv block or skip layers
     in Unet
    :param keras.layers input_layer: input_layer to the block
    :param str data_format: [channels_first, channels_last]
    :param int num_dims: as named
    :return: keras.layer, input layer cropped if shape is different than final
     layer, else input layer as is
    """

    input_shape = tf.shape(input_layer)
    final_shape = tf.shape(final_layer)
    # offsets for the top left corner of the crop
    if data_format == 'channels_first':
        offsets = [
            0, 0, (input_shape[2] - final_shape[2]) // 2,
            (input_shape[3] - final_shape[3]) // 2
        ]
        crop_shape = [-1, input_shape[1], final_shape[2], final_shape[3]]
        if num_dims == 3:
            offsets.append((input_shape[4] - final_shape[4]) // 2)
            crop_shape.append(final_shape[4])
    else:
        offsets = [
            0, (input_shape[1] - final_shape[1]) // 2,
            (input_shape[2] - final_shape[2]) // 2
        ]
        crop_shape = [-1, final_shape[1], final_shape[2]]
        if num_dims == 3:
            offsets.append((input_shape[3] - final_shape[3]) // 2)
            crop_shape.append(final_shape[3])
        offsets.append(0)
        crop_shape.append(input_shape[-1])

    # https://github.com/tensorflow/tensorflow/issues/19376
    input_cropped = tf.slice(input_layer, offsets, crop_shape)

    op_shape = final_layer.get_shape().as_list()
    channel_axis = get_channel_axis(data_format)
    op_shape[channel_axis] = input_layer.get_shape().as_list()[channel_axis]
    input_cropped.set_shape(tuple(op_shape))

    return input_cropped
示例#7
0
    def test_pad_channels(self):
        """Test pad_channels()

        zero-pads the layer along the channel dimension when padding=same.
        zero-pads + crops when padding=valid
        """

        for idx, in_shape in enumerate([self.in_shape_2d, self.in_shape_3d]):
            # create a model that gives padded layer as output
            self.network_config['num_dims'] = \
                self.network_config['num_dims'] + idx

            in_layer = k_layers.Input(shape=in_shape, dtype='float32')
            conv_layer = get_keras_layer('conv',
                                         self.network_config['num_dims'])
            out_layer = conv_layer(
                filters=self.network_config['num_filters_per_block'][0],
                kernel_size=self.network_config['filter_size'],
                padding='same',
                data_format=self.network_config['data_format'])(in_layer)

            channel_axis = get_channel_axis(self.network_config['data_format'])
            layer_padded = k_layers.Lambda(conv_blocks.pad_channels,
                                           arguments={
                                               'final_layer': out_layer,
                                               'channel_axis': channel_axis
                                           })(in_layer)
            # layer padded has zeros in all channels except 8
            model = Model(in_layer, layer_padded)
            test_shape = list(in_shape)
            test_shape.insert(0, 1)
            test_image = np.ones(shape=test_shape)
            sess = K.get_session()
            # forward pass
            out = model.predict(test_image, batch_size=1)
            # test shape: should be the same as conv_layer
            out_shape = list(in_shape)
            out_shape[0] = self.network_config['num_filters_per_block'][0]
            np.testing.assert_array_equal(out_layer.get_shape().as_list()[1:],
                                          out_shape)
            out = np.squeeze(out)
            # only slice 8 is not zero
            nose.tools.assert_equal(np.sum(out), np.sum(out[8]))
            np.testing.assert_array_equal(out[8], np.squeeze(test_image))
            nose.tools.assert_equal(np.sum(out[8]), np.prod(in_shape))
示例#8
0
def downsample_conv_block(layer,
                          network_config,
                          block_idx,
                          downsample_shape=None):
    """Conv-BN-activation block

    :param keras.layers layer: current input layer
    :param dict network_config: please check conv_block()
    :param int block_idx: block index in the network
    :param tuple downsample_shape: anisotropic downsampling kernel shape
    :return: keras.layers after downsampling and conv_block
    """

    conv = get_keras_layer(type='conv', num_dims=network_config['num_dims'])
    block_sequence = network_config['block_sequence'].split('-')
    for conv_idx in range(network_config['num_convs_per_block']):
        for cur_layer_type in block_sequence:
            if cur_layer_type == 'conv':
                if block_idx > 0 and conv_idx == 0:
                    if downsample_shape is None:
                        stride = (2, ) * network_config['num_dims']
                    else:
                        stride = downsample_shape
                else:
                    stride = (1, ) * network_config['num_dims']
                layer = conv(
                    filters=network_config['num_filters_per_block'][block_idx],
                    kernel_size=network_config['filter_size'],
                    strides=stride,
                    padding=network_config['padding'],
                    kernel_initializer=network_config['init'],
                    data_format=network_config['data_format'])(layer)
            elif cur_layer_type == 'bn' and network_config['batch_norm']:
                layer = BatchNormalization(axis=get_channel_axis(
                    network_config['data_format']))(layer)
            else:
                activation_layer_instance = create_activation_layer(
                    network_config['activation'])
                layer = activation_layer_instance(layer)

        if network_config['dropout']:
            layer = Dropout(network_config['dropout'])(layer)
    return layer
示例#9
0
def _split_ytrue_mask(y_true, n_channels):
    """Split the mask concatenated with y_true

    :param keras.tensor y_true: if channels_first, ytrue has shape [batch_size,
     n_channels, y, x]. mask is concatenated as the n_channels+1, shape:
     [[batch_size, n_channels+1, y, x].
    :param int n_channels: number of channels in y_true
    :return:
     keras.tensor ytrue_split - ytrue with the mask removed
     keras.tensor mask_image - bool mask
    """

    try:
        split_axis = get_channel_axis(K.image_data_format())
        y_true_split, mask_image = tf.split(y_true, [n_channels, 1],
                                            axis=split_axis)
        return y_true_split, mask_image
    except Exception as e:
        print('cannot separate mask and y_true' + str(e))
示例#10
0
def _merge_residual(final_layer, input_layer, data_format, num_dims,
                    kernel_init, padding):
    """Add residual connection from input to last layer
    :param keras.layers final_layer: last layer
    :param keras.layers input_layer: input_layer
    :param str data_format: [channels_first, channels_last]
    :param int num_dims: as named
    :param str kernel_init: kernel initializer from config
    :param str padding: same or valid
    :return: input_layer 1x1 / padded to match the shape of final_layer
     and added
    """

    channel_axis = get_channel_axis(data_format)
    conv_object = get_keras_layer(type='conv', num_dims=num_dims)
    num_final_layers = int(final_layer.get_shape()[channel_axis])
    num_input_layers = int(input_layer.get_shape()[channel_axis])
    # crop input if padding='valid'
    if padding == 'valid':
        input_layer = Lambda(_crop_layer,
                             arguments={
                                 'final_layer': final_layer,
                                 'data_format': data_format,
                                 'num_dims': num_dims
                             })(input_layer)

    if num_input_layers > num_final_layers:
        # use 1x 1 to get to the desired num of feature maps
        input_layer = conv_object(filters=num_final_layers,
                                  kernel_size=(1, ) * num_dims,
                                  padding='same',
                                  kernel_initializer=kernel_init,
                                  data_format=data_format)(input_layer)
    elif num_input_layers < num_final_layers:
        # padding with zeros along channels
        input_layer = Lambda(pad_channels,
                             arguments={
                                 'final_layer': final_layer,
                                 'channel_axis': channel_axis
                             })(input_layer)
    layer = Add()([final_layer, input_layer])
    return layer
示例#11
0
def conv_block(layer, network_config, block_idx):
    """Convolution block

    Allowed block-seq: [conv-BN-activation, conv-activation-BN,
     BN-activation-conv]
    To accommodate params of advanced activations, activation is a dict with
     keys 'type' and 'params'.
    For a complete list of keys in network_config, refer to
    BaseConvNet.__init__() in base_conv_net.py

    :param keras.layers layer: current input layer
    :param dict network_config: dict with network related keys
    :param int block_idx: block index in the network
    :return: keras.layers after performing operations in block-sequence
     repeated for num_convs_per_block times
    TODO: data_format from network_config won't work for full 3D models in predict
    if depth is set to None
    """

    conv = get_keras_layer(type='conv', num_dims=network_config['num_dims'])
    block_sequence = network_config['block_sequence'].split('-')
    for _ in range(network_config['num_convs_per_block']):
        for cur_layer_type in block_sequence:
            if cur_layer_type == 'conv':
                layer = conv(
                    filters=network_config['num_filters_per_block'][block_idx],
                    kernel_size=network_config['filter_size'],
                    padding=network_config['padding'],
                    kernel_initializer=network_config['init'],
                    data_format=network_config['data_format'])(layer)
            elif cur_layer_type == 'bn' and network_config['batch_norm']:
                layer = BatchNormalization(axis=get_channel_axis(
                    network_config['data_format']))(layer)
            else:
                activation_layer_instance = create_activation_layer(
                    network_config['activation'])
                layer = activation_layer_instance(layer)

        if network_config['dropout']:
            layer = Dropout(network_config['dropout'])(layer)

    return layer
示例#12
0
def test_get_channel_axis_first():
    channel_axis = aux_utils.get_channel_axis('channels_first')
    nose.tools.assert_equal(channel_axis, 1)