def create_optimizer(model, learning_rate, weight_decay, layers=None): """Instantiates Adam multi-optimizer.""" if layers is None: assert ( type(learning_rate) == type(weight_decay) == float ), 'Specify float values for moded learning rate and weight decay!' optimizer_def = optim.Adam(learning_rate=learning_rate, weight_decay=weight_decay) optimizer = optimizer_def.create(model) else: assert ( len(learning_rate) == len(weight_decay) == len(layers) ), 'Number of specified learning rates, weight decays, and layers must be equal!' optimizers = [] for lr, wd, layer in zip(learning_rate, weight_decay, layers): if lr > 0: opt = optim.Adam(learning_rate=lr, weight_decay=wd) filter_fn = functools.partial(path_inclusion_filter_fn, layer=layer) traversal = optim.ModelParamTraversal(filter_fn) traversal_opt = (traversal, opt) optimizers.append(traversal_opt) optimizer_def = optim.MultiOptimizer(*optimizers) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(model, learning_rate=1e-4): """Create optimizer used for training model. MultiOpt is used to apply Adam Optimizer with weight decay to all parameters except layer_norm and bias and Adam Optimizer without weight decay for layer_norm and bias params. Args: model: JAX model to add optimizer to learning_rate: base learning rate used for initializing optimizer Returns: optimizer: model with Adam Optimizer to be used for training """ weight_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=0.01) no_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=0.0) def filter_weight_decay(key, _): return 'layer_norm' not in key and 'bias' not in key def filter_other(key, _): return 'layer_norm' in key or 'bias' in key weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay) no_decay_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer( (weight_decay_traversal, weight_decay_def), (no_decay_traversal, no_decay_def)) optimizer = optimizer_def.create(model) optimizer = optimizer.replicate() del model return optimizer
class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop(LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ('adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update( self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
def __init__( self, state_dim, action_dim, max_action, lr=3e-4, discount=0.99, tau=0.005, policy_noise=0.2, expl_noise=0.1, noise_clip=0.5, policy_freq=2, seed=0, ): self.rng = PRNGSequence(seed) actor_input_dim = [((1, state_dim), jnp.float32)] init_rng = next(self.rng) actor = build_td3_actor_model(actor_input_dim, action_dim, max_action, init_rng) self.actor_target = build_td3_actor_model(actor_input_dim, action_dim, max_action, init_rng) actor_optimizer = optim.Adam(learning_rate=lr).create(actor) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [ ((1, state_dim), jnp.float32), ((1, action_dim), jnp.float32), ] critic = build_td3_critic_model(critic_input_dim, init_rng) self.critic_target = build_td3_critic_model(critic_input_dim, init_rng) critic_optimizer = optim.Adam(learning_rate=lr).create(critic) self.critic_optimizer = jax.device_put(critic_optimizer) self.max_action = max_action self.discount = discount self.tau = tau self.policy_noise = policy_noise self.expl_noise = expl_noise self.noise_clip = noise_clip self.policy_freq = policy_freq self.total_it = 0
def create_optimizer(model, model_kwargs, learning_rate=1e-4): """Create optimizer used for training model. MultiOpt is used to apply Adam/LAMB Optimizer with weight decay to all parameters except layer_norm and bias and Adam/LAMB Optimizer without weight decay for layer_norm and bias params. Args: model: JAX model to add optimizer to model_kwargs: Bert model config parameter dictionary. learning_rate: base learning rate used for initializing optimizer Returns: optimizer: model with Adam/LAMB Optimizer to be used for training """ if FLAGS.use_lamb: weight_decay_def = bert_lamb.BertLAMB( learning_rate=learning_rate, beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2, eps=10**FLAGS.log_epsilon, weight_decay=FLAGS.lamb_weight_decay, num_layers=model_kwargs['num_layers']) no_decay_def = bert_lamb.BertLAMB( learning_rate=learning_rate, beta1=FLAGS.lamb_beta_1, beta2=FLAGS.lamb_beta_2, eps=10**FLAGS.log_epsilon, weight_decay=0.0, num_layers=model_kwargs['num_layers']) else: weight_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=FLAGS.lamb_weight_decay) no_decay_def = optim.Adam( learning_rate=learning_rate, eps=1e-6, weight_decay=0.0) def filter_weight_decay(key, _): return 'layer_norm' not in key and 'bias' not in key and 'layernorm' not in key def filter_other(key, _): return 'layer_norm' in key or 'bias' in key or 'layernorm' in key weight_decay_traversal = optim.ModelParamTraversal(filter_weight_decay) no_decay_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer( (weight_decay_traversal, weight_decay_def), (no_decay_traversal, no_decay_def)) optimizer = optimizer_def.create(model) optimizer = jax_utils.replicate(optimizer) del model return optimizer
def __init__( self, state_dim: int, action_dim: int, max_action: float, lr: float = 3e-4, discount: float = 0.99, tau: float = 0.005, policy_noise: float = 0.2, expl_noise: float = 0.1, noise_clip: float = 0.5, policy_freq: int = 2, seed: int = 0, ): self.rng = PRNGSequence(seed) actor_input_dim = (1, state_dim) init_rng = next(self.rng) actor_params = build_td3_actor_model( actor_input_dim, action_dim, max_action, init_rng ) self.actor_target_params = build_td3_actor_model( actor_input_dim, action_dim, max_action, init_rng ) actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [(1, state_dim), (1, action_dim)] critic_params = build_td3_critic_model(critic_input_dim, init_rng) self.critic_target_params = build_td3_critic_model(critic_input_dim, init_rng) critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params) self.critic_optimizer = jax.device_put(critic_optimizer) self.max_action = max_action self.discount = discount self.tau = tau self.policy_noise = policy_noise self.expl_noise = expl_noise self.noise_clip = noise_clip self.policy_freq = policy_freq self.action_dim = action_dim self.total_it = 0
def create_optimizer(config, model): common_kwargs = dict( learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, ) optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs) optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs) decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path) no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(name='adam', learning_rate=6.25e-5, beta1=0.9, beta2=0.999, eps=1.5e-4): if name == 'adam': return optim.Adam( learning_rate=learning_rate, beta1=beta1, beta2=beta2, eps=eps) else: raise ValueError(f'Unknown optimizer {name}')
def main(argv): key = random.PRNGKey(0) train_ds = tfds.load('mnist', split=tfds.Split.TRAIN) train_ds = train_ds.cache().shuffle(1000).batch(FLAGS.batch_size) test_ds = tfds.as_numpy( tfds.load('mnist', split=tfds.Split.TEST, batch_size=-1)) _, params = VAE.init_by_shape(key, [((1, 784), jnp.float32)]) vae = nn.Model(VAE, params) optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae) for epoch in range(FLAGS.num_epochs): for batch in tfds.as_numpy(train_ds): batch['image'] = batch['image'].reshape(-1, 784) / 255.0 optimizer = train_step(optimizer, batch) z = np.random.normal(size=(64, 20)) metrics, comparison, sample = eval(optimizer.target, test_ds, z) save_image(comparison, 'results/reconstruction_' + str(epoch) + '.png', nrow=8) save_image(sample, 'results/sample_' + str(epoch) + '.png', nrow=8) print("eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}".format( epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
def train(rho_g, nn_params): optimizer = optim.Adam(learning_rate=lr, weight_decay=w_decay).create(nn_params) optimizer = jax.device_put(optimizer) train_loss = [] loss0 = 1E16 loss0_tot = 1E16 itercount = itertools.count() f_params = init_params for epoch in range(n_epochs): for _ in range(n_batches): optimizer, loss_and_grad = train_step(optimizer, rho_g, next(batches)) loss, grad = loss_and_grad # f = open(f_out,'a+') # print(i,loss,file=f) # f.close() train_loss.append(loss) # params = optimizer.target # loss_tot = f_validation(params) nn_params = optimizer.target return nn_params, loss_and_grad, train_loss
def train(): """Run main training loop.""" rng = random.PRNGKey(0) # Get Zachary's karate club graph dataset. node_feats, node_labels, sources, targets = get_karate_club_data() # Create model and optimizer. _, initial_params = GNN.init(rng, node_x=node_feats, edge_x=None, sources=sources, targets=targets) model = nn.Model(GNN, initial_params) optimizer = optim.Adam(learning_rate=0.01).create(model) # Train for 20 iterations. for iteration in range(20): optimizer, loss = train_step(optimizer, node_feats, sources, targets) accuracy = eval_step( # Model is stored in `optimizer.target`. optimizer.target, node_feats, sources, targets, node_labels) print('iteration: %d, loss: %.4f, accuracy: %.2f' % (iteration + 1, loss, accuracy * 100))
def create_optimizer(name='adam', learning_rate=6.25e-5, beta1=0.9, beta2=0.999, eps=1.5e-4): """Create an optimizer for training. Currently, only the Adam optimizer is supported. Args: name: str, name of the optimizer to create. learning_rate: float, learning rate to use in the optimizer. beta1: float, beta1 parameter for the optimizer. beta2: float, beta2 parameter for the optimizer. eps: float, epsilon parameter for the optimizer. Returns: A flax optimizer. """ if name == 'adam': logging.info( 'Creating Adam optimizer with settings lr=%f, beta1=%f, ' 'beta2=%f, eps=%f', learning_rate, beta1, beta2, eps) return optim.Adam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, eps=eps) elif name == 'rmsprop': logging.info( 'Creating RMSProp optimizer with settings lr=%f, beta2=%f, ' 'eps=%f', learning_rate, beta2, eps) return optim.RMSProp(learning_rate=learning_rate, beta2=beta2, eps=eps) else: raise ValueError('Unsupported optimizer {}'.format(name))
def init_optimizer_state(workload: spec.Workload, model_params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparamters, rng: spec.RandomState) -> spec.OptimizerState: del model_state del rng del workload optimizer_def = optim.Adam( learning_rate=hyperparameters.learning_rate, beta1=1.0 - hyperparameters.one_minus_beta_1, beta2=0.98, eps=hyperparameters.epsilon) optimizer = optimizer_def.create(model_params) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=hyperparameters.learning_rate, warmup_steps=1000) # compile multidevice versions of train. p_train_step = jax.pmap( functools.partial( train_step, config=models.TransformerConfig( dropout_rate=hyperparameters.dropout_rate, attention_dropout_rate=hyperparameters.attention_dropout_rate), learning_rate_fn=learning_rate_fn), axis_name="batch", donate_argnums=(0,)) return optimizer, p_train_step
def load_model(dataset_name, attention_mask_type, use_relative_attention, bos_special_attention, predict_config): """Loads a checkpoint.""" rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) m = models.DecomposeAttentionTransformer(predict_config) initial_variables = jax.jit(m.init)({ 'params': init_rng, 'dropout': init_rng }, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(1e-3, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=1e-1) optimizer = optimizer_def.create(initial_variables['params']) checkpoint_fname = os.path.join( FLAGS.train_directory, 'train-{}/checkpoints/' 'amt={},bsa={},ed=256,hd=512,l=0.001,nh=4,nl=3,s=0,ura={}/'.format( dataset_name, attention_mask_type, bos_special_attention, use_relative_attention)) logging.info('Loading checkpoint: %s', checkpoint_fname) optimizer = checkpoints.restore_checkpoint(checkpoint_fname, optimizer) checkpoint_num_trained_steps = int(optimizer.state.step) logging.info('Found model checkpointed at step %s.', checkpoint_num_trained_steps) optimizer = jax_utils.replicate(optimizer) return optimizer
def __init__( self, state_dim: int, action_dim: int, max_action: float, discount: float = 0.99, tau: float = 0.005, policy_freq: int = 2, lr: float = 3e-4, entropy_tune: bool = True, seed: int = 0, ): self.rng = PRNGSequence(seed) actor_input_dim = (1, state_dim) actor_params = build_gaussian_policy_model(actor_input_dim, action_dim, max_action, next(self.rng)) actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [(1, state_dim), (1, action_dim)] critic_params = build_double_critic_model(critic_input_dim, init_rng) self.critic_target_params = build_double_critic_model( critic_input_dim, init_rng) critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params) self.critic_optimizer = jax.device_put(critic_optimizer) self.entropy_tune = entropy_tune log_alpha_params = build_constant_model(-3.5, next(self.rng)) log_alpha_optimizer = optim.Adam( learning_rate=lr).create(log_alpha_params) self.log_alpha_optimizer = jax.device_put(log_alpha_optimizer) self.target_entropy = -action_dim self.max_action = max_action self.discount = discount self.tau = tau self.policy_freq = policy_freq self.action_dim = action_dim self.total_it = 0
def create_optimizer(model, learning_rate, weight_decay): optimizer_def = optim.Adam(learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=weight_decay) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(config, model): optimizer_def = optim.Adam(learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(model, learning_rate, weight_decay): """Instantiates Adam optimizer.""" optimizer_def = optim.Adam(learning_rate=learning_rate, weight_decay=weight_decay) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(model, learning_rate): optimizer_def = optim.Adam(learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(model) optimizer = jax_utils.replicate(optimizer) return optimizer
def init_fn(seed): rng = random.PRNGKey(seed) classifier = MLPClassifier.partial(hidden_layers=2, hidden_dim=512, n_classes=n_classes) _, initial_params = classifier.init_by_shape(rng, [(128, *input_shape)]) initial_model = nn.Model(classifier, initial_params) optimizer = optim.Adam(1e-4).create(initial_model) return optimizer
def load_parameters(logdir, init_params): if has_checkpoint(logdir): print("Loading checkpoint from %s" % logdir) optimizer_def = optim.Adam() optimizer = optimizer_def.create(init_params) optimizer = checkpoints.restore_checkpoint(logdir, optimizer) print("Checkpoint loaded from step %d" % optimizer.state.step) return optimizer.target else: print("No checkpoint found in %s" % logdir) return None
def create_model_optimizer(n_bins): ResNet50 = ResNet.partial(stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock) module = ResNet50.partial(n_bins=n_bins, dtype=jnp.float32) input_shape = (1, training_data.shape[1], 1) with nn.stateful() as init_state: _, initial_params = module.init_by_shape( jax.random.PRNGKey(0), [(input_shape, jnp.float32)] ) model = nn.Model(module, initial_params) optimizer = optim.Adam(learning_rate=learning_rate).create(model) return model, optimizer
def create_optimizer(config, model, initial_params): """Create a model, starting with a pre-trained checkpoint.""" common_kwargs = dict( learning_rate=config.learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, ) optimizer_decay_def = optim.Adam(weight_decay=0.01, **common_kwargs) optimizer_no_decay_def = optim.Adam(weight_decay=0.0, **common_kwargs) decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path) no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path) optimizer_def = optim.MultiOptimizer((decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) # TODO(marcvanzee): MultiOptimizer triggers double XLA compilation on TPU so # we use Adam here, but we should investigate why this happens. optimizer_def = optim.Adam(learning_rate=config.learning_rate) optimizer = optimizer_def.create(model) optimizer = optimizer.replicate() del model # don't keep a copy of the initial model return optimizer
def create_train_fn(model, model_dir, duration, batch, train_steps, learning_rate): optimizer = optim.Adam() opt = optimizer.create(model) state = TrainState(optimizer=opt, step=0) # pytype:disable=wrong-keyword-args state = checkpoints.restore_checkpoint(model_dir, state) state = jax_utils.replicate(state) iterator = None @functools.partial(jax.pmap, axis_name="batch") def train_step(obs, state): actions = obs["action"] rewards = obs["reward"] step = state.step optimizer = state.optimizer def loss(model): predictions = model(actions) l = (rewards - predictions)**2 l = jnp.mean(l) return l grad_fn = jax.value_and_grad(loss) l, grads = grad_fn(state.optimizer.target) grads = lax.pmean(grads, axis_name="batch") new_optimizer = optimizer.apply_gradient(grads, learning_rate=learning_rate) new_state = state.replace(step=step + 1, optimizer=new_optimizer) return new_state, l def train(data_path): nonlocal iterator nonlocal state if iterator is None: dataset = npz.load_dataset_from_directory(data_path, duration, batch) iterator = dataset.make_one_shot_iterator() iterator = map( lambda x: jax.tree_map( lambda x: np.reshape(x, (jax.local_device_count(), -1) + x. numpy().shape[1:]), x), iterator) iterator = jax_utils.prefetch_to_device(iterator, 2) for _ in range(train_steps): obs = next(iterator) state, l = train_step(obs, state) local_state = get_first_device(state) l = get_first_device(l) checkpoints.save_checkpoint(model_dir, local_state, local_state.step) return train
def make_optimizer(optimizer: str, learning_rate: float, weight_decay: float = 5.0e-4): if optimizer == 'SGD': return optim.Optimizer(lerning_rate=learning_rate) elif optimizer == 'Momentum': return optim.Momentum(learning_rate=learning_rate, weight_decay=weight_decay) elif optimizer == 'Adam': return optim.Adam(learning_rate=learning_rate, weight_decay=weight_decay) else: raise ValueError('Unknown optimizer spec.')
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.Adam(learning_rate=0.1, beta1=0.2, beta2=0.9, eps=0.01, weight_decay=0.0) state = optimizer_def.init_state(params) expected_hyper_params = _AdamHyperParams(0.1, 0.2, 0.9, 0.01, 0.0) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _AdamParamState(onp.zeros((1,)), onp.zeros((1,)))) self.assertEqual(state, expected_state)
def main(argv): del argv # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') rng = random.PRNGKey(0) rng, key = random.split(rng) ds_builder = tfds.builder('binarized_mnist') ds_builder.download_and_prepare() train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN) train_ds = train_ds.map(prepare_image) train_ds = train_ds.cache() train_ds = train_ds.repeat() train_ds = train_ds.shuffle(50000) train_ds = train_ds.batch(FLAGS.batch_size) train_ds = iter(tfds.as_numpy(train_ds)) test_ds = ds_builder.as_dataset(split=tfds.Split.TEST) test_ds = test_ds.map(prepare_image).batch(10000) test_ds = np.array(list(test_ds)[0]) test_ds = jax.device_put(test_ds) module = VAE.partial(latents=FLAGS.latents) _, params = module.init_by_shape(key, [(FLAGS.batch_size, 784)], z_rng=random.PRNGKey(0)) vae = nn.Model(module, params) optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae) optimizer = jax.device_put(optimizer) rng, z_key, eval_rng = random.split(rng, 3) z = random.normal(z_key, (64, FLAGS.latents)) steps_per_epoch = 50000 // FLAGS.batch_size for epoch in range(FLAGS.num_epochs): for _ in range(steps_per_epoch): batch = next(train_ds) rng, key = random.split(rng) optimizer = train_step(optimizer, batch, key) metrics, comparison, sample = eval(optimizer.target, test_ds, z, eval_rng) save_image(comparison, f'results/reconstruction_{epoch}.png', nrow=8) save_image(sample, f'results/sample_{epoch}.png', nrow=8) print('eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format( epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
def train_model(): """Train for a fixed number of steps and decode during training.""" param = get_param(jax.random.PRNGKey(0)) optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(param) key = jax.random.PRNGKey(0) for step in range(FLAGS.num_train_steps): key, lstm_key = jax.random.split(key) batch = get_batch(FLAGS.batch_size) optimizer, metrics = train_step(optimizer, batch, lstm_key) if step % FLAGS.decode_frequency == 0: key, decode_key = jax.random.split(key) logging.info('train step: %d, loss: %.4f, accuracy: %.2f', step, metrics['loss'], metrics['accuracy'] * 100) decode_batch(optimizer.target, 5, decode_key) return optimizer.target
def create_optimizer(model, learning_rate): # def adam_optimizer(weight_decay): # return optim.Adam(learning_rate=learning_rate, beta1=0.9, # beta2=0.999, eps=1e-6, weight_decay=weight_decay) # optimizer_decay_def = adam_optimizer(weight_decay=0.01) # optimizer_no_decay_def = adam_optimizer(weight_decay=0.0) # decay = optim.ModelParamTraversal(lambda path, _: 'bias' not in path) # no_decay = optim.ModelParamTraversal(lambda path, _: 'bias' in path) # optimizer_def = optim.MultiOptimizer( # (decay, optimizer_decay_def), (no_decay, optimizer_no_decay_def)) optimizer_def = optim.Adam(learning_rate=learning_rate) optimizer = optimizer_def.create(model) return optimizer
def test_apply_gradient(self): optimizer_def = optim.Adam(learning_rate=0.1, beta1=0.2, beta2=0.9, eps=0.01, weight_decay=0.0) params = onp.array([1.]) state = optim.OptimizerState( 1, _AdamParamState(onp.array([0.1]), onp.array([0.9]))) grads = onp.array([4.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 2, _AdamParamState(onp.array([3.22]), onp.array([2.41]))) expected_new_params = onp.array([0.906085]) onp.testing.assert_allclose(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)