def test_multi_element_dict(self): coll = bd.Sequential({'relu': nn.ReLU(), 'tanh': nn.Tanh()}) assert len(coll) == 2 assert isinstance(coll[0], nn.ReLU) assert isinstance(coll[1], nn.Tanh) assert isinstance(getattr(coll, 'relu'), nn.ReLU) assert isinstance(getattr(coll, 'tanh'), nn.Tanh)
def test_single_element(self): coll = bd.Sequential(nn.ReLU()) assert len(coll) == 1 assert isinstance(coll[0], nn.ReLU) assert isinstance(getattr(coll, '0'), nn.ReLU)
def test_can_accept_no_element(self): bd.Sequential()
def test_single_element_dict(self): coll = bd.Sequential({'relu': nn.ReLU()}) assert len(coll) == 1 assert isinstance(coll[0], nn.ReLU) assert isinstance(getattr(coll, 'relu'), nn.ReLU)
def _build_magic_module(cfg): # cfg can be Module, tuple, list if isinstance(cfg, nn.Module): return cfg if not isinstance(cfg, (tuple, list)): raise RuntimeError( '[model_cfg] ' 'Model config must be composed of lists, tuples, and Modules. ' f'Config: {cfg} is of {type(cfg)}') # cfg is list or tuple now # If empty, return Identity if not cfg: return nn.Identity() # If first element is 'list' or 'tuple', return corresponding if cfg[0] == 'list': return list(cfg[1:]) if cfg[0] == 'tuple': return tuple(cfg[1:]) # If the first element is a dict then: # 1. It is a sequential # 2. It must be the only element OR can have another # element that must be a dict with the kwargs # of bd.Sequential if isinstance(cfg[0], dict): if len(cfg) == 1: return bd.Sequential(cfg[0]) elif len(cfg) == 2: if isinstance(cfg[1], dict): seq_argnames = inspect.getfullargspec(bd.Sequential).kwonlyargs if not all(x in seq_argnames for x in cfg[1]): raise RuntimeError( f'[model_cfg] Invalid kwargs in config {cfg}') else: return bd.Sequential(cfg[0], **cfg[1]) else: raise RuntimeError(f'[model_cfg] Invalid module {cfg}') else: raise RuntimeError(f'[model_cfg] Invalid module {cfg}') # cfg is a list or tuple now and first element is NOT dict # If the first element is not allowed raise error if not isinstance(cfg[0], (str, list, tuple, Module)): raise RuntimeError( '[model_cfg] ' 'First elements of config lists/tuples must be modules, strings or ' 'list/tuples/dicts in the case of sequential modules. ' f'Got type {type(cfg[0])} for {cfg[0]}.') if isinstance(cfg[0], str): module_name = cfg[0].lower() rest = cfg[1:] else: module_name = 'sequential' rest = cfg if module_name not in Module._registry: raise RuntimeError( f'[model_cfg] Unsupported module {cfg[0]} found in configuration.') module_cls = Module._registry[module_name] module_argnames = inspect.getfullargspec(module_cls).args[1:] # Build modules if not rest: # No arguments return module_cls() elif isinstance(rest[-1], dict) and all(key in module_argnames for key in rest[-1]): # Final dict keys are ALL module kwargs return module_cls(*rest[:-1], **rest[-1]) elif isinstance(rest[-1], dict) and all(key in _ARGDICT_KEYS for key in rest[-1]): # Final dict is of the form: # {'kwargs': dict, 'apply_fn': callable, 'add_members': dict} argdict = rest[-1] if 'kwargs' in argdict: module = module_cls(*rest[:-1], **argdict['kwargs']) else: module = module_cls(*rest[:-1]) if 'apply_fn' in argdict: module.apply(argdict['apply_fn']) if 'add_members' in argdict: for key, val in argdict['add_members'].items(): setattr(module, key, val) return module else: # Pass all rest items as *args return module_cls(*rest)
def test_can_accept_module_and_returns_it(self): for mod in [nn.ReLU(), nn.Tanh(), bd.Sequential()]: assert bd.magic_module(mod) is mod