def test_transformer_xl(self, two_stream, memory_length, reuse_length, tie_attention_biases, state, mask, segment): batch_size, num_heads, head_size, seq_length = 2, 12, 64, 8 hidden_size, num_predictions, inner_size = 24, 8, 12 num_layers = 3 data = create_mock_transformer_xl_data(include_biases=False, num_heads=num_heads, head_size=head_size, hidden_size=hidden_size, seq_length=seq_length, batch_size=batch_size, memory_length=memory_length, num_predictions=num_predictions, two_stream=two_stream, num_layers=num_layers, include_state=state, include_mask=mask, include_segment=segment) transformer_xl_layer = transformer_xl.TransformerXL( vocab_size=32000, num_layers=num_layers, head_size=head_size, hidden_size=hidden_size, num_attention_heads=num_heads, inner_size=inner_size, dropout_rate=0., attention_dropout_rate=0., initializer=tf.keras.initializers.RandomNormal(stddev=0.1), two_stream=two_stream, tie_attention_biases=tie_attention_biases, memory_length=memory_length, reuse_length=reuse_length, inner_activation="relu") attention_output, cached_memory_states = transformer_xl_layer(**data) if two_stream: self.assertEqual(attention_output.shape, [batch_size, num_predictions, hidden_size]) else: self.assertEqual(attention_output.shape, [batch_size, seq_length, hidden_size]) self.assertEqual(len(cached_memory_states), num_layers)
def test_get_config(self): transformer_xl_layer = transformer_xl.TransformerXL( vocab_size=32000, num_layers=12, hidden_size=36, head_size=12, num_attention_heads=12, inner_size=12, dropout_rate=0., attention_dropout_rate=0., initializer=tf.keras.initializers.RandomNormal(stddev=0.1), two_stream=False, tie_attention_biases=True, memory_length=0, reuse_length=0, inner_activation="relu") transformer_xl_config = transformer_xl_layer.get_config() new_transformer_xl = transformer_xl.TransformerXL.from_config( transformer_xl_config) self.assertEqual(transformer_xl_config, new_transformer_xl.get_config())
def __init__(self, vocab_size, num_layers, hidden_size, num_attention_heads, head_size, inner_size, dropout_rate, attention_dropout_rate, attention_type, bi_data, initializer, two_stream=False, tie_attention_biases=True, memory_length=None, clamp_length=-1, reuse_length=None, inner_activation="relu", use_cls_mask=False, embedding_width=None, **kwargs): super(XLNetBase, self).__init__(**kwargs) self._vocab_size = vocab_size self._initializer = initializer self._attention_type = attention_type self._num_layers = num_layers self._hidden_size = hidden_size self._num_attention_heads = num_attention_heads self._head_size = head_size self._inner_size = inner_size self._inner_activation = inner_activation self._dropout_rate = dropout_rate self._attention_dropout_rate = attention_dropout_rate self._tie_attention_biases = tie_attention_biases self._two_stream = two_stream self._memory_length = memory_length self._reuse_length = reuse_length self._bi_data = bi_data self._clamp_length = clamp_length self._use_cls_mask = use_cls_mask self._segment_embedding = None self._mask_embedding = None self._embedding_width = embedding_width if embedding_width is None: embedding_width = hidden_size self._embedding_layer = layers.OnDeviceEmbedding( vocab_size=self._vocab_size, embedding_width=embedding_width, initializer=self._initializer, dtype=tf.float32, name="word_embedding") self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self.embedding_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self.position_encoding = RelativePositionEncoding(self._hidden_size) self._transformer_xl = transformer_xl.TransformerXL( vocab_size=vocab_size, num_layers=num_layers, hidden_size=hidden_size, num_attention_heads=num_attention_heads, head_size=head_size, inner_size=inner_size, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, initializer=initializer, two_stream=two_stream, tie_attention_biases=tie_attention_biases, memory_length=memory_length, reuse_length=reuse_length, inner_activation=inner_activation, name="transformer_xl")