def build_collective_reduce(input_tensors, devices, group_size, collective_keys, reduction_op='Add', unary_op='Id', communication_hint='AUTO', control_inputs=None, executors=None): """Build a subgraph that does one full all-reduce, using the collective Op. If called in eager mode, it's required to supply a list of async executors for each input Tensor. Args: input_tensors: tensors within a single worker graph that are to be reduced together; must be one per device. devices: a list of device strings to run the collective on. group_size: total number of devices globally that will be doing this same reduction. The reduction will actually include the corresponding tensors at all these workers. collective_keys: a CollectiveKeys object. reduction_op: string naming the reduction op. unary_op: string naming the unary final op. communication_hint: string providing hint to runtime for choosing collective implementation. control_inputs: if not None, add control edges between control_inputs and (index-wise) corresponding collective_reduce tensors executors: a list of async executor. Required for eager execution. Returns: An array of final tensors, one per device, computed by the full reduction. Raises: ValueError: There must be at least two tensors over all the workers. """ if context.executing_eagerly(): if (not executors or len(executors) != len(input_tensors) or not all(e.is_async() for e in executors)): raise ValueError( 'collectives requires async executors for each device in eager mode') if len(input_tensors) != len(devices): raise ValueError('collective requires one input tensor for each device, ' 'len(input_tensors) = %d, len(devices) = %d' % (len(input_tensors), len(devices))) if group_size < 2: return input_tensors group_key = collective_keys.get_group_key(devices) instance_key = collective_keys.get_op_instance_key() subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec out_tensors = [] for idx, input_tensor in enumerate(input_tensors): if context.executing_eagerly(): executor_scope = context.executor_scope(executors[idx]) else: executor_scope = ops.NullContextmanager() with executor_scope, \ ops.device(devices[idx]), \ ops.control_dependencies( _control_input(devices, control_inputs, idx)): out_tensor = collective_ops.all_reduce(input_tensor, group_size, group_key, instance_key, reduction_op, unary_op, subdiv_offsets, communication_hint) out_tensors.append(out_tensor) return out_tensors
def _control_input(self, control_input): if control_input is not None: return ops.control_dependencies([control_input]) return ops.NullContextmanager()
def _control_input(self, control_input): if control_input is not None and not self._use_ordering_token(): return ops.control_dependencies([control_input]) return ops.NullContextmanager()
def func_graph_from_py_func(name, python_func, args, kwargs, signature=None, func_graph=None, autograph=False, autograph_options=None, add_control_dependencies=True, arg_names=None, op_return_value=None, collections=None, capture_by_value=None, override_flat_arg_shapes=None): """Returns a `FuncGraph` generated from `python_func`. Args: name: an identifier for the function. python_func: the Python function to trace. args: the positional args with which the Python function should be called; ignored if a signature is provided. kwargs: the keyword args with which the Python function should be called; ignored if a signature is provided. signature: a possibly nested sequence of `TensorSpecs` specifying the shapes and dtypes of the arguments. When a signature is provided, `args` and `kwargs` are ignored, and `python_func` is traced with Tensors conforming to `signature`. If `None`, the shapes and dtypes are inferred from the inputs. func_graph: Optional. An instance of FuncGraph. If provided, we will use this graph else a new one is built and returned. autograph: whether to use autograph to compile `python_func`. See https://www.tensorflow.org/guide/autograph for more information. autograph_options: additional knobs to control when `autograph=True`. See https://www.tensorflow.org/guide/autograph for more information. add_control_dependencies: If True, automatically adds control dependencies to ensure program order matches execution order and stateful ops always execute. arg_names: Optional list of argument names, used to give input placeholders recognizable names. op_return_value: Optional. A Tensor. If set and `python_func` returns Operations, those return values will be replaced with this value. If not set, returning an Operation triggers an error. collections: a dictionary of collections this FuncGraph should start with. If not specified (None), the FuncGraph will read (but not write to) the outer graph's collections that are not whitelisted, and both read and write to the outer graph's collections that are whitelisted. The current whitelisted collections are the global variables, the local variables, and the trainable variables. Defaults to None. capture_by_value: An optional boolean. If True, the func graph will capture Variables by value instead of reference. By default inherit from outer graphs, and failing that will default to False. override_flat_arg_shapes: An optional list of instances that are either `None` or `TensorShape`. The length must match that of `nest.flatten((args, kwargs))`. The entries containing value `None` must match entries in flattened arguments containing non-tensors, while entries containing a `TensorShape` must match entries in the flattened arguments containing tensors. Returns: A FuncGraph. Raises: TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. ValueError: If both `signature` and `override_flat_arg_shapes` are passed in. """ if op_return_value is not None: assert isinstance(op_return_value, ops.Tensor), op_return_value if func_graph is None: func_graph = FuncGraph(name, collections=collections, capture_by_value=capture_by_value) assert isinstance(func_graph, FuncGraph) if add_control_dependencies: control_manager = AutomaticControlDependencies() else: control_manager = ops.NullContextmanager() with func_graph.as_default(), control_manager as a: current_scope = variable_scope.get_variable_scope() default_use_recource = current_scope.use_resource current_scope.set_use_resource(True) if signature is not None and override_flat_arg_shapes is not None: raise ValueError( "Passed both signature and override_flat_arg_shapes: %s and %s." % (signature, override_flat_arg_shapes)) if signature is not None: args = signature kwargs = {} # Creates and names placeholders for all arguments. if override_flat_arg_shapes is not None: flat_args = nest.flatten(args) arg_shapes = override_flat_arg_shapes[:len(flat_args)] kwarg_shapes = override_flat_arg_shapes[len(flat_args):] else: arg_shapes = None kwarg_shapes = None func_args = _get_defun_inputs_from_args( args, arg_names, flat_shapes=arg_shapes) func_kwargs = _get_defun_inputs_from_kwargs( kwargs, flat_shapes=kwarg_shapes) # Convert all Tensors into TensorSpecs before saving the structured inputs. # If storing pure concrete functions that are not called through polymorphic # functions, we don't have access to FunctionSpec, so we need to call the # TensorSpecs by their `arg_names` for later binding. func_graph.structured_input_signature = ( convert_structure_to_signature(func_args, arg_names), convert_structure_to_signature(func_kwargs)) flat_func_args = nest.flatten(func_args) flat_func_kwargs = nest.flatten(func_kwargs) # Temporarily set inputs to allow graph building code to inspect # them. Reassigned below. func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs if isinstance(arg, ops.Tensor)] # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. # Variables to help check whether mutation happens in calling the function # Copy the recursive list, tuple and map structure, but not base objects func_args_before = nest.pack_sequence_as(func_args, flat_func_args) func_kwargs_before = nest.pack_sequence_as( func_kwargs, flat_func_kwargs) def convert(x): """Converts a function output to a Tensor.""" if x is None: return None if op_return_value is not None and isinstance(x, ops.Operation): # TODO(b/79881896): we currently can't capture external control deps, so # this won't work if x needs to be captured (i.e. if python_func returns # captured Operations). with ops.control_dependencies([x]): x = array_ops.identity(op_return_value) elif not isinstance(x, tensor_array_ops.TensorArray): try: x = ops.convert_to_tensor_or_composite(x) except (ValueError, TypeError): raise TypeError( "To be compatible with tf.contrib.eager.defun, Python functions " "must return zero or more Tensors; in compilation of %s, found " "return value of type %s, which is not a Tensor." % (str(python_func), type(x))) if add_control_dependencies: x = a.mark_as_return(x) return x this_tape = tape.push_new_tape() try: if autograph: from tensorflow.python import autograph # pylint: disable=g-import-not-at-top _, original_func = tf_decorator.unwrap(python_func) def wrapper(*args, **kwargs): # Note: functions annotated with @tf.function should always be # converted even though they would meet autograph's whitelisting # criteria. # If this assumption is ever broken, converted_call will need to # handle the possibility of original_func still being a shim, e.g. # bound to WeakrefSelf. return autograph.converted_call( original_func, None, autograph.ConversionOptions( recursive=True, optional_features=autograph_options, force_conversion=True, ), args, kwargs) # Wrapping around a decorator allows checks like tf_inspect.getargspec # to be accurate. converted_func = tf_decorator.make_decorator(original_func, wrapper) python_func = tf_decorator.rewrap(python_func, original_func, converted_func) func_outputs = python_func(*func_args, **func_kwargs) # invariant: `func_outputs` contains only Tensors, IndexedSlices, # SparseTensors, TensorArrays and `None`s. func_outputs = nest.map_structure(convert, func_outputs) check_mutation(func_args_before, func_args) check_mutation(func_kwargs_before, func_kwargs) finally: tape.pop_tape(this_tape) current_scope.set_use_resource(default_use_recource) # Variables in `func_args`, `func_kwargs` should be explicit inputs # to the function, not captured inputs. tape_variables = this_tape.watched_variables() arg_variables = set() inputs = [] for arg in nest.flatten(func_args) + nest.flatten(func_kwargs): if isinstance(arg, resource_variable_ops.ResourceVariable): # Even if an argument variable was not used in the function, we've # already manually captured the resource Tensor when creating argument # placeholders. resource_placeholder = func_graph.captures.pop(arg.handle, None) if resource_placeholder is None: continue arg_variables.add(arg) inputs.append(resource_placeholder) elif isinstance(arg, ops.Tensor): inputs.append(arg) variables = [v for v in tape_variables if v not in arg_variables] func_graph.inputs = inputs + list(func_graph.captures.values()) func_graph.structured_outputs = func_outputs # Returning a closed-over tensor does not trigger convert_to_tensor. func_graph.outputs.extend( func_graph.capture(x) for x in flatten(func_graph.structured_outputs) if x is not None) func_graph.variables = variables if add_control_dependencies: func_graph.control_outputs.extend(control_manager.ops_which_must_run) # Register any other functions defined in the graph. with ops.init_scope(): if context.executing_eagerly(): for f in func_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? context.add_function(f._c_func.func) # pylint: disable=protected-access return func_graph
def scope(self): return ops.NullContextmanager()