def get_default_constructors(config): """Return default constructors from config.""" con_config = OrderedDict() device_conf = config.get('device', {}) device_ids = config.get('device_ids', None) arch_desc = config.get('arch_desc', None) if device_ids is not None: device_conf['device'] = device_ids else: device_ids = device_conf.get('device', device_ids) con_config['init'] = get_init_constructor(config.get('init', {}), device_ids) if 'ops' in config: con_config['init']['args']['ops_conf'] = config['ops'] if 'model' in config: con_config['model'] = get_model_constructor(config['model']) if 'mixed_op' in config: con_config['mixed_op'] = get_mixed_op_constructor(config['mixed_op']) if arch_desc is not None: con_config['arch_desc'] = get_arch_desc_constructor(arch_desc) con_config = utils.merge_config(con_config, config.get('construct', {})) if be.is_backend('torch'): con_config['device'] = {'type': 'TorchToDevice', 'args': device_conf} if config.get('chkpt'): con_config['chkpt'] = get_chkpt_constructor(config['chkpt']) constructor = partial(default_constructor, construct_config=con_config, arch_desc=arch_desc) return constructor
def run(*args, parse=False, **kwargs): """Run routine.""" if parse or (not args and not kwargs): parsed_kwargs = parse_routine_args() parsed_kwargs = utils.merge_config(parsed_kwargs, kwargs) else: parsed_kwargs = kwargs return run_default(*args, **parsed_kwargs)
def load_config(conf): """Load configurations.""" if not isinstance(conf, list): conf = [conf] config = None for cfg in conf: loaded_cfg = Config.load(cfg) config = loaded_cfg if config is None else utils.merge_config(config, loaded_cfg) return config
def get_data(configs): """Return a new dataset.""" config = None for conf in configs: if conf is None: continue config = conf if config is None else merge_config(config, conf) if config is None: return None return build_dataset(config)
def init(self): """Initialize ModularNAS components and Vega Trainer.""" self.config = _patch_fmt_config( self.config, { 'local_worker_path': self.trainer.get_local_worker_path(), 'local_base_path': self.trainer.local_base_path, 'local_output_path': self.trainer.local_output_path, }) self.config['name'] = self.config.get('name', 'default') self.config['routine'] = self.config.get('routine', 'search') self.config['expman'] = self.config.get('expman', {}) self.config['expman']['root_dir'] = FileOps.join_path( self.trainer.get_local_worker_path(), 'exp') self.config = merge_config(self.config, self.model.config) ctx = init_all(self.config, model=self.model.net) self.__dict__.update(ctx) self.model.net = list(self.estims.values())[0].model if self.optim: self.search_alg.set_optim(self.optim) self.wrp_trainer = VegaTrainerWrapper(self.trainer) self.wrp_trainer.init()
def dispatch_all(self, merge_ret=False, chain_ret=True, fret=None, is_ret=False): """Trigger all delayed event handlers.""" rets = {} self.event_queue, ev_queue = [], self.event_queue for ev_spec in ev_queue: ev, args, kwargs, callback = ev_spec ret = None for handler in self.get_handlers(ev): hret = handler(*args, **kwargs) logger.debug('handler: %s %s' % (handler, hret)) if chain_ret and is_ret: args = (fret if hret is None else hret, ) + args[1:] ret = merge_config( ret, hret) if merge_ret and hret is not None else hret if callback is not None: callback(ret) if len(ev_queue) == 1: return ret rets[ev] = ret return rets