コード例 #1
0
def test_leaky_relu():
    test_alpha = tf.constant(0.1)
    test_inp_1 = tf.constant(1.)
    test_inp_2 = tf.constant(-1.)

    test_relu_1 = leaky_relu(test_inp_1, test_alpha)
    test_relu_2 = leaky_relu(test_inp_2, test_alpha)

    with tf.Session() as s:
        out_1 = s.run(test_relu_1)
        assert np.isclose(out_1, 1.), \
            'Got {} but expected {}'.format(out_1, 1.)

        out_2 = s.run(test_relu_2)
        assert np.isclose(out_2, -0.1), \
            'Got {} but expected {}'.format(out_2, -0.1)
コード例 #2
0
ファイル: deepmedic.py プロジェクト: Mulugeta/DLTK
    def _build_normal_pathway(x):
        with tf.variable_scope('normal_pathway'):
            tf.logging.info('Building normal pathway')
            center_crop = crop_central_block(x, normal_input_shape)
            tf.logging.info('Input is {}'.format(
                center_crop.get_shape().as_list()))

            layers = []

            x = center_crop
            for i in range(len(normal_filters)):
                with tf.variable_scope('layer_{}'.format(i)):
                    layers.append(x)
                    if i > 0:
                        x = tf.layers.batch_normalization(
                            x, training=mode == tf.estimator.ModeKeys.TRAIN)
                        x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
                    x = tf.layers.conv3d(x,
                                         normal_filters[i],
                                         normal_kernels[i],
                                         normal_strides[i],
                                         **conv_params)
                    # TODO: add pooling and dropout?!
                    if i + 1 in normal_residuals:
                        x = _residual_connection(x, layers[i - 1])
                    tf.logging.info('Output of layer {} is {}'.format(
                        i, x.get_shape().as_list()))
        tf.logging.info('Output is {}'.format(x.get_shape().as_list()))
        return x
    def _build_normal_pathway(x):
        with tf.variable_scope('normal_pathway'):
            tf.logging.info('Building normal pathway')
            center_crop = crop_central_block(x, normal_input_shape)
            tf.logging.info('Input is {}'.format(
                center_crop.get_shape().as_list()))

            layers = []

            x = center_crop
            for i in range(len(normal_filters)):
                with tf.variable_scope('layer_{}'.format(i)):
                    layers.append(x)
                    if i > 0:
                        x = tf.layers.batch_normalization(
                            x, training=mode == tf.estimator.ModeKeys.TRAIN)
                        x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
                    x = tf.layers.conv3d(x, normal_filters[i],
                                         normal_kernels[i], normal_strides[i],
                                         **conv_params)
                    # TODO: add pooling and dropout?!
                    if i + 1 in normal_residuals:
                        x = _residual_connection(x, layers[i - 1])
                    tf.logging.info('Output of layer {} is {}'.format(
                        i,
                        x.get_shape().as_list()))
        tf.logging.info('Output is {}'.format(x.get_shape().as_list()))
        return x
コード例 #4
0
def upsample_and_concat_with_conv(inputs,
                                  inputs2,
                                  strides=(2, 2, 2),
                                  mode=tf.estimator.ModeKeys.EVAL,
                                  conv_params=global_conv_params):
    """Upsampling and concatenation layer according to [1].

    [1] O. Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image
        Segmentation. MICCAI 2015.

    Args:
        inputs (TYPE): Input features to be upsampled.
        inputs2 (TYPE): Higher resolution features from the encoder to
            concatenate.
        strides (tuple, optional): Upsampling factor for a strided transpose
            convolution.

    Returns:
        tf.Tensor: Upsampled feature tensor
    """
    assert len(inputs.get_shape().as_list()) == 5, \
        'inputs are required to have a rank of 5.'
    assert len(inputs.get_shape().as_list()) == len(inputs2.get_shape().as_list()), \
        'Ranks of input and input2 differ'

    # Upsample inputs
    inputs1 = linear_upsample_3d(inputs, strides)

    with tf.variable_scope('reduce_channel_unit'):
        result = tf.concat(axis=-1, values=[inputs2, inputs1])
        result = tf.layers.batch_normalization(
            result, training=mode == tf.estimator.ModeKeys.TRAIN)
        result = leaky_relu(result)

        result = tf.layers.conv3d(inputs=result,
                                  filters=inputs2.shape[-1],
                                  kernel_size=(3, 3, 3),
                                  strides=(1, 1, 1),
                                  **conv_params)

    return result
    def _build_subsampled_pathways(x):
        pathways = []
        for pathway in range(len(subsample_factors)):
            with tf.variable_scope('subsampled_pathway_{}'.format(pathway)):
                tf.logging.info(
                    'Building subsampled pathway {}'.format(pathway))
                center_crop = crop_central_block(
                    x, subsampled_input_shapes[pathway])
                tf.logging.info('Input is {}'.format(
                    center_crop.get_shape().as_list()))

                layers = []

                x = center_crop
                x = _downsample(x, subsample_factors[pathway])
                tf.logging.info('Downsampled input is {}'.format(
                    x.get_shape().as_list()))

                for i in range(len(subsampled_filters[pathway])):
                    with tf.variable_scope('layer_{}'.format(i)):
                        layers.append(x)
                        if i > 0:
                            x = tf.layers.batch_normalization(
                                x,
                                training=mode == tf.estimator.ModeKeys.TRAIN)
                            x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
                        x = tf.layers.conv3d(x, subsampled_filters[pathway][i],
                                             subsampled_kernels[pathway][i],
                                             subsampled_strides[pathway][i],
                                             **conv_params)
                        # TODO: add pooling and dropout?!
                        if i + 1 in subsampled_residuals:
                            x = _residual_connection(x, layers[i - 1])
                        tf.logging.info('Output of layer {} is {}'.format(
                            i,
                            x.get_shape().as_list()))

                x = _upsample(x, subsample_factors[pathway])
                tf.logging.info('Output is {}'.format(x.get_shape().as_list()))
                pathways.append(x)
        return pathways
コード例 #6
0
ファイル: deepmedic.py プロジェクト: Mulugeta/DLTK
    def _build_subsampled_pathways(x):
        pathways = []
        for pathway in range(len(subsample_factors)):
            with tf.variable_scope('subsampled_pathway_{}'.format(pathway)):
                tf.logging.info(
                    'Building subsampled pathway {}'.format(pathway))
                center_crop = crop_central_block(
                    x, subsampled_input_shapes[pathway])
                tf.logging.info('Input is {}'.format(
                    center_crop.get_shape().as_list()))

                layers = []

                x = center_crop
                x = _downsample(x, subsample_factors[pathway])
                tf.logging.info('Downsampled input is {}'.format(
                    x.get_shape().as_list()))

                for i in range(len(subsampled_filters[pathway])):
                    with tf.variable_scope('layer_{}'.format(i)):
                        layers.append(x)
                        if i > 0:
                            x = tf.layers.batch_normalization(
                                x, training=mode == tf.estimator.ModeKeys.TRAIN)
                            x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
                        x = tf.layers.conv3d(x, subsampled_filters[pathway][i],
                                             subsampled_kernels[pathway][i],
                                             subsampled_strides[pathway][i],
                                             **conv_params)
                        # TODO: add pooling and dropout?!
                        if i + 1 in subsampled_residuals:
                            x = _residual_connection(x, layers[i - 1])
                        tf.logging.info('Output of layer {} is {}'.format(
                            i, x.get_shape().as_list()))

                x = _upsample(x, subsample_factors[pathway])
                tf.logging.info('Output is {}'.format(x.get_shape().as_list()))
                pathways.append(x)
        return pathways
コード例 #7
0
def upsample_and_conv(inputs,
                      strides=(2, 2, 2),
                      mode=tf.estimator.ModeKeys.EVAL,
                      conv_params=global_conv_params,
                      name=None):
    assert len(inputs.get_shape().as_list()) == 5, \
        'inputs are required to have a rank of 5.'

    # Upsample inputs
    with tf.variable_scope('reduce_channel_unit_{}'.format(name)):
        inputs1 = linear_upsample_3d(inputs, strides)

        result = tf.layers.batch_normalization(
            inputs1, training=mode == tf.estimator.ModeKeys.TRAIN)
        result = leaky_relu(result)

        result = tf.layers.conv3d(inputs=result,
                                  filters=inputs.get_shape().as_list()[-1] / 2,
                                  kernel_size=(3, 3, 3),
                                  strides=(1, 1, 1),
                                  **conv_params)

    return result
def lrelu(x):
    return leaky_relu(x, 0.1)
コード例 #9
0
def dcgan_discriminator_3d(inputs,
                           filters=(64, 128, 256, 512),
                           strides=((2, 2, 2), (2, 2, 2), (1, 2, 2), (1, 2,
                                                                      2)),
                           mode=tf.estimator.ModeKeys.EVAL,
                           use_bias=False):
    """
    Deep convolutional generative adversarial network (DCGAN) discriminator
    network with num_convolutions on len(filters) resolution scales. The
    downsampling of features is done via strided convolutions. On each
    resolution scale s are num_convolutions with filter size = filters[s].
    strides[s] determine the downsampling factor at each resolution scale.

    Args:
        inputs (tf.Tensor): Input tensor to the network, required to be of
            rank 5.
        num_convolutions (int, optional): Number of convolutions per resolution
            scale.
        filters (tuple, optional): Number of filters for all convolutions at
            each resolution scale.
        strides (tuple, optional): Stride of the first convolution on a
            resolution scale.
        mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: TRAIN,
            EVAL or PREDICT.
        use_bias (bool, optional): Boolean, whether the layer uses a bias.

    Returns:
        dict: dictionary of output tensors

    """
    outputs = {}
    assert len(strides) == len(filters)
    assert len(inputs.get_shape().as_list()) == 5,\
        'inputs are required to have a rank of 5.'

    conv_op = tf.layers.conv3d

    conv_params = {
        'padding': 'same',
        'use_bias': use_bias,
        'kernel_initializer': tf.uniform_unit_scaling_initializer(),
        'bias_initializer': tf.zeros_initializer(),
        'kernel_regularizer': None,
        'bias_regularizer': None
    }

    x = inputs
    tf.logging.info('Input tensor shape {}'.format(x.get_shape()))

    for res_scale in range(0, len(filters)):
        with tf.variable_scope('disc_unit_{}'.format(res_scale)):

            x = conv_op(inputs=x,
                        filters=filters[res_scale],
                        kernel_size=(3, 3, 3),
                        strides=strides[res_scale],
                        **conv_params)

            x = tf.layers.batch_normalization(
                x, training=mode == tf.estimator.ModeKeys.TRAIN)

            x = leaky_relu(x, 0.2)

    x_shape = x.get_shape().as_list()
    x = tf.reshape(x, (tf.shape(x)[0], np.prod(x_shape[1:])))

    x = tf.layers.dense(inputs=x,
                        units=1,
                        use_bias=True,
                        kernel_initializer=conv_params['kernel_initializer'],
                        bias_initializer=conv_params['bias_initializer'],
                        kernel_regularizer=conv_params['kernel_regularizer'],
                        bias_regularizer=conv_params['bias_regularizer'],
                        name='out')

    outputs['logits'] = x

    outputs['probs'] = tf.nn.sigmoid(x)

    outputs['pred'] = tf.cast((x > 0.5), tf.int32)

    return outputs
コード例 #10
0
def dcgan_generator_3d(inputs,
                       filters=(256, 128, 64, 32, 1),
                       kernel_size=((4, 4, 4), (3, 3, 3), (3, 3, 3), (3, 3, 3),
                                    (4, 4, 4)),
                       strides=((4, 4, 4), (1, 2, 2), (1, 2, 2), (1, 2, 2),
                                (1, 2, 2)),
                       mode=tf.estimator.ModeKeys.TRAIN,
                       use_bias=False):
    """
    Deep convolutional generative adversial network (DCGAN) generator
    network. with num_convolutions on len(filters) resolution scales. The
    upsampling of features is done via strided transpose convolutions. On
    each resolution scale s are num_convolutions with filter size = filters[
    s]. strides[s] determine the upsampling factor at each resolution scale.

    Args:
        inputs (tf.Tensor): Input noise tensor to the network.
        out_filters (int): Number of output filters.
        num_convolutions (int, optional): Number of convolutions per resolution
            scale.
        filters (tuple, optional): Number of filters for all convolutions at
            each resolution scale.
        strides (tuple, optional): Stride of the first convolution on a
            resolution scale.
        mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: TRAIN,
            EVAL or PREDICT
        use_bias (bool, optional): Boolean, whether the layer uses a bias.

    Returns:
        dict: dictionary of output tensors

    """
    outputs = {}
    assert len(strides) == len(filters)
    assert len(inputs.get_shape().as_list()) == 5, \
        'inputs are required to have a rank of 5.'

    conv_op = tf.layers.conv3d

    conv_params = {
        'padding': 'same',
        'use_bias': use_bias,
        'kernel_initializer': tf.uniform_unit_scaling_initializer(),
        'bias_initializer': tf.zeros_initializer(),
        'kernel_regularizer': None,
        'bias_regularizer': None
    }

    x = inputs
    tf.logging.info('Input tensor shape {}'.format(x.get_shape()))

    for res_scale in range(0, len(filters)):
        with tf.variable_scope('gen_unit_{}'.format(res_scale)):

            tf.logging.info('Generator at res_scale before up {} tensor '
                            'shape: {}'.format(res_scale, x.get_shape()))

            x = linear_upsample_3d(x, strides[res_scale], trainable=True)

            x = conv_op(inputs=x,
                        filters=filters[res_scale],
                        kernel_size=kernel_size[res_scale],
                        **conv_params)

            tf.logging.info('Generator at res_scale after up {} tensor '
                            'shape: {}'.format(res_scale, x.get_shape()))

            x = tf.layers.batch_normalization(
                x, training=mode == tf.estimator.ModeKeys.TRAIN)

            x = leaky_relu(x, 0.2)
            tf.logging.info('Generator at res_scale {} tensor shape: '
                            '{}'.format(res_scale, x.get_shape()))

    outputs['gen'] = x

    return outputs
コード例 #11
0
ファイル: deepmedic.py プロジェクト: Mulugeta/DLTK
def deepmedic_3d(inputs, num_classes,
                 normal_filters=(30, 30, 40, 40, 40, 40, 50, 50),
                 normal_strides=((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1),
                                 (1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
                 normal_kernels=((3, 3, 3), (3, 3, 3), (3, 3, 3), (3, 3, 3),
                                 (3, 3, 3), (3, 3, 3), (3, 3, 3), (3, 3, 3)),
                 normal_residuals=(4, 6, 8),
                 normal_input_shape=(25, 25, 25),
                 subsampled_filters=((30, 30, 40, 40, 40, 40, 50, 50),),
                 subsampled_strides=(((1, 1, 1), (1, 1, 1), (1, 1, 1),
                                      (1, 1, 1), (1, 1, 1), (1, 1, 1),
                                      (1, 1, 1), (1, 1, 1)),),
                 subsampled_kernels=(((3, 3, 3), (3, 3, 3), (3, 3, 3),
                                      (3, 3, 3), (3, 3, 3), (3, 3, 3),
                                      (3, 3, 3), (3, 3, 3)),),
                 subsampled_residuals=((4, 6, 8),),
                 subsampled_input_shapes=((57, 57, 57),),
                 subsample_factors=((3, 3, 3),),
                 fc_filters=(150, 150),
                 first_fc_kernel=(3, 3, 3),
                 fc_residuals=(2, ),
                 padding='VALID',
                 use_prelu=True,
                 mode=tf.estimator.ModeKeys.EVAL,
                 use_bias=True,
                 kernel_initializer=tf.initializers.variance_scaling(distribution='uniform'),
                 bias_initializer=tf.zeros_initializer(),
                 kernel_regularizer=None,
                 bias_regularizer=None):
    """
    Image segmentation network based on a DeepMedic architecture [1, 2].
    Downsampling of features is done via strided convolutions. The architecture
    uses multiple processing paths with different resolutions. The different
    pathways are concatenated and then fed to the convolutional fc layers.

    [1] Konstantinos Kamnitsas et al. Efficient Multi-Scale 3D CNN with Fully
        Connected CRF for Accurate Brain Lesion Segmentation. Medical Image
        Analysis, 2016.
    [2] Konstantinos Kamnitsas et al. Multi-Scale 3D CNNs for segmentation of
        brain Lesions in multi-modal MRI. ISLES challenge, MICCAI 2015.

    Note: We are currently using bilinear upsampling whereas the original
    implementation (https://github.com/Kamnitsask/deepmedic) uses repeat
    upsampling.

    Args:
        inputs (tf.Tensor): Input feature tensor to the network (rank 5
            required).
        num_classes (int): Number of output classes.
        normal_filters (array_like, optional): Number of filters for each layer
            for normal path.
        normal_strides (array_like, optional): Strides for each layer for
            normal path.
        normal_kernels (array_like, optional): Kernel size for each layer for
            normal path.
        normal_residuals (array_like, optional): Location of residual
            connections for normal path.
        normal_input_shape (array_like, optional): Shape of input to normal
            path. Input to the network is center cropped to this shape.
        subsampled_filters (array_like, optional): Number of filters for each
            layer for each subsampled path.
        subsampled_strides (array_like, optional): Strides for each layer for
            each subsampled path.
        subsampled_kernels (array_like, optional): Kernel size for each layer
            for each subsampled path.
        subsampled_residuals (array_like, optional): Location of residual
            connections for each subsampled path.
        subsampled_input_shapes (array_like, optional): Shape of input to
            subsampled paths. Input to the network is downsampled and then
            center cropped to this shape.
        subsample_factors (array_like, optional): Downsampling factors for
            each subsampled path.
        fc_filters (array_like, optional): Number of filters for the fc layers.
        first_fc_kernel (array_like, optional): Shape of the kernel of the
            first fc layer.
        fc_residuals (array_like, optional): Location of residual connections
            for the fc layers.
        padding (string, optional): Type of padding used for convolutions.
            Standard is `VALID`
        use_prelu (bool, optional): Flag to enable PReLU activation.
            Alternatively leaky ReLU is used. Defaults to `True`.
        mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: TRAIN,
            EVAL or PREDICT
        use_bias (bool, optional): Boolean, whether the layer uses a bias.
        kernel_initializer (TYPE, optional): An initializer for the convolution
            kernel.
        bias_initializer (TYPE, optional): An initializer for the bias vector.
            If None, no bias will be applied.
        kernel_regularizer (None, optional): Optional regularizer for the
            convolution kernel.
        bias_regularizer (None, optional): Optional regularizer for the bias
            vector.

    Returns:
        dict: dictionary of output tensors

    """
    outputs = {}
    assert len(normal_filters) == len(normal_strides)
    assert len(normal_filters) == len(normal_kernels)
    assert len(inputs.get_shape().as_list()) == 5, \
        'inputs are required to have a rank of 5.'

    conv_params = {'use_bias': use_bias,
                   'kernel_initializer': kernel_initializer,
                   'bias_initializer': bias_initializer,
                   'kernel_regularizer': kernel_regularizer,
                   'bias_regularizer': bias_regularizer,
                   'padding': padding}

    def _residual_connection(x, prev_x):
        # crop previous to current size:
        prev_x = crop_central_block(prev_x, x.get_shape().as_list()[1:-1])

        # add prev_x to first channels of x

        to_pad = [[0, 0]] * (len(x.get_shape().as_list()) - 1)
        to_pad += [[0, x.get_shape().as_list()[-1] -
                    prev_x.get_shape().as_list()[-1]]]
        prev_x = tf.pad(prev_x, to_pad)

        return x + prev_x

    def _build_normal_pathway(x):
        with tf.variable_scope('normal_pathway'):
            tf.logging.info('Building normal pathway')
            center_crop = crop_central_block(x, normal_input_shape)
            tf.logging.info('Input is {}'.format(
                center_crop.get_shape().as_list()))

            layers = []

            x = center_crop
            for i in range(len(normal_filters)):
                with tf.variable_scope('layer_{}'.format(i)):
                    layers.append(x)
                    if i > 0:
                        x = tf.layers.batch_normalization(
                            x, training=mode == tf.estimator.ModeKeys.TRAIN)
                        x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
                    x = tf.layers.conv3d(x,
                                         normal_filters[i],
                                         normal_kernels[i],
                                         normal_strides[i],
                                         **conv_params)
                    # TODO: add pooling and dropout?!
                    if i + 1 in normal_residuals:
                        x = _residual_connection(x, layers[i - 1])
                    tf.logging.info('Output of layer {} is {}'.format(
                        i, x.get_shape().as_list()))
        tf.logging.info('Output is {}'.format(x.get_shape().as_list()))
        return x

    def _downsample(x, factor):
        if isinstance(factor, int):
            factor = [factor] * (len(x.get_shape().as_list()) - 2)
        pool_func = tf.nn.avg_pool3d

        factor = list(factor)

        x = pool_func(x, [1, ] + factor + [1, ], [1, ] + factor + [1, ],
                      'VALID')
        return x

    def _build_subsampled_pathways(x):
        pathways = []
        for pathway in range(len(subsample_factors)):
            with tf.variable_scope('subsampled_pathway_{}'.format(pathway)):
                tf.logging.info(
                    'Building subsampled pathway {}'.format(pathway))
                center_crop = crop_central_block(
                    x, subsampled_input_shapes[pathway])
                tf.logging.info('Input is {}'.format(
                    center_crop.get_shape().as_list()))

                layers = []

                x = center_crop
                x = _downsample(x, subsample_factors[pathway])
                tf.logging.info('Downsampled input is {}'.format(
                    x.get_shape().as_list()))

                for i in range(len(subsampled_filters[pathway])):
                    with tf.variable_scope('layer_{}'.format(i)):
                        layers.append(x)
                        if i > 0:
                            x = tf.layers.batch_normalization(
                                x, training=mode == tf.estimator.ModeKeys.TRAIN)
                            x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
                        x = tf.layers.conv3d(x, subsampled_filters[pathway][i],
                                             subsampled_kernels[pathway][i],
                                             subsampled_strides[pathway][i],
                                             **conv_params)
                        # TODO: add pooling and dropout?!
                        if i + 1 in subsampled_residuals:
                            x = _residual_connection(x, layers[i - 1])
                        tf.logging.info('Output of layer {} is {}'.format(
                            i, x.get_shape().as_list()))

                x = _upsample(x, subsample_factors[pathway])
                tf.logging.info('Output is {}'.format(x.get_shape().as_list()))
                pathways.append(x)
        return pathways

    def _upsample(x, factor):
        if isinstance(factor, int):
            factor = [factor] * (len(x.get_shape().as_list()) - 2)

        # TODO: build repeat upsampling

        x = linear_upsample_3d(x, strides=factor)
        return x

    x = inputs

    normal = _build_normal_pathway(x)
    pathways = _build_subsampled_pathways(x)

    normal_shape = normal.get_shape().as_list()[1:-1]
    paths = [normal]
    for x in pathways:
        paths.append(crop_central_block(x, normal_shape))

    x = tf.concat(paths, -1)

    layers = []
    for i in range(len(fc_filters)):
        with tf.variable_scope('fc_{}'.format(i)):
            layers.append(x)
            if i == 0 and any([k > 1 for k in first_fc_kernel]):
                x_shape = x.get_shape().as_list()
                # CAUTION: https://docs.python.org/2/faq/programming.html#how-do-i-create-a-multidimensional-list
                x_pad = [[0, 0] for _ in range(len(x_shape))]
                for j in range(len(first_fc_kernel)):
                    to_pad = (first_fc_kernel[j] - 1)
                    x_pad[j + 1][0] = to_pad // 2
                    x_pad[j + 1][1] = to_pad - x_pad[j + 1][0]
                    print(x_pad)
                x = tf.pad(x, x_pad, mode='SYMMETRIC')

            x = tf.layers.batch_normalization(
                x, training=mode == tf.estimator.ModeKeys.TRAIN)
            x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
            x = tf.layers.conv3d(x, fc_filters[i],
                                 first_fc_kernel if i == 0 else 1,
                                 **conv_params)
            if i + 1 in fc_residuals:
                x = _residual_connection(x, layers[i - 1])

    with tf.variable_scope('last'):
        x = tf.layers.batch_normalization(
            x, training=mode == tf.estimator.ModeKeys.TRAIN)
        x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
        conv_params['use_bias'] = True
        x = tf.layers.conv3d(x, num_classes, 1, **conv_params)

    outputs['logits'] = x
    tf.logging.info('last conv shape %s', x.get_shape())

    with tf.variable_scope('pred'):
        y_prob = tf.nn.softmax(x)
        outputs['y_prob'] = y_prob
        y_ = tf.argmax(x, axis=-1)
        outputs['y_'] = y_

    return outputs
def deepmedic_3d(inputs,
                 num_classes,
                 normal_filters=(30, 30, 40, 40, 40, 40, 50, 50),
                 normal_strides=((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1),
                                 (1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
                 normal_kernels=((3, 3, 3), (3, 3, 3), (3, 3, 3), (3, 3, 3),
                                 (3, 3, 3), (3, 3, 3), (3, 3, 3), (3, 3, 3)),
                 normal_residuals=(4, 6, 8),
                 normal_input_shape=(25, 25, 25),
                 subsampled_filters=((30, 30, 40, 40, 40, 40, 50, 50), ),
                 subsampled_strides=(((1, 1, 1), (1, 1, 1), (1, 1, 1),
                                      (1, 1, 1), (1, 1, 1), (1, 1, 1),
                                      (1, 1, 1), (1, 1, 1)), ),
                 subsampled_kernels=(((3, 3, 3), (3, 3, 3), (3, 3, 3),
                                      (3, 3, 3), (3, 3, 3), (3, 3, 3),
                                      (3, 3, 3), (3, 3, 3)), ),
                 subsampled_residuals=((4, 6, 8), ),
                 subsampled_input_shapes=((57, 57, 57), ),
                 subsample_factors=((3, 3, 3), ),
                 fc_filters=(150, 150),
                 first_fc_kernel=(3, 3, 3),
                 fc_residuals=(2, ),
                 padding='VALID',
                 use_prelu=True,
                 mode=tf.estimator.ModeKeys.EVAL,
                 use_bias=True,
                 kernel_initializer=tf.initializers.variance_scaling(
                     distribution='uniform'),
                 bias_initializer=tf.zeros_initializer(),
                 kernel_regularizer=None,
                 bias_regularizer=None):
    """
    Image segmentation network based on a DeepMedic architecture [1, 2].
    Downsampling of features is done via strided convolutions. The architecture
    uses multiple processing paths with different resolutions. The different
    pathways are concatenated and then fed to the convolutional fc layers.

    [1] Konstantinos Kamnitsas et al. Efficient Multi-Scale 3D CNN with Fully
        Connected CRF for Accurate Brain Lesion Segmentation. Medical Image
        Analysis, 2016.
    [2] Konstantinos Kamnitsas et al. Multi-Scale 3D CNNs for segmentation of
        brain Lesions in multi-modal MRI. ISLES challenge, MICCAI 2015.

    Note: We are currently using bilinear upsampling whereas the original
    implementation (https://github.com/Kamnitsask/deepmedic) uses repeat
    upsampling.

    Args:
        inputs (tf.Tensor): Input feature tensor to the network (rank 5
            required).
        num_classes (int): Number of output classes.
        normal_filters (array_like, optional): Number of filters for each layer
            for normal path.
        normal_strides (array_like, optional): Strides for each layer for
            normal path.
        normal_kernels (array_like, optional): Kernel size for each layer for
            normal path.
        normal_residuals (array_like, optional): Location of residual
            connections for normal path.
        normal_input_shape (array_like, optional): Shape of input to normal
            path. Input to the network is center cropped to this shape.
        subsampled_filters (array_like, optional): Number of filters for each
            layer for each subsampled path.
        subsampled_strides (array_like, optional): Strides for each layer for
            each subsampled path.
        subsampled_kernels (array_like, optional): Kernel size for each layer
            for each subsampled path.
        subsampled_residuals (array_like, optional): Location of residual
            connections for each subsampled path.
        subsampled_input_shapes (array_like, optional): Shape of input to
            subsampled paths. Input to the network is downsampled and then
            center cropped to this shape.
        subsample_factors (array_like, optional): Downsampling factors for
            each subsampled path.
        fc_filters (array_like, optional): Number of filters for the fc layers.
        first_fc_kernel (array_like, optional): Shape of the kernel of the
            first fc layer.
        fc_residuals (array_like, optional): Location of residual connections
            for the fc layers.
        padding (string, optional): Type of padding used for convolutions.
            Standard is `VALID`
        use_prelu (bool, optional): Flag to enable PReLU activation.
            Alternatively leaky ReLU is used. Defaults to `True`.
        mode (TYPE, optional): One of the tf.estimator.ModeKeys strings: TRAIN,
            EVAL or PREDICT
        use_bias (bool, optional): Boolean, whether the layer uses a bias.
        kernel_initializer (TYPE, optional): An initializer for the convolution
            kernel.
        bias_initializer (TYPE, optional): An initializer for the bias vector.
            If None, no bias will be applied.
        kernel_regularizer (None, optional): Optional regularizer for the
            convolution kernel.
        bias_regularizer (None, optional): Optional regularizer for the bias
            vector.

    Returns:
        dict: dictionary of output tensors

    """
    outputs = {}
    assert len(normal_filters) == len(normal_strides)
    assert len(normal_filters) == len(normal_kernels)
    assert len(inputs.get_shape().as_list()) == 5, \
        'inputs are required to have a rank of 5.'

    conv_params = {
        'use_bias': use_bias,
        'kernel_initializer': kernel_initializer,
        'bias_initializer': bias_initializer,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
        'padding': padding
    }

    def _residual_connection(x, prev_x):
        # crop previous to current size:
        prev_x = crop_central_block(prev_x, x.get_shape().as_list()[1:-1])

        # add prev_x to first channels of x

        to_pad = [[0, 0]] * (len(x.get_shape().as_list()) - 1)
        to_pad += [[
            0,
            x.get_shape().as_list()[-1] - prev_x.get_shape().as_list()[-1]
        ]]
        prev_x = tf.pad(prev_x, to_pad)

        return x + prev_x

    def _build_normal_pathway(x):
        with tf.variable_scope('normal_pathway'):
            tf.logging.info('Building normal pathway')
            center_crop = crop_central_block(x, normal_input_shape)
            tf.logging.info('Input is {}'.format(
                center_crop.get_shape().as_list()))

            layers = []

            x = center_crop
            for i in range(len(normal_filters)):
                with tf.variable_scope('layer_{}'.format(i)):
                    layers.append(x)
                    if i > 0:
                        x = tf.layers.batch_normalization(
                            x, training=mode == tf.estimator.ModeKeys.TRAIN)
                        x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
                    x = tf.layers.conv3d(x, normal_filters[i],
                                         normal_kernels[i], normal_strides[i],
                                         **conv_params)
                    # TODO: add pooling and dropout?!
                    if i + 1 in normal_residuals:
                        x = _residual_connection(x, layers[i - 1])
                    tf.logging.info('Output of layer {} is {}'.format(
                        i,
                        x.get_shape().as_list()))
        tf.logging.info('Output is {}'.format(x.get_shape().as_list()))
        return x

    def _downsample(x, factor):
        if isinstance(factor, int):
            factor = [factor] * (len(x.get_shape().as_list()) - 2)
        pool_func = tf.nn.avg_pool3d

        factor = list(factor)

        x = pool_func(x, [
            1,
        ] + factor + [
            1,
        ], [
            1,
        ] + factor + [
            1,
        ], 'VALID')
        return x

    def _build_subsampled_pathways(x):
        pathways = []
        for pathway in range(len(subsample_factors)):
            with tf.variable_scope('subsampled_pathway_{}'.format(pathway)):
                tf.logging.info(
                    'Building subsampled pathway {}'.format(pathway))
                center_crop = crop_central_block(
                    x, subsampled_input_shapes[pathway])
                tf.logging.info('Input is {}'.format(
                    center_crop.get_shape().as_list()))

                layers = []

                x = center_crop
                x = _downsample(x, subsample_factors[pathway])
                tf.logging.info('Downsampled input is {}'.format(
                    x.get_shape().as_list()))

                for i in range(len(subsampled_filters[pathway])):
                    with tf.variable_scope('layer_{}'.format(i)):
                        layers.append(x)
                        if i > 0:
                            x = tf.layers.batch_normalization(
                                x,
                                training=mode == tf.estimator.ModeKeys.TRAIN)
                            x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
                        x = tf.layers.conv3d(x, subsampled_filters[pathway][i],
                                             subsampled_kernels[pathway][i],
                                             subsampled_strides[pathway][i],
                                             **conv_params)
                        # TODO: add pooling and dropout?!
                        if i + 1 in subsampled_residuals:
                            x = _residual_connection(x, layers[i - 1])
                        tf.logging.info('Output of layer {} is {}'.format(
                            i,
                            x.get_shape().as_list()))

                x = _upsample(x, subsample_factors[pathway])
                tf.logging.info('Output is {}'.format(x.get_shape().as_list()))
                pathways.append(x)
        return pathways

    def _upsample(x, factor):
        if isinstance(factor, int):
            factor = [factor] * (len(x.get_shape().as_list()) - 2)

        # TODO: build repeat upsampling

        x = linear_upsample_3d(x, strides=factor)
        return x

    x = inputs

    normal = _build_normal_pathway(x)
    pathways = _build_subsampled_pathways(x)

    normal_shape = normal.get_shape().as_list()[1:-1]
    paths = [normal]
    for x in pathways:
        paths.append(crop_central_block(x, normal_shape))

    x = tf.concat(paths, -1)

    layers = []
    for i in range(len(fc_filters)):
        with tf.variable_scope('fc_{}'.format(i)):
            layers.append(x)
            if i == 0 and any([k > 1 for k in first_fc_kernel]):
                x_shape = x.get_shape().as_list()
                # CAUTION: https://docs.python.org/2/faq/programming.html#how-do-i-create-a-multidimensional-list
                x_pad = [[0, 0] for _ in range(len(x_shape))]
                for j in range(len(first_fc_kernel)):
                    to_pad = (first_fc_kernel[j] - 1)
                    x_pad[j + 1][0] = to_pad // 2
                    x_pad[j + 1][1] = to_pad - x_pad[j + 1][0]
                    print(x_pad)
                x = tf.pad(x, x_pad, mode='SYMMETRIC')

            x = tf.layers.batch_normalization(
                x, training=mode == tf.estimator.ModeKeys.TRAIN)
            x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
            x = tf.layers.conv3d(x, fc_filters[i],
                                 first_fc_kernel if i == 0 else 1,
                                 **conv_params)
            if i + 1 in fc_residuals:
                x = _residual_connection(x, layers[i - 1])

    with tf.variable_scope('last'):
        x = tf.layers.batch_normalization(
            x, training=mode == tf.estimator.ModeKeys.TRAIN)
        x = prelu(x) if use_prelu else leaky_relu(x, 0.01)
        conv_params['use_bias'] = True
        x = tf.layers.conv3d(x, num_classes, 1, **conv_params)

    outputs['logits'] = x
    tf.logging.info('last conv shape %s', x.get_shape())

    with tf.variable_scope('pred'):
        y_prob = tf.nn.softmax(x)
        outputs['y_prob'] = y_prob
        y_ = tf.argmax(x, axis=-1)
        outputs['y_'] = y_

    return outputs