def build_optimizer(self, clip=15.0, lr=5e-4, warmup=2000, cosine_decay_steps=None, optimizer_name="adabelief") -> GradientTransformation: chain = [] if optimizer_name == "adabelief": chain.append(util.scale_by_belief()) elif optimizer_name == "adam": chain.append(optax.scale_by_adam()) else: assert 0 # Make sure to use the negative learning rate so that we minimize if warmup and warmup > 0: warmup_schedule = partial(util.linear_warmup_lr_schedule, warmup=warmup, lr_decay=1.0, lr=-lr) chain.append(optax.scale_by_schedule(warmup_schedule)) else: chain.append(optax.scale(-lr)) if cosine_decay_steps and cosine_decay_steps > 0: cosine_lr = optax.cosine_decay_schedule( init_value=1.0, decay_steps=cosine_decay_steps, alpha=1e-1) chain.append(optax.scale_by_schedule(cosine_lr)) if clip and clip > 0: chain.append(optax.clip(clip)) return optax.chain(*chain)
def test_optimizer_chain(self): optimizer = elegy.Optimizer( optax.sgd(0.1), optax.clip(0.5), ) params = np.zeros(shape=(3, 4)) grads = np.ones(shape=(3, 4)) * 100_000 rng = elegy.RNGSeq(42) optimizer_states = optimizer.init( rng=rng, net_params=params, ) params, optimizer_states = optimizer.apply(params, grads, optimizer_states, rng) assert np.all(-0.5 <= params) and np.all(params <= 0.5)
def create_train_state(config, rng, learning_rate_fn, example_batch): """Create and initialize the model. Args: config: Configuration for model. rng: JAX PRNG Key. learning_rate_fn: learning rate function example_batch: for model intialization Returns: The initialized TrainState with the optimizer. """ model, variables = create_model(config, rng, example_batch) params = variables['params'] parameter_overview.log_parameter_overview(params) optimizer = optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=.98, eps=1e-9, weight_decay=config.train.weight_decay) if config.train.grad_max_norm > 0: tx = optax.chain(optax.clip_by_global_norm(config.train.grad_max_norm), optimizer) elif config.train.grad_max_val > 1: tx = optax.chain(optax.clip(config.train.grad_max_val), optimizer) else: tx = optimizer state = train_state.TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx, ) return model, state
import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.infer import SVI, RenyiELBO, Trace_ELBO try: import optax from numpyro.contrib.optim import optax_to_numpyro # the optimizer test is parameterized by different optax optimizers, but we have # to define them here to ensure that `optax` is defined. pytest.mark.parameterize # decorators are run even if tests are skipped at the top of the file. optimizers = [ (optax.adam, (1e-2, ), {}), # clipped adam (optax.chain, (optax.clip(10.0), optax.adam(1e-2)), {}), (optax.adagrad, (1e-1, ), {}), # SGD with momentum (optax.sgd, (1e-2, ), { "momentum": 0.9 }), (optax.rmsprop, (1e-2, ), { "decay": 0.95 }), # RMSProp with momentum (optax.rmsprop, (1e-4, ), { "decay": 0.9, "momentum": 0.9 }), (optax.sgd, (1e-2, ), {}), ]
accuracy), grads = jax.value_and_grad(loss, has_aux=True)(state.params, batch, labels) state = state.apply_gradients(grads=grads) return state, loss_val, accuracy for i, (batch, labels) in enumerate(zip(train_data, train_labels)): state, loss_val, accuracy = step(state, batch, labels) if i % 100 == 0: print( f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%" ) return params params = net.init(jax.random.PRNGKey(0), train_data[0]) schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=1.0, warmup_steps=50, decay_steps=5_000, end_value=0.0, ) opt = optax.chain( optax.clip(1.0), # optax.adamw(learning_rate=1e-3) optax.adamw(learning_rate=schedule)) params = fit(params, opt)
def create_optimizer(config): """Creates the optimizer associated to a config.""" ops = [] # Gradient clipping either by norm `gradient_norm_clip` or by absolute value # `gradient_value_clip`. if "gradient_clip" in config: raise ValueError("'gradient_clip' is deprecated, please use " "'gradient_norm_clip'.") assert not ("gradient_norm_clip" in config and "gradient_value_clip" in config), ( "Gradient clipping by norm and by value are exclusive.") if "gradient_norm_clip" in config: ops.append(optax.clip_by_global_norm(config.gradient_norm_clip)) if "gradient_value_clip" in config: ops.append(optax.clip(config.gradient_value_clip)) # Define the learning rate schedule. schedule_fn = utils.get_optax_schedule_fn( warmup_ratio=config.get("warmup_ratio", 0.), num_train_steps=config.num_train_steps, decay=config.get("learning_rate_step_decay", 1.0), decay_at_steps=config.get("learning_rate_decay_at_steps", []), cosine_decay_schedule=config.get("cosine_decay", False)) schedule_ops = [optax.scale_by_schedule(schedule_fn)] # Scale some parameters matching a regex by a multiplier. Config field # `scaling_by_regex` is a list of pairs (regex: str, multiplier: float). scaling_by_regex = config.get("scaling_learning_rate_by_regex", []) for regex, multiplier in scaling_by_regex: logging.info( "Learning rate is scaled by %f for parameters matching '%s'", multiplier, regex) schedule_ops.append(utils.scale_selected_parameters(regex, multiplier)) schedule_optimizer = optax.chain(*schedule_ops) if config.optimizer.lower() == "adam": optimizer = optax.adam(config.learning_rate) ops.append(optimizer) ops.append(schedule_optimizer) elif config.optimizer.lower() == "sgd": ops.append(schedule_optimizer) optimizer = optax.sgd(config.learning_rate, momentum=config.momentum) ops.append(optimizer) else: raise NotImplementedError("Invalid optimizer: {}".format( config.optimizer)) if "weight_decay" in config and config.weight_decay > 0.: ops.append( utils.decoupled_weight_decay(decay=config.weight_decay, step_size_fn=schedule_fn)) # Freeze parameters that match the given regexes (if any). freeze_weights_regexes = config.get("freeze_weights_regex", []) or [] if isinstance(freeze_weights_regexes, str): freeze_weights_regexes = [freeze_weights_regexes] for reg in freeze_weights_regexes: ops.append(utils.freeze(reg)) return optax.chain(*ops)
state, loss_val, accuracy = train_step(state, batch, labels) train_loss += loss_val train_accuracy += accuracy test_loss, test_accuracy = 0.0, 0.0 for batch, labels in test_loader: batch, labels = jnp.array(batch), jnp.array(labels) loss_val, accuracy = eval_step(state.params, batch, labels) test_loss += loss_val test_accuracy += accuracy train_loss /= len(train_loader) train_accuracy /= len(train_loader) test_loss /= len(test_loader) test_accuracy /= len(test_loader) print( f"epoch {i+1}/{nb_epochs} | train: {train_loss:.5f} [{train_accuracy*100:.2f}%] | eval: {test_loss:.5f} [{test_accuracy*100:.2f}%]" ) return params params = net.init(jax.random.PRNGKey(0), jnp.ones((1, *image_shape)).astype(jnp.float32)) opt = optax.chain(optax.clip(1.0), optax.adamw(learning_rate=1e-4)) params = fit(params, opt)
(loss_val, accuracy), grads = jax.value_and_grad(loss, has_aux=True)(params, batch, labels) updates, opt_state = opt.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return params, opt_state, loss_val, accuracy for i, (batch, labels) in enumerate(zip(train_data, train_labels)): params, opt_state, loss_val, accuracy = step(params, opt_state, batch, labels) if i % 100 == 0: print( f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%" ) return params # opt = optax.adam(learning_rate=1e-2) schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=1.0, warmup_steps=50, decay_steps=5_000, end_value=0.0, ) opt = optax.chain(optax.clip(1.0), optax.adamw(learning_rate=schedule)) params = fit(params, opt)