def _simple_broadcast(tensor, destinations): index = {} devices = _get_devices_from(destinations) for d in devices: with ops.device(d): index[d] = array_ops.identity(tensor) return value_lib.Mirrored(index)
def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) if not isinstance(value, values.PerDevice): if value == 0: return 0 if aggregation == variable_scope.VariableAggregation.MEAN: return self._broadcast(value, destinations) cross_tower_ops_lib.validate_destinations(destinations) if len(self._devices) == 1: if destinations: # TODO(anjalisridhar): Moves these methods to a device utility file? devices = cross_tower_ops_lib.get_devices_from( destinations) if len(devices) == 1: with ops.device(devices[0]): return array_ops.identity(value) else: value_updates = {} for d in devices: with ops.device(d): value_updates[d] = array_ops.identity(value) return values.Mirrored(value_updates) raise ValueError( "A non PerDevice value cannot be reduced with the given " "aggregation.") return self._get_cross_tower_ops().reduce(aggregation, value, destinations=destinations)
def _fake_mirrored(value, devices): """Create a faked Mirrored object for testing. All components of the returned Mirrored have the same objects, which is not true in reality. """ devices = cross_tower_ops_lib._get_devices_from(devices) return value_lib.Mirrored( {d: v for d, v in zip(devices, [value] * len(devices))})
def _reduce(self, aggregation, per_device_value, destinations): all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] if destinations is None or _devices_match(per_device_value, destinations): return all_reduced else: index = {} for d in get_devices_from(destinations): # pylint: disable=protected-access if d in all_reduced._index: index[d] = all_reduced._index[d] else: with ops.control_dependencies(list( all_reduced._index.values())), ops.device(d): index[d] = array_ops.identity(list(all_reduced._index.values())[0]) return value_lib.Mirrored(index)
def _reduce_non_distributed_value(distribution, reduce_op, value, destinations): """Reduce a non-DistributedValue `value` to `destinations`.""" if isinstance(value, values.DistributedValues): raise ValueError( "You are passing a `DistributedValue` to " "`_reduce_non_distributed_value`, which is not allowed.") # If the same value is present on all replicas then the PerReplica value will # be a single value. We also handle the case when `value` is a single value # and equal to 0. if value == 0: return 0 # If the reduce op is MEAN or ONLY_FIRST_REPLICA, then this # essentially means that the same value should be on all destinations. if reduce_op in (reduce_util.ReduceOp.MEAN, reduce_util.ReduceOp.ONLY_FIRST_REPLICA): return value cross_tower_ops_lib.validate_destinations(destinations) # We do not support a reduce op of SUM if the value is the same across # all replicas. We call this as part of assign functions for MirroredVariables # and summing up identical values across replicas is not clearly defined. if (len(distribution.worker_devices) != 1 or not cross_tower_ops_lib.check_destinations(destinations)): raise ValueError( "A non-DistributedValues value %s cannot be reduced with " "the given reduce op %s." % (value, reduce_op)) # TODO(anjalisridhar): Moves these methods to a device utility file? devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: with ops.device(devices[0]): return array_ops.identity(value) else: value_updates = {} for d in devices: with ops.device(d): value_updates[d] = array_ops.identity(value) return values.Mirrored(value_updates)
def _reduce_non_distributed_value(distribution, aggregation, value, destinations): """Reduce a non-DistributedValue `value` to `destinations`.""" if isinstance(value, values.DistributedValues): raise ValueError( "You are passing a `DistributedValue` to " "`_reduce_non_distributed_value`, which is not allowed.") # If the same value is present on all towers then the PerDevice value will # be a single value. We also handle the case when `value` is a single value # and equal to 0. if value == 0: return 0 # If the aggregation type is MEAN, then this essentially means that the same # value should be on all destinations. if aggregation == variable_scope.VariableAggregation.MEAN: return distribution.broadcast(value, destinations) cross_tower_ops_lib.validate_destinations(destinations) # We do not support an aggregation type of SUM if the value is the same across # all towers. We call this as part of assign functions for MirroredVariables # and summing up identical values across towers is not clearly defined. if (len(distribution.worker_devices) != 1 or not cross_tower_ops_lib.check_destinations(destinations)): raise ValueError( "A non-DistributedValues value cannot be reduced with the " "given aggregation.") # TODO(anjalisridhar): Moves these methods to a device utility file? devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: with ops.device(devices[0]): return array_ops.identity(value) else: value_updates = {} for d in devices: with ops.device(d): value_updates[d] = array_ops.identity(value) return values.Mirrored(value_updates)
def _reduce(self, aggregation, per_replica_value, destinations): if cross_tower_utils.contains_indexed_slices(per_replica_value): raise ValueError( "`IndexSlices` is not supported for Collective All-Reduce.") if context.executing_eagerly(): raise ValueError( "Eager execution is not supported for Collective All-Reduce") all_reduced = self._batch_all_reduce(aggregation, [per_replica_value])[0] if _devices_match(per_replica_value, destinations): return all_reduced else: index = {} for d in get_devices_from(destinations): # pylint: disable=protected-access if d in all_reduced._index: index[d] = all_reduced._index[d] else: with ops.control_dependencies(list( all_reduced._index.values())), ops.device(d): index[d] = array_ops.identity(list(all_reduced._index.values())[0]) return value_lib.Mirrored(index)
def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): """Ungroup results from all-reduce and make Mirrored objects. Each all-reduce result will be divided by the number of destinations before Mirrored objects are created if method_string is "mean". Args: grouped_reduced: a list of lists, each sublist has components for each device, paired with a None. It is the result from cross_tower_utils.aggregate_gradients_using*. destinations: a list of device strings for returned Mirrored objects. method_string: "mean" or "sum". Returns: a list of Mirrored objects. """ index = [{} for _ in range(len(grouped_reduced[0]))] for d, per_device_reduced in enumerate(grouped_reduced): for i, (v, _) in enumerate(per_device_reduced): if method_string == "mean": index[i][destinations[d]] = v / len(destinations) else: index[i][destinations[d]] = v return [value_lib.Mirrored(v) for v in index]
def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): return value_lib.Mirrored({ d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices })