Example #1
0
def single_hourglass(output_channels,
                     n_levels=4,
                     channels=32,
                     channels_growth=2,
                     spatial_dims=2,
                     spacing=1,
                     norm_groups=4):
    '''Combines an hourglass block with input/output blocks to 
    increase/decrease the number of channels.
    
    Notes:
    Expects the input to be already padded)
    '''

    Conv = get_nd_conv(spatial_dims)

    hglass = _stack_layers([
        Conv(channels, kernel_size=1, padding='same'),
        bottleneck_conv_block(channels, spatial_dims, norm_groups),
        hourglass_block(n_levels, channels, channels_growth, spatial_dims,
                        spacing, norm_groups),
        bottleneck_conv_block(channels, spatial_dims, norm_groups),
        Activation(),
        GroupNormalization(groups=norm_groups, axis=-1),
        Conv(channels, kernel_size=3, padding='same'),
        Activation(),
        GroupNormalization(groups=norm_groups, axis=-1),
        Conv(output_channels, kernel_size=1, padding='same'),
    ])

    return hglass
Example #2
0
def add_instance_seg_heads(model,
                           n_classes,
                           spacing=1.,
                           kernel_size=1,
                           class_activation=True):
    '''Attaches a semi-convolutional embeddings layer and a semantic 
    classification convolutional layer to the given model.
    
    Args:
        model: model to which output layers are added
        n_classes: number semantic classes
        spacing: pixel/voxel spacing of the semi-conv embeddings
        kernel_size: kernel size of the appended layers
    '''

    spatial_dims = len(model.inputs[0].shape) - 2
    spacing = tuple(
        float(val) for val in np.broadcast_to(spacing, spatial_dims))

    if len(model.outputs) > 1:
        warnings.warn(
            'The model as {} outputs. #outputs > 1 will be ingnored'.format(
                len(model.outputs)))

    if not class_activation:
        activation = None
    elif n_classes > 1:
        activation = 'softmax'
    else:
        activation = 'sigmoid'

    last_layer = model.outputs[0]

    conv = get_nd_conv(spatial_dims)(n_classes,
                                     kernel_size=kernel_size,
                                     activation=activation,
                                     name='semantic_class',
                                     padding='same')

    semi_conv = get_nd_semiconv(spatial_dims)(spacing=spacing,
                                              kernel_size=kernel_size,
                                              name='embeddings',
                                              padding='same')

    semantic_class = conv(last_layer)
    embeddings = semi_conv(last_layer)

    return Model(inputs=model.inputs,
                 outputs=[embeddings, semantic_class],
                 name=model.name)
Example #3
0
def rdc_block(n_groups=16,
              dilation_rates=(1, 2, 4, 8, 16),
              channels_per_group=32,
              k_size=3,
              spatial_dims=2,
              dropout=0.1):
    '''Grouped conv with stacked dilated conv in each group and pointwise convolution for mixing
    
    Notes
    -----
    pre-activation to keep the residual path clear as described in:
    
    HE, Kaiming, et al. Identity mappings in deep residual networks.
    In: European conference on computer vision. Springer, Cham, 2016.
    S. 630-645.
    '''

    Conv = get_nd_conv(spatial_dims)
    channels = channels_per_group * n_groups
    sd_conv = StackedDilatedConv(rank=spatial_dims,
                                 filters=channels,
                                 kernel_size=k_size,
                                 dilation_rates=dilation_rates,
                                 groups=n_groups,
                                 activation=LeakyReLU())

    # mixes ch/reduce from input_ch + channels_per_group*n_groups
    reduce_ch_conv = Conv(channels, 1)

    spatial_dropout = get_nd_spatial_dropout(spatial_dims)(dropout)

    def _call(x):

        x = spatial_dropout(x)
        x = LeakyReLU()(x)
        x = reduce_ch_conv(x)
        x = LeakyReLU()(x)
        x = sd_conv(x)

        return x

    return _call
Example #4
0
def bottleneck_conv_block(channels=32, spatial_dims=2, norm_groups=4):
    '''
    Notes
    -----
    pre-activation to keep the residual path clear as described in:
    
    HE, Kaiming, et al. Identity mappings in deep residual networks.
    In: European conference on computer vision. Springer, Cham, 2016.
    S. 630-645.
    '''

    Conv = get_nd_conv(spatial_dims)

    conv_in = Conv(channels, kernel_size=1, padding='same')

    seq = _stack_layers([
        Activation(),
        GroupNormalization(groups=norm_groups, axis=-1),
        Conv(channels // 2, kernel_size=1, padding='same'),
        Activation(),
        GroupNormalization(groups=norm_groups, axis=-1),
        Conv(channels // 2, kernel_size=3, padding='same'),
        Activation(),
        GroupNormalization(groups=norm_groups, axis=-1),
        Conv(channels, kernel_size=1, padding='same'),
    ])

    def block(x):
        if x.shape[-1] != channels:
            # if needed, brings the number of input channels to the same as output channels
            # not strictly a bottleneck anymore
            x = Activation()(conv_in(x))

        return x + seq(x)

    return block
Example #5
0
def hourglass_block(
        n_levels=4,
        channels=32,
        channels_growth=2,
        spatial_dims=2,
        spacing=1,
        norm_groups=4,
):
    Conv = get_nd_conv(spatial_dims)
    MaxPool = get_nd_maxpooling(spatial_dims)
    UpSampling = get_nd_upsampling(spatial_dims)

    # must be divisible by 2*norm_groups because of channels//2 in bottleneck block
    level_channels = [
        int((channels * channels_growth**l) // (2 * norm_groups) * 2 *
            norm_groups) for l in range(n_levels)
    ]

    # define layers ################################################
    # conv block for down path
    downs = [
        bottleneck_conv_block(level_channels[l], spatial_dims, norm_groups)
        for l in range(n_levels)
    ]

    # conv blocks for residual/skip paths
    skips = [
        bottleneck_conv_block(level_channels[l], spatial_dims, norm_groups)
        for l in range(n_levels)
    ]

    # conv block for middle layer
    mid = bottleneck_conv_block(level_channels[-1], spatial_dims, norm_groups)

    # conv blocks for up path
    ups = [
        _stack_layers([
            Conv(level_channels[l], kernel_size=1,
                 padding='same'),  # reduce concatenated channels by half
            LeakyReLU(),
            bottleneck_conv_block(level_channels[l], spatial_dims, norm_groups)
        ]) for l in range(n_levels - 1)
    ]

    pools = [
        MaxPool(anisotropic_kernel_size(spacing, l, n_levels - 1))
        for l in range(0, n_levels - 1)
    ]
    upsamples = [
        UpSampling(anisotropic_kernel_size(spacing, l, n_levels - 1))
        for l in range(0, n_levels - 1)
    ]

    def block(x):

        # down, keeping a handle on intermediate outputs to build skip connections
        level_outputs = [downs[0](x)]
        for down, pool in zip(downs[1:], pools):
            level_outputs.append(down(pool(level_outputs[-1])))

        # residual/skip
        for idx, skip in enumerate(skips):
            level_outputs[idx] = skip(level_outputs[idx])

        # middle
        x = mid(level_outputs.pop(-1))
        x = tf.identity(x, name="middle")

        # up
        for idx, (level_var, up, upsample) in enumerate(
                zip(level_outputs[::-1], ups[::-1], upsamples[::-1])):

            x = upsample(x)
            # name the concatenation layers for easy access: level 0: last up layer (back to image resolution)
            x = tf.concat([x, level_var],
                          axis=-1,
                          name='concat_l{}'.format(n_levels - 2 - idx))
            x = up(x)

        return x

    return block
Example #6
0
def GenericRDCnetBase(input_shape,
                      downsampling_factor,
                      n_downsampling_channels,
                      n_output_channels,
                      n_groups=16,
                      dilation_rates=(1, 2, 4, 8, 16),
                      channels_per_group=32,
                      n_steps=5,
                      dropout=0.1,
                      up_method='transposed_conv'):
    '''Fully convolutional model consisting of a stacked dilated convlution 
    block applied recursively'''

    spatial_dims = len(input_shape) - 1
    downsampling_factor = tuple(
        np.broadcast_to(np.array(downsampling_factor), spatial_dims).tolist())

    recurrent_block = rdc_block(n_groups,
                                dilation_rates,
                                channels_per_group,
                                spatial_dims=spatial_dims,
                                dropout=dropout)
    n_features = channels_per_group * n_groups
    loop = delta_loop(n_features, recurrent_block, n_steps)

    in_kernel_size = tuple(max(3, f) for f in downsampling_factor)
    out_kernel_size = tuple(max(3, 2 * f) for f in downsampling_factor)

    Conv = get_nd_conv(spatial_dims)
    conv_in = Conv(n_downsampling_channels,
                   kernel_size=in_kernel_size,
                   strides=downsampling_factor,
                   padding='same')

    if up_method == 'transposed_conv':
        ConvTranspose = get_nd_conv_transposed(spatial_dims)
        conv_out = ConvTranspose(n_output_channels,
                                 kernel_size=out_kernel_size,
                                 strides=downsampling_factor,
                                 padding='same')
    elif up_method == 'upsample':
        UpSampling = get_nd_upsampling(spatial_dims)
        upsampling_out = UpSampling(downsampling_factor)
        conv_out = Conv(n_output_channels,
                        kernel_size=out_kernel_size,
                        padding='same')
    else:
        raise ValueError(
            '{} up_method not recognized, expects one of: transposed_conv, upsample'
            .format(up_method))

    input_padding = DynamicPaddingLayer(downsampling_factor,
                                        ndim=spatial_dims + 2)
    output_trimming = DynamicTrimmingLayer(ndim=spatial_dims + 2)

    inputs = Input(shape=input_shape)
    x = input_padding(inputs)
    x = conv_in(x)

    x = loop(x)
    x = LeakyReLU()(x)
    if up_method == 'upsample':
        x = upsampling_out(x)
    x = conv_out(x)
    x = output_trimming([inputs, x])

    name = 'RDCNet-F{}-DC{}-OC{}-G{}-DR{}-GC{}-S{}-D{}'.format(
        _format_tuple(downsampling_factor),
        n_downsampling_channels, n_output_channels, n_groups,
        _format_tuple(dilation_rates), channels_per_group, n_steps, dropout)

    return Model(inputs=inputs, outputs=[x], name=name)
Example #7
0
def GenericUnetBase(input_shape=None,
                    input_tensor=None,
                    batch_size=None,
                    with_bn=False,
                    width=1,
                    n_levels=5,
                    n_blocks=2):
    '''UNet constructor for 2D and 3D.

    Parameters
    ----------
    input_shape: tuple or None
        Expected shape of the input tensor. Either input_shape or input_tensor
        have to be defined.
    input_tensor: Tensor or None
        Input tensor. Either input_shape or input_tensor have to be defined.
    batch_size: int or None
        Expected batch size.
    with_bn: bool
        If True, instantiate model with BatchNormalization.
    width: float
        Scales the number of features used in all layers. width=1.0 corresponds
        to the default of 64 features in the first level.
    n_levels: int
        Number of levels in the unet.
    n_blocks: int
        Number of blocks in each level.

    Notes
    -----
    * All dimensions are treated identically.
    * If you need more customization of the architecture, you might be
      interested in specializing UnetBuilder.

    '''
    if input_tensor is None and input_shape is None:
        raise ValueError('Either input_shape or input_tensor must be given!')

    if input_tensor is None:
        img_input = Input(batch_shape=(batch_size, ) + input_shape,
                          name='input')
    else:
        img_input = input_tensor

    ORIGINAL_FEATURES = 64

    # dont count batch and channel dimension.
    spatial_ndim = len(img_input.shape) - 2

    # determine normalization
    norm_layer = BatchNormalization if with_bn else None

    builder = UnetBuilder(conv_layer=get_nd_conv(spatial_ndim),
                          downsampling_layer=get_nd_maxpooling(spatial_ndim),
                          upsampling_layer=get_nd_upsampling(spatial_ndim),
                          norm_layer=norm_layer,
                          n_levels=n_levels,
                          n_blocks=n_blocks,
                          base_features=int(width * ORIGINAL_FEATURES))

    # add padding...
    padding_factor = 2**n_levels
    x = DynamicPaddingLayer(factor=padding_factor,
                            ndim=spatial_ndim + 2,
                            name='dpad')(img_input)

    # construct unet.
    x = builder.build_unet_block(x)

    # ...and remove padding.
    x = DynamicTrimmingLayer(ndim=spatial_ndim + 2,
                             name='dtrim')([img_input, x])

    inputs = (get_source_inputs(input_tensor)
              if input_tensor is not None else img_input)
    return Model(inputs=inputs, outputs=x, name=builder.get_model_name())