def test_from_model(self):
        for model in [self._get_classy_model(), self._get_non_classy_model()]:
            hub_interface = ClassyHubInterface.from_model(model)

            self.assertIsNone(hub_interface.task)
            self.assertIsInstance(hub_interface.model, ClassyModel)

            # this will pick up the transform from imagenet
            self._test_predict_and_extract_features(hub_interface)
Exemple #2
0
    def test_from_task(self):
        config = get_test_task_config()
        task = build_task(config)
        hub_interface = ClassyHubInterface.from_task(task)

        self.assertIsInstance(hub_interface.task, ClassyTask)
        self.assertIsInstance(hub_interface.model, ClassyModel)

        # this will pick up the transform from the task's config
        self._test_predict_and_extract_features(hub_interface)

        # test that the correct transform is picked up
        phase_type = "test"
        test_transform = TestTransform()
        task.datasets[phase_type].transform = test_transform
        hub_interface = ClassyHubInterface.from_task(task)
        dataset = hub_interface.create_image_dataset(
            image_files=[self.image_path], phase_type=phase_type)
        self.assertIsInstance(dataset.transform, TestTransform)
 def _test_predict_and_extract_features(self, hub_interface: ClassyHubInterface):
     dataset = hub_interface.create_image_dataset(
         [self.image_path], phase_type="test"
     )
     data_iterator = hub_interface.get_data_iterator(dataset)
     input = next(data_iterator)
     # set the model to eval mode
     hub_interface.eval()
     output = hub_interface.predict(input)
     self.assertIsNotNone(output)
     # see the prediction for the input
     hub_interface.predict(input).argmax().item()
     # check extract features
     output = hub_interface.extract_features(input)
     self.assertIsNotNone(output)
def _create_interface_from_torchhub(github, *args, **kwargs):
    model = torch.hub.load(github, *args, **kwargs)
    return ClassyHubInterface.from_model(model)