def test_registry_raises():
    backbones = FlashRegistry("backbones")

    @backbones
    def my_model(nc_input=5, nc_output=6):
        return nn.Linear(nc_input, nc_output), nc_input, nc_output

    with pytest.raises(
            MisconfigurationException,
            match="You can only register a function, found: Linear"):
        backbones(nn.Linear(1, 1), name="cho")

    backbones(my_model, name="cho", override=True)

    with pytest.raises(MisconfigurationException,
                       match="Function with name: cho and metadata: {}"):
        backbones(my_model, name="cho", override=False)

    with pytest.raises(KeyError, match="Found no matches"):
        backbones.get("cho", foo="bar")

    backbones.remove("cho")
    with pytest.raises(KeyError, match="Key: cho is not in FlashRegistry"):
        backbones.get("cho")

    with pytest.raises(TypeError, match="name` must be a str"):
        backbones(name=float)  # noqa
def test_registry():
    backbones = FlashRegistry("backbones")

    @backbones
    def my_model(nc_input=5, nc_output=6):
        return nn.Linear(nc_input, nc_output), nc_input, nc_output

    mlp, nc_input, nc_output = backbones.get("my_model")(nc_output=7)
    assert nc_input == 5
    assert nc_output == 7
    assert mlp.weight.shape == (7, 5)

    # basic get
    backbones(my_model, name="cho")
    assert backbones.get("cho")

    # test override
    backbones(my_model, name="cho", override=True)
    functions = backbones.get("cho", strict=False)
    assert len(functions) == 1

    # test metadata filtering
    backbones(my_model, name="cho", namespace="timm", type="resnet")
    backbones(my_model, name="cho", namespace="torchvision", type="resnet")
    backbones(my_model, name="cho", namespace="timm", type="densenet")
    backbones(my_model, name="cho", namespace="timm", type="alexnet")
    function = backbones.get("cho",
                             with_metadata=True,
                             type="resnet",
                             namespace="timm")
    assert function["name"] == "cho"
    assert function["metadata"] == {"namespace": "timm", "type": "resnet"}

    # test strict=False and with_metadata=False
    functions = backbones.get("cho", namespace="timm", strict=False)
    assert len(functions) == 3
    assert all(callable(f) for f in functions)

    # test available keys
    assert backbones.available_keys() == [
        'cho', 'cho', 'cho', 'cho', 'cho', 'my_model'
    ]