def test_serialize_deserialize(self):
    """Validates that the XLNet trainer can be serialized and deserialized."""
    # Build a simple XLNet based network to use with the XLNet trainer.
    xlnet_base = _get_xlnet_base()

    # Create an XLNet trainer with the created network.
    xlnet_trainer_model = xlnet.XLNetSpanLabeler(
        network=xlnet_base,
        start_n_top=2,
        end_n_top=2,
        initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
        span_labeling_activation='tanh',
        dropout_rate=0.1)

    # Create another XLNet trainer via serialization and deserialization.
    config = xlnet_trainer_model.get_config()
    new_xlnet_trainer_model = xlnet.XLNetSpanLabeler.from_config(
        config)

    # Validate that the config can be forced to JSON.
    _ = new_xlnet_trainer_model.to_json()

    # If serialization was successful, then the new config should match the old.
    self.assertAllEqual(xlnet_trainer_model.get_config(),
                        new_xlnet_trainer_model.get_config())
  def test_xlnet_trainer(self, top_n):
    """Validate that the Keras object can be created."""
    seq_length = 4
    # Build a simple XLNet based network to use with the XLNet trainer.
    xlnet_base = _get_xlnet_base()

    # Create an XLNet trainer with the created network.
    xlnet_trainer_model = xlnet.XLNetSpanLabeler(
        network=xlnet_base,
        start_n_top=top_n,
        end_n_top=top_n,
        initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
        span_labeling_activation='tanh',
        dropout_rate=0.1)
    inputs = dict(
        input_ids=tf.keras.layers.Input(
            shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
        segment_ids=tf.keras.layers.Input(
            shape=(seq_length,), dtype=tf.int32, name='segment_ids'),
        input_mask=tf.keras.layers.Input(
            shape=(seq_length,), dtype=tf.float32, name='input_mask'),
        position_mask=tf.keras.layers.Input(
            shape=(seq_length,), dtype=tf.float32, name='position_mask'),
        class_index=tf.keras.layers.Input(
            shape=(), dtype=tf.int32, name='class_index'),
        start_positions=tf.keras.layers.Input(
            shape=(), dtype=tf.int32, name='start_positions'))
    outputs, _ = xlnet_trainer_model(inputs)
    self.assertIsInstance(outputs, dict)

    # Test tensor value calls for the created model.
    batch_size = 2
    sequence_shape = (batch_size, seq_length)
    inputs = dict(
        input_ids=np.random.randint(10, size=sequence_shape, dtype='int32'),
        segment_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
        input_mask=np.random.randint(2, size=sequence_shape).astype('float32'),
        position_mask=np.random.randint(
            1, size=(sequence_shape)).astype('float32'),
        class_index=np.random.randint(1, size=(batch_size)).astype('uint8'),
        start_positions=tf.random.uniform(
            shape=(batch_size,), maxval=5, dtype=tf.int32))
    outputs, _ = xlnet_trainer_model(inputs)
    expected_inference_keys = {
        'start_top_log_probs', 'end_top_log_probs', 'class_logits',
        'start_top_index', 'end_top_index',
    }
    self.assertSetEqual(expected_inference_keys, set(outputs.keys()))

    outputs, _ = xlnet_trainer_model(inputs, training=True)
    self.assertIsInstance(outputs, dict)
    expected_train_keys = {
        'start_log_probs', 'end_log_probs', 'class_logits'
    }
    self.assertSetEqual(expected_train_keys, set(outputs.keys()))
    self.assertIsInstance(outputs, dict)