def test_add_lambda_override(): """@TODO: Docs. Contribution is welcome.""" r = Registry("") r.add(lambda x: x, name="bar") assert "bar" in r._factories
def test_add_function_name_override(): """@TODO: Docs. Contribution is welcome.""" r = Registry("") r.add(foo, name="bar") assert "bar" in r._factories
def test_add_function(): """@TODO: Docs. Contribution is welcome.""" r = Registry("") r.add(foo) assert "foo" in r._factories
def test_decorator(): r = Registry("") @r.add def bar(): pass r.get("bar")
def test_decorator(): """@TODO: Docs. Contribution is welcome.""" r = Registry("") @r.add def bar(): pass r.get("bar")
def _samplers_loader(r: Registry): from torch.utils.data import sampler as s factories = { k: v for k, v in s.__dict__.items() if "Sampler" in k and k != "Sampler" } r.add(**factories) from catalyst.data import sampler r.add_from_module(sampler)
def _model_loader(r: Registry): from catalyst.contrib import models as m r.add_from_module(m) try: import segmentation_models_pytorch as smp r.add_from_module(smp, prefix="smp.") except ImportError: pass
def test_add_module(): r = Registry("") r.add_from_module(module) r.get("foo") with pytest.raises(RegistryException): r.get_instance("bar")
def test_add_module(): """@TODO: Docs. Contribution is welcome.""" r = Registry("") r.add_from_module(module) r.get("foo") with pytest.raises(RegistryException): r.get_instance("bar")
def test_double_add_same_nofail(): r = Registry("") r.add(foo) # It's ok to add same twice, forced by python relative import # implementation # https://github.com/catalyst-team/catalyst/issues/135 r.add(foo)
def test_kwargs(): """@TODO: Docs. Contribution is welcome.""" r = Registry("") r.add(bar=foo) r.get("bar")
def test_fail_instantiation(): r = Registry("") r.add(foo) with pytest.raises(RegistryException) as e_ifo: r.get_instance("foo", c=1) assert hasattr(e_ifo.value, "__cause__")
def test_instantiations(): r = Registry("") r.add(foo) res = r.get_instance("foo", 1, 2) assert res == {"a": 1, "b": 2} res = r.get_instance("foo", 1, b=2) assert res == {"a": 1, "b": 2} res = r.get_instance("foo", a=1, b=2) assert res == {"a": 1, "b": 2}
def test_fail_instantiation(): """@TODO: Docs. Contribution is welcome.""" r = Registry("") r.add(foo) with pytest.raises(RegistryException) as e_ifo: r.get_instance("foo", c=1) assert hasattr(e_ifo.value, "__cause__")
def test_fail_double_add_different(): r = Registry("") r.add(foo) with pytest.raises(RegistryException): def bar(): pass r.add(foo=bar)
def test_from_config(): r = Registry("obj") r.add(foo) res = r.get_from_params(**{"obj": "foo", "a": 1, "b": 2}) assert res == {"a": 1, "b": 2} res = r.get_from_params(**{}) assert res is None
def test_instantiations(): """@TODO: Docs. Contribution is welcome.""" r = Registry("") r.add(foo) res = r.get_instance("foo", 1, 2) assert res == {"a": 1, "b": 2} res = r.get_instance("foo", 1, b=2) assert res == {"a": 1, "b": 2} res = r.get_instance("foo", a=1, b=2) assert res == {"a": 1, "b": 2}
def test_from_config(): """@TODO: Docs. Contribution is welcome.""" r = Registry("obj") r.add(foo) res = r.get_from_params(**{"obj": "foo", "a": 1, "b": 2}) assert res == {"a": 1, "b": 2} res = r.get_from_params(**{}) assert res is None
def test_fail_double_add_different(): """@TODO: Docs. Contribution is welcome.""" r = Registry("") r.add(foo) with pytest.raises(RegistryException): def bar(): pass r.add(foo=bar)
def test_meta_factory(): def meta_1(fn, args, kwargs): return fn def meta_2(fn, args, kwargs): return 1 r = Registry("obj", meta_1) r.add(foo) res = r.get_from_params(**{"obj": "foo"}) assert res == foo res = r.get_from_params(**{"obj": "foo"}, meta_factory=meta_2) assert res == 1
def _transforms_loader(r: Registry): try: import albumentations as m r.add_from_module(m, prefix=["A.", "albu.", "albumentations."]) from albumentations import pytorch as p r.add_from_module(p, prefix=["A.", "albu.", "albumentations."]) from catalyst.contrib.data.cv import transforms as t r.add_from_module(t, prefix=["catalyst.", "C."]) except ImportError as ex: if os.environ.get("USE_ALBUMENTATIONS", "0") == "1": logger.warning( "albumentations not available, to install albumentations, " "run `pip install albumentations`.") raise ex
def test_meta_factory(): """@TODO: Docs. Contribution is welcome.""" # noqa: D202 def meta_1(fn, args, kwargs): return fn def meta_2(fn, args, kwargs): return 1 r = Registry("obj", meta_1) r.add(foo) res = r.get_from_params(**{"obj": "foo"}) assert res == foo res = r.get_from_params(**{"obj": "foo"}, meta_factory=meta_2) assert res == 1
def _criterion_loader(r: Registry): from catalyst.contrib.nn import criterion as m r.add_from_module(m)
def _modules_loader(r: Registry): from catalyst.contrib.nn import modules as m r.add_from_module(m)
def _grad_clip_loader(r: Registry): from torch.nn.utils import clip_grad as m r.add_from_module(m)
def _schedulers_loader(r: Registry): from catalyst.contrib.nn import schedulers as m r.add_from_module(m)
def _optimizers_loader(r: Registry): from catalyst.contrib.nn import optimizers as m r.add_from_module(m)
r.add_from_module(m, prefix=["A.", "albu.", "albumentations."]) from albumentations import pytorch as p r.add_from_module(p, prefix=["A.", "albu.", "albumentations."]) from catalyst.contrib.data.cv import transforms as t r.add_from_module(t, prefix=["catalyst.", "C."]) except ImportError as ex: if os.environ.get("USE_ALBUMENTATIONS", "0") == "1": logger.warning( "albumentations not available, to install albumentations, " "run `pip install albumentations`.") raise ex TRANSFORMS = Registry("transform") TRANSFORMS.late_add(_transforms_loader) Transform = TRANSFORMS.add def _samplers_loader(r: Registry): from torch.utils.data import sampler as s factories = { k: v for k, v in s.__dict__.items() if "Sampler" in k and k != "Sampler" } r.add(**factories) from catalyst.data import sampler r.add_from_module(sampler)
def _callbacks_loader(r: Registry): from catalyst.core import callbacks as m r.add_from_module(m)
SAMPLERS, Scheduler, SCHEDULERS, Transform, TRANSFORMS, ) from catalyst.utils.tools.registry import Registry def _callbacks_loader(r: Registry): from catalyst.core import callbacks as m r.add_from_module(m) CALLBACKS = Registry("callback") CALLBACKS.late_add(_callbacks_loader) Callback = CALLBACKS.add __all__ = [ "Callback", "Criterion", "Optimizer", "Scheduler", "Module", "Model", "Sampler", "Transform", "CALLBACKS", "CRITERIONS", "GRAD_CLIPPERS",