def test_registry_raises(): backbones = FlashRegistry("backbones") @backbones def my_model(nc_input=5, nc_output=6): return nn.Linear(nc_input, nc_output), nc_input, nc_output with pytest.raises( MisconfigurationException, match="You can only register a function, found: Linear"): backbones(nn.Linear(1, 1), name="cho") backbones(my_model, name="cho", override=True) with pytest.raises(MisconfigurationException, match="Function with name: cho and metadata: {}"): backbones(my_model, name="cho", override=False) with pytest.raises(KeyError, match="Found no matches"): backbones.get("cho", foo="bar") backbones.remove("cho") with pytest.raises(KeyError, match="Key: cho is not in FlashRegistry"): backbones.get("cho") with pytest.raises(TypeError, match="name` must be a str"): backbones(name=float) # noqa
def test_registry(): backbones = FlashRegistry("backbones") @backbones def my_model(nc_input=5, nc_output=6): return nn.Linear(nc_input, nc_output), nc_input, nc_output mlp, nc_input, nc_output = backbones.get("my_model")(nc_output=7) assert nc_input == 5 assert nc_output == 7 assert mlp.weight.shape == (7, 5) # basic get backbones(my_model, name="cho") assert backbones.get("cho") # test override backbones(my_model, name="cho", override=True) functions = backbones.get("cho", strict=False) assert len(functions) == 1 # test metadata filtering backbones(my_model, name="cho", namespace="timm", type="resnet") backbones(my_model, name="cho", namespace="torchvision", type="resnet") backbones(my_model, name="cho", namespace="timm", type="densenet") backbones(my_model, name="cho", namespace="timm", type="alexnet") function = backbones.get("cho", with_metadata=True, type="resnet", namespace="timm") assert function["name"] == "cho" assert function["metadata"] == {"namespace": "timm", "type": "resnet"} # test strict=False and with_metadata=False functions = backbones.get("cho", namespace="timm", strict=False) assert len(functions) == 3 assert all(callable(f) for f in functions) # test available keys assert backbones.available_keys() == [ 'cho', 'cho', 'cho', 'cho', 'cho', 'my_model' ]