def testSplitAndCombineHeads(self): batch_size = 3 length = [5, 3, 7] num_heads = 8 depth = 20 inputs = tf.random.normal( [batch_size, max(length), depth * num_heads], dtype=tf.float32) split = transformer.split_heads(inputs, num_heads) combined = transformer.combine_heads(split) inputs, combined = self.evaluate([inputs, combined]) self.assertAllEqual(inputs, combined)
def testSplitAndCombineHeads(self): batch_size = 3 length = [5, 3, 7] num_heads = 8 depth = 20 inputs = tf.convert_to_tensor( np.random.randn(batch_size, max(length), depth * num_heads).astype(np.float32)) split = transformer.split_heads(inputs, num_heads) combined = transformer.combine_heads(split) inputs, combined = self.evaluate([inputs, combined]) self.assertAllEqual(inputs, combined)
def testCombineHeads(self): batch_size = 3 length = [5, 3, 7] num_heads = 8 depth = 20 inputs = tf.random.normal( [batch_size, num_heads, max(length), depth], dtype=tf.float32) outputs = transformer.combine_heads(inputs) static_shape = outputs.shape self.assertEqual(depth * num_heads, static_shape[-1]) outputs = self.evaluate(outputs) self.assertAllEqual( [batch_size, max(length), depth * num_heads], outputs.shape)
def testSplitAndCombineHeads(self): batch_size = 3 length = [5, 3, 7] num_heads = 8 depth = 20 inputs = tf.placeholder_with_default( np.random.randn(batch_size, max(length), depth * num_heads).astype(np.float32), shape=(None, None, depth * num_heads)) split = transformer.split_heads(inputs, num_heads) combined = transformer.combine_heads(split) with self.test_session() as sess: inputs, combined = sess.run([inputs, combined]) self.assertAllEqual(inputs, combined)
def testCombineHeads(self): batch_size = 3 length = [5, 3, 7] num_heads = 8 depth = 20 inputs = tf.convert_to_tensor( np.random.randn(batch_size, num_heads, max(length), depth).astype(np.float32)) outputs = transformer.combine_heads(inputs) static_shape = outputs.get_shape().as_list() self.assertEqual(depth * num_heads, static_shape[-1]) outputs = self.evaluate(outputs) self.assertAllEqual( [batch_size, max(length), depth * num_heads], outputs.shape)
def testCombineHeads(self): batch_size = 3 length = [5, 3, 7] num_heads = 8 depth = 20 inputs = tf.placeholder_with_default( np.random.randn(batch_size, num_heads, max(length), depth).astype(np.float32), shape=(None, num_heads, None, depth)) outputs = transformer.combine_heads(inputs) static_shape = outputs.get_shape().as_list() self.assertEqual(depth * num_heads, static_shape[-1]) with self.test_session() as sess: outputs = sess.run(outputs) self.assertAllEqual([batch_size, max(length), depth * num_heads], outputs.shape)