def create_and_check_xlnet_base_model_with_att_output(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
        token_labels,
    ):
        model = XLNetModel(config)
        model.to(torch_device)
        model.eval()

        attentions = model(input_ids_1,
                           target_mapping=target_mapping,
                           output_attentions=True)["attentions"]

        self.parent.assertEqual(len(attentions), config.n_layer)
        self.parent.assertIsInstance(attentions[0], tuple)
        self.parent.assertEqual(len(attentions[0]), 2)
        self.parent.assertTrue(attentions[0][0].shape, attentions[0][0].shape)
Ejemplo n.º 2
0
        def create_and_check_xlnet_base_model(
                self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask,
                input_mask, target_mapping, segment_ids, lm_labels,
                sequence_labels, is_impossible_labels, token_labels):
            model = XLNetModel(config)
            model.to(torch_device)
            model.eval()

            _, _ = model(input_ids_1, input_mask=input_mask)
            _, _ = model(input_ids_1, attention_mask=input_mask)
            _, _ = model(input_ids_1, token_type_ids=segment_ids)
            outputs, mems_1 = model(input_ids_1)

            result = {
                "mems_1": mems_1,
                "outputs": outputs,
            }

            config.mem_len = 0
            model = XLNetModel(config)
            model.to(torch_device)
            model.eval()
            no_mems_outputs = model(input_ids_1)
            self.parent.assertEqual(len(no_mems_outputs), 1)

            self.parent.assertListEqual(
                list(result["outputs"].size()),
                [self.batch_size, self.seq_length, self.hidden_size])
            self.parent.assertListEqual(
                list(list(mem.size()) for mem in result["mems_1"]),
                [[self.seq_length, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers)
Ejemplo n.º 3
0
    def create_and_check_xlnet_base_model(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
        token_labels,
    ):
        model = XLNetModel(config)
        model.to(torch_device)
        model.eval()

        result = model(input_ids_1, input_mask=input_mask)
        result = model(input_ids_1, attention_mask=input_mask)
        result = model(input_ids_1, token_type_ids=segment_ids)
        result = model(input_ids_1)

        config.mem_len = 0
        model = XLNetModel(config)
        model.to(torch_device)
        model.eval()
        base_model_output = model(input_ids_1)
        self.parent.assertEqual(len(base_model_output), 2)

        self.parent.assertListEqual(
            list(result["last_hidden_state"].size()),
            [self.batch_size, self.seq_length, self.hidden_size],
        )
        self.parent.assertListEqual(
            list(list(mem.size()) for mem in result["mems"]),
            [[self.seq_length, self.batch_size, self.hidden_size]] *
            self.num_hidden_layers,
        )
    def create_and_check_xlnet_model_use_cache(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
        token_labels,
    ):
        model = XLNetModel(config=config)
        model.to(torch_device)
        model.eval()

        # first forward pass
        causal_mask = torch.ones(
            input_ids_1.shape[0],
            input_ids_1.shape[1],
            input_ids_1.shape[1],
            dtype=torch.float,
            device=torch_device,
        )
        causal_mask = torch.triu(causal_mask, diagonal=0)
        outputs_cache = model(input_ids_1,
                              use_cache=True,
                              perm_mask=causal_mask)
        outputs_no_cache = model(input_ids_1,
                                 use_cache=False,
                                 perm_mask=causal_mask)
        outputs_conf = model(input_ids_1)

        self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
        self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1)

        output, mems = outputs_cache.to_tuple()

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

        # append to next input_ids and token_type_ids
        next_input_ids = torch.cat([input_ids_1, next_tokens], dim=-1)

        # causal mask
        causal_mask = torch.ones(
            input_ids_1.shape[0],
            input_ids_1.shape[1] + 1,
            input_ids_1.shape[1] + 1,
            dtype=torch.float,
            device=torch_device,
        )
        causal_mask = torch.triu(causal_mask, diagonal=0)
        single_mask = torch.ones(input_ids_1.shape[0],
                                 1,
                                 1,
                                 dtype=torch.float,
                                 device=torch_device)

        # second forward pass
        output_from_no_past = model(next_input_ids,
                                    perm_mask=causal_mask)["last_hidden_state"]
        output_from_past = model(next_tokens, mems=mems,
                                 perm_mask=single_mask)["last_hidden_state"]

        # select random slice
        random_slice_idx = ids_tensor((1, ), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -1,
                                                        random_slice_idx].detach(
                                                        )
        output_from_past_slice = output_from_past[:, 0,
                                                  random_slice_idx].detach()

        # test that outputs are equal for slice
        self.parent.assertTrue(
            torch.allclose(output_from_past_slice,
                           output_from_no_past_slice,
                           atol=1e-3))