def LvqCapsNet(input_shape, n_class): """ A LVQ-Capsule network on MNIST with Minkowski distance, neg_exp as probability transformation and pre-training. """ x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer digitcaps = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu')(x) digitcaps = layers.Conv2D(filters=8 * 32, kernel_size=9, strides=2, padding='valid')(digitcaps) # LVQ-Capsule digit_caps = Capsule(name='digitcaps', prototype_distribution=(1, n_class)) digit_caps.add( InputModule(signal_shape=8, dissimilarity_initializer='zeros')) digit_caps.add(LinearTransformation(output_dim=16, scope='local')) diss = MinkowskiDistance(prototype_initializer='zeros', squared_dissimilarity=False, name='minkowski_distance') digit_caps.add(diss) digit_caps.add( GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer='max_distance')) digit_caps.add( Classification(probability_transformation='neg_exp', name='lvq_caps')) digitcaps = digit_caps(digitcaps) # Decoder network. y = layers.Input(shape=(n_class, )) masked_by_y = Mask()([digitcaps[0], y]) masked = Mask()([digitcaps[0], digitcaps[2]]) # Shared Decoder model in training and prediction decoder = models.Sequential(name='decoder') decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon')) # Models for training and evaluation (prediction) train_model = models.Model([x, y], [digitcaps[2], decoder(masked_by_y)]) eval_model = models.Model(x, [digitcaps[2], decoder(masked)]) # manipulate model noise = layers.Input(shape=(n_class, 16)) noised_digitcaps = layers.Add()([digitcaps[0], noise]) masked_noised_y = Mask()([noised_digitcaps, y]) manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y)) return train_model, eval_model, manipulate_model, decoder
def CapsNet(input_shape, n_class, routings): """ Initialize the CapsNet""" x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer digitcaps = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x) # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule] primary_caps = Capsule(name='PrimaryCaps') primary_caps.add(layers.Conv2D(filters=8 * 32, kernel_size=9, strides=2, padding='valid', name='primarycap_conv2d')) primary_caps.add(layers.Reshape(target_shape=[-1, 8], name='primarycap_reshape')) primary_caps.add(layers.Lambda(squash, name='primarycap_squash')) digitcaps = primary_caps(digitcaps) # Layer 3: Capsule layer. Routing algorithm works here. digit_caps = Capsule(name='digitcaps', prototype_distribution=(1, n_class)) digit_caps.add(InputModule(signal_shape=None, init_diss_initializer='zeros', trainable=False)) digit_caps.add(LinearTransformation(output_dim=16, scope='local')) digit_caps.add(DynamicRouting(iterations=routings, name='capsnet')) digitcaps = digit_caps(digitcaps) # Decoder network. y = layers.Input(shape=(n_class,)) masked_by_y = Mask()([digitcaps[0], y]) # The true label is used to mask the output of capsule layer. For training masked = Mask()([digitcaps[0], digitcaps[2]]) # Mask using the capsule with maximal length. For prediction # Shared Decoder model in training and prediction decoder = models.Sequential(name='decoder') decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon')) # Models for training and evaluation (prediction) train_model = models.Model([x, y], [digitcaps[2], decoder(masked_by_y)]) eval_model = models.Model(x, [digitcaps[2], decoder(masked)]) # manipulate model noise = layers.Input(shape=(n_class, 16)) noised_digitcaps = layers.Add()([digitcaps[0], noise]) masked_noised_y = Mask()([noised_digitcaps, y]) manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y)) return train_model, eval_model, manipulate_model
x_train = x_train.astype('float32') y_train = to_categorical(y_train, num_classes).astype('float32') inputs = Input((x_train.shape[1], )) # define MLVQ prototypes diss = OmegaDistance(linear_factor=None, squared_dissimilarity=True, matrix_scope='global', matrix_constraint='OmegaNormalization', signal_output='signals') caps = Capsule(prototype_distribution=(1, num_classes)) caps.add( InputModule(signal_shape=x_train.shape[1], trainable=False, dissimilarity_initializer='zeros')) caps.add(diss) caps.add(SqueezeRouting()) caps.add(Classification(probability_transformation='flip', name='lvq_capsule')) outputs = caps(inputs) # pre-train the model pre_train_model = Model(inputs=inputs, outputs=diss.input[0]) diss_input = pre_train_model.predict(x_train, batch_size=batch_size) diss.pre_training(diss_input, y_train, capsule_inputs_are_equal=True) model = Model(inputs, outputs[2]) model.compile(loss=glvq_loss, optimizer=optimizers.Adam(lr=lr),
def LvqCapsNet(input_shape): input_img = Input(shape=input_shape) # Block 1 caps0 = Capsule() caps0.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps0.add(BatchNormalization()) caps0.add(Activation('relu')) caps0.add(Dropout(0.25)) x = caps0(input_img) # Block 2 caps1 = Capsule() caps1.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps1.add(BatchNormalization()) caps1.add(Activation('relu')) caps1.add(Dropout(0.25)) x = caps1(x) # Block 3 caps2 = Capsule() caps2.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps2.add(BatchNormalization()) caps2.add(Activation('relu')) caps2.add(Dropout(0.25)) x = caps2(x) # Block 4 caps3 = Capsule(prototype_distribution=32) caps3.add(Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps3.add(BatchNormalization()) caps3.add(Activation('relu')) caps3.add(Dropout(0.25)) x = caps3(x) # Block 5 caps4 = Capsule() caps4.add(Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps4.add(Dropout(0.25)) x = caps4(x) # Block 6 caps5 = Capsule() caps5.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps5.add(Dropout(0.25)) x = caps5(x) # Block 7 caps6 = Capsule() caps6.add(Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps6.add(SplitModule()) caps6.add(Activation('relu'), scope_keys=1) caps6.add(Flatten(), scope_keys=1) x = caps6(x) # Caps1 caps7 = Capsule(prototype_distribution=(1, 8 * 8)) caps7.add(InputModule(signal_shape=(-1, 64), init_diss_initializer=None, trainable=False)) diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16) caps7.add(diss7) caps7.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps7(x) # Caps2 caps8 = Capsule(prototype_distribution=(1, 4 * 4)) caps8.add(Reshape((8, 8, 64))) caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add(InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False)) diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, signal_output='signals') caps8.add(diss8) caps8.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps8(x) # Caps3 digit_caps = Capsule(prototype_distribution=(1, 10)) digit_caps.add(Reshape((4, 4, 64))) digit_caps.add(Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal())) digit_caps.add(InputModule(signal_shape=128, init_diss_initializer=None, trainable=False)) diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False, linear_factor=0.66, signal_output='signals') digit_caps.add(diss) digit_caps.add(GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer=MaxValue(alpha=0.0001))) digit_caps.add(DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps')) digitcaps = digit_caps(x) # intermediate model for Caps2; used for visualizations input_diss8 = [Input((4, 4, 64)), Input((16,))] model_vis_caps2 = models.Model(input_diss8, digit_caps(list_to_dict(input_diss8))) # Decoder network. y = layers.Input(shape=(10,)) masked_by_y = Mask()([digitcaps[0], y]) # Shared Decoder model in training and prediction decoder = models.Sequential(name='decoder') decoder.add(layers.Dense(512, activation='relu', input_dim=128 * 10)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod((28, 28, 1)), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=(28, 28, 1), name='out_recon')) # Models for training and evaluation (prediction) model = models.Model([input_img, y], [digitcaps[2], decoder(masked_by_y)]) return model, decoder, model_vis_caps2
y_test = to_categorical(y_test.astype('float32')) if not os.path.exists(save_dir): os.mkdir(save_dir) inputs = Input(x_train.shape[1:]) # define MLVQ prototypes diss = MinkowskiDistance(linear_factor=None, squared_dissimilarity=True, signal_output='signals') caps = Capsule(prototype_distribution=(protos_per_class, num_classes)) caps.add( InputModule(signal_shape=(1, ) + x_train.shape[1:], trainable=False, init_diss_initializer='zeros')) caps.add(diss) caps.add(SqueezeRouting()) caps.add(NearestCompetition(use_for_loop=False)) caps.add( DissimilarityTransformation(probability_transformation='flip', name='lvq_capsule')) outputs = caps(inputs) # pre-train the model over 10000 random digits idx = np.random.randint(0, len(x_train) - 1, (min(10000, len(x_train)), )) pre_train_model = Model(inputs=inputs, outputs=diss.input[0]) diss_input = pre_train_model.predict(x_train[idx, :], batch_size=batch_size) diss.pre_training(diss_input, y_train[idx], capsule_inputs_are_equal=True)
def LvqCapsNet(input_shape): input_img = layers.Input(shape=input_shape) # Block 1 caps0 = Capsule() caps0.add( Conv2D(32 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal(), input_shape=x_train.shape[1:])) caps0.add(BatchNormalization()) caps0.add(Activation('relu')) caps0.add(Dropout(0.25)) x = caps0(input_img) # Block 2 caps1 = Capsule() caps1.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps1.add(BatchNormalization()) caps1.add(Activation('relu')) caps1.add(Dropout(0.25)) x = caps1(x) # Block 3 caps2 = Capsule() caps2.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps2.add(BatchNormalization()) caps2.add(Activation('relu')) caps2.add(Dropout(0.25)) x = caps2(x) # Block 4 caps3 = Capsule() caps3.add( Conv2D(64 + 2, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps3.add(BatchNormalization()) caps3.add(Activation('relu')) caps3.add(Dropout(0.25)) x = caps3(x) # Block 5 caps4 = Capsule() caps4.add( Conv2D(32 + 2, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps4.add(Dropout(0.25)) x = caps4(x) # Block 6 caps5 = Capsule() caps5.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps5.add(Dropout(0.25)) x = caps5(x) # Block 7 caps6 = Capsule() caps6.add( Conv2D(64 + 2, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps6.add(SplitModule(index=-2)) caps6.add(Activation('relu'), scope_keys=1) caps6.add(Flatten(), scope_keys=1) x = caps6(x) # Caps1 caps7 = Capsule(prototype_distribution=(1, 8 * 8), sparse_signal=True) caps7.add( InputModule(signal_shape=(-1, 32), init_diss_initializer=None, trainable=False)) diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16, matrix_scope='global') caps7.add(diss7) caps7.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps7(x) # Caps2 caps8 = Capsule(prototype_distribution=(1, 4 * 4), sparse_signal=True) caps8.add(Reshape((8, 8, 32))) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( InputModule(signal_shape=(8 * 8, 64), init_diss_initializer=None, trainable=False)) diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, signal_output='signals', matrix_scope='global') caps8.add(diss8) caps8.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps8(x) # Caps3 digit_caps = Capsule(prototype_distribution=(1, 5)) digit_caps.add(Reshape((4, 4, 64))) digit_caps.add( Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal())) digit_caps.add( InputModule(signal_shape=128, init_diss_initializer=None, trainable=False)) diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False, linear_factor=0.66, signal_output='signals') digit_caps.add(diss) digit_caps.add( GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer=MaxValue(alpha=0.0001))) digit_caps.add( DissimilarityTransformation(probability_transformation='neg_softmax', name='lvq_caps')) digitcaps = digit_caps(x) model = models.Model(input_img, digitcaps[2]) return model
def LvqCapsNet(input_shape): input_img = layers.Input(shape=input_shape) # Block 1 caps1 = Capsule() caps1.add( Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps1.add(BatchNormalization()) caps1.add(Activation('relu')) caps1.add(Dropout(0.25)) x = caps1(input_img) # Block 2 caps2 = Capsule() caps2.add( Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps2.add(BatchNormalization()) caps2.add(Activation('relu')) caps2.add(Dropout(0.25)) x = caps2(x) # Block 3 caps3 = Capsule() caps3.add( Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps3.add(BatchNormalization()) caps3.add(Activation('relu')) caps3.add(Dropout(0.25)) x = caps3(x) # Block 4 caps4 = Capsule(prototype_distribution=32) caps4.add( Conv2D(64 + 1, (5, 5), strides=2, padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps4.add(BatchNormalization()) caps4.add(Activation('relu')) caps4.add(Dropout(0.25)) x = caps4(x) # Block 5 caps5 = Capsule() caps5.add( Conv2D(32 + 1, (3, 3), padding='same', kernel_initializer=RandomNormal(stddev=0.01))) caps5.add(Dropout(0.25)) x = caps5(x) # Block 6 caps6 = Capsule() caps6.add( Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps6.add(Dropout(0.25)) x = caps6(x) # Block 7 x = Conv2D(64 + 1, (3, 3), padding='same', kernel_initializer=glorot_normal())(x) x = [crop(3, 0, 64)(x), crop(3, 64, 65)(x)] x[1] = Activation('relu')(x[1]) x[1] = Flatten()(x[1]) # Caps1 caps7 = Capsule(prototype_distribution=(1, 8 * 8)) caps7.add( InputModule(signal_shape=(16 * 16, 64), dissimilarity_initializer=None, trainable=False)) diss7 = TangentDistance(squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, projected_atom_shape=16) caps7.add(diss7) caps7.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps7(list_to_dict(x)) # Caps2 caps8 = Capsule(prototype_distribution=(1, 4 * 4)) caps8.add(Reshape((8, 8, 64))) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( Conv2D(64, (3, 3), padding='same', kernel_initializer=glorot_normal())) caps8.add( InputModule(signal_shape=(8 * 8, 64), dissimilarity_initializer=None, trainable=False)) diss8 = TangentDistance(projected_atom_shape=16, squared_dissimilarity=False, epsilon=1.e-12, linear_factor=0.66, signal_output='signals') caps8.add(diss8) caps8.add(GibbsRouting(norm_axis='channels', trainable=False)) x = caps8(x) # Caps3 digit_caps = Capsule(prototype_distribution=(1, 10)) digit_caps.add(Reshape((4, 4, 64))) digit_caps.add( Conv2D(128, (3, 3), padding='same', kernel_initializer=glorot_normal())) digit_caps.add( InputModule(signal_shape=128, dissimilarity_initializer=None, trainable=False)) diss = RestrictedTangentDistance(projected_atom_shape=16, epsilon=1.e-12, squared_dissimilarity=False, linear_factor=0.66, signal_output='signals') digit_caps.add(diss) digit_caps.add( GibbsRouting(norm_axis='channels', trainable=False, diss_regularizer=MaxDistance(alpha=0.0001))) digit_caps.add( Classification(probability_transformation='neg_softmax', name='lvq_caps')) digitcaps = digit_caps(x) model = models.Model(input_img, digitcaps[2]) return model
x_train, y_train = tecator.load_data() x_train = x_train.astype('float32') y_train = to_categorical(y_train, num_classes).astype('float32') inputs = Input((x_train.shape[1], )) # define MLVQ prototypes diss = OmegaDistance(linear_factor=None, squared_dissimilarity=True, matrix_scope='global', matrix_constraint='OmegaNormalization', signal_output='signals') caps = Capsule(prototype_distribution=(1, num_classes)) caps.add(InputModule(signal_shape=x_train.shape[1], trainable=False, init_diss_initializer='zeros')) caps.add(diss) caps.add(SqueezeRouting()) caps.add(DissimilarityTransformation(probability_transformation='flip', name='lvq_capsule')) outputs = caps(inputs) # pre-train the model pre_train_model = Model(inputs=inputs, outputs=diss.input[0]) diss_input = pre_train_model.predict(x_train, batch_size=batch_size) diss.pre_training(diss_input, y_train, capsule_inputs_are_equal=True) model = Model(inputs, outputs[2]) model.compile(loss=glvq_loss, optimizer=optimizers.Adam(lr=lr), metrics={'lvq_capsule': 'accuracy'}) model.summary()