예제 #1
0
def shard_on_batch_dim_partition_spec(
    mesh_names: Sequence[str], x: jax.ShapeDtypeStruct) -> pjit.PartitionSpec:
  """Fully shards x on the batch dimension."""
  x_dim = len(x.shape)
  assert x_dim >= 1
  sharding = [-1] * x_dim
  # Assume the first dim is batch, and fully shard the batch dim over the entire
  # mesh.
  sharding[0] = tuple(mesh_names)
  return base_layer.to_partition_spec(sharding, mesh_names)
예제 #2
0
      def _gather(xs):

        def _gather_one(x, i):
          return x[i]

        if p.mesh_axis_names is not None:
          # When the stage dim is partitioned, we use xmap (with manual sharding
          # implementation) to make sure it's trivially partitioned on the stage
          # dim and work around some potential optimization problems in XLA.
          # TODO(yuanzx): Use xmap on the whole body fprop.
          mesh_axis = base_layer.to_partition_spec(
              p.weight_split_dims_mapping.stages, p.mesh_axis_names)[0]
          if mesh_axis is not None:
            axis_resources = {'num_stages': mesh_axis}
            return maps.xmap(
                _gather_one,
                # broadcast_inputs are replicated across stages, but IDs are
                # per-stage.
                in_axes=([...], ['num_stages', ...]),
                out_axes=['num_stages', ...],
                axis_resources=axis_resources)(xs, microbatch_ids)
        return xs[microbatch_ids]
예제 #3
0
def infer_partition_spec_based_on_rank_fn(
    mapping_dict: Dict[str, base_layer.SplitDimsMapping],
    mesh_names: Sequence[str],
    x: JTensor,
) -> Optional[pjit.PartitionSpec]:
  """Infers PartitionSpec of input from the rank of corresponding JTensors.

  Args:
    mapping_dict: Dictionary which contains the split mapping for different
      shapes. For n-d shape, it must have an entry f'map_{n}d' which tells us
      how to partition tensors of this dimension.
    mesh_names: List of mesh axis names.
    x: JTensor which to shard.

  Returns:
    PartitionSpec or None (if everything is replicated).
  """
  key = f'map_{len(x.shape)}d'
  if key not in mapping_dict:
    raise ValueError(f'Split mapping must be provided for {len(x.shape)}-d'
                     f'in the form of key map_{len(x.shape)} in'
                     f'{mapping_dict}.')
  if mapping_dict[key] is not None:
    return base_layer.to_partition_spec(mapping_dict[key], mesh_names)
예제 #4
0
def get_partitioned_spmd_model_decode_fn(jax_task, init_key,
                                         partitioned_train_state,
                                         train_state_partition_specs,
                                         inputs_shape: NestedShapeDtypeStruct):
  """Return sharded decode step function and input partition spec.

  Args:
    jax_task: Task instance.
    init_key: PRNGKey for initializing the model variables.
    partitioned_train_state: TrainState that holds all the variables.
    train_state_partition_specs: A TrainState contains PartitionSpecs for all
      the variables.
    inputs_shape: Shape of the inputs for use in pjit sharding.

  Returns:
    (decode_step_fn, inputs_partition_spec):
    The decode step function, and input partition spec.
  """
  task_p = jax_task.params
  mesh_names = task_p.model.mesh_axis_names
  model = jax_task.model

  # Compute inputs PartitionSpec from inputs_shape
  inputs_partition_spec_fn = functools.partial(
      shard_on_batch_dim_partition_spec, mesh_names)
  reshard_inputs_fn = functools.partial(reshard_input_based_on_rank_fn,
                                        task_p.train.inputs_split_mapping,
                                        mesh_names)

  inputs_partition_spec = tf.nest.map_structure(inputs_partition_spec_fn,
                                                inputs_shape)

  # TODO(b/198356509): Fix this so that prng_key is no longer replicated, as
  # we want each core to not have identical random behavior.
  prng_key_partition_spec = base_layer.to_partition_spec((None,), mesh_names)

  eval_fn_in_partition_specs = (train_state_partition_specs.mdl_vars,
                                prng_key_partition_spec,
                                train_state_partition_specs.step,
                                inputs_partition_spec)
  train_state_unpadded_shapes = jax_task.create_train_state_unpadded_shapes(
      model.vars, is_eval=True)
  def _decode_step(mdl_vars, prng_key, global_step, inputs):
    inputs = jax.tree_map(reshard_inputs_fn, inputs)
    mdl_vars = jax.tree_map(_maybe_slice_uneven_sharding, mdl_vars,
                            train_state_partition_specs.mdl_vars,
                            train_state_unpadded_shapes.mdl_vars)
    # Right now we only pad the vars, and decode doesn't output vars so we do
    # not need to pad at the end.
    return decode_step(
        model,
        mdl_vars,
        prng_key,
        global_step,
        inputs,
        fprop_dtype=task_p.model.fprop_dtype)

  decode_out_shapes = jax.eval_shape(_decode_step,
                                     partitioned_train_state.mdl_vars, init_key,
                                     partitioned_train_state.step, inputs_shape)

  # decoder output are always replicated at the moment.
  decode_fn_out_partition_specs = tf.nest.map_structure(lambda _: None,
                                                        decode_out_shapes)
  decode_step_fn = pjit.pjit(
      _decode_step,
      in_axis_resources=eval_fn_in_partition_specs,
      out_axis_resources=decode_fn_out_partition_specs)

  return decode_step_fn, inputs_partition_spec
예제 #5
0
def partition_spmd_model(
    task_p: InstantiableParams,
    init_key: PRNGKey,
    inputs_shape: NestedShapeDtypeStruct,
) -> Tuple[TrainState, TrainState, TrainState, TrainStepFn, EvalStepFn, int]:
  """Setup the SPMD model and return sharded train and eval step function.

  For partitioning inputs, it is assumed the `task_p.train` has a field
  `inputs_split_mapping` which further contains keys `map_1d`, `map_2d`, ...,
  etc., which specifies how to shard inputs of that corresponding dimension.

  Args:
    task_p: Task parameters of type NestedMap.
    init_key: PRNGKey for initializing the model variables.
    inputs_shape: Shape of the inputs for use in pjit sharding.

  Returns:
    (partitioned_train_state, train_state_partition_specs,
    inputs_partition_spec, train_step_fn, eval_step_fn, total_num_params):
    The partitioned TrainState, the corresponding partitioned TrainState specs,
    the partition spec for the inputs, the train step function, eval step
    function and total number of parameters.
  """
  model_p = task_p.model
  mesh_names = model_p.mesh_axis_names
  jax_task = task_p.Instantiate()
  model = jax_task.model

  reshard_inputs_fn = functools.partial(reshard_input_based_on_rank_fn,
                                        task_p.train.inputs_split_mapping,
                                        model_p.mesh_axis_names)
  inputs_partition_spec = get_input_partition_specs(model_p.mesh_axis_names,
                                                    inputs_shape)

  # Initialize the partitioned vars.
  (train_state_partition_specs, var_padded_shapes, var_shapes,
   partitioned_train_state) = (
       initialize_partitioned_model_states(jax_task, init_key))
  total_num_params = model.total_num_vars

  # TODO(bf-jax): prng_key is replicated. Would this be a problem?
  prng_key_partition_spec = base_layer.to_partition_spec((None,), mesh_names)

  def _train_step(state, prng_key, inputs):
    # Reshard inputs.
    inputs = jax.tree_map(reshard_inputs_fn, inputs)
    # Vars are padded at program entry/exit to avoid uneven sharding. We slice
    # the vars to revome padding before the step computation, and pad them after
    # the step computation to make user code independent of paddings. Internal
    # uneven sharding in the step computation is supported by XLA.
    state = jax.tree_map(_maybe_slice_uneven_sharding, state,
                         train_state_partition_specs, var_shapes)

    def _maybe_pad(x, pspec, shape):
      return _maybe_pad_uneven_sharding(x, pspec, shape,
                                        model_p.device_mesh.shape,
                                        model_p.mesh_axis_names)

    (new_states, weighted_loss, mean_metrics, per_example_out,
     summary_tensors) = train_step_single_learner(
         jax_task,
         state,
         prng_key,
         inputs,
         data_parallel_axis_name=None,
         fprop_dtype=model_p.fprop_dtype)
    new_states = jax.tree_map(_maybe_pad, new_states,
                              train_state_partition_specs, var_shapes)
    return (new_states, weighted_loss, mean_metrics, per_example_out,
            summary_tensors)

  def _eval_step(mdl_vars, prng_key, global_step, inputs):
    # Reshard inputs.
    inputs = jax.tree_map(reshard_inputs_fn, inputs)
    mdl_vars = jax.tree_map(_maybe_slice_uneven_sharding, mdl_vars,
                            train_state_partition_specs.mdl_vars,
                            var_shapes.mdl_vars)
    # Right now we only pad the vars, and eval doesn't output vars so we do not
    # need to pad at the end.
    return eval_step_single_learner(
        jax_task,
        mdl_vars,
        prng_key,
        global_step,
        inputs,
        data_parallel_axis_name=None,
        fprop_dtype=model_p.fprop_dtype)

  train_out_padded_shapes = jax.eval_shape(_train_step, var_padded_shapes,
                                           init_key, inputs_shape)

  eval_out_padded_shapes = jax.eval_shape(_eval_step,
                                          var_padded_shapes.mdl_vars, init_key,
                                          var_padded_shapes.step, inputs_shape)

  def _partition_spec_from_shape(x_shape):
    # Currently, all the outputs are fully replicated.
    # TODO(yonghui): Somehow fetch the output sharding spec from _eval_step fn.
    del x_shape
    return None

  train_fn_in_partition_specs = (train_state_partition_specs,
                                 prng_key_partition_spec, inputs_partition_spec)
  train_fn_out_replicated_specs = tf.nest.map_structure(
      _partition_spec_from_shape, train_out_padded_shapes)
  # Here we assume the first output is the train-state.
  # Expcept for the first train_state output, others outputs are explicitly
  # replicated.
  train_fn_out_partition_specs = tuple([train_state_partition_specs] +
                                       list(train_fn_out_replicated_specs[1:]))
  tf.nest.assert_same_structure(train_fn_out_replicated_specs,
                                train_fn_out_partition_specs)
  tf.nest.assert_same_structure(train_fn_out_partition_specs,
                                train_out_padded_shapes)

  eval_fn_in_partition_specs = (train_state_partition_specs.mdl_vars,
                                prng_key_partition_spec,
                                train_state_partition_specs.step,
                                inputs_partition_spec)
  eval_fn_out_partition_specs = tf.nest.map_structure(
      _partition_spec_from_shape, eval_out_padded_shapes)

  # pjit-ed train step function.
  train_step = pjit.pjit(
      _train_step,
      in_axis_resources=train_fn_in_partition_specs,
      out_axis_resources=train_fn_out_partition_specs,
      donate_argnums=(0,))

  # pjit-ed eval step function.
  eval_step = pjit.pjit(
      _eval_step,
      in_axis_resources=eval_fn_in_partition_specs,
      out_axis_resources=eval_fn_out_partition_specs)

  return (partitioned_train_state, train_state_partition_specs,
          inputs_partition_spec, train_step, eval_step, total_num_params)