Ejemplo n.º 1
0
def stripped_op_list_for_graph(graph_def):
    """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.

  Raises:
    ValueError: If an unregistered op is used.
  """
    # This is the Python equivalent of StrippedOpListForGraph in C++.
    # Unfortunately, since the Python op registry can differ from that in C++, we
    # can't remove the duplication using swig (at least naively).
    # TODO(irving): Support taking graphs directly.

    used_ops = ops_used_by_graph_def(graph_def)

    # Verify that all used ops are registered.
    registered_ops = op_def_registry.get_registered_ops()
    # These internal ops used by functions are not registered, so we need to
    # whitelist them.  # TODO(irving): Do something better here.
    op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
    for op in used_ops:
        if op not in registered_ops and op not in op_whitelist:
            raise ValueError("Op %s is used by the graph, but is not registered" % op)

    # Build the stripped op list in sorted order
    return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) if op in registered_ops])
Ejemplo n.º 2
0
def _add_op_node(op, func):
    """Converts an op to a function def node and add it to `func`."""
    node = function_pb2.FunctionDef.Node()
    node.op = op.type
    # pylint: disable=protected-access
    if hasattr(op, "_sig"):
        op_def = getattr(op, "_sig")
    else:
        op_def = op_def_registry.get_registered_ops()[op.type]
    # pylint: enable=protected-access
    attrs = _get_node_def_attr(op)
    if not op_def.output_arg:
        node.ret.append(_make_argname_from_tensor_name(op.name))
    else:
        out_index = 0
        for arg_def in op_def.output_arg:
            if arg_def.number_attr:
                dtype = arg_def.type or attrs[arg_def.type_attr].type
                num = attrs[arg_def.number_attr].i
                node.ret.append(
                    _add_output_array(op, out_index, out_index + num, dtype,
                                      func))
                out_index += num
            elif arg_def.type_list_attr:
                dtype_lst = attrs[arg_def.type_list_attr].list.type
                num = len(dtype_lst)
                node.ret.append(
                    _add_output_list(op, out_index, out_index + num, dtype_lst,
                                     func))
                out_index += num
            else:
                node.ret.append(
                    _make_argname_from_tensor_name(op.outputs[out_index].name))
                out_index += 1
    inp_index = 0
    for arg_def in op_def.input_arg:
        if arg_def.number_attr:
            dtype = arg_def.type or attrs[arg_def.type_attr].type
            num = attrs[arg_def.number_attr].i
            node.arg.append(
                _add_input_array(op, inp_index, inp_index + num, dtype, func))
            inp_index += num
        elif arg_def.type_list_attr:
            num = len(attrs[arg_def.type_list_attr].list.type)
            node.arg.extend([
                _make_argname_from_tensor_name(op.inputs[i].name)
                for i in range(inp_index, inp_index + num)
            ])
            inp_index += num
        else:
            node.arg.append(
                _make_argname_from_tensor_name(op.inputs[inp_index].name))
            inp_index += 1
    node.dep.extend(
        [_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
    for k, v in _get_node_def_attr(op).items():
        node.attr[k].CopyFrom(v)
    func.node.extend([node])
Ejemplo n.º 3
0
  def testStripDefaultAttrsInconsistentConsumerDefaults(self):
    if ops._USE_C_API: return  # TODO(skyewm): get this working

    export_dir = self._get_export_dir(
        "test_strip_default_attrs_no_consumer_defaults")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Add a graph with two float32 variables and a Complex Op composing them
    # with strip_default_attrs enabled. This must remove the following
    # defaults for the "Complex" Op:
    #   o "T"    : float32.   (input type)
    #   o "Tout" : complex64. (output type)
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], strip_default_attrs=True)

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Update the Op registry to remove defaults for all attrs("T", "Tout") from
    # the "Complex" OpDef.
    complex_op_def = op_def_registry.get_registered_ops()["Complex"]
    original_complex_op_def = op_def_pb2.OpDef()
    original_complex_op_def.CopyFrom(complex_op_def)
    for attr_def in complex_op_def.attr:
      attr_def.ClearField("default_value")

    # Loading the SavedModel via the loader must fail because the SavedModel
    # does not have any attr values for the "Complex" node and the current
    # op registry does not have have any default values for the "Complex" op.
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        ValueError,
        "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
      loader.load(sess, ["foo"], export_dir)

    # Update the Op registry to change the defaults for attr "Tout"
    # (complex64 -> complex128).
    complex_op_def.CopyFrom(original_complex_op_def)
    for attr_def in complex_op_def.attr:
      if attr_def.name == "Tout":
        attr_def.default_value.type = types_pb2.DT_COMPLEX128

    # Loading the SavedModel via the loader must set "Tout" attr_value for the
    # "Complex" node according to the latest defaults (complex128). This is
    # expected to fail the model import as there is no OpKernel registered to
    # handle attrs "T" (float32) and "Tout" (complex128).
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        ".*No OpKernel was registered to support Op \'Complex\' with these "
        "attrs..*"):
      loader.load(sess, ["foo"], export_dir)
Ejemplo n.º 4
0
def _is_array_type_input(op, i):
    registered_ops = op_def_registry.get_registered_ops()
    if op not in registered_ops:
        return False
    op_def = registered_ops[op]
    if i not in xrange(len(op_def.input_arg)):
        raise TypeError("Expected arg index " "to be in [0, %d)" % len(op_def.input_arg))
    input_arg = op_def.input_arg[i]
    return True if input_arg.number_attr else False
Ejemplo n.º 5
0
def _is_array_type_input(op, i):
    registered_ops = op_def_registry.get_registered_ops()
    if op not in registered_ops:
        return False
    op_def = registered_ops[op]
    if i not in xrange(len(op_def.input_arg)):
        raise TypeError("Expected arg index "
                        "to be in [0, %d)" % len(op_def.input_arg))
    input_arg = op_def.input_arg[i]
    return True if input_arg.number_attr else False
Ejemplo n.º 6
0
def _register_function_ops(func_list):
  """Registers custom ops in the default graph. This is needed
  Because our checkpoint is saved with ops that are not part of Tensorflow."""
  op_dict = op_def_registry.get_registered_ops()
  for func in func_list:
    #pylint: disable=W0212
    func._create_definition_if_needed()
    op_def = func._definition.signature
    op_dict[op_def.name] = op_def
    RegisterShape(op_def.name)(common_shapes.unknown_shape)
Ejemplo n.º 7
0
def _register_function_ops(func_list):
    """Registers custom ops in the default graph. This is needed
  Because our checkpoint is saved with ops that are not part of Tensorflow."""
    op_dict = op_def_registry.get_registered_ops()
    for func in func_list:
        #pylint: disable=W0212
        func._create_definition_if_needed()
        op_def = func._definition.signature
        op_dict[op_def.name] = op_def
        RegisterShape(op_def.name)(common_shapes.unknown_shape)
Ejemplo n.º 8
0
def _add_op_node(op, func):
  """Converts an op to a function def node and add it to `func`."""
  node = function_pb2.FunctionDef.Node()
  node.op = op.type
  # pylint: disable=protected-access
  if hasattr(op, "_sig"):
    op_def = getattr(op, "_sig")
  else:
    op_def = op_def_registry.get_registered_ops()[op.type]
  # pylint: enable=protected-access
  attrs = _get_node_def_attr(op)
  if not op_def.output_arg:
    node.ret.append(_make_argname_from_tensor_name(op.name))
  else:
    out_index = 0
    for arg_def in op_def.output_arg:
      if arg_def.number_attr:
        dtype = arg_def.type or attrs[arg_def.type_attr].type
        num = attrs[arg_def.number_attr].i
        node.ret.append(
            _add_output_array(op, out_index, out_index + num, dtype, func))
        out_index += num
      elif arg_def.type_list_attr:
        dtype_lst = attrs[arg_def.type_list_attr].list.type
        num = len(dtype_lst)
        node.ret.append(
            _add_output_list(op, out_index, out_index + num, dtype_lst, func))
        out_index += num
      else:
        node.ret.append(
            _make_argname_from_tensor_name(op.outputs[out_index].name))
        out_index += 1
  inp_index = 0
  for arg_def in op_def.input_arg:
    if arg_def.number_attr:
      dtype = arg_def.type or attrs[arg_def.type_attr].type
      num = attrs[arg_def.number_attr].i
      node.arg.append(
          _add_input_array(op, inp_index, inp_index + num, dtype, func))
      inp_index += num
    elif arg_def.type_list_attr:
      num = len(attrs[arg_def.type_list_attr].list.type)
      node.arg.extend([
          _make_argname_from_tensor_name(op.inputs[i].name)
          for i in range(inp_index, inp_index + num)
      ])
      inp_index += num
    else:
      node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
      inp_index += 1
  node.dep.extend(
      [_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
  for k, v in _get_node_def_attr(op).items():
    node.attr[k].CopyFrom(v)
  func.node.extend([node])
Ejemplo n.º 9
0
def _stripped_op_list_for_graph(graph_def):
  """Returns OpDefs of ops used in graph_def."""
  op_set = set()
  registered_ops = op_def_registry.get_registered_ops()
  for n in graph_def.node:
    if n.op in registered_ops:
      op_set.add(n.op)
  for func in graph_def.library.function:
    for n in func.node:
      if n.op in registered_ops:
        op_set.add(n.op)
  return op_def_pb2.OpList(op=[registered_ops[x] for x in sorted(op_set)])
Ejemplo n.º 10
0
  def sync():
    p_buffer = c_api.TF_GetAllOpList()
    cpp_op_list = op_def_pb2.OpList()
    cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))

    registered_ops = op_def_registry.get_registered_ops()
    for op_def in cpp_op_list.op:
      # If an OpList is registered from a gen_*_ops.py, it does not any
      # descriptions. Strip them here as well to satisfy validation in
      # register_op_list.
      _remove_non_deprecated_descriptions(op_def)
      registered_ops[op_def.name] = op_def
Ejemplo n.º 11
0
def list_registered_stateful_ops_without_inputs():
    """Returns set of registered stateful ops that do not expect inputs.

  This list is used to identify the ops to be included in the state-graph and
  that are subsequently fed into the apply-graphs.

  Returns:
    A set of strings.
  """
    return set([
        name for name, op in op_def_registry.get_registered_ops().items()
        if op.is_stateful and not op.input_arg
    ])
Ejemplo n.º 12
0
def list_registered_stateful_ops_without_inputs():
  """Returns set of registered stateful ops that do not expect inputs.

  This list is used to identify the ops to be included in the state-graph and
  that are subsequently fed into the apply-graphs.

  Returns:
    A set of strings.
  """
  return set([
      name
      for name, op in op_def_registry.get_registered_ops().items()
      if op.is_stateful and not op.input_arg
  ])
Ejemplo n.º 13
0
  def _get_ref_args(self, node):
    """Determine whether an input of an op is ref-type.

    Args:
      node: A `NodeDef`.

    Returns:
      A list of the arg names (as strs) that are ref-type.
    """
    op_def = op_def_registry.get_registered_ops().get(node.op)
    ref_args = []
    if op_def:
      for i, output_arg in enumerate(op_def.output_arg):
        if output_arg.is_ref:
          arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
          ref_args.append(arg_name)
    return ref_args
Ejemplo n.º 14
0
  def _get_ref_args(self, node):
    """Determine whether an input of an op is ref-type.

    Args:
      node: A `NodeDef`.

    Returns:
      A list of the arg names (as strs) that are ref-type.
    """
    op_def = op_def_registry.get_registered_ops().get(node.op)
    ref_args = []
    if op_def:
      for i, output_arg in enumerate(op_def.output_arg):
        if output_arg.is_ref:
          arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
          ref_args.append(arg_name)
    return ref_args
Ejemplo n.º 15
0
def _strip_graph_default_valued_attrs(meta_graph_def):
  """Strips default valued attributes for node defs in given MetaGraphDef.

  This method also sets `meta_info_def.stripped_default_attrs` in the given
  `MetaGraphDef` proto to True.

  Args:
    meta_graph_def: `MetaGraphDef` protocol buffer

  Returns:
    None.
  """
  # Map function op names to their function definitions.
  op_name_to_function = {}
  for function_def in meta_graph_def.graph_def.library.function:
    op_name_to_function[function_def.signature.name] = function_def

  # Get all registered ops.
  registered_ops = op_def_registry.get_registered_ops()

  def _strip_node_default_valued_attrs(node_def):
    """Removes default valued attributes from a single node def."""
    if node_def.op in op_name_to_function or node_def.op not in registered_ops:
      return
    op_def = registered_ops[node_def.op]

    attrs_to_strip = set()
    for attr_name, attr_value in node_def.attr.items():
      if _is_default_attr_value(op_def, attr_name, attr_value):
        attrs_to_strip.add(attr_name)

    for attr in attrs_to_strip:
      del node_def.attr[attr]

  # Process all NodeDef instances in graph_def.
  for node_def in meta_graph_def.graph_def.node:
    _strip_node_default_valued_attrs(node_def)

  # Process all NodeDef instances in graph_def.library.function.
  for function_def in meta_graph_def.graph_def.library.function:
    for function_node_def in function_def.node_def:
      _strip_node_default_valued_attrs(function_node_def)

  # Tell consumers of this graph that default valued attrs have been stripped.
  meta_graph_def.meta_info_def.stripped_default_attrs = True
Ejemplo n.º 16
0
def _strip_graph_default_valued_attrs(meta_graph_def):
  """Strips default valued attributes for node defs in given MetaGraphDef.

  This method also sets `meta_info_def.stripped_default_attrs` in the given
  `MetaGraphDef` proto to True.

  Args:
    meta_graph_def: `MetaGraphDef` protocol buffer

  Returns:
    None.
  """
  # Map function op names to their function definitions.
  op_name_to_function = {}
  for function_def in meta_graph_def.graph_def.library.function:
    op_name_to_function[function_def.signature.name] = function_def

  # Get all registered ops.
  registered_ops = op_def_registry.get_registered_ops()

  def _strip_node_default_valued_attrs(node_def):
    """Removes default valued attributes from a single node def."""
    if node_def.op in op_name_to_function or node_def.op not in registered_ops:
      return
    op_def = registered_ops[node_def.op]

    attrs_to_strip = set()
    for attr_name, attr_value in node_def.attr.items():
      if _is_default_attr_value(op_def, attr_name, attr_value):
        attrs_to_strip.add(attr_name)

    for attr in attrs_to_strip:
      del node_def.attr[attr]

  # Process all NodeDef instances in graph_def.
  for node_def in meta_graph_def.graph_def.node:
    _strip_node_default_valued_attrs(node_def)

  # Process all NodeDef instances in graph_def.library.function.
  for function_def in meta_graph_def.graph_def.library.function:
    for function_node_def in function_def.node_def:
      _strip_node_default_valued_attrs(function_node_def)

  # Tell consumers of this graph that default valued attrs have been stripped.
  meta_graph_def.meta_info_def.stripped_default_attrs = True
Ejemplo n.º 17
0
def _create_op_def_library(op_protos):
    for op_proto in op_protos:
        registered_ops = _registry.get_registered_ops()
        if op_proto.name not in registered_ops:
            raise LookupError("Op with name {0} not registered".format(
                op_proto.name))

        op_def_lib = _op_def_library.OpDefLibrary()
        ops_proto = _op_def_pb2.OpList()
        ops_proto.op.extend([op_proto])

    # Fails if the interfaces ("op schemas") don't match between the
    # previously registered op and this one.
    _registry.register_op_list(ops_proto)

    op_def_lib.add_op_list(ops_proto)

    return op_def_lib
Ejemplo n.º 18
0
def register_ops_if_needed(graph_ops):
  """Register graph ops absent in op_def_registry, if present in c++ registry.

  Args:
    graph_ops: set with graph op names to register.

  Raises:
    RuntimeError: if `graph_ops` contains ops that are not in either python or
      c++ registry.
  """
  missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys())

  if not missing_ops:
    return

  p_buffer = c_api.TF_GetAllOpList()
  cpp_op_list = op_def_pb2.OpList()
  cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
  cpp_registry_ops = {op.name: op for op in cpp_op_list.op}

  missing_op_list = op_def_pb2.OpList()
  for missing_op in missing_ops:
    if missing_op not in cpp_registry_ops:
      tf.logging.info(
          "Op %s is missing from both the python and C++ registry.",
          missing_op)
    else:
      missing_op_list.op.extend([cpp_registry_ops[missing_op]])
      tf.logging.info(
          "Adding op %s from c++ registry to python registry.",
          missing_op)

  op_def_registry.register_op_list(missing_op_list)

  # Note: Only raise missing op ValueError after trying to load ops.
  # This allows the test to exercise all the calls into TensorFlow
  # without having to write a C + python test.
  if not missing_ops <= set(cpp_registry_ops.keys()):
    raise RuntimeError(
        "Graph ops missing from the python registry (%s) are also absent from "
        "the c++ registry."
        % missing_ops.difference(set(cpp_registry_ops.keys())))
Ejemplo n.º 19
0
def register_ops_if_needed(graph_ops):
    """Register graph ops absent in op_def_registry, if present in c++ registry.

  Args:
    graph_ops: set with graph op names to register.

  Raises:
    RuntimeError: if `graph_ops` contains ops that are not in either python or
      c++ registry.
  """
    missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys())

    if not missing_ops:
        return

    p_buffer = c_api.TF_GetAllOpList()
    cpp_op_list = op_def_pb2.OpList()
    cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
    cpp_registry_ops = {op.name: op for op in cpp_op_list.op}

    missing_op_list = op_def_pb2.OpList()
    for missing_op in missing_ops:
        if missing_op not in cpp_registry_ops:
            tf.logging.info(
                "Op %s is missing from both the python and C++ registry.",
                missing_op)
        else:
            missing_op_list.op.extend([cpp_registry_ops[missing_op]])
            tf.logging.info(
                "Adding op %s from c++ registry to python registry.",
                missing_op)

    op_def_registry.register_op_list(missing_op_list)

    # Note: Only raise missing op ValueError after trying to load ops.
    # This allows the test to exercise all the calls into TensorFlow
    # without having to write a C + python test.
    if not missing_ops <= set(cpp_registry_ops.keys()):
        raise RuntimeError(
            "Graph ops missing from the python registry (%s) are also absent from "
            "the c++ registry." %
            missing_ops.difference(set(cpp_registry_ops.keys())))
Ejemplo n.º 20
0
def stripped_op_list_for_graph(graph_def):
    """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.

  Raises:
    ValueError: If an unregistered op is used.
  """
    # This is the Python equivalent of StrippedOpListForGraph in C++.
    # Unfortunately, since the Python op registry can differ from that in C++, we
    # can't remove the duplication using swig (at least naively).
    # TODO(irving): Support taking graphs directly.

    used_ops = ops_used_by_graph_def(graph_def)

    # Verify that all used ops are registered.
    registered_ops = op_def_registry.get_registered_ops()
    # These internal ops used by functions are not registered, so we need to
    # whitelist them.  # TODO(irving): Do something better here.
    op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
    for op in used_ops:
        if op not in registered_ops and op not in op_whitelist:
            raise ValueError(
                "Op %s is used by the graph, but is not registered" % op)

    # Build the stripped op list in sorted order
    return op_def_pb2.OpList(op=[
        registered_ops[op] for op in sorted(used_ops) if op in registered_ops
    ])
Ejemplo n.º 21
0
    def testStripDefaultAttrsInconsistentConsumerDefaults(self):
        export_dir = os.path.join(
            test.get_temp_dir(),
            "test_strip_default_attrs_no_consumer_defaults")
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Add a graph with two float32 variables and a Complex Op composing them
        # with strip_default_attrs enabled. This must remove the following
        # defaults for the "Complex" Op:
        #   o "T"    : float32.   (input type)
        #   o "Tout" : complex64. (output type)
        with session.Session(graph=ops.Graph()) as sess:
            real_num = variables.Variable(1.0,
                                          dtype=dtypes.float32,
                                          name="real")
            imag_num = variables.Variable(2.0,
                                          dtype=dtypes.float32,
                                          name="imag")
            math_ops.complex(real_num, imag_num, name="complex")
            sess.run(variables.global_variables_initializer())
            builder.add_meta_graph_and_variables(sess, ["foo"],
                                                 strip_default_attrs=True)

        # Save the SavedModel to disk in text format.
        builder.save(as_text=True)

        # Update the Op registry to remove defaults for all attrs("T", "Tout") from
        # the "Complex" OpDef.
        complex_op_def = op_def_registry.get_registered_ops()["Complex"]
        original_complex_op_def = op_def_pb2.OpDef()
        original_complex_op_def.CopyFrom(complex_op_def)
        for attr_def in complex_op_def.attr:
            attr_def.ClearField("default_value")

        # Loading the SavedModel via the loader must fail because the SavedModel
        # does not have any attr values for the "Complex" node and the current
        # op registry does not have have any default values for the "Complex" op.
        sess = session.Session(graph=ops.Graph())
        with self.assertRaisesRegexp(
                ValueError,
                "Expected one attr with name .*T(out)?.* in name: \"complex\".*"
        ):
            loader.load(sess, ["foo"], export_dir)

        # Update the Op registry to change the defaults for attr "Tout"
        # (complex64 -> complex128).
        complex_op_def.CopyFrom(original_complex_op_def)
        for attr_def in complex_op_def.attr:
            if attr_def.name == "Tout":
                attr_def.default_value.type = types_pb2.DT_COMPLEX128

        # Loading the SavedModel via the loader must set "Tout" attr_value for the
        # "Complex" node according to the latest defaults (complex128). This is
        # expected to fail the model import as there is no OpKernel registered to
        # handle attrs "T" (float32) and "Tout" (complex128).
        sess = session.Session(graph=ops.Graph())
        with self.assertRaisesRegexp(
                errors.InvalidArgumentError,
                ".*No OpKernel was registered to support Op \'Complex\' with these "
                "attrs..*"):
            loader.load(sess, ["foo"], export_dir)
Ejemplo n.º 22
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None):
    """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    # Type checks for inputs.
    if not isinstance(graph_def, graph_pb2.GraphDef):
        # `graph_def` could be a dynamically-created message, so try a duck-typed
        # approach
        try:
            old_graph_def = graph_def
            graph_def = graph_pb2.GraphDef()
            graph_def.MergeFrom(old_graph_def)
        except TypeError:
            raise TypeError('graph_def must be a GraphDef proto.')
    if input_map is None:
        input_map = {}
    else:
        if not (isinstance(input_map, dict) and all(
                isinstance(k, compat.bytes_or_text_types)
                for k in input_map.keys())):
            raise TypeError(
                'input_map must be a dictionary mapping strings to '
                'Tensor objects.')
    if return_elements is not None:
        return_elements = tuple(return_elements)
        if not all(
                isinstance(x, compat.bytes_or_text_types)
                for x in return_elements):
            raise TypeError('return_elements must be a list of strings.')

    # Use a canonical representation for all tensor names.
    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
    used_input_keys = set()

    name_to_op = {}

    if op_dict is None:
        op_dict = op_def_registry.get_registered_ops()

    with ops.op_scope(input_map.values(), name, 'import'):
        g = ops.get_default_graph()
        g.graph_def_versions.CopyFrom(graph_def.versions)

        with ops.name_scope('_inputs'):
            input_map = {
                k: ops.convert_to_tensor(v)
                for k, v in input_map.items()
            }

        # NOTE(mrry): We do this in two passes, because there may be a cycle in
        # `graph_def`.

        # 1. Add operations without their inputs.
        for node in graph_def.node:
            # Set any default attr values that aren't present.
            op_def = op_dict[node.op]
            for attr_def in op_def.attr:
                key = attr_def.name
                if attr_def.HasField('default_value'):
                    value = node.attr[key]
                    if value is None or value.WhichOneof('value') is None:
                        node.attr[key].CopyFrom(attr_def.default_value)

            output_types = _OutputTypes(node, op_dict)
            name_to_op[node.name] = g.create_op(node.op, [],
                                                output_types,
                                                name=node.name,
                                                attrs=node.attr,
                                                compute_shapes=False,
                                                compute_device=False)

        # 2. Add inputs to the operations.
        for node in graph_def.node:
            op = name_to_op[node.name]
            input_types = _InputTypes(node, op_dict)

            # NOTE(mrry): We cannot use zip here because control inputs do not appear
            # in the list of input_types.
            for i, input_name in enumerate(
                [_CanonicalInputName(x) for x in node.input]):

                if _IsControlInput(input_name):
                    # (a) Input is a control input that should be taken from an op
                    #     in "graph_def".
                    try:
                        source_op = name_to_op[input_name[1:]]
                    except KeyError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'Control input %r not found in graph_def.' %
                                (input_name, )))
                    # pylint: disable=protected-access
                    op._add_control_input(source_op)
                    # pylint: enable=protected-access

                else:
                    try:
                        input_type = input_types[i]
                    except IndexError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'More inputs specified (%r) than the op expects.'
                                % (input_name, )))

                    if input_name in input_map:
                        # (b) Input should be replaced by a tensor from the caller.
                        source_tensor = input_map[input_name]
                        used_input_keys.add(input_name)

                    else:
                        # (c) Input should be taken from an op in `graph_def`.
                        operation_name, output_index = _ParseTensorName(
                            input_name)
                        try:
                            source_op = name_to_op[operation_name]
                            source_tensor = list(
                                source_op.values())[output_index]
                        except (KeyError, IndexError):
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r not found in graph_def.' %
                                    (input_name, )))

                    try:
                        # pylint: disable=protected-access
                        op._add_input(source_tensor, dtype=input_type)
                        # pylint: enable=protected-access
                    except TypeError as te:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node, 'Input tensor %r %s' % (input_name, te)))

            # pylint: disable=protected_access
            if op._input_dtypes != input_types:
                raise ValueError(
                    _InvalidNodeMessage(
                        node, 'Input types mismatch (expected %r but got %r)' %
                        (", ".join(
                            dtypes.as_dtype(x).name
                            for x in input_types), ", ".join(
                                x.name for x in op._input_dtypes))))
            # pylint: enable=protected_access

            # Execute shape inference for this op.
            # NOTE(mrry): If the graph contains a cycle, the full shape information
            # may not be available for this op's inputs.
            ops.set_shapes_for_outputs(op)

            # Apply device functions for this op.
            # NOTE(mrry): We do this after configuring the inputs, because
            # the result of the device functions may depend on the inputs.
            with _MaybeDevice(node.device):
                g._apply_device_functions(op)  # pylint: disable=protected-access

        # Treat unused input mappings as an error, because they are likely to be
        # due to a typo.
        unused_input_keys = frozenset(
            input_map.keys()).difference(used_input_keys)
        if unused_input_keys:
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(unused_input_keys))

        if return_elements is None:
            return None
        else:
            ret = []
            for name in return_elements:
                name = compat.as_str(name)
                if ':' in name:
                    try:
                        operation_name, output_index = _ParseTensorName(name)
                        ret.append(
                            name_to_op[operation_name].outputs[output_index])
                    except (ValueError, KeyError, IndexError):
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
                else:
                    try:
                        ret.append(name_to_op[name])
                    except KeyError:
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
            return ret
Ejemplo n.º 23
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  graph_def = _ProcessGraphDefParam(graph_def)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()

  if graph._c_graph:  # pylint: disable=protected-access
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # Save unique prefix generated by name_scope
      if scope:
        assert scope.endswith('/')
        prefix = scope[:-1]
      else:
        prefix = ''

      # Generate any input map tensors inside name scope
      input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        with errors.raise_exception_on_not_ok_status() as status:
          results = c_api.TF_GraphImportGraphDefWithResults(
              graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    _ProcessNewOps(graph)

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    # Treat input mappings that don't appear in the graph as an error, because
    # they are likely to be due to a typo.
    missing_unused_input_keys = (
        c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
            results))
    if missing_unused_input_keys:
      missing_unused_input_keys = [compat.as_str(s)
                                   for s in missing_unused_input_keys]
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(missing_unused_input_keys))

    if return_elements is None:
      return None
    else:
      return _GatherReturnElements(return_elements, graph, results)

  else:
    g = graph

    # Use a canonical representation for all tensor names.
    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
    used_input_keys = set()
    name_to_op = {}

    # Add any functions defined in `graph_def` to `g`
    if graph_def.library and graph_def.library.function:
      # Copy op_dict so we don't clobber the original
      op_dict = copy.copy(op_dict)
      # pylint: disable=protected-access
      # Note that we do not prepend `name` to the function name. The reasoning
      # is that function names are similar to op definition names, which
      # currently do not have a scoped name or namespace scheme.
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(g)
        op_dict[f.name] = f.definition.signature
      # pylint: enable=protected-access

    # LINT.IfChange
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # TODO(ashankar): Should this just copy over or should it do some
      # more nuanced merging? For example, the graph may already have some
      # marked "bad versions" and we don't want to lose those because of
      # what's in graph_def.versions? The C++ ImporGraphDef does something
      # more nuanced.
      g.graph_def_versions.CopyFrom(graph_def.versions)

      input_map = _ConvertInputMapValues(name, input_map)

      # NOTE(mrry): We do this in two passes, because there may be a cycle in
      # `graph_def`.

      # 1. Add operations without their inputs.
      for node in graph_def.node:
        # Check to see if this op's name matches a previously seen op
        if node.name in name_to_op:
          raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
        # Set any default attr values that aren't present.
        if node.op not in op_dict:
          raise ValueError('No op named %s in defined operations.' % node.op)
        op_def = op_dict[node.op]
        for attr_def in op_def.attr:
          key = attr_def.name
          if attr_def.HasField('default_value'):
            value = node.attr[key]
            if value is None or value.WhichOneof('value') is None:
              node.attr[key].CopyFrom(attr_def.default_value)

        output_types = _OutputTypes(node, op_dict)
        name_to_op[node.name] = g.create_op(
            node.op, [], output_types, name=node.name, attrs=node.attr,
            compute_shapes=False, compute_device=False,
            op_def=op_def)

      # Maps from a node to the ops it is colocated with, if colocation
      # is specified in the attributes.
      colocation_pairs = collections.defaultdict(list)

      # 2. Add inputs to the operations.
      for node in graph_def.node:
        op = name_to_op[node.name]
        input_types = _InputTypes(node, op_dict)
        apply_device_function = True

        # Rewrite the colocation attributes in the graph, since the
        # names of new ops may have changed.
        for key, value in op.node_def.attr.items():
          if key == '_class':
            class_values = value.list
            new_class_values = []
            for class_value in class_values.s:
              if class_value.startswith(b'loc:@'):
                op_to_bind_to = class_value[5:].decode()
                # Find the op by its original name.
                if op_to_bind_to not in name_to_op:
                  raise ValueError('Specified colocation to an op that '
                                   'does not exist during import: %s in %s' % (
                                       op_to_bind_to, node.name))
                original_op = name_to_op[op_to_bind_to]
                new_class_values.append(compat.as_bytes(
                    'loc:@' + original_op.name))
                if op_to_bind_to != node.name:
                  # Keep track of this mapping for a later phase.
                  colocation_pairs[op].append(original_op)
                  # Don't apply this op's device function,
                  # the colocation constraint will ensure
                  # the proper device gets assigned at runtime.
                  apply_device_function = False

              else:
                new_class_values.append(class_value)
            value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
                s=new_class_values))

        # NOTE(mrry): We cannot use zip here because control inputs do not
        # appear in the list of input_types.
        for i, input_name in enumerate(
            [_CanonicalInputName(x) for x in node.input]):

          if _IsControlInput(input_name):
            # (a) Input is a control input that should be taken from an op
            #     in "graph_def".
            try:
              source_op = name_to_op[input_name[1:]]
            except KeyError:
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Control input %r not found in graph_def.'
                      % (input_name,)))
            # pylint: disable=protected-access
            op._add_control_input(source_op)
            # pylint: enable=protected-access

          else:
            try:
              input_type = input_types[i]
            except IndexError:
              raise ValueError(_InvalidNodeMessage(
                  node, 'More inputs specified (%r) than the op expects.'
                  % (input_name,)))

            if input_name in input_map:
              # (b) Input should be replaced by a tensor from the caller.
              source_tensor = input_map[input_name]
              used_input_keys.add(input_name)

            else:
              # (c) Input should be taken from an op in `graph_def`.
              operation_name, output_index = _ParseTensorName(input_name)
              try:
                source_op = name_to_op[operation_name]
                source_tensor = list(source_op.values())[output_index]
              except (KeyError, IndexError):
                raise ValueError(
                    _InvalidNodeMessage(
                        node,
                        'Input tensor %r not found in graph_def.'
                        % (input_name,)))

            try:
              # pylint: disable=protected-access
              op._add_input(source_tensor, dtype=input_type)
              # pylint: enable=protected-access
            except TypeError as te:
              raise ValueError(_InvalidNodeMessage(
                  node, 'Input tensor %r %s' % (input_name, te)))

        # pylint: disable=protected-access
        if op._input_types != input_types:
          raise ValueError(
              _InvalidNodeMessage(
                  node,
                  'Input types mismatch (expected %r but got %r)'
                  % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                     ', '.join(x.name for x in op._input_types))))
        # pylint: enable=protected-access

        if not g._is_function(op.type):  # pylint: disable=protected-access
          # Execute shape inference for this op.
          # NOTE(mrry): If the graph contains a cycle, the full shape
          # information may not be available for this op's inputs.
          ops.set_shapes_for_outputs(op)
        # For nodes with _output_shapes set, set the output shapes.
        if '_output_shapes' in op.node_def.attr:
          for i, output in enumerate(op.outputs):
            dims = op.node_def.attr['_output_shapes'].list.shape[i]
            output_shape = tensor_shape.TensorShape(
                None if dims.unknown_rank else
                [dim.size if dim.size >= 0 else None for dim in dims.dim])

            try:
              output.set_shape(output_shape)
            except ValueError as e:
              # If the output shape is incompatible with what is inferred
              # by the graph for a very specific whitelist of ops, then we
              # ignore this output shape.  This can happen if there is a
              # bug in the shape function for some operation, and the
              # serialized graph def has the incorrect shape set when
              # running on a newer binary with the fixed shape function.
              # This is an escape hatch that allows us to correct shape
              # functions that are not critical to correct execution but
              # would cause graphs to fail if imported after correcting.
              #
              # This can be removed after 2017/03/08.
              if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                             'FIFOQueue', 'PriorityQueue', 'QueueSize',
                             'Stack', 'Barrier', 'BarrierReadySize',
                             'BarrierIncompleteSize', 'HashTable',
                             'MutableHashTable',
                             'MutableHashTableOfTensors', 'Mutex',
                             'CuckooTable', 'IndexTable',
                             'WholeFileReader', 'TextLineReader',
                             'FixedLengthRecordReader',
                             'TFRecordReader', 'IdentityReader',
                             'LMDBReader',
                             'RefSwitch', 'RefEnter', 'RefNextIteration',
                             'RefMerge', 'RefIdentity']:
                pass
              elif op.type in [
                  'ConditionalAccumulator', 'SparseConditionalAccumulator',
                  'Table'
              ]:
                # This can be removed after 2017/04/24.
                pass
              else:
                raise e

          del op.node_def.attr['_output_shapes']

        # NOTE(mrry): We do this after configuring the inputs, because
        # the result of the device functions may depend on the inputs.
        if apply_device_function:
          with _MaybeDevice(node.device):
            g._apply_device_functions(op)  # pylint: disable=protected-access

      # The following loop populates the device field of ops that are
      # colocated with another op.  This is implied by the colocation
      # attribute, but we propagate the device field for completeness.
      for op, coloc_op_list in colocation_pairs.items():
        coloc_device = None
        # Find any device in the list of colocated ops that have a
        # device, if it exists.  We assume that if multiple ops
        # have devices, they refer to the same device.  Otherwise, a
        # runtime error will occur since the colocation property
        # cannot be guaranteed.
        #
        # One possible improvement is to try to check for compatibility
        # of all devices in this list at import time here, which would
        # require implementing a compatibility function for device specs
        # in python.
        for coloc_op in coloc_op_list:
          if coloc_op.device:
            coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
            break
        if coloc_device:
          op._set_device(coloc_device)  # pylint: disable=protected-access

      # Treat input mappings that don't appear in the graph as an error,
      # because they are likely to be due to a typo.
      def _IsImportedNodeOutput(tensor_name):
        operation_name, output_index = _ParseTensorName(tensor_name)
        try:
          return output_index < len(name_to_op[operation_name].outputs)
        except KeyError:
          return False
      absent_input_keys = [
          k for k in frozenset(input_map.keys()).difference(used_input_keys)
          if not _IsImportedNodeOutput(k)]
      if absent_input_keys:
        raise ValueError(
            'Attempted to map inputs that were not found in graph_def: [%s]'
            % ', '.join(absent_input_keys))

      if return_elements is None:
        return None
      else:
        ret = []
        for name in return_elements:
          name = compat.as_str(name)
          if ':' in name:
            try:
              operation_name, output_index = _ParseTensorName(name)
              ret.append(name_to_op[operation_name].outputs[output_index])
            except (ValueError, KeyError, IndexError):
              raise ValueError(
                  'Requested return_element %r not found in graph_def.' % name)
          else:
            try:
              ret.append(name_to_op[name])
            except KeyError:
              raise ValueError(
                  'Requested return_element %r not found in graph_def.' % name)
        return ret
Ejemplo n.º 24
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    op_dict = op_def_registry.get_registered_ops()

    graph_def = _ProcessGraphDefParam(graph_def, op_dict)
    input_map = _ProcessInputMapParam(input_map)
    return_elements = _ProcessReturnElementsParam(return_elements)

    if producer_op_list is not None:
        # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
        _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

    graph = ops.get_default_graph()

    if graph._c_graph:  # pylint: disable=protected-access
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # Save unique prefix generated by name_scope
            if scope:
                assert scope.endswith('/')
                prefix = scope[:-1]
            else:
                prefix = ''

            # Generate any input map tensors inside name scope
            input_map = _ConvertInputMapValues(name, input_map)

        scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
        options = scoped_options.options
        _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                         return_elements)

        # _ProcessNewOps mutates the new operations. _lock ensures a Session.run
        # call cannot occur between creating the TF_Operations in the
        # TF_GraphImportGraphDefWithResults call and mutating the them in
        # _ProcessNewOps.
        with graph._lock:  # pylint: disable=protected-access
            with c_api_util.tf_buffer(
                    graph_def.SerializeToString()) as serialized:
                try:
                    with errors.raise_exception_on_not_ok_status() as status:
                        results = c_api.TF_GraphImportGraphDefWithResults(
                            graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
                except errors.InvalidArgumentError as e:
                    # Convert to ValueError for backwards compatibility.
                    raise ValueError(str(e))

            _ProcessNewOps(graph)

        # Create _DefinedFunctions for any imported functions.
        #
        # We do this by creating _DefinedFunctions directly from `graph_def`, and
        # adding them to `graph`. Adding an existing function to a TF_Graph is a
        # no-op, so this only has the effect of updating the Python state (usually
        # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
        #
        # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
        # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
        if graph_def.library and graph_def.library.function:
            # pylint: disable=protected-access
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(graph)
            # pylint: enable=protected-access

        # Treat input mappings that don't appear in the graph as an error, because
        # they are likely to be due to a typo.
        missing_unused_input_keys = (
            c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
                results))
        if missing_unused_input_keys:
            missing_unused_input_keys = [
                compat.as_str(s) for s in missing_unused_input_keys
            ]
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(missing_unused_input_keys))

        if return_elements is None:
            return None
        else:
            return _GatherReturnElements(return_elements, graph, results)

    else:
        g = graph

        # Use a canonical representation for all tensor names.
        input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
        used_input_keys = set()
        name_to_op = {}

        # Add any functions defined in `graph_def` to `g`
        if graph_def.library and graph_def.library.function:
            # Copy op_dict so we don't clobber the original
            op_dict = copy.copy(op_dict)
            # pylint: disable=protected-access
            # Note that we do not prepend `name` to the function name. The reasoning
            # is that function names are similar to op definition names, which
            # currently do not have a scoped name or namespace scheme.
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(g)
                op_dict[f.name] = f.definition.signature
            # pylint: enable=protected-access

        # LINT.IfChange
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # TODO(ashankar): Should this just copy over or should it do some
            # more nuanced merging? For example, the graph may already have some
            # marked "bad versions" and we don't want to lose those because of
            # what's in graph_def.versions? The C++ ImporGraphDef does something
            # more nuanced.
            g.graph_def_versions.CopyFrom(graph_def.versions)

            input_map = _ConvertInputMapValues(name, input_map)

            # NOTE(mrry): We do this in two passes, because there may be a cycle in
            # `graph_def`.

            # 1. Add operations without their inputs.
            for node in graph_def.node:
                # Check to see if this op's name matches a previously seen op
                if node.name in name_to_op:
                    raise ValueError('Duplicate name \'%s\' in GraphDef.' %
                                     node.name)
                if node.op not in op_dict:
                    raise ValueError('No op named %s in defined operations.' %
                                     node.op)
                op_def = op_dict[node.op]

                output_types = _OutputTypes(node, op_dict)
                name_to_op[node.name] = g.create_op(node.op, [],
                                                    output_types,
                                                    name=node.name,
                                                    attrs=node.attr,
                                                    compute_shapes=False,
                                                    compute_device=False,
                                                    op_def=op_def)

            # Maps from a node to the ops it is colocated with, if colocation
            # is specified in the attributes.
            colocation_pairs = collections.defaultdict(list)

            # 2. Add inputs to the operations.
            for node in graph_def.node:
                op = name_to_op[node.name]
                input_types = _InputTypes(node, op_dict)
                apply_device_function = True

                # Rewrite the colocation attributes in the graph, since the
                # names of new ops may have changed.
                for key, value in op.node_def.attr.items():
                    if key == '_class':
                        class_values = value.list
                        new_class_values = []
                        for class_value in class_values.s:
                            if class_value.startswith(b'loc:@'):
                                op_to_bind_to = class_value[5:].decode()
                                # Find the op by its original name.
                                if op_to_bind_to not in name_to_op:
                                    raise ValueError(
                                        'Specified colocation to an op that '
                                        'does not exist during import: %s in %s'
                                        % (op_to_bind_to, node.name))
                                original_op = name_to_op[op_to_bind_to]
                                new_class_values.append(
                                    compat.as_bytes('loc:@' +
                                                    original_op.name))
                                if op_to_bind_to != node.name:
                                    # Keep track of this mapping for a later phase.
                                    colocation_pairs[op].append(original_op)
                                    # Don't apply this op's device function,
                                    # the colocation constraint will ensure
                                    # the proper device gets assigned at runtime.
                                    apply_device_function = False

                            else:
                                new_class_values.append(class_value)
                        value.list.CopyFrom(
                            attr_value_pb2.AttrValue.ListValue(
                                s=new_class_values))

                # NOTE(mrry): We cannot use zip here because control inputs do not
                # appear in the list of input_types.
                for i, input_name in enumerate(
                    [_CanonicalInputName(x) for x in node.input]):

                    if _IsControlInput(input_name):
                        # (a) Input is a control input that should be taken from an op
                        #     in "graph_def".
                        try:
                            source_op = name_to_op[input_name[1:]]
                        except KeyError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Control input %r not found in graph_def.'
                                    % (input_name, )))
                        # pylint: disable=protected-access
                        op._add_control_input(source_op)
                        # pylint: enable=protected-access

                    else:
                        try:
                            input_type = input_types[i]
                        except IndexError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'More inputs specified (%r) than the op expects.'
                                    % (input_name, )))

                        if input_name in input_map:
                            # (b) Input should be replaced by a tensor from the caller.
                            source_tensor = input_map[input_name]
                            used_input_keys.add(input_name)

                        else:
                            # (c) Input should be taken from an op in `graph_def`.
                            operation_name, output_index = _ParseTensorName(
                                input_name)
                            try:
                                source_op = name_to_op[operation_name]
                                source_tensor = list(
                                    source_op.values())[output_index]
                            except (KeyError, IndexError):
                                raise ValueError(
                                    _InvalidNodeMessage(
                                        node,
                                        'Input tensor %r not found in graph_def.'
                                        % (input_name, )))

                        try:
                            # pylint: disable=protected-access
                            op._add_input(source_tensor, dtype=input_type)
                            # pylint: enable=protected-access
                        except TypeError as te:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r %s' % (input_name, te)))

                # pylint: disable=protected-access
                if op._input_types != input_types:
                    raise ValueError(
                        _InvalidNodeMessage(
                            node,
                            'Input types mismatch (expected %r but got %r)' %
                            (', '.join(
                                dtypes.as_dtype(x).name
                                for x in input_types), ', '.join(
                                    x.name for x in op._input_types))))
                # pylint: enable=protected-access

                if not g._is_function(op.type):  # pylint: disable=protected-access
                    # Execute shape inference for this op.
                    # NOTE(mrry): If the graph contains a cycle, the full shape
                    # information may not be available for this op's inputs.
                    ops.set_shapes_for_outputs(op)
                # For nodes with _output_shapes set, set the output shapes.
                if '_output_shapes' in op.node_def.attr:
                    for i, output in enumerate(op.outputs):
                        dims = op.node_def.attr['_output_shapes'].list.shape[i]
                        output_shape = tensor_shape.TensorShape(
                            None if dims.unknown_rank else [
                                dim.size if dim.size >= 0 else None
                                for dim in dims.dim
                            ])

                        try:
                            output.set_shape(output_shape)
                        except ValueError as e:
                            # If the output shape is incompatible with what is inferred
                            # by the graph for a very specific whitelist of ops, then we
                            # ignore this output shape.  This can happen if there is a
                            # bug in the shape function for some operation, and the
                            # serialized graph def has the incorrect shape set when
                            # running on a newer binary with the fixed shape function.
                            # This is an escape hatch that allows us to correct shape
                            # functions that are not critical to correct execution but
                            # would cause graphs to fail if imported after correcting.
                            #
                            # This can be removed after 2017/03/08.
                            if op.type in [
                                    'RandomShuffleQueue', 'PaddingFIFOQueue',
                                    'FIFOQueue', 'PriorityQueue', 'QueueSize',
                                    'Stack', 'Barrier', 'BarrierReadySize',
                                    'BarrierIncompleteSize', 'HashTable',
                                    'MutableHashTable',
                                    'MutableHashTableOfTensors', 'Mutex',
                                    'CuckooTable', 'IndexTable',
                                    'WholeFileReader', 'TextLineReader',
                                    'FixedLengthRecordReader',
                                    'TFRecordReader', 'IdentityReader',
                                    'LMDBReader', 'RefSwitch', 'RefEnter',
                                    'RefNextIteration', 'RefMerge',
                                    'RefIdentity'
                            ]:
                                pass
                            elif op.type in [
                                    'ConditionalAccumulator',
                                    'SparseConditionalAccumulator', 'Table'
                            ]:
                                # This can be removed after 2017/04/24.
                                pass
                            else:
                                raise e

                    del op.node_def.attr['_output_shapes']

                # NOTE(mrry): We do this after configuring the inputs, because
                # the result of the device functions may depend on the inputs.
                if apply_device_function:
                    with _MaybeDevice(node.device):
                        g._apply_device_functions(op)  # pylint: disable=protected-access

            # The following loop populates the device field of ops that are
            # colocated with another op.  This is implied by the colocation
            # attribute, but we propagate the device field for completeness.
            for op, coloc_op_list in colocation_pairs.items():
                coloc_device = None
                # Find any device in the list of colocated ops that have a
                # device, if it exists.  We assume that if multiple ops
                # have devices, they refer to the same device.  Otherwise, a
                # runtime error will occur since the colocation property
                # cannot be guaranteed.
                #
                # One possible improvement is to try to check for compatibility
                # of all devices in this list at import time here, which would
                # require implementing a compatibility function for device specs
                # in python.
                for coloc_op in coloc_op_list:
                    if coloc_op.device:
                        coloc_device = pydev.DeviceSpec.from_string(
                            coloc_op.device)
                        break
                if coloc_device:
                    op._set_device(coloc_device)  # pylint: disable=protected-access

            # Treat input mappings that don't appear in the graph as an error,
            # because they are likely to be due to a typo.
            def _IsImportedNodeOutput(tensor_name):
                operation_name, output_index = _ParseTensorName(tensor_name)
                try:
                    return output_index < len(
                        name_to_op[operation_name].outputs)
                except KeyError:
                    return False

            absent_input_keys = [
                k for k in frozenset(input_map.keys()).difference(
                    used_input_keys) if not _IsImportedNodeOutput(k)
            ]
            if absent_input_keys:
                raise ValueError(
                    'Attempted to map inputs that were not found in graph_def: [%s]'
                    % ', '.join(absent_input_keys))

            if return_elements is None:
                return None
            else:
                ret = []
                for name in return_elements:
                    name = compat.as_str(name)
                    if ':' in name:
                        try:
                            operation_name, output_index = _ParseTensorName(
                                name)
                            ret.append(name_to_op[operation_name].
                                       outputs[output_index])
                        except (ValueError, KeyError, IndexError):
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                    else:
                        try:
                            ret.append(name_to_op[name])
                        except KeyError:
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                return ret
Ejemplo n.º 25
0
def _get_op_def(op):
  # pylint: disable=protected-access
  if hasattr(op, "_sig"):
    return getattr(op, "_sig")
  else:
    return op_def_registry.get_registered_ops()[op.type]
Ejemplo n.º 26
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      f.add_to_graph(g)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

  # LINT.IfChange
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # TODO(ashankar): Should this just copy over or should it do some
    # more nuanced merging? For example, the graph may already have some
    # marked "bad versions" and we don't want to lose those because of
    # what's in graph_def.versions? The C++ ImporGraphDef does something
    # more nuanced.
    g.graph_def_versions.CopyFrom(graph_def.versions)

    if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
      if not scope:
        # The caller must have passed `name=''`.
        raise ValueError(
            'tf.import_graph_def() requires a non-empty `name` if `input_map` '
            'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
            '`input_map` values before calling tf.import_graph_def().')
      with ops.name_scope('_inputs'):
        input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Check to see if this op's name matches a previously seen op
      if node.name in name_to_op:
        raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
      # Set any default attr values that aren't present.
      if node.op not in op_dict:
        raise ValueError('No op named %s in defined operations.' % node.op)
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)
      if producer_op_dict:
        # Remove any default attr values that aren't in op_def.
        if node.op in producer_op_dict:
          producer_op_def = producer_op_dict[node.op]
          # We make a copy of node.attr to iterate through since we
          # may modify node.attr inside the loop.
          for key in list(node.attr):
            if _FindAttrInOpDef(key, op_def) is None:
              # No attr_def in consumer, look in producer.
              attr_def = _FindAttrInOpDef(key, producer_op_def)
              if (attr_def and attr_def.HasField('default_value') and
                  node.attr[key] == attr_def.default_value):
                # Unknown attr had default value in producer, delete it
                # so it can be understood by consumer.
                del node.attr[key]

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # Maps from a node to the ops it is colocated with, if colocation
    # is specified in the attributes.
    colocation_pairs = collections.defaultdict(list)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)
      apply_device_function = True

      # Rewrite the colocation attributes in the graph, since the
      # names of new ops may have changed.
      for key, value in op.node_def.attr.items():
        if key == '_class':
          class_values = value.list
          new_class_values = []
          for class_value in class_values.s:
            if class_value.startswith(b'loc:@'):
              op_to_bind_to = class_value[5:].decode()
              # Find the op by its original name.
              if op_to_bind_to not in name_to_op:
                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' % (
                                     op_to_bind_to, node.name))
              original_op = name_to_op[op_to_bind_to]
              new_class_values.append(compat.as_bytes(
                  'loc:@' + original_op.name))
              if op_to_bind_to != node.name:
                # Keep track of this mapping for a later phase.
                colocation_pairs[op].append(original_op)
                # Don't apply this op's device function,
                # the colocation constraint will ensure
                # the proper device gets assigned at runtime.
                apply_device_function = False

            else:
              new_class_values.append(class_value)
          value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
              s=new_class_values))

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def`.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(_InvalidNodeMessage(
                node, 'Input tensor %r %s' % (input_name, te)))

      # pylint: disable=protected-access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                   ', '.join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected-access

      if not g._is_function(op.type):  # pylint: disable=protected-access
        # Execute shape inference for this op.
        # NOTE(mrry): If the graph contains a cycle, the full shape information
        # may not be available for this op's inputs.
        ops.set_shapes_for_outputs(op)
      # For nodes with _output_shapes set, set the output shapes.
      if '_output_shapes' in op.node_def.attr:
        for i, output in enumerate(op.outputs):
          dims = op.node_def.attr['_output_shapes'].list.shape[i]
          output_shape = tensor_shape.TensorShape(
              None if dims.unknown_rank else
              [dim.size if dim.size >= 0 else None for dim in dims.dim])

          try:
            output.set_shape(output_shape)
          except ValueError as e:
            # If the output shape is incompatible with what is inferred
            # by the graph for a very specific whitelist of ops, then we
            # ignore this output shape.  This can happen if there is a
            # bug in the shape function for some operation, and the
            # serialized graph def has the incorrect shape set when
            # running on a newer binary with the fixed shape function.
            # This is an escape hatch that allows us to correct shape
            # functions that are not critical to correct execution but
            # would cause graphs to fail if imported after correcting.
            #
            # This can be removed after 2017/03/08.
            if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                           'FIFOQueue', 'PriorityQueue', 'QueueSize',
                           'Stack', 'Barrier', 'BarrierReadySize',
                           'BarrierIncompleteSize', 'HashTable',
                           'MutableHashTable',
                           'MutableHashTableOfTensors', 'Mutex',
                           'CuckooTable', 'IndexTable',
                           'WholeFileReader', 'TextLineReader',
                           'FixedLengthRecordReader',
                           'TFRecordReader', 'IdentityReader',
                           'LMDBReader',
                           'RefSwitch', 'RefEnter', 'RefNextIteration',
                           'RefMerge', 'RefIdentity']:
              pass
            elif op.type in [
                'ConditionalAccumulator', 'SparseConditionalAccumulator',
                'Table'
            ]:
              # This can be removed after 2017/04/24.
              pass
            else:
              raise e

        del op.node_def.attr['_output_shapes']

      # NOTE(mrry): We do this after configuring the inputs, because
      # the result of the device functions may depend on the inputs.
      if apply_device_function:
        with _MaybeDevice(node.device):
          g._apply_device_functions(op)  # pylint: disable=protected-access

    # The following loop populates the device field of ops that are
    # colocated with another op.  This is implied by the colocation
    # attribute, but we propagate the device field for completeness.
    for op, coloc_op_list in colocation_pairs.items():
      coloc_device = None
      # Find any device in the list of colocated ops that have a
      # device, if it exists.  We assume that if multiple ops
      # have devices, they refer to the same device.  Otherwise, a
      # runtime error will occur since the colocation property
      # cannot be guaranteed.
      #
      # One possible improvement is to try to check for compatibility
      # of all devices in this list at import time here, which would
      # require implementing a compatibility function for device specs
      # in python.
      for coloc_op in coloc_op_list:
        if coloc_op.device:
          coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
          break
      if coloc_device:
        op._set_device(coloc_device)  # pylint: disable=protected-access

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        name = compat.as_str(name)
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
def _get_op_def(op):
    return op.op_def or op_def_registry.get_registered_ops()[op.type]
Ejemplo n.º 28
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    # Type checks for inputs.
    if not isinstance(graph_def, graph_pb2.GraphDef):
        # `graph_def` could be a dynamically-created message, so try a duck-typed
        # approach
        try:
            old_graph_def = graph_def
            graph_def = graph_pb2.GraphDef()
            graph_def.MergeFrom(old_graph_def)
        except TypeError:
            raise TypeError('graph_def must be a GraphDef proto.')
    if input_map is None:
        input_map = {}
    else:
        if not (isinstance(input_map, dict) and all(
                isinstance(k, compat.bytes_or_text_types)
                for k in input_map.keys())):
            raise TypeError(
                'input_map must be a dictionary mapping strings to '
                'Tensor objects.')
    if return_elements is not None:
        return_elements = tuple(return_elements)
        if not all(
                isinstance(x, compat.bytes_or_text_types)
                for x in return_elements):
            raise TypeError('return_elements must be a list of strings.')

    # Use a canonical representation for all tensor names.
    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
    used_input_keys = set()

    name_to_op = {}

    if op_dict is None:
        op_dict = op_def_registry.get_registered_ops()

    if producer_op_list is None:
        producer_op_dict = None
    else:
        producer_op_dict = {op.name: op for op in producer_op_list.op}

    with ops.op_scope(input_map.values(), name, 'import'):
        g = ops.get_default_graph()
        g.graph_def_versions.CopyFrom(graph_def.versions)

        with ops.name_scope('_inputs'):
            input_map = {
                k: ops.convert_to_tensor(v)
                for k, v in input_map.items()
            }

        # NOTE(mrry): We do this in two passes, because there may be a cycle in
        # `graph_def`.

        # 1. Add operations without their inputs.
        for node in graph_def.node:
            # Set any default attr values that aren't present.
            op_def = op_dict[node.op]
            for attr_def in op_def.attr:
                key = attr_def.name
                if attr_def.HasField('default_value'):
                    value = node.attr[key]
                    if value is None or value.WhichOneof('value') is None:
                        node.attr[key].CopyFrom(attr_def.default_value)
            if producer_op_dict:
                # Remove any default attr values that aren't in op_def.
                if node.op in producer_op_dict:
                    producer_op_def = producer_op_dict[node.op]
                    # We make a copy of node.attr to iterate through since we
                    # may modify node.attr inside the loop.
                    for key in list(node.attr):
                        if _FindAttrInOpDef(key, op_def) is None:
                            # No attr_def in consumer, look in producer.
                            attr_def = _FindAttrInOpDef(key, producer_op_def)
                            if (attr_def and attr_def.HasField('default_value')
                                    and node.attr[key]
                                    == attr_def.default_value):
                                # Unknown attr had default value in producer, delete it
                                # so it can be understood by consumer.
                                del node.attr[key]

            output_types = _OutputTypes(node, op_dict)
            name_to_op[node.name] = g.create_op(node.op, [],
                                                output_types,
                                                name=node.name,
                                                attrs=node.attr,
                                                compute_shapes=False,
                                                compute_device=False,
                                                op_def=op_def)

        # 2. Add inputs to the operations.
        for node in graph_def.node:
            op = name_to_op[node.name]
            input_types = _InputTypes(node, op_dict)

            # Rewrite the colocation attributes in the graph, since the
            # names of new ops may have changed.
            for key, value in op.node_def.attr.items():
                if key == '_class':
                    class_values = value.list
                    new_class_values = []
                    for class_value in class_values.s:
                        if class_value.startswith(b'loc:@'):
                            op_to_bind_to = class_value[5:].decode()
                            # Find the op by its original name.
                            if op_to_bind_to not in name_to_op:
                                raise ValueError(
                                    'Specified colocation to an op that '
                                    'does not exist during import: %s in %s' %
                                    (op_to_bind_to, node.name))
                            original_op = name_to_op[op_to_bind_to]
                            new_class_values.append(
                                compat.as_bytes('loc:@' + original_op.name))
                        else:
                            new_class_values.append(class_value)
                    value.list.CopyFrom(
                        attr_value_pb2.AttrValue.ListValue(s=new_class_values))

            # NOTE(mrry): We cannot use zip here because control inputs do not appear
            # in the list of input_types.
            for i, input_name in enumerate(
                [_CanonicalInputName(x) for x in node.input]):

                if _IsControlInput(input_name):
                    # (a) Input is a control input that should be taken from an op
                    #     in "graph_def".
                    try:
                        source_op = name_to_op[input_name[1:]]
                    except KeyError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'Control input %r not found in graph_def.' %
                                (input_name, )))
                    # pylint: disable=protected-access
                    op._add_control_input(source_op)
                    # pylint: enable=protected-access

                else:
                    try:
                        input_type = input_types[i]
                    except IndexError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'More inputs specified (%r) than the op expects.'
                                % (input_name, )))

                    if input_name in input_map:
                        # (b) Input should be replaced by a tensor from the caller.
                        source_tensor = input_map[input_name]
                        used_input_keys.add(input_name)

                    else:
                        # (c) Input should be taken from an op in `graph_def`.
                        operation_name, output_index = _ParseTensorName(
                            input_name)
                        try:
                            source_op = name_to_op[operation_name]
                            source_tensor = list(
                                source_op.values())[output_index]
                        except (KeyError, IndexError):
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r not found in graph_def.' %
                                    (input_name, )))

                    try:
                        # pylint: disable=protected-access
                        op._add_input(source_tensor, dtype=input_type)
                        # pylint: enable=protected-access
                    except TypeError as te:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node, 'Input tensor %r %s' % (input_name, te)))

            # pylint: disable=protected_access
            if op._input_dtypes != input_types:
                raise ValueError(
                    _InvalidNodeMessage(
                        node, 'Input types mismatch (expected %r but got %r)' %
                        (', '.join(
                            dtypes.as_dtype(x).name
                            for x in input_types), ', '.join(
                                x.name for x in op._input_dtypes))))
            # pylint: enable=protected_access

            # Execute shape inference for this op.
            # NOTE(mrry): If the graph contains a cycle, the full shape information
            # may not be available for this op's inputs.
            ops.set_shapes_for_outputs(op)
            # For nodes with _output_shapes set, set the output shapes.
            if '_output_shapes' in op.node_def.attr:
                for i, output in enumerate(op.outputs):
                    dims = op.node_def.attr['_output_shapes'].list.shape[i]
                    output_shape = tensor_shape.TensorShape(
                        None if dims.unknown_rank else [
                            dim.size if dim.size >= 0 else None
                            for dim in dims.dim
                        ])
                    output.set_shape(output_shape)
                del op.node_def.attr['_output_shapes']

            # Apply device functions for this op.
            # NOTE(mrry): We do this after configuring the inputs, because
            # the result of the device functions may depend on the inputs.
            with _MaybeDevice(node.device):
                g._apply_device_functions(op)  # pylint: disable=protected-access

        # Treat unused input mappings as an error, because they are likely to be
        # due to a typo.
        unused_input_keys = frozenset(
            input_map.keys()).difference(used_input_keys)
        if unused_input_keys:
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(unused_input_keys))

        if return_elements is None:
            return None
        else:
            ret = []
            for name in return_elements:
                name = compat.as_str(name)
                if ':' in name:
                    try:
                        operation_name, output_index = _ParseTensorName(name)
                        ret.append(
                            name_to_op[operation_name].outputs[output_index])
                    except (ValueError, KeyError, IndexError):
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
                else:
                    try:
                        ret.append(name_to_op[name])
                    except KeyError:
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
            return ret
Ejemplo n.º 29
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  op_dict = op_def_registry.get_registered_ops()

  graph_def = _ProcessGraphDefParam(graph_def, op_dict)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # Save unique prefix generated by name_scope
    if scope:
      assert scope.endswith('/')
      prefix = scope[:-1]
    else:
      prefix = ''

    # Generate any input map tensors inside name scope
    input_map = _ConvertInputMapValues(name, input_map)

  scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
  options = scoped_options.options
  _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                   return_elements)

  # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
  # Session.run call cannot occur between creating the TF_Operations in the
  # TF_GraphImportGraphDefWithResults call and mutating the them in
  # _ProcessNewOps.
  with graph._mutation_lock():  # pylint: disable=protected-access
    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        results = c_api.TF_GraphImportGraphDefWithResults(
            graph._c_graph, serialized, options)  # pylint: disable=protected-access
        results = c_api_util.ScopedTFImportGraphDefResults(results)
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
    # _USE_C_SHAPES is removed.
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    _ProcessNewOps(graph)

  # Treat input mappings that don't appear in the graph as an error, because
  # they are likely to be due to a typo.
  missing_unused_input_keys = (
      c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
          results.results))
  if missing_unused_input_keys:
    missing_unused_input_keys = [
        compat.as_str(s) for s in missing_unused_input_keys
    ]
    raise ValueError(
        'Attempted to map inputs that were not found in graph_def: [%s]' %
        ', '.join(missing_unused_input_keys))

  if return_elements is None:
    return None
  else:
    return _GatherReturnElements(return_elements, graph, results.results)
Ejemplo n.º 30
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None):
  """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements'.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map' is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, six.string_types) for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if (return_elements is not None
      and not (isinstance(return_elements, (list, tuple))
               and all(isinstance(x, six.string_types)
                       for x in return_elements))):
    raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  with ops.op_scope(input_map.values(), name, 'import'):
    g = ops.get_default_graph()

    with ops.name_scope('_inputs'):
      input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def'.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      output_types = _OutputTypes(node, op_dict)
      with _MaybeDevice(node.device):
        name_to_op[node.name] = g.create_op(
            node.op, [], output_types, name=node.name, attrs=node.attr,
            compute_shapes=False)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def'.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(
                _InvalidNodeMessage(node, 'Input tensor %r %s'
                                    % (input_name, te.message)))

      # pylint: disable=protected_access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (", ".join(types_lib.as_dtype(x).name for x in input_types),
                   ", ".join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected_access

      # Execute shape inference for this op.
      # NOTE(mrry): If the graph contains a cycle, the full shape information
      # may not be available for this op's inputs.
      ops.set_shapes_for_outputs(op)

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
def function_def_to_graph_def(fdef, input_shapes=None):
  """Convert a FunctionDef to a GraphDef.

  Steps:
  1. Creates placeholder nodes corresponding to inputs in
     `FunctionDef.signature.input_arg`.
  2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
  3. Renames inputs of all nodes to use the convention of GraphDef instead of
     FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
     in FunctionDefs is different from GraphDefs.

  Args:
    fdef: FunctionDef.
    input_shapes: Optional. A list of TensorShape objects of the shapes of
      function inputs. If specified, its length must match length of
      `fdef.signature.input_arg`. If a shape is None, the corresponding input
      placeholder will have unknown shape.

  Returns:
    A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
    from nested tensor names (in FunctionDef) to flattened names (in GraphDef).

  Raises:
    ValueError: If the length of input_shapes does not match the number of
      input_args or if the FunctionDef is invalid.
  """
  graph_def = graph_pb2.GraphDef()
  graph_def.versions.CopyFrom(
      versions_pb2.VersionDef(
          producer=versions.GRAPH_DEF_VERSION,
          min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))

  if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
    raise ValueError("Length of input_shapes must match the number of " +
                     "input_args. len(input_shapes): {} len(input_arg): {}".
                     format(len(input_shapes), len(fdef.signature.input_arg)))

  # 1. Create placeholders for input nodes.
  for i, arg_def in enumerate(fdef.signature.input_arg):
    node_def = graph_def.node.add()
    node_def.name = arg_def.name
    node_def.op = "Placeholder"
    node_def.attr["dtype"].type = arg_def.type
    if input_shapes and input_shapes[i] is not None:
      node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto())

  # 2. Copy all body NodeDefs to the GraphDef.
  graph_def.node.extend(fdef.node_def)

  # 3. Perform the renaming.

  # Build the tensor name mapping then flatten the tensor names.
  # See comment on `FunctionDef.node_def` on how the tensor naming in
  # FunctionDefs is different from GraphDefs.
  nested_to_flat_tensor_name = {}

  for arg_def in fdef.signature.input_arg:
    nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)

  for node_def in fdef.node_def:
    op_def = op_def_registry.get_registered_ops().get(node_def.op)
    if not op_def:
      # TODO(b/80470245): Support functions which refer other functions.
      raise NotImplementedError(
          "No op registered for {},".format(node_def.op) +
          " it may be a function. function_def_to_graph_def " +
          "currently does not support converting functions with " +
          "references to other graph functions.")

    for attr in op_def.attr:
      if attr.type in ("func", "list(func)"):
        # TODO(b/80470245): Support functions which refer other functions.
        raise NotImplementedError("Unsupported attr {} ".format(attr.name) +
                                  " with type {}".format(attr.type) +
                                  " in op {}. ".format(op_def.name) +
                                  "function_def_to_graph_def currently does " +
                                  "not support converting functions with " +
                                  "references to other graph functions.")

    # Iterate over output_args in op_def to build the map.
    # Index of the output tensor in the flattened list of *all* output
    # tensors of the op.
    flattened_index = 0
    for arg_def in op_def.output_arg:
      num_args = _get_num_args(arg_def, node_def)
      for i in range(num_args):
        # Map tensor names from "node_name:output_arg_name:index" to
        # "node_name:flattened_index".
        nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
        flat_name = "{}:{}".format(node_def.name, flattened_index)
        nested_to_flat_tensor_name[nested_name] = flat_name
        flattened_index += 1

  # Update inputs of all nodes in graph.
  for node_def in graph_def.node:
    for i in range(len(node_def.input)):
      node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]

  return graph_def, nested_to_flat_tensor_name
Ejemplo n.º 32
0
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top


_REGISTERED_OPS = op_def_registry.get_registered_ops()


def enable_jit_nonstateful(node_def):
  try:
    return not _REGISTERED_OPS[node_def.op].is_stateful
  except KeyError:
    raise ValueError("Unregistered op being created: %s" % node_def)


class JITTest(test.TestCase):

  def compute(self, use_jit, compute_fn):
    random_seed.set_random_seed(1234)
    with self.test_session(graph=ops.Graph()) as sess:
      with jit.experimental_jit_scope(use_jit):
Ejemplo n.º 33
0
def _stripped_op_list_for_graph(graph_def):
  registered_ops = op_def_registry.get_registered_ops()
  used_ops = {n.op for n in graph_def.node}
  op_list = [registered_ops[op_name] for op_name in sorted(used_ops)]
  return op_def_pb2.OpList(op=op_list)
Ejemplo n.º 34
0
def function_def_to_graph_def(fdef, input_shapes=None):
  """Convert a FunctionDef to a GraphDef.

  Steps:
  1. Creates placeholder nodes corresponding to inputs in
     `FunctionDef.signature.input_arg`.
  2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
  3. Renames inputs of all nodes to use the convention of GraphDef instead of
     FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
     in FunctionDefs is different from GraphDefs.

  Args:
    fdef: FunctionDef.
    input_shapes: Optional. A list of TensorShape objects of the shapes of
      function inputs. If specified, its length must match length of
      `fdef.signature.input_arg`. If a shape is None, the corresponding input
      placeholder will have unknown shape.

  Returns:
    A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
    from nested tensor names (in FunctionDef) to flattened names (in GraphDef).

  Raises:
    ValueError: If the length of input_shapes does not match the number of
      input_args or if the FunctionDef is invalid.
  """
  graph_def = graph_pb2.GraphDef()
  graph_def.versions.CopyFrom(
      versions_pb2.VersionDef(
          producer=versions.GRAPH_DEF_VERSION,
          min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))

  if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
    raise ValueError("Length of input_shapes must match the number of " +
                     "input_args. len(input_shapes): {} len(input_arg): {}".
                     format(len(input_shapes), len(fdef.signature.input_arg)))

  # 1. Create placeholders for input nodes.
  for i, arg_def in enumerate(fdef.signature.input_arg):
    node_def = graph_def.node.add()
    node_def.name = arg_def.name
    node_def.op = "Placeholder"
    node_def.attr["dtype"].type = arg_def.type
    if input_shapes and input_shapes[i] is not None:
      node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto())

  # 2. Copy all body NodeDefs to the GraphDef.
  graph_def.node.extend(fdef.node_def)

  # 3. Perform the renaming.

  # Build the tensor name mapping then flatten the tensor names.
  # See comment on `FunctionDef.node_def` on how the tensor naming in
  # FunctionDefs is different from GraphDefs.
  nested_to_flat_tensor_name = {}

  for arg_def in fdef.signature.input_arg:
    nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)

  for node_def in fdef.node_def:
    op_def = op_def_registry.get_registered_ops().get(node_def.op)
    if not op_def:
      # TODO(b/80470245): Support functions which refer other functions.
      raise NotImplementedError(
          "No op registered for {},".format(node_def.op) +
          " it may be a function. function_def_to_graph_def " +
          "currently does not support converting functions with " +
          "references to other graph functions.")

    for attr in op_def.attr:
      if attr.type in ("func", "list(func)"):
        # TODO(b/80470245): Support functions which refer other functions.
        raise NotImplementedError("Unsupported attr {} ".format(attr.name) +
                                  " with type {}".format(attr.type) +
                                  " in op {}. ".format(op_def.name) +
                                  "function_def_to_graph_def currently does " +
                                  "not support converting functions with " +
                                  "references to other graph functions.")

    # Iterate over output_args in op_def to build the map.
    # Index of the output tensor in the flattened list of *all* output
    # tensors of the op.
    flattened_index = 0
    for arg_def in op_def.output_arg:
      num_args = _get_num_args(arg_def, node_def)
      for i in range(num_args):
        # Map tensor names from "node_name:output_arg_name:index" to
        # "node_name:flattened_index".
        nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
        flat_name = "{}:{}".format(node_def.name, flattened_index)
        nested_to_flat_tensor_name[nested_name] = flat_name
        flattened_index += 1

  # Update inputs of all nodes in graph.
  for node_def in graph_def.node:
    for i in range(len(node_def.input)):
      node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]

  return graph_def, nested_to_flat_tensor_name
Ejemplo n.º 35
0
    return ('?', name, util.is_tensor)


def Variable(name=None):
    return ('?', name, util.is_var)


def Const(name=None):
    return ('?', name, util.is_const)


def Placeholder(name=None):
    return ('?', name, util.is_placeholder)


_op_names = op_def_registry.get_registered_ops().keys()
util.import_ops_no_clobber(globals(), _op_names)

# NOTE(mattjj): renamed in TF 1.0, but not registered as an op in 1.0.1
Unstack = util.make_op_pattern('Unpack')  # pylint: disable=invalid-name

## convenient compound patterns

# The op definitions are pulled in via the op_def_registry, which is
# why we disable the undefined variable check for e.g. Rsqrt, Mul, etc.
# Otherwise we would have to refer to them by name rather than object.
# pylint: disable=undefined-variable


def BatchNorm(in_pattern=Tensor('in'),
              scale_name='scale',
Ejemplo n.º 36
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None):
  """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  with ops.op_scope(input_map.values(), name, 'import'):
    g = ops.get_default_graph()
    g.graph_def_versions.CopyFrom(graph_def.versions)

    with ops.name_scope('_inputs'):
      input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Set any default attr values that aren't present.
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)

      # Rewrite the colocation attributes in the graph, since the
      # names of new ops may have changed.
      for key, value in op.node_def.attr.items():
        if key == '_class':
          class_values = value.list
          new_class_values = []
          for class_value in class_values.s:
            if class_value.startswith(b'loc:@'):
              op_to_bind_to = class_value[5:].decode()
              # Find the op by its original name.
              if op_to_bind_to not in name_to_op:
                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' % (
                                     op_to_bind_to, node.name))
              original_op = name_to_op[op_to_bind_to]
              new_class_values.append(compat.as_bytes(
                  'loc:@' + original_op.name))
            else:
              new_class_values.append(class_value)
          value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
              s=new_class_values))

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def`.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(_InvalidNodeMessage(
                node, 'Input tensor %r %s' % (input_name, te)))

      # pylint: disable=protected_access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (", ".join(dtypes.as_dtype(x).name for x in input_types),
                   ", ".join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected_access

      # Execute shape inference for this op.
      # NOTE(mrry): If the graph contains a cycle, the full shape information
      # may not be available for this op's inputs.
      ops.set_shapes_for_outputs(op)

      # Apply device functions for this op.
      # NOTE(mrry): We do this after configuring the inputs, because
      # the result of the device functions may depend on the inputs.
      with _MaybeDevice(node.device):
        g._apply_device_functions(op)  # pylint: disable=protected-access

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        name = compat.as_str(name)
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
Ejemplo n.º 37
0
def _get_op_def(op):
  return op.op_def or op_def_registry.get_registered_ops()[op.type]
Ejemplo n.º 38
0
def _get_op_def(op):
    # pylint: disable=protected-access
    if hasattr(op, "_sig"):
        return getattr(op, "_sig")
    else:
        return op_def_registry.get_registered_ops()[op.type]
Ejemplo n.º 39
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      g._add_function(f)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

  # LINT.IfChange
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # TODO(ashankar): Should this just copy over or should it do some
    # more nuanced merging? For example, the graph may already have some
    # marked "bad versions" and we don't want to lose those because of
    # what's in graph_def.versions? The C++ ImporGraphDef does something
    # more nuanced.
    g.graph_def_versions.CopyFrom(graph_def.versions)

    if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
      if not scope:
        # The caller must have passed `name=''`.
        raise ValueError(
            'tf.import_graph_def() requires a non-empty `name` if `input_map` '
            'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
            '`input_map` values before calling tf.import_graph_def().')
      with ops.name_scope('_inputs'):
        input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Set any default attr values that aren't present.
      if node.op not in op_dict:
        raise ValueError('No op named %s in defined operations.' % node.op)
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)
      if producer_op_dict:
        # Remove any default attr values that aren't in op_def.
        if node.op in producer_op_dict:
          producer_op_def = producer_op_dict[node.op]
          # We make a copy of node.attr to iterate through since we
          # may modify node.attr inside the loop.
          for key in list(node.attr):
            if _FindAttrInOpDef(key, op_def) is None:
              # No attr_def in consumer, look in producer.
              attr_def = _FindAttrInOpDef(key, producer_op_def)
              if (attr_def and attr_def.HasField('default_value') and
                  node.attr[key] == attr_def.default_value):
                # Unknown attr had default value in producer, delete it
                # so it can be understood by consumer.
                del node.attr[key]

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)

      # Rewrite the colocation attributes in the graph, since the
      # names of new ops may have changed.
      for key, value in op.node_def.attr.items():
        if key == '_class':
          class_values = value.list
          new_class_values = []
          for class_value in class_values.s:
            if class_value.startswith(b'loc:@'):
              op_to_bind_to = class_value[5:].decode()
              # Find the op by its original name.
              if op_to_bind_to not in name_to_op:
                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' % (
                                     op_to_bind_to, node.name))
              original_op = name_to_op[op_to_bind_to]
              new_class_values.append(compat.as_bytes(
                  'loc:@' + original_op.name))
            else:
              new_class_values.append(class_value)
          value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
              s=new_class_values))

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def`.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(_InvalidNodeMessage(
                node, 'Input tensor %r %s' % (input_name, te)))

      # pylint: disable=protected-access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                   ', '.join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected-access

      if not g._is_function(op.type):  # pylint: disable=protected-access
        # Execute shape inference for this op.
        # NOTE(mrry): If the graph contains a cycle, the full shape information
        # may not be available for this op's inputs.
        ops.set_shapes_for_outputs(op)
      # For nodes with _output_shapes set, set the output shapes.
      if '_output_shapes' in op.node_def.attr:
        for i, output in enumerate(op.outputs):
          dims = op.node_def.attr['_output_shapes'].list.shape[i]
          output_shape = tensor_shape.TensorShape(
              None if dims.unknown_rank else
              [dim.size if dim.size >= 0 else None for dim in dims.dim])

          try:
            output.set_shape(output_shape)
          except ValueError as e:
            # If the output shape is incompatible with what is inferred
            # by the graph for a very specific whitelist of ops, then we
            # ignore this output shape.  This can happen if there is a
            # bug in the shape function for some operation, and the
            # serialized graph def has the incorrect shape set when
            # running on a newer binary with the fixed shape function.
            # This is an escape hatch that allows us to correct shape
            # functions that are not critical to correct execution but
            # would cause graphs to fail if imported after correcting.
            #
            # This can be removed after 2017/03/08.
            if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                           'FIFOQueue', 'PriorityQueue', 'QueueSize',
                           'Stack', 'Barrier', 'BarrierReadySize',
                           'BarrierIncompleteSize', 'HashTable',
                           'MutableHashTable',
                           'MutableHashTableOfTensors', 'Mutex',
                           'CuckooTable', 'IndexTable',
                           'WholeFileReader', 'TextLineReader',
                           'FixedLengthRecordReader',
                           'TFRecordReader', 'IdentityReader',
                           'RefSwitch', 'RefEnter', 'RefNextIteration',
                           'RefMerge', 'RefIdentity']:
              pass
            elif op.type in [
                'ConditionalAccumulator', 'SparseConditionalAccumulator',
                'Table'
            ]:
              # This can be removed after 2017/04/24.
              pass
            else:
              raise e

        del op.node_def.attr['_output_shapes']

      # Apply device functions for this op.
      # NOTE(mrry): We do this after configuring the inputs, because
      # the result of the device functions may depend on the inputs.
      with _MaybeDevice(node.device):
        g._apply_device_functions(op)  # pylint: disable=protected-access

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        name = compat.as_str(name)
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  `tf.Tensor` and `tf.Operation` objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  `tf.Graph.as_graph_def` for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    op_dict = op_def_registry.get_registered_ops()

    graph_def = _ProcessGraphDefParam(graph_def, op_dict)
    input_map = _ProcessInputMapParam(input_map)
    return_elements = _ProcessReturnElementsParam(return_elements)

    if producer_op_list is not None:
        # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
        _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

    graph = ops.get_default_graph()
    with ops.name_scope(name, 'import', input_map.values()) as scope:
        # Save unique prefix generated by name_scope
        if scope:
            assert scope.endswith('/')
            prefix = scope[:-1]
        else:
            prefix = ''

        # Generate any input map tensors inside name scope
        input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
    # Session.run call cannot occur between creating the TF_Operations in the
    # TF_GraphImportGraphDefWithResults call and mutating the them in
    # _ProcessNewOps.
    with graph._mutation_lock():  # pylint: disable=protected-access
        with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
            try:
                results = c_api.TF_GraphImportGraphDefWithResults(
                    graph._c_graph, serialized, options)  # pylint: disable=protected-access
                results = c_api_util.ScopedTFImportGraphDefResults(results)
            except errors.InvalidArgumentError as e:
                # Convert to ValueError for backwards compatibility.
                raise ValueError(str(e))

        # Create _DefinedFunctions for any imported functions.
        #
        # We do this by creating _DefinedFunctions directly from `graph_def`, and
        # adding them to `graph`. Adding an existing function to a TF_Graph is a
        # no-op, so this only has the effect of updating the Python state (usually
        # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
        #
        # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
        # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
        # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
        # _USE_C_SHAPES is removed.
        if graph_def.library and graph_def.library.function:
            # pylint: disable=protected-access
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(graph)
            # pylint: enable=protected-access

        _ProcessNewOps(graph)

    # Treat input mappings that don't appear in the graph as an error, because
    # they are likely to be due to a typo.
    missing_unused_input_keys = (
        c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
            results.results))
    if missing_unused_input_keys:
        missing_unused_input_keys = [
            compat.as_str(s) for s in missing_unused_input_keys
        ]
        raise ValueError(
            'Attempted to map inputs that were not found in graph_def: [%s]' %
            ', '.join(missing_unused_input_keys))

    if return_elements is None:
        return None
    else:
        return _GatherReturnElements(return_elements, graph, results.results)
Ejemplo n.º 41
0
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top


_REGISTERED_OPS = op_def_registry.get_registered_ops()


def enable_jit_nonstateful(node_def):
  try:
    return not _REGISTERED_OPS[node_def.op].is_stateful
  except KeyError:
    raise ValueError("Unregistered op being created: %s" % node_def)


class JITTest(test.TestCase):

  def compute(self, use_jit, compute_fn):
    random_seed.set_random_seed(1234)
    with self.test_session(graph=ops.Graph()) as sess:
      with jit.experimental_jit_scope(use_jit):
Ejemplo n.º 42
0
 def get(name):
   registered_ops = op_def_registry.get_registered_ops()
   return registered_ops.get(name)