예제 #1
0
def get_model():
    inputs = Input(shape=input_shape)

    # get the dissimilarity either GLVQ, GTLVQ, or GMLVQ
    if args.mode == 'glvq':
        diss = MinkowskiDistance(linear_factor=None,
                                 squared_dissimilarity=True,
                                 signal_output='signals')
    elif args.mode == 'gtlvq':
        diss = RestrictedTangentDistance(linear_factor=None,
                                         squared_dissimilarity=True,
                                         signal_output='signals',
                                         projected_atom_shape=1)
    elif args.mode == 'gmlvq':
        # get identity matrices for the matrix initialization as we do not
        # use the standard routine of anysma for the initialization
        matrix_init = np.repeat(np.expand_dims(np.eye(2), 0),
                                repeats=np.sum(protos_per_class),
                                axis=0)
        matrix_init.astype(K.floatx())

        diss = OmegaDistance(linear_factor=None,
                             squared_dissimilarity=True,
                             signal_output='signals',
                             matrix_scope='local',
                             matrix_constraint='OmegaNormalization',
                             matrix_initializer=lambda x: matrix_init)

    # define capsule network
    caps = Capsule(prototype_distribution=protos_per_class)
    caps.add(
        InputModule(signal_shape=(-1, np.prod(input_shape)),
                    trainable=False,
                    init_diss_initializer='zeros'))
    caps.add(diss)
    caps.add(SqueezeRouting())
    caps.add(NearestCompetition())

    output = caps(inputs)[1]

    # pre-train the model and overwrite the standard initialization matrix
    # for GMLVQ
    if args.mode == 'gmlvq':
        _, matrices = diss.get_weights()

    pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
    diss_input = pre_train_model.predict(x_train, batch_size=batch_size)
    diss.pre_training(diss_input, y_train, capsule_inputs_are_equal=True)

    if args.mode == 'gmlvq':
        # set identity matrices
        centers, _ = diss.get_weights()
        diss.set_weights([centers, matrices])

    # define model and return
    model = Model(inputs, output)

    return model
예제 #2
0
def LvqCapsNet(input_shape):
    input_img = Input(shape=input_shape)

    # Block 1
    caps0 = Capsule()
    caps0.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps0.add(BatchNormalization())
    caps0.add(Activation('relu'))
    caps0.add(Dropout(0.25))
    x = caps0(input_img)

    # Block 2
    caps1 = Capsule()
    caps1.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(x)

    # Block 3
    caps2 = Capsule()
    caps2.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 4
    caps3 = Capsule(prototype_distribution=32)
    caps3.add(Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 5
    caps4 = Capsule()
    caps4.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 6
    caps5 = Capsule()
    caps5.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 7
    caps6 = Capsule()
    caps6.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps6.add(SplitModule())
    caps6.add(Activation('relu'), scope_keys=1)
    caps6.add(Flatten(), scope_keys=1)
    x = caps6(x)

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8))
    caps7.add(InputModule(signal_shape=(-1, 64), init_diss_initializer=None, trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16)
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(x)

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4))
    caps8.add(Reshape((8, 8, 64)))
    caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False,
                            epsilon=1.e-12, linear_factor=0.66, signal_output='signals')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 10))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    digit_caps.add(InputModule(signal_shape=128, init_diss_initializer=None, trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False,
                                     linear_factor=0.66, signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(GibbsRouting(norm_axis='channels', trainable=False,
                                diss_regularizer=MaxValue(alpha=0.0001)))
    digit_caps.add(DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps'))

    digitcaps = digit_caps(x)

    # intermediate model for Caps2; used for visualizations
    input_diss8 = [Input((4, 4, 64)), Input((16,))]
    model_vis_caps2 = models.Model(input_diss8, digit_caps(list_to_dict(input_diss8)))

    # Decoder network.
    y = layers.Input(shape=(10,))
    masked_by_y = Mask()([digitcaps[0], y])

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=128 * 10))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod((28, 28, 1)), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=(28, 28, 1), name='out_recon'))

    # Models for training and evaluation (prediction)
    model = models.Model([input_img, y], [digitcaps[2], decoder(masked_by_y)])

    return model, decoder,  model_vis_caps2
예제 #3
0
def LvqCapsNet(input_shape):
    input_img = layers.Input(shape=input_shape)

    # Block 1
    caps0 = Capsule()
    caps0.add(
        Conv2D(32 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal(),
               input_shape=x_train.shape[1:]))
    caps0.add(BatchNormalization())
    caps0.add(Activation('relu'))
    caps0.add(Dropout(0.25))
    x = caps0(input_img)

    # Block 2
    caps1 = Capsule()
    caps1.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(x)

    # Block 3
    caps2 = Capsule()
    caps2.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 4
    caps3 = Capsule()
    caps3.add(
        Conv2D(64 + 2, (5, 5),
               strides=2,
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 5
    caps4 = Capsule()
    caps4.add(
        Conv2D(32 + 2, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 6
    caps5 = Capsule()
    caps5.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 7
    caps6 = Capsule()
    caps6.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps6.add(SplitModule(index=-2))
    caps6.add(Activation('relu'), scope_keys=1)
    caps6.add(Flatten(), scope_keys=1)
    x = caps6(x)

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8), sparse_signal=True)
    caps7.add(
        InputModule(signal_shape=(-1, 32),
                    init_diss_initializer=None,
                    trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            projected_atom_shape=16,
                            matrix_scope='global')
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(x)

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4), sparse_signal=True)
    caps8.add(Reshape((8, 8, 32)))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        InputModule(signal_shape=(8 * 8, 64),
                    init_diss_initializer=None,
                    trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16,
                            squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            signal_output='signals',
                            matrix_scope='global')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 5))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(
        Conv2D(128, (3, 3), padding='same',
               kernel_initializer=glorot_normal()))
    digit_caps.add(
        InputModule(signal_shape=128,
                    init_diss_initializer=None,
                    trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16,
                                     epsilon=1.e-12,
                                     squared_dissimilarity=False,
                                     linear_factor=0.66,
                                     signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer=MaxValue(alpha=0.0001)))
    digit_caps.add(
        DissimilarityTransformation(probability_transformation='neg_softmax',
                                    name='lvq_caps'))

    digitcaps = digit_caps(x)

    model = models.Model(input_img, digitcaps[2])

    return model
예제 #4
0
def LvqCapsNet(input_shape):
    input_img = layers.Input(shape=input_shape)

    # Block 1
    caps1 = Capsule()
    caps1.add(
        Conv2D(32 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(input_img)

    # Block 2
    caps2 = Capsule()
    caps2.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 3
    caps3 = Capsule()
    caps3.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 4
    caps4 = Capsule(prototype_distribution=32)
    caps4.add(
        Conv2D(64 + 1, (5, 5),
               strides=2,
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(BatchNormalization())
    caps4.add(Activation('relu'))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 5
    caps5 = Capsule()
    caps5.add(
        Conv2D(32 + 1, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 6
    caps6 = Capsule()
    caps6.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps6.add(Dropout(0.25))
    x = caps6(x)

    # Block 7
    x = Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal())(x)
    x = [crop(3, 0, 64)(x), crop(3, 64, 65)(x)]
    x[1] = Activation('relu')(x[1])
    x[1] = Flatten()(x[1])

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8))
    caps7.add(
        InputModule(signal_shape=(16 * 16, 64),
                    dissimilarity_initializer=None,
                    trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            projected_atom_shape=16)
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(list_to_dict(x))

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4))
    caps8.add(Reshape((8, 8, 64)))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        InputModule(signal_shape=(8 * 8, 64),
                    dissimilarity_initializer=None,
                    trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16,
                            squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            signal_output='signals')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 10))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(
        Conv2D(128, (3, 3), padding='same',
               kernel_initializer=glorot_normal()))
    digit_caps.add(
        InputModule(signal_shape=128,
                    dissimilarity_initializer=None,
                    trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16,
                                     epsilon=1.e-12,
                                     squared_dissimilarity=False,
                                     linear_factor=0.66,
                                     signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer=MaxDistance(alpha=0.0001)))
    digit_caps.add(
        Classification(probability_transformation='neg_softmax',
                       name='lvq_caps'))

    digitcaps = digit_caps(x)

    model = models.Model(input_img, digitcaps[2])

    return model
예제 #5
0
save_dir = './output'

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.expand_dims(x_train.astype('float32') / 255., -1)
x_test = np.expand_dims(x_test.astype('float32') / 255., -1)
y_train = to_categorical(y_train.astype('float32'))
y_test = to_categorical(y_test.astype('float32'))

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

inputs = Input(x_train.shape[1:])

# define rTDLVQ prototypes
diss = RestrictedTangentDistance(projected_atom_shape=num_tangents,
                                 linear_factor=None,
                                 squared_dissimilarity=True,
                                 signal_output='signals')

caps = Capsule(prototype_distribution=(1, num_classes))
caps.add(
    InputModule(signal_shape=(1, ) + x_train.shape[1:4],
                trainable=False,
                dissimilarity_initializer='zeros'))
caps.add(diss)
caps.add(SqueezeRouting())
caps.add(NearestCompetition(use_for_loop=False))
caps.add(
    Classification(probability_transformation='neg_softmax',
                   name='lvq_capsule'))

outputs = caps(inputs)