def test_save_load_pretrained_default(self):
        tokenizer_slow = self.get_tokenizer()
        tokenizer_fast = self.get_rust_tokenizer()
        feature_extractor = self.get_feature_extractor()

        processor_slow = OwlViTProcessor(tokenizer=tokenizer_slow,
                                         feature_extractor=feature_extractor)
        processor_slow.save_pretrained(self.tmpdirname)
        processor_slow = OwlViTProcessor.from_pretrained(self.tmpdirname,
                                                         use_fast=False)

        processor_fast = OwlViTProcessor(tokenizer=tokenizer_fast,
                                         feature_extractor=feature_extractor)
        processor_fast.save_pretrained(self.tmpdirname)
        processor_fast = OwlViTProcessor.from_pretrained(self.tmpdirname)

        self.assertEqual(processor_slow.tokenizer.get_vocab(),
                         tokenizer_slow.get_vocab())
        self.assertEqual(processor_fast.tokenizer.get_vocab(),
                         tokenizer_fast.get_vocab())
        self.assertEqual(tokenizer_slow.get_vocab(),
                         tokenizer_fast.get_vocab())
        self.assertIsInstance(processor_slow.tokenizer, CLIPTokenizer)
        self.assertIsInstance(processor_fast.tokenizer, CLIPTokenizerFast)

        self.assertEqual(processor_slow.feature_extractor.to_json_string(),
                         feature_extractor.to_json_string())
        self.assertEqual(processor_fast.feature_extractor.to_json_string(),
                         feature_extractor.to_json_string())
        self.assertIsInstance(processor_slow.feature_extractor,
                              OwlViTFeatureExtractor)
        self.assertIsInstance(processor_fast.feature_extractor,
                              OwlViTFeatureExtractor)
Esempio n. 2
0
def convert_owlvit_checkpoint(pt_backbone, flax_params, attn_params, pytorch_dump_folder_path, config_path=None):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    repo = Repository(pytorch_dump_folder_path, clone_from=f"google/{pytorch_dump_folder_path}")
    repo.git_pull()

    if config_path is not None:
        config = OwlViTConfig.from_pretrained(config_path)
    else:
        config = OwlViTConfig()

    hf_backbone = OwlViTModel(config).eval()
    hf_model = OwlViTForObjectDetection(config).eval()

    copy_text_model_and_projection(hf_backbone, pt_backbone)
    copy_vision_model_and_projection(hf_backbone, pt_backbone)
    hf_backbone.logit_scale = pt_backbone.logit_scale
    copy_flax_attn_params(hf_backbone, attn_params)

    hf_model.owlvit = hf_backbone
    copy_class_merge_token(hf_model, flax_params)
    copy_class_box_heads(hf_model, flax_params)

    # Save HF model
    hf_model.save_pretrained(repo.local_dir)

    # Initialize feature extractor
    feature_extractor = OwlViTFeatureExtractor(
        size=config.vision_config.image_size, crop_size=config.vision_config.image_size
    )
    # Initialize tokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", pad_token="!", model_max_length=16)

    # Initialize processor
    processor = OwlViTProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
    feature_extractor.save_pretrained(repo.local_dir)
    processor.save_pretrained(repo.local_dir)

    repo.git_add()
    repo.git_commit("Upload model and processor")
    repo.git_push()
    def test_save_load_pretrained_additional_features(self):
        processor = OwlViTProcessor(
            tokenizer=self.get_tokenizer(),
            feature_extractor=self.get_feature_extractor())
        processor.save_pretrained(self.tmpdirname)

        tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)",
                                                  eos_token="(EOS)")
        feature_extractor_add_kwargs = self.get_feature_extractor(
            do_normalize=False)

        processor = OwlViTProcessor.from_pretrained(self.tmpdirname,
                                                    bos_token="(BOS)",
                                                    eos_token="(EOS)",
                                                    do_normalize=False)

        self.assertEqual(processor.tokenizer.get_vocab(),
                         tokenizer_add_kwargs.get_vocab())
        self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast)

        self.assertEqual(processor.feature_extractor.to_json_string(),
                         feature_extractor_add_kwargs.to_json_string())
        self.assertIsInstance(processor.feature_extractor,
                              OwlViTFeatureExtractor)