def test_positional_encoding_stream(self): pos_encoding = nn_layers.PositionalEncoding(initializer='ones', cache_encoding=False) inputs = tf.range(4, dtype=tf.float32) + 1. inputs = tf.reshape(inputs, [1, 4, 1, 1, 1]) inputs = tf.tile(inputs, [1, 1, 1, 1, 3]) expected, _ = pos_encoding(inputs) for num_splits in [1, 2, 4]: frames = tf.split(inputs, num_splits, axis=1) states = {} predicted = [] for frame in frames: output, states = pos_encoding(frame, states=states) predicted.append(output) predicted = tf.concat(predicted, axis=1) self.assertEqual(predicted.shape, expected.shape) self.assertAllClose(predicted, expected) self.assertAllClose(predicted, [[[[[1.0000000, 1.0000000, 2.0000000]]], [[[2.8414710, 2.0021544, 2.5403023]]], [[[3.9092975, 3.0043090, 2.5838532]]], [[[4.1411200, 4.0064630, 3.0100074]]]]])
def test_positional_encoding_bfloat16(self): pos_encoding = nn_layers.PositionalEncoding(initializer='ones') inputs = tf.ones([1, 4, 1, 1, 3], dtype=tf.bfloat16) outputs, _ = pos_encoding(inputs) expected = tf.constant([[[[[1.0000000, 1.0000000, 2.0000000]]], [[[1.8414710, 1.0021545, 1.5403023]]], [[[1.9092975, 1.0043088, 0.5838531]]], [[[1.1411200, 1.0064633, 0.0100075]]]]]) self.assertEqual(outputs.shape, expected.shape) self.assertAllClose(outputs, expected)
def test_positional_encoding(self): pos_encoding = nn_layers.PositionalEncoding(initializer='ones', cache_encoding=False) pos_encoding_cached = nn_layers.PositionalEncoding(initializer='ones', cache_encoding=True) inputs = tf.ones([1, 4, 1, 1, 3]) outputs, _ = pos_encoding(inputs) outputs_cached, _ = pos_encoding_cached(inputs) expected = tf.constant([[[[[1.0000000, 1.0000000, 2.0000000]]], [[[1.8414710, 1.0021545, 1.5403023]]], [[[1.9092975, 1.0043088, 0.5838531]]], [[[1.1411200, 1.0064633, 0.0100075]]]]]) self.assertEqual(outputs.shape, expected.shape) self.assertAllClose(outputs, expected) self.assertEqual(outputs.shape, outputs_cached.shape) self.assertAllClose(outputs, outputs_cached) inputs = tf.ones([1, 5, 1, 1, 3]) _ = pos_encoding(inputs)