def LvqCapsNet(input_shape, n_class):
    """
    A LVQ-Capsule network on MNIST with Minkowski distance, neg_exp as probability transformation and pre-training.
    """
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    digitcaps = layers.Conv2D(filters=256,
                              kernel_size=9,
                              strides=1,
                              padding='valid',
                              activation='relu')(x)
    digitcaps = layers.Conv2D(filters=8 * 32,
                              kernel_size=9,
                              strides=2,
                              padding='valid')(digitcaps)

    # LVQ-Capsule
    digit_caps = Capsule(name='digitcaps', prototype_distribution=(1, n_class))
    digit_caps.add(
        InputModule(signal_shape=8, dissimilarity_initializer='zeros'))
    digit_caps.add(LinearTransformation(output_dim=16, scope='local'))
    diss = MinkowskiDistance(prototype_initializer='zeros',
                             squared_dissimilarity=False,
                             name='minkowski_distance')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer='max_distance'))
    digit_caps.add(
        Classification(probability_transformation='neg_exp', name='lvq_caps'))

    digitcaps = digit_caps(digitcaps)

    # Decoder network.
    y = layers.Input(shape=(n_class, ))
    masked_by_y = Mask()([digitcaps[0], y])
    masked = Mask()([digitcaps[0], digitcaps[2]])

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = models.Model([x, y], [digitcaps[2], decoder(masked_by_y)])
    eval_model = models.Model(x, [digitcaps[2], decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=(n_class, 16))
    noised_digitcaps = layers.Add()([digitcaps[0], noise])
    masked_noised_y = Mask()([noised_digitcaps, y])
    manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y))
    return train_model, eval_model, manipulate_model, decoder
示例#2
0
def CapsNet(input_shape, n_class, routings):
    """ Initialize the CapsNet"""
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    digitcaps = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)

    # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
    primary_caps = Capsule(name='PrimaryCaps')
    primary_caps.add(layers.Conv2D(filters=8 * 32, kernel_size=9, strides=2, padding='valid', name='primarycap_conv2d'))
    primary_caps.add(layers.Reshape(target_shape=[-1, 8], name='primarycap_reshape'))
    primary_caps.add(layers.Lambda(squash, name='primarycap_squash'))

    digitcaps = primary_caps(digitcaps)

    # Layer 3: Capsule layer. Routing algorithm works here.
    digit_caps = Capsule(name='digitcaps', prototype_distribution=(1, n_class))
    digit_caps.add(InputModule(signal_shape=None, init_diss_initializer='zeros', trainable=False))
    digit_caps.add(LinearTransformation(output_dim=16, scope='local'))
    digit_caps.add(DynamicRouting(iterations=routings, name='capsnet'))

    digitcaps = digit_caps(digitcaps)

    # Decoder network.
    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([digitcaps[0], y])  # The true label is used to mask the output of capsule layer. For training
    masked = Mask()([digitcaps[0], digitcaps[2]])  # Mask using the capsule with maximal length. For prediction

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = models.Model([x, y], [digitcaps[2], decoder(masked_by_y)])
    eval_model = models.Model(x, [digitcaps[2], decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=(n_class, 16))
    noised_digitcaps = layers.Add()([digitcaps[0], noise])
    masked_noised_y = Mask()([noised_digitcaps, y])
    manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y))
    return train_model, eval_model, manipulate_model
x_train = x_train.astype('float32')
y_train = to_categorical(y_train, num_classes).astype('float32')

inputs = Input((x_train.shape[1], ))

# define MLVQ prototypes
diss = OmegaDistance(linear_factor=None,
                     squared_dissimilarity=True,
                     matrix_scope='global',
                     matrix_constraint='OmegaNormalization',
                     signal_output='signals')

caps = Capsule(prototype_distribution=(1, num_classes))
caps.add(
    InputModule(signal_shape=x_train.shape[1],
                trainable=False,
                dissimilarity_initializer='zeros'))
caps.add(diss)
caps.add(SqueezeRouting())
caps.add(Classification(probability_transformation='flip', name='lvq_capsule'))

outputs = caps(inputs)

# pre-train the model
pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
diss_input = pre_train_model.predict(x_train, batch_size=batch_size)
diss.pre_training(diss_input, y_train, capsule_inputs_are_equal=True)

model = Model(inputs, outputs[2])
model.compile(loss=glvq_loss,
              optimizer=optimizers.Adam(lr=lr),
示例#4
0
def LvqCapsNet(input_shape):
    input_img = Input(shape=input_shape)

    # Block 1
    caps0 = Capsule()
    caps0.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps0.add(BatchNormalization())
    caps0.add(Activation('relu'))
    caps0.add(Dropout(0.25))
    x = caps0(input_img)

    # Block 2
    caps1 = Capsule()
    caps1.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(x)

    # Block 3
    caps2 = Capsule()
    caps2.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 4
    caps3 = Capsule(prototype_distribution=32)
    caps3.add(Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 5
    caps4 = Capsule()
    caps4.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 6
    caps5 = Capsule()
    caps5.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 7
    caps6 = Capsule()
    caps6.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps6.add(SplitModule())
    caps6.add(Activation('relu'), scope_keys=1)
    caps6.add(Flatten(), scope_keys=1)
    x = caps6(x)

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8))
    caps7.add(InputModule(signal_shape=(-1, 64), init_diss_initializer=None, trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16)
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(x)

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4))
    caps8.add(Reshape((8, 8, 64)))
    caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False,
                            epsilon=1.e-12, linear_factor=0.66, signal_output='signals')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 10))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    digit_caps.add(InputModule(signal_shape=128, init_diss_initializer=None, trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False,
                                     linear_factor=0.66, signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(GibbsRouting(norm_axis='channels', trainable=False,
                                diss_regularizer=MaxValue(alpha=0.0001)))
    digit_caps.add(DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps'))

    digitcaps = digit_caps(x)

    # intermediate model for Caps2; used for visualizations
    input_diss8 = [Input((4, 4, 64)), Input((16,))]
    model_vis_caps2 = models.Model(input_diss8, digit_caps(list_to_dict(input_diss8)))

    # Decoder network.
    y = layers.Input(shape=(10,))
    masked_by_y = Mask()([digitcaps[0], y])

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=128 * 10))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod((28, 28, 1)), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=(28, 28, 1), name='out_recon'))

    # Models for training and evaluation (prediction)
    model = models.Model([input_img, y], [digitcaps[2], decoder(masked_by_y)])

    return model, decoder,  model_vis_caps2
示例#5
0
y_test = to_categorical(y_test.astype('float32'))

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

inputs = Input(x_train.shape[1:])

# define MLVQ prototypes
diss = MinkowskiDistance(linear_factor=None,
                         squared_dissimilarity=True,
                         signal_output='signals')

caps = Capsule(prototype_distribution=(protos_per_class, num_classes))
caps.add(
    InputModule(signal_shape=(1, ) + x_train.shape[1:],
                trainable=False,
                init_diss_initializer='zeros'))
caps.add(diss)
caps.add(SqueezeRouting())
caps.add(NearestCompetition(use_for_loop=False))
caps.add(
    DissimilarityTransformation(probability_transformation='flip',
                                name='lvq_capsule'))

outputs = caps(inputs)

# pre-train the model over 10000 random digits
idx = np.random.randint(0, len(x_train) - 1, (min(10000, len(x_train)), ))
pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
diss_input = pre_train_model.predict(x_train[idx, :], batch_size=batch_size)
diss.pre_training(diss_input, y_train[idx], capsule_inputs_are_equal=True)
示例#6
0
def LvqCapsNet(input_shape):
    input_img = layers.Input(shape=input_shape)

    # Block 1
    caps0 = Capsule()
    caps0.add(
        Conv2D(32 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal(),
               input_shape=x_train.shape[1:]))
    caps0.add(BatchNormalization())
    caps0.add(Activation('relu'))
    caps0.add(Dropout(0.25))
    x = caps0(input_img)

    # Block 2
    caps1 = Capsule()
    caps1.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(x)

    # Block 3
    caps2 = Capsule()
    caps2.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 4
    caps3 = Capsule()
    caps3.add(
        Conv2D(64 + 2, (5, 5),
               strides=2,
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 5
    caps4 = Capsule()
    caps4.add(
        Conv2D(32 + 2, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 6
    caps5 = Capsule()
    caps5.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 7
    caps6 = Capsule()
    caps6.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps6.add(SplitModule(index=-2))
    caps6.add(Activation('relu'), scope_keys=1)
    caps6.add(Flatten(), scope_keys=1)
    x = caps6(x)

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8), sparse_signal=True)
    caps7.add(
        InputModule(signal_shape=(-1, 32),
                    init_diss_initializer=None,
                    trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            projected_atom_shape=16,
                            matrix_scope='global')
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(x)

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4), sparse_signal=True)
    caps8.add(Reshape((8, 8, 32)))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        InputModule(signal_shape=(8 * 8, 64),
                    init_diss_initializer=None,
                    trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16,
                            squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            signal_output='signals',
                            matrix_scope='global')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 5))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(
        Conv2D(128, (3, 3), padding='same',
               kernel_initializer=glorot_normal()))
    digit_caps.add(
        InputModule(signal_shape=128,
                    init_diss_initializer=None,
                    trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16,
                                     epsilon=1.e-12,
                                     squared_dissimilarity=False,
                                     linear_factor=0.66,
                                     signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer=MaxValue(alpha=0.0001)))
    digit_caps.add(
        DissimilarityTransformation(probability_transformation='neg_softmax',
                                    name='lvq_caps'))

    digitcaps = digit_caps(x)

    model = models.Model(input_img, digitcaps[2])

    return model
示例#7
0
def LvqCapsNet(input_shape):
    input_img = layers.Input(shape=input_shape)

    # Block 1
    caps1 = Capsule()
    caps1.add(
        Conv2D(32 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(input_img)

    # Block 2
    caps2 = Capsule()
    caps2.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 3
    caps3 = Capsule()
    caps3.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 4
    caps4 = Capsule(prototype_distribution=32)
    caps4.add(
        Conv2D(64 + 1, (5, 5),
               strides=2,
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(BatchNormalization())
    caps4.add(Activation('relu'))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 5
    caps5 = Capsule()
    caps5.add(
        Conv2D(32 + 1, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 6
    caps6 = Capsule()
    caps6.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps6.add(Dropout(0.25))
    x = caps6(x)

    # Block 7
    x = Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal())(x)
    x = [crop(3, 0, 64)(x), crop(3, 64, 65)(x)]
    x[1] = Activation('relu')(x[1])
    x[1] = Flatten()(x[1])

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8))
    caps7.add(
        InputModule(signal_shape=(16 * 16, 64),
                    dissimilarity_initializer=None,
                    trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            projected_atom_shape=16)
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(list_to_dict(x))

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4))
    caps8.add(Reshape((8, 8, 64)))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        InputModule(signal_shape=(8 * 8, 64),
                    dissimilarity_initializer=None,
                    trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16,
                            squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            signal_output='signals')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 10))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(
        Conv2D(128, (3, 3), padding='same',
               kernel_initializer=glorot_normal()))
    digit_caps.add(
        InputModule(signal_shape=128,
                    dissimilarity_initializer=None,
                    trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16,
                                     epsilon=1.e-12,
                                     squared_dissimilarity=False,
                                     linear_factor=0.66,
                                     signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer=MaxDistance(alpha=0.0001)))
    digit_caps.add(
        Classification(probability_transformation='neg_softmax',
                       name='lvq_caps'))

    digitcaps = digit_caps(x)

    model = models.Model(input_img, digitcaps[2])

    return model
x_train, y_train = tecator.load_data()

x_train = x_train.astype('float32')
y_train = to_categorical(y_train, num_classes).astype('float32')

inputs = Input((x_train.shape[1], ))

# define MLVQ prototypes
diss = OmegaDistance(linear_factor=None,
                     squared_dissimilarity=True,
                     matrix_scope='global',
                     matrix_constraint='OmegaNormalization',
                     signal_output='signals')

caps = Capsule(prototype_distribution=(1, num_classes))
caps.add(InputModule(signal_shape=x_train.shape[1], trainable=False, init_diss_initializer='zeros'))
caps.add(diss)
caps.add(SqueezeRouting())
caps.add(DissimilarityTransformation(probability_transformation='flip', name='lvq_capsule'))

outputs = caps(inputs)

# pre-train the model
pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
diss_input = pre_train_model.predict(x_train, batch_size=batch_size)
diss.pre_training(diss_input, y_train, capsule_inputs_are_equal=True)

model = Model(inputs, outputs[2])
model.compile(loss=glvq_loss, optimizer=optimizers.Adam(lr=lr), metrics={'lvq_capsule': 'accuracy'})

model.summary()