Exemple #1
0
    def test_xlnet_tensor_call(self, num_classes):
        """Validates that the Keras object can be invoked."""
        seq_length = 4
        batch_size = 2
        # Build a simple XLNet based network to use with the XLNet trainer.
        xlnet_base = _get_xlnet_base()

        # Create an XLNet trainer with the created network.
        xlnet_trainer_model = xlnet.XLNetClassifier(
            network=xlnet_base,
            num_classes=num_classes,
            initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
            summary_type='last',
            dropout_rate=0.1)

        sequence_shape = (batch_size, seq_length)
        inputs = dict(
            input_word_ids=np.random.randint(10,
                                             size=sequence_shape,
                                             dtype='int32'),
            input_type_ids=np.random.randint(2,
                                             size=sequence_shape,
                                             dtype='int32'),
            input_mask=np.random.randint(2,
                                         size=sequence_shape).astype('int32'),
            permutation_mask=np.random.randint(
                2, size=(batch_size, seq_length, seq_length)).astype('int32'),
            masked_tokens=np.random.randint(10,
                                            size=sequence_shape,
                                            dtype='int32'))
        xlnet_trainer_model(inputs)
  def test_xlnet_trainer(self):
    """Validate that the Keras object can be created."""
    num_classes = 2
    seq_length = 4
    # Build a simple XLNet based network to use with the XLNet trainer.
    xlnet_base = _get_xlnet_base()

    # Create an XLNet trainer with the created network.
    xlnet_trainer_model = xlnet.XLNetClassifier(
        network=xlnet_base,
        num_classes=num_classes,
        initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
        summary_type='last',
        dropout_rate=0.1)
    inputs = dict(
        input_ids=tf.keras.layers.Input(
            shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
        segment_ids=tf.keras.layers.Input(
            shape=(seq_length,), dtype=tf.int32, name='segment_ids'),
        input_mask=tf.keras.layers.Input(
            shape=(seq_length,), dtype=tf.float32, name='input_mask'),
        permutation_mask=tf.keras.layers.Input(
            shape=(seq_length, seq_length,), dtype=tf.float32,
            name='permutation_mask'),
        masked_tokens=tf.keras.layers.Input(
            shape=(seq_length,), dtype=tf.float32, name='masked_tokens'))

    logits, _ = xlnet_trainer_model(inputs)

    expected_classification_shape = [None, num_classes]
    self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
Exemple #3
0
    def test_serialize_deserialize(self):
        """Validates that the XLNet trainer can be serialized and deserialized."""
        # Build a simple XLNet based network to use with the XLNet trainer.
        xlnet_base = _get_xlnet_base()

        # Create an XLNet trainer with the created network.
        xlnet_trainer_model = xlnet.XLNetClassifier(
            network=xlnet_base,
            num_classes=2,
            initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
            summary_type='last',
            dropout_rate=0.1)

        # Create another XLNet trainer via serialization and deserialization.
        config = xlnet_trainer_model.get_config()
        new_xlnet_trainer_model = xlnet.XLNetClassifier.from_config(config)

        # Validate that the config can be forced to JSON.
        _ = new_xlnet_trainer_model.to_json()

        # If serialization was successful, then the new config should match the old.
        self.assertAllEqual(xlnet_trainer_model.get_config(),
                            new_xlnet_trainer_model.get_config())