def check_encoder_decoder_model_output_attentions(self, config, decoder_config, decoder_input_ids, decoder_attention_mask, labels=None, pixel_values=None, **kwargs): # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, output_attentions=True, ) encoder_attentions = outputs_encoder_decoder["encoder_attentions"] self.assertEqual(len(encoder_attentions), config.num_hidden_layers) # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) image_size = to_2tuple(encoder_model.config.image_size) patch_size = to_2tuple(encoder_model.config.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) seq_len = num_patches + 1 self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads, seq_len, seq_len)) decoder_attentions = outputs_encoder_decoder["decoder_attentions"] num_decoder_layers = (decoder_config.num_decoder_layers if hasattr( decoder_config, "num_decoder_layers") else decoder_config.num_hidden_layers) self.assertEqual(len(decoder_attentions), num_decoder_layers) self.assertEqual( decoder_attentions[0].shape[-3:], (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), ) cross_attentions = outputs_encoder_decoder["cross_attentions"] self.assertEqual(len(cross_attentions), num_decoder_layers) cross_attention_input_seq_len = decoder_input_ids.shape[-1] self.assertEqual( cross_attentions[0].shape[-3:], (decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len), )
def create_and_check_model(self, config, pixel_values, labels): model = ViTModel(config=config) model.to(torch_device) model.eval() result = model(pixel_values) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) image_size = to_2tuple(self.image_size) patch_size = to_2tuple(self.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_pretraining(self, config, pixel_values, labels): model = ViTMAEForPreTraining(config) model.to(torch_device) model.eval() result = model(pixel_values) # expected sequence length = num_patches image_size = to_2tuple(self.image_size) patch_size = to_2tuple(self.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) expected_seq_len = num_patches expected_num_channels = self.patch_size ** 2 * self.num_channels self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
def create_and_check_model(self, config, pixel_values, labels): model = ViTMAEModel(config=config) model.to(torch_device) model.eval() result = model(pixel_values) # expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above # (we add 1 for the [CLS] token) image_size = to_2tuple(self.image_size) patch_size = to_2tuple(self.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1))) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) config.return_dict = True # in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above image_size = to_2tuple(self.model_tester.image_size) patch_size = to_2tuple(self.model_tester.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1))) encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) chunk_length = getattr(self.model_tester, "chunk_length", None) if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True inputs_dict["output_hidden_states"] = False config.return_dict = True model = model_class(config) model.to(torch_device) model.eval() with torch.no_grad(): outputs = model( **self._prepare_for_class(inputs_dict, model_class)) attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) # check that output_attentions also work using config del inputs_dict["output_attentions"] config.output_attentions = True model = model_class(config) model.to(torch_device) model.eval() with torch.no_grad(): outputs = model( **self._prepare_for_class(inputs_dict, model_class)) attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) if chunk_length is not None: self.assertListEqual( list(attentions[0].shape[-4:]), [ self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length ], ) else: self.assertListEqual( list(attentions[0].shape[-3:]), [ self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length ], ) out_len = len(outputs) # Check attention is always last and order is fine inputs_dict["output_attentions"] = True inputs_dict["output_hidden_states"] = True model = model_class(config) model.to(torch_device) model.eval() with torch.no_grad(): outputs = model( **self._prepare_for_class(inputs_dict, model_class)) if hasattr(self.model_tester, "num_hidden_states_types"): added_hidden_states = self.model_tester.num_hidden_states_types elif self.is_encoder_decoder: added_hidden_states = 2 else: added_hidden_states = 1 self.assertEqual(out_len + added_hidden_states, len(outputs)) self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) if chunk_length is not None: self.assertListEqual( list(self_attentions[0].shape[-4:]), [ self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length ], ) else: self.assertListEqual( list(self_attentions[0].shape[-3:]), [ self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length ], )