Beispiel #1
0
        def create_and_check_xlnet_lm_head(
            self,
            config,
            input_ids_1,
            input_ids_2,
            input_ids_q,
            perm_mask,
            input_mask,
            target_mapping,
            segment_ids,
            lm_labels,
            sequence_labels,
            is_impossible_labels,
            token_labels,
        ):
            model = XLNetLMHeadModel(config)
            model.to(torch_device)
            model.eval()

            loss_1, all_logits_1, mems_1 = model(input_ids_1,
                                                 token_type_ids=segment_ids,
                                                 labels=lm_labels)

            loss_2, all_logits_2, mems_2 = model(input_ids_2,
                                                 token_type_ids=segment_ids,
                                                 labels=lm_labels,
                                                 mems=mems_1)

            logits, _ = model(input_ids_q,
                              perm_mask=perm_mask,
                              target_mapping=target_mapping)

            result = {
                "loss_1": loss_1,
                "mems_1": mems_1,
                "all_logits_1": all_logits_1,
                "loss_2": loss_2,
                "mems_2": mems_2,
                "all_logits_2": all_logits_2,
            }

            self.parent.assertListEqual(list(result["loss_1"].size()), [])
            self.parent.assertListEqual(
                list(result["all_logits_1"].size()),
                [self.batch_size, self.seq_length, self.vocab_size])
            self.parent.assertListEqual(
                list(list(mem.size()) for mem in result["mems_1"]),
                [[self.seq_length, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers,
            )

            self.parent.assertListEqual(list(result["loss_2"].size()), [])
            self.parent.assertListEqual(
                list(result["all_logits_2"].size()),
                [self.batch_size, self.seq_length, self.vocab_size])
            self.parent.assertListEqual(
                list(list(mem.size()) for mem in result["mems_2"]),
                [[self.mem_len, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers,
            )
    def create_and_check_xlnet_lm_head(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
        token_labels,
    ):
        model = XLNetLMHeadModel(config)
        model.to(torch_device)
        model.eval()

        result1 = model(input_ids_1,
                        token_type_ids=segment_ids,
                        labels=lm_labels)

        result2 = model(input_ids_2,
                        token_type_ids=segment_ids,
                        labels=lm_labels,
                        mems=result1.mems)

        _ = model(input_ids_q,
                  perm_mask=perm_mask,
                  target_mapping=target_mapping)

        self.parent.assertEqual(result1.loss.shape, ())
        self.parent.assertEqual(
            result1.logits.shape,
            (self.batch_size, self.seq_length, self.vocab_size))
        self.parent.assertListEqual(
            [mem.shape for mem in result1.mems],
            [(self.seq_length, self.batch_size, self.hidden_size)] *
            self.num_hidden_layers,
        )

        self.parent.assertEqual(result2.loss.shape, ())
        self.parent.assertEqual(
            result2.logits.shape,
            (self.batch_size, self.seq_length, self.vocab_size))
        self.parent.assertListEqual(
            [mem.shape for mem in result2.mems],
            [(self.mem_len, self.batch_size, self.hidden_size)] *
            self.num_hidden_layers,
        )