def create_and_check_for_masked_lm(self, config, pixel_values, labels,
                                    pixel_labels):
     model = BeitForMaskedImageModeling(config=config)
     model.to(torch_device)
     model.eval()
     result = model(pixel_values)
     self.parent.assertEqual(
         result.logits.shape,
         (self.batch_size, self.seq_length - 1, self.vocab_size))
 def create_and_check_for_masked_lm(self, config, pixel_values, labels):
     model = BeitForMaskedImageModeling(config=config)
     model.to(torch_device)
     model.eval()
     result = model(pixel_values)
     # expected sequence length = num_patches
     image_size = to_2tuple(self.image_size)
     patch_size = to_2tuple(self.patch_size)
     num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
     self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))