Esempio n. 1
0
 def testLocalUnmaskedAttention2D(self):
     x = np.random.rand(5, 4, 25, 25, 16)
     y = np.random.rand(5, 4, 25, 25, 16)
     with self.test_session() as session:
         a = common_attention.local_attention_2d(
             tf.constant(x, dtype=tf.float32),
             tf.constant(y, dtype=tf.float32),
             tf.constant(y, dtype=tf.float32),
             block_length=4,
             filter_flange=3)
         session.run(tf.global_variables_initializer())
         res = session.run(a)
     self.assertEqual(res.shape, (5, 4, 25, 25, 16))
 def testLocalUnmaskedAttention2DMatchingBlockLength(self):
     x = np.random.rand(5, 4, 25, 25, 16)
     y = np.random.rand(5, 4, 25, 25, 16)
     with self.test_session() as session:
         a = common_attention.local_attention_2d(
             tf.constant(x, dtype=tf.float32),
             tf.constant(y, dtype=tf.float32),
             tf.constant(y, dtype=tf.float32),
             query_shape=(5, 5),
             memory_flange=(3, 3))
         session.run(tf.global_variables_initializer())
         res = session.run(a)
     self.assertEqual(res.shape, (5, 4, 25, 25, 16))
 def testLocalUnmaskedAttention2D(self):
   x = np.random.rand(5, 4, 25, 25, 16)
   y = np.random.rand(5, 4, 25, 25, 16)
   with self.test_session() as session:
     a = common_attention.local_attention_2d(
         tf.constant(x, dtype=tf.float32),
         tf.constant(y, dtype=tf.float32),
         tf.constant(y, dtype=tf.float32),
         query_shape=(4, 4),
         memory_flange=(3, 3))
     session.run(tf.global_variables_initializer())
     res = session.run(a)
   self.assertEqual(res.shape, (5, 4, 25, 25, 16))
Esempio n. 4
0
 def testLocalUnmaskedAttention2DMatchingBlockLength(self):
     x = np.random.rand(5, 4, 25, 25, 16)
     y = np.random.rand(5, 4, 25, 25, 16)
     a = common_attention.local_attention_2d(tf.constant(x,
                                                         dtype=tf.float32),
                                             tf.constant(y,
                                                         dtype=tf.float32),
                                             tf.constant(y,
                                                         dtype=tf.float32),
                                             query_shape=(5, 5),
                                             memory_flange=(3, 3))
     res = self.evaluate(a)
     self.assertEqual(res.shape, (5, 4, 25, 25, 16))
    def testLocalUnmaskedAttention2D(self, batch, heads, length, depth_k,
                                     depth_v, query_shape):
        if batch is None:
            batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
        q = tf.random_normal([batch, heads, length, length, depth_k])
        k = tf.random_normal([batch, heads, length, length, depth_k])
        v = tf.random_normal([batch, heads, length, length, depth_v])
        output = common_attention.local_attention_2d(q,
                                                     k,
                                                     v,
                                                     query_shape=query_shape,
                                                     memory_flange=(3, 3))
        if isinstance(batch, tf.Tensor):
            batch, res = self.evaluate([batch, output])
        else:
            res = self.evaluate(output)

        self.assertEqual(res.shape, (batch, heads, length, length, depth_v))
Esempio n. 6
0
def multihead_attention(x,
                        out_channel=64,
                        d_model=32,
                        n_heads=8,
                        query_shape=(128, 24),
                        memory_flange=(8, 8)):
    q = Conv2D(d_model, (3, 3),
               strides=(1, 1),
               padding="same",
               name="gen_q_conv")(x)
    k = Conv2D(d_model, (3, 3),
               strides=(1, 1),
               padding="same",
               name="gen_k_conv")(x)
    v = Conv2D(d_model, (3, 3),
               strides=(1, 1),
               padding="same",
               name="gen_v_conv")(x)

    q = split_heads_2d(q, n_heads)
    k = split_heads_2d(k, n_heads)
    v = split_heads_2d(v, n_heads)

    k_depth_per_head = d_model // n_heads
    q *= k_depth_per_head**-0.5
    """
    # local attetion 2d
    v_shape = K.int_shape(v)
    q = pad_to_multiple(q, query_shape)
    k = pad_to_multiple(k, query_shape)
    v = pad_to_multiple(v, query_shape)

    paddings = ((0, 0), (memory_flange[0], memory_flange[1]), (memory_flange[0], memory_flange[1]))
    k = L.ZeroPadding3D(padding=paddings)(k)
    v = L.ZeroPadding3D(padding=paddings)(v)
    
    # Set up query blocks
    q_indices = gather_indices_2d(q, query_shape, query_shape)
    q_new = gather_blocks_2d(q, q_indices)

    # Set up key and value blocks
    memory_shape = (query_shape[0] + 2*memory_flange[0],
                    query_shape[1] + 2*memory_flange[1])
    k_and_v_indices = gather_indices_2d(k, memory_shape, query_shape)
    k_new = gather_blocks_2d(k, k_and_v_indices)
    v_new = gather_blocks_2d(v, k_and_v_indices)

    output = dot_attention(q_new, k_new, v_new)

    # Put output back into original shapes
    padded_shape = K.shape(q)
    output = scatter_blocks_2d(output, q_indices, padded_shape) 

    # Remove padding
    output = K.slice(output, [0, 0, 0, 0, 0], [-1, -1, v_shape[2], v_shape[3], -1])
    """

    output = local_attention_2d(q,
                                k,
                                v,
                                query_shape=query_shape,
                                memory_flange=memory_flange)

    output = combine_heads_2d(output)
    output = Conv2D(out_channel, (3, 3),
                    strides=(1, 1),
                    padding="same",
                    use_bias=False)(output)

    return output