def testCaptureOrdering(self): v1 = resource_variable_ops.ResourceVariable(1.0) v2 = resource_variable_ops.ResourceVariable(2.0) v3 = resource_variable_ops.ResourceVariable(3.0) @def_function.function def fn(): return v1 + v2 + v3 concrete_fn = fn.get_concrete_function() original_captures = concrete_fn.graph.captures outputs = concrete_fn.graph.outputs for _ in range(100): g = func_graph.FuncGraph('lifted') lift_to_graph.lift_to_graph(outputs, g, add_sources=True, handle_captures=True) lifted_captures = g.captures self.assertLen(lifted_captures, 3) for original_capture, lifted_capture in zip( original_captures.values(), lifted_captures.values()): self.assertEqual(original_capture.name, lifted_capture.name)
def prune(self, feeds, fetches, name=None, input_signature=None): # TODO(b/129646028): Add support for CompositeTensors. name = name or "pruned" feeds = nest.map_structure(self.graph.as_graph_element, feeds) fetches = nest.map_structure(self.graph.as_graph_element, fetches) flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches) for f in flat_feeds: if not isinstance(f, ops.Tensor): raise ValueError("Feeds must be tensors.") # Ignoring all feeds that are captures allows prune to be called # using wrapped_func.inputs even when it uses variables internal_captures = self.graph.internal_captures flat_feeds = [f for f in flat_feeds if f not in internal_captures] operation_fetches = [] for f in flat_fetches: if isinstance(f, ops.Operation): operation_fetches.append(f) elif not isinstance(f, ops.Tensor): raise ValueError("Fetches must be tensors or operations.") for f in flat_feeds + flat_fetches: if f.graph is not self._func_graph: raise ValueError( "Can only prune function whose feeds and fetches " "are from this graph (%s). Tensor %s from graph %s" % (self._func_graph, f, f.graph)) with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph(name) lift_map = lift_to_graph.lift_to_graph(flat_fetches, pruned_graph, sources=flat_feeds + internal_captures) pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor)) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) for external_capture, internal_capture in self.graph.captures.items(): pruned_graph.captures[external_capture] = lift_map[ internal_capture] pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) pruned_graph.inputs.extend(pruned_graph.captures.values()) pruned_graph.variables = self.graph.variables def _structured_output_mapping(fetched): lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None return lifted pruned_graph.structured_outputs = nest.map_structure( _structured_output_mapping, fetches) pruned_graph.structured_input_signature = input_signature pruned_fn = WrappedFunction(pruned_graph, variable_holder=self._variable_holder) pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access # TODO(kathywu): Enable keyword arguments if an input signature is specified pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds] # pylint: disable=protected-access return pruned_fn
def prune(self, feeds, fetches): flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches) for f in flat_feeds + flat_fetches: if not isinstance(f, ops.Tensor): raise ValueError("Feeds and fetches must be tensors.") if f.graph is not self._func_graph: raise ValueError( "Can only prune function whose feeds and fetches " "are from this graph (%s). Tensor %s from graph %s" % (self._func_graph, f, f.graph)) with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph("pruned") sink_tensor = array_ops.identity_n(flat_fetches)[0] lift_map = lift_to_graph.lift_to_graph(sink_tensor, pruned_graph, sources=flat_feeds + self.graph.internal_captures) pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches) for external_capture, internal_capture in self.graph.captures.items(): pruned_graph.captures[external_capture] = lift_map[ internal_capture] pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) pruned_graph.inputs.extend(pruned_graph.captures.values()) pruned_graph.structured_outputs = nest.map_structure( lambda node: lift_map[node], fetches) pruned_fn = WrappedFunction(pruned_graph, variable_holder=self._variable_holder) pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access pruned_fn._arg_keywords = [] # pylint: disable=protected-access return pruned_fn
def wrap_cached_variables(concrete_function): """Wraps the concrete function if it uses cached read tensors. This function creates a new concrete function that captures variables instead of the cached read tensors. Args: concrete_function: A Concrete function that maybe captures cached read tensors. Returns: A concrete function that wraps the original concrete function, which captures variables instead. If the original function did not capture any cached values, then the function is not wrapped and the original object is returned. """ outer_graph = func_graph_module.FuncGraph("{}_no_cache".format( concrete_function.graph.name)) captures = concrete_function.graph._captures # pylint: disable=protected-access mapped_captures = None remapped_captures = {} # Update the external captures to use read tensors generated in the outer # graph. with outer_graph.as_default(): for capture, placeholder in concrete_function.graph.captures: cached_variable = getattr(capture, "_cached_variable", None) if cached_variable is None: continue cached_variable = cached_variable() new_cached_value = cached_variable.read_value() remapped_captures[id(capture)] = captures[id(capture)] captures[id(capture)] = (new_cached_value, placeholder) mapped_captures = True if not mapped_captures: return concrete_function inner_concrete = defun.ConcreteFunction(concrete_function.graph) def wrap_function(*args): return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access args = nest.flatten(concrete_function.structured_input_signature, expand_composites=True) func_graph_module.func_graph_from_py_func(None, wrap_function, args=tuple(args), kwargs={}, func_graph=outer_graph) fn = defun.ConcreteFunction(outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access # Return the captures to their original values for key, capture in remapped_captures.items(): captures[key] = capture return fn
def test_to_placeholder(self, shape, batch_size, ragged_rank): inp = layers.Input(shape=shape, batch_size=batch_size, ragged=True) self.assertEqual(inp.ragged_rank, ragged_rank) self.assertAllEqual(inp.shape, [batch_size] + list(shape)) with func_graph.FuncGraph('test').as_default(): placeholder = inp._to_placeholder() self.assertEqual(placeholder.ragged_rank, ragged_rank) self.assertAllEqual(placeholder.shape, [batch_size] + list(shape))
def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls, symbol_names): """Makes a best-effort attempt to substitute undefineds with placeholders. Note: this substitution requires two things to happen: 1. the types of loop variables could be inferred (usually by staging one iteration) 2. these types could be replaced by placeholders (e.g. zero values, for tensors. Args: body: a function representing the loop body. See while_stmt. get_state: state getter for the loop statement. See while_stmt. set_state: state getter for the loop statement. See while_stmt. init_vars: loop variables before entering the loop. See while_stmt. nulls: list of boolean flags indicating whether the corresponding loop var is None or undefined. symbol_names: list of loop variable names. See while_stmt. Returns: A tuple (success, new_init_vars). success is a boolean flag indicating whether types could be successfully inferred (step 1 above). new_init_vars contains the loop vars, with None or undefined values replaced by placeholders, where possible (step 2 above). """ state_modified = False try: # Stage an iteration of the loop body in a temporary graph. with func_graph.FuncGraph('tmp').as_default(): # This call to set_state helps report nicer error messages when symbols # are inconsistently used. set_state(init_vars) state_modified = True body() first_iter_vars = get_state() except (UnboundLocalError, TypeError, ValueError, KeyError): # Fall back to the old functionality. It will likely result in an input # validation failure. first_iter_vars = None finally: if state_modified: set_state(init_vars) if first_iter_vars is not None: # Note: the actual placeholder value doesn't matter, because as the staging # proved, it will be replaced by an actual value before being read. init_vars = tuple( (_placeholder_value(iv, v) if n else v) for v, n, iv in zip(init_vars, nulls, first_iter_vars)) success = True else: success = False # This check runs regardless, in case we captured non-Tensor inputs. _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars) return success, init_vars
def prune(self, feeds, fetches): flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches) for f in flat_feeds: if not isinstance(f, ops.Tensor): raise ValueError("Feeds must be tensors.") tensor_fetches = [] operation_fetches = [] for f in flat_fetches: if isinstance(f, ops.Tensor): tensor_fetches.append(f) elif isinstance(f, ops.Operation): operation_fetches.append(f) else: raise ValueError("Fetches must be tensors or operations.") for f in flat_feeds + flat_fetches: if f.graph is not self._func_graph: raise ValueError( "Can only prune function whose feeds and fetches " "are from this graph (%s). Tensor %s from graph %s" % (self._func_graph, f, f.graph)) with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph("pruned") with ops.control_dependencies(operation_fetches): if tensor_fetches: identity_fetches = array_ops.identity_n(tensor_fetches) sink_tensor = identity_fetches[0] else: identity_fetches = [] sink_tensor = control_flow_ops.no_op() lift_map = lift_to_graph.lift_to_graph(sink_tensor, pruned_graph, sources=flat_feeds + self.graph.internal_captures) for original_fetch, identity_fetch in zip(tensor_fetches, identity_fetches): lift_map[original_fetch] = lift_map[identity_fetch] pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor)) for external_capture, internal_capture in self.graph.captures.items(): pruned_graph.captures[external_capture] = lift_map[ internal_capture] pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) pruned_graph.inputs.extend(pruned_graph.captures.values()) def _structured_output_mapping(fetched): lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None return lifted pruned_graph.structured_outputs = nest.map_structure( _structured_output_mapping, fetches) pruned_fn = WrappedFunction(pruned_graph, variable_holder=self._variable_holder) pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access pruned_fn._arg_keywords = [] # pylint: disable=protected-access return pruned_fn
def _construct_concrete_function(input_func, graph_def, converted_handles): """Creates a ConcreteFunction from the input function and frozen graph. Args: input_func: ConcreteFunction. graph_def: TensorFlow GraphDef. converted_handles: a set of handles of the varialbes in input_func that were converted to constant in `graph_def`. Returns: ConcreteFunction containing the graph_def. """ captured_inputs = input_func.inputs[-len(input_func.captured_inputs):] captured_input_indices = [ input_func.captured_inputs.index(handle) for handle in converted_handles ] converted_inputs = set( [captured_inputs[index] for index in captured_input_indices]) not_converted_inputs = set(input_func.inputs).difference(converted_inputs) output_graph = func_graph.FuncGraph(input_func.graph.name) with output_graph.as_default(): importer.import_graph_def(graph_def, name="") output_graph.inputs = _get_tensors_from_graph(output_graph, not_converted_inputs) output_graph.outputs = _get_tensors_from_graph(output_graph, input_func.outputs) output_graph.structured_outputs = input_func.graph.structured_outputs output_graph.structured_input_signature = ( input_func.graph.structured_input_signature) # pylint: disable=protected-access # Create the ConcreteFunction and add it to the global context. output_func = function.ConcreteFunction(output_graph, attrs=input_func._attrs, signature=input_func._signature) output_func.add_to_graph() # Inject the captured inputs into the ConcreteFunction. output_func._captured_inputs = [ handle for handle in input_func.captured_inputs if handle not in converted_handles ] output_func.graph.variables = [ var for var in input_func.graph.variables if var.handle not in converted_handles ] output_func._arg_keywords = input_func._arg_keywords output_func._num_positional_args = input_func._num_positional_args # pylint: enable=protected-access # Register the gradients in the current root context. with ops.init_scope(): output_func._register_gradient() # pylint: disable=protected-access return output_func
def _try_convert_to_tensor_spec(arg, dtype_hint): """Returns None or TensorSpec obtained if `arg` is converted to tensor.""" try: # Note: try conversion in a FuncGraph to avoid polluting current context. with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) except (TypeError, ValueError): return None
def add_op_to_graph(num_ops): with func_graph.FuncGraph("resource").as_default(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[]) resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)) for _ in range(num_ops): gen_resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32)
def while_stmt( test, body, get_state, set_state, init_vars, basic_symbol_names=None, composite_symbol_names=None, opts=None, ): """Functional form of a while statement. The loop operates on a so-called state, which includes all symbols that are variant across loop iterations. In what follows we refer to state as either a tuple of entities that represent an actual state, or a list of arguments of the corresponding types. Args: test: Callable with the state as arguments, and boolean return type. The loop condition. body: Callable with the state as arguments, and state as return type. The actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. set_state: Additional callable which save values captured by get_state back into the Python environment. This is only useful when staging the loop. init_vars: Tuple containing the initial state. basic_symbol_names: Tuple containing basic loop var names. composite_symbol_names: Tuple containing composite loop var names. opts: Optional dict of extra loop parameters. Returns: Tuple containing the final state. """ # Evaluate the initial test once in order to do the dispatch. The evaluation # is isolated to minimize unwanted side effects. # TODO(mdan): Do a full iteration - some state types might lower to Tensor. with func_graph.FuncGraph('tmp').as_default(): init_test = test(*init_vars) # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine # with the re-evaluation of `test` that `_tf_while_stmt` will make. if tensors.is_dense_tensor(init_test): return _tf_while_stmt(test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts) # Normal Python: We already consumed one evaluation of `test`; consistently, # unroll one iteration before dispatching to a normal loop. # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt? if not init_test: return init_vars init_vars = body(*init_vars) return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
def _basic_cond(self, true_value, false_value): # Eager cond had different semantics, we don't test those here. with func_graph.FuncGraph('tmp').as_default(): return control_flow.if_stmt(cond=constant_op.constant(True), body=true_value, orelse=false_value, get_state=lambda: (), set_state=lambda _: None, basic_symbol_names=('s', ), composite_symbol_names=())
def testVariableInFuncGraph(self, distribution): def model_fn(): v = variable_scope.variable(2.0, name="bar") ds_context.get_replica_context().merge_call(lambda _: _) return v with func_graph.FuncGraph("fg").as_default(), distribution.scope(): v1 = variable_scope.variable(1.0, name="foo") v2 = distribution.extended.call_for_each_replica(model_fn) self._test_mv_properties(v1, "foo:0", distribution) self._test_mv_properties(v2, "bar:0", distribution)
def __init__(self, variable_holder=None, **kwargs): self._variable_holder = ( variable_holder or VariableHolder(share_variables=True)) name = kwargs.pop("name", "wrapped_function_graph") # Always start with empty collections, unless otherwise specified. Setting # `collections=None` will copy the collections from the outer graph. collections = kwargs.pop("collections", {}) self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs) self._wrapped_function = WrappedFunction(self.graph, self._variable_holder) self._functions = {}
def _convert_saved_model_v2(self): """Convert the input SavedModel in 2.0 format.""" self._saved_model = load.load(self._input_saved_model_dir, self._input_saved_model_tags) func = self._saved_model.signatures[ self._input_saved_model_signature_key] frozen_func = convert_to_constants.convert_variables_to_constants_v2( func) self._grappler_meta_graph_def = saver.export_meta_graph( graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() for array in func.inputs + func.outputs: fetch_collection.node_list.value.append(array.name) self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom( fetch_collection) # Run TRT optimizer in Grappler to convert the graph. self._run_conversion() def _get_tensor(graph, tensors): new_tensors = [] for tensor in tensors: new_tensor = graph.get_tensor_by_name(tensor.name) new_tensor.set_shape(tensor.shape) new_tensors.append(new_tensor) return new_tensors # TODO(laigd): do we need to use different name e.g. "trt_func_graph"? converted_graph = func_graph.FuncGraph(func.graph.name) with converted_graph.as_default(): importer.import_graph_def(self._converted_graph_def, name="") converted_graph.inputs = _get_tensor(converted_graph, func.graph.inputs) converted_graph.outputs = _get_tensor(converted_graph, func.graph.outputs) converted_graph.structured_outputs = func.graph.structured_outputs converted_graph.structured_input_signature = ( func.graph.structured_input_signature) # pylint: disable=protected-access # TODO(laigd): should we set up the signature as well? self._converted_func = function.ConcreteFunction(converted_graph, attrs=None, signature=None) self._converted_func.add_to_graph() self._converted_func._arg_keywords = func._arg_keywords self._converted_func._num_positional_args = func._num_positional_args self._converted_func._captured_inputs = func._captured_inputs self._converted_func.graph.variables = func.graph.variables
def test_unique_graph_func_graph(self): """Test for get_unique_graph with FuncGraph.""" outer = ops_lib.Graph() with outer.as_default(): k1 = constant_op.constant(1) inner = func_graph.FuncGraph("inner") inner._graph_key = outer._graph_key with inner.as_default(): k2 = constant_op.constant(2) unique_graph = op_selector.get_unique_graph([k1, k2]) self.assertEqual(unique_graph._graph_key, inner._graph_key)
def _initialize(self, args, kwds, add_initializers_to=None): """Initializes, on the first call. Creates two `Function`s, one that will allow creation of variables and one that won't. Additionally runs a trace for the `Function` that allows creation of variables. Args: args: Arguments to the underlying python callable. kwds: Keyword arguments to the python callable. add_initializers_to: Where to collect variable initializers, if not None. """ created_variables = [] lifted_initializer_graph = func_graph_module.FuncGraph("initializer") lifted_all_initializers = [True] lifted_placeholders = [] def variable_capturing_scope(unused_next_creator, **kwds): """Creates UnliftedInitializerVariables and saves references to them.""" v = UnliftedInitializerVariable( add_initializers_to=add_initializers_to, lifted_initializer_graph=lifted_initializer_graph, lifted_all_initializers=lifted_all_initializers, lifted_placeholders=lifted_placeholders, **kwds) created_variables.append(weakref.ref(v)) return v self._created_variables = created_variables self._stateful_fn = self._defun_with_scope(variable_capturing_scope) self._stateful_fn._name = self._name # pylint: disable=protected-access # Force the definition of the function for these arguments self._lifted_initializer_graph = lifted_initializer_graph self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph) self._lifted_placeholders = lifted_placeholders self._concrete_stateful_fn = ( self._stateful_fn. _get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access *args, **kwds)) self._lifted_all_initializers = lifted_all_initializers[0] def invalid_creator_scope(*unused_args, **unused_kwds): """Disables variable creation.""" raise ValueError("tf.function-decorated function tried to create " "variables on non-first call.") self._stateless_fn = self._defun_with_scope(invalid_creator_scope) self._stateless_fn._name = self._name # pylint: disable=protected-access
def testExecutingEagerlyOutsideFunction(self, distribution): """Verify we preserve the value of executing_eagerly_outside_functions().""" def model_fn(): return ops.executing_eagerly_outside_functions() originally = ops.executing_eagerly_outside_functions() with distribution.scope(): in_scope = ops.executing_eagerly_outside_functions() in_model_fn = distribution.extended.call_for_each_replica(model_fn) unwrapped = distribution.experimental_local_results(in_model_fn) self.assertEqual(in_scope, unwrapped[0]) self.assertEqual(in_scope, originally) # Verify this all again, but this time in a FuncGraph. with func_graph.FuncGraph("fg").as_default(), distribution.scope(): in_scope = ops.executing_eagerly_outside_functions() in_model_fn = distribution.extended.call_for_each_replica(model_fn) unwrapped = distribution.experimental_local_results(in_model_fn) self.assertEqual(in_scope, unwrapped[0]) self.assertEqual(in_scope, originally)
def _construct_concrete_function(input_func, graph_def): """Creates a ConcreteFunction from the input function and frozen graph. Args: input_func: ConcreteFunction. graph_def: TensorFlow GraphDef. Returns: ConcreteFunction containing the graph_def. """ output_graph = func_graph.FuncGraph(input_func.graph.name) with output_graph.as_default(): importer.import_graph_def(graph_def, name="") output_graph.inputs = _get_tensors_from_graph(output_graph, input_func.inputs) output_graph.outputs = _get_tensors_from_graph(output_graph, input_func.outputs) output_graph.structured_outputs = input_func.graph.structured_outputs output_graph.structured_input_signature = ( input_func.graph.structured_input_signature) # pylint: disable=protected-access # Create the ConcreteFunction and add it to the global context. output_func = function.ConcreteFunction(output_graph, attrs=input_func._attrs, signature=input_func._signature) output_func.add_to_graph() # Inject the captured inputs into the ConcreteFunction. output_func._captured_inputs = input_func.captured_inputs output_func.graph.variables = input_func.graph.variables output_func._arg_keywords = input_func._arg_keywords output_func._num_positional_args = input_func._num_positional_args # pylint: enable=protected-access # Register the gradients in the current root context. with ops.init_scope(): output_func._register_gradient() # pylint: disable=protected-access return output_func
def testClassAttrsRemoved(self): """Tests that _class attrs (from colocate_with()) are removed.""" @def_function.function def fn(): two = constant_op.constant(2.0, name='two') ten = constant_op.constant(10.0, name='ten') twenty = math_ops.multiply(two, ten, name='twenty') three = constant_op.constant(3.0, name='three') with framework_ops.colocate_with(twenty): thirty = math_ops.multiply(three, ten, name='thirty') return ten, twenty, thirty concrete_fn = fn.get_concrete_function() self.assertItemsEqual( # Before lifting, 'fn' has colocation attrs. concrete_fn.graph.get_operation_by_name( 'thirty').colocation_groups(), [compat.as_bytes('loc:@twenty')]) thirty_out = concrete_fn.graph.outputs[2] g = func_graph.FuncGraph('lifted') lift_to_graph.lift_to_graph([thirty_out], g) # After lifting, colocation attrs are gone. ops = g.get_operations() self.assertItemsEqual( [op.name for op in ops], [ 'three', 'ten', 'thirty', # Lifted from `fn` body. thirty_out.op.name ]) # Wrapper for output. for op in ops: with self.assertRaises(ValueError): class_attr = op.get_attr('_class') # Expected not to exist. print('Unexpected class_attr', class_attr, 'on', op.name) self.assertItemsEqual( op.colocation_groups(), # Expect default self-ref. [compat.as_bytes('loc:@%s' % op.name)])
def _basic_cond(self, body_fn, else_fn): def body(): nonlocal x x = body_fn() def orelse(): nonlocal x x = else_fn() def set_state(cond_vars): nonlocal x x, = cond_vars x = 0 # Eager cond had different semantics, we don't test those here. with func_graph.FuncGraph('tmp').as_default(): control_flow.if_stmt(cond=constant_op.constant(True), body=body, orelse=orelse, get_state=lambda: (x, ), set_state=set_state, symbol_names=('x', ), nouts=1) return x
def prune(self, feeds, fetches, name=None, input_signature=None): """Extract a subgraph of this function's underlying graph. Wraps the subgraph in a new `WrappedFunction` object. Args: feeds: Input tensors to the subgraph to extract, as `Tensor` objects. fetches: Possibly-nested Python data structure containing information about outputs of the target subgraph. Each entry can either be a `Tensor` object (for data outputs), an `Operation` object (for control outputs), or a `TensorInfo` proto. Any additional shape/dtype information provided in a `TensorInfo` and not present in the original graph will be added to the returned subgraph. name: (optional) Name to give to the underlying `FuncGraph` of the returned object. If no name is provided, the graph's name will be `"pruned"`. input_signature: (optional) possibly-nested Python data structure containing `TensorSpec` objects, with which to populate the returned functions's `FuncGraph`'s `structured_input_signature` field. Returns: A new `WrappedFunction` object containing a copy of the portion of this object's graph that goes from `feeds` to `fetches`. """ # TODO(b/129646028): Add support for CompositeTensors. name = name or "pruned" flat_feeds = nest.flatten(feeds, expand_composites=True) flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] for f in flat_feeds: if not isinstance(f, ops.Tensor): raise ValueError("Feeds must be tensors.") # Ignoring all feeds that are captures allows prune to be called # using wrapped_func.inputs even when it uses variables internal_captures = object_identity.ObjectIdentitySet( self.graph.internal_captures) flat_feeds = [f for f in flat_feeds if f not in internal_captures] operation_fetches = [] tensor_fetches = [] tensor_infos = [] def _fetch_preprocesing_callback(fetch): """Extract out lists of ops, tensors, and tensor type info. Turns TensorInfos into Tensors in the original `fetches` structure. Also extracts ops from `fetches`. Args: fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string identifying a Tensor or Operation. Returns: `fetch` converted to a Tensor. """ if isinstance(fetch, ops.Operation): operation_fetches.append(fetch) return fetch elif isinstance(fetch, meta_graph_pb2.TensorInfo): tensor_infos.append(fetch) decoded = _get_element_from_tensor_info( fetch, self._func_graph) if (tensor_util.is_tensor(decoded) or isinstance( decoded, composite_tensor.CompositeTensor)): tensor_fetches.append(decoded) else: operation_fetches.append(decoded) return decoded elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): tensor_fetches.append(fetch) return fetch else: graph_element = self.graph.as_graph_element(fetch) return _fetch_preprocesing_callback(graph_element) fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) # Expand composite tensors into their component dense Tensors. tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) for f in (flat_feeds + tensor_fetches + operation_fetches): if f.graph is not self._func_graph: raise ValueError( "Can only prune function whose feeds and fetches " "are from this graph (%s). Input %s is from graph %s" % (self._func_graph, f, f.graph)) with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph(name) lift_map = lift_to_graph.lift_to_graph( operation_fetches + tensor_fetches, pruned_graph, sources=flat_feeds + self.graph.internal_captures) # Note that we add the component tensors of any composite tensors to the # returned function's outputs list; the list must contain these component # tensors, or the function's sparse outputs won't work properly. pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) for external_capture, internal_capture in self.graph.captures: pruned_graph.add_capture(external_capture, lift_map[internal_capture]) for ti in tensor_infos: if ti.WhichOneof("encoding") == "name": # Dense tensors only t = pruned_graph.as_graph_element(ti.name) if tensor_util.is_tensor(t): t.set_shape(tensor_shape.TensorShape(ti.tensor_shape)) # pylint: disable=protected-access for f in self.graph._functions.values(): pruned_graph._add_function(f) # pylint: enable=protected-access pruned_graph.variables = self.graph.variables def _structured_output_mapping(fetched): """callback for `nest.map_structure()`""" lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None return lifted # expand_composites=True here causes composite tensors to be expanded # into their component dense Tensors, mapped to the new graph, and then # reconstituted into their original composite form. pruned_graph.structured_outputs = nest.map_structure( _structured_output_mapping, fetches, expand_composites=True) pruned_graph.structured_input_signature = input_signature pruned_fn = WrappedFunction(pruned_graph, variable_holder=self._variable_holder) pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access # TODO(kathywu): Enable keyword arguments if an input signature is specified pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds] # pylint: disable=protected-access return pruned_fn
def prune(self, feeds, fetches, name=None): name = name or "pruned" flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches) for f in flat_feeds: if not isinstance(f, ops.Tensor): raise ValueError("Feeds must be tensors.") # Ignoring all feeds that are captures allows prune to be called # using wrapped_func.inputs even when it uses variables internal_captures = self.graph.internal_captures flat_feeds = [f for f in flat_feeds if f not in internal_captures] tensor_fetches = [] operation_fetches = [] for f in flat_fetches: if isinstance(f, ops.Tensor): tensor_fetches.append(f) elif isinstance(f, ops.Operation): operation_fetches.append(f) else: raise ValueError("Fetches must be tensors or operations.") for f in flat_feeds + flat_fetches: if f.graph is not self._func_graph: raise ValueError( "Can only prune function whose feeds and fetches " "are from this graph (%s). Tensor %s from graph %s" % (self._func_graph, f, f.graph)) with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph(name) with ops.control_dependencies(operation_fetches): if tensor_fetches: identity_fetches = array_ops.identity_n(tensor_fetches) sink_tensor = identity_fetches[0] else: identity_fetches = [] sink_tensor = array_ops.zeros([]) lift_map = lift_to_graph.lift_to_graph([sink_tensor], pruned_graph, sources=flat_feeds + internal_captures) for original_fetch, identity_fetch in zip(tensor_fetches, identity_fetches): lift_map[original_fetch] = lift_map[identity_fetch] pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor)) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) if not tensor_fetches: pruned_graph.outputs.append(lift_map[sink_tensor]) for external_capture, internal_capture in self.graph.captures.items(): pruned_graph.captures[external_capture] = lift_map[ internal_capture] pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) pruned_graph.inputs.extend(pruned_graph.captures.values()) pruned_graph.variables = self.graph.variables def _structured_output_mapping(fetched): lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None return lifted pruned_graph.structured_outputs = nest.map_structure( _structured_output_mapping, fetches) pruned_fn = WrappedFunction(pruned_graph, variable_holder=self._variable_holder) pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access pruned_fn._arg_keywords = [] # pylint: disable=protected-access return pruned_fn
def add_op_to_graph(num_ops): with func_graph.FuncGraph("add").as_default(): a = gen_array_ops.placeholder(dtypes.float32) b = gen_array_ops.placeholder(dtypes.float32) for _ in range(num_ops): gen_math_ops.add(a, b)
def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls, shape_invariants, symbol_names): """Makes a best-effort attempt to substitute undefineds with placeholders. Note: this substitution requires two things to happen: 1. the types of loop variables could be inferred (usually by staging one iteration) 2. these types could be replaced by placeholders (e.g. zero values, for tensors. Args: body: a function representing the loop body. See while_stmt. get_state: state getter for the loop statement. See while_stmt. set_state: state getter for the loop statement. See while_stmt. init_vars: loop variables before entering the loop. See while_stmt. nulls: list of boolean flags indicating whether the corresponding loop var is None or undefined. shape_invariants: user-specified shape invariant for each loop variable. symbol_names: list of loop variable names. See while_stmt. Returns: A tuple (success, new_init_vars, extra_shape_invariants, failure_message): * success is a boolean flag indicating whether types could be successfully inferred (step 1 above) * new_init_vars contains the loop vars, with None or undefined values replaced by default values, where possible (step 2 above) * extra_shape_invariants contains shape invariants that would be needed by while_stmt, for instance if the placeholder values had a shape different from the corresponding loop outputs """ state_modified = False failure_message = None try: # Stage an iteration of the loop body in a temporary graph. with func_graph.FuncGraph('tmp').as_default(): # This call to set_state helps report nicer error messages when symbols # are inconsistently used. # Another complication is that non_tensor values will be autocast to # Tensor by while_loop, and their static value lost. So we need to account # that here. def autocast_to_tensor(v): if isinstance( v, (int, float, bool, str, list, tuple, np.ndarray, np.generic)): init_val = ops.convert_to_tensor_v2(v) return array_ops.placeholder(init_val.dtype, init_val.shape) return v autocast_init_vars = nest.map_structure(autocast_to_tensor, init_vars) set_state(autocast_init_vars) state_modified = True body() first_iter_vars = get_state() except (UnboundLocalError, TypeError, ValueError, KeyError): ag_logging.log(1, 'Caught error while staging loop body', exc_info=True) # Fall back to the old functionality. It will likely result in an input # validation failure. failure_message = ('Note: AutoGraph tried to determine initial values, but ' 'ran into an error and gave up:\n\t' + '\t'.join(traceback.format_exception(*sys.exc_info()))) first_iter_vars = None finally: if state_modified: set_state(init_vars) if first_iter_vars is not None: # Note: the actual placeholder value doesn't matter, because as the staging # proved, it will be replaced by an actual value before being read. inits_and_invariants = tuple( (_placeholder_value(iv, i, v) if n else (v, None)) for v, n, iv, i in zip(init_vars, nulls, first_iter_vars, shape_invariants)) init_vars, extra_shape_invariants = zip(*inits_and_invariants) success = True else: success = False # This check runs regardless, in case we captured non-Tensor inputs. _verify_loop_init_vars( init_vars, symbol_names, first_iter_vars, extra_message=failure_message) return success, init_vars, extra_shape_invariants