コード例 #1
0
def convert_weight_and_push(
    name: str,
    config: VanConfig,
    checkpoint: str,
    from_model: nn.Module,
    save_directory: Path,
    push_to_hub: bool = True,
):
    print(f"Downloading weights for {name}...")
    checkpoint_path = cached_download(checkpoint)
    print(f"Converting {name}...")
    from_state_dict = torch.load(checkpoint_path)["state_dict"]
    from_model.load_state_dict(from_state_dict)
    from_model.eval()
    with torch.no_grad():
        our_model = VanForImageClassification(config).eval()
        module_transfer = ModuleTransfer(src=from_model, dest=our_model)
        x = torch.randn((1, 3, 224, 224))
        module_transfer(x)
        our_model = copy_parameters(from_model, our_model)

    assert torch.allclose(
        from_model(x),
        our_model(x).logits), "The model logits don't match the original one."

    checkpoint_name = name
    print(checkpoint_name)

    if push_to_hub:
        our_model.push_to_hub(
            repo_path_or_name=save_directory / checkpoint_name,
            commit_message="Add model",
            use_temp_dir=True,
        )

        # we can use the convnext one
        feature_extractor = AutoFeatureExtractor.from_pretrained(
            "facebook/convnext-base-224-22k-1k")
        feature_extractor.push_to_hub(
            repo_path_or_name=save_directory / checkpoint_name,
            commit_message="Add feature extractor",
            use_temp_dir=True,
        )

        print(f"Pushed {checkpoint_name}")
コード例 #2
0
 def create_and_check_for_image_classification(self, config, pixel_values,
                                               labels):
     model = VanForImageClassification(config)
     model.to(torch_device)
     model.eval()
     result = model(pixel_values, labels=labels)
     self.parent.assertEqual(result.logits.shape,
                             (self.batch_size, self.num_labels))
コード例 #3
0
    def test_inference_image_classification_head(self):
        model = VanForImageClassification.from_pretrained(VAN_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device)

        feature_extractor = self.default_feature_extractor
        image = prepare_img()
        inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)

        # forward pass
        with torch.no_grad():
            outputs = model(**inputs)

        # verify the logits
        expected_shape = torch.Size((1, 1000))
        self.assertEqual(outputs.logits.shape, expected_shape)

        expected_slice = torch.tensor([0.1029, -0.0904, -0.6365]).to(torch_device)

        self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))