def convert_weight_and_push(name: str, config: ResNetConfig, save_directory: Path, push_to_hub: bool = True):
    print(f"Converting {name}...")
    with torch.no_grad():
        from_model = timm.create_model(name, pretrained=True).eval()
        our_model = ResNetForImageClassification(config).eval()
        module_transfer = ModuleTransfer(src=from_model, dest=our_model)
        x = torch.randn((1, 3, 224, 224))
        module_transfer(x)

    assert torch.allclose(from_model(x), our_model(x).logits), "The model logits don't match the original one."

    checkpoint_name = f"resnet{'-'.join(name.split('resnet'))}"
    print(checkpoint_name)

    if push_to_hub:
        our_model.push_to_hub(
            repo_path_or_name=save_directory / checkpoint_name,
            commit_message="Add model",
            use_temp_dir=True,
        )

        # we can use the convnext one
        feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k")
        feature_extractor.push_to_hub(
            repo_path_or_name=save_directory / checkpoint_name,
            commit_message="Add feature extractor",
            use_temp_dir=True,
        )

        print(f"Pushed {checkpoint_name}")
 def create_and_check_for_image_classification(self, config, pixel_values, labels):
     config.num_labels = self.num_labels
     model = ResNetForImageClassification(config)
     model.to(torch_device)
     model.eval()
     result = model(pixel_values, labels=labels)
     self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
    def test_inference_image_classification_head(self):
        model = ResNetForImageClassification.from_pretrained(RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device)

        feature_extractor = self.default_feature_extractor
        image = prepare_img()
        inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)

        # forward pass
        with torch.no_grad():
            outputs = model(**inputs)

        # verify the logits
        expected_shape = torch.Size((1, 1000))
        self.assertEqual(outputs.logits.shape, expected_shape)

        expected_slice = torch.tensor([-11.1069, -9.7877, -8.3777]).to(torch_device)

        self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))