コード例 #1
0
ファイル: registry.py プロジェクト: rhololkeolke/catalyst-rl
from catalyst_rl.contrib.registry import (Criterion, CRITERIONS, GRAD_CLIPPERS,
                                          Model, MODELS, Module, MODULES,
                                          Optimizer, OPTIMIZERS, Sampler,
                                          SAMPLERS, Scheduler, SCHEDULERS,
                                          Transform, TRANSFORMS)
from catalyst_rl.utils.tools.registry import Registry


def _callbacks_loader(r: Registry):
    from catalyst_rl.core import callbacks as m
    r.add_from_module(m)


CALLBACKS = Registry("callback")
CALLBACKS.late_add(_callbacks_loader)
Callback = CALLBACKS.add

__all__ = [
    "Callback",
    "Criterion",
    "Optimizer",
    "Scheduler",
    "Module",
    "Model",
    "Sampler",
    "Transform",
    "CALLBACKS",
    "CRITERIONS",
    "GRAD_CLIPPERS",
    "MODELS",
    "MODULES",
コード例 #2
0
from catalyst_rl.contrib.registry import (Criterion, CRITERIONS, GRAD_CLIPPERS,
                                          Model, MODELS, Module, MODULES,
                                          Optimizer, OPTIMIZERS, Sampler,
                                          SAMPLERS, Scheduler, SCHEDULERS,
                                          Transform, TRANSFORMS)
from catalyst_rl.core.registry import Callback, CALLBACKS
from catalyst_rl.utils.tools.registry import Registry


def _dbs_late_add(r: Registry):
    from . import db as m
    r.add_from_module(m)


DATABASES = Registry("db")
DATABASES.late_add(_dbs_late_add)
Database = DATABASES.add


def _agents_late_add(r: Registry):
    from . import agent as m
    r.add_from_module(m)


AGENTS = Registry("agent")
AGENTS.late_add(_agents_late_add)
Agent = AGENTS.add


def _offpolicy_algorithms_late_add(r: Registry):
    from .offpolicy import algorithms as m
コード例 #3
0
        from albumentations import pytorch as p
        r.add_from_module(p, prefix=["A.", "albu.", "albumentations."])

        from catalyst_rl.contrib.data.cv import transforms as t
        r.add_from_module(t, prefix=["catalyst_rl.", "C."])
    except ImportError as ex:
        if os.environ.get("USE_ALBUMENTATIONS", "0") == "1":
            logger.warning(
                "albumentations not available, to install albumentations, "
                "run `pip install albumentations`.")
            raise ex


TRANSFORMS = Registry("transform")
TRANSFORMS.late_add(_transforms_loader)
Transform = TRANSFORMS.add


def _samplers_loader(r: Registry):
    from torch.utils.data import sampler as s
    factories = {
        k: v
        for k, v in s.__dict__.items() if "Sampler" in k and k != "Sampler"
    }
    r.add(**factories)
    from catalyst_rl.data import sampler
    r.add_from_module(sampler)


SAMPLERS = Registry("sampler")