Exemplo n.º 1
0
def _reduce_non_distributed_value(distribution, aggregation, value,
                                  destinations):
  """Reduce a non-DistributedValue `value` to `destinations`."""
  if isinstance(value, values.DistributedValues):
    raise ValueError("You are passing a `DistributedValue` to "
                     "`_reduce_non_distributed_value`, which is not allowed.")

  if value == 0:
    return 0
  if aggregation == variable_scope.VariableAggregation.MEAN:
    return distribution.broadcast(value, destinations)

  cross_tower_ops_lib.validate_destinations(destinations)
  if (len(distribution.worker_devices) != 1 or
      not cross_tower_ops_lib.check_destinations(destinations)):
    raise ValueError("A non-DistributedValues value cannot be reduced with the "
                     "given aggregation.")
  # TODO(anjalisridhar): Moves these methods to a device utility file?
  devices = cross_tower_ops_lib.get_devices_from(destinations)
  if len(devices) == 1:
    with ops.device(devices[0]):
      return array_ops.identity(value)
  else:
    value_updates = {}
    for d in devices:
      with ops.device(d):
        value_updates[d] = array_ops.identity(value)
    return values.Mirrored(value_updates)
Exemplo n.º 2
0
def _reduce_non_distributed_value(extended, reduce_op, value, destinations):
  """Reduce a non-DistributedValue `value` to `destinations`."""
  if isinstance(value, values.DistributedValues):
    raise ValueError("You are passing a `DistributedValue` to "
                     "`_reduce_non_distributed_value`, which is not allowed.")

  # If the same value is present on all replicas then the PerReplica value will
  # be a single value. We also handle the case when `value` is a single value
  # and equal to 0.
  if value == 0:
    return 0
  # If there is only a single value and the reduce op is MEAN,
  # that value should be on all destinations.
  if reduce_op == reduce_util.ReduceOp.MEAN:
    return value

  cross_tower_ops_lib.validate_destinations(destinations)
  # We do not support a reduce op of SUM if the value is the same across
  # all replicas. We call this as part of assign functions for MirroredVariables
  # and summing up identical values across replicas is not clearly defined.
  if (len(extended.worker_devices) != 1 or
      not cross_tower_ops_lib.check_destinations(destinations)):
    raise ValueError("A non-DistributedValues value %s cannot be reduced with "
                     "the given reduce op %s." % (value, reduce_op))
  # TODO(anjalisridhar): Moves these methods to a device utility file?
  devices = cross_tower_ops_lib.get_devices_from(destinations)
  if len(devices) == 1:
    with ops.device(devices[0]):
      return array_ops.identity(value)
  else:
    value_updates = {}
    for d in devices:
      with ops.device(d):
        value_updates[d] = array_ops.identity(value)
    return values.Mirrored(value_updates)
Exemplo n.º 3
0
def _reduce_non_distributed_value(distribution, aggregation, value,
                                  destinations):
  """Reduce a non-DistributedValue `value` to `destinations`."""
  if isinstance(value, values.DistributedValues):
    raise ValueError("You are passing a `DistributedValue` to "
                     "`_reduce_non_distributed_value`, which is not allowed.")

  if value == 0:
    return 0
  if aggregation == variable_scope.VariableAggregation.MEAN:
    return distribution.broadcast(value, destinations)

  cross_tower_ops_lib.validate_destinations(destinations)
  if (len(distribution.worker_devices) != 1 or
      not cross_tower_ops_lib.check_destinations(destinations)):
    raise ValueError("A non-DistributedValues value cannot be reduced with the "
                     "given aggregation.")
  # TODO(anjalisridhar): Moves these methods to a device utility file?
  devices = cross_tower_ops_lib.get_devices_from(destinations)
  if len(devices) == 1:
    with ops.device(devices[0]):
      return array_ops.identity(value)
  else:
    value_updates = {}
    for d in devices:
      with ops.device(d):
        value_updates[d] = array_ops.identity(value)
    return values.Mirrored(value_updates)
def _reduce_non_distributed_value(distribution, aggregation, value,
                                  destinations):
    """Reduce a non-DistributedValue `value` to `destinations`."""
    if isinstance(value, values.DistributedValues):
        raise ValueError(
            "You are passing a `DistributedValue` to "
            "`_reduce_non_distributed_value`, which is not allowed.")

    # If the same value is present on all towers then the PerDevice value will
    # be a single value. We also handle the case when `value` is a single value
    # and equal to 0.
    if value == 0:
        return 0
    # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this
    # essentially means that the same value should be on all destinations.
    if aggregation in (variable_scope.VariableAggregation.MEAN,
                       variable_scope.VariableAggregation.ONLY_FIRST_TOWER):
        return value

    cross_tower_ops_lib.validate_destinations(destinations)
    # We do not support an aggregation type of SUM if the value is the same across
    # all towers. We call this as part of assign functions for MirroredVariables
    # and summing up identical values across towers is not clearly defined.
    if (len(distribution.worker_devices) != 1
            or not cross_tower_ops_lib.check_destinations(destinations)):
        raise ValueError(
            "A non-DistributedValues value %s cannot be reduced with "
            "the given aggregation %s." % (value, aggregation))
    # TODO(anjalisridhar): Moves these methods to a device utility file?
    devices = cross_tower_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
        with ops.device(devices[0]):
            return array_ops.identity(value)
    else:
        value_updates = {}
        for d in devices:
            with ops.device(d):
                value_updates[d] = array_ops.identity(value)
        return values.Mirrored(value_updates)
Exemplo n.º 5
0
def _reduce_non_distributed_value(distribution, aggregation, value,
                                  destinations):
  """Reduce a non-DistributedValue `value` to `destinations`."""
  if isinstance(value, values.DistributedValues):
    raise ValueError("You are passing a `DistributedValue` to "
                     "`_reduce_non_distributed_value`, which is not allowed.")

  # If the same value is present on all towers then the PerDevice value will
  # be a single value. We also handle the case when `value` is a single value
  # and equal to 0.
  if value == 0:
    return 0
  # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this
  # essentially means that the same value should be on all destinations.
  if aggregation in (
      variable_scope.VariableAggregation.MEAN,
      variable_scope.VariableAggregation.ONLY_FIRST_TOWER):
    return value

  cross_tower_ops_lib.validate_destinations(destinations)
  # We do not support an aggregation type of SUM if the value is the same across
  # all towers. We call this as part of assign functions for MirroredVariables
  # and summing up identical values across towers is not clearly defined.
  if (len(distribution.worker_devices) != 1 or
      not cross_tower_ops_lib.check_destinations(destinations)):
    raise ValueError("A non-DistributedValues value %s cannot be reduced with "
                     "the given aggregation %s." % (value, aggregation))
  # TODO(anjalisridhar): Moves these methods to a device utility file?
  devices = cross_tower_ops_lib.get_devices_from(destinations)
  if len(devices) == 1:
    with ops.device(devices[0]):
      return array_ops.identity(value)
  else:
    value_updates = {}
    for d in devices:
      with ops.device(d):
        value_updates[d] = array_ops.identity(value)
    return values.Mirrored(value_updates)
 def _broadcast(self, tensor, destinations):
     if not cross_tower_ops_lib.check_destinations(destinations):
         destinations = self._compute_devices
     return self._cross_tower_ops.broadcast(tensor, destinations)
 def _broadcast(self, tensor, destinations):
   if not cross_tower_ops_lib.check_destinations(destinations):
     destinations = self._compute_devices
   return self._cross_tower_ops.broadcast(tensor, destinations)