コード例 #1
0
ファイル: model.py プロジェクト: zhen8838/Caps-Net-tf
    def routing(self, inputs: tf.Tensor) -> tf.Tensor:
        """ Dynamic Routing

        Parameters
        ----------
        inputs: tf.Tensor
            shape must be[batch, in_caps, out_cpas, out_len]

        Returns
        -------
        tf.Tensor
            outputs, shape[batch, out_caps, out_len]
        """
        # b = tf.constant(np.zeros(shape=(inputs.shape[1], inputs.shape[2])), name='b')  # b [in_caps,out_caps]
        with tf.variable_scope('routing'):
            with tf.variable_scope('iter_1'):
                c = tf.nn.softmax(self.b)  # c [in_caps,out_caps]
                s = tf.einsum('jk,ijkq->ikq', c,
                              inputs)  # s [batch,out_caps,out_len]
                v = squash(s)  # v [batch,out_caps,out_len]
                self.b = tf.add(self.b, tf.einsum('iqjk,ijk->qj', inputs, v))
            with tf.variable_scope('iter_2'):
                c = tf.nn.softmax(self.b)
                s = tf.einsum('jk,ijkq->ikq', c, inputs)
                v = squash(s)
                self.b = tf.add(self.b, tf.einsum('iqjk,ijk->qj', inputs, v))
            with tf.variable_scope('iter_3'):
                c = tf.nn.softmax(self.b)
                s = tf.einsum('jk,ijkq->ikq', c, inputs)
                if self.use_bias:
                    s = tf.add(s, self.bias)
                v = squash(s)
        return v
コード例 #2
0
ファイル: model.py プロジェクト: zhen8838/Caps-Net-tf
 def call(self, inputs):
     # inputs [batch,caps,vec_len]
     if self.use_routing:
         u_hat = tf.einsum('jpkq,ijk->ijpq', self.W,
                           inputs)  # u_hat [batch,in_caps,out_caps,out_len]
         outputs = self.routing(u_hat)
     else:
         outputs = tf.einsum('jpkq,ijk->ipq', W,
                             inputs)  # outputs [batch,out_caps,out_len]
         if self.use_bias:
             outputs = tf.add(outputs, bias)
         if self.activation is not None:
             outputs = self.activation(u_hat)
     return outputs
コード例 #3
0
ファイル: model.py プロジェクト: zhen8838/Caps-Net-tf
def decoder(caps_out, y):
    with tf.variable_scope('Decoder'):
        mask = tf.einsum('ijk,ij->ik', caps_out, y)
        fc1 = keras.layers.Dense(units=512)(mask)
        fc2 = keras.layers.Dense(units=1024)(fc1)
        decoded = keras.layers.Dense(units=784, activation=tf.nn.sigmoid)(fc2)
    return decoded
コード例 #4
0
ファイル: test_fuc.py プロジェクト: zhen8838/Caps-Net-tf
def test_u_hat_c():
    batch = 16
    tf.set_random_seed(1)
    b = tf.get_variable('b',
                        shape=(1152, 10),
                        dtype=tf.float32,
                        initializer=tf.initializers.random_normal())
    c = tf.nn.softmax(b, name='c')
    u_hat = tf.get_variable('u_hat',
                            shape=(batch, 1152, 10, 16),
                            dtype=tf.float32,
                            initializer=tf.initializers.random_normal())
    s = tf.einsum('jk,ijkq->ikq', c, u_hat)
    print(s.shape)
    assert s.shape == [16, 10, 16]
コード例 #5
0
ファイル: test_fuc.py プロジェクト: zhen8838/Caps-Net-tf
def test_u_hat():
    batch = 16
    tf.set_random_seed(1)
    W = tf.get_variable('W',
                        shape=(1152, 10, 8, 16),
                        dtype=tf.float32,
                        initializer=tf.initializers.random_normal())
    u = tf.get_variable('u',
                        shape=(batch, 1152, 8),
                        dtype=tf.float32,
                        initializer=tf.initializers.random_normal())

    u_hat = tf.einsum('jpkq,ijk->ijpq', W, u)
    print(u_hat.shape)
    assert u_hat.shape == [16, 1152, 10, 16]
コード例 #6
0
ファイル: test_fuc.py プロジェクト: zhen8838/Caps-Net-tf
def test_update_b():
    batch = 16
    tf.set_random_seed(1)
    b = tf.get_variable('b',
                        shape=(1152, 10),
                        dtype=tf.float32,
                        initializer=tf.initializers.random_normal())
    u_hat = tf.get_variable('u_hat',
                            shape=(batch, 1152, 10, 16),
                            dtype=tf.float32,
                            initializer=tf.initializers.random_normal())
    v = tf.get_variable('v',
                        shape=(batch, 10, 16),
                        dtype=tf.float32,
                        initializer=tf.initializers.random_normal())

    delat = tf.einsum('iqjk,ijk->qj', u_hat, v)
    b = b + delat
    print(b.shape)
    assert b.shape == [1152, 10]
コード例 #7
0
ファイル: test_fuc.py プロジェクト: zhen8838/Caps-Net-tf
def test_einsum(W: tf.Tensor, u: tf.Tensor) -> tf.Tensor:
    return tf.einsum('ij,aki->akj', W, u)