示例#1
0
def MeshSplit(x,
              device_mesh,
              tensor_split_dims_mapping,
              use_sharding_op=True,
              unspecified_dims=None):
    """Wrapper of xla_sharding.mesh_split()."""
    if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None
            or device_mesh is None or device_mesh.size <= 1):
        return x
    # Apply the prefix in the context.
    tensor_split_dims_mapping = (_MESH_SPLIT_DIM_PREFIXES.stack +
                                 tensor_split_dims_mapping)
    num_tiles = np.prod(
        [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0])
    if num_tiles <= 1:
        return x
    if _MANUAL_MESH_DIMS.stack or unspecified_dims:
        return xla_sharding.mesh_split(
            x,
            device_mesh,
            tensor_split_dims_mapping,
            use_sharding_op=use_sharding_op,
            manual_mesh_dims=_MANUAL_MESH_DIMS.stack,
            unspecified_dims=unspecified_dims)
    # Do not include manual_mesh_dims or unspecified_dims to support legacy TF
    # versions.
    return xla_sharding.mesh_split(x,
                                   device_mesh,
                                   tensor_split_dims_mapping,
                                   use_sharding_op=use_sharding_op)
示例#2
0
def MeshSplit(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=True):
    """Wrapper of xla_sharding.mesh_split()."""
    if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None
            or device_mesh is None or device_mesh.size <= 1):
        return x
    num_tiles = np.prod(
        [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0])
    if num_tiles <= 1:
        return x
    return xla_sharding.mesh_split(x,
                                   device_mesh,
                                   tensor_split_dims_mapping,
                                   use_sharding_op=use_sharding_op)
示例#3
0
def MeshSplit(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=True):
    """Wrapper of xla_sharding.mesh_split()."""
    if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None
            or device_mesh is None or device_mesh.size <= 1):
        return x
    # Apply the prefix in the context.
    tensor_split_dims_mapping = (_MESH_SPLIT_DIM_PREFIXES.stack +
                                 tensor_split_dims_mapping)
    num_tiles = np.prod(
        [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0])
    if num_tiles <= 1:
        return x
    return xla_sharding.mesh_split(x,
                                   device_mesh,
                                   tensor_split_dims_mapping,
                                   use_sharding_op=use_sharding_op)
示例#4
0
def Split(x,
          split_dimension,
          num_devices,
          use_sharding_op=True,
          input_shape=None):
    """Wrapper for xla_sharding.split.

  Args:
    x: Tensor to annotate.
    split_dimension: xla_sharding.split arg.
    num_devices: xla_sharding.split arg.
    use_sharding_op: If true, adds a sharding op to set the sharding:
      tensor = gen_xla_ops.xla_sharding(tensor)

      hyouklee@: use_sharding_op=False
        "It adds the sharding attribute to the op itself. The outcome is that,
        that information could be lost by TF graph transformations. Also,
        directly attaching the sharding annotation to the op caused some
        compilation failures in the past (due to incompatible shardings), so the
        plan is to make use_sharding_op to be the default."

        "The only case I would set it to False today is when annotating weights.
        Weight annotation does some special handling, so there may be some
        changes needed in that logic if we add separate sharding op."
    input_shape: The shape of the original tensor.

  Returns:
    Tensor conditionally annotated with sharding.
  """
    if not py_utils_flags.use_tpu(
    ) or num_devices is None or not num_devices > 1:
        return x
    return xla_sharding.split(
        x,
        split_dimension,
        num_devices,
        input_shape=input_shape,
        use_sharding_op=use_sharding_op,
    )
示例#5
0
def Replicate(x, use_sharding_op=True):
    """Wrapper of xla_sharding.replicate."""
    if not py_utils_flags.use_tpu():
        return x
    return xla_sharding.replicate(x, use_sharding_op=use_sharding_op)