예제 #1
0
 def test_multi_element_dict(self):
     coll = bd.Sequential({'relu': nn.ReLU(), 'tanh': nn.Tanh()})
     assert len(coll) == 2
     assert isinstance(coll[0], nn.ReLU)
     assert isinstance(coll[1], nn.Tanh)
     assert isinstance(getattr(coll, 'relu'), nn.ReLU)
     assert isinstance(getattr(coll, 'tanh'), nn.Tanh)
예제 #2
0
 def test_single_element(self):
     coll = bd.Sequential(nn.ReLU())
     assert len(coll) == 1
     assert isinstance(coll[0], nn.ReLU)
     assert isinstance(getattr(coll, '0'), nn.ReLU)
예제 #3
0
 def test_can_accept_no_element(self):
     bd.Sequential()
예제 #4
0
 def test_single_element_dict(self):
     coll = bd.Sequential({'relu': nn.ReLU()})
     assert len(coll) == 1
     assert isinstance(coll[0], nn.ReLU)
     assert isinstance(getattr(coll, 'relu'), nn.ReLU)
예제 #5
0
def _build_magic_module(cfg):
    # cfg can be Module, tuple, list
    if isinstance(cfg, nn.Module):
        return cfg

    if not isinstance(cfg, (tuple, list)):
        raise RuntimeError(
            '[model_cfg] '
            'Model config must be composed of lists, tuples, and Modules. '
            f'Config: {cfg} is of {type(cfg)}')

    # cfg is list or tuple now

    # If empty, return Identity
    if not cfg:
        return nn.Identity()
    # If first element is 'list' or 'tuple', return corresponding
    if cfg[0] == 'list':
        return list(cfg[1:])
    if cfg[0] == 'tuple':
        return tuple(cfg[1:])

    # If the first element is a dict then:
    # 1. It is a sequential
    # 2. It must be the only element OR can have another
    #     element that must be a dict with the kwargs
    #     of bd.Sequential
    if isinstance(cfg[0], dict):
        if len(cfg) == 1:
            return bd.Sequential(cfg[0])
        elif len(cfg) == 2:
            if isinstance(cfg[1], dict):
                seq_argnames = inspect.getfullargspec(bd.Sequential).kwonlyargs
                if not all(x in seq_argnames for x in cfg[1]):
                    raise RuntimeError(
                        f'[model_cfg] Invalid kwargs in config {cfg}')
                else:
                    return bd.Sequential(cfg[0], **cfg[1])
            else:
                raise RuntimeError(f'[model_cfg] Invalid module {cfg}')
        else:
            raise RuntimeError(f'[model_cfg] Invalid module {cfg}')

    # cfg is a list or tuple now and first element is NOT dict

    # If the first element is not allowed raise error
    if not isinstance(cfg[0], (str, list, tuple, Module)):
        raise RuntimeError(
            '[model_cfg] '
            'First elements of config lists/tuples must be modules, strings or '
            'list/tuples/dicts in the case of sequential modules. '
            f'Got type {type(cfg[0])} for {cfg[0]}.')

    if isinstance(cfg[0], str):
        module_name = cfg[0].lower()
        rest = cfg[1:]
    else:
        module_name = 'sequential'
        rest = cfg

    if module_name not in Module._registry:
        raise RuntimeError(
            f'[model_cfg] Unsupported module {cfg[0]} found in configuration.')

    module_cls = Module._registry[module_name]
    module_argnames = inspect.getfullargspec(module_cls).args[1:]

    # Build modules
    if not rest:
        # No arguments
        return module_cls()
    elif isinstance(rest[-1], dict) and all(key in module_argnames
                                            for key in rest[-1]):
        # Final dict keys are ALL module kwargs
        return module_cls(*rest[:-1], **rest[-1])
    elif isinstance(rest[-1], dict) and all(key in _ARGDICT_KEYS
                                            for key in rest[-1]):
        # Final dict is of the form:
        # {'kwargs': dict, 'apply_fn': callable, 'add_members': dict}
        argdict = rest[-1]
        if 'kwargs' in argdict:
            module = module_cls(*rest[:-1], **argdict['kwargs'])
        else:
            module = module_cls(*rest[:-1])
        if 'apply_fn' in argdict:
            module.apply(argdict['apply_fn'])
        if 'add_members' in argdict:
            for key, val in argdict['add_members'].items():
                setattr(module, key, val)
        return module
    else:
        # Pass all rest items as *args
        return module_cls(*rest)
예제 #6
0
 def test_can_accept_module_and_returns_it(self):
     for mod in [nn.ReLU(), nn.Tanh(), bd.Sequential()]:
         assert bd.magic_module(mod) is mod