def create_and_check_xlnet_sequence_classif( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): model = XLNetForSequenceClassification(config) model.to(torch_device) model.eval() logits, mems_1 = model(input_ids_1) loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels) result = { "loss": loss, "mems_1": mems_1, "logits": logits, } self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual( list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1"]), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_sequence_classif( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels, ): model = XLNetForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids_1) result = model(input_ids_1, labels=sequence_labels) self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual( result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertListEqual( [mem.shape for mem in result.mems], [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers, )
def create_and_check_use_mems_train( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels, ): model = XLNetForSequenceClassification(config) model.to(torch_device) model.train() train_size = input_ids_1.shape[0] batch_size = 4 for i in range(train_size // batch_size + 1): input_ids = input_ids_1[i:(i + 1) * batch_size] labels = sequence_labels[i:(i + 1) * batch_size] outputs = model(input_ids=input_ids, labels=labels, return_dict=True) self.parent.assertIsNone(outputs.mems) self.parent.assertIsNotNone(outputs.loss)