def testContainsIndexedSlices_PerReplica(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
def _reduce(self, aggregation, per_device_value, destinations): if cross_tower_utils.contains_indexed_slices(per_device_value): raise ValueError( "`IndexSlices` is not supported for Collective All-Reduce.") if context.executing_eagerly(): raise ValueError( "Eager execution is not supported for Collective All-Reduce") all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] if _devices_match(per_device_value, destinations): return all_reduced else: index = {} for d in get_devices_from(destinations): # pylint: disable=protected-access if d in all_reduced._index: index[d] = all_reduced._index[d] else: with ops.control_dependencies( list(all_reduced._index.values())), ops.device(d): index[d] = array_ops.identity( list(all_reduced._index.values())[0]) return value_lib.Mirrored(index)
def testContainsIndexedSlices_PerReplica(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica))
def testContainsIndexedSlices_PerDeviceMapOutput(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) per_device = value_lib.PerDevice({ "/gpu:0": value_lib.MapOutput([t0]), "/cpu:0": value_lib.MapOutput([t1])}) self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
def _batch_reduce(self, aggregation, value_destination_pairs): all_devices_match = _all_devices_match(value_destination_pairs) contains_indexed_slices = cross_tower_utils.contains_indexed_slices( value_destination_pairs) if (all_devices_match and not context.executing_eagerly() and not contains_indexed_slices): return self._batch_all_reduce(aggregation, [v[0] for v in value_destination_pairs]) else: if not all_devices_match: logging.warning("Efficient batch_reduce is not supported if " "destinations are different.") return [ self._reduce(aggregation, t, destinations=v) for t, v in value_destination_pairs ]
def _batch_reduce(self, aggregation, value_destination_pairs): all_devices_match = _all_devices_match(value_destination_pairs) contains_indexed_slices = cross_tower_utils.contains_indexed_slices( value_destination_pairs) if (all_devices_match and not context.executing_eagerly() and not contains_indexed_slices): return self._batch_all_reduce( aggregation, [v[0] for v in value_destination_pairs]) else: if not all_devices_match: logging.warning("Efficient batch_reduce is not supported if " "destinations are different.") return [ self._reduce(aggregation, t, destinations=v) for t, v in value_destination_pairs ]
def _reduce(self, aggregation, per_device_value, destinations): contains_indexed_slices = cross_tower_utils.contains_indexed_slices( per_device_value) if ((destinations is None or _devices_match(per_device_value, destinations)) and not context.executing_eagerly() and not contains_indexed_slices): return self._batch_all_reduce(aggregation, [per_device_value])[0] else: if contains_indexed_slices: logging.log_first_n( logging.WARN, "Efficient allreduce is not supported for IndexedSlices.", 10) devices = get_devices_from(destinations or per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, math_ops.add_n, aggregation) return self.broadcast(reduced, devices)
def _reduce(self, aggregation, per_device_value, destinations): contains_indexed_slices = cross_tower_utils.contains_indexed_slices( per_device_value) if ((destinations is None or _devices_match(per_device_value, destinations)) and not context.executing_eagerly() and not contains_indexed_slices): return self._batch_all_reduce(aggregation, [per_device_value])[0] else: if contains_indexed_slices: logging.log_first_n( logging.WARN, "Efficient allreduce is not supported for IndexedSlices.", 10) devices = get_devices_from(destinations or per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, math_ops.add_n, aggregation) return self.broadcast(reduced, devices)
def _batch_reduce(self, aggregation, value_destination_pairs): if cross_tower_utils.contains_indexed_slices(value_destination_pairs): raise ValueError( "`IndexSlices` is not supported for Collective All-Reduce.") if context.executing_eagerly(): raise ValueError( "Eager execution is not supported for Collective All-Reduce") all_devices_match = _all_devices_match(value_destination_pairs) if all_devices_match: return self._batch_all_reduce(aggregation, [v[0] for v in value_destination_pairs]) else: if not all_devices_match: logging.log_first_n( logging.WARN, "Efficient batch_reduce is not supported if " "destinations are different.", 10) return [ self._reduce(aggregation, t, destinations=v) for t, v in value_destination_pairs ]
def _batch_reduce(self, aggregation, value_destination_pairs): if cross_tower_utils.contains_indexed_slices(value_destination_pairs): raise ValueError( "`IndexSlices` is not supported for Collective All-Reduce.") if context.executing_eagerly(): raise ValueError( "Eager execution is not supported for Collective All-Reduce") all_devices_match = _all_devices_match(value_destination_pairs) if all_devices_match: return self._batch_all_reduce(aggregation, [v[0] for v in value_destination_pairs]) else: if not all_devices_match: logging.log_first_n( logging.WARN, "Efficient batch_reduce is not supported if " "destinations are different.", 10) return [ self._reduce(aggregation, t, destinations=v) for t, v in value_destination_pairs ]
def _reduce(self, aggregation, per_device_value, destinations): if cross_tower_utils.contains_indexed_slices(per_device_value): raise ValueError( "`IndexSlices` is not supported for Collective All-Reduce.") if context.executing_eagerly(): raise ValueError( "Eager execution is not supported for Collective All-Reduce") all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] if _devices_match(per_device_value, destinations): return all_reduced else: index = {} for d in get_devices_from(destinations): # pylint: disable=protected-access if d in all_reduced._index: index[d] = all_reduced._index[d] else: with ops.control_dependencies(list( all_reduced._index.values())), ops.device(d): index[d] = array_ops.identity(list(all_reduced._index.values())[0]) return value_lib.Mirrored(index)
def testContainsIndexedSlices_Tuple(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1)))
def testIsIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices(t))
def testContainsIndexedSlices_Tuple(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1)))
def testIsIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices(t))