示例#1
0
    def testMaskedScaledDotAttention(self):
        batch_size = 3
        num_heads = 8
        queries_length = [8, 6, 10]
        depth = 20

        queries = tf.placeholder_with_default(
            np.random.randn(batch_size, num_heads, max(queries_length),
                            depth).astype(np.float32),
            shape=(None, num_heads, None, depth))

        mask = transformer.build_future_mask(queries_length,
                                             num_heads=num_heads)
        context, attn = transformer.dot_product_attention(
            queries,
            queries,
            queries,
            tf.estimator.ModeKeys.PREDICT,
            mask=mask)

        with self.test_session() as sess:
            context, attn = sess.run([context, attn])
            illegal_connections = np.triu_indices(max(queries_length), 1)
            for i in range(batch_size):
                for h in range(num_heads):
                    self.assertEqual(0.0,
                                     np.sum(attn[i, h][illegal_connections]))
示例#2
0
  def testScaledDotAttention(self):
    batch_size = 3
    num_heads = 8
    values_length = [5, 3, 7]
    queries_length = [8, 6, 10]
    depth = 20

    queries = tf.placeholder_with_default(
        np.random.randn(batch_size, num_heads, max(queries_length), depth).astype(np.float32),
        shape=(None, num_heads, None, depth))
    values = tf.placeholder_with_default(
        np.random.randn(batch_size, num_heads, max(values_length), depth).astype(np.float32),
        shape=(None, num_heads, None, depth))
    keys = values

    mask = transformer.build_sequence_mask(values_length, num_heads=num_heads)
    context, attn = transformer.dot_product_attention(
        queries,
        keys,
        values,
        tf.estimator.ModeKeys.PREDICT,
        mask=mask)

    with self.test_session() as sess:
      context, attn = sess.run([context, attn])
      self.assertTupleEqual(
          (batch_size, num_heads, max(queries_length), depth), context.shape)
      self.assertTupleEqual(
          (batch_size, num_heads, max(queries_length), max(values_length)), attn.shape)

      for i in range(batch_size):
        length = values_length[i]
        padding_length = max(values_length) - length
        if padding_length > 0:
          self.assertEqual(0.0, np.sum(attn[i, :, :, length:max(values_length)]))
示例#3
0
  def testMaskedScaledDotAttention(self):
    batch_size = 3
    num_heads = 8
    queries_length = [8, 6, 10]
    depth = 20

    queries = tf.placeholder_with_default(
        np.random.randn(batch_size, num_heads, max(queries_length), depth).astype(np.float32),
        shape=(None, num_heads, None, depth))

    mask = transformer.build_future_mask(queries_length, num_heads=num_heads)
    context, attn = transformer.dot_product_attention(
        queries,
        queries,
        queries,
        tf.estimator.ModeKeys.PREDICT,
        mask=mask)

    with self.test_session() as sess:
      context, attn = sess.run([context, attn])
      illegal_connections = np.triu_indices(max(queries_length), 1)
      for i in range(batch_size):
        for h in range(num_heads):
          self.assertEqual(0.0, np.sum(attn[i, h][illegal_connections]))