def create_and_check_xlnet_base_model( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, ): model = TFXLNetModel(config) inputs = {"input_ids": input_ids_1, "input_mask": input_mask, "token_type_ids": segment_ids} result = model(inputs) inputs = [input_ids_1, input_mask] result = model(inputs) config.use_mems_eval = False model = TFXLNetModel(config) no_mems_outputs = model(inputs) self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertListEqual( [mem.shape for mem in result.mems], [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers, )
def test_model_from_pretrained(self): for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = TFXLNetModel.from_pretrained(model_name) self.assertIsNotNone(model)