Esempio n. 1
0
def wrap_cached_variables(concrete_function):
    """Wraps the concrete function if it uses cached read tensors.

  This function creates a new concrete function that captures variables
  instead of the cached read tensors.

  Args:
    concrete_function: A Concrete function that maybe captures cached read
      tensors.

  Returns:
    A concrete function that wraps the original concrete function, which
    captures variables instead. If the original function did not capture any
    cached values, then the function is not wrapped and the original object is
    returned.
  """
    outer_graph = func_graph_module.FuncGraph("{}_no_cache".format(
        concrete_function.graph.name))
    captures = concrete_function.graph._captures  # pylint: disable=protected-access
    mapped_captures = None
    remapped_captures = {}

    # Update the external captures to use read tensors generated in the outer
    # graph.
    with outer_graph.as_default():
        for capture, placeholder in concrete_function.graph.captures:
            cached_variable = getattr(capture, "_cached_variable", None)
            if cached_variable is None:
                continue
            cached_variable = cached_variable()
            new_cached_value = cached_variable.read_value()
            remapped_captures[id(capture)] = captures[id(capture)]
            captures[id(capture)] = (new_cached_value, placeholder)
            mapped_captures = True

    if not mapped_captures:
        return concrete_function

    inner_concrete = defun.ConcreteFunction(concrete_function.graph)

    def wrap_function(*args):
        return inner_concrete._call_flat(args, inner_concrete.captured_inputs)  # pylint:disable=protected-access

    args = nest.flatten(concrete_function.structured_input_signature,
                        expand_composites=True)
    func_graph_module.func_graph_from_py_func(None,
                                              wrap_function,
                                              args=tuple(args),
                                              kwargs={},
                                              func_graph=outer_graph)
    fn = defun.ConcreteFunction(outer_graph,
                                function_spec=concrete_function._function_spec)  # pylint: disable=protected-access
    fn._arg_keywords = concrete_function._arg_keywords  # pylint: disable=protected-access
    fn._num_positional_args = concrete_function._num_positional_args  # pylint: disable=protected-access

    # Return the captures to their original values
    for key, capture in remapped_captures.items():
        captures[key] = capture
    return fn
Esempio n. 2
0
def load_function_def_library(library):
  """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
  functions = {}

  for fdef in _sort_function_defs(library):
    copy = _fix_fdef(fdef, functions)

    func_graph = function_def_lib.function_def_to_graph(copy)
    for dep in _list_function_deps(fdef):
      functions[dep].add_to_graph(func_graph)
    func = function_lib.ConcreteFunction(func_graph)
    func.add_to_graph()

    functions[fdef.signature.name] = func

    # Also register the gradients in the current root context.
    with ops.init_scope():
      func._register_gradient()  # pylint: disable=protected-access

  return functions
def load_function_def_library(library):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    # TODO(andresp): Look into restoring gradient function information.
    functions = {}
    name_mapping = {}
    # Note: Use a new graph to allow function_def_to_graph to help validating
    # that the functions are loaded correctly. This is not possible to do
    # just in eager mode as there is no python API to find if a function has
    # been registered in eager. Note also that despite this the created
    # func_graphs can still be used in eager or in other graphs.
    with ops.Graph().as_default() as import_graph:
        for fdef in _sort_function_defs(library):
            copy = _fix_fdef(fdef, name_mapping)

            func_graph = function_def_lib.function_def_to_graph(copy)
            func = function_lib.ConcreteFunction(func_graph)
            func.add_to_graph(import_graph)

            name_mapping[fdef.signature.name] = func.name
            functions[fdef.signature.name] = func
    return functions
def _construct_concrete_function(input_func, graph_def, converted_handles):
    """Creates a ConcreteFunction from the input function and frozen graph.

  Args:
    input_func: ConcreteFunction.
    graph_def: TensorFlow GraphDef.
    converted_handles: a set of handles of the varialbes in input_func that were
      converted to constant in `graph_def`.

  Returns:
    ConcreteFunction containing the graph_def.
  """
    captured_inputs = input_func.inputs[-len(input_func.captured_inputs):]
    captured_input_indices = [
        input_func.captured_inputs.index(handle)
        for handle in converted_handles
    ]
    converted_inputs = set(
        [captured_inputs[index] for index in captured_input_indices])
    not_converted_inputs = set(input_func.inputs).difference(converted_inputs)

    output_graph = func_graph.FuncGraph(input_func.graph.name)
    with output_graph.as_default():
        importer.import_graph_def(graph_def, name="")
        output_graph.inputs = _get_tensors_from_graph(output_graph,
                                                      not_converted_inputs)
        output_graph.outputs = _get_tensors_from_graph(output_graph,
                                                       input_func.outputs)

    output_graph.structured_outputs = input_func.graph.structured_outputs
    output_graph.structured_input_signature = (
        input_func.graph.structured_input_signature)

    # pylint: disable=protected-access
    # Create the ConcreteFunction and add it to the global context.
    output_func = function.ConcreteFunction(output_graph,
                                            attrs=input_func._attrs,
                                            signature=input_func._signature)
    output_func.add_to_graph()

    # Inject the captured inputs into the ConcreteFunction.
    output_func._captured_inputs = [
        handle for handle in input_func.captured_inputs
        if handle not in converted_handles
    ]
    output_func.graph.variables = [
        var for var in input_func.graph.variables
        if var.handle not in converted_handles
    ]
    output_func._arg_keywords = input_func._arg_keywords
    output_func._num_positional_args = input_func._num_positional_args
    # pylint: enable=protected-access

    # Register the gradients in the current root context.
    with ops.init_scope():
        output_func._register_gradient()  # pylint: disable=protected-access
    return output_func
Esempio n. 5
0
    def _convert_saved_model_v2(self):
        """Convert the input SavedModel in 2.0 format."""
        self._saved_model = load.load(self._input_saved_model_dir,
                                      self._input_saved_model_tags)
        func = self._saved_model.signatures[
            self._input_saved_model_signature_key]
        frozen_func = convert_to_constants.convert_variables_to_constants_v2(
            func)
        self._grappler_meta_graph_def = saver.export_meta_graph(
            graph_def=frozen_func.graph.as_graph_def(),
            graph=frozen_func.graph)

        # Add a collection 'train_op' so that Grappler knows the outputs.
        fetch_collection = meta_graph_pb2.CollectionDef()
        for array in func.inputs + func.outputs:
            fetch_collection.node_list.value.append(array.name)
        self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
            fetch_collection)

        # Run TRT optimizer in Grappler to convert the graph.
        self._run_conversion()

        def _get_tensor(graph, tensors):
            new_tensors = []
            for tensor in tensors:
                new_tensor = graph.get_tensor_by_name(tensor.name)
                new_tensor.set_shape(tensor.shape)
                new_tensors.append(new_tensor)
            return new_tensors

        # TODO(laigd): do we need to use different name e.g. "trt_func_graph"?
        converted_graph = func_graph.FuncGraph(func.graph.name)
        with converted_graph.as_default():
            importer.import_graph_def(self._converted_graph_def, name="")

        converted_graph.inputs = _get_tensor(converted_graph,
                                             func.graph.inputs)
        converted_graph.outputs = _get_tensor(converted_graph,
                                              func.graph.outputs)
        converted_graph.structured_outputs = func.graph.structured_outputs
        converted_graph.structured_input_signature = (
            func.graph.structured_input_signature)

        # pylint: disable=protected-access
        # TODO(laigd): should we set up the signature as well?
        self._converted_func = function.ConcreteFunction(converted_graph,
                                                         attrs=None,
                                                         signature=None)
        self._converted_func.add_to_graph()
        self._converted_func._arg_keywords = func._arg_keywords
        self._converted_func._num_positional_args = func._num_positional_args
        self._converted_func._captured_inputs = func._captured_inputs
        self._converted_func.graph.variables = func.graph.variables
Esempio n. 6
0
def load_function_def_library(library, load_shared_name_suffix=None):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.
    load_shared_name_suffix: If specified, used to uniquify shared
      names. Otherwise a unique name is generated.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    library_function_names = set(fdef.signature.name
                                 for fdef in library.function)
    functions = {}

    if load_shared_name_suffix is None:
        load_shared_name_suffix = "_load_{}".format(ops.uid())
    for fdef in _sort_function_defs(library, library_function_names):
        copy = _fix_fdef(fdef, functions, load_shared_name_suffix)

        # There is no need to copy all functions into the function def graph. It
        # leads to a O(n^2) increase of memory when importing functions and the
        # extra function definitions are a no-op since they already imported as a
        # function before and passed in explicitly (due to the topologic sort
        # import).
        func_graph = function_def_lib.function_def_to_graph(
            copy, copy_functions=False)

        for dep in _list_function_deps(fdef, library_function_names):
            functions[dep].add_to_graph(func_graph)
        func = function_lib.ConcreteFunction(func_graph)
        func.add_to_graph()
        if context.executing_eagerly():
            func.add_to_graph(ops.get_default_graph())

        functions[fdef.signature.name] = func

        # Also register the gradients in the current root context.
        with ops.init_scope():
            func._register_gradient()  # pylint: disable=protected-access

    return functions
Esempio n. 7
0
def load_function_def_library(library):
  """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
  functions = {}

  load_shared_name_suffix = "_load_{}".format(ops.uid())
  for fdef in _sort_function_defs(library):
    copy = _fix_fdef(fdef, functions, load_shared_name_suffix)

    # There is no need to copy functions into the function def graph.
    # It leads to a O(n^2) increase of memory when importing functions
    # and the extra function definitions are a no-op since they already
    # imported as a function before (due to the topologic sort import).
    func_graph = function_def_lib.function_def_to_graph(
        copy, copy_functions=False)

    for dep in _list_function_deps(fdef):
      functions[dep].add_to_graph(func_graph)
    func = function_lib.ConcreteFunction(func_graph)
    func.add_to_graph()

    functions[fdef.signature.name] = func

    # Also register the gradients in the current root context.
    with ops.init_scope():
      func._register_gradient()  # pylint: disable=protected-access

  return functions
Esempio n. 8
0
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
    """Create a `ConcreteFunction` from `args` and `kwargs`."""

    self.tracing_count += 1
    if self.input_signature is None:
        arglen = len(args)
    else:
        arglen = len(self.input_signature)
    base_arg_names = self._function_spec.arg_names[:arglen]
    num_missing_args = arglen - len(self._function_spec.arg_names)
    missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
    # Produce a list of missing args of the form ["arg_0", "arg_1", ...],
    # where arg is based on the self._function_spec.vararg_name.
    missing_arg_names = [
        "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
    ]
    arg_names = base_arg_names + missing_arg_names

    graph_function = _function.ConcreteFunction(
        func_graph_module.func_graph_from_py_func(
            self._name,
            self._python_function,
            args,
            kwargs,
            self.input_signature,
            autograph=self._autograph,
            autograph_options=self._autograph_options,
            arg_names=arg_names,
            override_flat_arg_shapes=override_flat_arg_shapes,
            capture_by_value=self._capture_by_value,
            add_control_dependencies=False,
        ),
        self._function_attributes,
        # Tell the ConcreteFunction to clean up its graph once it goes out of
        # scope. This is not the default behavior since it gets used in some
        # places (like Keras) where the FuncGraph lives longer than the
        # ConcreteFunction.
        shared_func_graph=False,
    )
    return graph_function
Esempio n. 9
0
def load_function_def_library(library):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    functions = {}

    # Note: Use a new graph to allow function_def_to_graph to help validating
    # that the functions are loaded correctly. This is not possible to do
    # just in eager mode as there is no python API to find if a function has
    # been registered in eager. Note also that despite this the created
    # func_graphs can still be used in eager or in other graphs.
    import_graph = ops.Graph()

    for fdef in _sort_function_defs(library):
        with import_graph.as_default():
            copy = _fix_fdef(fdef, functions)

            func_graph = function_def_lib.function_def_to_graph(copy)
            func = function_lib.ConcreteFunction(func_graph)
            func.add_to_graph(import_graph)

            functions[fdef.signature.name] = func

        # Also register the gradients in the current root context.
        with ops.init_scope():
            func._register_gradient()  # pylint: disable=protected-access

    return functions
Esempio n. 10
0
def _construct_concrete_function(input_func, graph_def):
    """Creates a ConcreteFunction from the input function and frozen graph.

  Args:
    input_func: ConcreteFunction.
    graph_def: TensorFlow GraphDef.

  Returns:
    ConcreteFunction containing the graph_def.
  """
    output_graph = func_graph.FuncGraph(input_func.graph.name)
    with output_graph.as_default():
        importer.import_graph_def(graph_def, name="")
        output_graph.inputs = _get_tensors_from_graph(output_graph,
                                                      input_func.inputs)
        output_graph.outputs = _get_tensors_from_graph(output_graph,
                                                       input_func.outputs)

    output_graph.structured_outputs = input_func.graph.structured_outputs
    output_graph.structured_input_signature = (
        input_func.graph.structured_input_signature)

    # pylint: disable=protected-access
    # Create the ConcreteFunction and add it to the global context.
    output_func = function.ConcreteFunction(output_graph,
                                            attrs=input_func._attrs,
                                            signature=input_func._signature)
    output_func.add_to_graph()

    # Inject the captured inputs into the ConcreteFunction.
    output_func._captured_inputs = input_func.captured_inputs
    output_func.graph.variables = input_func.graph.variables
    output_func._arg_keywords = input_func._arg_keywords
    output_func._num_positional_args = input_func._num_positional_args
    # pylint: enable=protected-access

    # Register the gradients in the current root context.
    with ops.init_scope():
        output_func._register_gradient()  # pylint: disable=protected-access
    return output_func
def load_function_def_library(library, load_shared_name_suffix=None):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Args:
    library: FunctionDefLibrary proto message.
    load_shared_name_suffix: If specified, used to uniquify shared
      names. Otherwise, a unique name is generated.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    library_function_names = set(fdef.signature.name
                                 for fdef in library.function)
    functions = {}
    renamed_functions = {}

    # Our graph building code currently requires functions to be registered with
    # some tf.Graph in order to import functions using the
    # op-name-is-function-name calling convention. To avoid leaking memory into
    # the global default graph when executing eagerly, we create a temporary
    # Graph.
    #
    # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
    # fixing function_def_to_graph_def.
    if ops.executing_eagerly_outside_functions():
        graph = ops.Graph()
    else:
        graph = ops.get_default_graph()

    if load_shared_name_suffix is None:
        load_shared_name_suffix = "_load_{}".format(ops.uid())
    for fdef in _sort_function_defs(library, library_function_names):
        copy = _fix_fdef(fdef, functions, load_shared_name_suffix)

        # There is no need to copy all functions into the function def graph. It
        # leads to a O(n^2) increase of memory when importing functions and the
        # extra function definitions are a no-op since they already imported as a
        # function before and passed in explicitly (due to the topologic sort
        # import).
        with graph.as_default():
            func_graph = function_def_lib.function_def_to_graph(copy)
        _restore_gradient_functions(func_graph, renamed_functions)

        for dep in _list_function_deps(fdef, library_function_names):
            functions[dep].add_to_graph(func_graph)

        # We do not initialize the new ConcreteFunction's function_spec and/or
        # arg_keywords here (which are used to parse the structured and flat
        # signatures, respectively). ConcreteFunction that are part of a saved
        # function is set up later by recreate_function(); and bare ConcreteFunction
        # is set up by by setup_bare_concrete_function().
        func = function_lib.ConcreteFunction(func_graph)
        func.add_to_graph(graph)

        functions[fdef.signature.name] = func
        renamed_functions[func.name] = func
        if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
            # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
            # is fixed. Currently it's leaking memory to maintain bug compatibility
            # with previous behavior.
            func.add_to_graph(ops.get_default_graph())

    return functions
def load_function_def_library(library,
                              load_shared_name_suffix=None,
                              wrapper_function=None):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Gradients are re-registered under new names. Ops that reference the gradients
  are updated to reflect the new registered names.

  Args:
    library: FunctionDefLibrary proto message.
    load_shared_name_suffix: If specified, used to uniquify shared
      names. Otherwise, a unique name is generated.
    wrapper_function: An object that will be wrapped on newly created functions.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    library_function_names = set(fdef.signature.name
                                 for fdef in library.function)
    functions = {}
    renamed_functions = {}

    # Our graph building code currently requires functions to be registered with
    # some tf.Graph in order to import functions using the
    # op-name-is-function-name calling convention. To avoid leaking memory into
    # the global default graph when executing eagerly, we create a temporary
    # Graph.
    #
    # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
    # fixing function_def_to_graph_def.
    if ops.executing_eagerly_outside_functions():
        graph = ops.Graph()
    else:
        graph = ops.get_default_graph()

    if load_shared_name_suffix is None:
        load_shared_name_suffix = "_load_{}".format(ops.uid())

    # Custom gradient functions must be re-registered under new UIDs.
    library_gradient_names = {}  # Maps old op type to old function name
    new_gradient_op_types = {}  # Maps old gradient op type to new op type.
    gradients_to_register = {}  # Maps old function name to new op type
    for gdef in library.registered_gradients:
        if gdef.registered_op_type:
            new_op_type = custom_gradient.generate_name()
            old_op_type = compat.as_bytes(gdef.registered_op_type)

            library_gradient_names[old_op_type] = gdef.gradient_func
            new_gradient_op_types[old_op_type] = new_op_type
            gradients_to_register[gdef.gradient_func] = new_op_type

    function_deps = {}
    for fdef in library.function:
        function_deps[fdef.signature.name] = _list_function_deps(
            fdef, library_function_names, library_gradient_names)

    loaded_gradients = {}
    for fdef in _sort_function_defs(library, function_deps):
        copy = _fix_fdef(fdef, functions, load_shared_name_suffix,
                         new_gradient_op_types)

        # There is no need to copy all functions into the function def graph. It
        # leads to a O(n^2) increase of memory when importing functions and the
        # extra function definitions are a no-op since they already imported as a
        # function before and passed in explicitly (due to the topologic sort
        # import).
        with graph.as_default():
            func_graph = function_def_lib.function_def_to_graph(copy)
        # Restores gradients for function-call ops (not the same as ops that use
        # custom gradients)
        _restore_gradient_functions(func_graph, renamed_functions,
                                    loaded_gradients)

        for dep in function_deps[fdef.signature.name]:
            functions[dep].add_to_graph(func_graph)

        # We do not initialize the new ConcreteFunction's function_spec and/or
        # arg_keywords here (which are used to parse the structured and flat
        # signatures, respectively). ConcreteFunction that are part of a saved
        # function is set up later by recreate_function(); and bare ConcreteFunction
        # is set up by by setup_bare_concrete_function().
        # However, we copy the FunctionDef attributes to the new ConcreteFunction,
        # excluding the "_input_shapes", which may cause an error during input shape
        # initialization at a later stage.
        if "_input_shapes" in copy.attr:
            del copy.attr["_input_shapes"]
        func = function_lib.ConcreteFunction(func_graph, attrs=copy.attr)
        if wrapper_function:
            func = wrapper_function(func)
        func.add_to_graph(graph)

        functions[fdef.signature.name] = func
        renamed_functions[func.name] = func
        if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
            # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
            # is fixed. Currently it's leaking memory to maintain bug compatibility
            # with previous behavior.
            func.add_to_graph(ops.get_default_graph())

        if fdef.signature.name in gradients_to_register:
            gradient_op_type = gradients_to_register[fdef.signature.name]
            loaded_gradients[compat.as_bytes(gradient_op_type)] = func
            ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func))

    return functions