def test_xlnet_model(self):
     batch_size = 2
     seq_length = 8
     num_predictions = 2
     hidden_size = 4
     xlnet_model = xlnet_base.XLNetBase(
         vocab_size=32000,
         num_layers=2,
         hidden_size=hidden_size,
         num_attention_heads=2,
         head_size=2,
         inner_size=2,
         dropout_rate=0.,
         attention_dropout_rate=0.,
         attention_type="bi",
         bi_data=True,
         initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
         two_stream=False,
         tie_attention_biases=True,
         reuse_length=0,
         inner_activation="relu")
     input_data = self._generate_data(batch_size=batch_size,
                                      seq_length=seq_length,
                                      num_predictions=num_predictions)
     model_output = xlnet_model(**input_data)
     self.assertEqual(model_output[0].shape,
                      (batch_size, seq_length, hidden_size))
 def test_get_config(self):
     xlnet_model = xlnet_base.XLNetBase(
         vocab_size=32000,
         num_layers=12,
         hidden_size=36,
         num_attention_heads=12,
         head_size=12,
         inner_size=12,
         dropout_rate=0.,
         attention_dropout_rate=0.,
         attention_type="bi",
         bi_data=True,
         initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
         two_stream=False,
         tie_attention_biases=True,
         memory_length=0,
         reuse_length=0,
         inner_activation="relu")
     config = xlnet_model.get_config()
     new_xlnet = xlnet_base.XLNetBase.from_config(config)
     self.assertEqual(config, new_xlnet.get_config())