def test_serialize_deserialize(self): """Validates that the XLNet trainer can be serialized and deserialized.""" # Build a simple XLNet based network to use with the XLNet trainer. xlnet_base = _get_xlnet_base() # Create an XLNet trainer with the created network. xlnet_trainer_model = xlnet.XLNetSpanLabeler( network=xlnet_base, start_n_top=2, end_n_top=2, initializer=tf.keras.initializers.RandomNormal(stddev=0.1), span_labeling_activation='tanh', dropout_rate=0.1) # Create another XLNet trainer via serialization and deserialization. config = xlnet_trainer_model.get_config() new_xlnet_trainer_model = xlnet.XLNetSpanLabeler.from_config( config) # Validate that the config can be forced to JSON. _ = new_xlnet_trainer_model.to_json() # If serialization was successful, then the new config should match the old. self.assertAllEqual(xlnet_trainer_model.get_config(), new_xlnet_trainer_model.get_config())
def test_xlnet_trainer(self, top_n): """Validate that the Keras object can be created.""" seq_length = 4 # Build a simple XLNet based network to use with the XLNet trainer. xlnet_base = _get_xlnet_base() # Create an XLNet trainer with the created network. xlnet_trainer_model = xlnet.XLNetSpanLabeler( network=xlnet_base, start_n_top=top_n, end_n_top=top_n, initializer=tf.keras.initializers.RandomNormal(stddev=0.1), span_labeling_activation='tanh', dropout_rate=0.1) inputs = dict( input_ids=tf.keras.layers.Input( shape=(seq_length,), dtype=tf.int32, name='input_word_ids'), segment_ids=tf.keras.layers.Input( shape=(seq_length,), dtype=tf.int32, name='segment_ids'), input_mask=tf.keras.layers.Input( shape=(seq_length,), dtype=tf.float32, name='input_mask'), position_mask=tf.keras.layers.Input( shape=(seq_length,), dtype=tf.float32, name='position_mask'), class_index=tf.keras.layers.Input( shape=(), dtype=tf.int32, name='class_index'), start_positions=tf.keras.layers.Input( shape=(), dtype=tf.int32, name='start_positions')) outputs, _ = xlnet_trainer_model(inputs) self.assertIsInstance(outputs, dict) # Test tensor value calls for the created model. batch_size = 2 sequence_shape = (batch_size, seq_length) inputs = dict( input_ids=np.random.randint(10, size=sequence_shape, dtype='int32'), segment_ids=np.random.randint(2, size=sequence_shape, dtype='int32'), input_mask=np.random.randint(2, size=sequence_shape).astype('float32'), position_mask=np.random.randint( 1, size=(sequence_shape)).astype('float32'), class_index=np.random.randint(1, size=(batch_size)).astype('uint8'), start_positions=tf.random.uniform( shape=(batch_size,), maxval=5, dtype=tf.int32)) outputs, _ = xlnet_trainer_model(inputs) expected_inference_keys = { 'start_top_log_probs', 'end_top_log_probs', 'class_logits', 'start_top_index', 'end_top_index', } self.assertSetEqual(expected_inference_keys, set(outputs.keys())) outputs, _ = xlnet_trainer_model(inputs, training=True) self.assertIsInstance(outputs, dict) expected_train_keys = { 'start_log_probs', 'end_log_probs', 'class_logits' } self.assertSetEqual(expected_train_keys, set(outputs.keys())) self.assertIsInstance(outputs, dict)