def _wrapped_fn(theta, per_stage_inputs, *per_stage_args, **per_stage_kwargs): with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=global_step) as jax_ctx: jax_ctx.bind(self.body, self.body.vars_to_flax_vars(theta), [base_layer.SCOPE_AUX_LOSS]) res = self.body.fprop(per_stage_inputs, *per_stage_args, **per_stage_kwargs) summaries = base_layer.all_summaries() return res, summaries
def _loss_fn( mdl_vars: NestedJTensor, inputs: NestedMap ) -> Tuple[JTensor, Tuple[Any, NestedMap, SummaryDict, SummaryDict]]: """Computes loss as well as other auxiliary outputs.""" if fprop_dtype == jnp.float32: pass elif fprop_dtype == jnp.bfloat16: mdl_vars = jax.tree_map(_maybe_to_bfloat16, mdl_vars) inputs = jax.tree_map(_maybe_to_bfloat16, inputs) else: assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.') with base_layer.JaxContext.new_context( params=context_p, prng_key=subkey, global_step=states.step) as jax_context: jax_context.bind(model, model.vars_to_flax_vars(mdl_vars), [base_layer.SCOPE_VARS, base_layer.SCOPE_AUX_LOSS]) metrics, per_example_output = model.fprop(inputs) loss_name = learner.loss_name assert loss_name in metrics loss, loss_weight = metrics[loss_name] assert loss.ndim == 0, 'loss has to be a scalar.' assert loss_weight.ndim == 0, 'loss_weight has to be a scalar' loss_weight = jax.lax.stop_gradient(loss_weight) if in_pmap: # Renormalize loss weight by the total weight across all replicas. # This also takes care of dividing loss by num of data parallel # replicas. loss_weight /= jax.lax.psum( loss_weight, axis_name=data_parallel_axis_name) else: # loss_weight == 1 in spmd. loss_weight /= jnp.sum(loss_weight) weighted_loss = loss * loss_weight # Fetch forward-updated vars, which often include batch norm vars, other # misc stats, etc. forward_updated_vars = model.updated_vars # Finally, fetch all the summary tensors. summary_tensors = base_layer.all_summaries() if in_pmap: summary_tensors = summary_utils.aggregate_per_replica_summaries( summary_tensors, data_parallel_axis_name) if fprop_dtype == jnp.bfloat16 and weighted_loss.dtype == fprop_dtype: weighted_loss = weighted_loss.astype(jnp.float32) return weighted_loss, (metrics, forward_updated_vars, summary_tensors, per_example_output)
def fn_wrap(carry, xs_t): # carry is augmented with time_step, prng_key, global_step three additional # tensors to make fn_wrap fully functional. # Start a new prng_key branch that also depends on the time step. prng_key_t = jax.random.fold_in(carry.prng_key, carry.time_step) with base_layer.JaxContext.new_context(prng_key=prng_key_t, global_step=carry.global_step): carry_new, ys_t = fn(carry, xs_t) carry_new.time_step = carry.time_step + 1 # copy over prng_key and global_step carry_new.prng_key = carry.prng_key carry_new.global_step = carry.global_step tf.nest.assert_same_structure(carry_new, carry) summaries = base_layer.all_summaries() return carry_new, (ys_t, summaries)
def Comp(theta, prng_key, global_step, inputs, paddings): with base_layer.JaxContext.new_context( global_step=global_step, prng_key=prng_key) as jax_context: jax_context.bind(layer, layer.vars_to_flax_vars(theta), [base_layer.SCOPE_VARS]) per_step_prng_key = jax.random.fold_in(prng_key, global_step) base_layer.reset_prng_key(per_step_prng_key, global_step) output = layer.fprop(inputs, paddings) forward_updated_theta = layer.updated_vars def UpdateParam(old, new): if new is not None: return new else: return old # Get the new variables. new_theta = tf.nest.map_structure(UpdateParam, theta, forward_updated_theta) # Fetch summaries. summaries = base_layer.all_summaries() return new_theta, output, summaries
def eval_step_single_learner( jax_task: base_task.SingleTask, mdl_vars: NestedJTensor, prng_key: JTensor, global_step: JTensor, inputs: Union[JTensor, NestedMap], data_parallel_axis_name: Optional[str] = 'batch', fprop_dtype: jnp.dtype = jnp.float32) -> Tuple[Any, Any, Any, SummaryDict]: """Evaluates a model for a single step. This utility is specialized for the single learner case. Args: jax_task: An instance of base_task.SingleTask. mdl_vars: model variables to be used during eval. prng_key: A prng seed, of shape [2], of type np.uint32. global_step: A global step tensor indicating how many steps a model has been trained. inputs: Inputs to the mdl.fprop() function. data_parallel_axis_name: a string, the device axis to aggregate gradients over. fprop_dtype: fprop datatype, can be either jnp.float32 or jnp.bfloat16. Returns: A tuple of the following elements. loss - loss as computed by mdl.fprop. mean_metrics - a dict of metrics. Each element of the dict is a pair (metric, weight). per_example_out - auxilillary per-example output as computed in mdl.fprop. summary_tensors - A nested map or dict of summary tensors computed in forward as well as backward pass. """ context_p = base_layer.JaxContext.Params().Set(do_eval=True) # Fold in global_step as part of the random seed key, so that random # numbers depends on global step. prng_key = jax.random.fold_in(prng_key, global_step) model = jax_task.model if fprop_dtype == jnp.float32: pass elif fprop_dtype == jnp.bfloat16: mdl_vars = jax.tree_map(_maybe_to_bfloat16, mdl_vars) inputs = jax.tree_map(_maybe_to_bfloat16, inputs) else: assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.') with base_layer.JaxContext.new_context( params=context_p, prng_key=prng_key, global_step=global_step) as jax_context: # Prepares mdl for fprop. This clears all forward-updated vars that kept # locally in mdl. jax_context.bind(model, model.vars_to_flax_vars(mdl_vars), [base_layer.SCOPE_AUX_LOSS]) # Support multiple learners. assert len(jax_task.learners) == 1 learner = jax_task.learners[0] metrics, per_example_out = model.fprop(inputs) loss_name = learner.loss_name assert loss_name in metrics loss, loss_weight = metrics[loss_name] assert loss.ndim == 0, 'loss has to be a scalar.' assert loss_weight.ndim == 0, 'loss_weight has to be a scalar.' summary_tensors = base_layer.all_summaries() if data_parallel_axis_name: # This is simple data-parallel training. # Renormalize loss weight by the total weight across all replicas. sum_loss = jax.lax.psum( loss * loss_weight, axis_name=data_parallel_axis_name) sum_loss_weight = jax.lax.psum( loss_weight, axis_name=data_parallel_axis_name) mean_loss = sum_loss / (sum_loss_weight + 1e-8) mean_metrics = type(metrics)() for key in metrics: value, weight = metrics[key] sum_value = jax.lax.psum( value * weight, axis_name=data_parallel_axis_name) sum_weight = jax.lax.psum(weight, axis_name=data_parallel_axis_name) mean_metrics[key] = (sum_value / (sum_weight + 1e-8), sum_weight) summary_tensors = summary_utils.aggregate_per_replica_summaries( summary_tensors, data_parallel_axis_name) else: # No data_parallel_axis_name is specified, most likely this is evaling an # spmd model. mean_metrics = metrics mean_loss = loss def _maybe_to_float32(x): if x.dtype == jnp.bfloat16: return x.astype(jnp.float32) else: return x if fprop_dtype == jnp.bfloat16: mean_loss, mean_metrics, per_example_out, summary_tensors = jax.tree_map( _maybe_to_float32, (mean_loss, mean_metrics, per_example_out, summary_tensors)) return mean_loss, mean_metrics, per_example_out, summary_tensors
def train_step_single_learner( jax_task: base_task.SingleTask, states: TrainState, prng_key: JTensor, inputs: Union[JTensor, NestedMap], data_parallel_axis_name: Optional[str] = 'batch', fprop_dtype: jnp.dtype = jnp.float32 ) -> Tuple[TrainState, Any, Any, Any, SummaryDict]: """Trains a model for a single step. This function works for both pmap-ed model and pjit-ed model. When this function is called from pmap-ed trainer, data_parallel_axis_name has to be set. Otherwise, data_parallel_axis_name should be either an empty string or None. TODO(yonghui): Maybe refactor pmap and pjit into two functions. This utility is specialized for the singler learner case. Args: jax_task: An instance of base_task.SingleTask. states: An instance of model.TrainState. prng_key: A PRNGKey, of shape [2], of type np.uint32. inputs: Inputs to the mdl.fprop() function. data_parallel_axis_name: a string, the device axis to aggregate gradients over. fprop_dtype: fprop datatype, can be either jnp.float32 or jnp.bfloat16. Returns: A tuple of the following elements. updated_states - updated states. loss - loss as computed by mdl.fprop. mean_metrics - a dict of metrics. Each element of the dict is a pair (metric, weight). per_example_out - auxilillary per-example output as computed in mdl.fprop. summary_tensors - A dict or nested map of summary tensors computed in forward as well as backward. """ assert len(jax_task.learners) == 1 learner = jax_task.learners[0] model = jax_task.model context_p = base_layer.JaxContext.Params().Set(do_eval=False) # Fold in global_step as part of the random seed key, so that random # numbers depends on global step. prng_key = jax.random.fold_in(prng_key, states.step) if data_parallel_axis_name: in_pmap = True else: in_pmap = False prng_key, subkey = jax.random.split(prng_key) def _loss_fn( mdl_vars: NestedJTensor, inputs: NestedMap ) -> Tuple[JTensor, Tuple[Any, NestedMap, SummaryDict, SummaryDict]]: """Computes loss as well as other auxiliary outputs.""" if fprop_dtype == jnp.float32: pass elif fprop_dtype == jnp.bfloat16: mdl_vars = jax.tree_map(_maybe_to_bfloat16, mdl_vars) inputs = jax.tree_map(_maybe_to_bfloat16, inputs) else: assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.') with base_layer.JaxContext.new_context( params=context_p, prng_key=subkey, global_step=states.step) as jax_context: jax_context.bind(model, model.vars_to_flax_vars(mdl_vars), [base_layer.SCOPE_VARS, base_layer.SCOPE_AUX_LOSS]) metrics, per_example_output = model.fprop(inputs) loss_name = learner.loss_name assert loss_name in metrics loss, loss_weight = metrics[loss_name] assert loss.ndim == 0, 'loss has to be a scalar.' assert loss_weight.ndim == 0, 'loss_weight has to be a scalar' loss_weight = jax.lax.stop_gradient(loss_weight) if in_pmap: # Renormalize loss weight by the total weight across all replicas. # This also takes care of dividing loss by num of data parallel # replicas. loss_weight /= jax.lax.psum( loss_weight, axis_name=data_parallel_axis_name) else: # loss_weight == 1 in spmd. loss_weight /= jnp.sum(loss_weight) weighted_loss = loss * loss_weight # Fetch forward-updated vars, which often include batch norm vars, other # misc stats, etc. forward_updated_vars = model.updated_vars # Finally, fetch all the summary tensors. summary_tensors = base_layer.all_summaries() if in_pmap: summary_tensors = summary_utils.aggregate_per_replica_summaries( summary_tensors, data_parallel_axis_name) if fprop_dtype == jnp.bfloat16 and weighted_loss.dtype == fprop_dtype: weighted_loss = weighted_loss.astype(jnp.float32) return weighted_loss, (metrics, forward_updated_vars, summary_tensors, per_example_output) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (weighted_loss, (metrics, fwd_updated_vars, fwd_summary_tensors, per_example_out)), grads = grad_fn(states.mdl_vars, inputs) if in_pmap: # Scale weighted_loss back after gradient computation. # This is the average loss across all replicas. weighted_loss = jax.lax.psum( weighted_loss, axis_name=data_parallel_axis_name) else: # No sum of loss over all the replicas. pass if in_pmap: mean_metrics = type(metrics)() for key in metrics: value, weight = metrics[key] sum_value = jax.lax.psum( value * weight, axis_name=data_parallel_axis_name) sum_weight = jax.lax.psum(weight, axis_name=data_parallel_axis_name) mean_metrics[key] = (sum_value / (sum_weight + 1e-8), sum_weight) else: # No aggregation is needed. mean_metrics = metrics var_weight_params = model.vars tf.nest.assert_same_structure(states.mdl_vars, var_weight_params) tf.nest.assert_same_structure(states.mdl_vars, grads) tf.nest.assert_same_structure(states.mdl_vars, fwd_updated_vars) var_is_learnable = tf.nest.map_structure( lambda x: not base_layer.var_not_trainable(x), var_weight_params) tf.nest.assert_same_structure(states.mdl_vars, var_is_learnable) def _maybe_zero_out_grad_fn(var_grad, var, var_learnable): if var_learnable: return var_grad else: return jnp.zeros_like(var) # Zero-out gradient for non-learnable vars. grads = tf.nest.map_structure(_maybe_zero_out_grad_fn, grads, states.mdl_vars, var_is_learnable) if in_pmap: # Aggregate grads across different model replicas. grads = jax.lax.psum(grads, axis_name=data_parallel_axis_name) else: # No gradient aggregation is needed. pass # Carry out backward computation under a JaxContext. prng_key, subkey = jax.random.split(prng_key) with base_layer.JaxContext.new_context( params=context_p, prng_key=subkey, global_step=states.step) as jax_context: # Nothing is allowed to change, except for summaries. jax_context.bind(model, model.vars_to_flax_vars(states.mdl_vars), [base_layer.SCOPE_AUX_LOSS]) # Add a summary for learning rate learning_rate = learner.optimizer.get_learning_rate(states.step) base_layer.add_summary('learning_rate', learning_rate) # Apply gradient transformations. transformed_grads, new_opt_states = learner.update_states( grads, states.opt_states[0], states.mdl_vars, var_weight_params) # Gradient descent on learnable vars. mdl_vars = learner.apply_gradient(states.mdl_vars, transformed_grads, var_is_learnable) def _synchronize_vars_using_mean(new_var: NestedMap, old_var: NestedMap) -> NestedMap: """Synchronize variables across a replica by averaging.""" delta = new_var - old_var delta_mean = jax.lax.pmean(delta, axis_name=data_parallel_axis_name) updated_var = old_var + delta_mean return updated_var def _update_non_learnable_var(old_var: NestedMap, new_var: NestedMap, var_params: ParamsT) -> NestedMap: """Update non-trainable variables, using cross-replica synchronization. Args: old_var: Nested map of old variables. new_var: Nested map of new variables. var_params: Nested map of var param attributes such as whether a variable is trainable or requires synchornization across replicas. Returns: Updated variables. Raises: ValueError if no synchronization method is provided for non-trainable variables. """ if not base_layer.var_not_trainable(var_params): assert new_var is None return old_var elif not in_pmap: # No aggregation is needed. assert new_var is not None return new_var elif base_layer.var_requires_mean_sync(var_params): assert new_var is not None return _synchronize_vars_using_mean(new_var, old_var) else: raise ValueError('Non-trainable variables must have a cross-replica ' 'synchronization method specified.') var_weight_params = model.vars tf.nest.assert_same_structure(mdl_vars, var_weight_params) mdl_vars = tf.nest.map_structure(_update_non_learnable_var, mdl_vars, fwd_updated_vars, var_weight_params) new_states = states.new_state( mdl_vars=mdl_vars, opt_states=[new_opt_states]) # Finally fetch all backward summary tensors. We do not aggregate the scalar # summaries with pmean because the grads are already psum-ed. bwd_summary_tensors = base_layer.all_summaries() summary_tensors = NestedMap( fwd_summary_tensors=fwd_summary_tensors, bwd_summary_tensors=bwd_summary_tensors) return (new_states, weighted_loss, mean_metrics, per_example_out, summary_tensors)