Exemplo n.º 1
0
  def testFromLibraryCyclicGradFuncs(self):

    @function.Defun(dtypes.float32)
    def F1(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    @function.Defun(dtypes.float32)
    def F2(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    # Create invalid function def library where F1 has gradient function F2 and
    # F2 has gradient function F1
    library = function_pb2.FunctionDefLibrary()
    library.function.extend([F1.definition, F2.definition])

    gradient1 = function_pb2.GradientDef()
    gradient1.function_name = F1.name
    gradient1.gradient_func = F2.name

    gradient2 = function_pb2.GradientDef()
    gradient2.function_name = F2.name
    gradient2.gradient_func = F1.name

    library.gradient.extend([gradient1, gradient2])

    with self.assertRaisesRegexp(
        ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
      function._from_library(library)
Exemplo n.º 2
0
    def testFromLibraryCyclicGradFuncs(self):
        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        @function.Defun(dtypes.float32)
        def F2(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        # Create invalid function def library where F1 has gradient function F2 and
        # F2 has gradient function F1
        library = function_pb2.FunctionDefLibrary()
        library.function.extend([F1.definition, F2.definition])

        gradient1 = function_pb2.GradientDef()
        gradient1.function_name = F1.name
        gradient1.gradient_func = F2.name

        gradient2 = function_pb2.GradientDef()
        gradient2.function_name = F2.name
        gradient2.gradient_func = F1.name

        library.gradient.extend([gradient1, gradient2])

        with self.assertRaisesRegexp(
                ValueError,
                "FunctionDefLibrary contains cyclic gradient functions!"):
            function._from_library(library)
Exemplo n.º 3
0
  def testFromLibraryMissingFuncDef(self):

    @function.Defun(dtypes.float32, dtypes.float32)
    def G1(x, dy):
      return x * dy

    @function.Defun(dtypes.float32)
    def F1(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    gradient = function_pb2.GradientDef()
    gradient.function_name = F1.name
    gradient.gradient_func = G1.name

    # Create invalid function def that is missing G1 function def
    library = function_pb2.FunctionDefLibrary()
    library.gradient.extend([gradient])
    library.function.extend([F1.definition])

    with self.assertRaisesRegexp(
        ValueError,
        "FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"):
      function._from_library(library)

    # Create invalid function def that is missing F1 function def
    library = function_pb2.FunctionDefLibrary()
    library.gradient.extend([gradient])
    library.function.extend([G1.definition])

    with self.assertRaisesRegexp(
        ValueError,
        "FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"):
      function._from_library(library)
Exemplo n.º 4
0
    def testFromLibraryMissingFuncDef(self):
        @function.Defun(dtypes.float32, dtypes.float32)
        def G1(x, dy):
            return x * dy

        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        gradient = function_pb2.GradientDef()
        gradient.function_name = F1.name
        gradient.gradient_func = G1.name

        # Create invalid function def that is missing G1 function def
        library = function_pb2.FunctionDefLibrary()
        library.gradient.extend([gradient])
        library.function.extend([F1.definition])

        with self.assertRaisesRegexp(
                ValueError,
                "FunctionDefLibrary missing 'G1_........' FunctionDef"):
            function._from_library(library)

        # Create invalid function def that is missing F1 function def
        library = function_pb2.FunctionDefLibrary()
        library.gradient.extend([gradient])
        library.function.extend([G1.definition])

        with self.assertRaisesRegexp(
                ValueError,
                "FunctionDefLibrary missing 'F1_........' FunctionDef"):
            function._from_library(library)
Exemplo n.º 5
0
  def testFromLibrary(self):
    # Define some functions with different gradient functions. Note that many of
    # the below functions are identical since function bodies don't matter for
    # this test.

    @function.Defun(dtypes.float32, dtypes.float32)
    def G1(x, dy):
      return x * dy

    @function.Defun(dtypes.float32, dtypes.float32)
    def G2(x, dy):
      return x * dy

    # F1 and F2 have the same gradient function
    @function.Defun(dtypes.float32, grad_func=G1)
    def F1(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    @function.Defun(dtypes.float32, grad_func=G1)
    def F2(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    # F3 has a different gradient function
    @function.Defun(dtypes.float32, grad_func=G2)
    def F3(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    # F4 has no gradient function
    @function.Defun(dtypes.float32)
    def F4(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    # Instantiate all functions
    g = ops.Graph()
    with g.as_default():
      c = constant_op.constant(1.0, dtypes.float32)
      f1 = F1(c)
      f2 = F2(c)
      f3 = F3(c)
      f4 = F4(c)
      gradients_impl.gradients([f1, f2, f3, f4], c)

    library = g.as_graph_def().library
    new_funcs = function._from_library(library)

    def CheckNewFunc(func):
      new_func = [f for f in new_funcs if f.name == func.name]
      self.assertEqual(len(new_func), 1)
      self.expectFunctionsEqual(func, new_func=new_func[0])

    CheckNewFunc(G1)
    CheckNewFunc(G2)
    CheckNewFunc(F1)
    CheckNewFunc(F2)
    CheckNewFunc(F3)
    CheckNewFunc(F4)
Exemplo n.º 6
0
    def testFromLibrary(self):
        # Define some functions with different gradient functions. Note that many of
        # the below functions are identical since function bodies don't matter for
        # this test.

        @function.Defun(dtypes.float32, dtypes.float32)
        def G1(x, dy):
            return x * dy

        @function.Defun(dtypes.float32, dtypes.float32)
        def G2(x, dy):
            return x * dy

        # F1 and F2 have the same gradient function
        @function.Defun(dtypes.float32, grad_func=G1)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        @function.Defun(dtypes.float32, grad_func=G1)
        def F2(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        # F3 has a different gradient function
        @function.Defun(dtypes.float32, grad_func=G2)
        def F3(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        # F4 has no gradient function
        @function.Defun(dtypes.float32)
        def F4(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        # Instantiate all functions
        g = ops.Graph()
        with g.as_default():
            c = constant_op.constant(1.0, dtypes.float32)
            f1 = F1(c)
            f2 = F2(c)
            f3 = F3(c)
            f4 = F4(c)
            gradients_impl.gradients([f1, f2, f3, f4], c)

        library = g.as_graph_def().library
        new_funcs = function._from_library(library)

        def CheckNewFunc(func):
            new_func = [f for f in new_funcs if f.name == func.name]
            self.assertEqual(len(new_func), 1)
            self.expectFunctionsEqual(func, new_func=new_func[0])

        CheckNewFunc(G1)
        CheckNewFunc(G2)
        CheckNewFunc(F1)
        CheckNewFunc(F2)
        CheckNewFunc(F3)
        CheckNewFunc(F4)
Exemplo n.º 7
0
 def testFromLibraryEmptyLib(self):
   library = function_pb2.FunctionDefLibrary()
   self.assertEqual(len(function._from_library(library)), 0)
Exemplo n.º 8
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      g._add_function(f)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

  # LINT.IfChange
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # TODO(ashankar): Should this just copy over or should it do some
    # more nuanced merging? For example, the graph may already have some
    # marked "bad versions" and we don't want to lose those because of
    # what's in graph_def.versions? The C++ ImporGraphDef does something
    # more nuanced.
    g.graph_def_versions.CopyFrom(graph_def.versions)

    if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
      if not scope:
        # The caller must have passed `name=''`.
        raise ValueError(
            'tf.import_graph_def() requires a non-empty `name` if `input_map` '
            'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
            '`input_map` values before calling tf.import_graph_def().')
      with ops.name_scope('_inputs'):
        input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Set any default attr values that aren't present.
      if node.op not in op_dict:
        raise ValueError('No op named %s in defined operations.' % node.op)
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)
      if producer_op_dict:
        # Remove any default attr values that aren't in op_def.
        if node.op in producer_op_dict:
          producer_op_def = producer_op_dict[node.op]
          # We make a copy of node.attr to iterate through since we
          # may modify node.attr inside the loop.
          for key in list(node.attr):
            if _FindAttrInOpDef(key, op_def) is None:
              # No attr_def in consumer, look in producer.
              attr_def = _FindAttrInOpDef(key, producer_op_def)
              if (attr_def and attr_def.HasField('default_value') and
                  node.attr[key] == attr_def.default_value):
                # Unknown attr had default value in producer, delete it
                # so it can be understood by consumer.
                del node.attr[key]

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)

      # Rewrite the colocation attributes in the graph, since the
      # names of new ops may have changed.
      for key, value in op.node_def.attr.items():
        if key == '_class':
          class_values = value.list
          new_class_values = []
          for class_value in class_values.s:
            if class_value.startswith(b'loc:@'):
              op_to_bind_to = class_value[5:].decode()
              # Find the op by its original name.
              if op_to_bind_to not in name_to_op:
                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' % (
                                     op_to_bind_to, node.name))
              original_op = name_to_op[op_to_bind_to]
              new_class_values.append(compat.as_bytes(
                  'loc:@' + original_op.name))
            else:
              new_class_values.append(class_value)
          value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
              s=new_class_values))

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def`.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(_InvalidNodeMessage(
                node, 'Input tensor %r %s' % (input_name, te)))

      # pylint: disable=protected-access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                   ', '.join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected-access

      if not g._is_function(op.type):  # pylint: disable=protected-access
        # Execute shape inference for this op.
        # NOTE(mrry): If the graph contains a cycle, the full shape information
        # may not be available for this op's inputs.
        ops.set_shapes_for_outputs(op)
      # For nodes with _output_shapes set, set the output shapes.
      if '_output_shapes' in op.node_def.attr:
        for i, output in enumerate(op.outputs):
          dims = op.node_def.attr['_output_shapes'].list.shape[i]
          output_shape = tensor_shape.TensorShape(
              None if dims.unknown_rank else
              [dim.size if dim.size >= 0 else None for dim in dims.dim])

          try:
            output.set_shape(output_shape)
          except ValueError as e:
            # If the output shape is incompatible with what is inferred
            # by the graph for a very specific whitelist of ops, then we
            # ignore this output shape.  This can happen if there is a
            # bug in the shape function for some operation, and the
            # serialized graph def has the incorrect shape set when
            # running on a newer binary with the fixed shape function.
            # This is an escape hatch that allows us to correct shape
            # functions that are not critical to correct execution but
            # would cause graphs to fail if imported after correcting.
            #
            # This can be removed after 2017/03/08.
            if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                           'FIFOQueue', 'PriorityQueue', 'QueueSize',
                           'Stack', 'Barrier', 'BarrierReadySize',
                           'BarrierIncompleteSize', 'HashTable',
                           'MutableHashTable',
                           'MutableHashTableOfTensors', 'Mutex',
                           'CuckooTable', 'IndexTable',
                           'WholeFileReader', 'TextLineReader',
                           'FixedLengthRecordReader',
                           'TFRecordReader', 'IdentityReader',
                           'RefSwitch', 'RefEnter', 'RefNextIteration',
                           'RefMerge', 'RefIdentity']:
              pass
            elif op.type in [
                'ConditionalAccumulator', 'SparseConditionalAccumulator',
                'Table'
            ]:
              # This can be removed after 2017/04/24.
              pass
            else:
              raise e

        del op.node_def.attr['_output_shapes']

      # Apply device functions for this op.
      # NOTE(mrry): We do this after configuring the inputs, because
      # the result of the device functions may depend on the inputs.
      with _MaybeDevice(node.device):
        g._apply_device_functions(op)  # pylint: disable=protected-access

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        name = compat.as_str(name)
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
Exemplo n.º 9
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      f.add_to_graph(g)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

  # LINT.IfChange
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # TODO(ashankar): Should this just copy over or should it do some
    # more nuanced merging? For example, the graph may already have some
    # marked "bad versions" and we don't want to lose those because of
    # what's in graph_def.versions? The C++ ImporGraphDef does something
    # more nuanced.
    g.graph_def_versions.CopyFrom(graph_def.versions)

    if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
      if not scope:
        # The caller must have passed `name=''`.
        raise ValueError(
            'tf.import_graph_def() requires a non-empty `name` if `input_map` '
            'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
            '`input_map` values before calling tf.import_graph_def().')
      with ops.name_scope('_inputs'):
        input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Check to see if this op's name matches a previously seen op
      if node.name in name_to_op:
        raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
      # Set any default attr values that aren't present.
      if node.op not in op_dict:
        raise ValueError('No op named %s in defined operations.' % node.op)
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)
      if producer_op_dict:
        # Remove any default attr values that aren't in op_def.
        if node.op in producer_op_dict:
          producer_op_def = producer_op_dict[node.op]
          # We make a copy of node.attr to iterate through since we
          # may modify node.attr inside the loop.
          for key in list(node.attr):
            if _FindAttrInOpDef(key, op_def) is None:
              # No attr_def in consumer, look in producer.
              attr_def = _FindAttrInOpDef(key, producer_op_def)
              if (attr_def and attr_def.HasField('default_value') and
                  node.attr[key] == attr_def.default_value):
                # Unknown attr had default value in producer, delete it
                # so it can be understood by consumer.
                del node.attr[key]

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # Maps from a node to the ops it is colocated with, if colocation
    # is specified in the attributes.
    colocation_pairs = collections.defaultdict(list)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)
      apply_device_function = True

      # Rewrite the colocation attributes in the graph, since the
      # names of new ops may have changed.
      for key, value in op.node_def.attr.items():
        if key == '_class':
          class_values = value.list
          new_class_values = []
          for class_value in class_values.s:
            if class_value.startswith(b'loc:@'):
              op_to_bind_to = class_value[5:].decode()
              # Find the op by its original name.
              if op_to_bind_to not in name_to_op:
                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' % (
                                     op_to_bind_to, node.name))
              original_op = name_to_op[op_to_bind_to]
              new_class_values.append(compat.as_bytes(
                  'loc:@' + original_op.name))
              if op_to_bind_to != node.name:
                # Keep track of this mapping for a later phase.
                colocation_pairs[op].append(original_op)
                # Don't apply this op's device function,
                # the colocation constraint will ensure
                # the proper device gets assigned at runtime.
                apply_device_function = False

            else:
              new_class_values.append(class_value)
          value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
              s=new_class_values))

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def`.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(_InvalidNodeMessage(
                node, 'Input tensor %r %s' % (input_name, te)))

      # pylint: disable=protected-access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                   ', '.join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected-access

      if not g._is_function(op.type):  # pylint: disable=protected-access
        # Execute shape inference for this op.
        # NOTE(mrry): If the graph contains a cycle, the full shape information
        # may not be available for this op's inputs.
        ops.set_shapes_for_outputs(op)
      # For nodes with _output_shapes set, set the output shapes.
      if '_output_shapes' in op.node_def.attr:
        for i, output in enumerate(op.outputs):
          dims = op.node_def.attr['_output_shapes'].list.shape[i]
          output_shape = tensor_shape.TensorShape(
              None if dims.unknown_rank else
              [dim.size if dim.size >= 0 else None for dim in dims.dim])

          try:
            output.set_shape(output_shape)
          except ValueError as e:
            # If the output shape is incompatible with what is inferred
            # by the graph for a very specific whitelist of ops, then we
            # ignore this output shape.  This can happen if there is a
            # bug in the shape function for some operation, and the
            # serialized graph def has the incorrect shape set when
            # running on a newer binary with the fixed shape function.
            # This is an escape hatch that allows us to correct shape
            # functions that are not critical to correct execution but
            # would cause graphs to fail if imported after correcting.
            #
            # This can be removed after 2017/03/08.
            if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                           'FIFOQueue', 'PriorityQueue', 'QueueSize',
                           'Stack', 'Barrier', 'BarrierReadySize',
                           'BarrierIncompleteSize', 'HashTable',
                           'MutableHashTable',
                           'MutableHashTableOfTensors', 'Mutex',
                           'CuckooTable', 'IndexTable',
                           'WholeFileReader', 'TextLineReader',
                           'FixedLengthRecordReader',
                           'TFRecordReader', 'IdentityReader',
                           'LMDBReader',
                           'RefSwitch', 'RefEnter', 'RefNextIteration',
                           'RefMerge', 'RefIdentity']:
              pass
            elif op.type in [
                'ConditionalAccumulator', 'SparseConditionalAccumulator',
                'Table'
            ]:
              # This can be removed after 2017/04/24.
              pass
            else:
              raise e

        del op.node_def.attr['_output_shapes']

      # NOTE(mrry): We do this after configuring the inputs, because
      # the result of the device functions may depend on the inputs.
      if apply_device_function:
        with _MaybeDevice(node.device):
          g._apply_device_functions(op)  # pylint: disable=protected-access

    # The following loop populates the device field of ops that are
    # colocated with another op.  This is implied by the colocation
    # attribute, but we propagate the device field for completeness.
    for op, coloc_op_list in colocation_pairs.items():
      coloc_device = None
      # Find any device in the list of colocated ops that have a
      # device, if it exists.  We assume that if multiple ops
      # have devices, they refer to the same device.  Otherwise, a
      # runtime error will occur since the colocation property
      # cannot be guaranteed.
      #
      # One possible improvement is to try to check for compatibility
      # of all devices in this list at import time here, which would
      # require implementing a compatibility function for device specs
      # in python.
      for coloc_op in coloc_op_list:
        if coloc_op.device:
          coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
          break
      if coloc_device:
        op._set_device(coloc_device)  # pylint: disable=protected-access

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        name = compat.as_str(name)
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
Exemplo n.º 10
0
 def testFromLibraryEmptyLib(self):
     library = function_pb2.FunctionDefLibrary()
     self.assertEqual(len(function._from_library(library)), 0)
Exemplo n.º 11
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  graph_def = _ProcessGraphDefParam(graph_def)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()

  if graph._c_graph:  # pylint: disable=protected-access
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # Save unique prefix generated by name_scope
      if scope:
        assert scope.endswith('/')
        prefix = scope[:-1]
      else:
        prefix = ''

      # Generate any input map tensors inside name scope
      input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        with errors.raise_exception_on_not_ok_status() as status:
          results = c_api.TF_GraphImportGraphDefWithResults(
              graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    _ProcessNewOps(graph)

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    # Treat input mappings that don't appear in the graph as an error, because
    # they are likely to be due to a typo.
    missing_unused_input_keys = (
        c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
            results))
    if missing_unused_input_keys:
      missing_unused_input_keys = [compat.as_str(s)
                                   for s in missing_unused_input_keys]
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(missing_unused_input_keys))

    if return_elements is None:
      return None
    else:
      return _GatherReturnElements(return_elements, graph, results)

  else:
    g = graph

    # Use a canonical representation for all tensor names.
    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
    used_input_keys = set()
    name_to_op = {}

    # Add any functions defined in `graph_def` to `g`
    if graph_def.library and graph_def.library.function:
      # Copy op_dict so we don't clobber the original
      op_dict = copy.copy(op_dict)
      # pylint: disable=protected-access
      # Note that we do not prepend `name` to the function name. The reasoning
      # is that function names are similar to op definition names, which
      # currently do not have a scoped name or namespace scheme.
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(g)
        op_dict[f.name] = f.definition.signature
      # pylint: enable=protected-access

    # LINT.IfChange
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # TODO(ashankar): Should this just copy over or should it do some
      # more nuanced merging? For example, the graph may already have some
      # marked "bad versions" and we don't want to lose those because of
      # what's in graph_def.versions? The C++ ImporGraphDef does something
      # more nuanced.
      g.graph_def_versions.CopyFrom(graph_def.versions)

      input_map = _ConvertInputMapValues(name, input_map)

      # NOTE(mrry): We do this in two passes, because there may be a cycle in
      # `graph_def`.

      # 1. Add operations without their inputs.
      for node in graph_def.node:
        # Check to see if this op's name matches a previously seen op
        if node.name in name_to_op:
          raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
        # Set any default attr values that aren't present.
        if node.op not in op_dict:
          raise ValueError('No op named %s in defined operations.' % node.op)
        op_def = op_dict[node.op]
        for attr_def in op_def.attr:
          key = attr_def.name
          if attr_def.HasField('default_value'):
            value = node.attr[key]
            if value is None or value.WhichOneof('value') is None:
              node.attr[key].CopyFrom(attr_def.default_value)

        output_types = _OutputTypes(node, op_dict)
        name_to_op[node.name] = g.create_op(
            node.op, [], output_types, name=node.name, attrs=node.attr,
            compute_shapes=False, compute_device=False,
            op_def=op_def)

      # Maps from a node to the ops it is colocated with, if colocation
      # is specified in the attributes.
      colocation_pairs = collections.defaultdict(list)

      # 2. Add inputs to the operations.
      for node in graph_def.node:
        op = name_to_op[node.name]
        input_types = _InputTypes(node, op_dict)
        apply_device_function = True

        # Rewrite the colocation attributes in the graph, since the
        # names of new ops may have changed.
        for key, value in op.node_def.attr.items():
          if key == '_class':
            class_values = value.list
            new_class_values = []
            for class_value in class_values.s:
              if class_value.startswith(b'loc:@'):
                op_to_bind_to = class_value[5:].decode()
                # Find the op by its original name.
                if op_to_bind_to not in name_to_op:
                  raise ValueError('Specified colocation to an op that '
                                   'does not exist during import: %s in %s' % (
                                       op_to_bind_to, node.name))
                original_op = name_to_op[op_to_bind_to]
                new_class_values.append(compat.as_bytes(
                    'loc:@' + original_op.name))
                if op_to_bind_to != node.name:
                  # Keep track of this mapping for a later phase.
                  colocation_pairs[op].append(original_op)
                  # Don't apply this op's device function,
                  # the colocation constraint will ensure
                  # the proper device gets assigned at runtime.
                  apply_device_function = False

              else:
                new_class_values.append(class_value)
            value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
                s=new_class_values))

        # NOTE(mrry): We cannot use zip here because control inputs do not
        # appear in the list of input_types.
        for i, input_name in enumerate(
            [_CanonicalInputName(x) for x in node.input]):

          if _IsControlInput(input_name):
            # (a) Input is a control input that should be taken from an op
            #     in "graph_def".
            try:
              source_op = name_to_op[input_name[1:]]
            except KeyError:
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Control input %r not found in graph_def.'
                      % (input_name,)))
            # pylint: disable=protected-access
            op._add_control_input(source_op)
            # pylint: enable=protected-access

          else:
            try:
              input_type = input_types[i]
            except IndexError:
              raise ValueError(_InvalidNodeMessage(
                  node, 'More inputs specified (%r) than the op expects.'
                  % (input_name,)))

            if input_name in input_map:
              # (b) Input should be replaced by a tensor from the caller.
              source_tensor = input_map[input_name]
              used_input_keys.add(input_name)

            else:
              # (c) Input should be taken from an op in `graph_def`.
              operation_name, output_index = _ParseTensorName(input_name)
              try:
                source_op = name_to_op[operation_name]
                source_tensor = list(source_op.values())[output_index]
              except (KeyError, IndexError):
                raise ValueError(
                    _InvalidNodeMessage(
                        node,
                        'Input tensor %r not found in graph_def.'
                        % (input_name,)))

            try:
              # pylint: disable=protected-access
              op._add_input(source_tensor, dtype=input_type)
              # pylint: enable=protected-access
            except TypeError as te:
              raise ValueError(_InvalidNodeMessage(
                  node, 'Input tensor %r %s' % (input_name, te)))

        # pylint: disable=protected-access
        if op._input_types != input_types:
          raise ValueError(
              _InvalidNodeMessage(
                  node,
                  'Input types mismatch (expected %r but got %r)'
                  % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                     ', '.join(x.name for x in op._input_types))))
        # pylint: enable=protected-access

        if not g._is_function(op.type):  # pylint: disable=protected-access
          # Execute shape inference for this op.
          # NOTE(mrry): If the graph contains a cycle, the full shape
          # information may not be available for this op's inputs.
          ops.set_shapes_for_outputs(op)
        # For nodes with _output_shapes set, set the output shapes.
        if '_output_shapes' in op.node_def.attr:
          for i, output in enumerate(op.outputs):
            dims = op.node_def.attr['_output_shapes'].list.shape[i]
            output_shape = tensor_shape.TensorShape(
                None if dims.unknown_rank else
                [dim.size if dim.size >= 0 else None for dim in dims.dim])

            try:
              output.set_shape(output_shape)
            except ValueError as e:
              # If the output shape is incompatible with what is inferred
              # by the graph for a very specific whitelist of ops, then we
              # ignore this output shape.  This can happen if there is a
              # bug in the shape function for some operation, and the
              # serialized graph def has the incorrect shape set when
              # running on a newer binary with the fixed shape function.
              # This is an escape hatch that allows us to correct shape
              # functions that are not critical to correct execution but
              # would cause graphs to fail if imported after correcting.
              #
              # This can be removed after 2017/03/08.
              if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                             'FIFOQueue', 'PriorityQueue', 'QueueSize',
                             'Stack', 'Barrier', 'BarrierReadySize',
                             'BarrierIncompleteSize', 'HashTable',
                             'MutableHashTable',
                             'MutableHashTableOfTensors', 'Mutex',
                             'CuckooTable', 'IndexTable',
                             'WholeFileReader', 'TextLineReader',
                             'FixedLengthRecordReader',
                             'TFRecordReader', 'IdentityReader',
                             'LMDBReader',
                             'RefSwitch', 'RefEnter', 'RefNextIteration',
                             'RefMerge', 'RefIdentity']:
                pass
              elif op.type in [
                  'ConditionalAccumulator', 'SparseConditionalAccumulator',
                  'Table'
              ]:
                # This can be removed after 2017/04/24.
                pass
              else:
                raise e

          del op.node_def.attr['_output_shapes']

        # NOTE(mrry): We do this after configuring the inputs, because
        # the result of the device functions may depend on the inputs.
        if apply_device_function:
          with _MaybeDevice(node.device):
            g._apply_device_functions(op)  # pylint: disable=protected-access

      # The following loop populates the device field of ops that are
      # colocated with another op.  This is implied by the colocation
      # attribute, but we propagate the device field for completeness.
      for op, coloc_op_list in colocation_pairs.items():
        coloc_device = None
        # Find any device in the list of colocated ops that have a
        # device, if it exists.  We assume that if multiple ops
        # have devices, they refer to the same device.  Otherwise, a
        # runtime error will occur since the colocation property
        # cannot be guaranteed.
        #
        # One possible improvement is to try to check for compatibility
        # of all devices in this list at import time here, which would
        # require implementing a compatibility function for device specs
        # in python.
        for coloc_op in coloc_op_list:
          if coloc_op.device:
            coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
            break
        if coloc_device:
          op._set_device(coloc_device)  # pylint: disable=protected-access

      # Treat input mappings that don't appear in the graph as an error,
      # because they are likely to be due to a typo.
      def _IsImportedNodeOutput(tensor_name):
        operation_name, output_index = _ParseTensorName(tensor_name)
        try:
          return output_index < len(name_to_op[operation_name].outputs)
        except KeyError:
          return False
      absent_input_keys = [
          k for k in frozenset(input_map.keys()).difference(used_input_keys)
          if not _IsImportedNodeOutput(k)]
      if absent_input_keys:
        raise ValueError(
            'Attempted to map inputs that were not found in graph_def: [%s]'
            % ', '.join(absent_input_keys))

      if return_elements is None:
        return None
      else:
        ret = []
        for name in return_elements:
          name = compat.as_str(name)
          if ':' in name:
            try:
              operation_name, output_index = _ParseTensorName(name)
              ret.append(name_to_op[operation_name].outputs[output_index])
            except (ValueError, KeyError, IndexError):
              raise ValueError(
                  'Requested return_element %r not found in graph_def.' % name)
          else:
            try:
              ret.append(name_to_op[name])
            except KeyError:
              raise ValueError(
                  'Requested return_element %r not found in graph_def.' % name)
        return ret
Exemplo n.º 12
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    op_dict = op_def_registry.get_registered_ops()

    graph_def = _ProcessGraphDefParam(graph_def, op_dict)
    input_map = _ProcessInputMapParam(input_map)
    return_elements = _ProcessReturnElementsParam(return_elements)

    if producer_op_list is not None:
        # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
        _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

    graph = ops.get_default_graph()

    if graph._c_graph:  # pylint: disable=protected-access
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # Save unique prefix generated by name_scope
            if scope:
                assert scope.endswith('/')
                prefix = scope[:-1]
            else:
                prefix = ''

            # Generate any input map tensors inside name scope
            input_map = _ConvertInputMapValues(name, input_map)

        scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
        options = scoped_options.options
        _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                         return_elements)

        # _ProcessNewOps mutates the new operations. _lock ensures a Session.run
        # call cannot occur between creating the TF_Operations in the
        # TF_GraphImportGraphDefWithResults call and mutating the them in
        # _ProcessNewOps.
        with graph._lock:  # pylint: disable=protected-access
            with c_api_util.tf_buffer(
                    graph_def.SerializeToString()) as serialized:
                try:
                    with errors.raise_exception_on_not_ok_status() as status:
                        results = c_api.TF_GraphImportGraphDefWithResults(
                            graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
                except errors.InvalidArgumentError as e:
                    # Convert to ValueError for backwards compatibility.
                    raise ValueError(str(e))

            _ProcessNewOps(graph)

        # Create _DefinedFunctions for any imported functions.
        #
        # We do this by creating _DefinedFunctions directly from `graph_def`, and
        # adding them to `graph`. Adding an existing function to a TF_Graph is a
        # no-op, so this only has the effect of updating the Python state (usually
        # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
        #
        # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
        # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
        if graph_def.library and graph_def.library.function:
            # pylint: disable=protected-access
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(graph)
            # pylint: enable=protected-access

        # Treat input mappings that don't appear in the graph as an error, because
        # they are likely to be due to a typo.
        missing_unused_input_keys = (
            c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
                results))
        if missing_unused_input_keys:
            missing_unused_input_keys = [
                compat.as_str(s) for s in missing_unused_input_keys
            ]
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(missing_unused_input_keys))

        if return_elements is None:
            return None
        else:
            return _GatherReturnElements(return_elements, graph, results)

    else:
        g = graph

        # Use a canonical representation for all tensor names.
        input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
        used_input_keys = set()
        name_to_op = {}

        # Add any functions defined in `graph_def` to `g`
        if graph_def.library and graph_def.library.function:
            # Copy op_dict so we don't clobber the original
            op_dict = copy.copy(op_dict)
            # pylint: disable=protected-access
            # Note that we do not prepend `name` to the function name. The reasoning
            # is that function names are similar to op definition names, which
            # currently do not have a scoped name or namespace scheme.
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(g)
                op_dict[f.name] = f.definition.signature
            # pylint: enable=protected-access

        # LINT.IfChange
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # TODO(ashankar): Should this just copy over or should it do some
            # more nuanced merging? For example, the graph may already have some
            # marked "bad versions" and we don't want to lose those because of
            # what's in graph_def.versions? The C++ ImporGraphDef does something
            # more nuanced.
            g.graph_def_versions.CopyFrom(graph_def.versions)

            input_map = _ConvertInputMapValues(name, input_map)

            # NOTE(mrry): We do this in two passes, because there may be a cycle in
            # `graph_def`.

            # 1. Add operations without their inputs.
            for node in graph_def.node:
                # Check to see if this op's name matches a previously seen op
                if node.name in name_to_op:
                    raise ValueError('Duplicate name \'%s\' in GraphDef.' %
                                     node.name)
                if node.op not in op_dict:
                    raise ValueError('No op named %s in defined operations.' %
                                     node.op)
                op_def = op_dict[node.op]

                output_types = _OutputTypes(node, op_dict)
                name_to_op[node.name] = g.create_op(node.op, [],
                                                    output_types,
                                                    name=node.name,
                                                    attrs=node.attr,
                                                    compute_shapes=False,
                                                    compute_device=False,
                                                    op_def=op_def)

            # Maps from a node to the ops it is colocated with, if colocation
            # is specified in the attributes.
            colocation_pairs = collections.defaultdict(list)

            # 2. Add inputs to the operations.
            for node in graph_def.node:
                op = name_to_op[node.name]
                input_types = _InputTypes(node, op_dict)
                apply_device_function = True

                # Rewrite the colocation attributes in the graph, since the
                # names of new ops may have changed.
                for key, value in op.node_def.attr.items():
                    if key == '_class':
                        class_values = value.list
                        new_class_values = []
                        for class_value in class_values.s:
                            if class_value.startswith(b'loc:@'):
                                op_to_bind_to = class_value[5:].decode()
                                # Find the op by its original name.
                                if op_to_bind_to not in name_to_op:
                                    raise ValueError(
                                        'Specified colocation to an op that '
                                        'does not exist during import: %s in %s'
                                        % (op_to_bind_to, node.name))
                                original_op = name_to_op[op_to_bind_to]
                                new_class_values.append(
                                    compat.as_bytes('loc:@' +
                                                    original_op.name))
                                if op_to_bind_to != node.name:
                                    # Keep track of this mapping for a later phase.
                                    colocation_pairs[op].append(original_op)
                                    # Don't apply this op's device function,
                                    # the colocation constraint will ensure
                                    # the proper device gets assigned at runtime.
                                    apply_device_function = False

                            else:
                                new_class_values.append(class_value)
                        value.list.CopyFrom(
                            attr_value_pb2.AttrValue.ListValue(
                                s=new_class_values))

                # NOTE(mrry): We cannot use zip here because control inputs do not
                # appear in the list of input_types.
                for i, input_name in enumerate(
                    [_CanonicalInputName(x) for x in node.input]):

                    if _IsControlInput(input_name):
                        # (a) Input is a control input that should be taken from an op
                        #     in "graph_def".
                        try:
                            source_op = name_to_op[input_name[1:]]
                        except KeyError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Control input %r not found in graph_def.'
                                    % (input_name, )))
                        # pylint: disable=protected-access
                        op._add_control_input(source_op)
                        # pylint: enable=protected-access

                    else:
                        try:
                            input_type = input_types[i]
                        except IndexError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'More inputs specified (%r) than the op expects.'
                                    % (input_name, )))

                        if input_name in input_map:
                            # (b) Input should be replaced by a tensor from the caller.
                            source_tensor = input_map[input_name]
                            used_input_keys.add(input_name)

                        else:
                            # (c) Input should be taken from an op in `graph_def`.
                            operation_name, output_index = _ParseTensorName(
                                input_name)
                            try:
                                source_op = name_to_op[operation_name]
                                source_tensor = list(
                                    source_op.values())[output_index]
                            except (KeyError, IndexError):
                                raise ValueError(
                                    _InvalidNodeMessage(
                                        node,
                                        'Input tensor %r not found in graph_def.'
                                        % (input_name, )))

                        try:
                            # pylint: disable=protected-access
                            op._add_input(source_tensor, dtype=input_type)
                            # pylint: enable=protected-access
                        except TypeError as te:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r %s' % (input_name, te)))

                # pylint: disable=protected-access
                if op._input_types != input_types:
                    raise ValueError(
                        _InvalidNodeMessage(
                            node,
                            'Input types mismatch (expected %r but got %r)' %
                            (', '.join(
                                dtypes.as_dtype(x).name
                                for x in input_types), ', '.join(
                                    x.name for x in op._input_types))))
                # pylint: enable=protected-access

                if not g._is_function(op.type):  # pylint: disable=protected-access
                    # Execute shape inference for this op.
                    # NOTE(mrry): If the graph contains a cycle, the full shape
                    # information may not be available for this op's inputs.
                    ops.set_shapes_for_outputs(op)
                # For nodes with _output_shapes set, set the output shapes.
                if '_output_shapes' in op.node_def.attr:
                    for i, output in enumerate(op.outputs):
                        dims = op.node_def.attr['_output_shapes'].list.shape[i]
                        output_shape = tensor_shape.TensorShape(
                            None if dims.unknown_rank else [
                                dim.size if dim.size >= 0 else None
                                for dim in dims.dim
                            ])

                        try:
                            output.set_shape(output_shape)
                        except ValueError as e:
                            # If the output shape is incompatible with what is inferred
                            # by the graph for a very specific whitelist of ops, then we
                            # ignore this output shape.  This can happen if there is a
                            # bug in the shape function for some operation, and the
                            # serialized graph def has the incorrect shape set when
                            # running on a newer binary with the fixed shape function.
                            # This is an escape hatch that allows us to correct shape
                            # functions that are not critical to correct execution but
                            # would cause graphs to fail if imported after correcting.
                            #
                            # This can be removed after 2017/03/08.
                            if op.type in [
                                    'RandomShuffleQueue', 'PaddingFIFOQueue',
                                    'FIFOQueue', 'PriorityQueue', 'QueueSize',
                                    'Stack', 'Barrier', 'BarrierReadySize',
                                    'BarrierIncompleteSize', 'HashTable',
                                    'MutableHashTable',
                                    'MutableHashTableOfTensors', 'Mutex',
                                    'CuckooTable', 'IndexTable',
                                    'WholeFileReader', 'TextLineReader',
                                    'FixedLengthRecordReader',
                                    'TFRecordReader', 'IdentityReader',
                                    'LMDBReader', 'RefSwitch', 'RefEnter',
                                    'RefNextIteration', 'RefMerge',
                                    'RefIdentity'
                            ]:
                                pass
                            elif op.type in [
                                    'ConditionalAccumulator',
                                    'SparseConditionalAccumulator', 'Table'
                            ]:
                                # This can be removed after 2017/04/24.
                                pass
                            else:
                                raise e

                    del op.node_def.attr['_output_shapes']

                # NOTE(mrry): We do this after configuring the inputs, because
                # the result of the device functions may depend on the inputs.
                if apply_device_function:
                    with _MaybeDevice(node.device):
                        g._apply_device_functions(op)  # pylint: disable=protected-access

            # The following loop populates the device field of ops that are
            # colocated with another op.  This is implied by the colocation
            # attribute, but we propagate the device field for completeness.
            for op, coloc_op_list in colocation_pairs.items():
                coloc_device = None
                # Find any device in the list of colocated ops that have a
                # device, if it exists.  We assume that if multiple ops
                # have devices, they refer to the same device.  Otherwise, a
                # runtime error will occur since the colocation property
                # cannot be guaranteed.
                #
                # One possible improvement is to try to check for compatibility
                # of all devices in this list at import time here, which would
                # require implementing a compatibility function for device specs
                # in python.
                for coloc_op in coloc_op_list:
                    if coloc_op.device:
                        coloc_device = pydev.DeviceSpec.from_string(
                            coloc_op.device)
                        break
                if coloc_device:
                    op._set_device(coloc_device)  # pylint: disable=protected-access

            # Treat input mappings that don't appear in the graph as an error,
            # because they are likely to be due to a typo.
            def _IsImportedNodeOutput(tensor_name):
                operation_name, output_index = _ParseTensorName(tensor_name)
                try:
                    return output_index < len(
                        name_to_op[operation_name].outputs)
                except KeyError:
                    return False

            absent_input_keys = [
                k for k in frozenset(input_map.keys()).difference(
                    used_input_keys) if not _IsImportedNodeOutput(k)
            ]
            if absent_input_keys:
                raise ValueError(
                    'Attempted to map inputs that were not found in graph_def: [%s]'
                    % ', '.join(absent_input_keys))

            if return_elements is None:
                return None
            else:
                ret = []
                for name in return_elements:
                    name = compat.as_str(name)
                    if ':' in name:
                        try:
                            operation_name, output_index = _ParseTensorName(
                                name)
                            ret.append(name_to_op[operation_name].
                                       outputs[output_index])
                        except (ValueError, KeyError, IndexError):
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                    else:
                        try:
                            ret.append(name_to_op[name])
                        except KeyError:
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                return ret
Exemplo n.º 13
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  op_dict = op_def_registry.get_registered_ops()

  graph_def = _ProcessGraphDefParam(graph_def, op_dict)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # Save unique prefix generated by name_scope
    if scope:
      assert scope.endswith('/')
      prefix = scope[:-1]
    else:
      prefix = ''

    # Generate any input map tensors inside name scope
    input_map = _ConvertInputMapValues(name, input_map)

  scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
  options = scoped_options.options
  _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                   return_elements)

  # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
  # Session.run call cannot occur between creating the TF_Operations in the
  # TF_GraphImportGraphDefWithResults call and mutating the them in
  # _ProcessNewOps.
  with graph._mutation_lock():  # pylint: disable=protected-access
    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        results = c_api.TF_GraphImportGraphDefWithResults(
            graph._c_graph, serialized, options)  # pylint: disable=protected-access
        results = c_api_util.ScopedTFImportGraphDefResults(results)
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
    # _USE_C_SHAPES is removed.
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    _ProcessNewOps(graph)

  # Treat input mappings that don't appear in the graph as an error, because
  # they are likely to be due to a typo.
  missing_unused_input_keys = (
      c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
          results.results))
  if missing_unused_input_keys:
    missing_unused_input_keys = [
        compat.as_str(s) for s in missing_unused_input_keys
    ]
    raise ValueError(
        'Attempted to map inputs that were not found in graph_def: [%s]' %
        ', '.join(missing_unused_input_keys))

  if return_elements is None:
    return None
  else:
    return _GatherReturnElements(return_elements, graph, results.results)
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  `tf.Tensor` and `tf.Operation` objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  `tf.Graph.as_graph_def` for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    op_dict = op_def_registry.get_registered_ops()

    graph_def = _ProcessGraphDefParam(graph_def, op_dict)
    input_map = _ProcessInputMapParam(input_map)
    return_elements = _ProcessReturnElementsParam(return_elements)

    if producer_op_list is not None:
        # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
        _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

    graph = ops.get_default_graph()
    with ops.name_scope(name, 'import', input_map.values()) as scope:
        # Save unique prefix generated by name_scope
        if scope:
            assert scope.endswith('/')
            prefix = scope[:-1]
        else:
            prefix = ''

        # Generate any input map tensors inside name scope
        input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
    # Session.run call cannot occur between creating the TF_Operations in the
    # TF_GraphImportGraphDefWithResults call and mutating the them in
    # _ProcessNewOps.
    with graph._mutation_lock():  # pylint: disable=protected-access
        with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
            try:
                results = c_api.TF_GraphImportGraphDefWithResults(
                    graph._c_graph, serialized, options)  # pylint: disable=protected-access
                results = c_api_util.ScopedTFImportGraphDefResults(results)
            except errors.InvalidArgumentError as e:
                # Convert to ValueError for backwards compatibility.
                raise ValueError(str(e))

        # Create _DefinedFunctions for any imported functions.
        #
        # We do this by creating _DefinedFunctions directly from `graph_def`, and
        # adding them to `graph`. Adding an existing function to a TF_Graph is a
        # no-op, so this only has the effect of updating the Python state (usually
        # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
        #
        # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
        # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
        # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
        # _USE_C_SHAPES is removed.
        if graph_def.library and graph_def.library.function:
            # pylint: disable=protected-access
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(graph)
            # pylint: enable=protected-access

        _ProcessNewOps(graph)

    # Treat input mappings that don't appear in the graph as an error, because
    # they are likely to be due to a typo.
    missing_unused_input_keys = (
        c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
            results.results))
    if missing_unused_input_keys:
        missing_unused_input_keys = [
            compat.as_str(s) for s in missing_unused_input_keys
        ]
        raise ValueError(
            'Attempted to map inputs that were not found in graph_def: [%s]' %
            ', '.join(missing_unused_input_keys))

    if return_elements is None:
        return None
    else:
        return _GatherReturnElements(return_elements, graph, results.results)