Exemplo n.º 1
0
def get_model(protos_per_class=1):
    inputs = Input(shape=input_shape)
    diss = TangentDistance(linear_factor=None,
                           squared_dissimilarity=True,
                           projected_atom_shape=12,
                           signal_output='signals')

    caps = Capsule(prototype_distribution=(protos_per_class, 10))
    caps.add(
        InputModule(signal_shape=(-1, np.prod(input_shape)),
                    trainable=False,
                    init_diss_initializer='zeros'))
    caps.add(diss)
    caps.add(SqueezeRouting())
    caps.add(NearestCompetition())

    output = caps(inputs)[1]

    # pre-train the model over 10000 random digits
    idx = np.random.randint(0, len(x_train) - 1, (min(10000, len(x_train)), ))
    pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
    diss_input = pre_train_model.predict(x_train[idx, :],
                                         batch_size=batch_size)
    diss.pre_training(diss_input, y_train[idx], capsule_inputs_are_equal=True)

    # define model and return
    model = Model(inputs, output)

    return model
Exemplo n.º 2
0
def get_model():
    inputs = Input(shape=input_shape)

    # get the dissimilarity either GLVQ, GTLVQ, or GMLVQ
    if args.mode == 'glvq':
        diss = MinkowskiDistance(linear_factor=None,
                                 squared_dissimilarity=True,
                                 signal_output='signals')
    elif args.mode == 'gtlvq':
        diss = RestrictedTangentDistance(linear_factor=None,
                                         squared_dissimilarity=True,
                                         signal_output='signals',
                                         projected_atom_shape=1)
    elif args.mode == 'gmlvq':
        # get identity matrices for the matrix initialization as we do not
        # use the standard routine of anysma for the initialization
        matrix_init = np.repeat(np.expand_dims(np.eye(2), 0),
                                repeats=np.sum(protos_per_class),
                                axis=0)
        matrix_init.astype(K.floatx())

        diss = OmegaDistance(linear_factor=None,
                             squared_dissimilarity=True,
                             signal_output='signals',
                             matrix_scope='local',
                             matrix_constraint='OmegaNormalization',
                             matrix_initializer=lambda x: matrix_init)

    # define capsule network
    caps = Capsule(prototype_distribution=protos_per_class)
    caps.add(
        InputModule(signal_shape=(-1, np.prod(input_shape)),
                    trainable=False,
                    init_diss_initializer='zeros'))
    caps.add(diss)
    caps.add(SqueezeRouting())
    caps.add(NearestCompetition())

    output = caps(inputs)[1]

    # pre-train the model and overwrite the standard initialization matrix
    # for GMLVQ
    if args.mode == 'gmlvq':
        _, matrices = diss.get_weights()

    pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
    diss_input = pre_train_model.predict(x_train, batch_size=batch_size)
    diss.pre_training(diss_input, y_train, capsule_inputs_are_equal=True)

    if args.mode == 'gmlvq':
        # set identity matrices
        centers, _ = diss.get_weights()
        diss.set_weights([centers, matrices])

    # define model and return
    model = Model(inputs, output)

    return model
Exemplo n.º 3
0
def get_model(protos_per_class=1):
    inputs = Input(shape=input_shape)
    diss = OmegaDistance(linear_factor=None,
                         squared_dissimilarity=True,
                         matrix_scope='global',
                         matrix_constraint='OmegaNormalization',
                         signal_output='signals',
                         matrix_initializer='identity')

    caps = Capsule(prototype_distribution=(protos_per_class, 10))
    caps.add(InputModule(signal_shape=(-1, np.prod(input_shape)),
                         trainable=False,
                         init_diss_initializer='zeros'))
    caps.add(diss)
    caps.add(SqueezeRouting())
    caps.add(NearestCompetition())

    output = caps(inputs)[1]

    # pre-train the model over 10000 random digits
    # skip the svd for GMLVQ
    _, matrix = diss.get_weights()

    idx = np.random.randint(0, len(x_train) - 1, (min(10000, len(x_train)),))
    pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
    diss_input = pre_train_model.predict(x_train[idx, :],
                                         batch_size=batch_size)
    diss.pre_training(diss_input, y_train[idx], capsule_inputs_are_equal=True)

    # set identity matrices
    centers, _ = diss.get_weights()
    diss.set_weights([centers, matrix])

    # define model and return
    model = Model(inputs, output)

    return model
def get_model():
    inputs = Input(shape=input_shape)
    diss = TangentDistance(linear_factor=None,
                           squared_dissimilarity=True,
                           projected_atom_shape=15,
                           signal_output='signals')

    # compute the prototype distribution
    proto_distrib = list(np.minimum(
        np.ceil(np.sum(y_train, 0) / 100).astype('int'), 5))
    # class 'Buildings-Grass-Trees-Drives'
    proto_distrib[-2] = 2
    # class 'Corn-mintill'
    proto_distrib[2] = 4
    print('proto_distrib: ' + str(proto_distrib))

    # define capsule network
    caps = Capsule(prototype_distribution=proto_distrib)
    caps.add(InputModule(signal_shape=(-1, np.prod(input_shape)),
                         trainable=False,
                         init_diss_initializer='zeros'))
    caps.add(diss)
    caps.add(SqueezeRouting())
    caps.add(NearestCompetition())

    output = caps(inputs)[1]

    # pre-train the model
    pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
    diss_input = pre_train_model.predict(x_train, batch_size=batch_size)
    diss.pre_training(diss_input, y_train, capsule_inputs_are_equal=True)

    # define model and return
    model = Model(inputs, output)

    return model
Exemplo n.º 5
0
def CapsNet(input_shape, n_class, routings):
    """ Initialize the CapsNet"""
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    digitcaps = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)

    # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
    primary_caps = Capsule(name='PrimaryCaps')
    primary_caps.add(layers.Conv2D(filters=8 * 32, kernel_size=9, strides=2, padding='valid', name='primarycap_conv2d'))
    primary_caps.add(layers.Reshape(target_shape=[-1, 8], name='primarycap_reshape'))
    primary_caps.add(layers.Lambda(squash, name='primarycap_squash'))

    digitcaps = primary_caps(digitcaps)

    # Layer 3: Capsule layer. Routing algorithm works here.
    digit_caps = Capsule(name='digitcaps', prototype_distribution=(1, n_class))
    digit_caps.add(InputModule(signal_shape=None, init_diss_initializer='zeros', trainable=False))
    digit_caps.add(LinearTransformation(output_dim=16, scope='local'))
    digit_caps.add(DynamicRouting(iterations=routings, name='capsnet'))

    digitcaps = digit_caps(digitcaps)

    # Decoder network.
    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([digitcaps[0], y])  # The true label is used to mask the output of capsule layer. For training
    masked = Mask()([digitcaps[0], digitcaps[2]])  # Mask using the capsule with maximal length. For prediction

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = models.Model([x, y], [digitcaps[2], decoder(masked_by_y)])
    eval_model = models.Model(x, [digitcaps[2], decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=(n_class, 16))
    noised_digitcaps = layers.Add()([digitcaps[0], noise])
    masked_noised_y = Mask()([noised_digitcaps, y])
    manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y))
    return train_model, eval_model, manipulate_model
Exemplo n.º 6
0
x_train, y_train = tecator.load_data()

x_train = x_train.astype('float32')
y_train = to_categorical(y_train, num_classes).astype('float32')

inputs = Input((x_train.shape[1], ))

# define MLVQ prototypes
diss = OmegaDistance(linear_factor=None,
                     squared_dissimilarity=True,
                     matrix_scope='global',
                     matrix_constraint='OmegaNormalization',
                     signal_output='signals')

caps = Capsule(prototype_distribution=(1, num_classes))
caps.add(
    InputModule(signal_shape=x_train.shape[1],
                trainable=False,
                dissimilarity_initializer='zeros'))
caps.add(diss)
caps.add(SqueezeRouting())
caps.add(Classification(probability_transformation='flip', name='lvq_capsule'))

outputs = caps(inputs)

# pre-train the model
pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
diss_input = pre_train_model.predict(x_train, batch_size=batch_size)
diss.pre_training(diss_input, y_train, capsule_inputs_are_equal=True)
Exemplo n.º 7
0
def LvqCapsNet(input_shape):
    input_img = Input(shape=input_shape)

    # Block 1
    caps0 = Capsule()
    caps0.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps0.add(BatchNormalization())
    caps0.add(Activation('relu'))
    caps0.add(Dropout(0.25))
    x = caps0(input_img)

    # Block 2
    caps1 = Capsule()
    caps1.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(x)

    # Block 3
    caps2 = Capsule()
    caps2.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 4
    caps3 = Capsule(prototype_distribution=32)
    caps3.add(Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 5
    caps4 = Capsule()
    caps4.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 6
    caps5 = Capsule()
    caps5.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 7
    caps6 = Capsule()
    caps6.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps6.add(SplitModule())
    caps6.add(Activation('relu'), scope_keys=1)
    caps6.add(Flatten(), scope_keys=1)
    x = caps6(x)

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8))
    caps7.add(InputModule(signal_shape=(-1, 64), init_diss_initializer=None, trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16)
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(x)

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4))
    caps8.add(Reshape((8, 8, 64)))
    caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False,
                            epsilon=1.e-12, linear_factor=0.66, signal_output='signals')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 10))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    digit_caps.add(InputModule(signal_shape=128, init_diss_initializer=None, trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False,
                                     linear_factor=0.66, signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(GibbsRouting(norm_axis='channels', trainable=False,
                                diss_regularizer=MaxValue(alpha=0.0001)))
    digit_caps.add(DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps'))

    digitcaps = digit_caps(x)

    # intermediate model for Caps2; used for visualizations
    input_diss8 = [Input((4, 4, 64)), Input((16,))]
    model_vis_caps2 = models.Model(input_diss8, digit_caps(list_to_dict(input_diss8)))

    # Decoder network.
    y = layers.Input(shape=(10,))
    masked_by_y = Mask()([digitcaps[0], y])

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=128 * 10))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod((28, 28, 1)), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=(28, 28, 1), name='out_recon'))

    # Models for training and evaluation (prediction)
    model = models.Model([input_img, y], [digitcaps[2], decoder(masked_by_y)])

    return model, decoder,  model_vis_caps2
Exemplo n.º 8
0
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
y_train = to_categorical(y_train.astype('float32'))
y_test = to_categorical(y_test.astype('float32'))

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

inputs = Input(x_train.shape[1:])

# define MLVQ prototypes
diss = MinkowskiDistance(linear_factor=None,
                         squared_dissimilarity=True,
                         signal_output='signals')

caps = Capsule(prototype_distribution=(protos_per_class, num_classes))
caps.add(
    InputModule(signal_shape=(1, ) + x_train.shape[1:],
                trainable=False,
                init_diss_initializer='zeros'))
caps.add(diss)
caps.add(SqueezeRouting())
caps.add(NearestCompetition(use_for_loop=False))
caps.add(
    DissimilarityTransformation(probability_transformation='flip',
                                name='lvq_capsule'))

outputs = caps(inputs)

# pre-train the model over 10000 random digits
idx = np.random.randint(0, len(x_train) - 1, (min(10000, len(x_train)), ))
Exemplo n.º 9
0
def LvqCapsNet(input_shape, n_class):
    """
    A LVQ-Capsule network on MNIST with Minkowski distance, neg_exp as probability transformation and pre-training.
    """
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    digitcaps = layers.Conv2D(filters=256,
                              kernel_size=9,
                              strides=1,
                              padding='valid',
                              activation='relu')(x)
    digitcaps = layers.Conv2D(filters=8 * 32,
                              kernel_size=9,
                              strides=2,
                              padding='valid')(digitcaps)

    # LVQ-Capsule
    digit_caps = Capsule(name='digitcaps', prototype_distribution=(1, n_class))
    digit_caps.add(
        InputModule(signal_shape=8, dissimilarity_initializer='zeros'))
    digit_caps.add(LinearTransformation(output_dim=16, scope='local'))
    diss = MinkowskiDistance(prototype_initializer='zeros',
                             squared_dissimilarity=False,
                             name='minkowski_distance')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer='max_distance'))
    digit_caps.add(
        Classification(probability_transformation='neg_exp', name='lvq_caps'))

    digitcaps = digit_caps(digitcaps)

    # Decoder network.
    y = layers.Input(shape=(n_class, ))
    masked_by_y = Mask()([digitcaps[0], y])
    masked = Mask()([digitcaps[0], digitcaps[2]])

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = models.Model([x, y], [digitcaps[2], decoder(masked_by_y)])
    eval_model = models.Model(x, [digitcaps[2], decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=(n_class, 16))
    noised_digitcaps = layers.Add()([digitcaps[0], noise])
    masked_noised_y = Mask()([noised_digitcaps, y])
    manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y))
    return train_model, eval_model, manipulate_model, decoder
Exemplo n.º 10
0
def LvqCapsNet(input_shape):
    input_img = layers.Input(shape=input_shape)

    # Block 1
    caps0 = Capsule()
    caps0.add(
        Conv2D(32 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal(),
               input_shape=x_train.shape[1:]))
    caps0.add(BatchNormalization())
    caps0.add(Activation('relu'))
    caps0.add(Dropout(0.25))
    x = caps0(input_img)

    # Block 2
    caps1 = Capsule()
    caps1.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(x)

    # Block 3
    caps2 = Capsule()
    caps2.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 4
    caps3 = Capsule()
    caps3.add(
        Conv2D(64 + 2, (5, 5),
               strides=2,
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 5
    caps4 = Capsule()
    caps4.add(
        Conv2D(32 + 2, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 6
    caps5 = Capsule()
    caps5.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 7
    caps6 = Capsule()
    caps6.add(
        Conv2D(64 + 2, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps6.add(SplitModule(index=-2))
    caps6.add(Activation('relu'), scope_keys=1)
    caps6.add(Flatten(), scope_keys=1)
    x = caps6(x)

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8), sparse_signal=True)
    caps7.add(
        InputModule(signal_shape=(-1, 32),
                    init_diss_initializer=None,
                    trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            projected_atom_shape=16,
                            matrix_scope='global')
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(x)

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4), sparse_signal=True)
    caps8.add(Reshape((8, 8, 32)))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        InputModule(signal_shape=(8 * 8, 64),
                    init_diss_initializer=None,
                    trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16,
                            squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            signal_output='signals',
                            matrix_scope='global')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 5))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(
        Conv2D(128, (3, 3), padding='same',
               kernel_initializer=glorot_normal()))
    digit_caps.add(
        InputModule(signal_shape=128,
                    init_diss_initializer=None,
                    trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16,
                                     epsilon=1.e-12,
                                     squared_dissimilarity=False,
                                     linear_factor=0.66,
                                     signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer=MaxValue(alpha=0.0001)))
    digit_caps.add(
        DissimilarityTransformation(probability_transformation='neg_softmax',
                                    name='lvq_caps'))

    digitcaps = digit_caps(x)

    model = models.Model(input_img, digitcaps[2])

    return model
Exemplo n.º 11
0
def LvqCapsNet(input_shape):
    input_img = layers.Input(shape=input_shape)

    # Block 1
    caps1 = Capsule()
    caps1.add(
        Conv2D(32 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps1.add(BatchNormalization())
    caps1.add(Activation('relu'))
    caps1.add(Dropout(0.25))
    x = caps1(input_img)

    # Block 2
    caps2 = Capsule()
    caps2.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps2.add(BatchNormalization())
    caps2.add(Activation('relu'))
    caps2.add(Dropout(0.25))
    x = caps2(x)

    # Block 3
    caps3 = Capsule()
    caps3.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps3.add(BatchNormalization())
    caps3.add(Activation('relu'))
    caps3.add(Dropout(0.25))
    x = caps3(x)

    # Block 4
    caps4 = Capsule(prototype_distribution=32)
    caps4.add(
        Conv2D(64 + 1, (5, 5),
               strides=2,
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps4.add(BatchNormalization())
    caps4.add(Activation('relu'))
    caps4.add(Dropout(0.25))
    x = caps4(x)

    # Block 5
    caps5 = Capsule()
    caps5.add(
        Conv2D(32 + 1, (3, 3),
               padding='same',
               kernel_initializer=RandomNormal(stddev=0.01)))
    caps5.add(Dropout(0.25))
    x = caps5(x)

    # Block 6
    caps6 = Capsule()
    caps6.add(
        Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal()))
    caps6.add(Dropout(0.25))
    x = caps6(x)

    # Block 7
    x = Conv2D(64 + 1, (3, 3),
               padding='same',
               kernel_initializer=glorot_normal())(x)
    x = [crop(3, 0, 64)(x), crop(3, 64, 65)(x)]
    x[1] = Activation('relu')(x[1])
    x[1] = Flatten()(x[1])

    # Caps1
    caps7 = Capsule(prototype_distribution=(1, 8 * 8))
    caps7.add(
        InputModule(signal_shape=(16 * 16, 64),
                    dissimilarity_initializer=None,
                    trainable=False))
    diss7 = TangentDistance(squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            projected_atom_shape=16)
    caps7.add(diss7)
    caps7.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps7(list_to_dict(x))

    # Caps2
    caps8 = Capsule(prototype_distribution=(1, 4 * 4))
    caps8.add(Reshape((8, 8, 64)))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal()))
    caps8.add(
        InputModule(signal_shape=(8 * 8, 64),
                    dissimilarity_initializer=None,
                    trainable=False))
    diss8 = TangentDistance(projected_atom_shape=16,
                            squared_dissimilarity=False,
                            epsilon=1.e-12,
                            linear_factor=0.66,
                            signal_output='signals')
    caps8.add(diss8)
    caps8.add(GibbsRouting(norm_axis='channels', trainable=False))
    x = caps8(x)

    # Caps3
    digit_caps = Capsule(prototype_distribution=(1, 10))
    digit_caps.add(Reshape((4, 4, 64)))
    digit_caps.add(
        Conv2D(128, (3, 3), padding='same',
               kernel_initializer=glorot_normal()))
    digit_caps.add(
        InputModule(signal_shape=128,
                    dissimilarity_initializer=None,
                    trainable=False))
    diss = RestrictedTangentDistance(projected_atom_shape=16,
                                     epsilon=1.e-12,
                                     squared_dissimilarity=False,
                                     linear_factor=0.66,
                                     signal_output='signals')
    digit_caps.add(diss)
    digit_caps.add(
        GibbsRouting(norm_axis='channels',
                     trainable=False,
                     diss_regularizer=MaxDistance(alpha=0.0001)))
    digit_caps.add(
        Classification(probability_transformation='neg_softmax',
                       name='lvq_caps'))

    digitcaps = digit_caps(x)

    model = models.Model(input_img, digitcaps[2])

    return model
Exemplo n.º 12
0
x_test = np.expand_dims(x_test.astype('float32') / 255., -1)
y_train = to_categorical(y_train.astype('float32'))
y_test = to_categorical(y_test.astype('float32'))

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

inputs = Input(x_train.shape[1:])

# define rTDLVQ prototypes
diss = RestrictedTangentDistance(projected_atom_shape=num_tangents,
                                 linear_factor=None,
                                 squared_dissimilarity=True,
                                 signal_output='signals')

caps = Capsule(prototype_distribution=(1, num_classes))
caps.add(
    InputModule(signal_shape=(1, ) + x_train.shape[1:4],
                trainable=False,
                dissimilarity_initializer='zeros'))
caps.add(diss)
caps.add(SqueezeRouting())
caps.add(NearestCompetition(use_for_loop=False))
caps.add(
    Classification(probability_transformation='neg_softmax',
                   name='lvq_capsule'))

outputs = caps(inputs)

# pre-train the model over 10000 random digits
idx = np.random.randint(0, len(x_train) - 1, (min(10000, len(x_train)), ))
    save_dir = save_dir_base + '/run_' + str(j)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    # iterate over number of tangents
    results = []
    for i in range(max_number_tangents + 1):
        number_tangents = i

        # define define GTLVQ network as capsule
        diss = TangentDistance(projected_atom_shape=number_tangents,
                               linear_factor=None,
                               squared_dissimilarity=True,
                               signal_output='signals')

        caps = Capsule(prototype_distribution=(1, 10))
        caps.add(
            InputModule(signal_shape=(1, ) + x_train.shape[1:4],
                        trainable=False,
                        init_diss_initializer='zeros'))
        caps.add(diss)
        caps.add(SqueezeRouting())
        caps.add(NearestCompetition(use_for_loop=False))

        outputs = caps(inputs)[1]

        # pre-train the model over 10000 random digits
        idx = np.random.randint(0,
                                len(x_train) - 1, (min(10000, len(x_train)), ))
        pre_train_model = Model(inputs=inputs, outputs=diss.input[0])
        diss_input = pre_train_model.predict(x_train[idx, :],