Ejemplo n.º 1
0
# limitations under the License.

import json
import os
import shutil
import tempfile
import unittest
from typing import List

from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast
from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES
from transformers.testing_utils import require_pytesseract, require_tokenizers, require_torch, slow
from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available

if is_pytesseract_available():
    from PIL import Image

    from transformers import LayoutLMv2FeatureExtractor, LayoutLMv2Processor


@require_pytesseract
@require_tokenizers
class LayoutLMv2ProcessorTest(unittest.TestCase):
    tokenizer_class = LayoutLMv2Tokenizer
    rust_tokenizer_class = LayoutLMv2TokenizerFast

    def setUp(self):
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
Ejemplo n.º 2
0
class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin,
                                      unittest.TestCase):

    feature_extraction_class = LayoutLMv2FeatureExtractor if is_pytesseract_available(
    ) else None

    def setUp(self):
        self.feature_extract_tester = LayoutLMv2FeatureExtractionTester(self)

    @property
    def feat_extract_dict(self):
        return self.feature_extract_tester.prepare_feat_extract_dict()

    def test_feat_extract_properties(self):
        feature_extractor = self.feature_extraction_class(
            **self.feat_extract_dict)
        self.assertTrue(hasattr(feature_extractor, "do_resize"))
        self.assertTrue(hasattr(feature_extractor, "size"))
        self.assertTrue(hasattr(feature_extractor, "apply_ocr"))

    def test_batch_feature(self):
        pass

    def test_call_pil(self):
        # Initialize feature_extractor
        feature_extractor = self.feature_extraction_class(
            **self.feat_extract_dict)
        # create random PIL images
        image_inputs = prepare_image_inputs(self.feature_extract_tester,
                                            equal_resolution=False)
        for image in image_inputs:
            self.assertIsInstance(image, Image.Image)

        # Test not batched input
        encoding = feature_extractor(image_inputs[0], return_tensors="pt")
        self.assertEqual(
            encoding.pixel_values.shape,
            (
                1,
                self.feature_extract_tester.num_channels,
                self.feature_extract_tester.size,
                self.feature_extract_tester.size,
            ),
        )

        self.assertIsInstance(encoding.words, list)
        self.assertIsInstance(encoding.boxes, list)

        # Test batched
        encoded_images = feature_extractor(image_inputs,
                                           return_tensors="pt").pixel_values
        self.assertEqual(
            encoded_images.shape,
            (
                self.feature_extract_tester.batch_size,
                self.feature_extract_tester.num_channels,
                self.feature_extract_tester.size,
                self.feature_extract_tester.size,
            ),
        )

    def test_call_numpy(self):
        # Initialize feature_extractor
        feature_extractor = self.feature_extraction_class(
            **self.feat_extract_dict)
        # create random numpy tensors
        image_inputs = prepare_image_inputs(self.feature_extract_tester,
                                            equal_resolution=False,
                                            numpify=True)
        for image in image_inputs:
            self.assertIsInstance(image, np.ndarray)

        # Test not batched input
        encoded_images = feature_extractor(image_inputs[0],
                                           return_tensors="pt").pixel_values
        self.assertEqual(
            encoded_images.shape,
            (
                1,
                self.feature_extract_tester.num_channels,
                self.feature_extract_tester.size,
                self.feature_extract_tester.size,
            ),
        )

        # Test batched
        encoded_images = feature_extractor(image_inputs,
                                           return_tensors="pt").pixel_values
        self.assertEqual(
            encoded_images.shape,
            (
                self.feature_extract_tester.batch_size,
                self.feature_extract_tester.num_channels,
                self.feature_extract_tester.size,
                self.feature_extract_tester.size,
            ),
        )

    def test_call_pytorch(self):
        # Initialize feature_extractor
        feature_extractor = self.feature_extraction_class(
            **self.feat_extract_dict)
        # create random PyTorch tensors
        image_inputs = prepare_image_inputs(self.feature_extract_tester,
                                            equal_resolution=False,
                                            torchify=True)
        for image in image_inputs:
            self.assertIsInstance(image, torch.Tensor)

        # Test not batched input
        encoded_images = feature_extractor(image_inputs[0],
                                           return_tensors="pt").pixel_values
        self.assertEqual(
            encoded_images.shape,
            (
                1,
                self.feature_extract_tester.num_channels,
                self.feature_extract_tester.size,
                self.feature_extract_tester.size,
            ),
        )

        # Test batched
        encoded_images = feature_extractor(image_inputs,
                                           return_tensors="pt").pixel_values
        self.assertEqual(
            encoded_images.shape,
            (
                self.feature_extract_tester.batch_size,
                self.feature_extract_tester.num_channels,
                self.feature_extract_tester.size,
                self.feature_extract_tester.size,
            ),
        )

    def test_layoutlmv2_integration_test(self):
        # with apply_OCR = True
        feature_extractor = LayoutLMv2FeatureExtractor()

        from datasets import load_dataset

        ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test")

        image = Image.open(ds[0]["file"]).convert("RGB")

        encoding = feature_extractor(image, return_tensors="pt")

        self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
        self.assertEqual(len(encoding.words), len(encoding.boxes))

        # fmt: off
        # the words and boxes were obtained with Tesseract 4.1.1
        expected_words = [[
            '11:14', 'to', '11:39', 'a.m', '11:39', 'to', '11:44', 'a.m.',
            '11:44', 'a.m.', 'to', '12:25', 'p.m.', '12:25', 'to', '12:58',
            'p.m.', '12:58', 'to', '4:00', 'p.m.', '2:00', 'to', '5:00',
            'p.m.', 'Coffee', 'Break', 'Coffee', 'will', 'be', 'served', 'for',
            'men', 'and', 'women', 'in', 'the', 'lobby', 'adjacent', 'to',
            'exhibit', 'area.', 'Please', 'move', 'into', 'exhibit', 'area.',
            '(Exhibits', 'Open)', 'TRRF', 'GENERAL', 'SESSION', '(PART', '|)',
            'Presiding:', 'Lee', 'A.', 'Waller', 'TRRF', 'Vice', 'President',
            '“Introductory', 'Remarks”', 'Lee', 'A.', 'Waller,', 'TRRF',
            'Vice', 'Presi-', 'dent', 'Individual', 'Interviews', 'with',
            'TRRF', 'Public', 'Board', 'Members', 'and', 'Sci-', 'entific',
            'Advisory', 'Council', 'Mem-', 'bers', 'Conducted', 'by', 'TRRF',
            'Treasurer', 'Philip', 'G.', 'Kuehn', 'to', 'get', 'answers',
            'which', 'the', 'public', 'refrigerated', 'warehousing',
            'industry', 'is', 'looking', 'for.', 'Plus', 'questions', 'from',
            'the', 'floor.', 'Dr.', 'Emil', 'M.', 'Mrak,', 'University', 'of',
            'Cal-', 'ifornia,', 'Chairman,', 'TRRF', 'Board;', 'Sam', 'R.',
            'Cecil,', 'University', 'of', 'Georgia', 'College', 'of',
            'Agriculture;', 'Dr.', 'Stanley', 'Charm,', 'Tufts', 'University',
            'School', 'of', 'Medicine;', 'Dr.', 'Robert', 'H.', 'Cotton,',
            'ITT', 'Continental', 'Baking', 'Company;', 'Dr.', 'Owen',
            'Fennema,', 'University', 'of', 'Wis-', 'consin;', 'Dr.', 'Robert',
            'E.', 'Hardenburg,', 'USDA.', 'Questions', 'and', 'Answers',
            'Exhibits', 'Open', 'Capt.', 'Jack', 'Stoney', 'Room', 'TRRF',
            'Scientific', 'Advisory', 'Council', 'Meeting', 'Ballroom', 'Foyer'
        ]]  # noqa: E231
        expected_boxes = [[[141, 57, 214, 69], [228, 58, 252, 69],
                           [141, 75, 216, 88], [230, 79, 280, 88],
                           [142, 260, 218, 273], [230, 261, 255, 273],
                           [143, 279, 218, 290], [231, 282, 290, 291],
                           [143, 342, 218, 354], [231, 345, 289, 355],
                           [202, 362, 227, 373], [143, 379, 220, 392],
                           [231, 382, 291, 394], [144, 714, 220, 726],
                           [231, 715, 256, 726], [144, 732, 220, 745],
                           [232, 736, 291, 747], [144, 769, 218, 782],
                           [231, 770, 256, 782], [141, 788, 202, 801],
                           [215, 791, 274, 804], [143, 826, 204, 838],
                           [215, 826, 240, 838], [142, 844, 202, 857],
                           [215, 847, 274, 859], [334, 57, 427, 69],
                           [440, 57, 522, 69], [369, 75, 461, 88],
                           [469, 75, 516, 88], [528, 76, 562, 88],
                           [570, 76, 667, 88], [675, 75, 711, 87],
                           [721, 79, 778, 88], [789, 75, 840, 88],
                           [369, 97, 470, 107], [484, 94, 507, 106],
                           [518, 94, 562, 107], [576, 94, 655, 110],
                           [668, 94, 792, 109], [804, 95, 829, 107],
                           [369, 113, 465, 125], [477, 116, 547, 125],
                           [562, 113, 658, 125], [671, 116, 748, 125],
                           [761, 113, 811, 125], [369, 131, 465, 143],
                           [477, 133, 548, 143], [563, 130, 698, 145],
                           [710, 130, 802, 146], [336, 171, 412, 183],
                           [423, 171, 572, 183], [582, 170, 716, 184],
                           [728, 171, 817, 187], [829, 171, 844, 186],
                           [338, 197, 482, 212], [507, 196, 557, 209],
                           [569, 196, 595, 208], [610, 196, 702, 209],
                           [505, 214, 583, 226], [595, 214, 656, 227],
                           [670, 215, 807, 227], [335, 259, 543, 274],
                           [556, 259, 708, 272], [372, 279, 422, 291],
                           [435, 279, 460, 291], [474, 279, 574, 292],
                           [587, 278, 664, 291], [676, 278, 738, 291],
                           [751, 279, 834, 291], [372, 298, 434, 310],
                           [335, 341, 483, 354], [497, 341, 655, 354],
                           [667, 341, 728, 354], [740, 341, 825, 354],
                           [335, 360, 430, 372], [442, 360, 534, 372],
                           [545, 359, 687, 372], [697, 360, 754, 372],
                           [765, 360, 823, 373], [334, 378, 428, 391],
                           [440, 378, 577, 394], [590, 378, 705, 391],
                           [720, 378, 801, 391], [334, 397, 400, 409],
                           [370, 416, 529, 429], [544, 416, 576, 432],
                           [587, 416, 665, 428], [677, 416, 814, 429],
                           [372, 435, 452, 450], [465, 434, 495, 447],
                           [511, 434, 600, 447], [611, 436, 637, 447],
                           [649, 436, 694, 451], [705, 438, 824, 447],
                           [369, 453, 452, 466], [464, 454, 509, 466],
                           [522, 453, 611, 469], [625, 453, 792, 469],
                           [370, 472, 556, 488], [570, 472, 684, 487],
                           [697, 472, 718, 485], [732, 472, 835, 488],
                           [369, 490, 411, 503], [425, 490, 484, 503],
                           [496, 490, 635, 506], [645, 490, 707, 503],
                           [718, 491, 761, 503], [771, 490, 840, 503],
                           [336, 510, 374, 521], [388, 510, 447, 522],
                           [460, 510, 489, 521], [503, 510, 580, 522],
                           [592, 509, 736, 525], [745, 509, 770, 522],
                           [781, 509, 840, 522], [338, 528, 434, 541],
                           [448, 528, 596, 541], [609, 527, 687, 540],
                           [700, 528, 792, 541], [336, 546, 397, 559],
                           [407, 546, 431, 559], [443, 546, 525, 560],
                           [537, 546, 680, 562], [688, 546, 714, 559],
                           [722, 546, 837, 562], [336, 565, 449, 581],
                           [461, 565, 485, 577], [497, 565, 665, 581],
                           [681, 565, 718, 577], [732, 565, 837, 580],
                           [337, 584, 438, 597], [452, 583, 521, 596],
                           [535, 584, 677, 599], [690, 583, 787, 596],
                           [801, 583, 825, 596], [338, 602, 478, 615],
                           [492, 602, 530, 614], [543, 602, 638, 615],
                           [650, 602, 676, 614], [688, 602, 788, 615],
                           [802, 602, 843, 614], [337, 621, 502, 633],
                           [516, 621, 615, 637], [629, 621, 774, 636],
                           [789, 621, 827, 633], [337, 639, 418, 652],
                           [432, 640, 571, 653], [587, 639, 731, 655],
                           [743, 639, 769, 652], [780, 639, 841, 652],
                           [338, 658, 440, 673], [455, 658, 491, 670],
                           [508, 658, 602, 671], [616, 658, 638, 670],
                           [654, 658, 835, 674], [337, 677, 429, 689],
                           [337, 714, 482, 726], [495, 714, 548, 726],
                           [561, 714, 683, 726], [338, 770, 461, 782],
                           [474, 769, 554, 785], [489, 788, 562, 803],
                           [576, 788, 643, 801], [656, 787, 751, 804],
                           [764, 788, 844, 801], [334, 825, 421, 838],
                           [430, 824, 574, 838], [584, 824, 723, 841],
                           [335, 844, 450, 857], [464, 843, 583, 860],
                           [628, 862, 755, 875], [769, 861, 848,
                                                  878]]]  # noqa: E231
        # fmt: on

        self.assertListEqual(encoding.words, expected_words)
        self.assertListEqual(encoding.boxes, expected_boxes)

        # with apply_OCR = False
        feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)

        encoding = feature_extractor(image, return_tensors="pt")

        self.assertEqual(
            encoding.pixel_values.shape,
            (
                1,
                3,
                224,
                224,
            ),
        )