Exemple #1
0
    def test_xlnet_tensor_call(self):
        """Validates that the Keras object can be invoked."""
        seq_length = 4
        batch_size = 2
        num_predictions = 2
        # Build a simple XLNet based network to use with the XLNet trainer.
        xlnet_base = _get_xlnet_base()

        # Create an XLNet trainer with the created network.
        xlnet_trainer_model = xlnet.XLNetPretrainer(network=xlnet_base)

        sequence_shape = (batch_size, seq_length)
        inputs = dict(
            input_word_ids=np.random.randint(10,
                                             size=sequence_shape,
                                             dtype='int32'),
            input_type_ids=np.random.randint(2,
                                             size=sequence_shape,
                                             dtype='int32'),
            input_mask=np.random.randint(2,
                                         size=sequence_shape).astype('int32'),
            permutation_mask=np.random.randint(
                2, size=(batch_size, seq_length, seq_length)).astype('int32'),
            target_mapping=np.random.randint(10,
                                             size=(num_predictions,
                                                   seq_length),
                                             dtype='int32'),
            masked_tokens=np.random.randint(10,
                                            size=sequence_shape,
                                            dtype='int32'))
        xlnet_trainer_model(inputs)
Exemple #2
0
    def test_serialize_deserialize(self):
        """Validates that the XLNet trainer can be serialized and deserialized."""
        # Build a simple XLNet based network to use with the XLNet trainer.
        xlnet_base = _get_xlnet_base()

        # Create an XLNet trainer with the created network.
        xlnet_trainer_model = xlnet.XLNetPretrainer(
            network=xlnet_base,
            mlm_activation='gelu',
            mlm_initializer='random_normal')

        # Create another XLNet trainer via serialization and deserialization.
        config = xlnet_trainer_model.get_config()
        new_xlnet_trainer_model = xlnet.XLNetPretrainer.from_config(config)

        # Validate that the config can be forced to JSON.
        _ = new_xlnet_trainer_model.to_json()

        # If serialization was successful, then the new config should match the old.
        self.assertAllEqual(xlnet_trainer_model.get_config(),
                            new_xlnet_trainer_model.get_config())
Exemple #3
0
    def test_xlnet_trainer(self):
        """Validates that the Keras object can be created."""
        seq_length = 4
        num_predictions = 2
        # Build a simple XLNet based network to use with the XLNet trainer.
        xlnet_base = _get_xlnet_base()

        # Create an XLNet trainer with the created network.
        xlnet_trainer_model = xlnet.XLNetPretrainer(network=xlnet_base)
        inputs = dict(
            input_word_ids=tf.keras.layers.Input(shape=(seq_length, ),
                                                 dtype=tf.int32,
                                                 name='input_word_ids'),
            input_type_ids=tf.keras.layers.Input(shape=(seq_length, ),
                                                 dtype=tf.int32,
                                                 name='input_type_ids'),
            input_mask=tf.keras.layers.Input(shape=(seq_length, ),
                                             dtype=tf.int32,
                                             name='input_mask'),
            permutation_mask=tf.keras.layers.Input(shape=(
                seq_length,
                seq_length,
            ),
                                                   dtype=tf.int32,
                                                   name='permutation_mask'),
            target_mapping=tf.keras.layers.Input(shape=(num_predictions,
                                                        seq_length),
                                                 dtype=tf.int32,
                                                 name='target_mapping'),
            masked_tokens=tf.keras.layers.Input(shape=(seq_length, ),
                                                dtype=tf.int32,
                                                name='masked_tokens'))
        logits, _ = xlnet_trainer_model(inputs)

        # [None, hidden_size, vocab_size]
        expected_output_shape = [None, 4, 100]
        self.assertAllEqual(expected_output_shape, logits.shape.as_list())