def create_and_check_xlnet_for_token_classification( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, ): config.num_labels = input_ids_1.shape[1] model = TFXLNetForTokenClassification(config) inputs = { "input_ids": input_ids_1, "attention_mask": input_mask, # 'token_type_ids': token_type_ids } logits, mems_1 = model(inputs) result = { "mems_1": [mem.numpy() for mem in mems_1], "logits": logits.numpy(), } self.parent.assertListEqual( list(result["logits"].shape), [self.batch_size, self.seq_length, config.num_labels]) self.parent.assertListEqual( list(list(mem.shape) for mem in result["mems_1"]), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, )
def create_and_check_xlnet_for_token_classification( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, ): config.num_labels = input_ids_1.shape[1] model = TFXLNetForTokenClassification(config) inputs = { "input_ids": input_ids_1, "attention_mask": input_mask, # 'token_type_ids': token_type_ids } result = model(inputs) self.parent.assertEqual( result.logits.shape, (self.batch_size, self.seq_length, config.num_labels)) self.parent.assertListEqual( [mem.shape for mem in result.mems], [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers, )