def get_lr_scheduler(optimizer, config, trainer_config=None): """Return a new LR Scheduler.""" trainer_config = trainer_config or {} lr_type, lr_args = parse_spec(config) if lr_type == 'CosineAnnealingLR': if 'T_max' not in lr_args and 'epochs' in trainer_config: lr_args['T_max'] = trainer_config['epochs'] return build(lr_type, optimizer, **lr_args)
def default_constructor(model, construct_config=None, arch_desc=None): """Apply all constructors on model.""" if arch_desc: reg_id, args = parse_spec(construct_config['arch_desc']) args['arch_desc'] = arch_desc construct_config['arch_desc'] = to_spec(reg_id, args) construct_fn = build_constructor_all(construct_config or {}) for name, con_fn in construct_fn.items(): logger.info('Running constructor: {} type: {}'.format(name, con_fn.__class__.__name__)) model = con_fn(model) return model
def get_optimizer(params, config, trainer_config=None): """Return a new Optimizer.""" trainer_config = trainer_config or {} optim_type, optim_args = parse_spec(config) device_ids = trainer_config.get('device', [None]) n_parallel = len(device_ids) if trainer_config.get('scale_lr', True) and 'lr' in optim_args: optim_args['lr'] *= n_parallel optimizer = build(optim_type, params, **optim_args) if n_parallel > 1: optimizer = torch.nn.DataParallel(optimizer, device_ids=device_ids).module return optimizer
def convert(self, slot): """Convert Slot to mixed operator.""" arch_params = self.param_map.get(slot.name, None) old_conf = self.mixed_op_conf reg_id, args = parse_spec(old_conf) if arch_params is not None: args['arch_param'] = arch_params self.mixed_op_conf = to_spec(reg_id, args) ent = super().convert(slot) self.mixed_op_conf = old_conf args.pop('arch_param', None) if slot.name not in self.param_map: self.param_map[slot.name] = ent.arch_param return ent
def run(self, optim): """Run Estimator routine.""" del optim logger = self.logger config = self.config pipeconf = config.pipeline pending = queue.Queue() for pn in pipeconf.keys(): pending.put(pn) finished = set() ret_values, ret = dict(), None while not pending.empty(): pname = pending.get() pconf = pipeconf.get(pname) dep_sat = True for dep in pconf.get('depends', []): if dep not in finished: dep_sat = False break if not dep_sat: pending.put(pname) continue ptype, pargs = parse_spec(pconf) pargs['name'] = pargs.get('name', pname) for inp_kw, inp_idx in pconf.get('inputs', {}).items(): keys = inp_idx.split('.') inp_val = ret_values for k in keys: if not inp_val or k not in inp_val: raise RuntimeError( 'input key {} not found in return {}'.format( inp_idx, ret_values)) inp_val = inp_val[k] pargs[inp_kw] = inp_val logger.info('pipeline: running {}, type={}'.format(pname, ptype)) ret = self.step(pconf) ret_values[pname] = ret logger.info('pipeline: finished {}, results={}'.format(pname, ret)) finished.add(pname) ret_values['final'] = ret logger.info('pipeline: all finished') return ret_values