import copy import torch import inspect from Dehaze.core.Registry import Registry, build_from_cfg from torch.optim import lr_scheduler SCHEDULER = Registry('scheduler') SCHEDULER_BUILDERS = Registry('scheduler builder') def build_scheduler(cfg, default_args=None): scheduler = build_from_cfg(cfg, SCHEDULER, default_args) return scheduler @SCHEDULER.register_module() class Epoch(object): ''' examples: >>> lr_config = dict(type='Epoch', # Epoch or Iter >>> warmup='exp', # liner, step, exp, >>> step=[10, 20], # start with 1 >>> liner_end=0.00001, >>> step_gamma=0.1, >>> exp_gamma=0.9) or >>>lr_config = dict(type='Epoch', # Epoch or Iter >>> warmup='linear', # liner, step, exp, >>> step=[10, 20], # start with 1 >>> liner_end=0.00001, >>> step_gamma=0.1, >>> exp_gamma=0.9)
import copy import inspect import sys import torch from Dehaze.core.Registry import Registry, build_from_cfg OPTIMIZERS = Registry('optimizer') OPTIMIZER_BUILDERS = Registry('optimizer builder') def register_torch_optimizers(): torch_optimizers = [] for module_name in dir(torch.optim): if module_name.startswith('__'): continue _optim = getattr(torch.optim, module_name) if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): OPTIMIZERS.register_module()(_optim) torch_optimizers.append(module_name) return torch_optimizers TORCH_OPTIMIZERS = register_torch_optimizers() def build_optimizer_constructor(cfg): return build_from_cfg(cfg, OPTIMIZER_BUILDERS) def build_optimizer(model, cfg):
from Dehaze.core.Registry import Registry, build_from_cfg LOSSES = Registry('losses') def build_loss(cfg): """Build loss.""" losses = build_from_cfg(cfg, LOSSES) return losses
from Dehaze.core.Registry import Registry, build_from_cfg from Dehaze.core.Datasets.GroupSampler import GroupSampler, collate import random from functools import partial import numpy as np from torch.utils.data import DataLoader DATASETS = Registry('dataset') PIPELINES = Registry('pipeline') def build_dataset(cfg, default_args=None): dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset def build_dataloader(dataset, samples_per_gpu, workers_per_gpu, num_gpus=1, shuffle=True, seed=None, **kwargs): """Build PyTorch DataLoader. In distributed training, each GPU/process has a dataloader. In non-distributed training, there is only one dataloader for all GPUs. Args: dataset (Dataset): A PyTorch dataset.
from Dehaze.core.Registry import Registry, build_from_cfg from torch import nn NETWORK = Registry('network') BACKBONES = Registry('backbone') def build(cfg, registry, default_args=None): """Build a module. Args: cfg (dict, list[dict]): The config of modules, is is either a dict or a list of configs. registry (:obj:`Registry`): A registry the module belongs to. default_args (dict, optional): Default arguments to build the module. Defaults to None. Returns: nn.Module: A built nn module. """ if isinstance(cfg, list): modules = [ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg ] return nn.Sequential(*modules) else: return build_from_cfg(cfg, registry, default_args) def build_network(cfg, train_cfg=None, test_cfg=None): """Build detector."""
def test_registry(): CATS = Registry('cat') assert CATS.name == 'cat' assert CATS.module_dict == {} assert len(CATS) == 0 @CATS.register_module() class BritishShorthair: pass assert len(CATS) == 1 assert CATS.get('BritishShorthair') is BritishShorthair class Munchkin: pass CATS.register_module(Munchkin) assert len(CATS) == 2 assert CATS.get('Munchkin') is Munchkin assert 'Munchkin' in CATS with pytest.raises(KeyError): CATS.register_module(Munchkin) CATS.register_module(Munchkin, force=True) assert len(CATS) == 2 # force=False with pytest.raises(KeyError): @CATS.register_module() class BritishShorthair: pass @CATS.register_module(force=True) class BritishShorthair: pass assert len(CATS) == 2 assert CATS.get('PersianCat') is None assert 'PersianCat' not in CATS @CATS.register_module(name='Siamese') class SiameseCat: pass assert CATS.get('Siamese').__name__ == 'SiameseCat' class SphynxCat: pass CATS.register_module(name='Sphynx', module=SphynxCat) assert CATS.get('Sphynx') is SphynxCat CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat) assert CATS.get('Sphynx2') is SphynxCat repr_str = 'Registry(name=cat, items={' repr_str += ("'BritishShorthair': <class 'test_registry.test_registry." "<locals>.BritishShorthair'>, ") repr_str += ("'Munchkin': <class 'test_registry.test_registry." "<locals>.Munchkin'>, ") repr_str += ("'Siamese': <class 'test_registry.test_registry." "<locals>.SiameseCat'>, ") repr_str += ("'Sphynx': <class 'test_registry.test_registry." "<locals>.SphynxCat'>, ") repr_str += ("'Sphynx1': <class 'test_registry.test_registry." "<locals>.SphynxCat'>, ") repr_str += ("'Sphynx2': <class 'test_registry.test_registry." "<locals>.SphynxCat'>") repr_str += '})' assert repr(CATS) == repr_str # name type with pytest.raises(AssertionError): CATS.register_module(name=7474741, module=SphynxCat) # the registered module should be a class with pytest.raises(TypeError): CATS.register_module(0) # can only decorate a class with pytest.raises(TypeError): @CATS.register_module() def some_method(): pass # begin: test old APIs with pytest.warns(UserWarning): CATS.register_module(SphynxCat) assert CATS.get('SphynxCat').__name__ == 'SphynxCat' with pytest.warns(UserWarning): CATS.register_module(SphynxCat, force=True) assert CATS.get('SphynxCat').__name__ == 'SphynxCat' with pytest.warns(UserWarning): @CATS.register_module class NewCat: pass assert CATS.get('NewCat').__name__ == 'NewCat' with pytest.warns(UserWarning): CATS.deprecated_register_module(SphynxCat, force=True) assert CATS.get('SphynxCat').__name__ == 'SphynxCat' with pytest.warns(UserWarning): @CATS.deprecated_register_module class CuteCat: pass assert CATS.get('CuteCat').__name__ == 'CuteCat' with pytest.warns(UserWarning): @CATS.deprecated_register_module(force=True) class NewCat2: pass assert CATS.get('NewCat2').__name__ == 'NewCat2'
def test_build_from_cfg(): BACKBONES = Registry('backbone') @BACKBONES.register_module() class ResNet: def __init__(self, depth, stages=4): self.depth = depth self.stages = stages @BACKBONES.register_module() class ResNeXt: def __init__(self, depth, stages=4): self.depth = depth self.stages = stages cfg = dict(type='ResNet', depth=50) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = dict(type='ResNet', depth=50) model = build_from_cfg(cfg, BACKBONES, default_args={'stages': 3}) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 3 cfg = dict(type='ResNeXt', depth=50, stages=3) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNeXt) assert model.depth == 50 and model.stages == 3 cfg = dict(type=ResNet, depth=50) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # type defined using default_args cfg = dict(depth=50) model = build_from_cfg(cfg, BACKBONES, default_args=dict(type='ResNet')) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = dict(depth=50) model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet)) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # not a registry with pytest.raises(TypeError): cfg = dict(type='VGG') model = build_from_cfg(cfg, 'BACKBONES') # non-registered class with pytest.raises(KeyError): cfg = dict(type='VGG') model = build_from_cfg(cfg, BACKBONES) # default_args must be a dict or None with pytest.raises(TypeError): cfg = dict(type='ResNet', depth=50) model = build_from_cfg(cfg, BACKBONES, default_args=1) # cfg['type'] should be a str or class with pytest.raises(TypeError): cfg = dict(type=1000) model = build_from_cfg(cfg, BACKBONES) # cfg should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = dict(depth=50, stages=4) model = build_from_cfg(cfg, BACKBONES) # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = dict(depth=50) model = build_from_cfg(cfg, BACKBONES, default_args=dict(stages=4)) # incorrect registry type with pytest.raises(TypeError): cfg = dict(type='ResNet', depth=50) model = build_from_cfg(cfg, 'BACKBONES') # incorrect default_args type with pytest.raises(TypeError): cfg = dict(type='ResNet', depth=50) model = build_from_cfg(cfg, BACKBONES, default_args=0)