def LvqCapsNet(input_shape, n_class): """ A LVQ-Capsule network on MNIST with Minkowski distance, neg_exp as probability transformation and pre-training. """ x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer digitcaps = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu')(x) digitcaps = layers.Conv2D(filters=8 * 32, kernel_size=9, strides=2, padding='valid')(digitcaps) # LVQ-Capsule digit_caps = Capsule(name='digitcaps', prototype_distribution=(1, n_class)) digit_caps.add( InputModule(signal_shape=8, dissimilarity_initializer='zeros')) digit_caps.add(LinearTransformation(output_dim=16, scope='local')) diss = MinkowskiDistance(prototype_initializer='zeros', squared_dissimilarity=False, name='minkowski_distance') digit_caps.add(diss) digit_caps.add( GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer='max_distance')) digit_caps.add( Classification(probability_transformation='neg_exp', name='lvq_caps')) digitcaps = digit_caps(digitcaps) # Decoder network. y = layers.Input(shape=(n_class, )) masked_by_y = Mask()([digitcaps[0], y]) masked = Mask()([digitcaps[0], digitcaps[2]]) # Shared Decoder model in training and prediction decoder = models.Sequential(name='decoder') decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon')) # Models for training and evaluation (prediction) train_model = models.Model([x, y], [digitcaps[2], decoder(masked_by_y)]) eval_model = models.Model(x, [digitcaps[2], decoder(masked)]) # manipulate model noise = layers.Input(shape=(n_class, 16)) noised_digitcaps = layers.Add()([digitcaps[0], noise]) masked_noised_y = Mask()([noised_digitcaps, y]) manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y)) return train_model, eval_model, manipulate_model, decoder
def LvqCapsNet(input_shape): input_img = Input(shape=input_shape) # Block 1 caps0 = Capsule() caps0.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps0.add(BatchNormalization()) caps0.add(Activation('relu')) caps0.add(Dropout(0.25)) x = caps0(input_img) # Block 2 caps1 = Capsule() caps1.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps1.add(BatchNormalization()) caps1.add(Activation('relu')) caps1.add(Dropout(0.25)) x = caps1(x) # Block 3 caps2 = Capsule() caps2.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps2.add(BatchNormalization()) caps2.add(Activation('relu')) caps2.add(Dropout(0.25)) x = caps2(x) # Block 4 caps3 = Capsule(prototype_distribution=32) caps3.add(Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps3.add(BatchNormalization()) caps3.add(Activation('relu')) caps3.add(Dropout(0.25)) x = caps3(x) # Block 5 caps4 = Capsule() caps4.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps4.add(Dropout(0.25)) x = caps4(x) # Block 6 caps5 = Capsule() caps5.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps5.add(Dropout(0.25)) x = caps5(x) # Block 7 caps6 = Capsule() caps6.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps6.add(SplitModule()) caps6.add(Activation('relu'), scope_keys=1) caps6.add(Flatten(), scope_keys=1) x = caps6(x) # Caps1 caps7 = Capsule(prototype_distribution=(1, 8 * 8)) caps7.add(InputModule(signal_shape=(-1, 64), init_diss_initializer=None, trainable=False)) diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16) caps7.add(diss7) caps7.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps7(x) # Caps2 caps8 = Capsule(prototype_distribution=(1, 4 * 4)) caps8.add(Reshape((8, 8, 64))) caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add(InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False)) diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, signal_output='signals') caps8.add(diss8) caps8.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps8(x) # Caps3 digit_caps = Capsule(prototype_distribution=(1, 10)) digit_caps.add(Reshape((4, 4, 64))) digit_caps.add(Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal())) digit_caps.add(InputModule(signal_shape=128, init_diss_initializer=None, trainable=False)) diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False, linear_factor=0.66, signal_output='signals') digit_caps.add(diss) digit_caps.add(GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer=MaxValue(alpha=0.0001))) digit_caps.add(DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps')) digitcaps = digit_caps(x) # intermediate model for Caps2; used for visualizations input_diss8 = [Input((4, 4, 64)), Input((16,))] model_vis_caps2 = models.Model(input_diss8, digit_caps(list_to_dict(input_diss8))) # Decoder network. y = layers.Input(shape=(10,)) masked_by_y = Mask()([digitcaps[0], y]) # Shared Decoder model in training and prediction decoder = models.Sequential(name='decoder') decoder.add(layers.Dense(512, activation='relu', input_dim=128 * 10)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod((28, 28, 1)), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=(28, 28, 1), name='out_recon')) # Models for training and evaluation (prediction) model = models.Model([input_img, y], [digitcaps[2], decoder(masked_by_y)]) return model, decoder, model_vis_caps2
def LvqCapsNet(input_shape): input_img = layers.Input(shape=input_shape) # Block 1 caps0 = Capsule() caps0.add( Conv2D(32 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal(), input_shape=x_train.shape[1:])) caps0.add(BatchNormalization()) caps0.add(Activation('relu')) caps0.add(Dropout(0.25)) x = caps0(input_img) # Block 2 caps1 = Capsule() caps1.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps1.add(BatchNormalization()) caps1.add(Activation('relu')) caps1.add(Dropout(0.25)) x = caps1(x) # Block 3 caps2 = Capsule() caps2.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps2.add(BatchNormalization()) caps2.add(Activation('relu')) caps2.add(Dropout(0.25)) x = caps2(x) # Block 4 caps3 = Capsule() caps3.add( Conv2D(64 + 2, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps3.add(BatchNormalization()) caps3.add(Activation('relu')) caps3.add(Dropout(0.25)) x = caps3(x) # Block 5 caps4 = Capsule() caps4.add( Conv2D(32 + 2, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps4.add(Dropout(0.25)) x = caps4(x) # Block 6 caps5 = Capsule() caps5.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps5.add(Dropout(0.25)) x = caps5(x) # Block 7 caps6 = Capsule() caps6.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps6.add(SplitModule(index=-2)) caps6.add(Activation('relu'), scope_keys=1) caps6.add(Flatten(), scope_keys=1) x = caps6(x) # Caps1 caps7 = Capsule(prototype_distribution=(1, 8 * 8), sparse_signal=True) caps7.add( InputModule(signal_shape=(-1, 32), init_diss_initializer=None, trainable=False)) diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16, matrix_scope='global') caps7.add(diss7) caps7.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps7(x) # Caps2 caps8 = Capsule(prototype_distribution=(1, 4 * 4), sparse_signal=True) caps8.add(Reshape((8, 8, 32))) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False)) diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, signal_output='signals', matrix_scope='global') caps8.add(diss8) caps8.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps8(x) # Caps3 digit_caps = Capsule(prototype_distribution=(1, 5)) digit_caps.add(Reshape((4, 4, 64))) digit_caps.add( Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal())) digit_caps.add( InputModule(signal_shape=128, init_diss_initializer=None, trainable=False)) diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False, linear_factor=0.66, signal_output='signals') digit_caps.add(diss) digit_caps.add( GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer=MaxValue(alpha=0.0001))) digit_caps.add( DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps')) digitcaps = digit_caps(x) model = models.Model(input_img, digitcaps[2]) return model
def LvqCapsNet(input_shape): input_img = layers.Input(shape=input_shape) # Block 1 caps1 = Capsule() caps1.add( Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps1.add(BatchNormalization()) caps1.add(Activation('relu')) caps1.add(Dropout(0.25)) x = caps1(input_img) # Block 2 caps2 = Capsule() caps2.add( Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps2.add(BatchNormalization()) caps2.add(Activation('relu')) caps2.add(Dropout(0.25)) x = caps2(x) # Block 3 caps3 = Capsule() caps3.add( Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps3.add(BatchNormalization()) caps3.add(Activation('relu')) caps3.add(Dropout(0.25)) x = caps3(x) # Block 4 caps4 = Capsule(prototype_distribution=32) caps4.add( Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps4.add(BatchNormalization()) caps4.add(Activation('relu')) caps4.add(Dropout(0.25)) x = caps4(x) # Block 5 caps5 = Capsule() caps5.add( Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps5.add(Dropout(0.25)) x = caps5(x) # Block 6 caps6 = Capsule() caps6.add( Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps6.add(Dropout(0.25)) x = caps6(x) # Block 7 x = Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())(x) x = [crop(3, 0, 64)(x), crop(3, 64, 65)(x)] x[1] = Activation('relu')(x[1]) x[1] = Flatten()(x[1]) # Caps1 caps7 = Capsule(prototype_distribution=(1, 8 * 8)) caps7.add( InputModule(signal_shape=(16 * 16, 64), dissimilarity_initializer=None, trainable=False)) diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16) caps7.add(diss7) caps7.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps7(list_to_dict(x)) # Caps2 caps8 = Capsule(prototype_distribution=(1, 4 * 4)) caps8.add(Reshape((8, 8, 64))) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( InputModule(signal_shape=(8 * 8, 64), dissimilarity_initializer=None, trainable=False)) diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, signal_output='signals') caps8.add(diss8) caps8.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps8(x) # Caps3 digit_caps = Capsule(prototype_distribution=(1, 10)) digit_caps.add(Reshape((4, 4, 64))) digit_caps.add( Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal())) digit_caps.add( InputModule(signal_shape=128, dissimilarity_initializer=None, trainable=False)) diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False, linear_factor=0.66, signal_output='signals') digit_caps.add(diss) digit_caps.add( GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer=MaxDistance(alpha=0.0001))) digit_caps.add( Classification(probability_transformation='neg_softmax', name='lvq_caps')) digitcaps = digit_caps(x) model = models.Model(input_img, digitcaps[2]) return model