Ejemplo n.º 1
0
def _datasets_loader(r: registry.Registry):
    from catalyst.data import dataset as m  # noqa: WPS347

    r.add_from_module(m)

    from catalyst.contrib import datasets as m_contrib  # noqa: WPS347

    r.add_from_module(m_contrib)
Ejemplo n.º 2
0
def _engines_loader(r: registry.Registry):
    from catalyst.core.engine import IEngine

    r.add(IEngine)

    from catalyst import engines as m  # noqa: WPS347

    r.add_from_module(m)
Ejemplo n.º 3
0
def _runners_loader(r: registry.Registry):
    from catalyst.core.runner import IRunner

    r.add(IRunner)
    r.add(IRunner)

    from catalyst import runners as m  # noqa: WPS347

    r.add_from_module(m)
Ejemplo n.º 4
0
def _callbacks_loader(r: registry.Registry):
    from catalyst.core.callback import Callback, CallbackWrapper

    r.add(Callback)
    r.add(CallbackWrapper)

    from catalyst import callbacks as m  # noqa: WPS347

    r.add_from_module(m)
Ejemplo n.º 5
0
def _optimizers_loader(r: registry.Registry):
    from catalyst.contrib.nn import optimizers as m

    r.add_from_module(m)

    if SETTINGS.fairscale_required:
        from fairscale import optim as m2

        r.add_from_module(m2, prefix=["fairscale."])
Ejemplo n.º 6
0
def test_add_module():
    """@TODO: Docs. Contribution is welcome."""
    r = Registry("")

    r.add_from_module(module)

    r.get("foo")

    with pytest.raises(RegistryException):
        r.get_instance("bar")
Ejemplo n.º 7
0
def _samplers_loader(r: Registry):
    from torch.utils.data import sampler as s

    factories = {
        k: v
        for k, v in s.__dict__.items() if "Sampler" in k and k != "Sampler"
    }
    r.add(**factories)
    from catalyst.data import sampler

    r.add_from_module(sampler)
Ejemplo n.º 8
0
def _model_loader(r: Registry):
    from catalyst.contrib import models as m

    r.add_from_module(m)

    try:
        import segmentation_models_pytorch as smp

        r.add_from_module(smp, prefix="smp.")
    except ImportError as ex:
        if settings.segmentation_models_required:
            logger.warning("segmentation_models_pytorch not available,"
                           " to install segmentation_models_pytorch,"
                           " run `pip install segmentation-models-pytorch`.")
            raise ex
Ejemplo n.º 9
0
def _transforms_loader(r: registry.Registry):
    from catalyst.data import transforms as t

    # add `'transform.'` prefix to avoid nameing conflicts with other catalyst modules
    r.add_from_module(t, prefix=["transform."])

    if SETTINGS.albu_required:
        import albumentations as m

        r.add_from_module(m, prefix=["A.", "albu.", "albumentations."])

        from albumentations import pytorch as p

        r.add_from_module(p, prefix=["A.", "albu.", "albumentations."])
Ejemplo n.º 10
0
def _transforms_loader(r: Registry):
    from torch.jit.frontend import UnsupportedNodeError

    try:
        import albumentations as m

        r.add_from_module(m, prefix=["A.", "albu.", "albumentations."])

        from albumentations import pytorch as p

        r.add_from_module(p, prefix=["A.", "albu.", "albumentations."])

        from catalyst.contrib.data.cv import transforms as t

        r.add_from_module(t, prefix=["catalyst.", "C."])
    except ImportError as ex:
        if settings.albumentations_required:
            logger.warning(
                "albumentations not available, to install albumentations, "
                "run `pip install albumentations`.")
            raise ex

    try:
        from kornia import augmentation as k

        r.add_from_module(k, prefix=["kornia."])
    except ImportError as ex:
        if settings.kornia_required:
            logger.warning("kornia not available, to install kornia, "
                           "run `pip install kornia`.")
            raise ex
    except UnsupportedNodeError as ex:
        logger.warning(
            "kornia has requirement torch>=1.5.0, probably you have"
            " an old version of torch which is incompatible.\n"
            "To update pytorch, run `pip install -U 'torch>=1.5.0'`.")
        if settings.kornia_required:
            raise ex
Ejemplo n.º 11
0
def _transforms_loader(r: Registry):
    try:
        import albumentations as m

        r.add_from_module(m, prefix=["A.", "albu.", "albumentations."])

        from albumentations import pytorch as p

        r.add_from_module(p, prefix=["A.", "albu.", "albumentations."])

        from catalyst.contrib.data.cv import transforms as t

        r.add_from_module(t, prefix=["catalyst.", "C."])
    except ImportError as ex:
        if settings.albumentations_required:
            logger.warning(
                "albumentations not available, to install albumentations, "
                "run `pip install albumentations`.")
            raise ex
Ejemplo n.º 12
0
def _modules_loader(r: Registry):
    from catalyst.contrib.nn import modules as m

    r.add_from_module(m)
Ejemplo n.º 13
0
def _grad_clip_loader(r: Registry):
    from torch.nn.utils import clip_grad as m

    r.add_from_module(m)
Ejemplo n.º 14
0
def _loggers_loader(r: registry.Registry):
    from catalyst import loggers as m  # noqa: WPS347

    r.add_from_module(m)
Ejemplo n.º 15
0
def _schedulers_loader(r: Registry):
    from catalyst.contrib.nn import schedulers as m

    r.add_from_module(m)
Ejemplo n.º 16
0
def _optimizers_loader(r: Registry):
    from catalyst.contrib.nn import optimizers as m

    r.add_from_module(m)
Ejemplo n.º 17
0
def _criterion_loader(r: Registry):
    from catalyst.contrib.nn import criterion as m

    r.add_from_module(m)
Ejemplo n.º 18
0
def _torch_functional_loader(r: registry.Registry):
    import torch.nn.functional as F

    r.add_from_module(F, ["F."])
Ejemplo n.º 19
0
def _torch_loader(r: registry.Registry):
    import torch as m

    r.add_from_module(m, ["torch."], ignore_all=True)
Ejemplo n.º 20
0
def _model_loader(r: registry.Registry):
    from catalyst.contrib import models as m

    r.add_from_module(m)
Ejemplo n.º 21
0
def _callbacks_loader(r: Registry):
    from catalyst.dl import callbacks as m  # noqa: WPS347

    r.add_from_module(m)
Ejemplo n.º 22
0
def _callbacks_loader(r: Registry):
    from catalyst.core import callbacks as m

    r.add_from_module(m)