def CapsNet(input_shape, n_class, routings):
    """
    A Capsule Network on MNIST.
    :param input_shape: data shape, 3d, [width, height, channels]
    :param n_class: number of classes
    :param routings: number of routing iterations
    :return: Two Keras Models, the first one used for training, and the second one for evaluation.
            `eval_model` can also be used for training.
    """
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)

    # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
    primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')

    # Layer 3: Capsule layer. Routing algorithm works here.
    
    digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=digitcap_dim, routings=routings,
                             name='digitcaps')(primarycaps)

    # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
    # If using tensorflow, this will not be necessary. :)
    out_caps = Length(name='capsnet')(digitcaps)

    # Decoder network.
    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([digitcaps, y])  # The true label is used to mask the output of capsule layer. For training
    masked = Mask()(digitcaps)  # Mask using the capsule with maximal length. For prediction

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=digitcap_dim*n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)])
    eval_model = models.Model(x, [out_caps, decoder(masked)])
    digit_caps_model = models.Model(x, digitcaps)

    
    return train_model, eval_model, digit_caps_model
Esempio n. 2
0
def CapsNetR3(input_shape, n_class=2):
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=16,
                          kernel_size=5,
                          strides=1,
                          padding='same',
                          activation='relu',
                          name='conv1')(x)

    # Reshape layer to be 1 capsule x [filters] atoms
    _, H, W, C = conv1.get_shape()
    # print("conv1 params",conv1.get_shape())
    conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1)

    # Layer 1: Primary Capsule: Conv cap with routing 1
    primary_caps = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=2,
                                    num_atoms=16,
                                    strides=2,
                                    padding='same',
                                    routings=1,
                                    name='primarycaps')(conv1_reshaped)

    # Layer 2: Convolutional Capsule
    conv_cap_2_1 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=4,
                                    num_atoms=16,
                                    strides=1,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_2_1')(primary_caps)

    # Layer 2: Convolutional Capsule
    conv_cap_2_2 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=4,
                                    num_atoms=32,
                                    strides=2,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_2_2')(conv_cap_2_1)

    # Layer 3: Convolutional Capsule
    conv_cap_3_1 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=8,
                                    num_atoms=32,
                                    strides=1,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_3_1')(conv_cap_2_2)

    # Layer 3: Convolutional Capsule
    conv_cap_3_2 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=8,
                                    num_atoms=64,
                                    strides=2,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_3_2')(conv_cap_3_1)

    # Layer 4: Convolutional Capsule
    conv_cap_4_1 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=8,
                                    num_atoms=32,
                                    strides=1,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_4_1')(conv_cap_3_2)

    # Layer 1 Up: Deconvolutional Capsule
    deconv_cap_1_1 = DeconvCapsuleLayer(kernel_size=4,
                                        num_capsule=8,
                                        num_atoms=32,
                                        upsamp_type='deconv',
                                        scaling=2,
                                        padding='same',
                                        routings=3,
                                        name='deconv_cap_1_1')(conv_cap_4_1)

    # Skip connection
    up_1 = layers.Concatenate(axis=-2,
                              name='up_1')([deconv_cap_1_1, conv_cap_3_1])

    # Layer 1 Up: Deconvolutional Capsule
    deconv_cap_1_2 = ConvCapsuleLayer(kernel_size=5,
                                      num_capsule=4,
                                      num_atoms=32,
                                      strides=1,
                                      padding='same',
                                      routings=3,
                                      name='deconv_cap_1_2')(up_1)

    # Layer 2 Up: Deconvolutional Capsule
    deconv_cap_2_1 = DeconvCapsuleLayer(kernel_size=4,
                                        num_capsule=4,
                                        num_atoms=16,
                                        upsamp_type='deconv',
                                        scaling=2,
                                        padding='same',
                                        routings=3,
                                        name='deconv_cap_2_1')(deconv_cap_1_2)

    # Skip connection
    up_2 = layers.Concatenate(axis=-2,
                              name='up_2')([deconv_cap_2_1, conv_cap_2_1])

    # Layer 2 Up: Deconvolutional Capsule
    deconv_cap_2_2 = ConvCapsuleLayer(kernel_size=5,
                                      num_capsule=4,
                                      num_atoms=16,
                                      strides=1,
                                      padding='same',
                                      routings=3,
                                      name='deconv_cap_2_2')(up_2)

    # Layer 3 Up: Deconvolutional Capsule
    deconv_cap_3_1 = DeconvCapsuleLayer(kernel_size=4,
                                        num_capsule=2,
                                        num_atoms=16,
                                        upsamp_type='deconv',
                                        scaling=2,
                                        padding='same',
                                        routings=3,
                                        name='deconv_cap_3_1')(deconv_cap_2_2)

    # Skip connection
    up_3 = layers.Concatenate(axis=-2,
                              name='up_3')([deconv_cap_3_1, conv1_reshaped])

    # Layer 4: Convolutional Capsule: 1x1
    seg_caps = ConvCapsuleLayer(kernel_size=1,
                                num_capsule=1,
                                num_atoms=16,
                                strides=1,
                                padding='same',
                                routings=3,
                                name='seg_caps')(up_3)

    # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
    out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps)

    # Decoder network.
    _, H, W, C, A = seg_caps.get_shape()
    y = layers.Input(shape=input_shape[:-1] + (1, ))
    # print(y)
    masked_by_y = Mask()(
        [seg_caps, y]
    )  # The true label is used to mask the output of capsule layer. For training
    masked = Mask()(
        seg_caps)  # Mask using the capsule with maximal length. For prediction

    # print(masked_by_y, masked)
    def shared_decoder(mask_layer):
        recon_remove_dim = layers.Reshape(
            (H.value, W.value, A.value))(mask_layer)

        recon_1 = layers.Conv2D(filters=64,
                                kernel_size=1,
                                padding='same',
                                kernel_initializer='he_normal',
                                activation='relu',
                                name='recon_1')(recon_remove_dim)

        recon_2 = layers.Conv2D(filters=128,
                                kernel_size=1,
                                padding='same',
                                kernel_initializer='he_normal',
                                activation='relu',
                                name='recon_2')(recon_1)

        out_recon = layers.Conv2D(filters=1,
                                  kernel_size=1,
                                  padding='same',
                                  kernel_initializer='he_normal',
                                  activation='sigmoid',
                                  name='out_recon')(recon_2)

        return out_recon

    # Models for training and evaluation (prediction)
    train_model = models.Model(inputs=[x, y],
                               outputs=[out_seg,
                                        shared_decoder(masked_by_y)])
    eval_model = models.Model(inputs=x,
                              outputs=[out_seg,
                                       shared_decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=((H.value, W.value, C.value, A.value)))
    noised_seg_caps = layers.Add()([seg_caps, noise])
    masked_noised_y = Mask()([noised_seg_caps, y])
    manipulate_model = models.Model(inputs=[x, y, noise],
                                    outputs=shared_decoder(masked_noised_y))
    train_model.compile(optimizer='rmsprop',
                        loss='binary_crossentropy',
                        metrics=[mean_iou])
    return train_model
Esempio n. 3
0
def DiagnosisCapsules(input_shape,
                      n_class=2,
                      k_size=5,
                      output_atoms=16,
                      routings1=3,
                      routings2=3):
    """
    A Capsule Network on Medical Image Diagnosis.
    :param input_shape: data shape
    :param n_class: number of classes
    :param k_size: kernel size for convolutional capsules
    :param output_atoms: number of atoms in D-Caps layer
    :param routings1: number of routing iterations when stride is 1
    :param routings2: number of routing iterations when stride is > 1
    :return: Two Keras Models, the first one used for training, and the second one for evaluation.
            `eval_model` can also be used for training.
    """
    if n_class == 2:
        n_class = 1  # binary output

    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=16,
                          kernel_size=k_size,
                          strides=2,
                          padding='same',
                          activation='relu',
                          name='conv1')(x)

    # Reshape layer to be 1 capsule x [filters] atoms
    conv1_reshaped = ExpandDim(name='expand_dim')(conv1)

    # Layer 1: Primary Capsule: Conv cap with routing 1
    primary_caps = ConvCapsuleLayer(kernel_size=k_size,
                                    num_capsule=2,
                                    num_atoms=16,
                                    strides=2,
                                    padding='same',
                                    routings=1,
                                    name='primarycaps')(conv1_reshaped)

    # Layer 2: Convolutional Capsule
    conv_cap_2_1 = ConvCapsuleLayer(kernel_size=k_size,
                                    num_capsule=4,
                                    num_atoms=16,
                                    strides=1,
                                    padding='same',
                                    routings=routings1,
                                    name='conv_cap_2_1')(primary_caps)

    # Layer 2: Convolutional Capsule
    conv_cap_2_2 = ConvCapsuleLayer(kernel_size=k_size,
                                    num_capsule=4,
                                    num_atoms=32,
                                    strides=2,
                                    padding='same',
                                    routings=routings2,
                                    name='conv_cap_2_2')(conv_cap_2_1)

    # Layer 3: Convolutional Capsule
    conv_cap_3_1 = ConvCapsuleLayer(kernel_size=k_size,
                                    num_capsule=8,
                                    num_atoms=32,
                                    strides=1,
                                    padding='same',
                                    routings=routings1,
                                    name='conv_cap_3_1')(conv_cap_2_2)

    # Layer 3: Convolutional Capsule
    conv_cap_3_2 = ConvCapsuleLayer(kernel_size=k_size,
                                    num_capsule=8,
                                    num_atoms=64,
                                    strides=2,
                                    padding='same',
                                    routings=routings2,
                                    name='conv_cap_3_2')(conv_cap_3_1)

    # Layer 4: Convolutional Capsule
    conv_cap_4_1 = ConvCapsuleLayer(kernel_size=k_size,
                                    num_capsule=8,
                                    num_atoms=32,
                                    strides=1,
                                    padding='same',
                                    routings=routings1,
                                    name='conv_cap_4_1')(conv_cap_3_2)

    # Layer 3: Convolutional Capsule
    conv_cap_4_2 = ConvCapsuleLayer(kernel_size=k_size,
                                    num_capsule=n_class,
                                    num_atoms=output_atoms,
                                    strides=2,
                                    padding='same',
                                    routings=routings2,
                                    name='conv_cap_4_2')(conv_cap_4_1)

    if n_class > 1:
        # Perform GAP on each capsule type.
        class_caps_list = []
        for i in range(n_class):
            in_shape = conv_cap_4_2.get_shape().as_list()
            one_class_capsule = layers.Lambda(lambda x: x[:, :, :, i, :],
                                              output_shape=in_shape[1:3] +
                                              in_shape[4:])(conv_cap_4_2)
            gap = layers.GlobalAveragePooling2D(
                name='gap_{}'.format(i))(one_class_capsule)

            # Put capsule dimension back for length and recon
            class_caps_list.append(
                ExpandDim(name='expand_gap_{}'.format(i))(gap))

        class_caps = layers.Concatenate(axis=-2,
                                        name='class_caps')(class_caps_list)
    else:
        # Remove capsule dim, perform GAP, put capsule dim back
        conv_cap_4_2_reshaped = RemoveDim(
            name='conv_cap_4_2_reshaped')(conv_cap_4_2)
        gap = layers.GlobalAveragePooling2D(name='gap')(conv_cap_4_2_reshaped)
        class_caps = ExpandDim(name='expand_gap')(gap)

    # Output layer which predicts classes
    out_caps = Length(num_classes=n_class, name='out_caps')(class_caps)

    # Decoder network.
    _, C, A = class_caps.get_shape()
    y = layers.Input(shape=(n_class, ))
    masked_by_y = Mask()(
        [class_caps, y]
    )  # The true label is used to mask the output of capsule layer. For training
    masked = Mask(
    )(class_caps)  # Mask using the capsule with maximal length. For prediction

    def shared_reconstructor(mask_layer):
        recon_1 = layers.Dense(input_shape[0] // (2**6) * input_shape[1] //
                               (2**6),
                               kernel_initializer='he_normal',
                               activation='relu',
                               name='recon_1',
                               input_shape=(A.value, ))(mask_layer)

        recon_1a = layers.Reshape(
            (input_shape[0] // (2**6), input_shape[1] // (2**6), 1),
            name='recon_1a')(recon_1)

        recon_2 = layers.Conv2DTranspose(filters=128,
                                         kernel_size=5,
                                         strides=(8, 8),
                                         padding='same',
                                         kernel_initializer='he_normal',
                                         activation='relu',
                                         name='recon_2')(recon_1a)

        recon_3 = layers.Conv2DTranspose(filters=64,
                                         kernel_size=5,
                                         strides=(8, 8),
                                         padding='same',
                                         kernel_initializer='he_normal',
                                         activation='relu',
                                         name='recon_3')(recon_2)

        out_recon = layers.Conv2D(filters=3,
                                  kernel_size=3,
                                  padding='same',
                                  kernel_initializer='he_normal',
                                  activation='tanh',
                                  name='out_recon')(recon_3)

        return out_recon

    # Models for training and evaluation (prediction)
    train_model = models.Model(
        inputs=[x, y], outputs=[out_caps,
                                shared_reconstructor(masked_by_y)])
    eval_model = models.Model(inputs=x,
                              outputs=[out_caps,
                                       shared_reconstructor(masked)])

    # manipulate model
    noise = layers.Input(shape=((C.value, A.value)))
    noised_class_caps = layers.Add()([class_caps, noise])
    masked_noised_y = Mask()([noised_class_caps, y])
    manipulate_model = models.Model(
        inputs=[x, y, noise], outputs=shared_reconstructor(masked_noised_y))

    return train_model, eval_model, manipulate_model
Esempio n. 4
0
def CapsNetBasic(input_shape, n_class=2):
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=256,
                          kernel_size=5,
                          strides=1,
                          padding='same',
                          activation='relu',
                          name='conv1')(x)

    # Reshape layer to be 1 capsule x [filters] atoms
    _, H, W, C = conv1.get_shape()
    conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1)

    # Layer 1: Primary Capsule: Conv cap with routing 1
    primary_caps = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=8,
                                    num_atoms=32,
                                    strides=1,
                                    padding='same',
                                    routings=1,
                                    name='primarycaps')(conv1_reshaped)

    # Layer 4: Convolutional Capsule: 1x1
    seg_caps = ConvCapsuleLayer(kernel_size=1,
                                num_capsule=1,
                                num_atoms=16,
                                strides=1,
                                padding='same',
                                routings=3,
                                name='seg_caps')(primary_caps)

    # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
    out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps)

    # Decoder network.
    _, H, W, C, A = seg_caps.get_shape()
    y = layers.Input(shape=input_shape[:-1] + (1, ))
    masked_by_y = Mask()(
        [seg_caps, y]
    )  # The true label is used to mask the output of capsule layer. For training
    masked = Mask()(
        seg_caps)  # Mask using the capsule with maximal length. For prediction

    def shared_decoder(mask_layer):
        recon_remove_dim = layers.Reshape(
            (H.value, W.value, A.value))(mask_layer)

        recon_1 = layers.Conv2D(filters=64,
                                kernel_size=1,
                                padding='same',
                                kernel_initializer='he_normal',
                                activation='relu',
                                name='recon_1')(recon_remove_dim)

        recon_2 = layers.Conv2D(filters=128,
                                kernel_size=1,
                                padding='same',
                                kernel_initializer='he_normal',
                                activation='relu',
                                name='recon_2')(recon_1)

        out_recon = layers.Conv2D(filters=1,
                                  kernel_size=1,
                                  padding='same',
                                  kernel_initializer='he_normal',
                                  activation='sigmoid',
                                  name='out_recon')(recon_2)

        return out_recon

    # Models for training and evaluation (prediction)
    train_model = models.Model(inputs=[x, y],
                               outputs=[out_seg,
                                        shared_decoder(masked_by_y)])
    eval_model = models.Model(inputs=x,
                              outputs=[out_seg,
                                       shared_decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=((H.value, W.value, C.value, A.value)))
    noised_seg_caps = layers.Add()([seg_caps, noise])
    masked_noised_y = Mask()([noised_seg_caps, y])
    manipulate_model = models.Model(inputs=[x, y, noise],
                                    outputs=shared_decoder(masked_noised_y))

    return train_model, eval_model, manipulate_model
Esempio n. 5
0
def CapsNet(input_shape, n_class=5, routings=3, noactiv=False):
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=256,
                          kernel_size=9,
                          strides=1,
                          padding='valid',
                          activation='relu',
                          name='conv1')(x)

    # Reshape layer to be 1 capsule x [filters] atoms
    conv1_reshaped = ExpandDim(name='expand_dim')(conv1)

    # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
    primary_caps = ConvCapsuleLayer(kernel_size=9,
                                    num_capsule=32,
                                    num_atoms=8,
                                    strides=2,
                                    padding='same',
                                    routings=1,
                                    name='primary_caps')(conv1_reshaped)

    # Layer 3: Capsule layer. Routing algorithm works here.
    malcaps = FullCapsuleLayer(num_capsule=n_class,
                               num_atoms=16,
                               routings=routings,
                               name='malcaps')(primary_caps)

    # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
    # If using tensorflow, this will not be necessary. :)
    if noactiv:
        out_mal = Length(num_classes=n_class, name='out_mal')(malcaps)
    else:
        mal_mag = Length(num_classes=n_class, name='mal_mag')(malcaps)
        out_mal = layers.Activation('softmax', name='out_mal')(mal_mag)

    # Decoder network.
    y = layers.Input(shape=(n_class, ))
    masked_by_y = Mask(n_class)(
        [malcaps, y]
    )  # The true label is used to mask the output of capsule layer. For training
    masked = Mask(n_class)(
        malcaps)  # Mask using the capsule with maximal length. For prediction

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='out_recon')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = models.Model([x, y], [out_mal, decoder(masked_by_y)])
    eval_model = models.Model(x, [out_mal, decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=(n_class, 16))
    noised_malcaps = layers.Add()([malcaps, noise])
    masked_noised_y = Mask(n_class)([noised_malcaps, y])
    manipulate_model = models.Model(
        [x, y, noise], [out_mal, decoder(masked_noised_y)])

    return train_model, eval_model, manipulate_model
Esempio n. 6
0
def CapsNetR3(input_shape, modalities=1, n_class=2):
    capsules_base = 2
    filter_multiplier = 1
    atoms_base = 16 * filter_multiplier
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=16 * filter_multiplier,
                          kernel_size=5,
                          strides=1,
                          padding='same',
                          activation='relu',
                          name='conv1')(x)

    # Reshape layer to be 1 capsule x [filters] atoms
    _, H, W, C = conv1.get_shape()
    print(conv1.shape)
    conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1)
    print(conv1_reshaped.shape)
    # Layer 1: Primary Capsule: Conv cap with routing 1
    primary_caps = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=capsules_base,
                                    num_atoms=atoms_base,
                                    strides=2,
                                    padding='same',
                                    routings=1,
                                    name='primarycaps')(conv1_reshaped)

    # Layer 2: Convolutional Capsule
    conv_cap_2_1 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=capsules_base * 2,
                                    num_atoms=atoms_base,
                                    strides=1,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_2_1')(primary_caps)

    # Layer 2: Convolutional Capsule
    conv_cap_2_2 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=capsules_base * 2,
                                    num_atoms=atoms_base * 2,
                                    strides=2,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_2_2')(conv_cap_2_1)

    # Layer 3: Convolutional Capsule
    conv_cap_3_1 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=capsules_base * 4,
                                    num_atoms=atoms_base * 2,
                                    strides=1,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_3_1')(conv_cap_2_2)

    # Layer 3: Convolutional Capsule
    conv_cap_3_2 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=capsules_base * 4,
                                    num_atoms=atoms_base * 4,
                                    strides=2,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_3_2')(conv_cap_3_1)

    # Layer 4: Convolutional Capsule
    conv_cap_4_1 = ConvCapsuleLayer(kernel_size=5,
                                    num_capsule=capsules_base * 4,
                                    num_atoms=atoms_base * 2,
                                    strides=1,
                                    padding='same',
                                    routings=3,
                                    name='conv_cap_4_1')(conv_cap_3_2)

    # Layer 1 Up: Deconvolutional Capsule
    deconv_cap_1_1 = DeconvCapsuleLayer(kernel_size=4,
                                        num_capsule=capsules_base * 4,
                                        num_atoms=atoms_base * 2,
                                        upsamp_type='deconv',
                                        scaling=2,
                                        padding='same',
                                        routings=3,
                                        name='deconv_cap_1_1')(conv_cap_4_1)

    # Skip connection
    up_1 = layers.Concatenate(axis=-2,
                              name='up_1')([deconv_cap_1_1, conv_cap_3_1])

    # Layer 1 Up: Deconvolutional Capsule
    deconv_cap_1_2 = ConvCapsuleLayer(kernel_size=5,
                                      num_capsule=capsules_base * 2,
                                      num_atoms=atoms_base * 2,
                                      strides=1,
                                      padding='same',
                                      routings=3,
                                      name='deconv_cap_1_2')(up_1)

    # Layer 2 Up: Deconvolutional Capsule
    deconv_cap_2_1 = DeconvCapsuleLayer(kernel_size=4,
                                        num_capsule=capsules_base * 2,
                                        num_atoms=atoms_base,
                                        upsamp_type='deconv',
                                        scaling=2,
                                        padding='same',
                                        routings=3,
                                        name='deconv_cap_2_1')(deconv_cap_1_2)

    # Skip connection
    up_2 = layers.Concatenate(axis=-2,
                              name='up_2')([deconv_cap_2_1, conv_cap_2_1])

    # Layer 2 Up: Deconvolutional Capsule
    deconv_cap_2_2 = ConvCapsuleLayer(kernel_size=5,
                                      num_capsule=capsules_base * 2,
                                      num_atoms=atoms_base,
                                      strides=1,
                                      padding='same',
                                      routings=3,
                                      name='deconv_cap_2_2')(up_2)

    # Layer 3 Up: Deconvolutional Capsule
    deconv_cap_3_1 = DeconvCapsuleLayer(kernel_size=4,
                                        num_capsule=capsules_base * 2,
                                        num_atoms=atoms_base,
                                        upsamp_type='deconv',
                                        scaling=2,
                                        padding='same',
                                        routings=3,
                                        name='deconv_cap_3_1')(deconv_cap_2_2)

    # Skip connection
    up_3 = layers.Concatenate(axis=-2,
                              name='up_3')([deconv_cap_3_1, conv1_reshaped])

    seg_caps = ConvCapsuleLayer(kernel_size=1,
                                num_capsule=n_class,
                                num_atoms=atoms_base,
                                strides=1,
                                padding='same',
                                routings=3,
                                name='seg_caps')(up_3)

    # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
    out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps)
    print(out_seg.shape)
    #assert False, "Out seg shape"
    #out_seg = seg_caps
    # Decoder network.
    _, H, W, C, A = seg_caps.get_shape()
    print(H.value, W.value, C.value, A.value)

    y = layers.Input(shape=input_shape[:-1] + (1, ), name='recon_input')

    masked_by_y = Mask()(
        [seg_caps, y]
    )  # The true label is used to mask the output of capsule layer. For training
    print('masked by y ' + str(masked_by_y.shape))
    print('Y ' + str(y.shape))
    masked = Mask()(
        seg_caps)  # Mask using the capsule with maximal length. For prediction

    def shared_decoder(mask_layer):

        recon_remove_dim = layers.Reshape(
            (input_shape[0], input_shape[1], n_class * atoms_base))(mask_layer)

        recon_1 = layers.Conv2D(filters=64,
                                kernel_size=1,
                                padding='same',
                                kernel_initializer='he_normal',
                                activation='relu',
                                name='recon_1')(recon_remove_dim)

        recon_2 = layers.Conv2D(filters=128,
                                kernel_size=1,
                                padding='same',
                                kernel_initializer='he_normal',
                                activation='relu',
                                name='recon_2')(recon_1)

        out_recon = layers.Conv2D(filters=modalities,
                                  kernel_size=1,
                                  padding='same',
                                  kernel_initializer='he_normal',
                                  activation='sigmoid',
                                  name='out_recon')(recon_2)

        return out_recon

    # Models for training and evaluation (prediction)
    train_outputs = [out_seg, shared_decoder(masked_by_y)]
    print("Train outputs")
    print(train_outputs[0].shape)
    print(train_outputs[1].shape)
    #assert False
    train_model = models.Model(inputs=[x, y], outputs=train_outputs)
    eval_model = models.Model(inputs=x,
                              outputs=[
                                  out_seg, shared_decoder(masked)
                              ])  #TODO: Check masked by y for testing!

    # manipulate model
    noise = layers.Input(shape=((H.value, W.value, C.value, A.value)))
    noised_seg_caps = layers.Add()([seg_caps, noise])
    masked_noised_y = Mask()([noised_seg_caps, y])
    manipulate_model = models.Model(inputs=[x, y, noise],
                                    outputs=shared_decoder(masked_noised_y))

    return train_model, eval_model, manipulate_model