Esempio n. 1
0
def _reduce_non_distributed_value(extended, reduce_op, value, destinations):
  """Reduce a non-DistributedValue `value` to `destinations`."""
  if isinstance(value, values.DistributedValues):
    raise ValueError("You are passing a `DistributedValue` to "
                     "`_reduce_non_distributed_value`, which is not allowed.")

  # If the same value is present on all replicas then the PerReplica value will
  # be a single value. We also handle the case when `value` is a single value
  # and equal to 0.
  if value == 0:
    return 0
  # If there is only a single value and the reduce op is MEAN,
  # that value should be on all destinations.
  if reduce_op == reduce_util.ReduceOp.MEAN:
    return value

  cross_device_ops_lib.validate_destinations(destinations)
  # We do not support a reduce op of SUM if the value is the same across
  # all replicas. We call this as part of assign functions for MirroredVariables
  # and summing up identical values across replicas is not clearly defined.
  if (len(extended.worker_devices) != 1 or
      not cross_device_ops_lib.check_destinations(destinations)):
    raise ValueError("A non-DistributedValues value %s cannot be reduced with "
                     "the given reduce op %s." % (value, reduce_op))
  # TODO(anjalisridhar): Moves these methods to a device utility file?
  devices = cross_device_ops_lib.get_devices_from(destinations)
  if len(devices) == 1:
    with ops.device(devices[0]):
      return array_ops.identity(value)
  else:
    value_updates = {}
    for d in devices:
      with ops.device(d):
        value_updates[d] = array_ops.identity(value)
    return values.Mirrored(value_updates)
Esempio n. 2
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    if not isinstance(value, values.DistributedValues):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case `value` would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, self._device_map, value, destinations)

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self._host_device)
    else:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    output = math_ops.add_n(value)
    if reduce_op == reduce_util.ReduceOp.MEAN:
      return output * (1. / len(value))
    return output
def _fake_mirrored(value, devices):
  """Create a faked Mirrored object for testing.

  All components of the returned Mirrored have the same objects, which is not
  true in reality.
  """
  devices = cross_device_ops_lib.get_devices_from(devices)
  return value_lib.Mirrored(
      {d: v for d, v in zip(devices, [value] * len(devices))})
 def _verify_destinations_not_different_worker(self, destinations):
   if not self._cluster_spec:
     return
   if destinations is None:
     return
   for d in cross_device_ops_lib.get_devices_from(destinations):
     d_spec = tf_device.DeviceSpec.from_string(d)
     if d_spec.job == self._task_type and d_spec.task != self._task_id:
       raise ValueError(
           "Cannot reduce to another worker: %r, current worker is %r" %
           (d, self._worker_device))
def _make_per_replica(values, devices, regroup=False):
  devices = cross_device_ops_lib.get_devices_from(devices)
  assert len(values) == len(devices)

  # We simulate the result of regroup called on PerReplica which strips the
  # PerReplica wrapper if it has only one value.
  if len(values) == 1 and regroup:
    with ops.device(devices[0]):
      placed_v = array_ops.identity(values[0])
    return placed_v

  index = {}
  for d, v in zip(devices, values):
    with ops.device(d):
      placed_v = array_ops.identity(v)
    index[d] = placed_v
  return value_lib.PerReplica(index)
Esempio n. 6
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    if not isinstance(value, values.DistributedValues):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case `value` would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, self._device_map, value, destinations)

    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) != 1:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    # Always performs the reduction on the TPU host.
    with ops.device(self._host_device):
      output = math_ops.add_n(value.values)
      if reduce_op == reduce_util.ReduceOp.MEAN:
        output *= (1. / len(value.values))

    # If necessary, copy to requested destination.
    dest_canonical = device_util.canonicalize(devices[0])
    host_canonical = device_util.canonicalize(self._host_device)

    if dest_canonical != host_canonical:
      with ops.device(devices[0]):
        output = array_ops.identity(output)

    return output
Esempio n. 7
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    if not isinstance(value, values.DistributedValues):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case `value` would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, self._device_map, value, destinations)

    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) != 1:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    # Always performs the reduction on the TPU host.
    with ops.device(self._host_device):
      output = math_ops.add_n(value.values)
      if reduce_op == reduce_util.ReduceOp.MEAN:
        output *= (1. / len(value.values))

    # If necessary, copy to requested destination.
    dest_canonical = device_util.canonicalize(devices[0])
    host_canonical = device_util.canonicalize(self._host_device)

    if dest_canonical != host_canonical:
      with ops.device(devices[0]):
        output = array_ops.identity(output)

    return output
Esempio n. 8
0
def _reduce_non_distributed_value(extended, reduce_op, value, destinations):
    """Reduce a non-DistributedValue `value` to `destinations`."""
    if isinstance(value, values.DistributedValues):
        raise ValueError(
            "You are passing a `DistributedValue` to "
            "`_reduce_non_distributed_value`, which is not allowed.")

    # If the same value is present on all replicas then the PerReplica value will
    # be a single value. We also handle the case when `value` is a single value
    # and equal to 0.
    if value == 0:
        return 0
    # If there is only a single value and the reduce op is MEAN,
    # that value should be on all destinations.
    if reduce_op == reduce_util.ReduceOp.MEAN:
        return value

    cross_device_ops_lib.validate_destinations(destinations)
    # We do not support a reduce op of SUM if the value is the same across
    # all replicas. We call this as part of assign functions for MirroredVariables
    # and summing up identical values across replicas is not clearly defined.
    if (len(extended.worker_devices) != 1
            or not cross_device_ops_lib.check_destinations(destinations)):
        raise ValueError(
            "A non-DistributedValues value %s cannot be reduced with "
            "the given reduce op %s." % (value, reduce_op))
    # TODO(anjalisridhar): Moves these methods to a device utility file?
    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
        with ops.device(devices[0]):
            return array_ops.identity(value)
    else:
        value_updates = {}
        for d in devices:
            with ops.device(d):
                value_updates[d] = array_ops.identity(value)
        return values.Mirrored(value_updates)
Esempio n. 9
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self._host_device)
    else:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    output = math_ops.add_n(value)
    if reduce_op == reduce_util.ReduceOp.MEAN:
      return output * (1. / len(value))
    return output
Esempio n. 10
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self._host_device)
    else:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    output = math_ops.add_n(value)
    if reduce_op == reduce_util.ReduceOp.MEAN:
      return output * (1. / len(value))
    return output
Esempio n. 11
0
 def _get_devices_from(self, colocate_with=None):
     if colocate_with is None:
         return self._devices
     else:
         return cross_device_ops_lib.get_devices_from(colocate_with)
Esempio n. 12
0
 def _get_devices_from(self, colocate_with=None):
   if colocate_with is None:
     return self._devices
   else:
     return cross_device_ops_lib.get_devices_from(colocate_with)