Example #1
0
    def _reduce(self, aggregation, value, destinations):
        assert not isinstance(value, values.Mirrored)
        if not isinstance(value, values.PerDevice):
            if value == 0:
                return 0
            if aggregation == variable_scope.VariableAggregation.MEAN:
                return self._broadcast(value, destinations)

            cross_tower_ops_lib.validate_destinations(destinations)
            if len(self._devices) == 1:
                if destinations:
                    # TODO(anjalisridhar): Moves these methods to a device utility file?
                    devices = cross_tower_ops_lib.get_devices_from(
                        destinations)
                    if len(devices) == 1:
                        with ops.device(devices[0]):
                            return array_ops.identity(value)
                    else:
                        value_updates = {}
                        for d in devices:
                            with ops.device(d):
                                value_updates[d] = array_ops.identity(value)
                        return values.Mirrored(value_updates)
            raise ValueError(
                "A non PerDevice value cannot be reduced with the given "
                "aggregation.")

        return self._get_cross_tower_ops().reduce(aggregation,
                                                  value,
                                                  destinations=destinations)
Example #2
0
  def _reduce(self, aggregation, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if aggregation == vs.VariableAggregation.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self.num_towers)
      elif aggregation != vs.VariableAggregation.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_tower_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self.get_host_cpu_device(0))
    else:
      raise ValueError('Multiple devices are not supported for TPUStrategy')

    if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
      return value[0]
    output = math_ops.add_n(value)
    if aggregation == vs.VariableAggregation.MEAN:
      return output * (1. / len(value))
    return output
  def _reduce(self, method_string, value, destinations):
    assert not isinstance(value, values.Mirrored)
    if not isinstance(value, values.PerDevice):
      if value == 0:
        return 0
      if method_string == "mean":
        return self._broadcast(value, destinations)

      cross_tower_ops_lib.validate_destinations(destinations)
      if len(self._devices) == 1:
        if destinations:
          # TODO(anjalisridhar): Moves these methods to a device utility file?
          devices = cross_tower_ops_lib.get_devices_from(destinations)
          if len(devices) == 1:
            with ops.device(devices[0]):
              return array_ops.identity(value)
          else:
            value_updates = {}
            for d in devices:
              with ops.device(d):
                value_updates[d] = array_ops.identity(value)
            return values.Mirrored(value_updates)
      raise ValueError("A non PerDevice value cannot be reduced with the given "
                       "method_string.")

    return self._get_cross_tower_ops().reduce(
        method_string, value, destinations=destinations)
Example #4
0
def _reduce_non_distributed_value(distribution, aggregation, value,
                                  destinations):
  """Reduce a non-DistributedValue `value` to `destinations`."""
  if isinstance(value, values.DistributedValues):
    raise ValueError("You are passing a `DistributedValue` to "
                     "`_reduce_non_distributed_value`, which is not allowed.")

  if value == 0:
    return 0
  if aggregation == variable_scope.VariableAggregation.MEAN:
    return distribution.broadcast(value, destinations)

  cross_tower_ops_lib.validate_destinations(destinations)
  if (len(distribution.worker_devices) != 1 or
      not cross_tower_ops_lib.check_destinations(destinations)):
    raise ValueError("A non-DistributedValues value cannot be reduced with the "
                     "given aggregation.")
  # TODO(anjalisridhar): Moves these methods to a device utility file?
  devices = cross_tower_ops_lib.get_devices_from(destinations)
  if len(devices) == 1:
    with ops.device(devices[0]):
      return array_ops.identity(value)
  else:
    value_updates = {}
    for d in devices:
      with ops.device(d):
        value_updates[d] = array_ops.identity(value)
    return values.Mirrored(value_updates)
Example #5
0
  def _reduce(self, aggregation, value, destinations):
    graph = ops.get_default_graph()
    cf_context = graph._get_control_flow_context()  # pylint: disable=protected-access
    # If we're inside the ReplicateContext, reduction should be done using
    # CrossReplicaSum while outside we can directly use an add_n op.
    while cf_context:
      if isinstance(cf_context, tpu.TPUReplicateContext):
        if aggregation == vs.VariableAggregation.MEAN:
          # TODO(jhseu):  Revisit once we support model-parallelism.
          value *= (1. / self.num_towers)
        return tpu_ops.cross_replica_sum(value)
      cf_context = cf_context.outer_context

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_tower_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self._host)
    else:
      raise ValueError('Multiple devices are not supported for TPUStrategy')

    output = math_ops.add_n(value)
    if aggregation == vs.VariableAggregation.MEAN:
      return output * (1. / len(value))
    return output
Example #6
0
    def _reduce(self, aggregation, value, destinations):
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if aggregation == vs.VariableAggregation.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self.num_towers)
            elif aggregation != vs.VariableAggregation.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_tower_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(
                    self.get_host_cpu_device(0))
        else:
            raise ValueError(
                'Multiple devices are not supported for TPUStrategy')

        if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
            return value[0]
        output = math_ops.add_n(value)
        if aggregation == vs.VariableAggregation.MEAN:
            return output * (1. / len(value))
        return output
Example #7
0
    def _reduce_to(self, reduce_op, value, destinations):
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_tower_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(self._host_device)
        else:
            raise ValueError(
                "Multiple devices are not supported for TPUStrategy")

        output = math_ops.add_n(value)
        if reduce_op == reduce_util.ReduceOp.MEAN:
            return output * (1. / len(value))
        return output
Example #8
0
def _reduce_non_distributed_value(extended, reduce_op, value, destinations):
  """Reduce a non-DistributedValue `value` to `destinations`."""
  if isinstance(value, values.DistributedValues):
    raise ValueError("You are passing a `DistributedValue` to "
                     "`_reduce_non_distributed_value`, which is not allowed.")

  # If the same value is present on all replicas then the PerReplica value will
  # be a single value. We also handle the case when `value` is a single value
  # and equal to 0.
  if value == 0:
    return 0
  # If there is only a single value and the reduce op is MEAN,
  # that value should be on all destinations.
  if reduce_op == reduce_util.ReduceOp.MEAN:
    return value

  cross_tower_ops_lib.validate_destinations(destinations)
  # We do not support a reduce op of SUM if the value is the same across
  # all replicas. We call this as part of assign functions for MirroredVariables
  # and summing up identical values across replicas is not clearly defined.
  if (len(extended.worker_devices) != 1 or
      not cross_tower_ops_lib.check_destinations(destinations)):
    raise ValueError("A non-DistributedValues value %s cannot be reduced with "
                     "the given reduce op %s." % (value, reduce_op))
  # TODO(anjalisridhar): Moves these methods to a device utility file?
  devices = cross_tower_ops_lib.get_devices_from(destinations)
  if len(devices) == 1:
    with ops.device(devices[0]):
      return array_ops.identity(value)
  else:
    value_updates = {}
    for d in devices:
      with ops.device(d):
        value_updates[d] = array_ops.identity(value)
    return values.Mirrored(value_updates)
Example #9
0
    def _reduce(self, aggregation, value, destinations):
        graph = ops.get_default_graph()
        cf_context = graph._get_control_flow_context()  # pylint: disable=protected-access
        # If we're inside the ReplicateContext, reduction should be done using
        # CrossReplicaSum while outside we can directly use an add_n op.
        while cf_context:
            if isinstance(cf_context, tpu.TPUReplicateContext):
                if aggregation == vs.VariableAggregation.MEAN:
                    # TODO(jhseu):  Revisit once we support model-parallelism.
                    value *= (1. / self.num_towers)
                return tpu_ops.cross_replica_sum(value)
            cf_context = cf_context.outer_context

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_tower_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(self._host)
        else:
            raise ValueError(
                'Multiple devices are not supported for TPUStrategy')

        output = math_ops.add_n(value)
        if aggregation == vs.VariableAggregation.MEAN:
            return output * (1. / len(value))
        return output
def _reduce_non_distributed_value(distribution, aggregation, value,
                                  destinations):
  """Reduce a non-DistributedValue `value` to `destinations`."""
  if isinstance(value, values.DistributedValues):
    raise ValueError("You are passing a `DistributedValue` to "
                     "`_reduce_non_distributed_value`, which is not allowed.")

  if value == 0:
    return 0
  if aggregation == variable_scope.VariableAggregation.MEAN:
    return distribution.broadcast(value, destinations)

  cross_tower_ops_lib.validate_destinations(destinations)
  if (len(distribution.worker_devices) != 1 or
      not cross_tower_ops_lib.check_destinations(destinations)):
    raise ValueError("A non-DistributedValues value cannot be reduced with the "
                     "given aggregation.")
  # TODO(anjalisridhar): Moves these methods to a device utility file?
  devices = cross_tower_ops_lib.get_devices_from(destinations)
  if len(devices) == 1:
    with ops.device(devices[0]):
      return array_ops.identity(value)
  else:
    value_updates = {}
    for d in devices:
      with ops.device(d):
        value_updates[d] = array_ops.identity(value)
    return values.Mirrored(value_updates)
Example #11
0
def _make_per_device(values, devices):
    devices = cross_tower_ops_lib.get_devices_from(devices)
    assert len(values) == len(devices)
    index = {}
    for d, v in zip(devices, values):
        with ops.device(d):
            placed_v = array_ops.identity(v)
        index[d] = placed_v
    return value_lib.PerDevice(index)
Example #12
0
def _fake_mirrored(value, devices):
  """Create a faked Mirrored object for testing.

  All components of the returned Mirrored have the same objects, which is not
  true in reality.
  """
  devices = cross_tower_ops_lib.get_devices_from(devices)
  return value_lib.Mirrored(
      {d: v for d, v in zip(devices, [value] * len(devices))})
def _make_per_device(values, devices):
  devices = cross_tower_ops_lib.get_devices_from(devices)
  assert len(values) == len(devices)
  index = {}
  for d, v in zip(devices, values):
    with ops.device(d):
      placed_v = array_ops.identity(v)
    index[d] = placed_v
  return value_lib.PerDevice(index)
Example #14
0
 def _verify_destinations_not_different_worker(self, destinations):
   if destinations is None:
     return
   for d in cross_tower_ops_lib.get_devices_from(destinations):
     d_spec = tf_device.DeviceSpec.from_string(d)
     if d_spec.job == self._task_type and d_spec.task != self._task_id:
       raise ValueError(
           "Cannot reduce to another worker: %r, current worker is %r" %
           (d, self._worker_device))
 def _verify_destinations_not_different_worker(self, destinations):
     if destinations is None:
         return
     for d in cross_tower_ops_lib.get_devices_from(destinations):
         d_spec = tf_device.DeviceSpec.from_string(d)
         if d_spec.job == self._task_type and d_spec.task != self._task_id:
             raise ValueError(
                 "Cannot reduce to another worker: %r, current worker is %r"
                 % (d, self._worker_device))
def _make_per_replica(values, devices, regroup=False):
  devices = cross_tower_ops_lib.get_devices_from(devices)
  assert len(values) == len(devices)

  # We simulate the result of regroup called on PerReplica which strips the
  # PerReplica wrapper if it has only one value.
  if len(values) == 1 and regroup:
    with ops.device(devices[0]):
      placed_v = array_ops.identity(values[0])
    return placed_v

  index = {}
  for d, v in zip(devices, values):
    with ops.device(d):
      placed_v = array_ops.identity(v)
    index[d] = placed_v
  return value_lib.PerReplica(index)
def _make_per_device(values, devices, regroup=False):
    devices = cross_tower_ops_lib.get_devices_from(devices)
    assert len(values) == len(devices)

    # We simulate the result of regroup called on PerDevice which strips the
    # PerDevice wrapper if it has only one value.
    if len(values) == 1 and regroup:
        with ops.device(devices[0]):
            placed_v = array_ops.identity(values[0])
        return placed_v

    index = {}
    for d, v in zip(devices, values):
        with ops.device(d):
            placed_v = array_ops.identity(v)
        index[d] = placed_v
    return value_lib.PerDevice(index)
def _reduce_non_distributed_value(distribution, aggregation, value,
                                  destinations):
    """Reduce a non-DistributedValue `value` to `destinations`."""
    if isinstance(value, values.DistributedValues):
        raise ValueError(
            "You are passing a `DistributedValue` to "
            "`_reduce_non_distributed_value`, which is not allowed.")

    # If the same value is present on all towers then the PerDevice value will
    # be a single value. We also handle the case when `value` is a single value
    # and equal to 0.
    if value == 0:
        return 0
    # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this
    # essentially means that the same value should be on all destinations.
    if aggregation in (variable_scope.VariableAggregation.MEAN,
                       variable_scope.VariableAggregation.ONLY_FIRST_TOWER):
        return value

    cross_tower_ops_lib.validate_destinations(destinations)
    # We do not support an aggregation type of SUM if the value is the same across
    # all towers. We call this as part of assign functions for MirroredVariables
    # and summing up identical values across towers is not clearly defined.
    if (len(distribution.worker_devices) != 1
            or not cross_tower_ops_lib.check_destinations(destinations)):
        raise ValueError(
            "A non-DistributedValues value %s cannot be reduced with "
            "the given aggregation %s." % (value, aggregation))
    # TODO(anjalisridhar): Moves these methods to a device utility file?
    devices = cross_tower_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
        with ops.device(devices[0]):
            return array_ops.identity(value)
    else:
        value_updates = {}
        for d in devices:
            with ops.device(d):
                value_updates[d] = array_ops.identity(value)
        return values.Mirrored(value_updates)
def _reduce_non_distributed_value(distribution, aggregation, value,
                                  destinations):
  """Reduce a non-DistributedValue `value` to `destinations`."""
  if isinstance(value, values.DistributedValues):
    raise ValueError("You are passing a `DistributedValue` to "
                     "`_reduce_non_distributed_value`, which is not allowed.")

  # If the same value is present on all towers then the PerDevice value will
  # be a single value. We also handle the case when `value` is a single value
  # and equal to 0.
  if value == 0:
    return 0
  # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this
  # essentially means that the same value should be on all destinations.
  if aggregation in (
      variable_scope.VariableAggregation.MEAN,
      variable_scope.VariableAggregation.ONLY_FIRST_TOWER):
    return value

  cross_tower_ops_lib.validate_destinations(destinations)
  # We do not support an aggregation type of SUM if the value is the same across
  # all towers. We call this as part of assign functions for MirroredVariables
  # and summing up identical values across towers is not clearly defined.
  if (len(distribution.worker_devices) != 1 or
      not cross_tower_ops_lib.check_destinations(destinations)):
    raise ValueError("A non-DistributedValues value %s cannot be reduced with "
                     "the given aggregation %s." % (value, aggregation))
  # TODO(anjalisridhar): Moves these methods to a device utility file?
  devices = cross_tower_ops_lib.get_devices_from(destinations)
  if len(devices) == 1:
    with ops.device(devices[0]):
      return array_ops.identity(value)
  else:
    value_updates = {}
    for d in devices:
      with ops.device(d):
        value_updates[d] = array_ops.identity(value)
    return values.Mirrored(value_updates)
 def _get_devices_from(self, colocate_with=None):
     if colocate_with is None:
         return self._devices
     else:
         return cross_tower_ops_lib.get_devices_from(colocate_with)
 def _get_devices_from(self, colocate_with=None):
   if colocate_with is None:
     return self._devices
   else:
     return cross_tower_ops_lib.get_devices_from(colocate_with)