def test_xlnet_model(self): batch_size = 2 seq_length = 8 num_predictions = 2 hidden_size = 4 xlnet_model = xlnet_base.XLNetBase( vocab_size=32000, num_layers=2, hidden_size=hidden_size, num_attention_heads=2, head_size=2, inner_size=2, dropout_rate=0., attention_dropout_rate=0., attention_type="bi", bi_data=True, initializer=tf.keras.initializers.RandomNormal(stddev=0.1), two_stream=False, tie_attention_biases=True, reuse_length=0, inner_activation="relu") input_data = self._generate_data(batch_size=batch_size, seq_length=seq_length, num_predictions=num_predictions) model_output = xlnet_model(**input_data) self.assertEqual(model_output[0].shape, (batch_size, seq_length, hidden_size))
def test_get_config(self): xlnet_model = xlnet_base.XLNetBase( vocab_size=32000, num_layers=12, hidden_size=36, num_attention_heads=12, head_size=12, inner_size=12, dropout_rate=0., attention_dropout_rate=0., attention_type="bi", bi_data=True, initializer=tf.keras.initializers.RandomNormal(stddev=0.1), two_stream=False, tie_attention_biases=True, memory_length=0, reuse_length=0, inner_activation="relu") config = xlnet_model.get_config() new_xlnet = xlnet_base.XLNetBase.from_config(config) self.assertEqual(config, new_xlnet.get_config())