def create_and_check_xlnet_base_model(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
    ):
        model = TFXLNetModel(config)

        inputs = {"input_ids": input_ids_1, "input_mask": input_mask, "token_type_ids": segment_ids}
        result = model(inputs)

        inputs = [input_ids_1, input_mask]
        result = model(inputs)

        config.use_mems_eval = False
        model = TFXLNetModel(config)
        no_mems_outputs = model(inputs)
        self.parent.assertEqual(len(no_mems_outputs), 1)

        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertListEqual(
            [mem.shape for mem in result.mems],
            [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
        )
예제 #2
0
 def test_model_from_pretrained(self):
     for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
         model = TFXLNetModel.from_pretrained(model_name)
         self.assertIsNotNone(model)