예제 #1
0
import copy
import torch
import inspect
from UW.core.Registry import Registry, build_from_cfg
from torch.optim import lr_scheduler

SCHEDULER = Registry('scheduler')
SCHEDULER_BUILDERS = Registry('scheduler builder')


def build_scheduler(cfg, default_args=None):
    scheduler = build_from_cfg(cfg, SCHEDULER, default_args)
    return scheduler


@SCHEDULER.register_module()
class Epoch(object):
    '''
    examples:
    >>> lr_config = dict(type='Epoch',          # Epoch or Iter
    >>>                  warmup='exp',       # liner, step, exp,
    >>>                  step=[10, 20],          # start with 1
    >>>                  liner_end=0.00001,
    >>>                  step_gamma=0.1,
    >>>                  exp_gamma=0.9)
    or
    >>>lr_config = dict(type='Epoch',          # Epoch or Iter
    >>>                 warmup='linear',       # liner, step, exp,
    >>>                 step=[10, 20],          # start with 1
    >>>                 liner_end=0.00001,
    >>>                 step_gamma=0.1,
예제 #2
0
from UW.core.Registry import Registry, build_from_cfg
from torch.utils.data import DataLoader

DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')


def build_dataset(cfg, default_args=None):
    dataset = build_from_cfg(cfg, DATASETS, default_args)
    return dataset


def build_dataloader(dataset,
                     samples_per_gpu,
                     workers_per_gpu,
                     num_gpus=1,
                     shuffle=True,
                     **kwargs):
    """Build PyTorch DataLoader.

    In distributed training, each GPU/process has a dataloader.
    In non-distributed training, there is only one dataloader for all GPUs.

    Args:
        dataset (Dataset): A PyTorch dataset.
        samples_per_gpu (int): Number of training samples on each GPU, i.e.,
            batch size of each GPU.
        workers_per_gpu (int): How many subprocesses to use for data loading
            for each GPU.
        num_gpus (int): Number of GPUs. Only used in non-distributed training.
        dist (bool): Distributed training/test or not. Default: True.
예제 #3
0
from UW.core.Registry import Registry, build_from_cfg
from torch import nn

NETWORK = Registry('network')
BACKBONES = Registry('backbone')


def build(cfg, registry, default_args=None):
    """Build a module.

    Args:
        cfg (dict, list[dict]): The config of modules, is is either a dict
            or a list of configs.
        registry (:obj:`Registry`): A registry the module belongs to.
        default_args (dict, optional): Default arguments to build the module.
            Defaults to None.

    Returns:
        nn.Module: A built nn module.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


def build_network(cfg, train_cfg=None, test_cfg=None):
    """Build detector."""
예제 #4
0
from UW.core.Registry import Registry, build_from_cfg

LOSSES = Registry('losses')


def build_loss(cfg):
    """Build loss."""
    losses = build_from_cfg(cfg, LOSSES)
    return losses
예제 #5
0
def test_registry():
    CATS = Registry('cat')
    assert CATS.name == 'cat'
    assert CATS.module_dict == {}
    assert len(CATS) == 0

    @CATS.register_module()
    class BritishShorthair:
        pass

    assert len(CATS) == 1
    assert CATS.get('BritishShorthair') is BritishShorthair

    class Munchkin:
        pass

    CATS.register_module(Munchkin)
    assert len(CATS) == 2
    assert CATS.get('Munchkin') is Munchkin
    assert 'Munchkin' in CATS

    with pytest.raises(KeyError):
        CATS.register_module(Munchkin)

    CATS.register_module(Munchkin, force=True)
    assert len(CATS) == 2

    # force=False
    with pytest.raises(KeyError):

        @CATS.register_module()
        class BritishShorthair:
            pass

    @CATS.register_module(force=True)
    class BritishShorthair:
        pass

    assert len(CATS) == 2

    assert CATS.get('PersianCat') is None
    assert 'PersianCat' not in CATS

    @CATS.register_module(name='Siamese')
    class SiameseCat:
        pass

    assert CATS.get('Siamese').__name__ == 'SiameseCat'

    class SphynxCat:
        pass

    CATS.register_module(name='Sphynx', module=SphynxCat)
    assert CATS.get('Sphynx') is SphynxCat

    CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat)
    assert CATS.get('Sphynx2') is SphynxCat

    repr_str = 'Registry(name=cat, items={'
    repr_str += ("'BritishShorthair': <class 'test_registry.test_registry."
                 "<locals>.BritishShorthair'>, ")
    repr_str += ("'Munchkin': <class 'test_registry.test_registry."
                 "<locals>.Munchkin'>, ")
    repr_str += ("'Siamese': <class 'test_registry.test_registry."
                 "<locals>.SiameseCat'>, ")
    repr_str += ("'Sphynx': <class 'test_registry.test_registry."
                 "<locals>.SphynxCat'>, ")
    repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
                 "<locals>.SphynxCat'>, ")
    repr_str += ("'Sphynx2': <class 'test_registry.test_registry."
                 "<locals>.SphynxCat'>")
    repr_str += '})'
    assert repr(CATS) == repr_str

    # name type
    with pytest.raises(AssertionError):
        CATS.register_module(name=7474741, module=SphynxCat)

    # the registered module should be a class
    with pytest.raises(TypeError):
        CATS.register_module(0)

    # can only decorate a class
    with pytest.raises(TypeError):

        @CATS.register_module()
        def some_method():
            pass

    # begin: test old APIs
    with pytest.warns(UserWarning):
        CATS.register_module(SphynxCat)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

    with pytest.warns(UserWarning):
        CATS.register_module(SphynxCat, force=True)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

    with pytest.warns(UserWarning):

        @CATS.register_module
        class NewCat:
            pass

        assert CATS.get('NewCat').__name__ == 'NewCat'

    with pytest.warns(UserWarning):
        CATS.deprecated_register_module(SphynxCat, force=True)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

    with pytest.warns(UserWarning):

        @CATS.deprecated_register_module
        class CuteCat:
            pass

        assert CATS.get('CuteCat').__name__ == 'CuteCat'

    with pytest.warns(UserWarning):

        @CATS.deprecated_register_module(force=True)
        class NewCat2:
            pass

        assert CATS.get('NewCat2').__name__ == 'NewCat2'
예제 #6
0
def test_build_from_cfg():
    BACKBONES = Registry('backbone')

    @BACKBONES.register_module()
    class ResNet:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    @BACKBONES.register_module()
    class ResNeXt:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    cfg = dict(type='ResNet', depth=50)
    model = build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(type='ResNet', depth=50)
    model = build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type='ResNeXt', depth=50, stages=3)
    model = build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNeXt)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type=ResNet, depth=50)
    model = build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # type defined using default_args
    cfg = dict(depth=50)
    model = build_from_cfg(
        cfg, BACKBONES, default_args=dict(type='ResNet'))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(depth=50)
    model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # not a registry
    with pytest.raises(TypeError):
        cfg = dict(type='VGG')
        model = build_from_cfg(cfg, 'BACKBONES')

    # non-registered class
    with pytest.raises(KeyError):
        cfg = dict(type='VGG')
        model = build_from_cfg(cfg, BACKBONES)

    # default_args must be a dict or None
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = build_from_cfg(cfg, BACKBONES, default_args=1)

    # cfg['type'] should be a str or class
    with pytest.raises(TypeError):
        cfg = dict(type=1000)
        model = build_from_cfg(cfg, BACKBONES)

    # cfg should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50, stages=4)
        model = build_from_cfg(cfg, BACKBONES)

    # cfg or default_args should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50)
        model = build_from_cfg(
            cfg, BACKBONES, default_args=dict(stages=4))

    # incorrect registry type
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = build_from_cfg(cfg, 'BACKBONES')

    # incorrect default_args type
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = build_from_cfg(cfg, BACKBONES, default_args=0)
예제 #7
0
import copy
import inspect
import sys
import torch
from UW.core.Registry import Registry, build_from_cfg

OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder')


def register_torch_optimizers():
    torch_optimizers = []
    for module_name in dir(torch.optim):
        if module_name.startswith('__'):
            continue
        _optim = getattr(torch.optim, module_name)
        if inspect.isclass(_optim) and issubclass(_optim,
                                                  torch.optim.Optimizer):
            OPTIMIZERS.register_module()(_optim)
            torch_optimizers.append(module_name)
    return torch_optimizers


TORCH_OPTIMIZERS = register_torch_optimizers()


def build_optimizer_constructor(cfg):
    return build_from_cfg(cfg, OPTIMIZER_BUILDERS)


def build_optimizer(model, cfg):