示例#1
0
import sys
sys.path.append('..')
from registry import Registry

META_ARCH_REGISTRY = Registry("META_ARCH")
META_ARCH_REGISTRY.__doc__ = """
Registry for meta-architectures, i.e. the whole model.

The registered object will be called with `obj(cfg)`
and expected to return a `nn.Module` object.
"""

__all__ = [
    'build_model',
    'META_ARCH_REGISTRY',
]

def build_model(cfg):
    """
    Built the whole model, defined by `cfg.MODEL.META_ARCH`.
    """
    meta_arch = cfg.model.name
    return META_ARCH_REGISTRY.get(meta_arch)(cfg)
    
示例#2
0
#coding=utf-8
from torchvision import transforms
from . import autoaugment

from registry import Registry
from utils import MyLogger

TRANSFORMS_REGISTRY = Registry('TRANSFORMS')
TRANSFORMS_REGISTRY.__doc__ = """
Registry for data transform functions, i.e. torchvision.transforms

The registered object will be called with `obj(cfg)`
"""

LABEL_TRANSFORMS_REGISTRY = Registry('LABEL_TRANSFORMS')
LABEL_TRANSFORMS_REGISTRY.__doc__ = """
Registry for label transform functions, i.e. torchvision.transforms

The registered object will be called with `obj(cfg)`
"""

__all__ = [
    'build_transforms', 'build_label_transforms', 'TRANSFORMS_REGISTRY',
    'LABEL_TRANSFORMS_REGISTRY', 'DefaultTransforms', 'BaseTransforms'
]


def build_transforms(cfg):
    """
    Built the transforms, defined by `cfg.transforms.name`.
    """
示例#3
0
import sys
sys.path.append('..')
from registry import Registry

LOSS_FN_REGISTRY = Registry("LOSS_FN")
LOSS_FN_REGISTRY.__doc__ = """
Registry for loss function, e.g. cross entropy loss.

The registered object will be called with `obj(cfg)`
"""

__all__ = ['build_loss_fn', 'LOSS_FN_REGISTRY']


def build_loss_fn(cfg):
    """
    Built the loss function, defined by `cfg.loss.name`.
    """
    name = cfg.loss.name
    return LOSS_FN_REGISTRY.get(name)(cfg)
示例#4
0
import sys

sys.path.append('..')
from registry import Registry

TRAINER_REGISTRY = Registry("TRAINER")
TRAINER_REGISTRY.__doc__ = """
Registry for trainer, i.e. the OnehotTrainer.

The registered object will be called with `obj(cfg)`
and expected to return a `nn.Module` object.
"""

__all__ = [
    'build_trainer',
    'TRAINER_REGISTRY',
]


def build_trainer(cfg):
    """
    Built the trainer, defined by `cfg.trainer.name`.
    """
    trainer = cfg.trainer.name
    return TRAINER_REGISTRY.get(trainer)(cfg)
示例#5
0
import sys
sys.path.append('..')
from registry import Registry

EVALUATOR_REGISTRY = Registry("TRAINER")
EVALUATOR_REGISTRY.__doc__ = """
Registry for evaluator, i.e. the DefaultEvaluator.

The registered object will be called with `obj(cfg)`
"""

__all__ = [
    'build_evaluator',
    'EVALUATOR_REGISTRY',
]

def build_evaluator(cfg):
    """
    Built the trainer, defined by `cfg.trainer.name`.
    """
    evaluator = cfg.evaluator.name
    return EVALUATOR_REGISTRY.get(evaluator)(cfg)
    
示例#6
0
import sys
sys.path.append('..')
from registry import Registry

MUTATOR_REGISTRY = Registry("MUTATOR")
MUTATOR_REGISTRY.__doc__ = """
Registry for mutator.

The registered object will be called with `obj(cfg)`
"""

__all__ = [
    'MUTATOR_REGISTRY',
    'build_mutator',
]


def build_mutator(model, cfg):
    """
    Built the mutator.
    """
    name = cfg.mutator.name
    return MUTATOR_REGISTRY.get(name)(model, cfg)
示例#7
0
import sys
sys.path.append('..')
from registry import Registry

DATASET_REGISTRY = Registry("DATASET")
DATASET_REGISTRY.__doc__ = """
Registry for dataset, i.e. torch.utils.data.Dataset.

The registered object will be called with `obj(cfg)`
"""

__all__ = [
    'build_dataset',
    'DATASET_REGISTRY'
]

def build_dataset(cfg):
    """
    Built the dataset, defined by `cfg.dataset.name`.
    """
    name = cfg.dataset.name
    return DATASET_REGISTRY.get(name)(cfg)