def routing(self, inputs: tf.Tensor) -> tf.Tensor: """ Dynamic Routing Parameters ---------- inputs: tf.Tensor shape must be[batch, in_caps, out_cpas, out_len] Returns ------- tf.Tensor outputs, shape[batch, out_caps, out_len] """ # b = tf.constant(np.zeros(shape=(inputs.shape[1], inputs.shape[2])), name='b') # b [in_caps,out_caps] with tf.variable_scope('routing'): with tf.variable_scope('iter_1'): c = tf.nn.softmax(self.b) # c [in_caps,out_caps] s = tf.einsum('jk,ijkq->ikq', c, inputs) # s [batch,out_caps,out_len] v = squash(s) # v [batch,out_caps,out_len] self.b = tf.add(self.b, tf.einsum('iqjk,ijk->qj', inputs, v)) with tf.variable_scope('iter_2'): c = tf.nn.softmax(self.b) s = tf.einsum('jk,ijkq->ikq', c, inputs) v = squash(s) self.b = tf.add(self.b, tf.einsum('iqjk,ijk->qj', inputs, v)) with tf.variable_scope('iter_3'): c = tf.nn.softmax(self.b) s = tf.einsum('jk,ijkq->ikq', c, inputs) if self.use_bias: s = tf.add(s, self.bias) v = squash(s) return v
def call(self, inputs): # inputs [batch,caps,vec_len] if self.use_routing: u_hat = tf.einsum('jpkq,ijk->ijpq', self.W, inputs) # u_hat [batch,in_caps,out_caps,out_len] outputs = self.routing(u_hat) else: outputs = tf.einsum('jpkq,ijk->ipq', W, inputs) # outputs [batch,out_caps,out_len] if self.use_bias: outputs = tf.add(outputs, bias) if self.activation is not None: outputs = self.activation(u_hat) return outputs
def decoder(caps_out, y): with tf.variable_scope('Decoder'): mask = tf.einsum('ijk,ij->ik', caps_out, y) fc1 = keras.layers.Dense(units=512)(mask) fc2 = keras.layers.Dense(units=1024)(fc1) decoded = keras.layers.Dense(units=784, activation=tf.nn.sigmoid)(fc2) return decoded
def test_u_hat_c(): batch = 16 tf.set_random_seed(1) b = tf.get_variable('b', shape=(1152, 10), dtype=tf.float32, initializer=tf.initializers.random_normal()) c = tf.nn.softmax(b, name='c') u_hat = tf.get_variable('u_hat', shape=(batch, 1152, 10, 16), dtype=tf.float32, initializer=tf.initializers.random_normal()) s = tf.einsum('jk,ijkq->ikq', c, u_hat) print(s.shape) assert s.shape == [16, 10, 16]
def test_u_hat(): batch = 16 tf.set_random_seed(1) W = tf.get_variable('W', shape=(1152, 10, 8, 16), dtype=tf.float32, initializer=tf.initializers.random_normal()) u = tf.get_variable('u', shape=(batch, 1152, 8), dtype=tf.float32, initializer=tf.initializers.random_normal()) u_hat = tf.einsum('jpkq,ijk->ijpq', W, u) print(u_hat.shape) assert u_hat.shape == [16, 1152, 10, 16]
def test_update_b(): batch = 16 tf.set_random_seed(1) b = tf.get_variable('b', shape=(1152, 10), dtype=tf.float32, initializer=tf.initializers.random_normal()) u_hat = tf.get_variable('u_hat', shape=(batch, 1152, 10, 16), dtype=tf.float32, initializer=tf.initializers.random_normal()) v = tf.get_variable('v', shape=(batch, 10, 16), dtype=tf.float32, initializer=tf.initializers.random_normal()) delat = tf.einsum('iqjk,ijk->qj', u_hat, v) b = b + delat print(b.shape) assert b.shape == [1152, 10]
def test_einsum(W: tf.Tensor, u: tf.Tensor) -> tf.Tensor: return tf.einsum('ij,aki->akj', W, u)