Example #1
0
 def create_and_check_xlnet_for_token_classification(
     self,
     config,
     input_ids_1,
     input_ids_2,
     input_ids_q,
     perm_mask,
     input_mask,
     target_mapping,
     segment_ids,
     lm_labels,
     sequence_labels,
     is_impossible_labels,
 ):
     config.num_labels = input_ids_1.shape[1]
     model = TFXLNetForTokenClassification(config)
     inputs = {
         "input_ids": input_ids_1,
         "attention_mask": input_mask,
         # 'token_type_ids': token_type_ids
     }
     logits, mems_1 = model(inputs)
     result = {
         "mems_1": [mem.numpy() for mem in mems_1],
         "logits": logits.numpy(),
     }
     self.parent.assertListEqual(
         list(result["logits"].shape),
         [self.batch_size, self.seq_length, config.num_labels])
     self.parent.assertListEqual(
         list(list(mem.shape) for mem in result["mems_1"]),
         [[self.seq_length, self.batch_size, self.hidden_size]] *
         self.num_hidden_layers,
     )
 def create_and_check_xlnet_for_token_classification(
     self,
     config,
     input_ids_1,
     input_ids_2,
     input_ids_q,
     perm_mask,
     input_mask,
     target_mapping,
     segment_ids,
     lm_labels,
     sequence_labels,
     is_impossible_labels,
 ):
     config.num_labels = input_ids_1.shape[1]
     model = TFXLNetForTokenClassification(config)
     inputs = {
         "input_ids": input_ids_1,
         "attention_mask": input_mask,
         # 'token_type_ids': token_type_ids
     }
     result = model(inputs)
     self.parent.assertEqual(
         result.logits.shape,
         (self.batch_size, self.seq_length, config.num_labels))
     self.parent.assertListEqual(
         [mem.shape for mem in result.mems],
         [(self.seq_length, self.batch_size, self.hidden_size)] *
         self.num_hidden_layers,
     )