def build_optimizer(optimizer_config, net, name=None, mixed=False, loss_scale=512.0): """Create optimizer based on config. Args: optimizer_config: A Optimizer proto message. Returns: An optimizer and a list of variables for summary. Raises: ValueError: when using an unsupported input data type. """ optimizer_type = optimizer_config.TYPE config = optimizer_config.VALUE if optimizer_type == "rms_prop_optimizer": optimizer_func = partial( torch.optim.RMSprop, alpha=config.decay, momentum=config.momentum_optimizer_value, eps=config.epsilon, ) elif optimizer_type == "momentum_optimizer": optimizer_func = partial( torch.optim.SGD, momentum=config.momentum_optimizer_value, eps=config.epsilon, ) elif optimizer_type == "adam": if optimizer_config.FIXED_WD: optimizer_func = partial( torch.optim.Adam, betas=(0.9, 0.99), amsgrad=config.amsgrad ) else: # regular adam optimizer_func = partial(torch.optim.Adam, amsgrad=config.amsgrad) optimizer = OptimWrapper.create( optimizer_func, 3e-3, get_layer_groups(net), wd=config.WD, true_wd=optimizer_config.FIXED_WD, bn_wd=True, ) if optimizer is None: raise ValueError("Optimizer %s not supported." % optimizer_type) if optimizer_config.MOVING_AVERAGE: raise ValueError("torch don't support moving average") if name is None: # assign a name to optimizer for checkpoint system optimizer.name = optimizer_type else: optimizer.name = name return optimizer
def build_one_cycle_optimizer(model, optimizer_config): if optimizer_config.fixed_wd: # True optimizer_func = partial(torch.optim.Adam, betas=(0.9, 0.99), amsgrad=optimizer_config.amsgrad) else: optimizer_func = partial(torch.optim.Adam, amsgrad=optimizer_config.amsgrad ) # todo: optimizer_cfg -> optimizer_config optimizer = OptimWrapper.create( optimizer_func, 3e-3, get_layer_groups(model), wd=optimizer_config.wd, true_wd=optimizer_config.fixed_wd, bn_wd=True, ) return optimizer