# Copyright (c) OpenMMLab. All rights reserved. from torch.nn.parallel import DataParallel, DistributedDataParallel from mmcv.utils import Registry MODULE_WRAPPERS = Registry('module wrapper') MODULE_WRAPPERS.register_module(module=DataParallel) MODULE_WRAPPERS.register_module(module=DistributedDataParallel)
# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.utils import Registry, build_from_cfg TRANSFORMER = Registry('Transformer') LINEAR_LAYERS = Registry('linear layers') def build_transformer(cfg, default_args=None): """Builder for Transformer.""" return build_from_cfg(cfg, TRANSFORMER, default_args) LINEAR_LAYERS.register_module('Linear', module=nn.Linear) def build_linear_layer(cfg, *args, **kwargs): """Build linear layer. Args: cfg (None or dict): The linear layer config, which should contain: - type (str): Layer type. - layer args: Args needed to instantiate an linear layer. args (argument list): Arguments passed to the `__init__` method of the corresponding linear layer. kwargs (keyword arguments): Keyword arguments passed to the `__init__` method of the corresponding linear layer. Returns: nn.Module: Created linear layer. """ if cfg is None: cfg_ = dict(type='Linear')