示例#1
0
def build_optimizer(optimizer_config, net, name=None, mixed=False, loss_scale=512.0):
    """Create optimizer based on config.

    Args:
        optimizer_config: A Optimizer proto message.

    Returns:
        An optimizer and a list of variables for summary.

    Raises:
        ValueError: when using an unsupported input data type.
    """
    optimizer_type = optimizer_config.TYPE
    config = optimizer_config.VALUE

    if optimizer_type == "rms_prop_optimizer":
        optimizer_func = partial(
            torch.optim.RMSprop,
            alpha=config.decay,
            momentum=config.momentum_optimizer_value,
            eps=config.epsilon,
        )
    elif optimizer_type == "momentum_optimizer":
        optimizer_func = partial(
            torch.optim.SGD,
            momentum=config.momentum_optimizer_value,
            eps=config.epsilon,
        )
    elif optimizer_type == "adam":
        if optimizer_config.FIXED_WD:
            optimizer_func = partial(
                torch.optim.Adam, betas=(0.9, 0.99), amsgrad=config.amsgrad
            )
        else:
            # regular adam
            optimizer_func = partial(torch.optim.Adam, amsgrad=config.amsgrad)

    optimizer = OptimWrapper.create(
        optimizer_func,
        3e-3,
        get_layer_groups(net),
        wd=config.WD,
        true_wd=optimizer_config.FIXED_WD,
        bn_wd=True,
    )

    if optimizer is None:
        raise ValueError("Optimizer %s not supported." % optimizer_type)

    if optimizer_config.MOVING_AVERAGE:
        raise ValueError("torch don't support moving average")

    if name is None:
        # assign a name to optimizer for checkpoint system
        optimizer.name = optimizer_type
    else:
        optimizer.name = name

    return optimizer
示例#2
0
def build_one_cycle_optimizer(model, optimizer_config):
    if optimizer_config.fixed_wd:  # True
        optimizer_func = partial(torch.optim.Adam,
                                 betas=(0.9, 0.99),
                                 amsgrad=optimizer_config.amsgrad)
    else:
        optimizer_func = partial(torch.optim.Adam,
                                 amsgrad=optimizer_config.amsgrad
                                 )  # todo: optimizer_cfg -> optimizer_config
    optimizer = OptimWrapper.create(
        optimizer_func,
        3e-3,
        get_layer_groups(model),
        wd=optimizer_config.wd,
        true_wd=optimizer_config.fixed_wd,
        bn_wd=True,
    )
    return optimizer