Exemple #1
0
def multi_caps_layer(inputs, out_caps, pool, iter_num, log):
    # inputs [bs, caps_in, atoms]
    poses, probs = layers.Activation('squash', with_prob=True)(inputs)
    for i, out_num in enumerate(out_caps):
        prediction_caps = layers.CapsuleTransformDense(
            num_out=out_num,
            matrix=True,
            out_atom=0,
            share_weights=False,
            regularizer=kernel_regularizer)(poses)
        prediction_caps = keras.layers.BatchNormalization()(prediction_caps)
        log.add_hist('prediction_caps{}'.format(i + 1), prediction_caps)
        if pool == 'dynamic':
            poses, probs = layers.DynamicRouting(
                num_routing=iter_num,
                softmax_in=False,
                temper=1,
                activation='squash',
                pooling=False)(prediction_caps)
        elif pool == 'EM':
            poses, probs = layers.EMRouting(num_routing=iter_num)(
                (prediction_caps, probs))
        elif pool == 'FM':
            prediction_caps = layers.Activation('norm')(prediction_caps)
            poses, probs = layers.LastFMPool(
                axis=-3,
                activation='squash',
                shrink=False,
                stable=False,
                regularize=True,
                norm_pose=True if i == len(out_caps) - 1 else False,
                log=None)(prediction_caps)

        log.add_hist('prob{}'.format(i + 1), probs)
    return poses, probs
Exemple #2
0
def multi_caps_layer(inputs, out_caps, log):
    # inputs [bs, caps_in, atoms]
    poses, probs = layers.Activation('squash', with_prob=True)(inputs)
    for i, out_num in enumerate(out_caps):
        prediction_caps = layers.CapsuleTransformDense(num_out=out_num, matrix=True, out_atom=0,
                                                       share_weights=False,
                                                       regularizer=kernel_regularizer)(poses)
        prediction_caps = keras.layers.BatchNormalization()(prediction_caps)
        log.add_hist('prediction_caps{}'.format(i+1), prediction_caps)

        prediction_caps = layers.Activation('norm')(prediction_caps)
        poses, probs = layers.LastFMPool(axis=-3, activation='accumulate',
                                         shrink=False, stable=False, regularize=True,
                                         norm_pose=True if i==len(out_caps)-1 else False,
                                         log=None)(prediction_caps)

        log.add_hist('prob{}'.format(i+1), probs)
    return poses, probs
Exemple #3
0
def build(inputs, num_out, iter_num, pool, in_norm_fn):
    log = utils.TensorLog()
    conv1 = keras.layers.Conv2D(filters=64,
                                kernel_size=9,
                                strides=1,
                                padding='valid',
                                activation='relu')(inputs)
    pose, prob = layers.PrimaryCapsule(
        kernel_size=9,
        strides=2,
        padding='valid',
        groups=8,
        use_bias=True,
        atoms=8,
        activation=in_norm_fn,
        kernel_initializer=keras.initializers.he_normal())(conv1)
    transformed_caps = layers.CapsuleTransformDense(
        num_out=num_out,
        out_atom=16,
        share_weights=False,
        initializer=keras.initializers.glorot_normal())(pose)
    if pool == 'dynamic':
        pose, prob = layers.DynamicRouting(num_routing=iter_num,
                                           softmax_in=False,
                                           temper=1,
                                           activation='squash',
                                           pooling=False)(transformed_caps)
    elif pool == 'EM':
        pose, prob = layers.EMRouting(num_routing=iter_num)(
            (transformed_caps, prob))
    elif pool == 'FM':
        pose, prob = layers.LastFMPool(axis=-3,
                                       activation='accumulate',
                                       shrink=True,
                                       stable=False,
                                       log=log)(transformed_caps)

    log.add_hist('prob', prob)
    return pose, prob, log