Beispiel #1
0
    def build_optimizer(self,
                        clip=15.0,
                        lr=5e-4,
                        warmup=2000,
                        cosine_decay_steps=None,
                        optimizer_name="adabelief") -> GradientTransformation:
        chain = []
        if optimizer_name == "adabelief":
            chain.append(util.scale_by_belief())
        elif optimizer_name == "adam":
            chain.append(optax.scale_by_adam())
        else:
            assert 0

        # Make sure to use the negative learning rate so that we minimize
        if warmup and warmup > 0:
            warmup_schedule = partial(util.linear_warmup_lr_schedule,
                                      warmup=warmup,
                                      lr_decay=1.0,
                                      lr=-lr)
            chain.append(optax.scale_by_schedule(warmup_schedule))
        else:
            chain.append(optax.scale(-lr))

        if cosine_decay_steps and cosine_decay_steps > 0:
            cosine_lr = optax.cosine_decay_schedule(
                init_value=1.0, decay_steps=cosine_decay_steps, alpha=1e-1)
            chain.append(optax.scale_by_schedule(cosine_lr))

        if clip and clip > 0:
            chain.append(optax.clip(clip))

        return optax.chain(*chain)
Beispiel #2
0
    def test_optimizer_chain(self):

        optimizer = elegy.Optimizer(
            optax.sgd(0.1),
            optax.clip(0.5),
        )

        params = np.zeros(shape=(3, 4))
        grads = np.ones(shape=(3, 4)) * 100_000
        rng = elegy.RNGSeq(42)

        optimizer_states = optimizer.init(
            rng=rng,
            net_params=params,
        )

        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)

        assert np.all(-0.5 <= params) and np.all(params <= 0.5)
Beispiel #3
0
def create_train_state(config, rng, learning_rate_fn, example_batch):
    """Create and initialize the model.

  Args:
    config: Configuration for model.
    rng: JAX PRNG Key.
    learning_rate_fn: learning rate function
    example_batch: for model intialization

  Returns:
    The initialized TrainState with the optimizer.
  """
    model, variables = create_model(config, rng, example_batch)
    params = variables['params']
    parameter_overview.log_parameter_overview(params)

    optimizer = optax.adamw(learning_rate=learning_rate_fn,
                            b1=0.9,
                            b2=.98,
                            eps=1e-9,
                            weight_decay=config.train.weight_decay)

    if config.train.grad_max_norm > 0:
        tx = optax.chain(optax.clip_by_global_norm(config.train.grad_max_norm),
                         optimizer)
    elif config.train.grad_max_val > 1:
        tx = optax.chain(optax.clip(config.train.grad_max_val), optimizer)
    else:
        tx = optimizer

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables['params'],
        tx=tx,
    )
    return model, state
Beispiel #4
0
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import SVI, RenyiELBO, Trace_ELBO

try:
    import optax

    from numpyro.contrib.optim import optax_to_numpyro

    # the optimizer test is parameterized by different optax optimizers, but we have
    # to define them here to ensure that `optax` is defined. pytest.mark.parameterize
    # decorators are run even if tests are skipped at the top of the file.
    optimizers = [
        (optax.adam, (1e-2, ), {}),
        # clipped adam
        (optax.chain, (optax.clip(10.0), optax.adam(1e-2)), {}),
        (optax.adagrad, (1e-1, ), {}),
        # SGD with momentum
        (optax.sgd, (1e-2, ), {
            "momentum": 0.9
        }),
        (optax.rmsprop, (1e-2, ), {
            "decay": 0.95
        }),
        # RMSProp with momentum
        (optax.rmsprop, (1e-4, ), {
            "decay": 0.9,
            "momentum": 0.9
        }),
        (optax.sgd, (1e-2, ), {}),
    ]
Beispiel #5
0
         accuracy), grads = jax.value_and_grad(loss,
                                               has_aux=True)(state.params,
                                                             batch, labels)
        state = state.apply_gradients(grads=grads)
        return state, loss_val, accuracy

    for i, (batch, labels) in enumerate(zip(train_data, train_labels)):
        state, loss_val, accuracy = step(state, batch, labels)
        if i % 100 == 0:
            print(
                f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%"
            )

    return params


params = net.init(jax.random.PRNGKey(0), train_data[0])
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=1.0,
    warmup_steps=50,
    decay_steps=5_000,
    end_value=0.0,
)
opt = optax.chain(
    optax.clip(1.0),
    # optax.adamw(learning_rate=1e-3)
    optax.adamw(learning_rate=schedule))

params = fit(params, opt)
Beispiel #6
0
def create_optimizer(config):
    """Creates the optimizer associated to a config."""
    ops = []

    # Gradient clipping either by norm `gradient_norm_clip` or by absolute value
    # `gradient_value_clip`.
    if "gradient_clip" in config:
        raise ValueError("'gradient_clip' is deprecated, please use "
                         "'gradient_norm_clip'.")
    assert not ("gradient_norm_clip" in config
                and "gradient_value_clip" in config), (
                    "Gradient clipping by norm and by value are exclusive.")

    if "gradient_norm_clip" in config:
        ops.append(optax.clip_by_global_norm(config.gradient_norm_clip))
    if "gradient_value_clip" in config:
        ops.append(optax.clip(config.gradient_value_clip))

    # Define the learning rate schedule.
    schedule_fn = utils.get_optax_schedule_fn(
        warmup_ratio=config.get("warmup_ratio", 0.),
        num_train_steps=config.num_train_steps,
        decay=config.get("learning_rate_step_decay", 1.0),
        decay_at_steps=config.get("learning_rate_decay_at_steps", []),
        cosine_decay_schedule=config.get("cosine_decay", False))

    schedule_ops = [optax.scale_by_schedule(schedule_fn)]

    # Scale some parameters matching a regex by a multiplier. Config field
    # `scaling_by_regex` is a list of pairs (regex: str, multiplier: float).
    scaling_by_regex = config.get("scaling_learning_rate_by_regex", [])
    for regex, multiplier in scaling_by_regex:
        logging.info(
            "Learning rate is scaled by %f for parameters matching '%s'",
            multiplier, regex)
        schedule_ops.append(utils.scale_selected_parameters(regex, multiplier))
    schedule_optimizer = optax.chain(*schedule_ops)

    if config.optimizer.lower() == "adam":
        optimizer = optax.adam(config.learning_rate)
        ops.append(optimizer)
        ops.append(schedule_optimizer)
    elif config.optimizer.lower() == "sgd":
        ops.append(schedule_optimizer)
        optimizer = optax.sgd(config.learning_rate, momentum=config.momentum)
        ops.append(optimizer)
    else:
        raise NotImplementedError("Invalid optimizer: {}".format(
            config.optimizer))

    if "weight_decay" in config and config.weight_decay > 0.:
        ops.append(
            utils.decoupled_weight_decay(decay=config.weight_decay,
                                         step_size_fn=schedule_fn))

    # Freeze parameters that match the given regexes (if any).
    freeze_weights_regexes = config.get("freeze_weights_regex", []) or []
    if isinstance(freeze_weights_regexes, str):
        freeze_weights_regexes = [freeze_weights_regexes]
    for reg in freeze_weights_regexes:
        ops.append(utils.freeze(reg))

    return optax.chain(*ops)
Beispiel #7
0
            state, loss_val, accuracy = train_step(state, batch, labels)

            train_loss += loss_val
            train_accuracy += accuracy

        test_loss, test_accuracy = 0.0, 0.0
        for batch, labels in test_loader:
            batch, labels = jnp.array(batch), jnp.array(labels)
            loss_val, accuracy = eval_step(state.params, batch, labels)

            test_loss += loss_val
            test_accuracy += accuracy

        train_loss /= len(train_loader)
        train_accuracy /= len(train_loader)

        test_loss /= len(test_loader)
        test_accuracy /= len(test_loader)

        print(
            f"epoch {i+1}/{nb_epochs} | train: {train_loss:.5f} [{train_accuracy*100:.2f}%] | eval: {test_loss:.5f} [{test_accuracy*100:.2f}%]"
        )

    return params


params = net.init(jax.random.PRNGKey(0),
                  jnp.ones((1, *image_shape)).astype(jnp.float32))
opt = optax.chain(optax.clip(1.0), optax.adamw(learning_rate=1e-4))
params = fit(params, opt)
Beispiel #8
0
        (loss_val, accuracy), grads = jax.value_and_grad(loss,
                                                         has_aux=True)(params,
                                                                       batch,
                                                                       labels)
        updates, opt_state = opt.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_val, accuracy

    for i, (batch, labels) in enumerate(zip(train_data, train_labels)):
        params, opt_state, loss_val, accuracy = step(params, opt_state, batch,
                                                     labels)
        if i % 100 == 0:
            print(
                f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%"
            )

    return params


# opt = optax.adam(learning_rate=1e-2)
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=1.0,
    warmup_steps=50,
    decay_steps=5_000,
    end_value=0.0,
)
opt = optax.chain(optax.clip(1.0), optax.adamw(learning_rate=schedule))

params = fit(params, opt)