예제 #1
0
def test_decorator():
    r = Registry("")

    @r.add
    def bar():
        pass

    r.get("bar")
예제 #2
0
def _model_loader(r: Registry):
    from catalyst_rl.contrib import models as m
    r.add_from_module(m)

    try:
        import segmentation_models_pytorch as smp
        r.add_from_module(smp, prefix="smp.")
    except ImportError:
        pass
예제 #3
0
def _samplers_loader(r: Registry):
    from torch.utils.data import sampler as s
    factories = {
        k: v
        for k, v in s.__dict__.items() if "Sampler" in k and k != "Sampler"
    }
    r.add(**factories)
    from catalyst_rl.data import sampler
    r.add_from_module(sampler)
예제 #4
0
def test_add_module():
    r = Registry("")

    r.add_from_module(module)

    r.get("foo")

    with pytest.raises(RegistryException):
        r.get_instance("bar")
예제 #5
0
def test_double_add_same_nofail():
    r = Registry("")
    r.add(foo)
    # It's ok to add same twice, forced by python relative import
    # implementation
    # https://github.com/catalyst-team/catalyst/issues/135
    r.add(foo)
예제 #6
0
def test_fail_instantiation():
    r = Registry("")

    r.add(foo)

    with pytest.raises(RegistryException) as e_ifo:
        r.get_instance("foo", c=1)

    assert hasattr(e_ifo.value, "__cause__")
예제 #7
0
def test_instantiations():
    r = Registry("")

    r.add(foo)

    res = r.get_instance("foo", 1, 2)
    assert res == {"a": 1, "b": 2}

    res = r.get_instance("foo", 1, b=2)
    assert res == {"a": 1, "b": 2}

    res = r.get_instance("foo", a=1, b=2)
    assert res == {"a": 1, "b": 2}
예제 #8
0
def test_fail_double_add_different():
    r = Registry("")
    r.add(foo)

    with pytest.raises(RegistryException):

        def bar():
            pass

        r.add(foo=bar)
예제 #9
0
def test_from_config():
    r = Registry("obj")

    r.add(foo)

    res = r.get_from_params(**{"obj": "foo", "a": 1, "b": 2})
    assert res == {"a": 1, "b": 2}

    res = r.get_from_params(**{})
    assert res is None
예제 #10
0
def test_meta_factory():
    def meta_1(fn, args, kwargs):
        return fn

    def meta_2(fn, args, kwargs):
        return 1

    r = Registry("obj", meta_1)
    r.add(foo)

    res = r.get_from_params(**{"obj": "foo"})
    assert res == foo

    res = r.get_from_params(**{"obj": "foo"}, meta_factory=meta_2)
    assert res == 1
예제 #11
0
def _transforms_loader(r: Registry):
    try:
        import albumentations as m
        r.add_from_module(m, prefix=["A.", "albu.", "albumentations."])

        from albumentations import pytorch as p
        r.add_from_module(p, prefix=["A.", "albu.", "albumentations."])

        from catalyst_rl.contrib.data.cv import transforms as t
        r.add_from_module(t, prefix=["catalyst_rl.", "C."])
    except ImportError as ex:
        if os.environ.get("USE_ALBUMENTATIONS", "0") == "1":
            logger.warning(
                "albumentations not available, to install albumentations, "
                "run `pip install albumentations`.")
            raise ex
예제 #12
0
def test_add_function_name_override():
    r = Registry("")

    r.add(foo, name="bar")

    assert "bar" in r._factories
예제 #13
0
def test_add_function():
    r = Registry("")

    r.add(foo)

    assert "foo" in r._factories
예제 #14
0
def _modules_loader(r: Registry):
    from catalyst_rl.contrib.nn import modules as m
    r.add_from_module(m)
예제 #15
0
def _grad_clip_loader(r: Registry):
    from torch.nn.utils import clip_grad as m
    r.add_from_module(m)
예제 #16
0
def _optimizers_loader(r: Registry):
    from catalyst_rl.contrib.nn import optimizers as m
    r.add_from_module(m)
예제 #17
0
def _criterion_loader(r: Registry):
    from catalyst_rl.contrib.nn import criterion as m
    r.add_from_module(m)
예제 #18
0
def _exploration_late_add(r: Registry):
    from . import exploration as m
    r.add_from_module(m)
예제 #19
0
def _onpolicy_algorithms_late_add(r: Registry):
    from .onpolicy import algorithms as m
    r.add_from_module(m)
예제 #20
0
from catalyst_rl.contrib.registry import (Criterion, CRITERIONS, GRAD_CLIPPERS,
                                          Model, MODELS, Module, MODULES,
                                          Optimizer, OPTIMIZERS, Sampler,
                                          SAMPLERS, Scheduler, SCHEDULERS,
                                          Transform, TRANSFORMS)
from catalyst_rl.core.registry import Callback, CALLBACKS
from catalyst_rl.utils.tools.registry import Registry


def _dbs_late_add(r: Registry):
    from . import db as m
    r.add_from_module(m)


DATABASES = Registry("db")
DATABASES.late_add(_dbs_late_add)
Database = DATABASES.add


def _agents_late_add(r: Registry):
    from . import agent as m
    r.add_from_module(m)


AGENTS = Registry("agent")
AGENTS.late_add(_agents_late_add)
Agent = AGENTS.add


def _offpolicy_algorithms_late_add(r: Registry):
    from .offpolicy import algorithms as m
예제 #21
0
from catalyst_rl.contrib.registry import (Criterion, CRITERIONS, GRAD_CLIPPERS,
                                          Model, MODELS, Module, MODULES,
                                          Optimizer, OPTIMIZERS, Sampler,
                                          SAMPLERS, Scheduler, SCHEDULERS,
                                          Transform, TRANSFORMS)
from catalyst_rl.utils.tools.registry import Registry


def _callbacks_loader(r: Registry):
    from catalyst_rl.core import callbacks as m
    r.add_from_module(m)


CALLBACKS = Registry("callback")
CALLBACKS.late_add(_callbacks_loader)
Callback = CALLBACKS.add

__all__ = [
    "Callback",
    "Criterion",
    "Optimizer",
    "Scheduler",
    "Module",
    "Model",
    "Sampler",
    "Transform",
    "CALLBACKS",
    "CRITERIONS",
    "GRAD_CLIPPERS",
    "MODELS",
    "MODULES",
예제 #22
0
def test_kwargs():
    r = Registry("")

    r.add(bar=foo)

    r.get("bar")
예제 #23
0
def _callbacks_loader(r: Registry):
    from catalyst_rl.core import callbacks as m
    r.add_from_module(m)
예제 #24
0
def test_add_lambda_fail():
    r = Registry("")

    with pytest.raises(RegistryException):
        r.add(lambda x: x)
예제 #25
0
def _agents_late_add(r: Registry):
    from . import agent as m
    r.add_from_module(m)
예제 #26
0
def test_add_lambda_override():
    r = Registry("")

    r.add(lambda x: x, name="bar")

    assert "bar" in r._factories
예제 #27
0
def _env_late_add(r: Registry):
    from . import environment as m
    r.add_from_module(m)
예제 #28
0
def test_fail_multiple_with_name():
    r = Registry("")

    with pytest.raises(RegistryException):
        r.add(foo, foo, name="bar")
예제 #29
0
def _dbs_late_add(r: Registry):
    from . import db as m
    r.add_from_module(m)
예제 #30
0
        r.add_from_module(m, prefix=["A.", "albu.", "albumentations."])

        from albumentations import pytorch as p
        r.add_from_module(p, prefix=["A.", "albu.", "albumentations."])

        from catalyst_rl.contrib.data.cv import transforms as t
        r.add_from_module(t, prefix=["catalyst_rl.", "C."])
    except ImportError as ex:
        if os.environ.get("USE_ALBUMENTATIONS", "0") == "1":
            logger.warning(
                "albumentations not available, to install albumentations, "
                "run `pip install albumentations`.")
            raise ex


TRANSFORMS = Registry("transform")
TRANSFORMS.late_add(_transforms_loader)
Transform = TRANSFORMS.add


def _samplers_loader(r: Registry):
    from torch.utils.data import sampler as s
    factories = {
        k: v
        for k, v in s.__dict__.items() if "Sampler" in k and k != "Sampler"
    }
    r.add(**factories)
    from catalyst_rl.data import sampler
    r.add_from_module(sampler)