def make_learner( self, random_key: networks_lib.PRNGKey, networks: d4pg_networks.D4PGNetworks, dataset: Iterator[reverb.ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec, replay_client policy_optimizer = optax.adam(self._config.learning_rate) critic_optimizer = optax.adam(self._config.learning_rate) if self._config.clipping: policy_optimizer = optax.chain(optax.clip_by_global_norm(40.), policy_optimizer) critic_optimizer = optax.chain(optax.clip_by_global_norm(40.), critic_optimizer) # The learner updates the parameters (and initializes them). return learning.D4PGLearner( policy_network=networks.policy_network, critic_network=networks.critic_network, random_key=random_key, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=self._config.clipping, discount=self._config.discount, target_update_period=self._config.target_update_period, iterator=dataset, counter=counter, logger=logger_fn('learner'), num_sgd_steps_per_step=self._config.num_sgd_steps_per_step)
def make_optimizer(momentum=True, schedule_fn = lambda x:-1e-3): """SGD with momentum and a fixed lr.""" if momentum: return optax.chain( optax.trace(decay=0.9, nesterov=False), # momentum optax.scale_by_schedule(schedule_fn)) else: return optax.chain( optax.scale_by_schedule(schedule_fn))
def test_correctness(self): """Testing correctness via independent implementation.""" def ema(decay, debias=True): def init_fn(params): del params return {'w': jnp.zeros((2, )), 'count': 0} def update_fn(updates, state, params=None): del params state['count'] += 1 state['w'] = ((1 - decay) * updates['w'] + decay * state['w']) if debias: update = {'w': state['w'] / (1 - decay**state['count'])} else: update = {'w': state['w']} return update, state return optax.GradientTransformation(init_fn, update_fn) decay = 0.7 learning_rate = 0.01 true_ema = optax.chain(ema(decay), optax.scale(-1. * learning_rate)) ks_ema = transform_chain(['first_moment_ema'], [{ 'decay': decay, 'debias': True, }], learning_rate=learning_rate) targets = _optimizer_loop(true_ema) results = _optimizer_loop(ks_ema) for target, result in zip(targets, results): chex.assert_trees_all_close(target, result)
def make_learner( self, random_key: networks_lib.PRNGKey, networks: impala_networks.IMPALANetworks, dataset: Iterator[reverb.ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec, replay_client optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.adam( self._config.learning_rate, b1=self._config.adam_momentum_decay, b2=self._config.adam_variance_decay, eps=self._config.adam_eps, eps_root=self._config.adam_eps_root)) return learning.IMPALALearner( networks=networks, iterator=dataset, optimizer=optimizer, random_key=random_key, discount=self._config.discount, entropy_cost=self._config.entropy_cost, baseline_cost=self._config.baseline_cost, max_abs_reward=self._config.max_abs_reward, counter=counter, logger=logger_fn('learner'), )
def make_learner( self, random_key: networks_lib.PRNGKey, networks: td3_networks.TD3Networks, dataset: Iterator[reverb.ReplaySample], replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: critic_optimizer = optax.adam(self._config.critic_learning_rate) twin_critic_optimizer = optax.adam(self._config.critic_learning_rate) policy_optimizer = optax.adam(self._config.policy_learning_rate) if self._config.policy_gradient_clipping is not None: policy_optimizer = optax.chain( optax.clip_by_global_norm( self._config.policy_gradient_clipping), policy_optimizer) return learning.TD3Learner( networks=networks, random_key=random_key, discount=self._config.discount, target_sigma=self._config.target_sigma, noise_clip=self._config.noise_clip, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, twin_critic_optimizer=twin_critic_optimizer, num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, bc_alpha=self._config.bc_alpha, iterator=dataset, logger=self._logger_fn(), counter=counter)
def __init__(self, model: SwarmModel, optimizer: optax.GradientTransformation, loss_scale: float, dataloader: Callable, precision: NetworkPrecision): self.model = model self.optimizer = optax.chain(optax.scale(1 / loss_scale), optimizer) self.dataloader = dataloader self.minibatches = 1 self.loss_scale = loss_scale assert ray.is_initialized() # needs a valid ray cluster to start example = self.dataloader() self.embedding = EmbeddingLayer.options(max_concurrency=8).remote( example["obs"], self.model.vocab, self.model.d_model, self.optimizer, precision) self.embedding.run.remote() x, _ = self.embedding.embed_forward.remote(example["obs"]) self.proj = ProjLayer.options(max_concurrency=8).remote( x, self.model.vocab, self.model.d_model, self.optimizer, self.loss_scale, precision) self.proj.run.remote() self.layers = [] for i in range(model.rev_layers): self.layers.append( ReversibleLayer.options(max_concurrency=8).remote( self.model.rev_init, i, x, self.optimizer, precision)) for l in self.layers: l.run.remote() self.all_layers = [self.embedding] + self.layers + [self.proj]
def build_optimizer(self, clip=15.0, lr=5e-4, warmup=2000, cosine_decay_steps=None, optimizer_name="adabelief") -> GradientTransformation: chain = [] if optimizer_name == "adabelief": chain.append(util.scale_by_belief()) elif optimizer_name == "adam": chain.append(optax.scale_by_adam()) else: assert 0 # Make sure to use the negative learning rate so that we minimize if warmup and warmup > 0: warmup_schedule = partial(util.linear_warmup_lr_schedule, warmup=warmup, lr_decay=1.0, lr=-lr) chain.append(optax.scale_by_schedule(warmup_schedule)) else: chain.append(optax.scale(-lr)) if cosine_decay_steps and cosine_decay_steps > 0: cosine_lr = optax.cosine_decay_schedule( init_value=1.0, decay_steps=cosine_decay_steps, alpha=1e-1) chain.append(optax.scale_by_schedule(cosine_lr)) if clip and clip > 0: chain.append(optax.clip(clip)) return optax.chain(*chain)
def main(): # Create the dataset. train_dataset, vocab_size = load(batch_size, sequence_length) # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers, dropout_rate) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value), optax.adam(learning_rate, b1=0.9, b2=0.99)) updater = GradientUpdater(forward_fn.init, loss_fn, optimizer) # Initialize parameters. logging.info('Initializing parameters...') rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) logging.info('Starting train loop...') prev_time = time.time() for step in range(MAX_STEPS): data = next(train_dataset) state, metrics = updater.update(state, data)
def sgd_momentum(learning_rate_fn: optax.Schedule, momentum: float = 0., nesterov: bool = False) -> optax.GradientTransformation: return optax.chain( optax.trace(decay=momentum, nesterov=nesterov), optax.scale_by_schedule(learning_rate_fn), optax.scale(-1.))
def make_optimizer(): """SGD with nesterov momentum and a custom lr schedule.""" return optax.chain( optax.trace( decay=FLAGS.optimizer_momentum, nesterov=FLAGS.optimizer_use_nesterov), optax.scale_by_schedule(lr_schedule), optax.scale(-1))
def test_batch_overfit(train_dataset): vocab_size, d_model, num_heads, num_layers = 100, 32, 8, 1 dropout_rate, grad_clip_value, learning_rate = 0.01, 0.25, 2e-2 max_iter = 100 # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers, dropout_rate, max_iter) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value), optax.adam(learning_rate, b1=0.9, b2=0.99)) updater = Updater(forward_fn.init, loss_fn, optimizer) # Initialize parameters. rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) for step in range(100): data = next(train_dataset) state, metrics = updater.update(state, data) assert metrics['loss'] < 0.1 assert metrics['step'] == 99
def optimizer(hyperparameters): opt_init_fn, opt_update_fn = optax.chain( optax.scale_by_adam(b1=1.0 - hyperparameters.one_minus_beta_1, b2=0.999, eps=hyperparameters.epsilon), optax.scale(-hyperparameters.learning_rate)) return opt_init_fn, opt_update_fn
def test_graph_conditioned_transformer_learns(self): graphs = jraph.GraphsTuple( nodes=np.ones((4, 3), dtype=np.float32), edges=np.ones((3, 1), dtype=np.float32), senders=np.array([0, 2, 3], dtype=np.int32), receivers=np.array([1, 3, 2], dtype=np.int32), n_node=np.array([2, 2], dtype=np.int32), n_edge=np.array([1, 2], dtype=np.int32), globals=None, ) seqs = np.array([[1, 2, 2, 0], [1, 3, 3, 3]], dtype=np.int32) vocab_size = seqs.max() + 1 embed_dim = 8 max_graph_size = graphs.n_node.max() logging.info('Training seqs: %r', seqs) x = seqs[:, :-1] y = seqs[:, 1:] def model_fn(vocab_size, embed_dim): return models.Graph2TextTransformer( vocab_size=vocab_size, emb_dim=embed_dim, num_layers=2, num_heads=4, cutoffs=[], gnn_embed_dim=embed_dim, gnn_num_layers=2) def forward(graphs, inputs, labels, max_graph_size): input_mask = (labels != 0).astype(jnp.float32) return model_fn(vocab_size, embed_dim).loss( graphs, max_graph_size, False, inputs, labels, mask=input_mask) init_fn, apply_fn = hk.transform_with_state(forward) rng = hk.PRNGSequence(8) params, state = init_fn(next(rng), graphs, x, y, max_graph_size) def apply(*args, **kwargs): out, state = apply_fn(*args, **kwargs) return out[0], (out[1], state) apply = jax.jit(apply, static_argnums=6) optimizer = optax.chain( optax.scale_by_adam(), optax.scale(-1e-3)) opt_state = optimizer.init(params) for i in range(500): (loss, model_state), grad = jax.value_and_grad(apply, has_aux=True)( params, state, next(rng), graphs, x, y, max_graph_size) metrics, state = model_state updates, opt_state = optimizer.update(grad, opt_state, params) params = optax.apply_updates(params, updates) if (i + 1) % 100 == 0: logging.info( 'Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()}) logging.info('Loss: %.8f', loss) self.assertLess(loss, 1.0)
def _run_attack( max_objective_fn: Callable[[Tensor, PRNGKey], Tensor], projection_fn: Callable[[Tensor], Tensor], x_init: Tensor, prng_key: PRNGKey, num_steps: int, learning_rate: float, ): """Run attack.""" opt = optax.chain( optax.scale(-1), # maximization optax.adam(learning_rate)) grad_fn = jax.grad(max_objective_fn) def body_fn(it, inputs): del it # unused x, prng_in, opt_state = inputs prng_out, prng_used = jax.random.split(prng_in) grad_x = grad_fn(x, prng_used) updates, opt_state = opt.update(grad_x, opt_state, x) x = optax.apply_updates(x, updates) x = projection_fn(x) return x, prng_out, opt_state opt_state = opt.init(x_init) init_state = (x_init, prng_key, opt_state) x, prng_final, _ = jax.lax.fori_loop(0, num_steps, body_fn, init_state) return max_objective_fn(x, prng_final)
def main(_): # Create the dataset. train_dataset, vocab_size = dataset.load(FLAGS.batch_size, FLAGS.sequence_length) # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads, FLAGS.num_layers, FLAGS.dropout_rate) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.grad_clip_value), optax.adam(FLAGS.learning_rate, b1=0.9, b2=0.99)) updater = Updater(forward_fn.init, loss_fn, optimizer) updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) # Initialize parameters. logging.info('Initializing parameters...') rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) logging.info('Starting train loop...') prev_time = time.time() for step in range(MAX_STEPS): data = next(train_dataset) state, metrics = updater.update(state, data) # We use JAX runahead to mask data preprocessing and JAX dispatch overheads. # Using values from state/metrics too often will block the runahead and can # cause these overheads to become more prominent. if step % LOG_EVERY == 0: steps_per_sec = LOG_EVERY / (time.time() - prev_time) prev_time = time.time() metrics.update({'steps_per_sec': steps_per_sec}) logging.info({k: float(v) for k, v in metrics.items()})
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks_lib.FeedForwardNetwork, config: DQNConfig, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=config.n_step, batch_size=config.batch_size, max_replay_size=config.max_replay_size, min_replay_size=config.min_replay_size, priority_exponent=config.priority_exponent, discount=config.discount, ) self._server = reverb_replay.server optimizer = optax.chain( optax.clip_by_global_norm(config.max_gradient_norm), optax.adam(config.learning_rate), ) key_learner, key_actor = jax.random.split( jax.random.PRNGKey(config.seed)) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, random_key=key_learner, optimizer=optimizer, discount=config.discount, importance_sampling_exponent=config.importance_sampling_exponent, target_update_period=config.target_update_period, iterator=reverb_replay.data_iterator, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. def policy(params: networks_lib.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(config.epsilon).sample( key, action_values) actor = actors.FeedForwardActor( policy=policy, rng=hk.PRNGSequence(key_actor), variable_client=variable_utils.VariableClient(learner, ''), adder=reverb_replay.adder) super().__init__( actor=actor, learner=learner, min_observations=max(config.batch_size, config.min_replay_size), observations_per_step=config.batch_size / config.samples_per_insert, )
def train(*, data_folder, batch_size, epochs, learning_rate, weight_decay, seed, max_norm, text_vocab, text_dim, text_depth, text_heads, audio_dim, audio_depth, audio_heads): # rng rng_key = random.PRNGKey(seed) # data dataset = PairTextSpectrogramDataset(data_folder) dl = DataLoader(dataset, batch_size=batch_size, collate_fn=pair_text_spectrogram_dataset_collate_fn, drop_last=True, shuffle=True) # model model = CLAP(text_vocab=text_vocab, text_dim=text_dim, text_depth=text_depth, text_heads=text_heads, audio_dim=audio_dim, audio_depth=audio_depth, audio_heads=audio_heads) # optimizer exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1, params) optim = chain(clip_by_global_norm(max_norm), scale_by_adam(eps=1e-4), add_decayed_weights(weight_decay, exclude_bias), scale(-learning_rate)) # init audio, audio_mask, text, text_mask = next(iter(dl)) params = model.init(rng_key, text, audio, text_mask, audio_mask) optim_state = optim.init(params) # loss function, for use with value_and_grad @jit @value_and_grad def loss_fn(params, text, audio, text_mask, audio_mask): return model.apply(params, text, audio, text_mask, audio_mask) # train loop for _ in range(epochs): for audio, audio_mask, text, text_mask in dl: loss, grads = loss_fn(params, text, audio, text_mask, audio_mask) updates, optim_state = optim.update(grads, optim_state, params) params = apply_updates(params, updates) print(f'loss: {loss}')
def create_train_state(rng, config: ml_collections.ConfigDict, model): """Create initial training state.""" params = get_initial_params(rng, model) tx = optax.chain( optax.sgd(learning_rate=config.learning_rate, momentum=config.momentum), optax.additive_weight_decay(weight_decay=config.weight_decay)) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state
def kitchen_sink(chains: List[optax.GradientTransformation], scales: jnp.array = None, combinator: Union[Callable[[Any, Any], Any], str] = 'sum', combinator_args: Dict[str, float] = None, learning_rate: float = None) -> optax.GradientTransformation: """Runs a list of GradientTransforms in parallel and combines. Args: chains: list of optax.GradientTransforms (typically from transform_chain). scales: a (len(chains),)-shaped jnp.array. combinator: a combinator that reduces a list of identical pytrees combinator_args: a dictionary of keyword arguments to the combinator func. learning_rate: learning rate that gets injected. Returns: optax.GradientTransform """ if isinstance(combinator, str): combinator = _combinators.get(combinator, _sum_combinator) combinator_args = combinator_args or {} if scales is None: scales = jnp.ones(len(chains)) chains = [ optax.chain(chain, optax.scale(scale)) for chain, scale in zip(chains, scales) ] def init_fn(params): return [chain.init(params) for chain in chains] def update_fn(updates, state, params=None): result = [chain.update(updates, chain_state, params) for chain, chain_state in zip(chains, state)] new_updates, new_state = list(zip(*result)) return combinator(*new_updates, **combinator_args), new_state transform = optax.GradientTransformation(init_fn, update_fn) if learning_rate is not None: transform = optax.chain(transform, scale_by_learning_rate(learning_rate)) return transform
def __init__( self, *optimizer: optax.GradientTransformation, lr_schedule: tp.Optional[LRScheduler] = None, steps_per_epoch: tp.Union[int, jnp.ndarray, np.ndarray, None] = None, **kwargs, ): r""" Arguments: optimizer: An optax `GradientTransformation` object, if more than one is passed via `*args` then they are grouped using `optax.chain`. lr_schedule: A optional callable of the form `def lr_schedule(step: int, epoch: Optional[int]) -> float` that returns the learning rate schedule at each time step. If `steps_per_epoch` is given then epoch is calculated, else epoch is None. steps_per_epoch: The number of steps to in an epoch, needed to caculate `epoch` from `step`. """ if len(optimizer) == 0: raise ValueError("Must pass atleast 1 optimizer, got 0") elif lr_schedule is not None: # do this to preserve reference after re-assign latter base_schedule = lr_schedule def lr_schedule_(step: jnp.ndarray) -> jnp.ndarray: epoch: tp.Any = (step // steps_per_epoch if steps_per_epoch is not None else None) return base_schedule(step, epoch) optimizer = optax.chain( *optimizer, optax.scale_by_schedule(lr_schedule_), ) lr_schedule = lr_schedule_ elif len(optimizer) == 1: optimizer = optimizer[0] else: optimizer = optax.chain(*optimizer) self.optimizer = optimizer self.lr_schedule = lr_schedule
def test_regularized_training(self): """Test that adding regularization penalty to the training loss works.""" np.random.seed(0) # Set up the problem of recovering w given x and # y = x . w + noise # with the a priori assumption that w is sparse. There are fewer examples # than dimensions (x is a wide matrix), so the problem is underdetermined # without the sparsity assumption. num_examples, num_dim = 8, 10 x = np.random.randn(num_examples, num_dim).astype(np.float32) true_w = np.zeros((num_dim, 2), np.float32) true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0] true_w[[3, 5], 1] = [4.0, 5.0] y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2) # Get the least squares estimate for w. It isn't very accurate. least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0] least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w) # Get a better estimate by solving the L1 regularized problem # argmin_w ||x . w - y||_2^2 + c ||w||_1. w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w) def model_fun(batch): x = batch['x'] return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x) model = hk_util.transform(model_fun) def loss_fun(params, batch): """Training loss with L1 regularization penalty term.""" y_predicted, penalties = model.apply(params, None, batch) return hk_util.l2_loss(y_predicted - batch['y']) + penalties batch = {'x': x, 'y': y} params = model.init(jax.random.PRNGKey(0), batch) optimizer = optax.chain( # Gradient descent with decreasing learning rate. optax.trace(decay=0.0, nesterov=False), optax.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i))) opt_state = optimizer.init(params) @jax.jit def train_step(params, opt_state, batch): grads = jax.grad(loss_fun)(params, batch) updates, opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, opt_state for _ in range(1000): params, opt_state = train_step(params, opt_state, batch) l1_w = params['linear']['w'] l1_w_error = hk_util.l2_loss(l1_w - true_w).item() # The L1-regularized estimate is much more accurate. self.assertGreater(least_squares_w_error, 4.0) self.assertLess(l1_w_error, 1.0)
def transform_chain( elements: List[str], hps: List[Dict[str, float]] = None, masks: List[Any] = None, learning_rate: float = None) -> optax.GradientTransformation: """Utility function for chaining GradientTransforms based on string names. Args: elements: list of transform strings. hps: list of dicts of args for each transform. masks: list of masks for each transform. learning_rate: learning rate that gets injected. Returns: optax.GradientTransform """ hps = hps or [{}] * len(elements) masks = masks or [None] * len(elements) transforms = [] if len(hps) != len(elements): raise ValueError('Number of hps must equal number of elements.') if len(masks) != len(elements): raise ValueError('Number of masks must equal number of elements.') transforms = [_transformations[el](**hp) for el, hp in zip(elements, hps)] for i, (transform, mask) in enumerate(zip(transforms, masks)): if mask is not None: transforms[i] = optax.masked(transform, mask) if learning_rate is not None: transforms += [scale_by_learning_rate(learning_rate)] init_fn, update_fn = optax.chain(*transforms) # NOTE(dsuo): We use plain dicts internally due to this issue # https://github.com/deepmind/optax/issues/160. def wrapped_init_fn(params): return init_fn(flax.core.unfreeze(params)) def wrapped_update_fn(updates, state, params=None): new_updates, state = update_fn( flax.core.unfreeze(updates), state, None if params is None else flax.core.unfreeze(params)) if isinstance(updates, flax.core.FrozenDict): new_updates = flax.core.freeze(new_updates) return new_updates, state return optax.GradientTransformation(wrapped_init_fn, wrapped_update_fn)
def get(self) -> optax.GradientTransformation: if "adam" in self.optimizer: opt = optax.adam(self.base_learning_rate) elif "sgd" == self.optimizer and self.lr_schedule == "linear": lr_schedule = warm_up_polynomial_schedule( base_learning_rate=self.base_learning_rate, end_learning_rate=self.final_decay_factor * self.base_learning_rate, decay_steps=(self.n_batches * (self.epochs - self.lr_warmup_epochs)), warmup_steps=self.n_batches * self.lr_warmup_epochs, decay_power=1.0, ) momentum = 1 - self.one_minus_momentum opt = optax.chain( optax.trace(decay=momentum, nesterov=True), optax.scale_by_schedule(lr_schedule), optax.scale(-1), ) elif "sgd" in self.optimizer and self.lr_schedule == "step": lr_decay_epochs = [ (int(start_epoch_str) * self.epochs) // DEFAULT_NUM_EPOCHS for start_epoch_str in self.lr_decay_epochs ] lr_schedule = warm_up_piecewise_constant_schedule( steps_per_epoch=self.n_batches, base_learning_rate=self.base_learning_rate, decay_ratio=self.lr_decay_ratio, decay_epochs=lr_decay_epochs, warmup_epochs=self.lr_warmup_epochs, ) momentum = 1 - self.one_minus_momentum opt = optax.chain( optax.trace(decay=momentum, nesterov=True), optax.scale_by_schedule(lr_schedule), optax.scale(-1), ) else: raise ValueError("No optimizer specified.") return opt
def make_learner( self, random_key: networks_lib.PRNGKey, networks: ppo_networks.PPONetworks, dataset: Iterator[reverb.ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec, replay_client if callable(self._config.learning_rate): optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.scale_by_adam(eps=self._config.adam_epsilon), optax.scale_by_schedule(self._config.learning_rate), optax.scale(-1)) else: optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.scale_by_adam(eps=self._config.adam_epsilon), optax.scale(-self._config.learning_rate)) return learning.PPOLearner( ppo_networks=networks, iterator=dataset, discount=self._config.discount, entropy_cost=self._config.entropy_cost, value_cost=self._config.value_cost, max_abs_reward=self._config.max_abs_reward, ppo_clipping_epsilon=self._config.ppo_clipping_epsilon, clip_value=self._config.clip_value, gae_lambda=self._config.gae_lambda, counter=counter, random_key=random_key, optimizer=optimizer, num_epochs=self._config.num_epochs, num_minibatches=self._config.num_minibatches, logger=logger_fn('learner'), )
def _create_jax_optimizer(self): import optax process = [] if isinstance(self.learning_rate, LearningRateSchedule): scheduler = self.learning_rate._create_jax_schedule() process.append(optax.scale_by_schedule(scheduler)) last_process = optax.scale(-1.0) else: lr = self.learning_rate last_process = optax.scale(-1.0 * lr) process.append(last_process) return optax.chain(*process)
def main(argv): del argv learning_rate = 1e-2 batch_size = 64 input_size = 8 n_training_steps = 100 # Random number generator sequence. key_seq = hk.PRNGSequence(1729) # A simple Linear function. def forward_pass(x): return hk.Linear(10)(x) network = hk.without_apply_rng(hk.transform(forward_pass)) # Some arbitrary loss. def mean_square_loss(params, x): output = network.apply(params, x) loss = jnp.sum(output**2) return loss # Construct a simple Adam optimiser using the transforms in optax. # You could also just use the `optax.adam` alias, but we show here how # to do so manually so that you may construct your own `custom` optimiser. opt_init, opt_update = optax.chain( # Set the parameters of Adam. Note the learning_rate is not here. optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), # Put a minus sign to *minimise* the loss. optax.scale(-learning_rate)) # Initialise the model's parameters and the optimiser's state. # The `state` of an optimiser contains all statistics used by the # stateful transformations in the `chain` (in this case just `scale_by_adam`). params = network.init(next(key_seq), jnp.zeros([1, input_size])) opt_state = opt_init(params) # Minimise the loss. for step in range(n_training_steps): # Get input. Learn to minimize the input to 0. data = jax.random.normal(next(key_seq), [batch_size, input_size]) # Compute gradient and loss. loss, grad = jax.value_and_grad(mean_square_loss)(params, data) print(f'Loss[{step}] = {loss}') # Transform the gradients using the optimiser. updates, opt_state = opt_update(grad, opt_state, params) # Update parameters. params = optax.apply_updates(params, updates)
def create_train_state(config, rng, learning_rate_fn, example_batch): """Create and initialize the model. Args: config: Configuration for model. rng: JAX PRNG Key. learning_rate_fn: learning rate function example_batch: for model intialization Returns: The initialized TrainState with the optimizer. """ model, variables = create_model(config, rng, example_batch) params = variables['params'] parameter_overview.log_parameter_overview(params) optimizer = optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=.98, eps=1e-9, weight_decay=config.train.weight_decay) if config.train.grad_max_norm > 0: tx = optax.chain(optax.clip_by_global_norm(config.train.grad_max_norm), optimizer) elif config.train.grad_max_val > 1: tx = optax.chain(optax.clip(config.train.grad_max_val), optimizer) else: tx = optimizer state = train_state.TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx, ) return model, state
def optimizer(self, learning_rate): """Construct optimizer.""" clip = optax.clip_by_global_norm( self.config.optimizer.gradient_clip_norm) optimizer = getattr(optax, self.config.optimizer.name)( learning_rate, **self.config.optimizer.args, ) optim_step = optax.chain(clip, optimizer) optim_step = optimizers.maybe_skip_gradient_update( optim_step, self.config.optimizer.gradient_skip_norm, ) return optim_step
def _create_jax_optimizer(self): import optax process = [] if isinstance(self.learning_rate, LearningRateSchedule): lr = self.learning_rate.initial_rate last_process = optax.scale(-1.0) else: lr = self.learning_rate last_process = optax.scale(-1.0 * lr) process.append( optax.scale_by_rss( initial_accumulator_value=self.initial_accumulator_value, eps=self.epsilon)) process.append(last_process) return optax.chain(*process)
def test_bow_transformer_learns(self): bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1], [0, 1, 0, 0, 1, 0, 1, 0], [1, 0, 0, 0, 1, 0, 0, 1]], dtype=np.int32) seqs = np.array([[1, 2, 2, 3, 0, 0], [1, 2, 4, 5, 6, 0], [1, 3, 3, 5, 4, 2]], dtype=np.int32) x = seqs[:, :-1] y = seqs[:, 1:] vocab_size = seqs.max() + 1 def model_fn(): return models.Bow2TextTransformer( vocab_size=vocab_size, emb_dim=16, num_layers=2, num_heads=4, cutoffs=[]) def loss_fn(bow, inputs, labels): mask = (labels != 0).astype(jnp.float32) return model_fn().loss(bow, inputs, labels, mask=mask) init_fn, apply_fn = hk.transform_with_state(loss_fn) key = hk.PRNGSequence(8) params, state = init_fn(next(key), bow, x, y) def apply(*args, **kwargs): out, state = apply_fn(*args, **kwargs) return out[0], (out[1], state) value_and_grad = jax.jit(jax.value_and_grad(apply, has_aux=True)) optimizer = optax.chain( optax.scale_by_adam(), optax.scale(-1e-3)) opt_state = optimizer.init(params) for i in range(800): (loss, model_state), grad = value_and_grad( params, state, next(key), bow, x, y) metrics, state = model_state updates, opt_state = optimizer.update(grad, opt_state, params) params = optax.apply_updates(params, updates) if (i + 1) % 100 == 0: logging.info('Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()}) logging.info('Loss: %.8f', loss) self.assertLess(loss, 0.1)