def test_classy_model_adapter(self): model = TestModel() classy_model = ClassyModel.from_model(model) # test that the returned object is an instance of ClassyModel self.assertIsInstance(classy_model, ClassyModel) # test that the returned object is also an instance of _ClassyModelAdapter self.assertIsInstance(classy_model, _ClassyModelAdapter) # test that forward works correctly input = torch.zeros((100, 10)) output = classy_model(input) self.assertEqual(output.shape, (100, 5)) # test that extract_features works correctly input = torch.zeros((100, 10)) output = classy_model.extract_features(input) self.assertEqual(output.shape, (100, 20)) # test that get_classy_state and set_classy_state work nn.init.constant_(classy_model.model.linear.weight, 1) weights = copy.deepcopy(classy_model.model.linear.weight.data) state_dict = classy_model.get_classy_state(deep_copy=True) nn.init.constant_(classy_model.model.linear.weight, 0) classy_model.set_classy_state(state_dict) self.assertTrue(torch.allclose(weights, classy_model.model.linear.weight.data))
def test_classy_model_adapter_properties(self): # test that the properties work correctly when passed to the adapter model = TestModel() input_shape = (10,) model_depth = 1 classy_model = ClassyModel.from_model( model, input_shape=input_shape, model_depth=model_depth ) self.assertEqual(classy_model.input_shape, input_shape)
def test_heads(self): model = models.resnet50(pretrained=False) classy_model = ClassyModel.from_model(model) num_classes = 5 head = FullyConnectedHead( unique_id="default", in_plane=2048, num_classes=num_classes ) classy_model.set_heads({"layer4": [head]}) input = torch.ones((1, 3, 224, 224)) self.assertEqual(classy_model(input).shape, (1, num_classes))
def test_train_step(self): # test that the model can be run in a train step model = models.resnet34(pretrained=False) classy_model = ClassyModel.from_model(model) config = get_fast_test_task_config() task = build_task(config) task.set_model(classy_model) trainer = LocalTrainer() trainer.train(task)
def from_model(cls, model: Union[nn.Module, ClassyModel]) -> "ClassyHubInterface": """Instantiates the ClassyHubInterface from a model. This function returns a hub interface based on a ClassyModel Args: model: torchhub model """ if not isinstance(model, ClassyModel): model = ClassyModel.from_model(model) return cls(model=model)
def __init__(self): super().__init__() self.resnet = ClassyModel.from_model(resnet50()) self.relu = nn.ReLU() self.linear = nn.Linear(1000, 8)