Ejemplo n.º 1
0
def _shortcut(input,
              residual,
              subsample,
              upsample,
              weight_norm=False,
              normalization=None,
              weight_decay=None,
              init='he_normal',
              ndim=2,
              name=None):
    name = _get_unique_name('shortcut', name)

    # Expand channels of shortcut to match residual.
    # Stride appropriately to match residual (width, height)
    # Should be int if network architecture is correctly configured.
    equal_channels = residual._keras_shape[1] == input._keras_shape[1]

    shortcut = input

    # Downsample input
    if subsample:

        def downsample_output_shape(input_shape):
            output_shape = list(input_shape)
            output_shape[-2] = None if output_shape[-2]==None \
                                    else output_shape[-2]//2
            output_shape[-1] = None if output_shape[-1]==None \
                                    else output_shape[-1]//2
            return tuple(output_shape)

        if ndim == 2:
            shortcut = Lambda(lambda x: x[:, :, ::2, ::2],
                              output_shape=downsample_output_shape)(shortcut)
        elif ndim == 3:
            shortcut = Lambda(lambda x: x[:, :, :, ::2, ::2],
                              output_shape=downsample_output_shape)(shortcut)
        else:
            raise ValueError("ndim must be 2 or 3")

    # Upsample input
    if upsample:
        shortcut = UpSampling(size=2, ndim=ndim)(shortcut)

    # Adjust input channels to match residual
    if not equal_channels:
        shortcut = Convolution(filters=residual._keras_shape[1],
                               kernel_size=1,
                               ndim=ndim,
                               weight_norm=weight_norm,
                               kernel_initializer=init,
                               padding='valid',
                               kernel_regularizer=_l2(weight_decay),
                               name=name + '_conv2d')(shortcut)

    out = merge_add([shortcut, residual])
    if weight_norm and normalization is None:
        # Divide sum by two.
        out = Lambda(lambda x: x / 2., output_shape=lambda x: x)(out)

    return out
Ejemplo n.º 2
0
    def f(input):
        residuals = []
        for i in range(num_residuals):
            residual = _norm_relu_conv(filters,
                                       kernel_size=1,
                                       subsample=subsample,
                                       normalization=normalization,
                                       weight_norm=weight_norm,
                                       weight_decay=weight_decay,
                                       norm_kwargs=norm_kwargs,
                                       init=init,
                                       nonlinearity=nonlinearity,
                                       ndim=ndim,
                                       name=name)(input)
            residual = _norm_relu_conv(filters,
                                       kernel_size=3,
                                       normalization=normalization,
                                       weight_norm=weight_norm,
                                       weight_decay=weight_decay,
                                       norm_kwargs=norm_kwargs,
                                       init=init,
                                       nonlinearity=nonlinearity,
                                       ndim=ndim,
                                       name=name)(residual)
            residual = _norm_relu_conv(filters * 4,
                                       kernel_size=1,
                                       upsample=upsample,
                                       normalization=normalization,
                                       weight_norm=weight_norm,
                                       weight_decay=weight_decay,
                                       norm_kwargs=norm_kwargs,
                                       init=init,
                                       nonlinearity=nonlinearity,
                                       ndim=ndim,
                                       name=name)(residual)
            if dropout > 0:
                if nonlinearity == 'selu':
                    residual = AlphaDropout(dropout)(residual)
                else:
                    residual = Dropout(dropout)(residual)
            residiuals.append(residual)

        if len(residuals) > 1:
            output = merge_add(residuals)
        else:
            output = residuals[0]
        if skip:
            output = _shortcut(input,
                               output,
                               subsample=subsample,
                               upsample=upsample,
                               weight_norm=weight_norm,
                               normalization=normalization,
                               weight_decay=weight_decay,
                               init=init,
                               ndim=ndim,
                               name=name)
        return output
Ejemplo n.º 3
0
def merge(x, mode):
    if mode=='sum':
        out = merge_add(x)
    elif mode=='concat':
        channel_axis = get_channel_axis()
        out = merge_concat(x, axis=channel_axis)
    else:
        raise ValueError("Unrecognized merge mode: {}".format(mode))
    return out
Ejemplo n.º 4
0
    def f(input):
        residuals = []
        for i in range(num_residuals):
            residual = input
            if normalization is not None:
                residual = normalization(name=name + "_norm_" + str(i),
                                         **norm_kwargs)(residual)
            residual = get_nonlinearity(nonlinearity)(residual)
            if subsample:
                residual = MaxPooling(pool_size=2, ndim=ndim)(residual)
            residual = Convolution(filters=filters,
                                   kernel_size=3,
                                   ndim=ndim,
                                   weight_norm=weight_norm,
                                   kernel_initializer=init,
                                   padding='same',
                                   kernel_regularizer=_l2(weight_decay),
                                   name=name + "_conv2d_" + str(i))(residual)
            if dropout > 0:
                if nonlinearity == 'selu':
                    residual = AlphaDropout(dropout)(residual)
                else:
                    residual = Dropout(dropout)(residual)
            if upsample:
                residual = UpSampling(size=2, ndim=ndim)(residual)
            residuals.append(residual)

        if len(residuals) > 1:
            output = merge_add(residuals)
        else:
            output = residuals[0]
        if skip:
            output = _shortcut(input,
                               output,
                               subsample=subsample,
                               upsample=upsample,
                               normalization=normalization,
                               weight_norm=weight_norm,
                               weight_decay=weight_decay,
                               init=init,
                               ndim=ndim,
                               name=name)
        return output
Ejemplo n.º 5
0
 def first_block(x):
     outputs = []
     for i in range(num_first_conv):
         out = Convolution(filters=input_num_filters,
                           kernel_size=3,
                           ndim=ndim,
                           weight_norm=weight_norm,
                           kernel_initializer=init,
                           padding='same',
                           kernel_regularizer=_l2(weight_decay),
                           name=_unique('first_conv_' + str(i)))(x)
         outputs.append(out)
     if len(outputs) > 1:
         out = merge_add(outputs)
         if normalization is None:
             # Divide sum by two.
             out = Lambda(lambda x: x / 2.,
                          output_shape=lambda x: x)(out)
     else:
         out = outputs[0]
     return out
Ejemplo n.º 6
0
 def final_block(x):
     outputs = []
     for i in range(num_final_conv):
         out = Convolution(filters=input_num_filters,
                           kernel_size=3,
                           ndim=ndim,
                           weight_norm=weight_norm,
                           kernel_initializer=init,
                           padding='same',
                           kernel_regularizer=_l2(weight_decay),
                           name=_unique('final_conv_' + str(i)))(x)
         if normalization is not None:
             out = normalization(name=_unique('final_norm_' +
                                              str(i)),
                                 **norm_kwargs)(out)
         out = get_nonlinearity(nonlinearity)(out)
         outputs.append(out)
     if len(outputs) > 1:
         out = merge_add(outputs)
     else:
         out = outputs[0]
     return out
Ejemplo n.º 7
0
 def merge_into(x, into, skips, cycle, direction, depth):
     if x._keras_shape[1] != into._keras_shape[1]:
         if cycles_share_weights and depth in skips[cycle - 1][direction]:
             conv_layer = skips[cycle - 1][direction][depth]
         else:
             name = _unique('long_skip_' + str(direction) + '_' +
                            str(depth))
             conv_layer = Convolution(filters=into._keras_shape[1],
                                      kernel_size=1,
                                      ndim=ndim,
                                      weight_norm=weight_norm,
                                      kernel_initializer=init,
                                      padding='valid',
                                      kernel_regularizer=_l2(weight_decay),
                                      name=name)
         skips[cycle][direction][depth] = conv_layer
         x = conv_layer(x)
     out = merge_add([x, into])
     if normalization is None:
         # Divide sum by two.
         out = Lambda(lambda x: x / 2., output_shape=lambda x: x)(out)
     return out
Ejemplo n.º 8
0
def _shortcut(input, residual, subsample, upsample, upsample_mode='repeat',
              weight_decay=None, init='he_normal', ndim=2, name=None):
    name = _get_unique_name('shortcut', name)
    channel_axis = get_channel_axis(ndim)
    shortcut = input
    
    # Downsample input
    if subsample:
        shortcut = _subsample(shortcut, ndim=ndim)
        
    # Upsample input
    if upsample:
        shortcut = _upsample(shortcut,
                             mode=upsample_mode,
                             ndim=ndim,
                             filters=shortcut._keras_shape[channel_axis],
                             kernel_size=2,
                             kernel_initializer=init,
                             kernel_regularizer=_l2(weight_decay),
                             name=name+"_upconv")
        
    # Expand channels of shortcut to match residual.
    # Stride appropriately to match residual (width, height)
    # Should be int if network architecture is correctly configured.
    equal_channels = residual._keras_shape[channel_axis] == \
                                            shortcut._keras_shape[channel_axis]
    if not equal_channels:
        shortcut = Convolution(filters=residual._keras_shape[channel_axis],
                               kernel_size=1, ndim=ndim,
                               kernel_initializer=init, padding='valid',
                               kernel_regularizer=_l2(weight_decay),
                               name=name+"_conv")(shortcut)
    
    out = merge_add([shortcut, residual])
        
    return out
Ejemplo n.º 9
0
def assemble_model(input_shape,
                   num_classes,
                   num_init_blocks,
                   num_main_blocks,
                   main_block_depth,
                   input_num_filters,
                   num_cycles=1,
                   preprocessor_network=None,
                   postprocessor_network=None,
                   mainblock=None,
                   initblock=None,
                   nonlinearity='relu',
                   dropout=0.,
                   normalization=BatchNormalization,
                   weight_norm=False,
                   weight_decay=None,
                   norm_kwargs=None,
                   init='he_normal',
                   ndim=2,
                   cycles_share_weights=True,
                   num_residuals=1,
                   num_first_conv=1,
                   num_final_conv=1,
                   num_classifier=1,
                   num_outputs=1,
                   use_first_conv=True,
                   use_final_conv=True):
    """
    input_shape : tuple specifiying the 2D image input shape.
    num_classes : number of classes in the segmentation output.
    num_init_blocks : the number of blocks of type initblock, above mainblocks.
        These blocks always have the same number of channels as the first
        convolutional layer in the model.
    num_main_blocks : the number of blocks of type mainblock, below initblocks.
        These blocks double (halve) in number of channels at each downsampling
        (upsampling).
    main_block_depth : an integer or list of integers specifying the number of
        repetitions of each mainblock. A list must contain as many values as
        there are main_blocks in the downward (or upward -- it's mirrored) path
        plus one for the across path.
    input_num_filters : the number channels in the first (last) convolutional
        layer in the model (and of each initblock).
    num_cycles : number of times to cycle the down/up processing pair.
    preprocessor_network : a neural network for preprocessing the input data.
    postprocessor_network : a neural network for postprocessing the data fed
        to the classifier.
    mainblock : a layer defining the mainblock (bottleneck by default).
    initblock : a layer defining the initblock (basic_block_mp by default).
    nonlinearity : string or function specifying/defining the nonlinearity.
    dropout : the dropout probability, introduced in every block.
    normalization : the normalization to apply to layers (by default: batch
        normalization). If None, no normalization is applied.
    weight_norm : boolean, whether to use weight norm on conv layers.
    weight_decay : the weight decay (L2 penalty) used in every convolution.
    norm_kwargs : keyword arguments to pass to batch norm layers.
    init : string or function specifying the initializer for layers.
    ndim : the spatial dimensionality of the input and output (2 or 3)
    cycles_share_weights : share network weights across cycles.
    num_residuals : the number of parallel residual functions per block.
    num_first_conv : the number of parallel first convolutions.
    num_final_conv : the number of parallel final convolutions (+BN).
    num_classifier : the number of parallel linear classifiers.
    num_outputs : the number of model outputs, each with num_classifier
        classifiers.
    """
    '''
    By default, use depth 2 basic_block for mainblock
    '''
    if mainblock is None:
        mainblock = basic_block
    if initblock is None:
        initblock = basic_block_mp
    '''
    main_block_depth can be a list per block or a single value 
    -- ensure the list length is correct (if list) and that no length is 0
    '''
    if not hasattr(main_block_depth, '__len__'):
        if main_block_depth == 0:
            raise ValueError("main_block_depth must never be zero")
    else:
        if len(main_block_depth) != num_main_blocks + 1:
            raise ValueError("main_block_depth must have "
                             "`num_main_blocks+1` values when "
                             "passed as a list")
        for d in main_block_depth:
            if d == 0:
                raise ValueError("main_block_depth must never be zero")
    '''
    Returns the depth of a mainblock for a given pooling level.
    '''
    def get_repetitions(level):
        if hasattr(main_block_depth, '__len__'):
            return main_block_depth[level]
        return main_block_depth

    '''
    Merge tensors, changing the number of feature maps in the first input
    to match that of the second input. Feature maps in the first input are
    reweighted.
    
    If weight sharing is enabled, reuse old convolutions.
    '''

    def merge_into(x, into, skips, cycle, direction, depth):
        if x._keras_shape[1] != into._keras_shape[1]:
            if cycles_share_weights and depth in skips[cycle - 1][direction]:
                conv_layer = skips[cycle - 1][direction][depth]
            else:
                name = _unique('long_skip_' + str(direction) + '_' +
                               str(depth))
                conv_layer = Convolution(filters=into._keras_shape[1],
                                         kernel_size=1,
                                         ndim=ndim,
                                         weight_norm=weight_norm,
                                         kernel_initializer=init,
                                         padding='valid',
                                         kernel_regularizer=_l2(weight_decay),
                                         name=name)
            skips[cycle][direction][depth] = conv_layer
            x = conv_layer(x)
        out = merge_add([x, into])
        if normalization is None:
            # Divide sum by two.
            out = Lambda(lambda x: x / 2., output_shape=lambda x: x)(out)
        return out

    '''
    Given some block function and an input tensor, return a reusable model
    instantiating that block function. This is to allow weight sharing.
    '''

    def make_block(block_func, x):
        x_filters = x._keras_shape[1]
        input = Input(shape=(x_filters, ) + tuple([None] * ndim))
        model = Model(input, block_func(input))
        return model

    '''
    Constant kwargs passed to the init and main blocks.
    '''
    block_kwargs = {
        'skip': True,
        'dropout': dropout,
        'weight_norm': weight_norm,
        'weight_decay': weight_decay,
        'num_residuals': num_residuals,
        'norm_kwargs': norm_kwargs,
        'nonlinearity': nonlinearity,
        'init': init,
        'ndim': ndim
    }
    if norm_kwargs is None:
        norm_kwargs = {}

    # INPUT
    input = Input(shape=input_shape)

    # Preprocessing
    if preprocessor_network is not None:
        input = preprocessor_network(input)
    '''
    Build the blocks for all cycles, contracting and expanding in each cycle.
    '''
    tensors = []  # feature tensors
    blocks = []  # residual block layers
    skips = []  # 1x1 kernel convolution layers on long skip connections
    x = input
    for cycle in range(num_cycles):
        # Create tensors and layer lists for this cycle.
        tensors.append({'down': {}, 'up': {}, 'across': {}})
        blocks.append({'down': {}, 'up': {}, 'across': {}})
        skips.append({'down': {}, 'up': {}, 'across': {}})

        # First convolution
        if cycle > 0:
            x = merge_into(x,
                           tensors[cycle - 1]['up'][0],
                           skips=skips,
                           cycle=cycle,
                           direction='down',
                           depth=0)
        if cycles_share_weights and cycle > 1:
            block = blocks[cycle - 1]['down'][0]
        else:

            def first_block(x):
                outputs = []
                for i in range(num_first_conv):
                    out = Convolution(filters=input_num_filters,
                                      kernel_size=3,
                                      ndim=ndim,
                                      weight_norm=weight_norm,
                                      kernel_initializer=init,
                                      padding='same',
                                      kernel_regularizer=_l2(weight_decay),
                                      name=_unique('first_conv_' + str(i)))(x)
                    outputs.append(out)
                if len(outputs) > 1:
                    out = merge_add(outputs)
                    if normalization is None:
                        # Divide sum by two.
                        out = Lambda(lambda x: x / 2.,
                                     output_shape=lambda x: x)(out)
                else:
                    out = outputs[0]
                return out

            block = make_block(first_block, x)
        if use_first_conv:
            x = block(x)
            blocks[cycle]['down'][0] = block
        else:
            blocks[cycle]['down'][0] = lambda x: x
        tensors[cycle]['down'][0] = x
        print("Cycle {} - FIRST DOWN: {}".format(cycle, x._keras_shape))

        # DOWN (initial subsampling blocks)
        for b in range(num_init_blocks):
            depth = b + 1
            if cycle > 0:
                x = merge_into(x,
                               tensors[cycle - 1]['up'][depth],
                               skips=skips,
                               cycle=cycle,
                               direction='down',
                               depth=depth)
            if cycles_share_weights and cycle > 1:
                block = blocks[cycle - 1]['down'][depth]
            else:
                block_func = residual_block(initblock,
                                            filters=input_num_filters,
                                            repetitions=1,
                                            subsample=True,
                                            upsample=False,
                                            normalization=normalization,
                                            name='d' + str(depth),
                                            **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['down'][depth] = block
            tensors[cycle]['down'][depth] = x
            print("Cycle {} - INIT DOWN {}: {}".format(cycle, b,
                                                       x._keras_shape))

        # DOWN (resnet blocks)
        for b in range(num_main_blocks):
            depth = b + 1 + num_init_blocks
            if cycle > 0:
                x = merge_into(x,
                               tensors[cycle - 1]['up'][depth],
                               skips=skips,
                               cycle=cycle,
                               direction='down',
                               depth=depth)
            if cycles_share_weights and cycle > 1:
                block = blocks[cycle - 1]['down'][depth]
            else:
                block_func = residual_block(mainblock,
                                            filters=input_num_filters * (2**b),
                                            repetitions=get_repetitions(b),
                                            subsample=True,
                                            upsample=False,
                                            normalization=normalization,
                                            name='d' + str(depth),
                                            **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['down'][depth] = block
            tensors[cycle]['down'][depth] = x
            print("Cycle {} - MAIN DOWN {}: {}".format(cycle, b,
                                                       x._keras_shape))

        # ACROSS
        if num_main_blocks:
            if cycle > 0:
                x = merge_into(x,
                               tensors[cycle - 1]['across'][0],
                               skips=skips,
                               cycle=cycle,
                               direction='across',
                               depth=0)
            if cycles_share_weights and cycle > 1:
                block = blocks[cycle - 1]['across'][0]
            else:
                block_func = residual_block( \
                                  mainblock,
                                  filters=input_num_filters*(2**b),
                                  repetitions=get_repetitions(num_main_blocks),
                                  subsample=True,
                                  upsample=True,
                                  normalization=normalization,
                                  name='a',
                                  **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['across'][0] = block
            tensors[cycle]['across'][0] = x
            print("Cycle {} - ACROSS: {}".format(cycle, x._keras_shape))

        # UP (resnet blocks)
        for b in range(num_main_blocks - 1, -1, -1):
            depth = b + 1 + num_init_blocks
            x = merge_into(x,
                           tensors[cycle]['down'][depth],
                           skips=skips,
                           cycle=cycle,
                           direction='up',
                           depth=depth)
            if cycles_share_weights and cycle > 0 and cycle < num_cycles - 1:
                block = blocks[cycle - 1]['up'][depth]
            else:

                block_func = residual_block(mainblock,
                                            filters=input_num_filters * (2**b),
                                            repetitions=get_repetitions(b),
                                            subsample=False,
                                            upsample=True,
                                            normalization=normalization,
                                            name='u' + str(depth),
                                            **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['up'][depth] = block
            tensors[cycle]['up'][depth] = x
            print("Cycle {} - MAIN UP {}: {}".format(cycle, b, x._keras_shape))

        # UP (final upsampling blocks)
        for b in range(num_init_blocks - 1, -1, -1):
            depth = b + 1
            x = merge_into(x,
                           tensors[cycle]['down'][depth],
                           skips=skips,
                           cycle=cycle,
                           direction='up',
                           depth=depth)
            if cycles_share_weights and cycle > 0 and cycle < num_cycles - 1:
                block = blocks[cycle - 1]['up'][depth]
            else:
                block_func = residual_block(initblock,
                                            filters=input_num_filters,
                                            repetitions=1,
                                            subsample=False,
                                            upsample=True,
                                            normalization=normalization,
                                            name='u' + str(depth),
                                            **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['up'][depth] = block
            tensors[cycle]['up'][depth] = x
            print("Cycle {} - INIT UP {}: {}".format(cycle, b, x._keras_shape))

        # Final convolution.
        x = merge_into(x,
                       tensors[cycle]['down'][0],
                       skips=skips,
                       cycle=cycle,
                       direction='up',
                       depth=0)
        if cycles_share_weights and cycle > 0 and cycle < num_cycles - 1:
            block = blocks[cycle - 1]['up'][0]
        else:

            def final_block(x):
                outputs = []
                for i in range(num_final_conv):
                    out = Convolution(filters=input_num_filters,
                                      kernel_size=3,
                                      ndim=ndim,
                                      weight_norm=weight_norm,
                                      kernel_initializer=init,
                                      padding='same',
                                      kernel_regularizer=_l2(weight_decay),
                                      name=_unique('final_conv_' + str(i)))(x)
                    if normalization is not None:
                        out = normalization(name=_unique('final_norm_' +
                                                         str(i)),
                                            **norm_kwargs)(out)
                    out = get_nonlinearity(nonlinearity)(out)
                    outputs.append(out)
                if len(outputs) > 1:
                    out = merge_add(outputs)
                else:
                    out = outputs[0]
                return out

            block = make_block(final_block, x)
        if use_final_conv:
            x = block(x)
            blocks[cycle]['up'][0] = block
        else:
            blocks[cycle]['up'][0] = lambda x: x
        tensors[cycle]['up'][0] = x
        if cycle > 0:
            # Merge preclassifier outputs across all cycles.
            x = merge_into(x,
                           tensors[cycle - 1]['up'][0],
                           skips=skips,
                           cycle=cycle,
                           direction='up',
                           depth=-1)
        print("Cycle {} - FIRST UP: {}".format(cycle, x._keras_shape))

    # Postprocessing
    if postprocessor_network is not None:
        x = postprocessor_network(x)

    # OUTPUTs (SIGMOID)
    all_outputs = []
    if num_classes is not None:
        for i in range(num_outputs):
            # Linear classifier
            classifiers = []
            for j in range(num_classifier):
                name = 'classifier_conv_' + str(j)
                if i > 0:
                    # backwards compatibility
                    name += '_out' + str(i)
                output = Convolution(filters=num_classes,
                                     kernel_size=1,
                                     ndim=ndim,
                                     activation='linear',
                                     kernel_regularizer=_l2(weight_decay),
                                     name=_unique(name))(x)
                classifiers.append(output)
            if len(classifiers) > 1:
                output = merge_add(classifiers)
            else:
                output = classifiers[0]
            if ndim == 2:
                output = Permute((2, 3, 1))(output)
            else:
                output = Permute((2, 3, 4, 1))(output)
            if num_classes == 1:
                output = Activation('sigmoid', name='sigmoid' + str(i))(output)
            else:
                output = Activation(_softmax, name='softmax' + str(i))(output)
            if ndim == 2:
                output_layer = Permute((3, 1, 2))
            else:
                output_layer = Permute((4, 1, 2, 3))
            output_layer.name = 'output_' + str(i)
            output = output_layer(output)
            all_outputs.append(output)
    else:
        # No classifier
        all_outputs = Activation('linear', name='output_0')(x)

    # MODEL
    model = Model(inputs=input, outputs=all_outputs)

    return model
Ejemplo n.º 10
0
def assemble_model_recurrent(input_shape,
                             num_filters,
                             num_classes,
                             normalization=LayerNorm,
                             norm_kwargs=None,
                             weight_norm=False,
                             num_outputs=1,
                             weight_decay=0.0005,
                             init='he_normal'):
    from recurrentshop import RecurrentModel
    assert (num_outputs == 1)

    if norm_kwargs is None:
        norm_kwargs = {}

    # Inputs
    model_input = Input(batch_shape=input_shape, name='model_input')
    input_t = Input(batch_shape=(input_shape[0], ) + input_shape[2:])
    hidden_input_t = Input(batch_shape=(input_shape[0], num_filters) +
                           input_shape[3:])

    # Common convolution kwargs.
    convolution_kwargs = {
        'filters': num_filters,
        'kernel_size': 3,
        'ndim': 2,
        'padding': 'same',
        'weight_norm': weight_norm,
        'kernel_initializer': init
    }

    # GRU input.
    x_t = Convolution(**convolution_kwargs,
                      kernel_regularizer=_l2(weight_decay),
                      activation='relu',
                      name=_unique('conv_x'))(input_t)
    if normalization is not None:
        x_t = normalization(**norm_kwargs)(x_t)

    # GRU block.
    gate_replace_x = Convolution(**convolution_kwargs,
                                 kernel_regularizer=_l2(weight_decay),
                                 activation='sigmoid',
                                 name=_unique('conv_gate_replace'))(x_t)
    #if normalization is not None:
    #gate_replace_x = normalization(**norm_kwargs)(gate_replace_x)
    gate_replace_h = Convolution(**convolution_kwargs,
                                 kernel_regularizer=_l2(weight_decay),
                                 activation='sigmoid',
                                 name=_unique('conv_gate_replace'))(
                                     hidden_input_t)
    #if normalization is not None:
    #gate_replace_h = normalization(**norm_kwargs)(gate_replace_h)
    gate_replace = merge_add([gate_replace_x, gate_replace_h])

    gate_read_x = Convolution(**convolution_kwargs,
                              kernel_regularizer=_l2(weight_decay),
                              activation='sigmoid',
                              name=_unique('conv_gate_read'))(x_t)
    #if normalization is not None:
    #gate_read_x = normalization(**norm_kwargs)(gate_read_x)
    gate_read_h = Convolution(**convolution_kwargs,
                              kernel_regularizer=_l2(weight_decay),
                              activation='sigmoid',
                              name=_unique('conv_gate_read'))(hidden_input_t)
    #if normalization is not None:
    #gate_read_h = normalization(**norm_kwargs)(gate_read_h)
    gate_read = merge_add([gate_read_x, gate_read_h])

    hidden_read_t = merge_multiply([gate_read, hidden_input_t])
    #if normalization is not None:
    #hidden_read_t = normalization(**norm_kwargs)(hidden_read_t)

    mix_t_pre = merge_concatenate([x_t, hidden_read_t], axis=1)
    mix_t = Convolution(**convolution_kwargs,
                        kernel_regularizer=_l2(weight_decay),
                        activation='tanh',
                        name=_unique('conv_mix'))(mix_t_pre)
    #if normalization is not None:
    #mix_t = normalization(**norm_kwargs)(mix_t)

    lambda_inputs = [mix_t, hidden_input_t, gate_replace]
    hidden_t = Lambda(function=lambda ins: ins[2] * ins[0] +
                      (1 - ins[2]) * ins[1],
                      output_shape=lambda x: x[0])(lambda_inputs)

    # GRU output.
    out_t = Convolution(**convolution_kwargs,
                        kernel_regularizer=_l2(weight_decay),
                        activation='relu',
                        name=_unique('conv_out'))(hidden_t)
    class_convolution_kwargs = copy.copy(convolution_kwargs)
    class_convolution_kwargs['filters'] = num_classes
    out_t = Convolution(**class_convolution_kwargs,
                        kernel_regularizer=_l2(weight_decay),
                        activation='linear',
                        name=_unique('conv_out'))(hidden_t)
    #if normalization is not None:
    #out_t = normalization(**norm_kwargs)(out_t)

    # Classifier.
    out_t = Permute((2, 3, 1))(out_t)
    if num_classes == 1:
        out_t = Activation('sigmoid')(out_t)
    else:
        out_t = Activation(_softmax)(out_t)
    out_t = Permute((3, 1, 2))(out_t)

    # Make it a recurrent block.
    #
    # NOTE: a bidirectional 'stateful' GRU has states passed between blocks
    # of the reverse path in non-temporal order. Only the forward pass is
    # stateful in sequential/temporal order.
    cobject = {LayerNorm.__name__: LayerNorm}
    output_layer = Bidirectional_(RecurrentModel(
        input=input_t,
        initial_states=[hidden_input_t],
        output=out_t,
        final_states=[hidden_t],
        stateful=True,
        return_sequences=True),
                                  merge_mode='sum',
                                  custom_objects=cobject)
    output_layer.name = 'output_0'
    model = Model(inputs=model_input, outputs=output_layer(model_input))
    return model