def get_model(): inputs = Input(shape=input_shape) # get the dissimilarity either GLVQ, GTLVQ, or GMLVQ if args.mode == 'glvq': diss = MinkowskiDistance(linear_factor=None, squared_dissimilarity=True, signal_output='signals') elif args.mode == 'gtlvq': diss = RestrictedTangentDistance(linear_factor=None, squared_dissimilarity=True, signal_output='signals', projected_atom_shape=1) elif args.mode == 'gmlvq': # get identity matrices for the matrix initialization as we do not # use the standard routine of anysma for the initialization matrix_init = np.repeat(np.expand_dims(np.eye(2), 0), repeats=np.sum(protos_per_class), axis=0) matrix_init.astype(K.floatx()) diss = OmegaDistance(linear_factor=None, squared_dissimilarity=True, signal_output='signals', matrix_scope='local', matrix_constraint='OmegaNormalization', matrix_initializer=lambda x: matrix_init) # define capsule network caps = Capsule(prototype_distribution=protos_per_class) caps.add( InputModule(signal_shape=(-1, np.prod(input_shape)), trainable=False, init_diss_initializer='zeros')) caps.add(diss) caps.add(SqueezeRouting()) caps.add(NearestCompetition()) output = caps(inputs)[1] # pre-train the model and overwrite the standard initialization matrix # for GMLVQ if args.mode == 'gmlvq': _, matrices = diss.get_weights() pre_train_model = Model(inputs=inputs, outputs=diss.input[0]) diss_input = pre_train_model.predict(x_train, batch_size=batch_size) diss.pre_training(diss_input, y_train, capsule_inputs_are_equal=True) if args.mode == 'gmlvq': # set identity matrices centers, _ = diss.get_weights() diss.set_weights([centers, matrices]) # define model and return model = Model(inputs, output) return model
def LvqCapsNet(input_shape): input_img = Input(shape=input_shape) # Block 1 caps0 = Capsule() caps0.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps0.add(BatchNormalization()) caps0.add(Activation('relu')) caps0.add(Dropout(0.25)) x = caps0(input_img) # Block 2 caps1 = Capsule() caps1.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps1.add(BatchNormalization()) caps1.add(Activation('relu')) caps1.add(Dropout(0.25)) x = caps1(x) # Block 3 caps2 = Capsule() caps2.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps2.add(BatchNormalization()) caps2.add(Activation('relu')) caps2.add(Dropout(0.25)) x = caps2(x) # Block 4 caps3 = Capsule(prototype_distribution=32) caps3.add(Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps3.add(BatchNormalization()) caps3.add(Activation('relu')) caps3.add(Dropout(0.25)) x = caps3(x) # Block 5 caps4 = Capsule() caps4.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps4.add(Dropout(0.25)) x = caps4(x) # Block 6 caps5 = Capsule() caps5.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps5.add(Dropout(0.25)) x = caps5(x) # Block 7 caps6 = Capsule() caps6.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps6.add(SplitModule()) caps6.add(Activation('relu'), scope_keys=1) caps6.add(Flatten(), scope_keys=1) x = caps6(x) # Caps1 caps7 = Capsule(prototype_distribution=(1, 8 * 8)) caps7.add(InputModule(signal_shape=(-1, 64), init_diss_initializer=None, trainable=False)) diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16) caps7.add(diss7) caps7.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps7(x) # Caps2 caps8 = Capsule(prototype_distribution=(1, 4 * 4)) caps8.add(Reshape((8, 8, 64))) caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add(InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False)) diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, signal_output='signals') caps8.add(diss8) caps8.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps8(x) # Caps3 digit_caps = Capsule(prototype_distribution=(1, 10)) digit_caps.add(Reshape((4, 4, 64))) digit_caps.add(Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal())) digit_caps.add(InputModule(signal_shape=128, init_diss_initializer=None, trainable=False)) diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False, linear_factor=0.66, signal_output='signals') digit_caps.add(diss) digit_caps.add(GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer=MaxValue(alpha=0.0001))) digit_caps.add(DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps')) digitcaps = digit_caps(x) # intermediate model for Caps2; used for visualizations input_diss8 = [Input((4, 4, 64)), Input((16,))] model_vis_caps2 = models.Model(input_diss8, digit_caps(list_to_dict(input_diss8))) # Decoder network. y = layers.Input(shape=(10,)) masked_by_y = Mask()([digitcaps[0], y]) # Shared Decoder model in training and prediction decoder = models.Sequential(name='decoder') decoder.add(layers.Dense(512, activation='relu', input_dim=128 * 10)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod((28, 28, 1)), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=(28, 28, 1), name='out_recon')) # Models for training and evaluation (prediction) model = models.Model([input_img, y], [digitcaps[2], decoder(masked_by_y)]) return model, decoder, model_vis_caps2
def LvqCapsNet(input_shape): input_img = layers.Input(shape=input_shape) # Block 1 caps0 = Capsule() caps0.add( Conv2D(32 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal(), input_shape=x_train.shape[1:])) caps0.add(BatchNormalization()) caps0.add(Activation('relu')) caps0.add(Dropout(0.25)) x = caps0(input_img) # Block 2 caps1 = Capsule() caps1.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps1.add(BatchNormalization()) caps1.add(Activation('relu')) caps1.add(Dropout(0.25)) x = caps1(x) # Block 3 caps2 = Capsule() caps2.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps2.add(BatchNormalization()) caps2.add(Activation('relu')) caps2.add(Dropout(0.25)) x = caps2(x) # Block 4 caps3 = Capsule() caps3.add( Conv2D(64 + 2, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps3.add(BatchNormalization()) caps3.add(Activation('relu')) caps3.add(Dropout(0.25)) x = caps3(x) # Block 5 caps4 = Capsule() caps4.add( Conv2D(32 + 2, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps4.add(Dropout(0.25)) x = caps4(x) # Block 6 caps5 = Capsule() caps5.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps5.add(Dropout(0.25)) x = caps5(x) # Block 7 caps6 = Capsule() caps6.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps6.add(SplitModule(index=-2)) caps6.add(Activation('relu'), scope_keys=1) caps6.add(Flatten(), scope_keys=1) x = caps6(x) # Caps1 caps7 = Capsule(prototype_distribution=(1, 8 * 8), sparse_signal=True) caps7.add( InputModule(signal_shape=(-1, 32), init_diss_initializer=None, trainable=False)) diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16, matrix_scope='global') caps7.add(diss7) caps7.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps7(x) # Caps2 caps8 = Capsule(prototype_distribution=(1, 4 * 4), sparse_signal=True) caps8.add(Reshape((8, 8, 32))) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False)) diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, signal_output='signals', matrix_scope='global') caps8.add(diss8) caps8.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps8(x) # Caps3 digit_caps = Capsule(prototype_distribution=(1, 5)) digit_caps.add(Reshape((4, 4, 64))) digit_caps.add( Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal())) digit_caps.add( InputModule(signal_shape=128, init_diss_initializer=None, trainable=False)) diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False, linear_factor=0.66, signal_output='signals') digit_caps.add(diss) digit_caps.add( GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer=MaxValue(alpha=0.0001))) digit_caps.add( DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps')) digitcaps = digit_caps(x) model = models.Model(input_img, digitcaps[2]) return model
def LvqCapsNet(input_shape): input_img = layers.Input(shape=input_shape) # Block 1 caps1 = Capsule() caps1.add( Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps1.add(BatchNormalization()) caps1.add(Activation('relu')) caps1.add(Dropout(0.25)) x = caps1(input_img) # Block 2 caps2 = Capsule() caps2.add( Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps2.add(BatchNormalization()) caps2.add(Activation('relu')) caps2.add(Dropout(0.25)) x = caps2(x) # Block 3 caps3 = Capsule() caps3.add( Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps3.add(BatchNormalization()) caps3.add(Activation('relu')) caps3.add(Dropout(0.25)) x = caps3(x) # Block 4 caps4 = Capsule(prototype_distribution=32) caps4.add( Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps4.add(BatchNormalization()) caps4.add(Activation('relu')) caps4.add(Dropout(0.25)) x = caps4(x) # Block 5 caps5 = Capsule() caps5.add( Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps5.add(Dropout(0.25)) x = caps5(x) # Block 6 caps6 = Capsule() caps6.add( Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps6.add(Dropout(0.25)) x = caps6(x) # Block 7 x = Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())(x) x = [crop(3, 0, 64)(x), crop(3, 64, 65)(x)] x[1] = Activation('relu')(x[1]) x[1] = Flatten()(x[1]) # Caps1 caps7 = Capsule(prototype_distribution=(1, 8 * 8)) caps7.add( InputModule(signal_shape=(16 * 16, 64), dissimilarity_initializer=None, trainable=False)) diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16) caps7.add(diss7) caps7.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps7(list_to_dict(x)) # Caps2 caps8 = Capsule(prototype_distribution=(1, 4 * 4)) caps8.add(Reshape((8, 8, 64))) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( InputModule(signal_shape=(8 * 8, 64), dissimilarity_initializer=None, trainable=False)) diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, signal_output='signals') caps8.add(diss8) caps8.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps8(x) # Caps3 digit_caps = Capsule(prototype_distribution=(1, 10)) digit_caps.add(Reshape((4, 4, 64))) digit_caps.add( Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal())) digit_caps.add( InputModule(signal_shape=128, dissimilarity_initializer=None, trainable=False)) diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False, linear_factor=0.66, signal_output='signals') digit_caps.add(diss) digit_caps.add( GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer=MaxDistance(alpha=0.0001))) digit_caps.add( Classification(probability_transformation='neg_softmax', name='lvq_caps')) digitcaps = digit_caps(x) model = models.Model(input_img, digitcaps[2]) return model
save_dir = './output' (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = np.expand_dims(x_train.astype('float32') / 255., -1) x_test = np.expand_dims(x_test.astype('float32') / 255., -1) y_train = to_categorical(y_train.astype('float32')) y_test = to_categorical(y_test.astype('float32')) if not os.path.exists(save_dir): os.mkdir(save_dir) inputs = Input(x_train.shape[1:]) # define rTDLVQ prototypes diss = RestrictedTangentDistance(projected_atom_shape=num_tangents, linear_factor=None, squared_dissimilarity=True, signal_output='signals') caps = Capsule(prototype_distribution=(1, num_classes)) caps.add( InputModule(signal_shape=(1, ) + x_train.shape[1:4], trainable=False, dissimilarity_initializer='zeros')) caps.add(diss) caps.add(SqueezeRouting()) caps.add(NearestCompetition(use_for_loop=False)) caps.add( Classification(probability_transformation='neg_softmax', name='lvq_capsule')) outputs = caps(inputs)