Exemplo n.º 1
0
def get_lr_scheduler(optimizer, config, trainer_config=None):
    """Return a new LR Scheduler."""
    trainer_config = trainer_config or {}
    lr_type, lr_args = parse_spec(config)
    if lr_type == 'CosineAnnealingLR':
        if 'T_max' not in lr_args and 'epochs' in trainer_config:
            lr_args['T_max'] = trainer_config['epochs']
    return build(lr_type, optimizer, **lr_args)
Exemplo n.º 2
0
def default_constructor(model, construct_config=None, arch_desc=None):
    """Apply all constructors on model."""
    if arch_desc:
        reg_id, args = parse_spec(construct_config['arch_desc'])
        args['arch_desc'] = arch_desc
        construct_config['arch_desc'] = to_spec(reg_id, args)
    construct_fn = build_constructor_all(construct_config or {})
    for name, con_fn in construct_fn.items():
        logger.info('Running constructor: {} type: {}'.format(name, con_fn.__class__.__name__))
        model = con_fn(model)
    return model
Exemplo n.º 3
0
def get_optimizer(params, config, trainer_config=None):
    """Return a new Optimizer."""
    trainer_config = trainer_config or {}
    optim_type, optim_args = parse_spec(config)
    device_ids = trainer_config.get('device', [None])
    n_parallel = len(device_ids)
    if trainer_config.get('scale_lr', True) and 'lr' in optim_args:
        optim_args['lr'] *= n_parallel
    optimizer = build(optim_type, params, **optim_args)
    if n_parallel > 1:
        optimizer = torch.nn.DataParallel(optimizer,
                                          device_ids=device_ids).module
    return optimizer
Exemplo n.º 4
0
 def convert(self, slot):
     """Convert Slot to mixed operator."""
     arch_params = self.param_map.get(slot.name, None)
     old_conf = self.mixed_op_conf
     reg_id, args = parse_spec(old_conf)
     if arch_params is not None:
         args['arch_param'] = arch_params
     self.mixed_op_conf = to_spec(reg_id, args)
     ent = super().convert(slot)
     self.mixed_op_conf = old_conf
     args.pop('arch_param', None)
     if slot.name not in self.param_map:
         self.param_map[slot.name] = ent.arch_param
     return ent
Exemplo n.º 5
0
 def run(self, optim):
     """Run Estimator routine."""
     del optim
     logger = self.logger
     config = self.config
     pipeconf = config.pipeline
     pending = queue.Queue()
     for pn in pipeconf.keys():
         pending.put(pn)
     finished = set()
     ret_values, ret = dict(), None
     while not pending.empty():
         pname = pending.get()
         pconf = pipeconf.get(pname)
         dep_sat = True
         for dep in pconf.get('depends', []):
             if dep not in finished:
                 dep_sat = False
                 break
         if not dep_sat:
             pending.put(pname)
             continue
         ptype, pargs = parse_spec(pconf)
         pargs['name'] = pargs.get('name', pname)
         for inp_kw, inp_idx in pconf.get('inputs', {}).items():
             keys = inp_idx.split('.')
             inp_val = ret_values
             for k in keys:
                 if not inp_val or k not in inp_val:
                     raise RuntimeError(
                         'input key {} not found in return {}'.format(
                             inp_idx, ret_values))
                 inp_val = inp_val[k]
             pargs[inp_kw] = inp_val
         logger.info('pipeline: running {}, type={}'.format(pname, ptype))
         ret = self.step(pconf)
         ret_values[pname] = ret
         logger.info('pipeline: finished {}, results={}'.format(pname, ret))
         finished.add(pname)
     ret_values['final'] = ret
     logger.info('pipeline: all finished')
     return ret_values