def testPyFunctionAsync(self): def simple_fn(v): one = constant_op.constant(1.) return v + one @def_function.function def test_fn(v): return script_ops.eager_py_func(simple_fn, [v], dtypes.float32) async_executor = executor.new_executor(enable_async=True) with context.executor_scope(async_executor): test_var = variables.Variable(2.) self.assertAllEqual(test_fn(test_var), 3.0) async_executor.wait()
def testCancelGetNextWithDevice(self, cls): ping = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64) pong = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64) @def_function.function def map_fn(v): ball = ping.dequeue() with ops.control_dependencies([pong.enqueue(ball)]): return v + ping.dequeue() dataset = dataset_ops.Dataset.range(10) dataset = dataset.map(map_fn) # We need to set prefetch_buffer_size=0 so that we can cancel the # MultiDeviceIteratorGetNextFromShardOp from eager. If # prefetch_buffer_size>0, that op runs in the background threads of the # prefetch and can only be cancelled by deleting the iterator. multi_device_iterator = cls(dataset, [self._devices[1], self._devices[2]], prefetch_buffer_size=0) @def_function.function def get_next_device1(): return multi_device_iterator.get_next(self._devices[1]) async_executor = executor.new_executor(enable_async=True) with context.executor_scope(async_executor): cancel_mgr = cancellation.CancellationManager() cancel_mgr.get_cancelable_function( get_next_device1.get_concrete_function())() # Make sure we cancel in the middle of get_next. ping.enqueue(0) pong.dequeue() cancel_mgr.start_cancel() with self.assertRaises(errors.CancelledError): async_executor.wait() # Note that fetching from upstream iterator is not cancelled with the # cancellation of get_next. ping.enqueue(0) # Cancelling a get_next on one device shouldn't cancel the # multi_device_iterator and iterators on other devices. ping.enqueue(0) ping.enqueue(0) self.assertEqual( 1, multi_device_iterator.get_next(self._devices[2]).numpy())
def testRemoteFunctionCancellation(self): context._reset_context() logical_devices = [] logical_devices.append(context.LogicalDeviceConfiguration()) logical_devices.append(context.LogicalDeviceConfiguration()) framework_config.set_logical_device_configuration( framework_config.list_physical_devices("CPU")[0], logical_devices) @function.Defun(dtypes.float32) def _remote_fn(v): # We run two collectives here to make sure we cancel in the middle of the # RemoteCall. The second one should never finish. anchor = collective_ops.all_reduce_v2( v, group_size=2, group_key=1, instance_key=1) with ops.control_dependencies([anchor]): return collective_ops.all_reduce_v2( v, group_size=2, group_key=1, instance_key=2) @eager_def_function.function def run(): with ops.device("/cpu:0"): return functional_ops.remote_call( args=[constant_op.constant([1.])], Tout=[dtypes.float32], f=_remote_fn, target="/cpu:1")[0] async_executor = executor.new_executor(enable_async=True) cancel_mgr = cancellation.CancellationManager() with context.executor_scope(async_executor): # This should never finish. cancel_mgr.get_cancelable_function(run.get_concrete_function())() with ops.device("/cpu:0"): collective_ops.all_reduce_v2([1.], group_size=2, group_key=1, instance_key=1) cancel_mgr.start_cancel() with self.assertRaises(errors.CancelledError): async_executor.wait()
def thread_fn(executor_obj, device, results): with self._coord.stop_on_exception(): for i in range(num_calls): with context.executor_scope(executor_obj): with ops.device(device): results[i] = worker_fn()
def build_collective_reduce(input_tensors, devices, group_size, collective_keys, reduction_op='Add', unary_op='Id', communication_hint='AUTO', control_inputs=None, executors=None): """Build a subgraph that does one full all-reduce, using the collective Op. If called in eager mode, it's required to supply a list of async executors for each input Tensor. Args: input_tensors: tensors within a single worker graph that are to be reduced together; must be one per device. devices: a list of device strings to run the collective on. group_size: total number of devices globally that will be doing this same reduction. The reduction will actually include the corresponding tensors at all these workers. collective_keys: a CollectiveKeys object. reduction_op: string naming the reduction op. unary_op: string naming the unary final op. communication_hint: string providing hint to runtime for choosing collective implementation. control_inputs: if not None, add control edges between control_inputs and (index-wise) corresponding collective_reduce tensors executors: a list of async executor. Required for eager execution. Returns: An array of final tensors, one per device, computed by the full reduction. Raises: ValueError: There must be at least two tensors over all the workers. """ if context.executing_eagerly(): if (not executors or len(executors) != len(input_tensors) or not all(e.is_async() for e in executors)): raise ValueError( 'collectives requires async executors for each device in eager mode' ) if len(input_tensors) != len(devices): raise ValueError( 'collective requires one input tensor for each device, ' 'len(input_tensors) = %d, len(devices) = %d' % (len(input_tensors), len(devices))) if group_size < 2: return input_tensors group_key = collective_keys.get_group_key(devices) instance_key = collective_keys.get_op_instance_key() subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec out_tensors = [] for idx, input_tensor in enumerate(input_tensors): if context.executing_eagerly(): executor_scope = context.executor_scope(executors[idx]) else: executor_scope = ops.NullContextmanager() with executor_scope, \ ops.device(devices[idx]), \ ops.control_dependencies( _control_input(devices, control_inputs, idx)): out_tensor = collective_ops.all_reduce(input_tensor, group_size, group_key, instance_key, reduction_op, unary_op, subdiv_offsets, communication_hint) out_tensors.append(out_tensor) return out_tensors
def _executor_scope(self): if context.executing_eagerly() and not self._executor: raise ValueError('collectives requires a async executor in eager mode') if context.executing_eagerly(): return context.executor_scope(self._executor) return ops.NullContextmanager()