def convert_weight_and_push( name: str, config: VanConfig, checkpoint: str, from_model: nn.Module, save_directory: Path, push_to_hub: bool = True, ): print(f"Downloading weights for {name}...") checkpoint_path = cached_download(checkpoint) print(f"Converting {name}...") from_state_dict = torch.load(checkpoint_path)["state_dict"] from_model.load_state_dict(from_state_dict) from_model.eval() with torch.no_grad(): our_model = VanForImageClassification(config).eval() module_transfer = ModuleTransfer(src=from_model, dest=our_model) x = torch.randn((1, 3, 224, 224)) module_transfer(x) our_model = copy_parameters(from_model, our_model) assert torch.allclose( from_model(x), our_model(x).logits), "The model logits don't match the original one." checkpoint_name = name print(checkpoint_name) if push_to_hub: our_model.push_to_hub( repo_path_or_name=save_directory / checkpoint_name, commit_message="Add model", use_temp_dir=True, ) # we can use the convnext one feature_extractor = AutoFeatureExtractor.from_pretrained( "facebook/convnext-base-224-22k-1k") feature_extractor.push_to_hub( repo_path_or_name=save_directory / checkpoint_name, commit_message="Add feature extractor", use_temp_dir=True, ) print(f"Pushed {checkpoint_name}")
def create_and_check_for_image_classification(self, config, pixel_values, labels): model = VanForImageClassification(config) model.to(torch_device) model.eval() result = model(pixel_values, labels=labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def test_inference_image_classification_head(self): model = VanForImageClassification.from_pretrained(VAN_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device) feature_extractor = self.default_feature_extractor image = prepare_img() inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass with torch.no_grad(): outputs = model(**inputs) # verify the logits expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) expected_slice = torch.tensor([0.1029, -0.0904, -0.6365]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))