def test_transformer_block(self): hidden_size = 128 filter_size = 128 num_heads = 2 dropout = 0.1 batch_size = 3 input_len = 10 block = transformer_block.TransformerBlock(hidden_size, filter_size, num_heads, dropout) output = block(True, tf.ones((batch_size, input_len, hidden_size)), tf.ones((batch_size, 1, input_len)), None, None) self.assertEqual(output.shape, [batch_size, input_len, hidden_size])
def __init__(self, vocab_size, hidden_size, filter_size, num_heads, num_encoder_layers, num_decoder_layers, label_smoothing, dropout): self._dtype = tf.float32 self._embedding_layer = embedding.Embedding(vocab_size, hidden_size, "weights", self._dtype) block_fn = lambda: transformer_block.TransformerBlock( hidden_size, filter_size, num_heads, dropout) self._encoder_layers = [block_fn() for _ in range(num_encoder_layers)] self._decoder_layers = [block_fn() for _ in range(num_decoder_layers)] self._dropout_fn = lambda x, training: tf.compat.v2.nn.dropout( x, dropout, noise_shape=[x.shape[0], 1, x.shape[2]]) if training else x self._vocab_size = vocab_size self._num_heads = num_heads self._label_smoothing = label_smoothing self._decoder_scope_name = "decoder"