Пример #1
0
    def testSplitAndCombineHeads(self):
        batch_size = 3
        length = [5, 3, 7]
        num_heads = 8
        depth = 20

        inputs = tf.random.normal(
            [batch_size, max(length), depth * num_heads], dtype=tf.float32)
        split = transformer.split_heads(inputs, num_heads)
        combined = transformer.combine_heads(split)
        inputs, combined = self.evaluate([inputs, combined])
        self.assertAllEqual(inputs, combined)
Пример #2
0
    def testSplitAndCombineHeads(self):
        batch_size = 3
        length = [5, 3, 7]
        num_heads = 8
        depth = 20

        inputs = tf.convert_to_tensor(
            np.random.randn(batch_size, max(length),
                            depth * num_heads).astype(np.float32))
        split = transformer.split_heads(inputs, num_heads)
        combined = transformer.combine_heads(split)
        inputs, combined = self.evaluate([inputs, combined])
        self.assertAllEqual(inputs, combined)
Пример #3
0
    def testCombineHeads(self):
        batch_size = 3
        length = [5, 3, 7]
        num_heads = 8
        depth = 20

        inputs = tf.random.normal(
            [batch_size, num_heads, max(length), depth], dtype=tf.float32)
        outputs = transformer.combine_heads(inputs)

        static_shape = outputs.shape
        self.assertEqual(depth * num_heads, static_shape[-1])
        outputs = self.evaluate(outputs)
        self.assertAllEqual(
            [batch_size, max(length), depth * num_heads], outputs.shape)
Пример #4
0
  def testSplitAndCombineHeads(self):
    batch_size = 3
    length = [5, 3, 7]
    num_heads = 8
    depth = 20

    inputs = tf.placeholder_with_default(
        np.random.randn(batch_size, max(length), depth * num_heads).astype(np.float32),
        shape=(None, None, depth * num_heads))
    split = transformer.split_heads(inputs, num_heads)
    combined = transformer.combine_heads(split)

    with self.test_session() as sess:
      inputs, combined = sess.run([inputs, combined])
      self.assertAllEqual(inputs, combined)
Пример #5
0
    def testCombineHeads(self):
        batch_size = 3
        length = [5, 3, 7]
        num_heads = 8
        depth = 20

        inputs = tf.convert_to_tensor(
            np.random.randn(batch_size, num_heads, max(length),
                            depth).astype(np.float32))
        outputs = transformer.combine_heads(inputs)

        static_shape = outputs.get_shape().as_list()
        self.assertEqual(depth * num_heads, static_shape[-1])
        outputs = self.evaluate(outputs)
        self.assertAllEqual(
            [batch_size, max(length), depth * num_heads], outputs.shape)
Пример #6
0
  def testCombineHeads(self):
    batch_size = 3
    length = [5, 3, 7]
    num_heads = 8
    depth = 20

    inputs = tf.placeholder_with_default(
        np.random.randn(batch_size, num_heads, max(length), depth).astype(np.float32),
        shape=(None, num_heads, None, depth))
    outputs = transformer.combine_heads(inputs)

    static_shape = outputs.get_shape().as_list()
    self.assertEqual(depth * num_heads, static_shape[-1])

    with self.test_session() as sess:
      outputs = sess.run(outputs)
      self.assertAllEqual([batch_size, max(length), depth * num_heads], outputs.shape)