Exemple #1
0
        def check_hidden_states_output(inputs_dict, config, model_class):
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            with torch.no_grad():
                outputs = model(
                    **self._prepare_for_class(inputs_dict, model_class))

            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states

            expected_num_layers = getattr(self.model_tester,
                                          "expected_num_hidden_layers",
                                          len(self.model_tester.depths) + 1)
            self.assertEqual(len(hidden_states), expected_num_layers)

            # Swin has a different seq_length
            image_size = to_2tuple(self.model_tester.image_size)
            patch_size = to_2tuple(self.model_tester.patch_size)
            num_patches = (image_size[1] // patch_size[1]) * (image_size[0] //
                                                              patch_size[0])

            self.assertListEqual(
                list(hidden_states[0].shape[-2:]),
                [num_patches, self.model_tester.embed_dim],
            )
        def check_hidden_states_output(inputs_dict, config, model_class):
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            with torch.no_grad():
                outputs = model(
                    **self._prepare_for_class(inputs_dict, model_class))

            hidden_states = outputs.hidden_states

            expected_num_layers = getattr(self.model_tester,
                                          "expected_num_hidden_layers",
                                          len(self.model_tester.depths) + 1)
            self.assertEqual(len(hidden_states), expected_num_layers)

            # Swin has a different seq_length
            image_size = to_2tuple(self.model_tester.image_size)
            patch_size = to_2tuple(self.model_tester.patch_size)

            num_patches = (image_size[1] // patch_size[1]) * (image_size[0] //
                                                              patch_size[0])

            self.assertListEqual(
                list(hidden_states[0].shape[-2:]),
                [num_patches, self.model_tester.embed_dim],
            )

            reshaped_hidden_states = outputs.reshaped_hidden_states
            self.assertEqual(len(reshaped_hidden_states), expected_num_layers)

            batch_size, num_channels, height, width = reshaped_hidden_states[
                0].shape
            reshaped_hidden_states = (reshaped_hidden_states[0].view(
                batch_size, num_channels, height * width).permute(0, 2, 1))
            self.assertListEqual(
                list(reshaped_hidden_states.shape[-2:]),
                [num_patches, self.model_tester.embed_dim],
            )
Exemple #3
0
    def test_hidden_states_output_with_padding(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        config.patch_size = 3

        image_size = to_2tuple(self.model_tester.image_size)
        patch_size = to_2tuple(config.patch_size)

        padded_height = image_size[0] + patch_size[0] - (image_size[0] %
                                                         patch_size[0])
        padded_width = image_size[1] + patch_size[1] - (image_size[1] %
                                                        patch_size[1])

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            self.check_hidden_states_output(inputs_dict, config, model_class,
                                            (padded_height, padded_width))

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True
            self.check_hidden_states_output(inputs_dict, config, model_class,
                                            (padded_height, padded_width))
Exemple #4
0
    def test_hidden_states_output(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        image_size = to_2tuple(self.model_tester.image_size)

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            self.check_hidden_states_output(inputs_dict, config, model_class,
                                            image_size)

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True

            self.check_hidden_states_output(inputs_dict, config, model_class,
                                            image_size)
Exemple #5
0
    def test_attention_outputs(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        config.return_dict = True

        image_size = to_2tuple(self.model_tester.image_size)
        patch_size = to_2tuple(self.model_tester.patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] //
                                                          patch_size[0])
        seq_len = num_patches
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length",
                                     seq_len)
        chunk_length = getattr(self.model_tester, "chunk_length", None)
        if chunk_length is not None and hasattr(self.model_tester,
                                                "num_hashes"):
            encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes

        for model_class in self.all_model_classes:
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = False
            config.return_dict = True
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                outputs = model(
                    **self._prepare_for_class(inputs_dict, model_class))
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
            self.assertEqual(len(attentions), len(self.model_tester.depths))

            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            window_size_squared = config.window_size**2
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                outputs = model(
                    **self._prepare_for_class(inputs_dict, model_class))
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
            self.assertEqual(len(attentions), len(self.model_tester.depths))

            if chunk_length is not None:
                self.assertListEqual(
                    list(attentions[0].shape[-4:]),
                    [
                        self.model_tester.num_heads[0], window_size_squared,
                        chunk_length, window_size_squared
                    ],
                )
            else:
                self.assertListEqual(
                    list(attentions[0].shape[-3:]),
                    [
                        self.model_tester.num_heads[0], window_size_squared,
                        window_size_squared
                    ],
                )
            out_len = len(outputs)

            # Check attention is always last and order is fine
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = True
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                outputs = model(
                    **self._prepare_for_class(inputs_dict, model_class))

            if hasattr(self.model_tester, "num_hidden_states_types"):
                added_hidden_states = self.model_tester.num_hidden_states_types
            elif self.is_encoder_decoder:
                added_hidden_states = 2
            else:
                added_hidden_states = 1
            self.assertEqual(out_len + added_hidden_states, len(outputs))

            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions

            self.assertEqual(len(self_attentions),
                             len(self.model_tester.depths))
            if chunk_length is not None:
                self.assertListEqual(
                    list(self_attentions[0].shape[-4:]),
                    [
                        self.model_tester.num_heads[0], window_size_squared,
                        chunk_length, window_size_squared
                    ],
                )
            else:
                self.assertListEqual(
                    list(self_attentions[0].shape[-3:]),
                    [
                        self.model_tester.num_heads[0], window_size_squared,
                        window_size_squared
                    ],
                )