Beispiel #1
0
def bottleneck(inp, output, internal_scale=4, asymmetric=0, dilated=0, downsample=False, dropout_rate=0.1):
    # main branch
    internal = output // internal_scale
    encoder = inp

    # 1x1
    input_stride = 2 if downsample else 1  # the 1st 1x1 projection is replaced with a 2x2 convolution when downsampling
    encoder = Convolution2D(internal, (input_stride, input_stride), padding='same', strides=(input_stride, input_stride), use_bias=False)(encoder)
    # Batch normalization + PReLU
    encoder = BatchNormalization(momentum=0.1)(encoder) # enet uses momentum of 0.1, keras default is 0.99
    encoder = PReLU(shared_axes=[1, 2])(encoder)

    # conv
    if not asymmetric and not dilated:
        encoder = Convolution2D(internal, (3, 3), padding='same')(encoder)
    elif asymmetric:
        encoder = Convolution2D(internal, (1, asymmetric), padding='same', use_bias=False)(encoder)
        encoder = Convolution2D(internal, (asymmetric, 1), padding='same')(encoder)
    elif dilated:
        encoder = Convolution2D(internal, (3, 3), dilation_rate=(dilated, dilated), padding='same')(encoder)
    else:
        raise(Exception('You shouldn\'t be here'))

    encoder = BatchNormalization(momentum=0.1)(encoder)  # enet uses momentum of 0.1, keras default is 0.99
    encoder = PReLU(shared_axes=[1, 2])(encoder)
    
    # 1x1
    encoder = Convolution2D(output, (1, 1), padding='same', use_bias=False)(encoder)

    encoder = BatchNormalization(momentum=0.1)(encoder)  # enet uses momentum of 0.1, keras default is 0.99
    encoder = SpatialDropout2D(dropout_rate)(encoder)

    other = inp
    # other branch
    if downsample:
        print(encoder.get_shape(), inp.get_shape(), other.get_shape(),output)
        other = MaxPooling2D()(other)
        
        other = Permute((1, 3, 2))(other)
        pad_featmaps = output - inp.get_shape().as_list()[3]
        tb_pad = (0, 0)
        lr_pad = (0, pad_featmaps)
        print(other.get_shape(), "pad", lr_pad)
        other = ZeroPadding2D(padding=(tb_pad, lr_pad))(other)
        other = Permute((1, 3, 2))(other)

    encoder = add([encoder, other])
    encoder = PReLU(shared_axes=[1, 2])(encoder)
    return encoder