def build(inputs, num_out, atoms): log = utils.TensorLog() backbone = res_blocks.build_resnet_backbone( inputs=inputs, layer_num=0, repetitions=[8, 8, 8], start_filters=16, arch='cifar', use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, bn_axis=-1, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, version='v2') log.add_hist('backbone', backbone) pri_caps = layers.PrimaryCapsule( kernel_size=5, strides=2, padding='same', groups=4, atoms=atoms, activation=None, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer)(backbone) pri_caps = keras.layers.BatchNormalization()(pri_caps) poses, probs = multi_caps_layer(pri_caps, [32, 16, num_out], log) # poses, probs = multi_caps_layer(pri_caps, [num_out], log) return poses, probs, log
def build(inputs, num_out, arch, method, m, iter): log = utils.TensorLog() feature = inputs for i, layer in enumerate(arch): if layer == 'M': feature = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(feature) else: feature = keras.layers.Conv2D(filters=layer, kernel_size=3, strides=1, padding='same', use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer)(feature) if method == 'bn': feature = keras.layers.BatchNormalization()(feature) elif method == 'zca': feature = normalization.DecorelationNormalization(m_per_group=m, decomposition='zca_wm')(feature) elif method == 'iter_norm': feature = normalization.DecorelationNormalization(m_per_group=m, decomposition='iter_norm_wm', iter_num=iter)(feature) log.add_hist('bn{}'.format(i+1), feature) feature = keras.layers.ReLU()(feature) feature = keras.layers.AveragePooling2D(pool_size=(2, 2))(feature) feature = keras.layers.Flatten()(feature) output = keras.layers.Dense(num_out)(feature) return output, log
def build(inputs, num_out, iter_num, pool, in_norm_fn): log = utils.TensorLog() conv1 = keras.layers.Conv2D(filters=64, kernel_size=9, strides=1, padding='valid', activation='relu')(inputs) pose, prob = layers.PrimaryCapsule( kernel_size=9, strides=2, padding='valid', groups=8, use_bias=True, atoms=8, activation=in_norm_fn, kernel_initializer=keras.initializers.he_normal())(conv1) transformed_caps = layers.CapsuleTransformDense( num_out=num_out, out_atom=16, share_weights=False, initializer=keras.initializers.glorot_normal())(pose) if pool == 'dynamic': pose, prob = layers.DynamicRouting(num_routing=iter_num, softmax_in=False, temper=1, activation='squash', pooling=False)(transformed_caps) elif pool == 'EM': pose, prob = layers.EMRouting(num_routing=iter_num)( (transformed_caps, prob)) elif pool == 'FM': pose, prob = layers.LastFMPool(axis=-3, activation='accumulate', shrink=True, stable=False, log=log)(transformed_caps) log.add_hist('prob', prob) return pose, prob, log
def build_model(shape, num_out, params): # optimizer = keras.optimizers.SGD(learning_rate=keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=0.001, # decay_steps=5000, # decay_rate=0.5), momentum=0.9) # 3e-3 20000 0.96 optimizer = keras.optimizers.Adam(0.0001) inputs = keras.Input(shape=shape) model_name = build_model_name(params) model_log = utils.TensorLog() pose, prob = build_encoder(inputs, num_out, params.caps.atoms, params.routing.iter_num, params.model.pool, model_log) encoder = keras.Model(inputs=inputs, outputs=(pose, prob), name='encoder') encoder.compile(optimizer=optimizer, metrics=[]) encoder.summary() labels = keras.Input(shape=(num_out, )) in_pose = keras.Input(shape=(num_out, cfg.caps.atoms)) in_prob = keras.Input(shape=(num_out, )) inputs_shape = inputs.get_shape().as_list() active_cap = layers.Mask(order=0, share=cfg.recons.share, out_num=num_out)((in_pose, in_prob, labels)) if cfg.recons.conv: decoder_layer = layers.DecoderConv( height=inputs_shape[1], width=inputs_shape[2], channel=inputs_shape[3], balance_factor=params.recons.balance_factor, base=10) else: decoder_layer = layers.Decoder( height=inputs_shape[1], width=inputs_shape[2], channel=inputs_shape[3], balance_factor=params.recons.balance_factor, layers=[512, 1024]) recons_loss, recons_img = decoder_layer((active_cap, inputs)) decoder = keras.Model(inputs=(in_pose, in_prob, inputs, labels), outputs=recons_img, name='decoder') decoder.compile(optimizer=optimizer, metrics=[]) decoder.summary() active_cap = layers.Mask(order=0, share=cfg.recons.share, out_num=num_out)((pose, prob, labels)) recons_loss, recons_img = decoder_layer((active_cap, inputs)) model_log.add_scalar('reconstruction_loss', recons_loss) image_out = tf.concat([inputs, recons_img], 1) model_log.add_image('recons_img', image_out) model = keras.Model(inputs=(inputs, labels), outputs=(prob, recons_img), name=model_name) model.compile(optimizer=optimizer, loss=keras.losses.CategoricalCrossentropy(from_logits=False), metrics=[]) model.summary() # lr_scheduler = keras.callbacks.LearningRateScheduler(schedule=lr_scheduler) # lr_scheduler.set_model(model) # callbacks = [lr_scheduler] model.callbacks = [] log_model = keras.Model(inputs=(inputs, labels), outputs=model_log.get_outputs(), name='model_log') model_log.set_model(log_model) return model, model_log, encoder, decoder
def build_model(shape, num_out, params): optimizer = keras.optimizers.Adam(0.0001) inputs = keras.Input(shape=shape) model_name = build_model_name(params) model_log = utils.TensorLog() pose, prob = build_encoder(inputs, num_out, params.caps.atoms, model_log) encoder = keras.Model(inputs=inputs, outputs=(pose, prob), name='encoder') encoder.compile(optimizer=optimizer, metrics=[]) encoder.summary() image1 = keras.Input(shape=shape) image2 = keras.Input(shape=shape) label1 = keras.Input(shape=(num_out,)) label2 = keras.Input(shape=(num_out,)) in_pose = keras.Input(shape=(num_out, cfg.caps.atoms)) in_prob = keras.Input(shape=(num_out,)) inputs_shape = image1.get_shape().as_list() active_cap1 = layers.Mask(order=0, share=cfg.recons.share)((in_pose, in_prob, label1)) active_cap2 = layers.Mask(order=1, share=cfg.recons.share)((in_pose, in_prob, label2)) if cfg.recons.conv: decoder_layer = layers.DecoderConv(height=inputs_shape[1], width=inputs_shape[2], channel=inputs_shape[3], balance_factor=params.recons.balance_factor, base=9, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer) else: decoder_layer = layers.Decoder(height=inputs_shape[1], width=inputs_shape[2], channel=inputs_shape[3], balance_factor=params.recons.balance_factor, layers=[512, 1024]) recons_loss1, recons_img1 = decoder_layer((active_cap1, image1)) recons_loss2, recons_img2 = decoder_layer((active_cap2, image2)) decoder = keras.Model(inputs=(in_pose, in_prob, image1, image2, label1, label2), outputs=(recons_img1, recons_img2), name='decoder') decoder.compile(optimizer=optimizer, metrics=[]) decoder.summary() active_cap1 = layers.Mask(order=0, share=cfg.recons.share)((pose, prob, label1)) active_cap2 = layers.Mask(order=1, share=cfg.recons.share)((pose, prob, label2)) recons_loss1, recons_img1 = decoder_layer((active_cap1, image1)) recons_loss2, recons_img2 = decoder_layer((active_cap2, image2)) recons_loss = recons_loss1 + recons_loss2 model_log.add_scalar('reconstruction_loss', recons_loss) image_recons = tf.concat([tf.zeros_like(recons_img1), recons_img1, recons_img2], axis=-1) image_merge_ori = tf.tile(inputs, multiples=[1, 1, 1, 3]) image_merge = tf.concat([image_merge_ori, image_recons], axis=1) model_log.add_image('recons_img', image_merge) model = keras.Model(inputs=(inputs, image1, image2, label1, label2), outputs=(prob, recons_img1, recons_img2), name=model_name) model.compile(optimizer=optimizer, loss=losses.MarginLoss(False, 0.9, 0.1, 0.5), # loss=keras.losses.CategoricalCrossentropy(from_logits=True), metrics=[]) model.summary() # lr_scheduler = keras.callbacks.LearningRateScheduler(schedule=lr_scheduler) # lr_scheduler.set_model(model) # callbacks = [lr_scheduler] model.callbacks = [] log_model = keras.Model(inputs=(inputs, image1, image2, label1, label2), outputs=model_log.get_outputs(), name='model_log') model_log.set_model(log_model) return model, model_log, encoder, decoder
def test_build(): tf.keras.backend.set_learning_phase(1) inputs = tf.random.normal([128, 32, 32, 1]) labels = tf.random.uniform([128, ], 0, 5, tf.int32) labels = tf.one_hot(labels, 5) outputs = build_encoder(inputs, 5, 16, 3, 'FM', utils.TensorLog())