Beispiel #1
0
    def _downsampling_block(self,
                            input_layer,
                            block_idx,
                            filter_shape=None,
                            downsample_shape=None):
        """Downsampling blocks of U-net

        :param keras.layer input_layer: must be the output of Input layer
        :param int block_idx: as named
        :param tuple filter_shape: filter size is an int for most cases.
         filter_shape enables passing anisotropic filter shapes
        :return keras.layer layer: output layer of bridge/middle block
         skip_layers_list: list of all skip layers
        """

        assert filter_shape is not None, 'Anisotropic filter shape is required'
        assert downsample_shape is not None, 'Downsample_shape is required'

        if self.config['residual']:
            layer = residual_conv_block(layer=input_layer,
                                        network_config=self.config,
                                        block_idx=block_idx)

        else:
            layer = conv_block(layer=input_layer,
                               network_config=self.config,
                               block_idx=block_idx)
        skip_layers = layer
        if block_idx < self.num_down_blocks - 1:
            pool_object = get_keras_layer(type=self.config['pooling_type'],
                                          num_dims=self.config['num_dims'])
            layer = pool_object(pool_size=downsample_shape,
                                data_format=self.config['data_format'])(layer)

        return layer, skip_layers
    def _downsample_layer(self, layer):
        """Downsample a keras layer"""

        pool_object = get_keras_layer(
            type=self.config['pooling_type'],
            num_dims=self.config['num_dims']
        )
        layer = pool_object(
            pool_size=(2,) * self.config['num_dims'],
            data_format=self.config['data_format']
        )(layer)
        return layer
Beispiel #3
0
    def test_pad_channels(self):
        """Test pad_channels()

        zero-pads the layer along the channel dimension when padding=same.
        zero-pads + crops when padding=valid
        """

        for idx, in_shape in enumerate([self.in_shape_2d, self.in_shape_3d]):
            # create a model that gives padded layer as output
            self.network_config['num_dims'] = \
                self.network_config['num_dims'] + idx

            in_layer = k_layers.Input(shape=in_shape, dtype='float32')
            conv_layer = get_keras_layer('conv',
                                         self.network_config['num_dims'])
            out_layer = conv_layer(
                filters=self.network_config['num_filters_per_block'][0],
                kernel_size=self.network_config['filter_size'],
                padding='same',
                data_format=self.network_config['data_format'])(in_layer)

            channel_axis = get_channel_axis(self.network_config['data_format'])
            layer_padded = k_layers.Lambda(conv_blocks.pad_channels,
                                           arguments={
                                               'final_layer': out_layer,
                                               'channel_axis': channel_axis
                                           })(in_layer)
            # layer padded has zeros in all channels except 8
            model = Model(in_layer, layer_padded)
            test_shape = list(in_shape)
            test_shape.insert(0, 1)
            test_image = np.ones(shape=test_shape)
            sess = K.get_session()
            # forward pass
            out = model.predict(test_image, batch_size=1)
            # test shape: should be the same as conv_layer
            out_shape = list(in_shape)
            out_shape[0] = self.network_config['num_filters_per_block'][0]
            np.testing.assert_array_equal(out_layer.get_shape().as_list()[1:],
                                          out_shape)
            out = np.squeeze(out)
            # only slice 8 is not zero
            nose.tools.assert_equal(np.sum(out), np.sum(out[8]))
            np.testing.assert_array_equal(out[8], np.squeeze(test_image))
            nose.tools.assert_equal(np.sum(out[8]), np.prod(in_shape))
Beispiel #4
0
def downsample_conv_block(layer,
                          network_config,
                          block_idx,
                          downsample_shape=None):
    """Conv-BN-activation block

    :param keras.layers layer: current input layer
    :param dict network_config: please check conv_block()
    :param int block_idx: block index in the network
    :param tuple downsample_shape: anisotropic downsampling kernel shape
    :return: keras.layers after downsampling and conv_block
    """

    conv = get_keras_layer(type='conv', num_dims=network_config['num_dims'])
    block_sequence = network_config['block_sequence'].split('-')
    for conv_idx in range(network_config['num_convs_per_block']):
        for cur_layer_type in block_sequence:
            if cur_layer_type == 'conv':
                if block_idx > 0 and conv_idx == 0:
                    if downsample_shape is None:
                        stride = (2, ) * network_config['num_dims']
                    else:
                        stride = downsample_shape
                else:
                    stride = (1, ) * network_config['num_dims']
                layer = conv(
                    filters=network_config['num_filters_per_block'][block_idx],
                    kernel_size=network_config['filter_size'],
                    strides=stride,
                    padding=network_config['padding'],
                    kernel_initializer=network_config['init'],
                    data_format=network_config['data_format'])(layer)
            elif cur_layer_type == 'bn' and network_config['batch_norm']:
                layer = BatchNormalization(axis=get_channel_axis(
                    network_config['data_format']))(layer)
            else:
                activation_layer_instance = create_activation_layer(
                    network_config['activation'])
                layer = activation_layer_instance(layer)

        if network_config['dropout']:
            layer = Dropout(network_config['dropout'])(layer)
    return layer
Beispiel #5
0
def _merge_residual(final_layer, input_layer, data_format, num_dims,
                    kernel_init, padding):
    """Add residual connection from input to last layer
    :param keras.layers final_layer: last layer
    :param keras.layers input_layer: input_layer
    :param str data_format: [channels_first, channels_last]
    :param int num_dims: as named
    :param str kernel_init: kernel initializer from config
    :param str padding: same or valid
    :return: input_layer 1x1 / padded to match the shape of final_layer
     and added
    """

    channel_axis = get_channel_axis(data_format)
    conv_object = get_keras_layer(type='conv', num_dims=num_dims)
    num_final_layers = int(final_layer.get_shape()[channel_axis])
    num_input_layers = int(input_layer.get_shape()[channel_axis])
    # crop input if padding='valid'
    if padding == 'valid':
        input_layer = Lambda(_crop_layer,
                             arguments={
                                 'final_layer': final_layer,
                                 'data_format': data_format,
                                 'num_dims': num_dims
                             })(input_layer)

    if num_input_layers > num_final_layers:
        # use 1x 1 to get to the desired num of feature maps
        input_layer = conv_object(filters=num_final_layers,
                                  kernel_size=(1, ) * num_dims,
                                  padding='same',
                                  kernel_initializer=kernel_init,
                                  data_format=data_format)(input_layer)
    elif num_input_layers < num_final_layers:
        # padding with zeros along channels
        input_layer = Lambda(pad_channels,
                             arguments={
                                 'final_layer': final_layer,
                                 'channel_axis': channel_axis
                             })(input_layer)
    layer = Add()([final_layer, input_layer])
    return layer
Beispiel #6
0
def conv_block(layer, network_config, block_idx):
    """Convolution block

    Allowed block-seq: [conv-BN-activation, conv-activation-BN,
     BN-activation-conv]
    To accommodate params of advanced activations, activation is a dict with
     keys 'type' and 'params'.
    For a complete list of keys in network_config, refer to
    BaseConvNet.__init__() in base_conv_net.py

    :param keras.layers layer: current input layer
    :param dict network_config: dict with network related keys
    :param int block_idx: block index in the network
    :return: keras.layers after performing operations in block-sequence
     repeated for num_convs_per_block times
    TODO: data_format from network_config won't work for full 3D models in predict
    if depth is set to None
    """

    conv = get_keras_layer(type='conv', num_dims=network_config['num_dims'])
    block_sequence = network_config['block_sequence'].split('-')
    for _ in range(network_config['num_convs_per_block']):
        for cur_layer_type in block_sequence:
            if cur_layer_type == 'conv':
                layer = conv(
                    filters=network_config['num_filters_per_block'][block_idx],
                    kernel_size=network_config['filter_size'],
                    padding=network_config['padding'],
                    kernel_initializer=network_config['init'],
                    data_format=network_config['data_format'])(layer)
            elif cur_layer_type == 'bn' and network_config['batch_norm']:
                layer = BatchNormalization(axis=get_channel_axis(
                    network_config['data_format']))(layer)
            else:
                activation_layer_instance = create_activation_layer(
                    network_config['activation'])
                layer = activation_layer_instance(layer)

        if network_config['dropout']:
            layer = Dropout(network_config['dropout'])(layer)

    return layer
Beispiel #7
0
    def build_net(self):
        """Assemble the network"""

        with tf.name_scope('input'):
            input_layer = inputs = Input(shape=self._get_input_shape)

        # ---------- Downsampling + middle blocks ---------
        skip_layers_list = []
        for block_idx in range(self.num_down_blocks + 1):
            block_name = 'down_block_{}'.format(block_idx + 1)
            with tf.name_scope(block_name):
                layer, cur_skip_layers = self._downsampling_block(
                    input_layer=input_layer, block_idx=block_idx)
            skip_layers_list.append(cur_skip_layers)
            input_layer = layer
        del skip_layers_list[-1]

        # ------------- Upsampling / decoding blocks -------------
        for block_idx in reversed(range(self.num_down_blocks)):
            cur_skip_layers = skip_layers_list[block_idx]
            block_name = 'up_block_{}'.format(block_idx)
            with tf.name_scope(block_name):
                layer = self._upsampling_block(input_layers=input_layer,
                                               skip_layers=cur_skip_layers,
                                               block_idx=block_idx)
            input_layer = layer

        # ------------ output block ------------------------
        final_activation = self.config['final_activation']
        num_output_channels = self.config['num_target_channels']
        conv_object = get_keras_layer(type='conv',
                                      num_dims=self.config['num_dims'])
        with tf.name_scope('output'):
            layer = conv_object(
                filters=num_output_channels,
                kernel_size=(1, ) * self.config['num_dims'],
                padding=self.config['padding'],
                kernel_initializer=self.config['init'],
                data_format=self.config['data_format'])(input_layer)
            outputs = Activation(final_activation)(layer)
        return inputs, outputs
Beispiel #8
0
def residual_downsample_conv_block(layer,
                                   network_config,
                                   block_idx,
                                   downsample_shape=None):
    """Convolution block where the last layer is merged (+) with input layer

    :param keras.layers layer: current input layer
    :param dict network_config: please check conv_block()
    :param int block_idx: block index in the network
    :param tuple downsample_shape: anisotropic downsampling kernel shape
    :return: keras.layers after conv-block and residual merge
    """

    if downsample_shape is None:
        downsample_shape = (2, ) * network_config['num_dims']

    if block_idx == 0:
        input_layer = layer
        final_layer = conv_block(layer, network_config, block_idx)
    else:
        final_layer = downsample_conv_block(layer=layer,
                                            network_config=network_config,
                                            block_idx=block_idx,
                                            downsample_shape=downsample_shape)

        pool_layer = get_keras_layer(type=network_config['pooling_type'],
                                     num_dims=network_config['num_dims'])
        downsampled_input_layer = pool_layer(
            pool_size=downsample_shape,
            data_format=network_config['data_format'])(layer)
        input_layer = downsampled_input_layer

    layer = _merge_residual(final_layer=final_layer,
                            input_layer=input_layer,
                            data_format=network_config['data_format'],
                            num_dims=network_config['num_dims'],
                            kernel_init=network_config['init'],
                            padding=network_config['padding'])
    return layer