def test_save_load_pretrained_additional_features(self): processor = LayoutXLMProcessor( feature_extractor=self.get_feature_extractor(), tokenizer=self.get_tokenizer()) processor.save_pretrained(self.tmpdirname) # slow tokenizer tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") feature_extractor_add_kwargs = self.get_feature_extractor( do_resize=False, size=30) processor = LayoutXLMProcessor.from_pretrained( self.tmpdirname, use_fast=False, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30, ) self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) self.assertIsInstance(processor.tokenizer, LayoutXLMTokenizer) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor) # fast tokenizer tokenizer_add_kwargs = self.get_rust_tokenizer(bos_token="(BOS)", eos_token="(EOS)") feature_extractor_add_kwargs = self.get_feature_extractor( do_resize=False, size=30) processor = LayoutXLMProcessor.from_pretrained(self.tmpdirname, use_xlm=True, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30) self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) self.assertIsInstance(processor.tokenizer, LayoutXLMTokenizerFast) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
def test_save_load_pretrained_default(self): feature_extractor = self.get_feature_extractor() tokenizers = self.get_tokenizers() for tokenizer in tokenizers: processor = LayoutXLMProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) processor.save_pretrained(self.tmpdirname) processor = LayoutXLMProcessor.from_pretrained(self.tmpdirname) self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) self.assertIsInstance(processor.tokenizer, (LayoutXLMTokenizer, LayoutXLMTokenizerFast)) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)