コード例 #1
0
def test_add_function_name_override():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(foo, name="bar")

    assert "bar" in r._factories
コード例 #2
0
def test_kwargs():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(bar=foo)

    r.get("bar")
コード例 #3
0
def test_add_lambda_override():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(lambda x: x, name="bar")

    assert "bar" in r._factories
コード例 #4
0
def test_add_function():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(foo)

    assert "foo" in r._factories
コード例 #5
0
def test_double_add_same_nofail():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")
    r.add(foo)
    # It's ok to add same twice, forced by python relative import
    # implementation
    # https://github.com/catalyst-team/catalyst/issues/135
    r.add(foo)
コード例 #6
0
def test_decorator():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    @r.add
    def bar():
        pass

    r.get("bar")
コード例 #7
0
def test_fail_instantiation():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(foo)

    with pytest.raises(RegistryException) as e_ifo:
        r.get_instance("foo", c=1)

    assert hasattr(e_ifo.value, "__cause__")
コード例 #8
0
def test_add_module():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add_from_module(module)

    r.get("foo")

    with pytest.raises(RegistryException):
        r.get_instance("bar")
コード例 #9
0
ファイル: test_registry.py プロジェクト: zkid18/catalyst
def test_name_key():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry(name_key="_key_")

    r.add(foo)

    res = r.get_from_params(**{"_key_": "foo", "a": 1, "b": 2})()
    assert res == {"a": 1, "b": 2}

    res = r.get_from_params(**{"_target_": "foo", "a": 1, "b": 2})
    assert res == {"_target_": "foo", "a": 1, "b": 2}
コード例 #10
0
def test_from_config():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("obj")

    r.add(foo)

    res = r.get_from_params(**{"obj": "foo", "a": 1, "b": 2})
    assert res == {"a": 1, "b": 2}

    res = r.get_from_params(**{})
    assert res is None
コード例 #11
0
def test_fail_double_add_different():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")
    r.add(foo)

    with pytest.raises(RegistryException):

        def bar():
            pass

        r.add(foo=bar)
コード例 #12
0
def test_instantiations():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(foo)

    res = r.get_instance("foo", 1, 2)
    assert res == {"a": 1, "b": 2}

    res = r.get_instance("foo", 1, b=2)
    assert res == {"a": 1, "b": 2}

    res = r.get_instance("foo", a=1, b=2)
    assert res == {"a": 1, "b": 2}
コード例 #13
0
ファイル: test_registry.py プロジェクト: zkid18/catalyst
def test_instantiations():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry()

    r.add(foo)

    res = r.get_instance("foo", 1, 2)()
    assert res == {"a": 1, "b": 2}

    res = r.get_instance("foo", 1, b=2)()
    assert res == {"a": 1, "b": 2}

    res = r.get_instance("foo", a=1, b=2)()
    assert res == {"a": 1, "b": 2}

    res = r.get_instance("tests.catalyst.tools.registery_foo.foo", a=1, b=2)()
    assert res == {"a": 1, "b": 2}
コード例 #14
0
def test_meta_factory():
    """@TODO: Docs. Contribution is welcome."""  # noqa: D202

    def meta_1(fn, args, kwargs):
        return fn

    def meta_2(fn, args, kwargs):
        return 1

    r = Registry("obj", meta_1)
    r.add(foo)

    res = r.get_from_params(**{"obj": "foo"})
    assert res == foo

    res = r.get_from_params(**{"obj": "foo"}, meta_factory=meta_2)
    assert res == 1
コード例 #15
0
    SAMPLERS,
    Scheduler,
    SCHEDULERS,
    Transform,
    TRANSFORMS,
)
from catalyst.tools.registry import Registry


def _callbacks_loader(r: Registry):
    from catalyst.core import callbacks as m

    r.add_from_module(m)


CALLBACKS = Registry("callback")
CALLBACKS.late_add(_callbacks_loader)
Callback = CALLBACKS.add

__all__ = [
    "Callback",
    "Criterion",
    "Optimizer",
    "Scheduler",
    "Module",
    "Model",
    "Sampler",
    "Transform",
    "CALLBACKS",
    "CRITERIONS",
    "GRAD_CLIPPERS",
コード例 #16
0
ファイル: test_registry.py プロジェクト: zkid18/catalyst
def test_recursive_get_from_config():
    def meta_factory(factory, args, kwargs):
        return factory(*args, **kwargs)

    r = Registry(meta_factory=meta_factory)

    r.add(foo)

    res = r.get_from_params(**{"_target_": "foo", "a": 1, "b": 2})
    assert res == {"a": 1, "b": 2}

    res = r.get_from_params(
        **{
            "_target_":
            "foo",
            "a": {
                "_target_": "foo",
                "a": {
                    "_target_": "foo",
                    "a": 1,
                    "b": 2
                },
                "b": 2
            },
            "b": [{
                "_target_": "foo",
                "a": 1,
                "b": 2
            }, {
                "_target_": "foo",
                "a": 1,
                "b": 2
            }],
        })
    assert res == {
        "a": {
            "a": {
                "a": 1,
                "b": 2
            },
            "b": 2
        },
        "b": [{
            "a": 1,
            "b": 2
        }, {
            "a": 1,
            "b": 2
        }]
    }

    res = r.get_from_params(
        **{
            "a": {
                "_target_": "foo",
                "a": 1,
                "b": 2
            },
            "b": {
                "_target_": "foo",
                "a": 1,
                "b": 2
            }
        })
    assert res == {"a": {"a": 1, "b": 2}, "b": {"a": 1, "b": 2}}

    # check nested dicts support
    res = r.get_from_params(
        **{
            "a": {
                "_target_": "foo",
                "a": 1,
                "b": 2
            },
            "b": {
                "_target_": "foo",
                "a": 1,
                "b": 2
            }
        })
    assert res == {"a": {"a": 1, "b": 2}, "b": {"a": 1, "b": 2}}

    res = r.get_from_params(**{
        "_target_": "foo",
        "a": {
            "c": {
                "_target_": "foo",
                "a": 1,
                "b": 2
            }
        },
        "b": 2
    })
    assert res == {"a": {"c": {"a": 1, "b": 2}}, "b": 2}

    # check nested lists support
    res = r.get_from_params(
        **{
            "_target_": "foo",
            "a": [[[{
                "_target_": "foo",
                "a": 1,
                "b": 2
            }]]],
            "b": {
                "c": 3,
                "d": {
                    "e": 4
                }
            },
        })
    assert res == {"a": [[[{"a": 1, "b": 2}]]], "b": {"c": 3, "d": {"e": 4}}}

    # check shared_params
    res = r.get_from_params(**{
        "_target_": "foo",
        "a": {
            "_target_": "foo",
            "a": 1,
            "b": 3
        }
    },
                            shared_params={"b": 2})
    assert res == {"a": {"a": 1, "b": 3}, "b": 2}
コード例 #17
0
def test_fail_multiple_with_name():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    with pytest.raises(RegistryException):
        r.add(foo, foo, name="bar")
コード例 #18
0
ファイル: registry.py プロジェクト: zafariqballevi2/catalyst
        from albumentations import pytorch as p

        r.add_from_module(p, prefix=["A.", "albu.", "albumentations."])

        from catalyst.contrib.data.cv import transforms as t

        r.add_from_module(t, prefix=["catalyst.", "C."])
    except ImportError as ex:
        if settings.albumentations_required:
            logger.warning(
                "albumentations not available, to install albumentations, "
                "run `pip install albumentations`.")
            raise ex


TRANSFORMS = Registry("transform")
TRANSFORMS.late_add(_transforms_loader)
Transform = TRANSFORMS.add


def _samplers_loader(r: Registry):
    from torch.utils.data import sampler as s

    factories = {
        k: v
        for k, v in s.__dict__.items() if "Sampler" in k and k != "Sampler"
    }
    r.add(**factories)
    from catalyst.data import sampler

    r.add_from_module(sampler)
コード例 #19
0
def test_add_lambda_fail():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    with pytest.raises(RegistryException):
        r.add(lambda x: x)