def simple_broadcast(value, destinations, always_mirrored=False): """Broadcast `value` to `destinations` using simple copies.""" devices = get_devices_from(destinations) if len(devices) == 1 and not always_mirrored: return cross_device_utils.copy_tensor_or_indexed_slices_to_device( value, devices[0]) else: value_updates = [] for d in devices: value_updates.append( cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d)) return distribute_utils.regroup(value_updates, wrap_class=value_lib.Mirrored)
def simple_broadcast(value, destinations, always_mirrored=False): """Broadcast `value` to `destinations` using simple copies.""" device_map, logical_device = get_device_map_from(destinations) devices = device_map.logical_to_actual_devices(logical_device) if len(devices) == 1 and not always_mirrored: return cross_device_utils.copy_tensor_or_indexed_slices_to_device( value, devices[0]) else: value_updates = [] for d in devices: value_updates.append( cross_device_utils.copy_tensor_or_indexed_slices_to_device( value, d)) return value_lib.Mirrored(device_map, value_updates, logical_device)
def _simple_broadcast(value, destinations): index = {} devices = get_devices_from(destinations) for d in devices: index[d] = cross_device_utils.copy_tensor_or_indexed_slices_to_device( value, d) return value_lib.Mirrored(index)
def testCopyTensor(self): with ops.device("/cpu:0"): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) destination = "/gpu:0" result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( t, destination) self._assert_values_equal(t, result) self.assertEqual(device_util.resolve(destination), device_util.resolve(result.device))
def testCopyIndexedSlicesNoDenseShape(self): with ops.device("/cpu:0"): t = indexed_slices.IndexedSlices(indices=array_ops.identity([0]), values=array_ops.identity([1.])) destination = "/gpu:0" result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( t, destination) self.assertIsInstance(result, indexed_slices.IndexedSlices) self.assertAllEqual(t.indices, result.indices) self.assertAllEqual(t.values, result.values) self.assertEqual(device_util.resolve(destination), device_util.resolve(result.device))