def create_and_check_xlnet_token_classif(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
        token_labels,
    ):
        model = XLNetForTokenClassification(config)
        model.to(torch_device)
        model.eval()

        result = model(input_ids_1)
        result = model(input_ids_1, labels=token_labels)

        self.parent.assertEqual(result.loss.shape, ())
        self.parent.assertEqual(
            result.logits.shape,
            (self.batch_size, self.seq_length, self.type_sequence_label_size))
        self.parent.assertListEqual(
            [mem.shape for mem in result.mems],
            [(self.seq_length, self.batch_size, self.hidden_size)] *
            self.num_hidden_layers,
        )
示例#2
0
        def create_and_check_xlnet_token_classif(
                self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask,
                input_mask, target_mapping, segment_ids, lm_labels,
                sequence_labels, is_impossible_labels, token_labels):
            model = XLNetForTokenClassification(config)
            model.to(torch_device)
            model.eval()

            logits, mems_1 = model(input_ids_1)
            loss, logits, mems_1 = model(input_ids_1, labels=token_labels)

            result = {
                "loss": loss,
                "mems_1": mems_1,
                "logits": logits,
            }

            self.parent.assertListEqual(list(result["loss"].size()), [])
            self.parent.assertListEqual(list(result["logits"].size()), [
                self.batch_size, self.seq_length, self.type_sequence_label_size
            ])
            self.parent.assertListEqual(
                list(list(mem.size()) for mem in result["mems_1"]),
                [[self.seq_length, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers)