def wrap_cached_variables(concrete_function): """Wraps the concrete function if it uses cached read tensors. This function creates a new concrete function that captures variables instead of the cached read tensors. Args: concrete_function: A Concrete function that maybe captures cached read tensors. Returns: A concrete function that wraps the original concrete function, which captures variables instead. If the original function did not capture any cached values, then the function is not wrapped and the original object is returned. """ outer_graph = func_graph_module.FuncGraph("{}_no_cache".format( concrete_function.graph.name)) captures = concrete_function.graph._captures # pylint: disable=protected-access mapped_captures = None remapped_captures = {} # Update the external captures to use read tensors generated in the outer # graph. with outer_graph.as_default(): for capture, placeholder in concrete_function.graph.captures: cached_variable = getattr(capture, "_cached_variable", None) if cached_variable is None: continue cached_variable = cached_variable() new_cached_value = cached_variable.read_value() remapped_captures[id(capture)] = captures[id(capture)] captures[id(capture)] = (new_cached_value, placeholder) mapped_captures = True if not mapped_captures: return concrete_function inner_concrete = defun.ConcreteFunction(concrete_function.graph) def wrap_function(*args): return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access args = nest.flatten(concrete_function.structured_input_signature, expand_composites=True) func_graph_module.func_graph_from_py_func(None, wrap_function, args=tuple(args), kwargs={}, func_graph=outer_graph) fn = defun.ConcreteFunction(outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access # Return the captures to their original values for key, capture in remapped_captures.items(): captures[key] = capture return fn
def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ functions = {} for fdef in _sort_function_defs(library): copy = _fix_fdef(fdef, functions) func_graph = function_def_lib.function_def_to_graph(copy) for dep in _list_function_deps(fdef): functions[dep].add_to_graph(func_graph) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph() functions[fdef.signature.name] = func # Also register the gradients in the current root context. with ops.init_scope(): func._register_gradient() # pylint: disable=protected-access return functions
def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ # TODO(andresp): Look into restoring gradient function information. functions = {} name_mapping = {} # Note: Use a new graph to allow function_def_to_graph to help validating # that the functions are loaded correctly. This is not possible to do # just in eager mode as there is no python API to find if a function has # been registered in eager. Note also that despite this the created # func_graphs can still be used in eager or in other graphs. with ops.Graph().as_default() as import_graph: for fdef in _sort_function_defs(library): copy = _fix_fdef(fdef, name_mapping) func_graph = function_def_lib.function_def_to_graph(copy) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph(import_graph) name_mapping[fdef.signature.name] = func.name functions[fdef.signature.name] = func return functions
def _construct_concrete_function(input_func, graph_def, converted_handles): """Creates a ConcreteFunction from the input function and frozen graph. Args: input_func: ConcreteFunction. graph_def: TensorFlow GraphDef. converted_handles: a set of handles of the varialbes in input_func that were converted to constant in `graph_def`. Returns: ConcreteFunction containing the graph_def. """ captured_inputs = input_func.inputs[-len(input_func.captured_inputs):] captured_input_indices = [ input_func.captured_inputs.index(handle) for handle in converted_handles ] converted_inputs = set( [captured_inputs[index] for index in captured_input_indices]) not_converted_inputs = set(input_func.inputs).difference(converted_inputs) output_graph = func_graph.FuncGraph(input_func.graph.name) with output_graph.as_default(): importer.import_graph_def(graph_def, name="") output_graph.inputs = _get_tensors_from_graph(output_graph, not_converted_inputs) output_graph.outputs = _get_tensors_from_graph(output_graph, input_func.outputs) output_graph.structured_outputs = input_func.graph.structured_outputs output_graph.structured_input_signature = ( input_func.graph.structured_input_signature) # pylint: disable=protected-access # Create the ConcreteFunction and add it to the global context. output_func = function.ConcreteFunction(output_graph, attrs=input_func._attrs, signature=input_func._signature) output_func.add_to_graph() # Inject the captured inputs into the ConcreteFunction. output_func._captured_inputs = [ handle for handle in input_func.captured_inputs if handle not in converted_handles ] output_func.graph.variables = [ var for var in input_func.graph.variables if var.handle not in converted_handles ] output_func._arg_keywords = input_func._arg_keywords output_func._num_positional_args = input_func._num_positional_args # pylint: enable=protected-access # Register the gradients in the current root context. with ops.init_scope(): output_func._register_gradient() # pylint: disable=protected-access return output_func
def _convert_saved_model_v2(self): """Convert the input SavedModel in 2.0 format.""" self._saved_model = load.load(self._input_saved_model_dir, self._input_saved_model_tags) func = self._saved_model.signatures[ self._input_saved_model_signature_key] frozen_func = convert_to_constants.convert_variables_to_constants_v2( func) self._grappler_meta_graph_def = saver.export_meta_graph( graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() for array in func.inputs + func.outputs: fetch_collection.node_list.value.append(array.name) self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom( fetch_collection) # Run TRT optimizer in Grappler to convert the graph. self._run_conversion() def _get_tensor(graph, tensors): new_tensors = [] for tensor in tensors: new_tensor = graph.get_tensor_by_name(tensor.name) new_tensor.set_shape(tensor.shape) new_tensors.append(new_tensor) return new_tensors # TODO(laigd): do we need to use different name e.g. "trt_func_graph"? converted_graph = func_graph.FuncGraph(func.graph.name) with converted_graph.as_default(): importer.import_graph_def(self._converted_graph_def, name="") converted_graph.inputs = _get_tensor(converted_graph, func.graph.inputs) converted_graph.outputs = _get_tensor(converted_graph, func.graph.outputs) converted_graph.structured_outputs = func.graph.structured_outputs converted_graph.structured_input_signature = ( func.graph.structured_input_signature) # pylint: disable=protected-access # TODO(laigd): should we set up the signature as well? self._converted_func = function.ConcreteFunction(converted_graph, attrs=None, signature=None) self._converted_func.add_to_graph() self._converted_func._arg_keywords = func._arg_keywords self._converted_func._num_positional_args = func._num_positional_args self._converted_func._captured_inputs = func._captured_inputs self._converted_func.graph.variables = func.graph.variables
def load_function_def_library(library, load_shared_name_suffix=None): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. load_shared_name_suffix: If specified, used to uniquify shared names. Otherwise a unique name is generated. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ library_function_names = set(fdef.signature.name for fdef in library.function) functions = {} if load_shared_name_suffix is None: load_shared_name_suffix = "_load_{}".format(ops.uid()) for fdef in _sort_function_defs(library, library_function_names): copy = _fix_fdef(fdef, functions, load_shared_name_suffix) # There is no need to copy all functions into the function def graph. It # leads to a O(n^2) increase of memory when importing functions and the # extra function definitions are a no-op since they already imported as a # function before and passed in explicitly (due to the topologic sort # import). func_graph = function_def_lib.function_def_to_graph( copy, copy_functions=False) for dep in _list_function_deps(fdef, library_function_names): functions[dep].add_to_graph(func_graph) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph() if context.executing_eagerly(): func.add_to_graph(ops.get_default_graph()) functions[fdef.signature.name] = func # Also register the gradients in the current root context. with ops.init_scope(): func._register_gradient() # pylint: disable=protected-access return functions
def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ functions = {} load_shared_name_suffix = "_load_{}".format(ops.uid()) for fdef in _sort_function_defs(library): copy = _fix_fdef(fdef, functions, load_shared_name_suffix) # There is no need to copy functions into the function def graph. # It leads to a O(n^2) increase of memory when importing functions # and the extra function definitions are a no-op since they already # imported as a function before (due to the topologic sort import). func_graph = function_def_lib.function_def_to_graph( copy, copy_functions=False) for dep in _list_function_deps(fdef): functions[dep].add_to_graph(func_graph) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph() functions[fdef.signature.name] = func # Also register the gradients in the current root context. with ops.init_scope(): func._register_gradient() # pylint: disable=protected-access return functions
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None): """Create a `ConcreteFunction` from `args` and `kwargs`.""" self.tracing_count += 1 if self.input_signature is None: arglen = len(args) else: arglen = len(self.input_signature) base_arg_names = self._function_spec.arg_names[:arglen] num_missing_args = arglen - len(self._function_spec.arg_names) missing_arg_names = [self._function_spec.vararg_name] * num_missing_args # Produce a list of missing args of the form ["arg_0", "arg_1", ...], # where arg is based on the self._function_spec.vararg_name. missing_arg_names = [ "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) ] arg_names = base_arg_names + missing_arg_names graph_function = _function.ConcreteFunction( func_graph_module.func_graph_from_py_func( self._name, self._python_function, args, kwargs, self.input_signature, autograph=self._autograph, autograph_options=self._autograph_options, arg_names=arg_names, override_flat_arg_shapes=override_flat_arg_shapes, capture_by_value=self._capture_by_value, add_control_dependencies=False, ), self._function_attributes, # Tell the ConcreteFunction to clean up its graph once it goes out of # scope. This is not the default behavior since it gets used in some # places (like Keras) where the FuncGraph lives longer than the # ConcreteFunction. shared_func_graph=False, ) return graph_function
def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ functions = {} # Note: Use a new graph to allow function_def_to_graph to help validating # that the functions are loaded correctly. This is not possible to do # just in eager mode as there is no python API to find if a function has # been registered in eager. Note also that despite this the created # func_graphs can still be used in eager or in other graphs. import_graph = ops.Graph() for fdef in _sort_function_defs(library): with import_graph.as_default(): copy = _fix_fdef(fdef, functions) func_graph = function_def_lib.function_def_to_graph(copy) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph(import_graph) functions[fdef.signature.name] = func # Also register the gradients in the current root context. with ops.init_scope(): func._register_gradient() # pylint: disable=protected-access return functions
def _construct_concrete_function(input_func, graph_def): """Creates a ConcreteFunction from the input function and frozen graph. Args: input_func: ConcreteFunction. graph_def: TensorFlow GraphDef. Returns: ConcreteFunction containing the graph_def. """ output_graph = func_graph.FuncGraph(input_func.graph.name) with output_graph.as_default(): importer.import_graph_def(graph_def, name="") output_graph.inputs = _get_tensors_from_graph(output_graph, input_func.inputs) output_graph.outputs = _get_tensors_from_graph(output_graph, input_func.outputs) output_graph.structured_outputs = input_func.graph.structured_outputs output_graph.structured_input_signature = ( input_func.graph.structured_input_signature) # pylint: disable=protected-access # Create the ConcreteFunction and add it to the global context. output_func = function.ConcreteFunction(output_graph, attrs=input_func._attrs, signature=input_func._signature) output_func.add_to_graph() # Inject the captured inputs into the ConcreteFunction. output_func._captured_inputs = input_func.captured_inputs output_func.graph.variables = input_func.graph.variables output_func._arg_keywords = input_func._arg_keywords output_func._num_positional_args = input_func._num_positional_args # pylint: enable=protected-access # Register the gradients in the current root context. with ops.init_scope(): output_func._register_gradient() # pylint: disable=protected-access return output_func
def load_function_def_library(library, load_shared_name_suffix=None): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. load_shared_name_suffix: If specified, used to uniquify shared names. Otherwise, a unique name is generated. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ library_function_names = set(fdef.signature.name for fdef in library.function) functions = {} renamed_functions = {} # Our graph building code currently requires functions to be registered with # some tf.Graph in order to import functions using the # op-name-is-function-name calling convention. To avoid leaking memory into # the global default graph when executing eagerly, we create a temporary # Graph. # # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by # fixing function_def_to_graph_def. if ops.executing_eagerly_outside_functions(): graph = ops.Graph() else: graph = ops.get_default_graph() if load_shared_name_suffix is None: load_shared_name_suffix = "_load_{}".format(ops.uid()) for fdef in _sort_function_defs(library, library_function_names): copy = _fix_fdef(fdef, functions, load_shared_name_suffix) # There is no need to copy all functions into the function def graph. It # leads to a O(n^2) increase of memory when importing functions and the # extra function definitions are a no-op since they already imported as a # function before and passed in explicitly (due to the topologic sort # import). with graph.as_default(): func_graph = function_def_lib.function_def_to_graph(copy) _restore_gradient_functions(func_graph, renamed_functions) for dep in _list_function_deps(fdef, library_function_names): functions[dep].add_to_graph(func_graph) # We do not initialize the new ConcreteFunction's function_spec and/or # arg_keywords here (which are used to parse the structured and flat # signatures, respectively). ConcreteFunction that are part of a saved # function is set up later by recreate_function(); and bare ConcreteFunction # is set up by by setup_bare_concrete_function(). func = function_lib.ConcreteFunction(func_graph) func.add_to_graph(graph) functions[fdef.signature.name] = func renamed_functions[func.name] = func if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()): # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration # is fixed. Currently it's leaking memory to maintain bug compatibility # with previous behavior. func.add_to_graph(ops.get_default_graph()) return functions
def load_function_def_library(library, load_shared_name_suffix=None, wrapper_function=None): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Gradients are re-registered under new names. Ops that reference the gradients are updated to reflect the new registered names. Args: library: FunctionDefLibrary proto message. load_shared_name_suffix: If specified, used to uniquify shared names. Otherwise, a unique name is generated. wrapper_function: An object that will be wrapped on newly created functions. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ library_function_names = set(fdef.signature.name for fdef in library.function) functions = {} renamed_functions = {} # Our graph building code currently requires functions to be registered with # some tf.Graph in order to import functions using the # op-name-is-function-name calling convention. To avoid leaking memory into # the global default graph when executing eagerly, we create a temporary # Graph. # # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by # fixing function_def_to_graph_def. if ops.executing_eagerly_outside_functions(): graph = ops.Graph() else: graph = ops.get_default_graph() if load_shared_name_suffix is None: load_shared_name_suffix = "_load_{}".format(ops.uid()) # Custom gradient functions must be re-registered under new UIDs. library_gradient_names = {} # Maps old op type to old function name new_gradient_op_types = {} # Maps old gradient op type to new op type. gradients_to_register = {} # Maps old function name to new op type for gdef in library.registered_gradients: if gdef.registered_op_type: new_op_type = custom_gradient.generate_name() old_op_type = compat.as_bytes(gdef.registered_op_type) library_gradient_names[old_op_type] = gdef.gradient_func new_gradient_op_types[old_op_type] = new_op_type gradients_to_register[gdef.gradient_func] = new_op_type function_deps = {} for fdef in library.function: function_deps[fdef.signature.name] = _list_function_deps( fdef, library_function_names, library_gradient_names) loaded_gradients = {} for fdef in _sort_function_defs(library, function_deps): copy = _fix_fdef(fdef, functions, load_shared_name_suffix, new_gradient_op_types) # There is no need to copy all functions into the function def graph. It # leads to a O(n^2) increase of memory when importing functions and the # extra function definitions are a no-op since they already imported as a # function before and passed in explicitly (due to the topologic sort # import). with graph.as_default(): func_graph = function_def_lib.function_def_to_graph(copy) # Restores gradients for function-call ops (not the same as ops that use # custom gradients) _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients) for dep in function_deps[fdef.signature.name]: functions[dep].add_to_graph(func_graph) # We do not initialize the new ConcreteFunction's function_spec and/or # arg_keywords here (which are used to parse the structured and flat # signatures, respectively). ConcreteFunction that are part of a saved # function is set up later by recreate_function(); and bare ConcreteFunction # is set up by by setup_bare_concrete_function(). # However, we copy the FunctionDef attributes to the new ConcreteFunction, # excluding the "_input_shapes", which may cause an error during input shape # initialization at a later stage. if "_input_shapes" in copy.attr: del copy.attr["_input_shapes"] func = function_lib.ConcreteFunction(func_graph, attrs=copy.attr) if wrapper_function: func = wrapper_function(func) func.add_to_graph(graph) functions[fdef.signature.name] = func renamed_functions[func.name] = func if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()): # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration # is fixed. Currently it's leaking memory to maintain bug compatibility # with previous behavior. func.add_to_graph(ops.get_default_graph()) if fdef.signature.name in gradients_to_register: gradient_op_type = gradients_to_register[fdef.signature.name] loaded_gradients[compat.as_bytes(gradient_op_type)] = func ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func)) return functions