def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) rng_model_local, diag_noise_rng, standard_noise_rng = jax.random.split( rng_model_local, num=3) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={ 'dropout': rng_model_local, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements
def update_fn(opt, lr, images, labels, rng): """Update step. Copy to deterministic_utils.py whenever changes are made!""" measurements = {} # Split rng and return next_rng for the following step. rng, next_rng = jax.random.split(rng, 2) rng_local = jax.random.fold_in(rng, jax.lax.axis_index('batch')) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_local}) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) return loss, logits # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# (l, logits), g = train_utils.accumulate_gradient( jax.value_and_grad(loss_fn, has_aux=True), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') measurements['training_loss'] = l # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) top1_idx = jnp.argmax(logits, axis=1) top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] prec1 = jax.lax.psum(jnp.sum(top1_correct), axis_name='batch') / batch_size measurements['training_prec@1'] = prec1 measurements['learning_rate'] = lr return opt, next_rng, measurements
def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_model_local}) return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None: grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if grad_clip_norm is not None: g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) decay_rules = weight_decay or [] if isinstance(decay_rules, numbers.Number): decay_rules = [('.*kernel.*', decay_rules)] sched_m = lr / config.lr.base if config.get( 'weight_decay_decouple') else lr def decay_fn(v, wd): return (1.0 - sched_m * wd) * v opt = opt.replace(target=train_utils.tree_map_with_regex( decay_fn, opt.target, decay_rules)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements