Пример #1
0
def get_model(input_shape, num_classes):
    inputs = tf.keras.Input(input_shape)

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv3D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = SeparableConv3D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = SeparableConv3D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling3D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv3D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv3DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv3DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling3D(2)(x)

        # Project residual
        residual = layers.UpSampling3D(2)(previous_block_activation)
        residual = layers.Conv3D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv3D(num_classes, 3, activation="softmax", padding="same")(x)

    # Define the model
    model = tf.keras.Model(inputs, outputs)
    return model
Пример #2
0
def decoder(convs,
            nb_filter,
            decoder_depth=3,
            decoder_num=0,
            layer_act='relu',
            activation='sigmoid',
            bn_axis=-1):
    outputs = []
    start_layer = convs[-1]
    for idx in range(decoder_depth):
        up = layers.UpSampling3D(name='up' + str(idx) +
                                 str(decoder_num))(start_layer)
        conv = layers.concatenate([up, convs[-(idx + 2)]],
                                  name='merge' + str(idx) + str(decoder_num),
                                  axis=bn_axis)
        conv = standard_unit(conv,
                             stage=str(idx) + str(decoder_num),
                             nb_filter=nb_filter[2],
                             layer_act=layer_act)
        start_layer = conv
        outputs.append(conv)

    conv_last = layers.Conv3D(1, (1, 1, 1),
                              activation=activation,
                              name='output_' + str(decoder_num),
                              kernel_initializer='he_normal',
                              padding='same')(start_layer)

    return outputs, conv_last
Пример #3
0
def build_up_step(input, conv_link, filters_no):
    """Builds a level of the synthesis path of the unet. 
    UpSampling layer, conv_bnorm_relu leyer, concatenate layer. Followed by 
    2 conv_bnorm_relu layers.

    *conv-bnorm-relu layer: conv. layer with added batch normalization and relu, 
    please see conv_bnorm_relu() function.

    Args:
        input: input value passed to layers.UpSampling3D
        conv_link: 
            output of the convolutinoal layer on the corresponding level 
            of the contraction path
        filters_no (int): number of convolutional filters in this level's 
            convolutional layers.

    Returns:
        conv: 
            output of the conv layers 
    """

    upsample = layers.UpSampling3D(size=(2, 2, 2))(input)
    up = conv_bnorm_relu(upsample, filters_no)
    merge = layers.concatenate([conv_link, up], axis=4)

    conv = conv_bnorm_relu(merge, filters_no)
    conv = conv_bnorm_relu(conv, filters_no)

    return conv
Пример #4
0
def retinanet(inputs, K, A):
    """Retinanet architecture with peter's simple backbone architecture."""
    kwargs = {'kernel_size': (1, 3, 3), 'padding': 'same'}
    conv = lambda x, filters, strides: layers.Conv3D(
        filters=filters, strides=strides, **kwargs)(x)
    norm = lambda x: layers.BatchNormalization()(x)
    relu = lambda x: layers.LeakyReLU()(x)

    conv1 = lambda filters, x: relu(norm(conv(x, filters, strides=1)))
    conv2 = lambda filters, x: relu(norm(conv(x, filters, strides=(1, 2, 2))))

    l1 = conv1(8, inputs['dat'])
    l2 = conv1(16, conv2(16, l1))
    l3 = conv1(24, conv2(24, l2))
    l4 = conv1(32, conv2(32, l3))
    l5 = conv1(48, conv2(48, l4))
    l6 = conv1(64, conv2(64, l5))

    zoom = lambda x: layers.UpSampling3D(size=(1, 2, 2))(x)

    proj = lambda filters, x: layers.Conv3D(filters=filters,
                                            strides=1,
                                            kernel_size=(1, 1, 1),
                                            padding='same',
                                            kernel_initializer='he_normal')(x)

    l7 = proj(64, l6)
    l8 = conv1(64, zoom(l7) + proj(64, l5))
    l9 = conv1(64, zoom(l8) + proj(64, l4))

    logits = {}
    K = K
    A = A

    # --- C2
    c3_cls = conv1(64, conv1(64, l9))
    c3_reg = conv1(64, conv1(64, l9))
    logits['cls-c3'] = layers.Conv3D(filters=(A * K), name='cls-c3',
                                     **kwargs)(c3_cls)
    logits['reg-c3'] = layers.Conv3D(filters=(A * 4), name='reg-c3',
                                     **kwargs)(c3_reg)

    # --- C3
    c4_cls = conv1(64, conv1(64, l8))
    c4_reg = conv1(64, conv1(64, l8))
    logits['cls-c4'] = layers.Conv3D(filters=(A * K), name='cls-c4',
                                     **kwargs)(c4_cls)
    logits['reg-c4'] = layers.Conv3D(filters=(A * 4), name='reg-c4',
                                     **kwargs)(c4_reg)

    # --- C4
    c5_cls = conv1(64, conv1(64, l7))
    c5_reg = conv1(64, conv1(64, l7))
    logits['cls-c5'] = layers.Conv3D(filters=(A * K), name='cls-c5',
                                     **kwargs)(c5_cls)
    logits['reg-c5'] = layers.Conv3D(filters=(A * 4), name='reg-c5',
                                     **kwargs)(c5_reg)

    model = Model(inputs=inputs, outputs=logits)
    return model
Пример #5
0
def feature_pyramid_3d(inputs, filter_ratio):
    kwargs1 = {
        'kernel_size': (1, 1, 1),
        'padding': 'valid',
    }
    kwargs3 = {
        'kernel_size': (1, 3, 3),
        'padding': 'same',
    }
    conv1 = lambda x, filters, strides: layers.Conv3D(
        filters=filters, strides=strides, **kwargs1)(x)
    add = lambda x, y: layers.Add()([x, y])
    upsamp2x = lambda x: layers.UpSampling3D(size=(1, 2, 2))(x)
    fp_block = lambda x, y: add(upsamp2x(x),
                                conv1(y, int(256 * filter_ratio), strides=1))
    conv3 = lambda x, filters, strides: layers.Conv3D(
        filters=filters, strides=strides, **kwargs3)(x)
    relu = lambda x: layers.LeakyReLU()(x)

    p5 = conv1(inputs[2], int(256 * filter_ratio), strides=1)
    fp4 = fp_block(p5, inputs[1])
    p4 = conv3(fp4, int(256 * filter_ratio), strides=1)
    fp3 = fp_block(fp4, inputs[0])
    p3 = conv3(fp3, int(256 * filter_ratio), strides=1)
    p6 = conv3(p5, int(256 * filter_ratio), strides=(1, 2, 2))
    p7 = conv3(relu(p6), int(256 * filter_ratio), strides=(1, 2, 2))
    return [p3, p4, p5, p6, p7]
Пример #6
0
    def __init__(self, out_shape, strides=1, ksize=3, shortcut=False):
        super(ResBlock_generator, self).__init__()
        self.shortcut = shortcut

        self.upSample = layers.UpSampling3D()
        self.conv_0 = layers.Conv3D(out_shape,
                                    kernel_size=ksize,
                                    strides=1,
                                    padding='same',
                                    name='rg_conv1',
                                    use_bias=False)
        self.bn_0 = layers.BatchNormalization()
        self.PRelu0 = layers.LeakyReLU(name='G_LeakyReLU1')
        self.conv_1 = layers.Conv3D(out_shape,
                                    kernel_size=ksize,
                                    strides=1,
                                    padding='same',
                                    name='rg_conv2',
                                    use_bias=False)
        self.bn_1 = layers.BatchNormalization()
        self.PRelu1 = layers.LeakyReLU(name='G_LeakyReLU2')
        self.conv_2 = layers.Conv3D(out_shape,
                                    kernel_size=ksize,
                                    strides=1,
                                    padding='same',
                                    name='rg_conv3',
                                    use_bias=False)
        self.bn_2 = layers.BatchNormalization()
        self.PRelu2 = layers.LeakyReLU(name='G_LeakyReLU3')
        self.conv_3 = layers.Conv3D(out_shape,
                                    kernel_size=ksize,
                                    strides=1,
                                    padding='same',
                                    name='rg_conv4',
                                    use_bias=False)

        self.bn_3 = layers.BatchNormalization()

        if shortcut:
            self.upSample_shortcut = layers.UpSampling3D()
            self.conv_shortcut = layers.Conv3D(out_shape,
                                               kernel_size=1,
                                               strides=1,
                                               padding='same',
                                               use_bias=False)

        self.PRelu3 = layers.LeakyReLU(name='G_LeakyReLU4')
Пример #7
0
def up_unit(conv4_1, conv3_1, stage):
    up3_3 = layers.UpSampling3D(name='up' + stage)(conv4_1)
    #up3_3 = layers.Conv3DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
    att3 = layers.Attention()([conv3_1, up3_3])
    conv3_3 = layers.concatenate([up3_3, att3],
                                 name='merge' + stage,
                                 axis=bn_axis)
    return conv3_3
Пример #8
0
def decoder_att_321(conv1_1,
                    conv2_1,
                    conv3_1,
                    conv4_1,
                    decoder_num=0,
                    layer_act='relu',
                    bn_axis=-1):
    up3_3 = layers.UpSampling3D(name='up33_' + str(decoder_num))(conv4_1)
    #up3_3 = layers.Conv3DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
    att3_3 = layers.Attention()([up3_3, conv3_1])
    conv3_3 = layers.concatenate([up3_3, att3_3],
                                 name='merge33_' + decoder_num,
                                 axis=bn_axis)
    conv3_3 = standard_unit(conv3_3,
                            stage='33_' + str(decoder_num),
                            nb_filter=nb_filter[2],
                            layer_act=layer_act)

    up2_4 = layers.UpSampling3D(name='up24_' + str(decoder_num))(conv3_3)
    #up2_4 = layers.Conv3DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
    att2_4 = layers.Attention()([up2_4, conv2_1])
    conv2_4 = layers.concatenate([up2_4, att2_4],
                                 name='merge24_' + str(decoder_num),
                                 axis=bn_axis)
    conv2_4 = standard_unit(conv2_4,
                            stage='24_' + str(decoder_num),
                            nb_filter=nb_filter[1],
                            layer_act=layer_act)

    up1_5 = layers.UpSampling3D(name='up15_' + str(decoder_num))(conv2_4)
    #up1_5 = layers.Conv3DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
    att1_5 = layers.Attention()([up1_5, conv1_1])
    conv1_5 = layers.concatenate([up1_5, att1_5],
                                 name='merge15_' + str(decoder_num),
                                 axis=bn_axis)
    conv1_5 = standard_unit(conv1_5,
                            stage='15_' + str(decoder_num),
                            nb_filter=nb_filter[0],
                            layer_act=layer_act)

    unet_output = layers.Conv3D(1, (1, 1, 1),
                                activation=activation,
                                name='output_' + str(decoder_num),
                                kernel_initializer='he_normal',
                                padding='same')(conv1_5)
    return unet_output
Пример #9
0
def Unet3D_MultiOutputs(shape, num_class=1, filters=32, first=4, end=3, model_depth=3, activation='sigmoid', layer_act='relu', pooling='max'):    
    nb_filter = [filters*2**i for i in range(5)]
    img_input = Input(shape=shape, name='main_input')
    conv1_1 = standard_unit(img_input, stage=model_depth-3, nb_filter=nb_filter[model_depth-3], layer_act=layer_act)
    conv1_2 = layers.Conv3D(num_class, (1, 1, 1), activation=activation, name='out_1', kernel_initializer='he_normal', padding='same')(conv1_1)  
    pool1 = pooling_unit(conv1_1, stage=model_depth-3, pooling=pooling)

    conv2_1 = standard_unit(pool1, stage=model_depth-2, nb_filter=nb_filter[model_depth-2], layer_act=layer_act)
    conv2_2 = layers.UpSampling3D(size=2**1)(conv2_1)
    conv2_2 = layers.Conv3D(num_class, (1, 1, 1), activation=activation, name='out_2', kernel_initializer='he_normal', padding='same')(conv2_2)  
    pool2 = pooling_unit(conv2_1, stage=model_depth-2, pooling=pooling)

    conv3_1 = standard_unit(pool2, stage=model_depth-1, nb_filter=nb_filter[model_depth-1], layer_act=layer_act)
    conv3_2 = layers.UpSampling3D(size=2**2)(conv3_1)
    conv3_2 = layers.Conv3D(num_class, (1, 1, 1), activation=activation, name='out_3', kernel_initializer='he_normal', padding='same')(conv3_2)     
    pool3 = pooling_unit(conv3_1, stage=model_depth-1, pooling=pooling)

    conv4_1 = standard_unit(pool3, stage=model_depth, nb_filter=nb_filter[model_depth], layer_act=layer_act)
    conv4_2 = layers.UpSampling3D(size=2**3)(conv4_1)
    conv4_2 = layers.Conv3D(num_class, (1, 1, 1), activation=activation, name='out_4', kernel_initializer='he_normal', padding='same')(conv4_2)   
    
    conv3_2 = up_unit(conv4_1, conv3_1, stage=model_depth-1, nb_filter=nb_filter[model_depth-1], layer_act=layer_act)  
    conv3_3 = layers.UpSampling3D(size=2**2,)(conv3_2)
    conv3_3 = layers.Conv3D(num_class, (1, 1, 1), activation=activation, name='out_5', kernel_initializer='he_normal', padding='same')(conv3_3)                           
    conv2_2 = up_unit(conv3_2, conv2_1, stage=model_depth-2, nb_filter=nb_filter[model_depth-2], layer_act=layer_act)   
    conv2_3 = layers.UpSampling3D(size=2**1)(conv2_2)
    conv2_3 = layers.Conv3D(num_class, (1, 1, 1), activation=activation, name='out_6', kernel_initializer='he_normal', padding='same')(conv2_3)                          
    conv1_2 = up_unit(conv2_2, conv1_1, stage=model_depth-3, nb_filter=nb_filter[model_depth-3], layer_act=layer_act)
    conv1_3 = layers.Conv3D(num_class, (1, 1, 1), activation=activation, name='out_7', kernel_initializer='he_normal', padding='same')(conv1_2)
    
    
    if first==1 and end==1:
        return Model(inputs=img_input, outputs=[conv1_2, last_out])
    elif first==2 and end==2:
        return Model(inputs=img_input, outputs=[conv1_2, conv2_2, conv2_4, conv1_3])    
    elif first==3 and end==3:
        return Model(inputs=img_input, outputs=[conv1_2, conv2_2, conv3_2, conv3_3, conv2_3, conv1_3])
    elif first==4 and end==3:
        return Model(inputs=img_input, outputs=[conv1_2, conv2_2, conv3_2, conv4_1, conv3_3, conv2_3, conv1_3])
    elif first==0 and end==1:
        return Model(inputs=img_input, outputs=[last_out])
    elif first==0 and end==2:
        return Model(inputs=img_input, outputs=[conv2_4, last_out])
    elif first==0 and end==3:
        return Model(inputs=img_input, outputs=[conv3_3, conv2_4, last_out])
Пример #10
0
 def __init__(self, scale: tuple, interp=NEAREST, scope='UPS'):
     super(Upscale, self).__init__(scope)
     dim = len(scale)
     if dim == 1:
         self.fn = layers.UpSampling1D(scale_factor=scale, mode=TF_INTERP[interp])
     elif dim == 2:
         self.fn = layers.UpSampling2D(scale_factor=scale, mode=TF_INTERP[interp])
     elif dim == 3:
         self.fn = layers.UpSampling3D(scale_factor=scale, mode=TF_INTERP[interp])
     else:
         raise Exception('NEBULAE ERROR ⨷ %d-d upscaling is not supported.' % dim)
Пример #11
0
def upsample_type(ch, mode="upsample"):
    if mode == "upsample":
        upsample = Sequential([
            layers.UpSampling3D(2),
            BasicBlock(ch, bn=True, act=True)
        ])
    elif mode == "transpose":
        upsample = BasicBlock(ch, transpose=True, bn=True, act=True)
    else:
        upsample = None
    return upsample
def get_model_unet(input_shape=input_shape,
                   conv_filt=32,
                   kernel_size=3,
                   activation="relu",
                   padding="same",
                   pool_size=pool_size):
    conv_args = {
        "activation": activation,
        "padding": padding,
        "kernel_size": kernel_size
    }
    inputs = layers.Input(shape=input_shape)
    conv1 = layers.Conv3D(filters=conv_filt, **conv_args)(inputs)
    conv2 = layers.Conv3D(filters=conv_filt, **conv_args)(conv1)
    pool1 = layers.MaxPooling3D(pool_size=pool_size)(conv2)
    #
    conv3 = layers.Conv3D(filters=2 * conv_filt, **conv_args)(pool1)
    conv4 = layers.Conv3D(filters=2 * conv_filt, **conv_args)(conv3)
    pool2 = layers.MaxPooling3D(pool_size=pool_size)(conv4)
    #
    conv5 = layers.Conv3D(filters=4 * conv_filt, **conv_args)(pool2)
    conv6 = layers.Conv3D(filters=2 * conv_filt, **conv_args)(conv5)
    up1 = layers.UpSampling3D(size=(2, 2, 2))(conv6)
    #
    conc1 = layers.Concatenate()([conv4, up1])
    #
    conv7 = layers.Conv3D(filters=2 * conv_filt, **conv_args)(conc1)
    conv8 = layers.Conv3D(filters=conv_filt, **conv_args)(conv7)
    up2 = layers.UpSampling3D(size=(2, 2, 2))(conv8)
    #
    conc2 = layers.Concatenate()([conv2, up2])
    #
    conv9 = layers.Conv3D(filters=conv_filt, **conv_args)(conc2)
    conv10 = layers.Conv3D(filters=conv_filt, **conv_args)(conv9)
    #
    output = layers.Conv3D(filters=1, kernel_size=1, activation=None)(conv10)
    #
    model = keras.Model(inputs=[inputs], outputs=[output])
    return model
Пример #13
0
    def __init__(self,
                 num_channels,
                 num_classes,
                 use_2d=True,
                 num_conv_layers=2,
                 kernel_size=(3, 3),
                 nonlinearity='relu',
                 use_batchnorm=True,
                 use_bias=True,
                 data_format='channels_last',
                 **kwargs):

        super(Nested_UNet, self).__init__(**kwargs)

        self.conv_block_lists = []
        self.pool = tfkl.MaxPooling2D() if use_2d else tfkl.MaxPooling3D()
        self.up = tfkl.UpSampling2D() if use_2d else tfkl.UpSampling3D()

        for i in range(len(num_channels)):
            output_ch = num_channels[i]
            conv_layer_lists = []
            num_conv_blocks = len(num_channels) - i

            for _ in range(num_conv_blocks):
                conv_layer_lists.append(
                    Conv_Block(num_channels=output_ch,
                               use_2d=use_2d,
                               num_conv_layers=num_conv_layers,
                               kernel_size=kernel_size,
                               nonlinearity=nonlinearity,
                               use_batchnorm=use_batchnorm,
                               use_bias=use_bias,
                               data_format=data_format))

            self.conv_block_lists.append(conv_layer_lists)

        if use_2d:
            self.conv_1x1 = tfkl.Conv2D(
                num_classes, (1, 1),
                activation='sigmoid' if self.num_classes == 1 else 'softmax',
                padding='same',
                data_format=data_format)
        else:
            self.conv_1x1 = tfkl.Conv3D(
                num_classes, (1, 1, 1),
                activation='sigmoid' if self.num_classes == 1 else 'softmax',
                padding='same',
                data_format=data_format)
Пример #14
0
    def generator_model(self, out_size, start_size=8, start_filters=512):

        # Fading function
        def blend_resolutions(upper, lower, alpha):
            upper = tf.multiply(upper, alpha)
            lower = tf.multiply(lower, tf.subtract(1, alpha))
            return kl.Add()([upper, lower])

        # For now we start at 2x4x4 and upsample by 2x each time, e.g. 4x8x8 is next, followed by 8x16x16
        conv_loop = int(np.log2(out_size/start_size))

        z = kl.Input(shape=(self.z_dim,))
        fade = kl.Input(shape=(1,))

        # First resolution (2 x 4 x 4)
        x = kl.Dense(start_filters * start_size**2 * start_size/2,
                     kernel_initializer=tf.keras.initializers.random_normal(stddev=0.01),
                     name='dense')(z)
        x = kl.Reshape((int(start_size/2), start_size, start_size, start_filters))(x)
        x = kl.BatchNormalization()(x)
        x = kl.ReLU()(x)

        lower_res = None
        for resolution in range(conv_loop):
            filters = max(start_filters // 2**(resolution+1), 4)
            x = kl.Conv3DTranspose(filters=filters, kernel_size=4, strides=2, padding='same',
                                   kernel_initializer=self.conv_init, use_bias=True,
                                   name='conv_'+str(2**(resolution+1)))(x)
            x = kl.BatchNormalization()(x)
            x = kl.ReLU()(x)
            if resolution == conv_loop - 1 and conv_loop > 1:
                lower_res = x

        # Conversion to 3-channel color
        # This is explicitly defined so we can reuse it for the upsampled lower-resolution frames as well
        convert_to_image = kl.Conv3DTranspose(filters=3, kernel_size=1, strides=1, padding='same',
                                              kernel_initializer=self.conv_init, use_bias=True, activation='tanh',
                                              name='conv_to_img_'+str(x.get_shape().as_list()[-1]))
        x = convert_to_image(x)

        # Fade output of previous resolution stage into final resolution stage
        if self.fade and lower_res:
            lower_upsampled = kl.UpSampling3D()(lower_res)
            lower_upsampled = convert_to_image(lower_upsampled)
            x = kl.Lambda(lambda x, y, alpha: blend_resolutions(x, y, alpha))([x, lower_upsampled, fade])

        return tf.keras.models.Model(inputs=[z, fade], outputs=x, name='generator')
Пример #15
0
def _upconv_bn_relu(layer, filters, kernel_size=2):
    """
    """
    if ndim(layer) == 4:
        layer = layers.Conv2D(filters=filters,
                              kernel_size=kernel_size,
                              padding='same',
                              kernel_initializer='he_uniform')(
                                  layers.UpSampling2D(size=(2, 2))
                                  (layer)
                             )
    elif ndim(layer) == 5:
        layer = layers.Conv3D(filters=filters,
                              kernel_size=kernel_size,
                              padding='same',
                              kernel_initializer='he_uniform')(
                                  layers.UpSampling3D(size=(2, 2, 2))
                                  (layer)
                             )
    layer = layers.BatchNormalization()(layer)
    layer = layers.Activation('relu')(layer)

    return layer
Пример #16
0
def unet3d(
    input_size,
    pretrained_weights=None,
    three_layers=False,
):
    """Constructs 3D unet manually, as opposed to unet3d_simply() where the
    model is constructed with standarised comopnents.

    Args:
        input_size: Input size passed to keras.layers.Input()

        pretrained_weights (optional): 
            Passed to model.load_weights(). 
            Defaults to None.

        three_layers (bool, optional): 
            Whether to creat a U-Net with 3 levles of depth. If False, 4 layers 
            are used. 
            Defaults to False.


    Returns:
        model: keras model
    """

    inputs = layers.Input(input_size)
    frames, height, width, _ = input_size

    # Down 1
    conv1 = layers.Conv3D(64,
                          3,
                          activation="relu",
                          padding="same",
                          kernel_initializer="he_normal")(inputs)
    conv1 = layers.Conv3D(64,
                          3,
                          activation="relu",
                          padding="same",
                          kernel_initializer="he_normal")(conv1)
    pool1 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)

    # Down 2
    conv2 = layers.Conv3D(128,
                          3,
                          activation="relu",
                          padding="same",
                          kernel_initializer="he_normal")(pool1)
    conv2 = layers.Conv3D(128,
                          3,
                          activation="relu",
                          padding="same",
                          kernel_initializer="he_normal")(conv2)
    pool2 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv2)

    # Down 3
    conv3 = layers.Conv3D(256,
                          3,
                          activation="relu",
                          padding="same",
                          kernel_initializer="he_normal")(pool2)
    conv3 = layers.Conv3D(256,
                          3,
                          activation="relu",
                          padding="same",
                          kernel_initializer="he_normal")(conv3)

    #  -------------------------- 4 levels version -------------------------------

    if not three_layers:
        pool3 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv3)

        # Down 4
        conv4 = layers.Conv3D(512,
                              3,
                              activation="relu",
                              padding="same",
                              kernel_initializer="he_normal")(pool3)
        conv4 = layers.Conv3D(512,
                              3,
                              activation="relu",
                              padding="same",
                              kernel_initializer="he_normal")(conv4)
        drop4 = layers.Dropout(0.5)(conv4)
        pool4 = layers.MaxPooling3D(pool_size=(2, 2, 2))(drop4)

        # Bottom
        conv_bottom = layers.Conv3D(1024,
                                    3,
                                    activation="relu",
                                    padding="same",
                                    kernel_initializer="he_normal")(pool4)
        conv_bottom = layers.Conv3D(
            1024,
            3,
            activation="relu",
            padding="same",
            kernel_initializer="he_normal")(conv_bottom)
        drop_bottom = layers.Dropout(0.5)(conv_bottom)

        # Up 4
        up4 = layers.Conv3D(512,
                            2,
                            2,
                            activation="relu",
                            padding="same",
                            kernel_initializer="he_normal")(
                                layers.UpSampling3D(size=(2, 2,
                                                          2))(drop_bottom))
        merge4 = layers.concatenate([drop4, up4], axis=4)
        convup4 = layers.Conv3D(512,
                                3,
                                activation="relu",
                                padding="same",
                                kernel_initializer="he_normal")(merge4)
        convup4 = layers.Conv3D(512,
                                3,
                                activation="relu",
                                padding="same",
                                kernel_initializer="he_normal")(convup4)

        to_up3 = convup4

    #  -------------------------- 4 levels version -------------------------------

    # --------------------------- 3 levels version -------------------------------
    else:
        conv3 = layers.Dropout(0.5)(conv3)
        pool3 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv3)

        conv_bottom = layers.Conv3D(512,
                                    3,
                                    activation="relu",
                                    padding="same",
                                    kernel_initializer="he_normal")(pool3)
        conv_bottom = layers.Conv3D(
            512,
            3,
            activation="relu",
            padding="same",
            kernel_initializer="he_normal")(conv_bottom)
        drop_bottom = layers.Dropout(0.5)(conv_bottom)

        to_up3 = drop_bottom
    # --------------------------- 3 levels version -------------------------------

    # Up 3
    up3 = layers.Conv3D(256,
                        2,
                        activation="relu",
                        padding="same",
                        kernel_initializer="he_normal")(
                            layers.UpSampling3D(size=(2, 2, 2))(to_up3))
    merge3 = layers.concatenate([conv3, up3], axis=4)
    convup3 = layers.Conv3D(256,
                            3,
                            activation="relu",
                            padding="same",
                            kernel_initializer="he_normal")(merge3)
    convup3 = layers.Conv3D(256,
                            3,
                            activation="relu",
                            padding="same",
                            kernel_initializer="he_normal")(convup3)

    # Up 2
    up2 = layers.Conv3D(128,
                        2,
                        activation="relu",
                        padding="same",
                        kernel_initializer="he_normal")(
                            layers.UpSampling3D(size=(2, 2, 2))(convup3))
    merge2 = layers.concatenate([conv2, up2], axis=4)
    convup2 = layers.Conv3D(128,
                            3,
                            activation="relu",
                            padding="same",
                            kernel_initializer="he_normal")(merge2)
    convup2 = layers.Conv3D(128,
                            3,
                            activation="relu",
                            padding="same",
                            kernel_initializer="he_normal")(convup2)

    # Up 1
    up1 = layers.Conv3D(64,
                        2,
                        activation="relu",
                        padding="same",
                        kernel_initializer="he_normal")(
                            layers.UpSampling3D(size=(2, 2, 2))(convup2))
    merge1 = layers.concatenate([conv1, up1], axis=4)
    convup1 = layers.Conv3D(64,
                            3,
                            activation="relu",
                            padding="same",
                            kernel_initializer="he_normal")(merge1)
    convup1 = layers.Conv3D(64,
                            3,
                            activation="relu",
                            padding="same",
                            kernel_initializer="he_normal")(convup1)

    convout = layers.Conv3D(2, (1, 3, 3),
                            activation="relu",
                            padding="same",
                            kernel_initializer="he_normal")(convup1)
    convout = layers.Conv3D(
        2,
        (frames, 1, 1),
        activation="relu",
        padding="valid",
        kernel_initializer="he_normal",
    )(convout)
    convout = layers.Reshape((height, width, 1))(convout)

    model = keras.Model(inputs=inputs, outputs=convout)

    model.compile(
        optimizer=keras.optimizers.Adam(lr=1e-4),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )

    if pretrained_weights:
        model.load_weights(pretrained_weights)

    return model
Пример #17
0
def get_model(img_size):
    num_classes = 1
    inputs = keras.Input(shape=img_size)
    conv1 = layers.Conv3D(8,
                          3,
                          activation='relu',
                          padding="same",
                          data_format="channels_last")(inputs)
    conv1 = layers.Conv3D(8, 3, activation='relu', padding="same")(conv1)
    pool1 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)

    conv2 = layers.Conv3D(16, 3, activation='relu', padding="same")(pool1)
    pool2 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv2)

    conv3 = layers.Conv3D(32, 3, activation='relu', padding="same")(pool2)
    pool3 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv3)

    conv4 = layers.Conv3D(64, 3, activation='relu', padding="same")(pool3)
    pool4 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv4)

    conv5 = layers.Conv3D(128, 3, activation='relu', padding="same")(pool4)
    conv5 = layers.Conv3D(128, 3, activation='relu', padding="same")(conv5)

    up6 = layers.Conv3D(64, 2, activation='relu',
                        padding="same")(layers.UpSampling3D(2)(conv5))
    merge6 = layers.concatenate([conv4, up6], axis=-1)
    conv6 = layers.Conv3D(64, 3, activation='relu', padding="same")(merge6)
    conv6 = layers.Conv3D(64, 3, activation='relu', padding="same")(conv6)

    up7 = layers.Conv3D(32, 2, activation='relu',
                        padding="same")(layers.UpSampling3D(2)(conv6))
    merge7 = layers.concatenate([conv3, up7], axis=-1)
    conv7 = layers.Conv3D(32, 3, activation='relu', padding="same")(merge7)
    conv7 = layers.Conv3D(32, 3, activation='relu', padding="same")(conv7)

    up8 = layers.Conv3D(16, 2, activation='relu',
                        padding="same")(layers.UpSampling3D(2)(conv7))
    merge8 = layers.concatenate([conv2, up8], axis=-1)
    conv8 = layers.Conv3D(16, 3, activation='relu', padding="same")(merge8)
    conv8 = layers.Conv3D(16, 3, activation='relu', padding="same")(conv8)

    up9 = layers.Conv3D(8, 2, activation='relu',
                        padding="same")(layers.UpSampling3D(2)(conv8))
    merge9 = layers.concatenate([conv1, up9], axis=-1)
    conv9 = layers.Conv3D(8, 3, activation='relu', padding="same")(merge9)
    conv9 = layers.Conv3D(8, 3, activation='relu', padding="same")(conv9)
    '''
    #Downsampling
    for filters in [8,16,32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv3D(filters,3,padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv3D(filters,3,padding="same")(x)
        x = layers.BatchNormalization()(x)
    
        x = layers.MaxPooling3D(3,strides=2,padding="same")(x)

        residual = layers.Conv3D(filters,1,strides=2,padding="same")(previous_block)

        x = layers.add([x,residual])
        previous_block = x

    #Upsampling
    for filters in [32,16,8,4]:
        x = layers.Activation("relu")(x)
        x = layers.Conv3DTranspose(filters,3,padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv3DTranspose(filters,3,padding="same")(x)
        x = layers.BatchNormalization()(x)
    
        x = layers.UpSampling3D(2)(x)

        residual = layers.UpSampling3D(4)(previous_block)
        residual = layers.Conv3D(filters,1,strides=2,padding="same")(residual)
        x = layers.add([x,residual])
        previous_block = x
    '''
    outputs = layers.Conv3D(num_classes, 1, activation="sigmoid")(conv9)
    model = keras.Model(inputs, outputs)
    return model
Пример #18
0
def up_unit(conv4_1, conv3_1, stage, nb_filter, bn_axis=-1, layer_act='relu'):
    up = layers.UpSampling3D(name='up_'+str(stage))(conv4_1)
    conv = layers.concatenate([up, conv3_1], name='merge_'+str(stage), axis=bn_axis)
    conv = standard_unit(conv, stage='up_last_'+str(stage), nb_filter=nb_filter, layer_act=layer_act)
    return conv
Пример #19
0
def create_model(image_shape=(1280, 720)):
    data_augmentation = keras.Sequential([
        layers.experimental.preprocessing.RandomFlip('horizontal',
                                                     input_shape=image_shape),
        layers.experimental.preprocessing.RandomRotation(0.1),
        layers.experimental.preprocessing.RandomZoom(0.1)
    ])
    model = Sequential([
        data_augmentation,
        layers.experimental.preprocessing.Rescaling(1. / 255,
                                                    input_shape=image_shape),
        # Block 1
        layers.Conv2D(filters=64,
                      kernel_size=(4, 4),
                      strides=4,
                      kernel_initializer="glorot_uniform"),
        layers.BatchNormalization(axis=3),
        layers.Activation('relu'),
        # Block 2
        layers.Conv2D(filters=256,
                      kernel_size=(2, 2),
                      strides=2,
                      kernel_initializer="glorot_uniform"),
        layers.BatchNormalization(axis=3),
        layers.Activation('relu'),
        # Block 3
        layers.Conv2D(filters=512,
                      kernel_size=(1, 1),
                      strides=1,
                      kernel_initializer="glorot_uniform"),
        layers.BatchNormalization(axis=3),
        layers.Activation('relu'),
        # Block 4
        layers.Conv2D(filters=1024,
                      kernel_size=(5, 5),
                      strides=5,
                      kernel_initializer="glorot_uniform"),
        layers.BatchNormalization(axis=3),
        layers.Activation('relu'),
        # Block 5
        layers.Conv2D(filters=2048,
                      kernel_size=(2, 2),
                      strides=2,
                      kernel_initializer="glorot_uniform"),
        layers.BatchNormalization(axis=3),
        layers.Activation('relu'),
        # Block 6
        layers.Conv2D(filters=2048,
                      kernel_size=(5, 5),
                      strides=1,
                      kernel_initializer="glorot_uniform"),
        layers.BatchNormalization(axis=3),
        layers.Activation('relu'),
        layers.Flatten(),
        layers.UpSampling3D(size=(1920, 1080, 3)),
        # layers.Dense(units=, activation=None),
        # layers.Reshape((1920, 1080, 3))
    ])
    model.summary()
    model.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'])

    return model
Пример #20
0
def retinanet_resnet(inputs, K, A):
    """Retinanet with resnet backbone. Classification and regression networks share weights across feature pyramid
     layers"""
    # --- Define kwargs dictionary
    kwargs1 = {
        'kernel_size': (1, 1, 1),
        'padding': 'valid',
    }
    kwargs3 = {
        'kernel_size': (1, 3, 3),
        'padding': 'same',
    }
    kwargs7 = {
        'kernel_size': (1, 7, 7),
        'padding': 'valid',
    }
    # --- Define block components
    conv1 = lambda x, filters, strides: layers.Conv3D(
        filters=filters, strides=strides, **kwargs1)(x)
    conv3 = lambda x, filters, strides: layers.Conv3D(
        filters=filters, strides=strides, **kwargs3)(x)
    relu = lambda x: layers.LeakyReLU()(x)
    conv7 = lambda x, filters, strides: layers.Conv3D(
        filters=filters, strides=strides, **kwargs7)(x)
    max_pool = lambda x, pool_size, strides: layers.MaxPooling3D(
        pool_size=pool_size, strides=strides, padding='valid')(x)
    norm = lambda x: layers.BatchNormalization()(x)
    add = lambda x, y: layers.Add()([x, y])
    zeropad = lambda x, padding: layers.ZeroPadding3D(padding=padding)(x)
    upsamp2x = lambda x: layers.UpSampling3D(size=(1, 2, 2))(x)
    # --- Define stride-1, stride-2 blocks
    # conv1 = lambda filters, x : relu(conv(x, filters, strides=1))
    # conv2 = lambda filters, x : relu(conv(x, filters, strides=(2, 2)))
    # --- Residual blocks
    # conv blocks
    conv_1 = lambda filters, x, strides: relu(
        norm(conv1(x, filters, strides=strides)))
    conv_2 = lambda filters, x: relu(norm(conv3(x, filters, strides=1)))
    conv_3 = lambda filters, x: norm(conv1(x, filters, strides=1))
    conv_sc = lambda filters, x, strides: norm(
        conv1(x, filters, strides=strides))
    conv_block = lambda filters1, filters2, x, strides: relu(
        add(conv_3(filters2, conv_2(filters1, conv_1(filters1, x, strides))),
            conv_sc(filters2, x, strides)))
    # identity blocks
    identity_1 = lambda filters, x: relu(norm(conv1(x, filters, strides=1)))
    identity_2 = lambda filters, x: relu(norm(conv3(x, filters, strides=1)))
    identity_3 = lambda filters, x: norm(conv1(x, filters, strides=1))
    identity_block = lambda filters1, filters2, x: relu(
        add(
            identity_3(filters2, identity_2(filters1, identity_1(filters1, x))
                       ), x))
    # --- feature pyramid blocks
    fp_block = lambda x, y: add(upsamp2x(x), conv1(y, 256, strides=1))
    # --- classification head
    class_subnet = classification_head(K, A)
    # --- regression head
    box_subnet = regression_head(A)
    # --- ResNet-50 backbone
    # stage 1 c2 1/4
    res1 = max_pool(zeropad(
        relu(
            norm(
                conv7(zeropad(inputs['dat'], (0, 3, 3)), 64,
                      strides=(1, 2, 2)))), (0, 1, 1)), (1, 3, 3),
                    strides=(1, 2, 2))
    # stage 2 c2 1/4
    res2 = identity_block(
        64, 256, identity_block(64, 256, conv_block(64, 256, res1, strides=1)))
    # stage 3 c3 1/8
    res3 = identity_block(
        128, 512,
        identity_block(
            128, 512,
            identity_block(128, 512,
                           conv_block(128, 512, res2, strides=(1, 2, 2)))))
    # stage 4 c4 1/16
    res4 = identity_block(
        256, 1024,
        identity_block(
            256, 1024,
            identity_block(
                256, 1024,
                identity_block(
                    256, 1024,
                    identity_block(
                        256, 1024,
                        conv_block(256, 1024, res3, strides=(1, 2, 2)))))))
    # stage 5 c5 1/32
    res5 = identity_block(
        512, 2048,
        identity_block(512, 2048, conv_block(512,
                                             2048,
                                             res4,
                                             strides=(1, 2, 2))))
    # --- Feature Pyramid Network architecture
    # p5 1/32
    fp5 = conv1(res5, 256, strides=1)
    # p4 1/16
    fp4 = fp_block(fp5, res4)
    p4 = conv3(fp4, 256, strides=1)
    # p3 1/8
    fp3 = fp_block(fp4, res3)
    p3 = conv3(fp3, 256, strides=1)
    # p6 1/4
    # p6 = conv3(fp5, 256, strides=(2, 2))
    # p7 1/2
    # p7 = conv3(relu(p6), 256, strides=(2, 2))
    feature_pyramid = [p3, p4, fp5]
    # lambda layer that allows multiple outputs from a shared model to have specific names
    # layers.Lambda(lambda x:x, name=name)()
    # --- Class subnet
    class_outputs = [class_subnet(features) for features in feature_pyramid]
    # --- Box subnet
    box_outputs = [box_subnet(features) for features in feature_pyramid]
    # --- put class and box outputs in dictionary
    logits = {
        'cls-c3': layers.Lambda(lambda x: x, name='cls-c3')(class_outputs[0]),
        'reg-c3': layers.Lambda(lambda x: x, name='reg-c3')(box_outputs[0]),
        'cls-c4': layers.Lambda(lambda x: x, name='cls-c4')(class_outputs[1]),
        'reg-c4': layers.Lambda(lambda x: x, name='reg-c4')(box_outputs[1]),
        'cls-c5': layers.Lambda(lambda x: x, name='cls-c5')(class_outputs[2]),
        'reg-c5': layers.Lambda(lambda x: x, name='reg-c5')(box_outputs[2])
    }

    model = Model(inputs=inputs, outputs=logits)
    return model
Пример #21
0
# 14X14 Stage
x46 = Bottleneck_JQ(512, 256, mode="UP", k=2)(x45)  # no shortcut, uses channel reduction function.
x47 = Bottleneck_JQ(256, 256)(x46)
x48 = layers.Concatenate()([x47, x25])  # Concatenation does not happen at the begining of the stage?
x49 = Bottleneck_JQ(768, 384, mode="UP", k=2)(x48)
x50 = Bottleneck_JQ(384, 384)(x49)


# upsampling
x51 = layers.UpSampling2D((2, 2))(x50)

# 28x28 Stage
x52 = Bottleneck_JQ(384, 128, mode="UP", k=3)(x51)
x53 = Bottleneck_JQ(128, 128)(x52)
x53a = layers.UpSampling3D((1, 1, 2))(x53)  # upsampled the channels in order to make the numbers work
x54 = layers.Concatenate()([x53a, x17])  # Concatenation does not happen at the begining of the stage?
x55 = Bottleneck_JQ(512, 256, mode="UP", k=2)(x54)
x56 = Bottleneck_JQ(256, 256)(x55)

# upsampling
x57 = layers.UpSampling2D((2, 2))(x56)

# 56x56 Stage
x58 = Bottleneck_JQ(256, 64, mode="UP", k=4)(x57)
x59 = Bottleneck_JQ(64, 64)(x58)
x59a = layers.UpSampling3D((1, 1, 3))(x59)  # another weird upsampling situation to make the numbers work
x60 = layers.Concatenate()([x59a, x12])  # Concatenation does not happen at the begining of the stage?
x61 = Bottleneck_JQ(320, 320)(x60)  # THERE IS SOME SUPREME WEIRDNESS GOING ON HERE
x62 = Bottleneck_JQ(320, 320)(x61)
Пример #22
0
#x = layers.Conv3D(256, (3, 3, 3), activation="relu",  padding="same")(x)
x = layers.Flatten()(x) # to feed into sampling function
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_sigma = layers.Dense(latent_dim, name="z_log_var")(x)
z = layers.Lambda(sampling, name='z')([z_mean, z_log_sigma])
z_label = concatenate([z, label], name='encoded') 
## initiating the encoder, it ouputs the latent dim dimensions
encoder = keras.Model([encoder_inputs, label], [z_mean, z_log_sigma, z_label, z], name="encoder")
encoder.summary()

#### Make the decoder, takes the latent keras
latent_inputs = keras.Input(shape=(latent_dim + n_y),) # changes based on depth 
x =  layers.Dense(2*10*10*128, activation='relu')(latent_inputs)
x = layers.Reshape((2, 10, 10, 128))(x)
#x = layers.Conv3DTranspose(128, (3, 3, 3), activation="relu", padding="same")(x)
x = layers.UpSampling3D((2,2,2))(x)
#x = layers.SpatialDropout3D(0.3)(x)
x = layers.Conv3DTranspose(64, (3, 3, 3), activation="relu",  padding="same")(x)
x = layers.UpSampling3D((2,2,2))(x)
#x = layers.SpatialDropout3D(0.2)(x)
x = layers.Conv3DTranspose(32, (3, 3, 3), activation="relu",  padding="same")(x)
x = layers.UpSampling3D((2,2,2))(x)
#x = layers.SpatialDropout3D(0.3)(x)
decoder_outputs = layers.Conv3DTranspose(1, 3, activation="sigmoid", padding="same")(x)
# Initiate decoder
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

# Instantiate and fit VAE model, outputs only z_label
outputs = decoder(encoder([encoder_inputs, label])[2])
cvae = keras.Model([encoder_inputs, label], outputs, name='cvae')
Пример #23
0
def unet3d_keras():
    """
    3D U-Net
    """
    def ConvolutionBlock(x, name, fms, params):

        x = layers.Conv3D(filters=fms, **params, name=name + "_conv0")(x)
        x = layers.BatchNormalization(name=name + "_bn0")(x)
        x = layers.Activation("relu", name=name + "_relu0")(x)

        x = layers.Conv3D(filters=fms, **params, name=name + "_conv1")(x)
        x = layers.BatchNormalization(name=name + "_bn1")(x)
        x = layers.Activation("relu", name=name)(x)

        return x

    inputs = layers.Input(shape=config.patch_size + [4], name="MRImages")

    params = dict(kernel_size=(3, 3, 3),
                  activation=None,
                  padding="same",
                  data_format="channels_last",
                  kernel_initializer="he_uniform")

    # Transposed convolution parameters
    params_trans = dict(data_format="channels_last",
                        kernel_size=(2, 2, 2),
                        strides=(2, 2, 2),
                        padding="same")

    # BEGIN - Encoding path
    encodeA = ConvolutionBlock(inputs, "encodeA", config.BASE_FILTER, params)
    poolA = layers.MaxPooling3D(name="poolA", pool_size=(2, 2, 2))(encodeA)

    encodeB = ConvolutionBlock(poolA, "encodeB", config.BASE_FILTER * 2,
                               params)
    poolB = layers.MaxPooling3D(name="poolB", pool_size=(2, 2, 2))(encodeB)

    encodeC = ConvolutionBlock(poolB, "encodeC", config.BASE_FILTER * 4,
                               params)
    poolC = layers.MaxPooling3D(name="poolC", pool_size=(2, 2, 2))(encodeC)

    encodeD = ConvolutionBlock(poolC, "encodeD", config.BASE_FILTER * 8,
                               params)
    poolD = layers.MaxPooling3D(name="poolD", pool_size=(2, 2, 2))(encodeD)

    encodeE = ConvolutionBlock(poolD, "encodeE", config.BASE_FILTER * 16,
                               params)
    # END - Encoding path

    # BEGIN - Decoding path
    up = layers.UpSampling3D(name="upE", size=(2, 2, 2))(encodeE)

    concatD = layers.concatenate([up, encodeD], axis=-1, name="concatD")
    decodeC = ConvolutionBlock(concatD, "decodeC", config.BASE_FILTER * 8,
                               params)

    up = layers.UpSampling3D(name="upC", size=(2, 2, 2))(decodeC)
    concatC = layers.concatenate([up, encodeC], axis=-1, name="concatC")
    decodeB = ConvolutionBlock(concatC, "decodeB", config.BASE_FILTER * 4,
                               params)

    up = layers.UpSampling3D(name="upB", size=(2, 2, 2))(decodeB)
    concatB = layers.concatenate([up, encodeB], axis=-1, name="concatB")
    decodeA = ConvolutionBlock(concatB, "decodeA", config.BASE_FILTER * 2,
                               params)

    up = layers.UpSampling3D(name="upA", size=(2, 2, 2))(decodeA)
    concatA = layers.concatenate([up, encodeA], axis=-1, name="concatA")

    # END - Decoding path
    convOut = ConvolutionBlock(concatA, "convOut", config.BASE_FILTER, params)
    logits = layers.Conv3D(name="PredictionMask",
                           filters=config.num_classes,
                           kernel_size=(1, 1, 1),
                           data_format="channels_last")(convOut)

    model = models.Model(inputs=[inputs], outputs=[logits])

    print(model.summary())

    return model
    def __init__(self,
                 num_channels,
                 use_2d=True,
                 kernel_size=2,
                 activation='relu',
                 use_attention=False,
                 use_batchnorm=False,
                 use_transpose=False,
                 use_bias=True,
                 strides=2,
                 data_format='channels_last',
                 name="upsampling_conv_block",
                 **kwargs):

        super(Up_Conv, self).__init__(name=name)

        self.data_format = data_format
        self.use_attention = use_attention

        if use_transpose:
            if use_2d:
                self.upconv_layer = tfkl.Conv2DTranspose(
                    num_channels,
                    kernel_size,
                    padding='same',
                    strides=strides,
                    data_format=self.data_format)
            else:
                self.upconv_layer = tfkl.Conv3DTranspose(
                    num_channels,
                    kernel_size,
                    padding='same',
                    strides=strides,
                    data_format=self.data_format)
        else:
            if use_2d:
                self.upconv_layer = tfkl.UpSampling2D(size=strides)
            else:
                self.upconv_layer = tfkl.UpSampling3D(size=strides)

        if self.use_attention:
            self.attention = Attention_Gate(num_channels=num_channels,
                                            use_2d=use_2d,
                                            kernel_size=1,
                                            activation=activation,
                                            padding='same',
                                            strides=strides,
                                            use_bias=use_bias,
                                            data_format=self.data_format)

        self.conv = Conv_Block(num_channels=num_channels,
                               use_2d=use_2d,
                               num_conv_layers=1,
                               kernel_size=kernel_size,
                               activation=activation,
                               use_batchnorm=use_batchnorm,
                               use_dropout=False,
                               data_format=self.data_format)

        self.conv_block = Conv_Block(num_channels=num_channels,
                                     use_2d=use_2d,
                                     num_conv_layers=2,
                                     kernel_size=3,
                                     activation=activation,
                                     use_batchnorm=use_batchnorm,
                                     use_dropout=False,
                                     data_format=self.data_format)