def test_add_function_name_override():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(foo, name="bar")

    assert "bar" in r._factories  # noqa: WPS437
def test_add_lambda_override():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(lambda x: x, name="bar")

    assert "bar" in r._factories  # noqa: WPS437
def test_add_function():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(foo)

    assert "foo" in r._factories  # noqa: WPS437
def test_decorator():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    @r.add
    def bar():
        pass

    r.get("bar")
Exemple #5
0
def _samplers_loader(r: Registry):
    from torch.utils.data import sampler as s

    factories = {
        k: v
        for k, v in s.__dict__.items() if "Sampler" in k and k != "Sampler"
    }
    r.add(**factories)
    from catalyst.data import sampler

    r.add_from_module(sampler)
Exemple #6
0
def _model_loader(r: Registry):
    from catalyst.contrib import models as m

    r.add_from_module(m)

    try:
        import segmentation_models_pytorch as smp

        r.add_from_module(smp, prefix="smp.")
    except ImportError as ex:
        if SETTINGS.segmentation_models_required:
            logger.warning("segmentation_models_pytorch not available,"
                           " to install segmentation_models_pytorch,"
                           " run `pip install segmentation-models-pytorch`.")
            raise ex
Exemple #7
0
def _transforms_loader(r: Registry):
    from torch.jit.frontend import UnsupportedNodeError

    from catalyst.data.cv.transforms import torch as t

    r.add_from_module(t, prefix=["catalyst.", "C."])

    try:
        import albumentations as m

        r.add_from_module(m, prefix=["A.", "albu.", "albumentations."])

        from albumentations import pytorch as p

        r.add_from_module(p, prefix=["A.", "albu.", "albumentations."])

        from catalyst.data.cv.transforms import albumentations as t

        r.add_from_module(t, prefix=["catalyst.", "C."])
    except ImportError as ex:
        if SETTINGS.albumentations_required:
            logger.warning(
                "albumentations not available, to install albumentations, "
                "run `pip install albumentations`."
            )
            raise ex

    try:
        from kornia import augmentation as k

        r.add_from_module(k, prefix=["kornia."])
    except ImportError as ex:
        if SETTINGS.kornia_required:
            logger.warning(
                "kornia not available, to install kornia, "
                "run `pip install kornia`."
            )
            raise ex
    except UnsupportedNodeError as ex:
        logger.warning(
            "kornia has requirement torch>=1.5.0, probably you have"
            " an old version of torch which is incompatible.\n"
            "To update pytorch, run `pip install -U 'torch>=1.5.0'`."
        )
        if SETTINGS.kornia_required:
            raise ex
def test_add_module():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add_from_module(module)

    r.get("foo")

    with pytest.raises(RegistryException):
        r.get_instance("bar")
Exemple #9
0
def _experiments_loader(r: Registry):
    from catalyst.core import IExperiment, IStageBasedRunner

    r.add(IExperiment)
    r.add(IStageBasedRunner)

    from catalyst.dl import experiment as m  # noqa: WPS347

    r.add_from_module(m)

    from catalyst.contrib.dl import experiment as m  # noqa: WPS347

    r.add_from_module(m)
Exemple #10
0
def _runners_loader(r: Registry):
    from catalyst.core import IRunner, IStageBasedRunner

    r.add(IRunner)
    r.add(IStageBasedRunner)

    from catalyst.dl import runner as m  # noqa: WPS347

    r.add_from_module(m)

    from catalyst.contrib.dl import runner as m  # noqa: WPS347

    r.add_from_module(m)
Exemple #11
0
def test_kwargs():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(bar=foo)

    r.get("bar")
Exemple #12
0
def test_double_add_same_nofail():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")
    r.add(foo)
    # It's ok to add same twice, forced by python relative import
    # implementation
    # https://github.com/catalyst-team/catalyst/issues/135
    r.add(foo)
Exemple #13
0
def _callbacks_loader(r: Registry):
    from catalyst.core.callback import Callback, CallbackWrapper

    r.add(Callback)
    r.add(CallbackWrapper)

    from catalyst import callbacks as m  # noqa: WPS347

    r.add_from_module(m)
Exemple #14
0
def test_fail_instantiation():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(foo)

    with pytest.raises(RegistryException) as e_ifo:
        r.get_instance("foo", c=1)

    assert hasattr(e_ifo.value, "__cause__")
Exemple #15
0
def _experiments_loader(r: Registry):
    from catalyst.core.experiment import IExperiment

    r.add(IExperiment)

    from catalyst import experiments as m

    r.add_from_module(m)  # noqa: WPS347

    r.add_from_module(m)
Exemple #16
0
def test_instantiations():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add(foo)

    res = r.get_instance("foo", 1, 2)
    assert res == {"a": 1, "b": 2}

    res = r.get_instance("foo", 1, b=2)
    assert res == {"a": 1, "b": 2}

    res = r.get_instance("foo", a=1, b=2)
    assert res == {"a": 1, "b": 2}
Exemple #17
0
def test_from_config():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("obj")

    r.add(foo)

    res = r.get_from_params(**{"obj": "foo", "a": 1, "b": 2})
    assert res == {"a": 1, "b": 2}

    res = r.get_from_params(**{})
    assert res is None
Exemple #18
0
def test_fail_double_add_different():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")
    r.add(foo)

    with pytest.raises(RegistryException):

        def bar():
            pass

        r.add(foo=bar)
Exemple #19
0
def _callbacks_loader(r: Registry):
    from catalyst.core import callbacks as m

    r.add_from_module(m)

    from catalyst.dl import callbacks as m  # noqa: WPS347

    r.add_from_module(m)

    from catalyst.contrib.dl import callbacks as m  # noqa: WPS347

    r.add_from_module(m)
Exemple #20
0
def test_meta_factory():
    """@TODO: Docs. Contribution is welcome."""  # noqa: D202

    def meta_factory1(fn, args, kwargs):
        return fn

    def meta_factory2(fn, args, kwargs):
        return 1

    r = Registry("obj", meta_factory1)
    r.add(foo)

    res = r.get_from_params(**{"obj": "foo"})
    assert res == foo

    res = r.get_from_params(**{"obj": "foo"}, meta_factory=meta_factory2)
    assert res == 1
Exemple #21
0
        r.add_from_module(t, prefix=["catalyst.", "C."])
    except ImportError as ex:
        if SETTINGS.kornia_required:
            logger.warning("kornia not available, to install kornia, "
                           "run `pip install kornia`.")
            raise ex
    except UnsupportedNodeError as ex:
        logger.warning(
            "kornia has requirement torch>=1.5.0, probably you have"
            " an old version of torch which is incompatible.\n"
            "To update pytorch, run `pip install -U 'torch>=1.5.0'`.")
        if SETTINGS.kornia_required:
            raise ex


TRANSFORM = Registry("transform")
TRANSFORM.late_add(_transforms_loader)
Transform = TRANSFORM.add


def _samplers_loader(r: Registry):
    from torch.utils.data import sampler as s

    factories = {
        k: v
        for k, v in s.__dict__.items() if "Sampler" in k and k != "Sampler"
    }
    r.add(**factories)
    from catalyst.data import sampler

    r.add_from_module(sampler)
Exemple #22
0
def test_fail_multiple_with_name():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    with pytest.raises(RegistryException):
        r.add(foo, foo, name="bar")
Exemple #23
0
def _schedulers_loader(r: Registry):
    from catalyst.contrib.nn import schedulers as m

    r.add_from_module(m)
Exemple #24
0
def _optimizers_loader(r: Registry):
    from catalyst.contrib.nn import optimizers as m

    r.add_from_module(m)
Exemple #25
0
def _criterion_loader(r: Registry):
    from catalyst.contrib.nn import criterion as m

    r.add_from_module(m)
Exemple #26
0
def _modules_loader(r: Registry):
    from catalyst.contrib.nn import modules as m

    r.add_from_module(m)
Exemple #27
0
def _grad_clip_loader(r: Registry):
    from torch.nn.utils import clip_grad as m

    r.add_from_module(m)
Exemple #28
0
def test_add_lambda_fail():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    with pytest.raises(RegistryException):
        r.add(lambda x: x)