def shard_on_batch_dim_partition_spec( mesh_names: Sequence[str], x: jax.ShapeDtypeStruct) -> pjit.PartitionSpec: """Fully shards x on the batch dimension.""" x_dim = len(x.shape) assert x_dim >= 1 sharding = [-1] * x_dim # Assume the first dim is batch, and fully shard the batch dim over the entire # mesh. sharding[0] = tuple(mesh_names) return base_layer.to_partition_spec(sharding, mesh_names)
def _gather(xs): def _gather_one(x, i): return x[i] if p.mesh_axis_names is not None: # When the stage dim is partitioned, we use xmap (with manual sharding # implementation) to make sure it's trivially partitioned on the stage # dim and work around some potential optimization problems in XLA. # TODO(yuanzx): Use xmap on the whole body fprop. mesh_axis = base_layer.to_partition_spec( p.weight_split_dims_mapping.stages, p.mesh_axis_names)[0] if mesh_axis is not None: axis_resources = {'num_stages': mesh_axis} return maps.xmap( _gather_one, # broadcast_inputs are replicated across stages, but IDs are # per-stage. in_axes=([...], ['num_stages', ...]), out_axes=['num_stages', ...], axis_resources=axis_resources)(xs, microbatch_ids) return xs[microbatch_ids]
def infer_partition_spec_based_on_rank_fn( mapping_dict: Dict[str, base_layer.SplitDimsMapping], mesh_names: Sequence[str], x: JTensor, ) -> Optional[pjit.PartitionSpec]: """Infers PartitionSpec of input from the rank of corresponding JTensors. Args: mapping_dict: Dictionary which contains the split mapping for different shapes. For n-d shape, it must have an entry f'map_{n}d' which tells us how to partition tensors of this dimension. mesh_names: List of mesh axis names. x: JTensor which to shard. Returns: PartitionSpec or None (if everything is replicated). """ key = f'map_{len(x.shape)}d' if key not in mapping_dict: raise ValueError(f'Split mapping must be provided for {len(x.shape)}-d' f'in the form of key map_{len(x.shape)} in' f'{mapping_dict}.') if mapping_dict[key] is not None: return base_layer.to_partition_spec(mapping_dict[key], mesh_names)
def get_partitioned_spmd_model_decode_fn(jax_task, init_key, partitioned_train_state, train_state_partition_specs, inputs_shape: NestedShapeDtypeStruct): """Return sharded decode step function and input partition spec. Args: jax_task: Task instance. init_key: PRNGKey for initializing the model variables. partitioned_train_state: TrainState that holds all the variables. train_state_partition_specs: A TrainState contains PartitionSpecs for all the variables. inputs_shape: Shape of the inputs for use in pjit sharding. Returns: (decode_step_fn, inputs_partition_spec): The decode step function, and input partition spec. """ task_p = jax_task.params mesh_names = task_p.model.mesh_axis_names model = jax_task.model # Compute inputs PartitionSpec from inputs_shape inputs_partition_spec_fn = functools.partial( shard_on_batch_dim_partition_spec, mesh_names) reshard_inputs_fn = functools.partial(reshard_input_based_on_rank_fn, task_p.train.inputs_split_mapping, mesh_names) inputs_partition_spec = tf.nest.map_structure(inputs_partition_spec_fn, inputs_shape) # TODO(b/198356509): Fix this so that prng_key is no longer replicated, as # we want each core to not have identical random behavior. prng_key_partition_spec = base_layer.to_partition_spec((None,), mesh_names) eval_fn_in_partition_specs = (train_state_partition_specs.mdl_vars, prng_key_partition_spec, train_state_partition_specs.step, inputs_partition_spec) train_state_unpadded_shapes = jax_task.create_train_state_unpadded_shapes( model.vars, is_eval=True) def _decode_step(mdl_vars, prng_key, global_step, inputs): inputs = jax.tree_map(reshard_inputs_fn, inputs) mdl_vars = jax.tree_map(_maybe_slice_uneven_sharding, mdl_vars, train_state_partition_specs.mdl_vars, train_state_unpadded_shapes.mdl_vars) # Right now we only pad the vars, and decode doesn't output vars so we do # not need to pad at the end. return decode_step( model, mdl_vars, prng_key, global_step, inputs, fprop_dtype=task_p.model.fprop_dtype) decode_out_shapes = jax.eval_shape(_decode_step, partitioned_train_state.mdl_vars, init_key, partitioned_train_state.step, inputs_shape) # decoder output are always replicated at the moment. decode_fn_out_partition_specs = tf.nest.map_structure(lambda _: None, decode_out_shapes) decode_step_fn = pjit.pjit( _decode_step, in_axis_resources=eval_fn_in_partition_specs, out_axis_resources=decode_fn_out_partition_specs) return decode_step_fn, inputs_partition_spec
def partition_spmd_model( task_p: InstantiableParams, init_key: PRNGKey, inputs_shape: NestedShapeDtypeStruct, ) -> Tuple[TrainState, TrainState, TrainState, TrainStepFn, EvalStepFn, int]: """Setup the SPMD model and return sharded train and eval step function. For partitioning inputs, it is assumed the `task_p.train` has a field `inputs_split_mapping` which further contains keys `map_1d`, `map_2d`, ..., etc., which specifies how to shard inputs of that corresponding dimension. Args: task_p: Task parameters of type NestedMap. init_key: PRNGKey for initializing the model variables. inputs_shape: Shape of the inputs for use in pjit sharding. Returns: (partitioned_train_state, train_state_partition_specs, inputs_partition_spec, train_step_fn, eval_step_fn, total_num_params): The partitioned TrainState, the corresponding partitioned TrainState specs, the partition spec for the inputs, the train step function, eval step function and total number of parameters. """ model_p = task_p.model mesh_names = model_p.mesh_axis_names jax_task = task_p.Instantiate() model = jax_task.model reshard_inputs_fn = functools.partial(reshard_input_based_on_rank_fn, task_p.train.inputs_split_mapping, model_p.mesh_axis_names) inputs_partition_spec = get_input_partition_specs(model_p.mesh_axis_names, inputs_shape) # Initialize the partitioned vars. (train_state_partition_specs, var_padded_shapes, var_shapes, partitioned_train_state) = ( initialize_partitioned_model_states(jax_task, init_key)) total_num_params = model.total_num_vars # TODO(bf-jax): prng_key is replicated. Would this be a problem? prng_key_partition_spec = base_layer.to_partition_spec((None,), mesh_names) def _train_step(state, prng_key, inputs): # Reshard inputs. inputs = jax.tree_map(reshard_inputs_fn, inputs) # Vars are padded at program entry/exit to avoid uneven sharding. We slice # the vars to revome padding before the step computation, and pad them after # the step computation to make user code independent of paddings. Internal # uneven sharding in the step computation is supported by XLA. state = jax.tree_map(_maybe_slice_uneven_sharding, state, train_state_partition_specs, var_shapes) def _maybe_pad(x, pspec, shape): return _maybe_pad_uneven_sharding(x, pspec, shape, model_p.device_mesh.shape, model_p.mesh_axis_names) (new_states, weighted_loss, mean_metrics, per_example_out, summary_tensors) = train_step_single_learner( jax_task, state, prng_key, inputs, data_parallel_axis_name=None, fprop_dtype=model_p.fprop_dtype) new_states = jax.tree_map(_maybe_pad, new_states, train_state_partition_specs, var_shapes) return (new_states, weighted_loss, mean_metrics, per_example_out, summary_tensors) def _eval_step(mdl_vars, prng_key, global_step, inputs): # Reshard inputs. inputs = jax.tree_map(reshard_inputs_fn, inputs) mdl_vars = jax.tree_map(_maybe_slice_uneven_sharding, mdl_vars, train_state_partition_specs.mdl_vars, var_shapes.mdl_vars) # Right now we only pad the vars, and eval doesn't output vars so we do not # need to pad at the end. return eval_step_single_learner( jax_task, mdl_vars, prng_key, global_step, inputs, data_parallel_axis_name=None, fprop_dtype=model_p.fprop_dtype) train_out_padded_shapes = jax.eval_shape(_train_step, var_padded_shapes, init_key, inputs_shape) eval_out_padded_shapes = jax.eval_shape(_eval_step, var_padded_shapes.mdl_vars, init_key, var_padded_shapes.step, inputs_shape) def _partition_spec_from_shape(x_shape): # Currently, all the outputs are fully replicated. # TODO(yonghui): Somehow fetch the output sharding spec from _eval_step fn. del x_shape return None train_fn_in_partition_specs = (train_state_partition_specs, prng_key_partition_spec, inputs_partition_spec) train_fn_out_replicated_specs = tf.nest.map_structure( _partition_spec_from_shape, train_out_padded_shapes) # Here we assume the first output is the train-state. # Expcept for the first train_state output, others outputs are explicitly # replicated. train_fn_out_partition_specs = tuple([train_state_partition_specs] + list(train_fn_out_replicated_specs[1:])) tf.nest.assert_same_structure(train_fn_out_replicated_specs, train_fn_out_partition_specs) tf.nest.assert_same_structure(train_fn_out_partition_specs, train_out_padded_shapes) eval_fn_in_partition_specs = (train_state_partition_specs.mdl_vars, prng_key_partition_spec, train_state_partition_specs.step, inputs_partition_spec) eval_fn_out_partition_specs = tf.nest.map_structure( _partition_spec_from_shape, eval_out_padded_shapes) # pjit-ed train step function. train_step = pjit.pjit( _train_step, in_axis_resources=train_fn_in_partition_specs, out_axis_resources=train_fn_out_partition_specs, donate_argnums=(0,)) # pjit-ed eval step function. eval_step = pjit.pjit( _eval_step, in_axis_resources=eval_fn_in_partition_specs, out_axis_resources=eval_fn_out_partition_specs) return (partitioned_train_state, train_state_partition_specs, inputs_partition_spec, train_step, eval_step, total_num_params)