def create_and_check_xlnet_token_classif( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels, ): model = XLNetForTokenClassification(config) model.to(torch_device) model.eval() result = model(input_ids_1) result = model(input_ids_1, labels=token_labels) self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual( result.logits.shape, (self.batch_size, self.seq_length, self.type_sequence_label_size)) self.parent.assertListEqual( [mem.shape for mem in result.mems], [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers, )
def create_and_check_xlnet_token_classif( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): model = XLNetForTokenClassification(config) model.to(torch_device) model.eval() logits, mems_1 = model(input_ids_1) loss, logits, mems_1 = model(input_ids_1, labels=token_labels) result = { "loss": loss, "mems_1": mems_1, "logits": logits, } self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["logits"].size()), [ self.batch_size, self.seq_length, self.type_sequence_label_size ]) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1"]), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)