示例#1
0
 def _wrapped_fn(theta, per_stage_inputs, *per_stage_args,
                 **per_stage_kwargs):
     with base_layer.JaxContext.new_context(
             prng_key=prng_key, global_step=global_step) as jax_ctx:
         jax_ctx.bind(self.body, self.body.vars_to_flax_vars(theta),
                      [base_layer.SCOPE_AUX_LOSS])
         res = self.body.fprop(per_stage_inputs, *per_stage_args,
                               **per_stage_kwargs)
         summaries = base_layer.all_summaries()
         return res, summaries
示例#2
0
  def _loss_fn(
      mdl_vars: NestedJTensor, inputs: NestedMap
  ) -> Tuple[JTensor, Tuple[Any, NestedMap, SummaryDict, SummaryDict]]:
    """Computes loss as well as other auxiliary outputs."""
    if fprop_dtype == jnp.float32:
      pass
    elif fprop_dtype == jnp.bfloat16:
      mdl_vars = jax.tree_map(_maybe_to_bfloat16, mdl_vars)
      inputs = jax.tree_map(_maybe_to_bfloat16, inputs)
    else:
      assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.')

    with base_layer.JaxContext.new_context(
        params=context_p, prng_key=subkey,
        global_step=states.step) as jax_context:
      jax_context.bind(model, model.vars_to_flax_vars(mdl_vars),
                       [base_layer.SCOPE_VARS, base_layer.SCOPE_AUX_LOSS])

      metrics, per_example_output = model.fprop(inputs)
      loss_name = learner.loss_name
      assert loss_name in metrics
      loss, loss_weight = metrics[loss_name]
      assert loss.ndim == 0, 'loss has to be a scalar.'
      assert loss_weight.ndim == 0, 'loss_weight has to be a scalar'
      loss_weight = jax.lax.stop_gradient(loss_weight)
      if in_pmap:
        # Renormalize loss weight by the total weight across all replicas.
        # This also takes care of dividing loss by num of data parallel
        # replicas.
        loss_weight /= jax.lax.psum(
            loss_weight, axis_name=data_parallel_axis_name)
      else:
        # loss_weight == 1 in spmd.
        loss_weight /= jnp.sum(loss_weight)
      weighted_loss = loss * loss_weight
      # Fetch forward-updated vars, which often include batch norm vars, other
      # misc stats, etc.
      forward_updated_vars = model.updated_vars
      # Finally, fetch all the summary tensors.
      summary_tensors = base_layer.all_summaries()
      if in_pmap:
        summary_tensors = summary_utils.aggregate_per_replica_summaries(
            summary_tensors, data_parallel_axis_name)
    if fprop_dtype == jnp.bfloat16 and weighted_loss.dtype == fprop_dtype:
      weighted_loss = weighted_loss.astype(jnp.float32)
    return weighted_loss, (metrics, forward_updated_vars, summary_tensors,
                           per_example_output)
示例#3
0
    def fn_wrap(carry, xs_t):
        # carry is augmented with time_step, prng_key, global_step three additional
        # tensors to make fn_wrap fully functional.
        # Start a new prng_key branch that also depends on the time step.
        prng_key_t = jax.random.fold_in(carry.prng_key, carry.time_step)
        with base_layer.JaxContext.new_context(prng_key=prng_key_t,
                                               global_step=carry.global_step):

            carry_new, ys_t = fn(carry, xs_t)
            carry_new.time_step = carry.time_step + 1
            # copy over prng_key and global_step
            carry_new.prng_key = carry.prng_key
            carry_new.global_step = carry.global_step

            tf.nest.assert_same_structure(carry_new, carry)
            summaries = base_layer.all_summaries()

        return carry_new, (ys_t, summaries)
示例#4
0
        def Comp(theta, prng_key, global_step, inputs, paddings):
            with base_layer.JaxContext.new_context(
                    global_step=global_step, prng_key=prng_key) as jax_context:
                jax_context.bind(layer, layer.vars_to_flax_vars(theta),
                                 [base_layer.SCOPE_VARS])
                per_step_prng_key = jax.random.fold_in(prng_key, global_step)
                base_layer.reset_prng_key(per_step_prng_key, global_step)
                output = layer.fprop(inputs, paddings)
                forward_updated_theta = layer.updated_vars

                def UpdateParam(old, new):
                    if new is not None:
                        return new
                    else:
                        return old

                # Get the new variables.
                new_theta = tf.nest.map_structure(UpdateParam, theta,
                                                  forward_updated_theta)
                # Fetch summaries.
                summaries = base_layer.all_summaries()

                return new_theta, output, summaries
示例#5
0
def eval_step_single_learner(
    jax_task: base_task.SingleTask,
    mdl_vars: NestedJTensor,
    prng_key: JTensor,
    global_step: JTensor,
    inputs: Union[JTensor, NestedMap],
    data_parallel_axis_name: Optional[str] = 'batch',
    fprop_dtype: jnp.dtype = jnp.float32) -> Tuple[Any, Any, Any, SummaryDict]:
  """Evaluates a model for a single step.

  This utility is specialized for the single learner case.

  Args:
    jax_task: An instance of base_task.SingleTask.
    mdl_vars: model variables to be used during eval.
    prng_key: A prng seed, of shape [2], of type np.uint32.
    global_step: A global step tensor indicating how many steps a model has been
      trained.
    inputs: Inputs to the mdl.fprop() function.
    data_parallel_axis_name: a string, the device axis to aggregate gradients
      over.
    fprop_dtype: fprop datatype, can be either jnp.float32 or jnp.bfloat16.

  Returns:
    A tuple of the following elements.
    loss - loss as computed by mdl.fprop.
    mean_metrics - a dict of metrics. Each element of the dict is a pair
    (metric, weight).
    per_example_out - auxilillary per-example output as computed in mdl.fprop.
    summary_tensors - A nested map or dict of summary tensors computed in
      forward as well as backward pass.
  """
  context_p = base_layer.JaxContext.Params().Set(do_eval=True)
  # Fold in global_step as part of the random seed key, so that random
  # numbers depends on global step.
  prng_key = jax.random.fold_in(prng_key, global_step)
  model = jax_task.model

  if fprop_dtype == jnp.float32:
    pass
  elif fprop_dtype == jnp.bfloat16:
    mdl_vars = jax.tree_map(_maybe_to_bfloat16, mdl_vars)
    inputs = jax.tree_map(_maybe_to_bfloat16, inputs)
  else:
    assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.')

  with base_layer.JaxContext.new_context(
      params=context_p, prng_key=prng_key,
      global_step=global_step) as jax_context:
    # Prepares mdl for fprop. This clears all forward-updated vars that kept
    # locally in mdl.
    jax_context.bind(model, model.vars_to_flax_vars(mdl_vars),
                     [base_layer.SCOPE_AUX_LOSS])

    # Support multiple learners.
    assert len(jax_task.learners) == 1
    learner = jax_task.learners[0]

    metrics, per_example_out = model.fprop(inputs)
    loss_name = learner.loss_name
    assert loss_name in metrics
    loss, loss_weight = metrics[loss_name]
    assert loss.ndim == 0, 'loss has to be a scalar.'
    assert loss_weight.ndim == 0, 'loss_weight has to be a scalar.'

    summary_tensors = base_layer.all_summaries()

    if data_parallel_axis_name:
      # This is simple data-parallel training.
      # Renormalize loss weight by the total weight across all replicas.
      sum_loss = jax.lax.psum(
          loss * loss_weight, axis_name=data_parallel_axis_name)
      sum_loss_weight = jax.lax.psum(
          loss_weight, axis_name=data_parallel_axis_name)
      mean_loss = sum_loss / (sum_loss_weight + 1e-8)

      mean_metrics = type(metrics)()
      for key in metrics:
        value, weight = metrics[key]
        sum_value = jax.lax.psum(
            value * weight, axis_name=data_parallel_axis_name)
        sum_weight = jax.lax.psum(weight, axis_name=data_parallel_axis_name)
        mean_metrics[key] = (sum_value / (sum_weight + 1e-8), sum_weight)

      summary_tensors = summary_utils.aggregate_per_replica_summaries(
          summary_tensors, data_parallel_axis_name)
    else:
      # No data_parallel_axis_name is specified, most likely this is evaling an
      # spmd model.
      mean_metrics = metrics
      mean_loss = loss

  def _maybe_to_float32(x):
    if x.dtype == jnp.bfloat16:
      return x.astype(jnp.float32)
    else:
      return x

  if fprop_dtype == jnp.bfloat16:
    mean_loss, mean_metrics, per_example_out, summary_tensors = jax.tree_map(
        _maybe_to_float32,
        (mean_loss, mean_metrics, per_example_out, summary_tensors))

  return mean_loss, mean_metrics, per_example_out, summary_tensors
示例#6
0
def train_step_single_learner(
    jax_task: base_task.SingleTask,
    states: TrainState,
    prng_key: JTensor,
    inputs: Union[JTensor, NestedMap],
    data_parallel_axis_name: Optional[str] = 'batch',
    fprop_dtype: jnp.dtype = jnp.float32
) -> Tuple[TrainState, Any, Any, Any, SummaryDict]:
  """Trains a model for a single step.

  This function works for both pmap-ed model and pjit-ed model. When this
  function is called from pmap-ed trainer, data_parallel_axis_name has to be
  set. Otherwise, data_parallel_axis_name should be either an empty string or
  None.

  TODO(yonghui): Maybe refactor pmap and pjit into two functions.

  This utility is specialized for the singler learner case.

  Args:
    jax_task: An instance of base_task.SingleTask.
    states: An instance of model.TrainState.
    prng_key: A PRNGKey, of shape [2], of type np.uint32.
    inputs: Inputs to the mdl.fprop() function.
    data_parallel_axis_name: a string, the device axis to aggregate gradients
      over.
    fprop_dtype: fprop datatype, can be either jnp.float32 or jnp.bfloat16.

  Returns:
    A tuple of the following elements.
    updated_states - updated states.
    loss - loss as computed by mdl.fprop.
    mean_metrics - a dict of metrics. Each element of the dict is a pair
    (metric, weight).
    per_example_out - auxilillary per-example output as computed in mdl.fprop.
    summary_tensors - A dict or nested map of summary tensors computed in
      forward as well as backward.
  """
  assert len(jax_task.learners) == 1
  learner = jax_task.learners[0]
  model = jax_task.model

  context_p = base_layer.JaxContext.Params().Set(do_eval=False)
  # Fold in global_step as part of the random seed key, so that random
  # numbers depends on global step.
  prng_key = jax.random.fold_in(prng_key, states.step)

  if data_parallel_axis_name:
    in_pmap = True
  else:
    in_pmap = False

  prng_key, subkey = jax.random.split(prng_key)

  def _loss_fn(
      mdl_vars: NestedJTensor, inputs: NestedMap
  ) -> Tuple[JTensor, Tuple[Any, NestedMap, SummaryDict, SummaryDict]]:
    """Computes loss as well as other auxiliary outputs."""
    if fprop_dtype == jnp.float32:
      pass
    elif fprop_dtype == jnp.bfloat16:
      mdl_vars = jax.tree_map(_maybe_to_bfloat16, mdl_vars)
      inputs = jax.tree_map(_maybe_to_bfloat16, inputs)
    else:
      assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.')

    with base_layer.JaxContext.new_context(
        params=context_p, prng_key=subkey,
        global_step=states.step) as jax_context:
      jax_context.bind(model, model.vars_to_flax_vars(mdl_vars),
                       [base_layer.SCOPE_VARS, base_layer.SCOPE_AUX_LOSS])

      metrics, per_example_output = model.fprop(inputs)
      loss_name = learner.loss_name
      assert loss_name in metrics
      loss, loss_weight = metrics[loss_name]
      assert loss.ndim == 0, 'loss has to be a scalar.'
      assert loss_weight.ndim == 0, 'loss_weight has to be a scalar'
      loss_weight = jax.lax.stop_gradient(loss_weight)
      if in_pmap:
        # Renormalize loss weight by the total weight across all replicas.
        # This also takes care of dividing loss by num of data parallel
        # replicas.
        loss_weight /= jax.lax.psum(
            loss_weight, axis_name=data_parallel_axis_name)
      else:
        # loss_weight == 1 in spmd.
        loss_weight /= jnp.sum(loss_weight)
      weighted_loss = loss * loss_weight
      # Fetch forward-updated vars, which often include batch norm vars, other
      # misc stats, etc.
      forward_updated_vars = model.updated_vars
      # Finally, fetch all the summary tensors.
      summary_tensors = base_layer.all_summaries()
      if in_pmap:
        summary_tensors = summary_utils.aggregate_per_replica_summaries(
            summary_tensors, data_parallel_axis_name)
    if fprop_dtype == jnp.bfloat16 and weighted_loss.dtype == fprop_dtype:
      weighted_loss = weighted_loss.astype(jnp.float32)
    return weighted_loss, (metrics, forward_updated_vars, summary_tensors,
                           per_example_output)

  grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)

  (weighted_loss, (metrics, fwd_updated_vars, fwd_summary_tensors,
                   per_example_out)), grads = grad_fn(states.mdl_vars, inputs)

  if in_pmap:
    # Scale weighted_loss back after gradient computation.
    # This is the average loss across all replicas.
    weighted_loss = jax.lax.psum(
        weighted_loss, axis_name=data_parallel_axis_name)
  else:
    # No sum of loss over all the replicas.
    pass

  if in_pmap:
    mean_metrics = type(metrics)()
    for key in metrics:
      value, weight = metrics[key]
      sum_value = jax.lax.psum(
          value * weight, axis_name=data_parallel_axis_name)
      sum_weight = jax.lax.psum(weight, axis_name=data_parallel_axis_name)
      mean_metrics[key] = (sum_value / (sum_weight + 1e-8), sum_weight)
  else:
    # No aggregation is needed.
    mean_metrics = metrics

  var_weight_params = model.vars
  tf.nest.assert_same_structure(states.mdl_vars, var_weight_params)
  tf.nest.assert_same_structure(states.mdl_vars, grads)
  tf.nest.assert_same_structure(states.mdl_vars, fwd_updated_vars)
  var_is_learnable = tf.nest.map_structure(
      lambda x: not base_layer.var_not_trainable(x), var_weight_params)
  tf.nest.assert_same_structure(states.mdl_vars, var_is_learnable)

  def _maybe_zero_out_grad_fn(var_grad, var, var_learnable):
    if var_learnable:
      return var_grad
    else:
      return jnp.zeros_like(var)

  # Zero-out gradient for non-learnable vars.
  grads = tf.nest.map_structure(_maybe_zero_out_grad_fn, grads, states.mdl_vars,
                                var_is_learnable)

  if in_pmap:
    # Aggregate grads across different model replicas.
    grads = jax.lax.psum(grads, axis_name=data_parallel_axis_name)
  else:
    # No gradient aggregation is needed.
    pass

  # Carry out backward computation under a JaxContext.
  prng_key, subkey = jax.random.split(prng_key)
  with base_layer.JaxContext.new_context(
      params=context_p, prng_key=subkey,
      global_step=states.step) as jax_context:
    # Nothing is allowed to change, except for summaries.
    jax_context.bind(model, model.vars_to_flax_vars(states.mdl_vars),
                     [base_layer.SCOPE_AUX_LOSS])

    # Add a summary for learning rate
    learning_rate = learner.optimizer.get_learning_rate(states.step)
    base_layer.add_summary('learning_rate', learning_rate)

    # Apply gradient transformations.
    transformed_grads, new_opt_states = learner.update_states(
        grads, states.opt_states[0], states.mdl_vars, var_weight_params)

    # Gradient descent on learnable vars.
    mdl_vars = learner.apply_gradient(states.mdl_vars, transformed_grads,
                                      var_is_learnable)

    def _synchronize_vars_using_mean(new_var: NestedMap,
                                     old_var: NestedMap) -> NestedMap:
      """Synchronize variables across a replica by averaging."""
      delta = new_var - old_var
      delta_mean = jax.lax.pmean(delta, axis_name=data_parallel_axis_name)
      updated_var = old_var + delta_mean
      return updated_var

    def _update_non_learnable_var(old_var: NestedMap, new_var: NestedMap,
                                  var_params: ParamsT) -> NestedMap:
      """Update non-trainable variables, using cross-replica synchronization.

      Args:
        old_var: Nested map of old variables.
        new_var: Nested map of new variables.
        var_params: Nested map of var param attributes such as whether a
          variable is trainable or requires synchornization across replicas.

      Returns:
        Updated variables.

      Raises:
        ValueError if no synchronization method is provided for non-trainable
        variables.
      """
      if not base_layer.var_not_trainable(var_params):
        assert new_var is None
        return old_var
      elif not in_pmap:
        # No aggregation is needed.
        assert new_var is not None
        return new_var
      elif base_layer.var_requires_mean_sync(var_params):
        assert new_var is not None
        return _synchronize_vars_using_mean(new_var, old_var)
      else:
        raise ValueError('Non-trainable variables must have a cross-replica '
                         'synchronization method specified.')

    var_weight_params = model.vars
    tf.nest.assert_same_structure(mdl_vars, var_weight_params)
    mdl_vars = tf.nest.map_structure(_update_non_learnable_var, mdl_vars,
                                     fwd_updated_vars, var_weight_params)

    new_states = states.new_state(
        mdl_vars=mdl_vars, opt_states=[new_opt_states])
    # Finally fetch all backward summary tensors. We do not aggregate the scalar
    # summaries with pmean because the grads are already psum-ed.
    bwd_summary_tensors = base_layer.all_summaries()

  summary_tensors = NestedMap(
      fwd_summary_tensors=fwd_summary_tensors,
      bwd_summary_tensors=bwd_summary_tensors)

  return (new_states, weighted_loss, mean_metrics, per_example_out,
          summary_tensors)