示例#1
0
    def test_classy_model_adapter(self):
        model = TestModel()
        classy_model = ClassyModel.from_model(model)
        # test that the returned object is an instance of ClassyModel
        self.assertIsInstance(classy_model, ClassyModel)
        # test that the returned object is also an instance of _ClassyModelAdapter
        self.assertIsInstance(classy_model, _ClassyModelAdapter)

        # test that forward works correctly
        input = torch.zeros((100, 10))
        output = classy_model(input)
        self.assertEqual(output.shape, (100, 5))

        # test that extract_features works correctly
        input = torch.zeros((100, 10))
        output = classy_model.extract_features(input)
        self.assertEqual(output.shape, (100, 20))

        # test that get_classy_state and set_classy_state work
        nn.init.constant_(classy_model.model.linear.weight, 1)
        weights = copy.deepcopy(classy_model.model.linear.weight.data)
        state_dict = classy_model.get_classy_state(deep_copy=True)
        nn.init.constant_(classy_model.model.linear.weight, 0)
        classy_model.set_classy_state(state_dict)
        self.assertTrue(torch.allclose(weights, classy_model.model.linear.weight.data))
示例#2
0
 def test_classy_model_adapter_properties(self):
     # test that the properties work correctly when passed to the adapter
     model = TestModel()
     input_shape = (10,)
     model_depth = 1
     classy_model = ClassyModel.from_model(
         model, input_shape=input_shape, model_depth=model_depth
     )
     self.assertEqual(classy_model.input_shape, input_shape)
示例#3
0
 def test_heads(self):
     model = models.resnet50(pretrained=False)
     classy_model = ClassyModel.from_model(model)
     num_classes = 5
     head = FullyConnectedHead(
         unique_id="default", in_plane=2048, num_classes=num_classes
     )
     classy_model.set_heads({"layer4": [head]})
     input = torch.ones((1, 3, 224, 224))
     self.assertEqual(classy_model(input).shape, (1, num_classes))
示例#4
0
    def test_train_step(self):
        # test that the model can be run in a train step
        model = models.resnet34(pretrained=False)
        classy_model = ClassyModel.from_model(model)

        config = get_fast_test_task_config()
        task = build_task(config)
        task.set_model(classy_model)
        trainer = LocalTrainer()
        trainer.train(task)
    def from_model(cls, model: Union[nn.Module, ClassyModel]) -> "ClassyHubInterface":
        """Instantiates the ClassyHubInterface from a model.

        This function returns a hub interface based on a ClassyModel

        Args:
            model: torchhub model

        """
        if not isinstance(model, ClassyModel):
            model = ClassyModel.from_model(model)
        return cls(model=model)
示例#6
0
 def __init__(self):
     super().__init__()
     self.resnet = ClassyModel.from_model(resnet50())
     self.relu = nn.ReLU()
     self.linear = nn.Linear(1000, 8)