Beispiel #1
0
def eval_step(model, state, batch, config):
  """Evaluation step."""
  state = train_functions.pmean(state, config)

  with flax.deprecated.nn.stateful(state, mutable=False):
    logits, penalty, summaries = model(batch['image'], train=False)
  penalty = train_functions.pmean(penalty, config, 'batch')
  metrics = {'model_penalty': penalty}
  metrics.update(
      train_functions.compute_metrics(config, logits, batch['label']))
  metrics.update(summaries)
  return logits, batch['label'], metrics
Beispiel #2
0
  def update_grad_vars(optimizer, state, batch, prng_key, values):
    """Computes gradient variances for the preconditioner."""
    grad_fn = jax.value_and_grad(
        functools.partial(loss_fn, state=state, batch=batch, prng_key=prng_key),
        has_aux=True)

    _, grad = grad_fn(optimizer.target)
    grad = train_functions.pmean(grad, config, 'batch')

    values = jax.tree_multimap(lambda v, g: v + jnp.square(g), values, grad)
    return values
Beispiel #3
0
    def train_step(
        optimizer,
        state,
        batch,
        prng_key,
        opt_rng,
    ):
        step = optimizer.state.step
        lr = learning_rate_fn(step)
        temp = temperature_fn(step)
        step_size_factor = step_size_fn(step)

        grad_fn = jax.value_and_grad(functools.partial(loss_fn,
                                                       state=state,
                                                       batch=batch,
                                                       prng_key=prng_key),
                                     has_aux=True)

        (loss, (new_state, logits, prior_penalty, model_penalty,
                weight_penalty, summaries)), grad = grad_fn(optimizer.target)
        grad = train_functions.pmean(grad, config, 'batch')

        metrics = train_functions.compute_metrics(config, logits,
                                                  batch['label'])
        opt_kwargs = {}

        if config.optimizer in ['sym_euler']:
            opt_kwargs['temperature'] = temp
            opt_kwargs['step_size_factor'] = step_size_factor
            # NOTE: ignoring lr in lieu of step_size_factor.
            metrics['temperature'] = temp
            metrics['step_factor'] = step_size_factor
        else:
            opt_kwargs['learning_rate'] = lr
            metrics['learning_rate'] = lr

        grad = jax.example_libraries.optimizers.clip_grads(
            grad, config.gradient_clipping)
        with flax.deprecated.nn.stochastic(opt_rng):
            new_optimizer = optimizer.apply_gradient(grad, **opt_kwargs)

        metrics['cum_loss'] = loss
        metrics['prior_penalty'] = prior_penalty
        metrics['model_penalty'] = model_penalty
        metrics['weight_penalty'] = weight_penalty
        metrics.update(summaries)
        return new_optimizer, new_state, metrics
Beispiel #4
0
def predict_step(model, state, batch, config):
  state = train_functions.pmean(state, config)
  with flax.deprecated.nn.stateful(state, mutable=False):
    logits, _ = model(batch['image'], train=False)

  return logits, batch['label']