コード例 #1
0
    def output_capsules(self, rc_capsules, contrib):
        '''compute the output capsules

        args:
            rc_capsules: the rate coded capsules
                [batch_size x num_capsules x capsule_dim]
            contrib: the conttibution of each timestep in the rc capsules
                [batch_size x time x num_capsules x capsule_dim]

        returns:
            the output_capsules [batch_size x num_capsules x capsule_dim]
            the alignment of the output capsules to the timesteps
                [batch_size x time x num_capsules]
        '''

        with tf.variable_scope('output_capsules'):

            capsules = tf.identity(rc_capsules, 'rc_capsules')
            contrib = tf.identity(contrib, 'contrib')
            num_capsules = capsules.shape[1].value
            capsule_dim = capsules.shape[2].value

            for l in range(int(self.conf['num_rc_layers'])):
                with tf.variable_scope('layer%d' % l):

                    num_capsules /= int(self.conf['capsule_ratio'])
                    capsule_dim *= int(self.conf['capsule_ratio'])

                    layer = layers.Capsule(num_capsules=num_capsules,
                                           capsule_dim=capsule_dim,
                                           routing_iters=int(
                                               self.conf['routing_iters']))

                    capsules = layer(capsules)

                    #get the predictions for the contributions
                    contrib_predict = layer.predict(contrib)

                    #get the final routing logits
                    logits = tf.get_default_graph().get_tensor_by_name(
                        layer.scope_name + '/cluster/while/Exit_1:0')

                    #get the final squash factor
                    sf = tf.get_default_graph().get_tensor_by_name(
                        layer.scope_name + '/cluster/squash/div_1:0')

                    #get the routing weight
                    weights = layer.probability_fn(logits)

                    weights *= tf.transpose(sf, [0, 2, 1])
                    weights = tf.expand_dims(tf.expand_dims(weights, 1), 4)

                    contrib = tf.reduce_sum(contrib_predict * weights, 2)

            alignment = tf.reduce_sum(contrib * tf.expand_dims(capsules, 1),
                                      3,
                                      name='alignment')
            capsules = tf.identity(capsules, 'output_capsules')

        return capsules, alignment
コード例 #2
0
ファイル: simple_model.py プロジェクト: zx-/CapsNet
def simple_caps_net(inputs):
    x = tf.expand_dims(inputs, -1)
    x = tf.keras.layers.Conv2D(filters=256,
                               kernel_size=9,
                               padding='valid',
                               data_format='channels_last',
                               activation='relu')(x)
    x = caps_layers.PrimaryCaps(conv_units=8,
                                channels=32,
                                kernel_size=9,
                                strides=2)(x)
    x = caps_layers.Capsule(capsules=10, capsule_units=16)(x)
    return x