예제 #1
0
def LvqCapsNet(input_shape, n_class):
    """
    A LVQ-Capsule network on MNIST with Minkowski distance, neg_exp as probability transformation and pre-training.
    """
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    digitcaps = layers.Conv2D(filters=256,
                              kernel_size=9,
                              strides=1,
                              padding='valid',
                              activation='relu')(x)
    digitcaps = layers.Conv2D(filters=8 * 32,
                              kernel_size=9,
                              strides=2,
                              padding='valid')(digitcaps)

    # LVQ-Capsule
    digit_caps = Capsule(name='digitcaps', prototype_distribution=(1, n_class))
    digit_caps.add(
        InputModule(signal_shape=8, dissimilarity_initializer='zeros'))
    digit_caps.add(LinearTransformation(output_dim=16, scope='local'))
    diss = MinkowskiDistance(prototype_initializer='zeros',
                             squared_dissimilarity=False,
                             name='minkowski_distance')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer='max_distance'))
    digit_caps.add(
        Classification(probability_transformation='neg_exp', name='lvq_caps'))

    digitcaps = digit_caps(digitcaps)

    # Decoder network.
    y = layers.Input(shape=(n_class, ))
    masked_by_y = Mask()([digitcaps[0], y])
    masked = Mask()([digitcaps[0], digitcaps[2]])

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = models.Model([x, y], [digitcaps[2], decoder(masked_by_y)])
    eval_model = models.Model(x, [digitcaps[2], decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=(n_class, 16))
    noised_digitcaps = layers.Add()([digitcaps[0], noise])
    masked_noised_y = Mask()([noised_digitcaps, y])
    manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y))
    return train_model, eval_model, manipulate_model, decoder
예제 #2
0
def LvqCapsNet(input_shape):
    input_img = Input(shape=input_shape)

    # Block 1
    caps0 = Capsule()
    caps0.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps0.add(BatchNormalization())
    caps0.add(Activation('relu'))
    caps0.add(Dropout(0.25))
    x = caps0(input_img)

    # Block 2
    caps1 = Capsule()
    caps1.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(x)

    # Block 3
    caps2 = Capsule()
    caps2.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 4
    caps3 = Capsule(prototype_distribution=32)
    caps3.add(Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 5
    caps4 = Capsule()
    caps4.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 6
    caps5 = Capsule()
    caps5.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 7
    caps6 = Capsule()
    caps6.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps6.add(SplitModule())
    caps6.add(Activation('relu'), scope_keys=1)
    caps6.add(Flatten(), scope_keys=1)
    x = caps6(x)

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8))
    caps7.add(InputModule(signal_shape=(-1, 64), init_diss_initializer=None, trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16)
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(x)

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4))
    caps8.add(Reshape((8, 8, 64)))
    caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False,
                            epsilon=1.e-12, linear_factor=0.66, signal_output='signals')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 10))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    digit_caps.add(InputModule(signal_shape=128, init_diss_initializer=None, trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False,
                                     linear_factor=0.66, signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(GibbsRouting(norm_axis='channels', trainable=False,
                                diss_regularizer=MaxValue(alpha=0.0001)))
    digit_caps.add(DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps'))

    digitcaps = digit_caps(x)

    # intermediate model for Caps2; used for visualizations
    input_diss8 = [Input((4, 4, 64)), Input((16,))]
    model_vis_caps2 = models.Model(input_diss8, digit_caps(list_to_dict(input_diss8)))

    # Decoder network.
    y = layers.Input(shape=(10,))
    masked_by_y = Mask()([digitcaps[0], y])

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=128 * 10))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod((28, 28, 1)), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=(28, 28, 1), name='out_recon'))

    # Models for training and evaluation (prediction)
    model = models.Model([input_img, y], [digitcaps[2], decoder(masked_by_y)])

    return model, decoder,  model_vis_caps2
예제 #3
0
def LvqCapsNet(input_shape):
    input_img = layers.Input(shape=input_shape)

    # Block 1
    caps0 = Capsule()
    caps0.add(
        Conv2D(32 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal(),
               input_shape=x_train.shape[1:]))
    caps0.add(BatchNormalization())
    caps0.add(Activation('relu'))
    caps0.add(Dropout(0.25))
    x = caps0(input_img)

    # Block 2
    caps1 = Capsule()
    caps1.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(x)

    # Block 3
    caps2 = Capsule()
    caps2.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 4
    caps3 = Capsule()
    caps3.add(
        Conv2D(64 + 2, (5, 5),
               strides=2,
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 5
    caps4 = Capsule()
    caps4.add(
        Conv2D(32 + 2, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 6
    caps5 = Capsule()
    caps5.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 7
    caps6 = Capsule()
    caps6.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps6.add(SplitModule(index=-2))
    caps6.add(Activation('relu'), scope_keys=1)
    caps6.add(Flatten(), scope_keys=1)
    x = caps6(x)

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8), sparse_signal=True)
    caps7.add(
        InputModule(signal_shape=(-1, 32),
                    init_diss_initializer=None,
                    trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            projected_atom_shape=16,
                            matrix_scope='global')
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(x)

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4), sparse_signal=True)
    caps8.add(Reshape((8, 8, 32)))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        InputModule(signal_shape=(8 * 8, 64),
                    init_diss_initializer=None,
                    trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16,
                            squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            signal_output='signals',
                            matrix_scope='global')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 5))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(
        Conv2D(128, (3, 3), padding='same',
               kernel_initializer=glorot_normal()))
    digit_caps.add(
        InputModule(signal_shape=128,
                    init_diss_initializer=None,
                    trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16,
                                     epsilon=1.e-12,
                                     squared_dissimilarity=False,
                                     linear_factor=0.66,
                                     signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer=MaxValue(alpha=0.0001)))
    digit_caps.add(
        DissimilarityTransformation(probability_transformation='neg_softmax',
                                    name='lvq_caps'))

    digitcaps = digit_caps(x)

    model = models.Model(input_img, digitcaps[2])

    return model
예제 #4
0
def LvqCapsNet(input_shape):
    input_img = layers.Input(shape=input_shape)

    # Block 1
    caps1 = Capsule()
    caps1.add(
        Conv2D(32 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(input_img)

    # Block 2
    caps2 = Capsule()
    caps2.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 3
    caps3 = Capsule()
    caps3.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 4
    caps4 = Capsule(prototype_distribution=32)
    caps4.add(
        Conv2D(64 + 1, (5, 5),
               strides=2,
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(BatchNormalization())
    caps4.add(Activation('relu'))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 5
    caps5 = Capsule()
    caps5.add(
        Conv2D(32 + 1, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 6
    caps6 = Capsule()
    caps6.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps6.add(Dropout(0.25))
    x = caps6(x)

    # Block 7
    x = Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal())(x)
    x = [crop(3, 0, 64)(x), crop(3, 64, 65)(x)]
    x[1] = Activation('relu')(x[1])
    x[1] = Flatten()(x[1])

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8))
    caps7.add(
        InputModule(signal_shape=(16 * 16, 64),
                    dissimilarity_initializer=None,
                    trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            projected_atom_shape=16)
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(list_to_dict(x))

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4))
    caps8.add(Reshape((8, 8, 64)))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        InputModule(signal_shape=(8 * 8, 64),
                    dissimilarity_initializer=None,
                    trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16,
                            squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            signal_output='signals')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 10))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(
        Conv2D(128, (3, 3), padding='same',
               kernel_initializer=glorot_normal()))
    digit_caps.add(
        InputModule(signal_shape=128,
                    dissimilarity_initializer=None,
                    trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16,
                                     epsilon=1.e-12,
                                     squared_dissimilarity=False,
                                     linear_factor=0.66,
                                     signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer=MaxDistance(alpha=0.0001)))
    digit_caps.add(
        Classification(probability_transformation='neg_softmax',
                       name='lvq_caps'))

    digitcaps = digit_caps(x)

    model = models.Model(input_img, digitcaps[2])

    return model