コード例 #1
0
 def _to_auto(self, val):
   if self._maybe_partial_manual:
     return xla_sharding.manual_to_auto_spmd_partition(
         val, self._sharding, self._var.shape, single_dim=0)
   else:
     return xla_sharding.manual_to_auto_spmd_partition(val, self._sharding,
                                                       self._var.shape)
コード例 #2
0
    def assign(self, value, use_locking=False, name=None, read_value=True):
        """Implements the interface of tf.Variable.assign.

    Args:
      value: A manually sharded tensor that has the shape of the individual
        elements of the stacked variable (shard shape with the stacking
        dimension collapsed).
      use_locking: See tf.Variable.assign.
      name: See tf.Variable.assign.
      read_value: See tf.Variable.assign. If True, the returned value will be
        manually sharded.

    Returns:
      See tf.Variable.assign. If read_value is True, returns the updated value
      in the shard shape of the shape of the individual elements of the stacked
      variable (shard shape with the stacking dimension collapsed).
    """
        value = tf.expand_dims(value, 0)
        value = xla_sharding.manual_to_auto_spmd_partition(
            value, self._sharding, self._var.shape)
        res = self._var.assign(value, use_locking, name, read_value)
        if read_value:
            res = xla_sharding.auto_to_manual_spmd_partition(
                res, self._sharding)
            res = tf.squeeze(res, 0)
        return res
コード例 #3
0
 def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
     """Implements the interface of tf.Variable.assign_sub."""
     delta = tf.expand_dims(delta, 0)
     delta = xla_sharding.manual_to_auto_spmd_partition(
         delta, self._sharding, self._var.shape)
     res = self._var.assign_sub(delta, use_locking, name, read_value)
     if read_value:
         res = xla_sharding.auto_to_manual_spmd_partition(
             res, self._sharding)
         res = tf.squeeze(res, 0)
     return res
コード例 #4
0
ファイル: gshard_utils.py プロジェクト: Mddct/lingvo
 def ManualToAutoPartitioning(self, tensor: tf.Tensor) -> tf.Tensor:
   """Converts manually sharded tensor to full-size for auto partitioning."""
   full_shape = list(tensor.shape)
   if not self.is_replicated:
     for i in range(len(self._split_dims_mapping)):
       if self._split_dims_mapping[i] >= 0:
         full_shape[i] *= self._device_mesh.shape[self._split_dims_mapping[i]]
       if self._uneven_padding is not None and self._uneven_padding[i] > 0:
         full_shape[i] -= self._uneven_padding[i]
   return xla_sharding.manual_to_auto_spmd_partition(
       tensor,
       self.ToXlaOpSharding().SerializeToString(), full_shape)
コード例 #5
0
ファイル: moe_layers.py プロジェクト: apoorvakumar2306/lingvo
def GatherK(selected_pos, values, k, num_devices=1):
  """Gather up to k elements from given tensors at selected pos under SPMD.

  Example::

    # Input
    k = 3

    selected_pos = [
        [0, 0, 1, 1],
        [0, 1, 1, 0],
        [0, 0, 0, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1],  # topk(k=3) largest indices are selected in this row.
    ]

    value_2d = [
        [1, 3, 5, 7],
        [9, 11, 13, 15],
        [17, 19, 21, 23],
        [25, 27, 29, 31],
        [33, 35, 37, 39],
    ]

    # Output:
    output = [
        [0, 5, 7],
        [0, 11, 13],
        [0, 0, 0],
        [25, 27, 29],
        [35, 37, 39],
    ]

    # Output padding:
    output_padding = [
        [1, 0, 0],
        [1, 0, 0],
        [1, 1, 1],
        [0, 0, 0],
        [0, 0, 0],
    ]

  Args:
    selected_pos: a 0/1 2D tf.int32 tensor of shape [batch, time].
    values: a list of tensors, the rank of each is at least rank=2. [batch,
      time, ...].
    k: a scalar tf.int32 tensor or a Python int. On TPU, k must be a
      compile-time constant.
    num_devices: number of TPU devices used in xla_sharding SPMD.

  Returns:
    A tuple (output, padding).

    - output: a list of tensors of shape [batch, k, ...].
    - padding: a 2D 0/1 tensor of shape [batch, k], '1's are padded locations.
  """
  global_batch, seq_len = py_utils.GetShape(selected_pos, 2)
  if num_devices:
    device_batch = global_batch // num_devices
  else:
    device_batch = global_batch

  for i in range(len(values)):
    # Assert the first 2 dim of values[i] is [global_batch, seq_len]
    values[i] = py_utils.HasShape(values[i], [global_batch, seq_len], 2)
  # indices are 1-based for now, to distinguish between padding and selected
  # locations.
  indices = 1 + tf.range(tf.shape(values[0])[1], dtype=tf.int32)
  # [1, seq_len]
  indices = tf.expand_dims(indices, axis=0)

  # if 0, the position is not selected.
  # [1, seq_len] * [global_batch, seq_len] => [global_batch, t]
  # -- topk --> [global_batch, k]
  topk_indices, _ = tf.math.top_k(
      indices * tf.cast(selected_pos, indices.dtype), k)

  # [global_batch, k], sorted in ascending order.
  indices = tf.reverse(topk_indices, [-1])
  # [global_batch, k], padded positions are '1's.
  padding = tf.cast(tf.equal(indices, 0), values[0].dtype)
  padding = Split(padding, 0, num_devices)

  # [global_batch, k], zero_based_indices
  mp_idx = tf.maximum(0, indices - 1)
  mp_idx = Split(mp_idx, 0, num_devices)

  # [device_batch, k]
  if num_devices > 1 and py_utils.use_tpu():
    mp_idx = xla_sharding.auto_to_manual_spmd_partition(
        mp_idx, xla_sharding.get_op_sharding(mp_idx.op))
  # [device_batch, k, 1]
  mp_idx = tf.expand_dims(mp_idx, -1)

  # [device_batch]
  batch_ids = tf.range(device_batch, dtype=tf.int32)
  # [device_batch, 1, 1]
  batch_ids = tf.reshape(batch_ids, [device_batch, 1, 1])
  # [device_batch, k, 1]
  batch_ids = tf.broadcast_to(batch_ids, [device_batch, k, 1])

  # [device_batch, k, 2]
  final_indices = tf.concat([batch_ids, mp_idx], axis=-1)

  output = []
  for v in values:
    # Begin manually partition gather.
    v = Split(v, 0, num_devices)
    v_shape = v.shape.as_list()
    if num_devices > 1 and py_utils.use_tpu():
      op_sharding = xla_sharding.get_op_sharding(v.op)
      v = xla_sharding.auto_to_manual_spmd_partition(v, op_sharding)
    # Returns [global_batch, k, ...]
    v_out = tf.gather_nd(v, final_indices)

    if num_devices > 1 and py_utils.use_tpu():
      v_shape[1] = k
      v_out = xla_sharding.manual_to_auto_spmd_partition(
          v_out, op_sharding, full_shape=tf.TensorShape(v_shape))
    output.append(v_out)

  return output, padding