Esempio n. 1
0
def calculate_graph_metrics(graph_def, statistic_types, input_layer,
                            input_shape_override, batch_size):
    """Looks at the performance statistics of all nodes in the graph."""
    _ = tf.import_graph_def(graph_def, name="")
    total_stats = {}
    node_stats = {}
    for statistic_type in statistic_types:
        total_stats[statistic_type] = ops.OpStats(statistic_type)
        node_stats[statistic_type] = {}
    # Make sure we get pretty-printed numbers with separators.
    locale.setlocale(locale.LC_ALL, "")
    with tf.Session() as sess:
        input_tensor = sess.graph.get_tensor_by_name(input_layer)
        input_shape_tensor = input_tensor.get_shape()
        if input_shape_tensor:
            input_shape = input_shape_tensor.as_list()
        else:
            input_shape = None
        if input_shape_override:
            input_shape = input_shape_override
        input_shape[0] = batch_size
        input_tensor.set_shape(input_shape)
        for node in graph_def.node:
            # Ensure that the updated input shape has been fully-propagated before we
            # ask for the statistics, since they may depend on the output size.
            op = sess.graph.get_operation_by_name(node.name)
            ops.set_shapes_for_outputs(op)
            for statistic_type in statistic_types:
                current_stats = ops.get_stats_for_node_def(
                    sess.graph, node, statistic_type)
                node_stats[statistic_type][node.name] = current_stats
                total_stats[statistic_type] += current_stats
    return total_stats, node_stats
Esempio n. 2
0
def calculate_graph_metrics(graph_def, statistic_types, input_layer,
                            input_shape_override, batch_size):
  """Looks at the performance statistics of all nodes in the graph."""
  _ = tf.import_graph_def(graph_def, name="")
  total_stats = {}
  node_stats = {}
  for statistic_type in statistic_types:
    total_stats[statistic_type] = ops.OpStats(statistic_type)
    node_stats[statistic_type] = {}
  # Make sure we get pretty-printed numbers with separators.
  locale.setlocale(locale.LC_ALL, "")
  with tf.Session() as sess:
    input_tensor = sess.graph.get_tensor_by_name(input_layer)
    input_shape_tensor = input_tensor.get_shape()
    if input_shape_tensor:
      input_shape = input_shape_tensor.as_list()
    else:
      input_shape = None
    if input_shape_override:
      input_shape = input_shape_override
    input_shape[0] = batch_size
    input_tensor.set_shape(input_shape)
    for node in graph_def.node:
      # Ensure that the updated input shape has been fully-propagated before we
      # ask for the statistics, since they may depend on the output size.
      op = sess.graph.get_operation_by_name(node.name)
      ops.set_shapes_for_outputs(op)
      for statistic_type in statistic_types:
        current_stats = ops.get_stats_for_node_def(sess.graph, node,
                                                   statistic_type)
        node_stats[statistic_type][node.name] = current_stats
        total_stats[statistic_type] += current_stats
  return total_stats, node_stats
Esempio n. 3
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      g._add_function(f)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

  # LINT.IfChange
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # TODO(ashankar): Should this just copy over or should it do some
    # more nuanced merging? For example, the graph may already have some
    # marked "bad versions" and we don't want to lose those because of
    # what's in graph_def.versions? The C++ ImporGraphDef does something
    # more nuanced.
    g.graph_def_versions.CopyFrom(graph_def.versions)

    if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
      if not scope:
        # The caller must have passed `name=''`.
        raise ValueError(
            'tf.import_graph_def() requires a non-empty `name` if `input_map` '
            'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
            '`input_map` values before calling tf.import_graph_def().')
      with ops.name_scope('_inputs'):
        input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Set any default attr values that aren't present.
      if node.op not in op_dict:
        raise ValueError('No op named %s in defined operations.' % node.op)
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)
      if producer_op_dict:
        # Remove any default attr values that aren't in op_def.
        if node.op in producer_op_dict:
          producer_op_def = producer_op_dict[node.op]
          # We make a copy of node.attr to iterate through since we
          # may modify node.attr inside the loop.
          for key in list(node.attr):
            if _FindAttrInOpDef(key, op_def) is None:
              # No attr_def in consumer, look in producer.
              attr_def = _FindAttrInOpDef(key, producer_op_def)
              if (attr_def and attr_def.HasField('default_value') and
                  node.attr[key] == attr_def.default_value):
                # Unknown attr had default value in producer, delete it
                # so it can be understood by consumer.
                del node.attr[key]

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)

      # Rewrite the colocation attributes in the graph, since the
      # names of new ops may have changed.
      for key, value in op.node_def.attr.items():
        if key == '_class':
          class_values = value.list
          new_class_values = []
          for class_value in class_values.s:
            if class_value.startswith(b'loc:@'):
              op_to_bind_to = class_value[5:].decode()
              # Find the op by its original name.
              if op_to_bind_to not in name_to_op:
                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' % (
                                     op_to_bind_to, node.name))
              original_op = name_to_op[op_to_bind_to]
              new_class_values.append(compat.as_bytes(
                  'loc:@' + original_op.name))
            else:
              new_class_values.append(class_value)
          value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
              s=new_class_values))

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def`.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(_InvalidNodeMessage(
                node, 'Input tensor %r %s' % (input_name, te)))

      # pylint: disable=protected-access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                   ', '.join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected-access

      if not g._is_function(op.type):  # pylint: disable=protected-access
        # Execute shape inference for this op.
        # NOTE(mrry): If the graph contains a cycle, the full shape information
        # may not be available for this op's inputs.
        ops.set_shapes_for_outputs(op)
      # For nodes with _output_shapes set, set the output shapes.
      if '_output_shapes' in op.node_def.attr:
        for i, output in enumerate(op.outputs):
          dims = op.node_def.attr['_output_shapes'].list.shape[i]
          output_shape = tensor_shape.TensorShape(
              None if dims.unknown_rank else
              [dim.size if dim.size >= 0 else None for dim in dims.dim])

          try:
            output.set_shape(output_shape)
          except ValueError as e:
            # If the output shape is incompatible with what is inferred
            # by the graph for a very specific whitelist of ops, then we
            # ignore this output shape.  This can happen if there is a
            # bug in the shape function for some operation, and the
            # serialized graph def has the incorrect shape set when
            # running on a newer binary with the fixed shape function.
            # This is an escape hatch that allows us to correct shape
            # functions that are not critical to correct execution but
            # would cause graphs to fail if imported after correcting.
            #
            # This can be removed after 2017/03/08.
            if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                           'FIFOQueue', 'PriorityQueue', 'QueueSize',
                           'Stack', 'Barrier', 'BarrierReadySize',
                           'BarrierIncompleteSize', 'HashTable',
                           'MutableHashTable',
                           'MutableHashTableOfTensors', 'Mutex',
                           'CuckooTable', 'IndexTable',
                           'WholeFileReader', 'TextLineReader',
                           'FixedLengthRecordReader',
                           'TFRecordReader', 'IdentityReader',
                           'RefSwitch', 'RefEnter', 'RefNextIteration',
                           'RefMerge', 'RefIdentity']:
              pass
            elif op.type in [
                'ConditionalAccumulator', 'SparseConditionalAccumulator',
                'Table'
            ]:
              # This can be removed after 2017/04/24.
              pass
            else:
              raise e

        del op.node_def.attr['_output_shapes']

      # Apply device functions for this op.
      # NOTE(mrry): We do this after configuring the inputs, because
      # the result of the device functions may depend on the inputs.
      with _MaybeDevice(node.device):
        g._apply_device_functions(op)  # pylint: disable=protected-access

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        name = compat.as_str(name)
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
Esempio n. 4
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      f.add_to_graph(g)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

  # LINT.IfChange
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # TODO(ashankar): Should this just copy over or should it do some
    # more nuanced merging? For example, the graph may already have some
    # marked "bad versions" and we don't want to lose those because of
    # what's in graph_def.versions? The C++ ImporGraphDef does something
    # more nuanced.
    g.graph_def_versions.CopyFrom(graph_def.versions)

    if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
      if not scope:
        # The caller must have passed `name=''`.
        raise ValueError(
            'tf.import_graph_def() requires a non-empty `name` if `input_map` '
            'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
            '`input_map` values before calling tf.import_graph_def().')
      with ops.name_scope('_inputs'):
        input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Check to see if this op's name matches a previously seen op
      if node.name in name_to_op:
        raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
      # Set any default attr values that aren't present.
      if node.op not in op_dict:
        raise ValueError('No op named %s in defined operations.' % node.op)
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)
      if producer_op_dict:
        # Remove any default attr values that aren't in op_def.
        if node.op in producer_op_dict:
          producer_op_def = producer_op_dict[node.op]
          # We make a copy of node.attr to iterate through since we
          # may modify node.attr inside the loop.
          for key in list(node.attr):
            if _FindAttrInOpDef(key, op_def) is None:
              # No attr_def in consumer, look in producer.
              attr_def = _FindAttrInOpDef(key, producer_op_def)
              if (attr_def and attr_def.HasField('default_value') and
                  node.attr[key] == attr_def.default_value):
                # Unknown attr had default value in producer, delete it
                # so it can be understood by consumer.
                del node.attr[key]

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # Maps from a node to the ops it is colocated with, if colocation
    # is specified in the attributes.
    colocation_pairs = collections.defaultdict(list)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)
      apply_device_function = True

      # Rewrite the colocation attributes in the graph, since the
      # names of new ops may have changed.
      for key, value in op.node_def.attr.items():
        if key == '_class':
          class_values = value.list
          new_class_values = []
          for class_value in class_values.s:
            if class_value.startswith(b'loc:@'):
              op_to_bind_to = class_value[5:].decode()
              # Find the op by its original name.
              if op_to_bind_to not in name_to_op:
                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' % (
                                     op_to_bind_to, node.name))
              original_op = name_to_op[op_to_bind_to]
              new_class_values.append(compat.as_bytes(
                  'loc:@' + original_op.name))
              if op_to_bind_to != node.name:
                # Keep track of this mapping for a later phase.
                colocation_pairs[op].append(original_op)
                # Don't apply this op's device function,
                # the colocation constraint will ensure
                # the proper device gets assigned at runtime.
                apply_device_function = False

            else:
              new_class_values.append(class_value)
          value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
              s=new_class_values))

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def`.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(_InvalidNodeMessage(
                node, 'Input tensor %r %s' % (input_name, te)))

      # pylint: disable=protected-access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                   ', '.join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected-access

      if not g._is_function(op.type):  # pylint: disable=protected-access
        # Execute shape inference for this op.
        # NOTE(mrry): If the graph contains a cycle, the full shape information
        # may not be available for this op's inputs.
        ops.set_shapes_for_outputs(op)
      # For nodes with _output_shapes set, set the output shapes.
      if '_output_shapes' in op.node_def.attr:
        for i, output in enumerate(op.outputs):
          dims = op.node_def.attr['_output_shapes'].list.shape[i]
          output_shape = tensor_shape.TensorShape(
              None if dims.unknown_rank else
              [dim.size if dim.size >= 0 else None for dim in dims.dim])

          try:
            output.set_shape(output_shape)
          except ValueError as e:
            # If the output shape is incompatible with what is inferred
            # by the graph for a very specific whitelist of ops, then we
            # ignore this output shape.  This can happen if there is a
            # bug in the shape function for some operation, and the
            # serialized graph def has the incorrect shape set when
            # running on a newer binary with the fixed shape function.
            # This is an escape hatch that allows us to correct shape
            # functions that are not critical to correct execution but
            # would cause graphs to fail if imported after correcting.
            #
            # This can be removed after 2017/03/08.
            if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                           'FIFOQueue', 'PriorityQueue', 'QueueSize',
                           'Stack', 'Barrier', 'BarrierReadySize',
                           'BarrierIncompleteSize', 'HashTable',
                           'MutableHashTable',
                           'MutableHashTableOfTensors', 'Mutex',
                           'CuckooTable', 'IndexTable',
                           'WholeFileReader', 'TextLineReader',
                           'FixedLengthRecordReader',
                           'TFRecordReader', 'IdentityReader',
                           'LMDBReader',
                           'RefSwitch', 'RefEnter', 'RefNextIteration',
                           'RefMerge', 'RefIdentity']:
              pass
            elif op.type in [
                'ConditionalAccumulator', 'SparseConditionalAccumulator',
                'Table'
            ]:
              # This can be removed after 2017/04/24.
              pass
            else:
              raise e

        del op.node_def.attr['_output_shapes']

      # NOTE(mrry): We do this after configuring the inputs, because
      # the result of the device functions may depend on the inputs.
      if apply_device_function:
        with _MaybeDevice(node.device):
          g._apply_device_functions(op)  # pylint: disable=protected-access

    # The following loop populates the device field of ops that are
    # colocated with another op.  This is implied by the colocation
    # attribute, but we propagate the device field for completeness.
    for op, coloc_op_list in colocation_pairs.items():
      coloc_device = None
      # Find any device in the list of colocated ops that have a
      # device, if it exists.  We assume that if multiple ops
      # have devices, they refer to the same device.  Otherwise, a
      # runtime error will occur since the colocation property
      # cannot be guaranteed.
      #
      # One possible improvement is to try to check for compatibility
      # of all devices in this list at import time here, which would
      # require implementing a compatibility function for device specs
      # in python.
      for coloc_op in coloc_op_list:
        if coloc_op.device:
          coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
          break
      if coloc_device:
        op._set_device(coloc_device)  # pylint: disable=protected-access

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        name = compat.as_str(name)
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
Esempio n. 5
0
def copy(org_instance,
         dict_swap=None,
         scope="copied",
         replace_itself=False,
         copy_q=False):
    """Build a new node in the TensorFlow graph from `org_instance`,
  where any of its ancestors existing in `dict_swap` are
  replaced with `dict_swap`'s corresponding value.

  The copying is done recursively, so any `Operation` whose output
  is required to evaluate `org_instance` is also copied (if it isn't
  already copied within the new scope). This is with the exception of
  `tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue`, which
  are reused and not newly copied.

  Parameters
  ----------
  org_instance : RandomVariable, tf.Variable, tf.Tensor, or tf.Operation
    Node to add in graph with replaced ancestors.
  dict_swap : dict, optional
    Random variables, variables, tensors, or operations to swap with.
    Its keys are what `org_instance` may depend on, and its values are
    the corresponding object (not necessarily of the same class
    instance, but must have the same type, e.g., float32) that is used
    in exchange.
  scope : str, optional
    A scope for the new node(s). This is used to avoid name
    conflicts with the original node(s).
  replace_itself : bool, optional
    Whether to replace `org_instance` itself if it exists in
    `dict_swap`. (This is used for the recursion.)
  copy_q : bool, optional
    Whether to copy the replaced tensors too (if not already
    copied within the new scope). Otherwise will reuse them.

  Returns
  -------
  RandomVariable, tf.Variable, tf.Tensor, or tf.Operation
    The copied node.

  Raises
  ------
  TypeError
    If `org_instance` is not one of the above types.

  Examples
  --------
  >>> x = tf.constant(2.0)
  >>> y = tf.constant(3.0)
  >>> z = x * y
  >>>
  >>> qx = tf.constant(4.0)
  >>> # The TensorFlow graph is currently
  >>> # `x` -> `z` <- y`, `qx`
  >>>
  >>> # This adds a subgraph with newly copied nodes,
  >>> # `copied/qx` -> `copied/z` <- `copied/y`
  >>> z_new = copy(z, {x: qx})
  >>>
  >>> sess = tf.Session()
  >>> sess.run(z)
  6.0
  >>> sess.run(z_new)
  12.0
  """
    if not isinstance(org_instance, RandomVariable) and \
       not isinstance(org_instance, tf.Variable) and \
       not isinstance(org_instance, tf.Tensor) and \
       not isinstance(org_instance, tf.Operation):
        raise TypeError("Could not copy instance: " + str(org_instance))

    if dict_swap is None:
        dict_swap = {}

    # Swap instance if in dictionary.
    if org_instance in dict_swap and replace_itself:
        org_instance = dict_swap[org_instance]
        if not copy_q:
            return org_instance
    elif isinstance(org_instance, tf.Tensor) and replace_itself:
        # Deal with case when `org_instance` is the associated tensor
        # from the RandomVariable, e.g., `z.value()`. If
        # `dict_swap={z: qz}`, we aim to swap it with `qz.value()`.
        for key, value in six.iteritems(dict_swap):
            if isinstance(key, RandomVariable):
                if org_instance == key.value():
                    if isinstance(value, RandomVariable):
                        org_instance = value.value()
                    else:
                        org_instance = value

                    if not copy_q:
                        return org_instance
                    break

    graph = tf.get_default_graph()
    new_name = scope + '/' + org_instance.name

    # If an instance of the same name exists, return appropriately.
    # Do this for random variables.
    random_variables = {
        x.name: x
        for x in graph.get_collection('_random_variable_collection_')
    }
    if new_name in random_variables:
        return random_variables[new_name]

    # Do this for tensors and operations.
    try:
        already_present = graph.as_graph_element(new_name,
                                                 allow_tensor=True,
                                                 allow_operation=True)
        return already_present
    except:
        pass

    # If instance is a variable, return it; do not re-copy any.
    # Note we check variables via their name and not their type. This
    # is because if we get variables through an op's inputs, it has
    # type tf.Tensor: we can only tell it is a variable via its name.
    variables = {
        x.name: x
        for x in graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    }
    if org_instance.name in variables:
        return graph.get_tensor_by_name(variables[org_instance.name].name)

    # Do the same for placeholders. Determine via its op's type.
    if isinstance(org_instance, tf.Tensor) and \
            "Placeholder" in org_instance.op.type:
        return org_instance

    if isinstance(org_instance, RandomVariable):
        rv = org_instance

        # If it has copiable arguments, copy them.
        args = []
        for arg in rv._args:
            if isinstance(arg, RandomVariable) or \
               isinstance(arg, tf.Variable) or \
               isinstance(arg, tf.Tensor) or \
               isinstance(arg, tf.Operation):
                arg = copy(arg, dict_swap, scope, True, copy_q)

            args.append(arg)

        kwargs = {}
        for key, value in six.iteritems(rv._kwargs):
            if isinstance(value, list):
                kwargs[key] = [
                    copy_rv(v, dict_swap, scope, True, copy_q) for v in value
                ]
            else:
                kwargs[key] = copy_rv(value, dict_swap, scope, True, copy_q)

        kwargs['name'] = new_name
        # Create new random variable with copied arguments.
        new_rv = rv.__class__(*args, **kwargs)
        return new_rv
    elif isinstance(org_instance, tf.Tensor):
        tensor = org_instance

        # A tensor is one of the outputs of its underlying
        # op. Therefore copy the op itself.
        op = tensor.op
        new_op = copy(op, dict_swap, scope, True, copy_q)

        output_index = op.outputs.index(tensor)
        new_tensor = new_op.outputs[output_index]

        # Add copied tensor to collections that the original one is in.
        for name, collection in tensor.graph._collections.items():
            if tensor in collection:
                graph.add_to_collection(name, new_tensor)

        return new_tensor
    elif isinstance(org_instance, tf.Operation):
        op = org_instance

        # Do not copy queue operations
        if 'Queue' in op.type:
            return op

        # If it has an original op, copy it.
        if op._original_op is not None:
            new_original_op = copy(op._original_op, dict_swap, scope, True,
                                   copy_q)
        else:
            new_original_op = None

        # If it has control inputs, copy them.
        new_control_inputs = []
        for x in op.control_inputs:
            elem = copy(x, dict_swap, scope, True, copy_q)
            if not isinstance(elem, tf.Operation):
                elem = tf.convert_to_tensor(elem)

            new_control_inputs += [elem]

        # If it has inputs, copy them.
        new_inputs = []
        for x in op.inputs:
            elem = copy(x, dict_swap, scope, True, copy_q)
            if not isinstance(elem, tf.Operation):
                elem = tf.convert_to_tensor(elem)

            new_inputs += [elem]

        # Make a copy of the node def.
        # As an instance of tensorflow.core.framework.graph_pb2.NodeDef, it
        # stores string-based info such as name, device, and type of the op.
        # It is unique to every Operation instance.
        new_node_def = deepcopy(op.node_def)
        new_node_def.name = new_name

        # Copy the other inputs needed for initialization.
        output_types = op._output_types[:]
        input_types = op._input_types[:]

        # Make a copy of the op def.
        # It is unique to every Operation type.
        op_def = deepcopy(op.op_def)

        ret = tf.Operation(new_node_def, graph, new_inputs, output_types,
                           new_control_inputs, input_types, new_original_op,
                           op_def)

        # Use Graph's private methods to add the op, following
        # implementation of `tf.Graph().create_op()`.
        compute_shapes = True
        compute_device = True
        op_type = new_name

        if compute_shapes:
            set_shapes_for_outputs(ret)
        graph._add_op(ret)
        graph._record_op_seen_by_control_dependencies(ret)

        if compute_device:
            graph._apply_device_functions(ret)

        if graph._colocation_stack:
            all_colocation_groups = []
            for colocation_op in graph._colocation_stack:
                all_colocation_groups.extend(colocation_op.colocation_groups())
                if colocation_op.device:
                    # Make this device match the device of the colocated op, to
                    # provide consistency between the device and the colocation
                    # property.
                    if ret.device and ret.device != colocation_op.device:
                        logging.warning(
                            "Tried to colocate %s with an op %s that had "
                            "a different device: %s vs %s. "
                            "Ignoring colocation property.", name,
                            colocation_op.name, ret.device,
                            colocation_op.device)
                    else:
                        ret._set_device(colocation_op.device)

            all_colocation_groups = sorted(set(all_colocation_groups))
            ret.node_def.attr["_class"].CopyFrom(
                attr_value_pb2.AttrValue(
                    list=attr_value_pb2.AttrValue.ListValue(
                        s=all_colocation_groups)))

        # Sets "container" attribute if
        # (1) graph._container is not None
        # (2) "is_stateful" is set in OpDef
        # (3) "container" attribute is in OpDef
        # (4) "container" attribute is None
        if (graph._container and op_type in graph._registered_ops
                and graph._registered_ops[op_type].is_stateful
                and "container" in ret.node_def.attr
                and not ret.node_def.attr["container"].s):
            ret.node_def.attr["container"].CopyFrom(
                attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container)))

        return ret
    else:
        raise TypeError("Could not copy instance: " + str(org_instance))
Esempio n. 6
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None):
    """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    # Type checks for inputs.
    if not isinstance(graph_def, graph_pb2.GraphDef):
        # `graph_def` could be a dynamically-created message, so try a duck-typed
        # approach
        try:
            old_graph_def = graph_def
            graph_def = graph_pb2.GraphDef()
            graph_def.MergeFrom(old_graph_def)
        except TypeError:
            raise TypeError('graph_def must be a GraphDef proto.')
    if input_map is None:
        input_map = {}
    else:
        if not (isinstance(input_map, dict) and all(
                isinstance(k, compat.bytes_or_text_types)
                for k in input_map.keys())):
            raise TypeError(
                'input_map must be a dictionary mapping strings to '
                'Tensor objects.')
    if return_elements is not None:
        return_elements = tuple(return_elements)
        if not all(
                isinstance(x, compat.bytes_or_text_types)
                for x in return_elements):
            raise TypeError('return_elements must be a list of strings.')

    # Use a canonical representation for all tensor names.
    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
    used_input_keys = set()

    name_to_op = {}

    if op_dict is None:
        op_dict = op_def_registry.get_registered_ops()

    with ops.op_scope(input_map.values(), name, 'import'):
        g = ops.get_default_graph()
        g.graph_def_versions.CopyFrom(graph_def.versions)

        with ops.name_scope('_inputs'):
            input_map = {
                k: ops.convert_to_tensor(v)
                for k, v in input_map.items()
            }

        # NOTE(mrry): We do this in two passes, because there may be a cycle in
        # `graph_def`.

        # 1. Add operations without their inputs.
        for node in graph_def.node:
            # Set any default attr values that aren't present.
            op_def = op_dict[node.op]
            for attr_def in op_def.attr:
                key = attr_def.name
                if attr_def.HasField('default_value'):
                    value = node.attr[key]
                    if value is None or value.WhichOneof('value') is None:
                        node.attr[key].CopyFrom(attr_def.default_value)

            output_types = _OutputTypes(node, op_dict)
            name_to_op[node.name] = g.create_op(node.op, [],
                                                output_types,
                                                name=node.name,
                                                attrs=node.attr,
                                                compute_shapes=False,
                                                compute_device=False)

        # 2. Add inputs to the operations.
        for node in graph_def.node:
            op = name_to_op[node.name]
            input_types = _InputTypes(node, op_dict)

            # NOTE(mrry): We cannot use zip here because control inputs do not appear
            # in the list of input_types.
            for i, input_name in enumerate(
                [_CanonicalInputName(x) for x in node.input]):

                if _IsControlInput(input_name):
                    # (a) Input is a control input that should be taken from an op
                    #     in "graph_def".
                    try:
                        source_op = name_to_op[input_name[1:]]
                    except KeyError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'Control input %r not found in graph_def.' %
                                (input_name, )))
                    # pylint: disable=protected-access
                    op._add_control_input(source_op)
                    # pylint: enable=protected-access

                else:
                    try:
                        input_type = input_types[i]
                    except IndexError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'More inputs specified (%r) than the op expects.'
                                % (input_name, )))

                    if input_name in input_map:
                        # (b) Input should be replaced by a tensor from the caller.
                        source_tensor = input_map[input_name]
                        used_input_keys.add(input_name)

                    else:
                        # (c) Input should be taken from an op in `graph_def`.
                        operation_name, output_index = _ParseTensorName(
                            input_name)
                        try:
                            source_op = name_to_op[operation_name]
                            source_tensor = list(
                                source_op.values())[output_index]
                        except (KeyError, IndexError):
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r not found in graph_def.' %
                                    (input_name, )))

                    try:
                        # pylint: disable=protected-access
                        op._add_input(source_tensor, dtype=input_type)
                        # pylint: enable=protected-access
                    except TypeError as te:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node, 'Input tensor %r %s' % (input_name, te)))

            # pylint: disable=protected_access
            if op._input_dtypes != input_types:
                raise ValueError(
                    _InvalidNodeMessage(
                        node, 'Input types mismatch (expected %r but got %r)' %
                        (", ".join(
                            dtypes.as_dtype(x).name
                            for x in input_types), ", ".join(
                                x.name for x in op._input_dtypes))))
            # pylint: enable=protected_access

            # Execute shape inference for this op.
            # NOTE(mrry): If the graph contains a cycle, the full shape information
            # may not be available for this op's inputs.
            ops.set_shapes_for_outputs(op)

            # Apply device functions for this op.
            # NOTE(mrry): We do this after configuring the inputs, because
            # the result of the device functions may depend on the inputs.
            with _MaybeDevice(node.device):
                g._apply_device_functions(op)  # pylint: disable=protected-access

        # Treat unused input mappings as an error, because they are likely to be
        # due to a typo.
        unused_input_keys = frozenset(
            input_map.keys()).difference(used_input_keys)
        if unused_input_keys:
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(unused_input_keys))

        if return_elements is None:
            return None
        else:
            ret = []
            for name in return_elements:
                name = compat.as_str(name)
                if ':' in name:
                    try:
                        operation_name, output_index = _ParseTensorName(name)
                        ret.append(
                            name_to_op[operation_name].outputs[output_index])
                    except (ValueError, KeyError, IndexError):
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
                else:
                    try:
                        ret.append(name_to_op[name])
                    except KeyError:
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
            return ret
Esempio n. 7
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  graph_def = _ProcessGraphDefParam(graph_def)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()

  if graph._c_graph:  # pylint: disable=protected-access
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # Save unique prefix generated by name_scope
      if scope:
        assert scope.endswith('/')
        prefix = scope[:-1]
      else:
        prefix = ''

      # Generate any input map tensors inside name scope
      input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        with errors.raise_exception_on_not_ok_status() as status:
          results = c_api.TF_GraphImportGraphDefWithResults(
              graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    _ProcessNewOps(graph)

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    # Treat input mappings that don't appear in the graph as an error, because
    # they are likely to be due to a typo.
    missing_unused_input_keys = (
        c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
            results))
    if missing_unused_input_keys:
      missing_unused_input_keys = [compat.as_str(s)
                                   for s in missing_unused_input_keys]
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(missing_unused_input_keys))

    if return_elements is None:
      return None
    else:
      return _GatherReturnElements(return_elements, graph, results)

  else:
    g = graph

    # Use a canonical representation for all tensor names.
    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
    used_input_keys = set()
    name_to_op = {}

    # Add any functions defined in `graph_def` to `g`
    if graph_def.library and graph_def.library.function:
      # Copy op_dict so we don't clobber the original
      op_dict = copy.copy(op_dict)
      # pylint: disable=protected-access
      # Note that we do not prepend `name` to the function name. The reasoning
      # is that function names are similar to op definition names, which
      # currently do not have a scoped name or namespace scheme.
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(g)
        op_dict[f.name] = f.definition.signature
      # pylint: enable=protected-access

    # LINT.IfChange
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # TODO(ashankar): Should this just copy over or should it do some
      # more nuanced merging? For example, the graph may already have some
      # marked "bad versions" and we don't want to lose those because of
      # what's in graph_def.versions? The C++ ImporGraphDef does something
      # more nuanced.
      g.graph_def_versions.CopyFrom(graph_def.versions)

      input_map = _ConvertInputMapValues(name, input_map)

      # NOTE(mrry): We do this in two passes, because there may be a cycle in
      # `graph_def`.

      # 1. Add operations without their inputs.
      for node in graph_def.node:
        # Check to see if this op's name matches a previously seen op
        if node.name in name_to_op:
          raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
        # Set any default attr values that aren't present.
        if node.op not in op_dict:
          raise ValueError('No op named %s in defined operations.' % node.op)
        op_def = op_dict[node.op]
        for attr_def in op_def.attr:
          key = attr_def.name
          if attr_def.HasField('default_value'):
            value = node.attr[key]
            if value is None or value.WhichOneof('value') is None:
              node.attr[key].CopyFrom(attr_def.default_value)

        output_types = _OutputTypes(node, op_dict)
        name_to_op[node.name] = g.create_op(
            node.op, [], output_types, name=node.name, attrs=node.attr,
            compute_shapes=False, compute_device=False,
            op_def=op_def)

      # Maps from a node to the ops it is colocated with, if colocation
      # is specified in the attributes.
      colocation_pairs = collections.defaultdict(list)

      # 2. Add inputs to the operations.
      for node in graph_def.node:
        op = name_to_op[node.name]
        input_types = _InputTypes(node, op_dict)
        apply_device_function = True

        # Rewrite the colocation attributes in the graph, since the
        # names of new ops may have changed.
        for key, value in op.node_def.attr.items():
          if key == '_class':
            class_values = value.list
            new_class_values = []
            for class_value in class_values.s:
              if class_value.startswith(b'loc:@'):
                op_to_bind_to = class_value[5:].decode()
                # Find the op by its original name.
                if op_to_bind_to not in name_to_op:
                  raise ValueError('Specified colocation to an op that '
                                   'does not exist during import: %s in %s' % (
                                       op_to_bind_to, node.name))
                original_op = name_to_op[op_to_bind_to]
                new_class_values.append(compat.as_bytes(
                    'loc:@' + original_op.name))
                if op_to_bind_to != node.name:
                  # Keep track of this mapping for a later phase.
                  colocation_pairs[op].append(original_op)
                  # Don't apply this op's device function,
                  # the colocation constraint will ensure
                  # the proper device gets assigned at runtime.
                  apply_device_function = False

              else:
                new_class_values.append(class_value)
            value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
                s=new_class_values))

        # NOTE(mrry): We cannot use zip here because control inputs do not
        # appear in the list of input_types.
        for i, input_name in enumerate(
            [_CanonicalInputName(x) for x in node.input]):

          if _IsControlInput(input_name):
            # (a) Input is a control input that should be taken from an op
            #     in "graph_def".
            try:
              source_op = name_to_op[input_name[1:]]
            except KeyError:
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Control input %r not found in graph_def.'
                      % (input_name,)))
            # pylint: disable=protected-access
            op._add_control_input(source_op)
            # pylint: enable=protected-access

          else:
            try:
              input_type = input_types[i]
            except IndexError:
              raise ValueError(_InvalidNodeMessage(
                  node, 'More inputs specified (%r) than the op expects.'
                  % (input_name,)))

            if input_name in input_map:
              # (b) Input should be replaced by a tensor from the caller.
              source_tensor = input_map[input_name]
              used_input_keys.add(input_name)

            else:
              # (c) Input should be taken from an op in `graph_def`.
              operation_name, output_index = _ParseTensorName(input_name)
              try:
                source_op = name_to_op[operation_name]
                source_tensor = list(source_op.values())[output_index]
              except (KeyError, IndexError):
                raise ValueError(
                    _InvalidNodeMessage(
                        node,
                        'Input tensor %r not found in graph_def.'
                        % (input_name,)))

            try:
              # pylint: disable=protected-access
              op._add_input(source_tensor, dtype=input_type)
              # pylint: enable=protected-access
            except TypeError as te:
              raise ValueError(_InvalidNodeMessage(
                  node, 'Input tensor %r %s' % (input_name, te)))

        # pylint: disable=protected-access
        if op._input_types != input_types:
          raise ValueError(
              _InvalidNodeMessage(
                  node,
                  'Input types mismatch (expected %r but got %r)'
                  % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                     ', '.join(x.name for x in op._input_types))))
        # pylint: enable=protected-access

        if not g._is_function(op.type):  # pylint: disable=protected-access
          # Execute shape inference for this op.
          # NOTE(mrry): If the graph contains a cycle, the full shape
          # information may not be available for this op's inputs.
          ops.set_shapes_for_outputs(op)
        # For nodes with _output_shapes set, set the output shapes.
        if '_output_shapes' in op.node_def.attr:
          for i, output in enumerate(op.outputs):
            dims = op.node_def.attr['_output_shapes'].list.shape[i]
            output_shape = tensor_shape.TensorShape(
                None if dims.unknown_rank else
                [dim.size if dim.size >= 0 else None for dim in dims.dim])

            try:
              output.set_shape(output_shape)
            except ValueError as e:
              # If the output shape is incompatible with what is inferred
              # by the graph for a very specific whitelist of ops, then we
              # ignore this output shape.  This can happen if there is a
              # bug in the shape function for some operation, and the
              # serialized graph def has the incorrect shape set when
              # running on a newer binary with the fixed shape function.
              # This is an escape hatch that allows us to correct shape
              # functions that are not critical to correct execution but
              # would cause graphs to fail if imported after correcting.
              #
              # This can be removed after 2017/03/08.
              if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                             'FIFOQueue', 'PriorityQueue', 'QueueSize',
                             'Stack', 'Barrier', 'BarrierReadySize',
                             'BarrierIncompleteSize', 'HashTable',
                             'MutableHashTable',
                             'MutableHashTableOfTensors', 'Mutex',
                             'CuckooTable', 'IndexTable',
                             'WholeFileReader', 'TextLineReader',
                             'FixedLengthRecordReader',
                             'TFRecordReader', 'IdentityReader',
                             'LMDBReader',
                             'RefSwitch', 'RefEnter', 'RefNextIteration',
                             'RefMerge', 'RefIdentity']:
                pass
              elif op.type in [
                  'ConditionalAccumulator', 'SparseConditionalAccumulator',
                  'Table'
              ]:
                # This can be removed after 2017/04/24.
                pass
              else:
                raise e

          del op.node_def.attr['_output_shapes']

        # NOTE(mrry): We do this after configuring the inputs, because
        # the result of the device functions may depend on the inputs.
        if apply_device_function:
          with _MaybeDevice(node.device):
            g._apply_device_functions(op)  # pylint: disable=protected-access

      # The following loop populates the device field of ops that are
      # colocated with another op.  This is implied by the colocation
      # attribute, but we propagate the device field for completeness.
      for op, coloc_op_list in colocation_pairs.items():
        coloc_device = None
        # Find any device in the list of colocated ops that have a
        # device, if it exists.  We assume that if multiple ops
        # have devices, they refer to the same device.  Otherwise, a
        # runtime error will occur since the colocation property
        # cannot be guaranteed.
        #
        # One possible improvement is to try to check for compatibility
        # of all devices in this list at import time here, which would
        # require implementing a compatibility function for device specs
        # in python.
        for coloc_op in coloc_op_list:
          if coloc_op.device:
            coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
            break
        if coloc_device:
          op._set_device(coloc_device)  # pylint: disable=protected-access

      # Treat input mappings that don't appear in the graph as an error,
      # because they are likely to be due to a typo.
      def _IsImportedNodeOutput(tensor_name):
        operation_name, output_index = _ParseTensorName(tensor_name)
        try:
          return output_index < len(name_to_op[operation_name].outputs)
        except KeyError:
          return False
      absent_input_keys = [
          k for k in frozenset(input_map.keys()).difference(used_input_keys)
          if not _IsImportedNodeOutput(k)]
      if absent_input_keys:
        raise ValueError(
            'Attempted to map inputs that were not found in graph_def: [%s]'
            % ', '.join(absent_input_keys))

      if return_elements is None:
        return None
      else:
        ret = []
        for name in return_elements:
          name = compat.as_str(name)
          if ':' in name:
            try:
              operation_name, output_index = _ParseTensorName(name)
              ret.append(name_to_op[operation_name].outputs[output_index])
            except (ValueError, KeyError, IndexError):
              raise ValueError(
                  'Requested return_element %r not found in graph_def.' % name)
          else:
            try:
              ret.append(name_to_op[name])
            except KeyError:
              raise ValueError(
                  'Requested return_element %r not found in graph_def.' % name)
        return ret
Esempio n. 8
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    op_dict = op_def_registry.get_registered_ops()

    graph_def = _ProcessGraphDefParam(graph_def, op_dict)
    input_map = _ProcessInputMapParam(input_map)
    return_elements = _ProcessReturnElementsParam(return_elements)

    if producer_op_list is not None:
        # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
        _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

    graph = ops.get_default_graph()

    if graph._c_graph:  # pylint: disable=protected-access
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # Save unique prefix generated by name_scope
            if scope:
                assert scope.endswith('/')
                prefix = scope[:-1]
            else:
                prefix = ''

            # Generate any input map tensors inside name scope
            input_map = _ConvertInputMapValues(name, input_map)

        scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
        options = scoped_options.options
        _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                         return_elements)

        # _ProcessNewOps mutates the new operations. _lock ensures a Session.run
        # call cannot occur between creating the TF_Operations in the
        # TF_GraphImportGraphDefWithResults call and mutating the them in
        # _ProcessNewOps.
        with graph._lock:  # pylint: disable=protected-access
            with c_api_util.tf_buffer(
                    graph_def.SerializeToString()) as serialized:
                try:
                    with errors.raise_exception_on_not_ok_status() as status:
                        results = c_api.TF_GraphImportGraphDefWithResults(
                            graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
                except errors.InvalidArgumentError as e:
                    # Convert to ValueError for backwards compatibility.
                    raise ValueError(str(e))

            _ProcessNewOps(graph)

        # Create _DefinedFunctions for any imported functions.
        #
        # We do this by creating _DefinedFunctions directly from `graph_def`, and
        # adding them to `graph`. Adding an existing function to a TF_Graph is a
        # no-op, so this only has the effect of updating the Python state (usually
        # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
        #
        # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
        # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
        if graph_def.library and graph_def.library.function:
            # pylint: disable=protected-access
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(graph)
            # pylint: enable=protected-access

        # Treat input mappings that don't appear in the graph as an error, because
        # they are likely to be due to a typo.
        missing_unused_input_keys = (
            c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
                results))
        if missing_unused_input_keys:
            missing_unused_input_keys = [
                compat.as_str(s) for s in missing_unused_input_keys
            ]
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(missing_unused_input_keys))

        if return_elements is None:
            return None
        else:
            return _GatherReturnElements(return_elements, graph, results)

    else:
        g = graph

        # Use a canonical representation for all tensor names.
        input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
        used_input_keys = set()
        name_to_op = {}

        # Add any functions defined in `graph_def` to `g`
        if graph_def.library and graph_def.library.function:
            # Copy op_dict so we don't clobber the original
            op_dict = copy.copy(op_dict)
            # pylint: disable=protected-access
            # Note that we do not prepend `name` to the function name. The reasoning
            # is that function names are similar to op definition names, which
            # currently do not have a scoped name or namespace scheme.
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(g)
                op_dict[f.name] = f.definition.signature
            # pylint: enable=protected-access

        # LINT.IfChange
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # TODO(ashankar): Should this just copy over or should it do some
            # more nuanced merging? For example, the graph may already have some
            # marked "bad versions" and we don't want to lose those because of
            # what's in graph_def.versions? The C++ ImporGraphDef does something
            # more nuanced.
            g.graph_def_versions.CopyFrom(graph_def.versions)

            input_map = _ConvertInputMapValues(name, input_map)

            # NOTE(mrry): We do this in two passes, because there may be a cycle in
            # `graph_def`.

            # 1. Add operations without their inputs.
            for node in graph_def.node:
                # Check to see if this op's name matches a previously seen op
                if node.name in name_to_op:
                    raise ValueError('Duplicate name \'%s\' in GraphDef.' %
                                     node.name)
                if node.op not in op_dict:
                    raise ValueError('No op named %s in defined operations.' %
                                     node.op)
                op_def = op_dict[node.op]

                output_types = _OutputTypes(node, op_dict)
                name_to_op[node.name] = g.create_op(node.op, [],
                                                    output_types,
                                                    name=node.name,
                                                    attrs=node.attr,
                                                    compute_shapes=False,
                                                    compute_device=False,
                                                    op_def=op_def)

            # Maps from a node to the ops it is colocated with, if colocation
            # is specified in the attributes.
            colocation_pairs = collections.defaultdict(list)

            # 2. Add inputs to the operations.
            for node in graph_def.node:
                op = name_to_op[node.name]
                input_types = _InputTypes(node, op_dict)
                apply_device_function = True

                # Rewrite the colocation attributes in the graph, since the
                # names of new ops may have changed.
                for key, value in op.node_def.attr.items():
                    if key == '_class':
                        class_values = value.list
                        new_class_values = []
                        for class_value in class_values.s:
                            if class_value.startswith(b'loc:@'):
                                op_to_bind_to = class_value[5:].decode()
                                # Find the op by its original name.
                                if op_to_bind_to not in name_to_op:
                                    raise ValueError(
                                        'Specified colocation to an op that '
                                        'does not exist during import: %s in %s'
                                        % (op_to_bind_to, node.name))
                                original_op = name_to_op[op_to_bind_to]
                                new_class_values.append(
                                    compat.as_bytes('loc:@' +
                                                    original_op.name))
                                if op_to_bind_to != node.name:
                                    # Keep track of this mapping for a later phase.
                                    colocation_pairs[op].append(original_op)
                                    # Don't apply this op's device function,
                                    # the colocation constraint will ensure
                                    # the proper device gets assigned at runtime.
                                    apply_device_function = False

                            else:
                                new_class_values.append(class_value)
                        value.list.CopyFrom(
                            attr_value_pb2.AttrValue.ListValue(
                                s=new_class_values))

                # NOTE(mrry): We cannot use zip here because control inputs do not
                # appear in the list of input_types.
                for i, input_name in enumerate(
                    [_CanonicalInputName(x) for x in node.input]):

                    if _IsControlInput(input_name):
                        # (a) Input is a control input that should be taken from an op
                        #     in "graph_def".
                        try:
                            source_op = name_to_op[input_name[1:]]
                        except KeyError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Control input %r not found in graph_def.'
                                    % (input_name, )))
                        # pylint: disable=protected-access
                        op._add_control_input(source_op)
                        # pylint: enable=protected-access

                    else:
                        try:
                            input_type = input_types[i]
                        except IndexError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'More inputs specified (%r) than the op expects.'
                                    % (input_name, )))

                        if input_name in input_map:
                            # (b) Input should be replaced by a tensor from the caller.
                            source_tensor = input_map[input_name]
                            used_input_keys.add(input_name)

                        else:
                            # (c) Input should be taken from an op in `graph_def`.
                            operation_name, output_index = _ParseTensorName(
                                input_name)
                            try:
                                source_op = name_to_op[operation_name]
                                source_tensor = list(
                                    source_op.values())[output_index]
                            except (KeyError, IndexError):
                                raise ValueError(
                                    _InvalidNodeMessage(
                                        node,
                                        'Input tensor %r not found in graph_def.'
                                        % (input_name, )))

                        try:
                            # pylint: disable=protected-access
                            op._add_input(source_tensor, dtype=input_type)
                            # pylint: enable=protected-access
                        except TypeError as te:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r %s' % (input_name, te)))

                # pylint: disable=protected-access
                if op._input_types != input_types:
                    raise ValueError(
                        _InvalidNodeMessage(
                            node,
                            'Input types mismatch (expected %r but got %r)' %
                            (', '.join(
                                dtypes.as_dtype(x).name
                                for x in input_types), ', '.join(
                                    x.name for x in op._input_types))))
                # pylint: enable=protected-access

                if not g._is_function(op.type):  # pylint: disable=protected-access
                    # Execute shape inference for this op.
                    # NOTE(mrry): If the graph contains a cycle, the full shape
                    # information may not be available for this op's inputs.
                    ops.set_shapes_for_outputs(op)
                # For nodes with _output_shapes set, set the output shapes.
                if '_output_shapes' in op.node_def.attr:
                    for i, output in enumerate(op.outputs):
                        dims = op.node_def.attr['_output_shapes'].list.shape[i]
                        output_shape = tensor_shape.TensorShape(
                            None if dims.unknown_rank else [
                                dim.size if dim.size >= 0 else None
                                for dim in dims.dim
                            ])

                        try:
                            output.set_shape(output_shape)
                        except ValueError as e:
                            # If the output shape is incompatible with what is inferred
                            # by the graph for a very specific whitelist of ops, then we
                            # ignore this output shape.  This can happen if there is a
                            # bug in the shape function for some operation, and the
                            # serialized graph def has the incorrect shape set when
                            # running on a newer binary with the fixed shape function.
                            # This is an escape hatch that allows us to correct shape
                            # functions that are not critical to correct execution but
                            # would cause graphs to fail if imported after correcting.
                            #
                            # This can be removed after 2017/03/08.
                            if op.type in [
                                    'RandomShuffleQueue', 'PaddingFIFOQueue',
                                    'FIFOQueue', 'PriorityQueue', 'QueueSize',
                                    'Stack', 'Barrier', 'BarrierReadySize',
                                    'BarrierIncompleteSize', 'HashTable',
                                    'MutableHashTable',
                                    'MutableHashTableOfTensors', 'Mutex',
                                    'CuckooTable', 'IndexTable',
                                    'WholeFileReader', 'TextLineReader',
                                    'FixedLengthRecordReader',
                                    'TFRecordReader', 'IdentityReader',
                                    'LMDBReader', 'RefSwitch', 'RefEnter',
                                    'RefNextIteration', 'RefMerge',
                                    'RefIdentity'
                            ]:
                                pass
                            elif op.type in [
                                    'ConditionalAccumulator',
                                    'SparseConditionalAccumulator', 'Table'
                            ]:
                                # This can be removed after 2017/04/24.
                                pass
                            else:
                                raise e

                    del op.node_def.attr['_output_shapes']

                # NOTE(mrry): We do this after configuring the inputs, because
                # the result of the device functions may depend on the inputs.
                if apply_device_function:
                    with _MaybeDevice(node.device):
                        g._apply_device_functions(op)  # pylint: disable=protected-access

            # The following loop populates the device field of ops that are
            # colocated with another op.  This is implied by the colocation
            # attribute, but we propagate the device field for completeness.
            for op, coloc_op_list in colocation_pairs.items():
                coloc_device = None
                # Find any device in the list of colocated ops that have a
                # device, if it exists.  We assume that if multiple ops
                # have devices, they refer to the same device.  Otherwise, a
                # runtime error will occur since the colocation property
                # cannot be guaranteed.
                #
                # One possible improvement is to try to check for compatibility
                # of all devices in this list at import time here, which would
                # require implementing a compatibility function for device specs
                # in python.
                for coloc_op in coloc_op_list:
                    if coloc_op.device:
                        coloc_device = pydev.DeviceSpec.from_string(
                            coloc_op.device)
                        break
                if coloc_device:
                    op._set_device(coloc_device)  # pylint: disable=protected-access

            # Treat input mappings that don't appear in the graph as an error,
            # because they are likely to be due to a typo.
            def _IsImportedNodeOutput(tensor_name):
                operation_name, output_index = _ParseTensorName(tensor_name)
                try:
                    return output_index < len(
                        name_to_op[operation_name].outputs)
                except KeyError:
                    return False

            absent_input_keys = [
                k for k in frozenset(input_map.keys()).difference(
                    used_input_keys) if not _IsImportedNodeOutput(k)
            ]
            if absent_input_keys:
                raise ValueError(
                    'Attempted to map inputs that were not found in graph_def: [%s]'
                    % ', '.join(absent_input_keys))

            if return_elements is None:
                return None
            else:
                ret = []
                for name in return_elements:
                    name = compat.as_str(name)
                    if ':' in name:
                        try:
                            operation_name, output_index = _ParseTensorName(
                                name)
                            ret.append(name_to_op[operation_name].
                                       outputs[output_index])
                        except (ValueError, KeyError, IndexError):
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                    else:
                        try:
                            ret.append(name_to_op[name])
                        except KeyError:
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                return ret
Esempio n. 9
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None):
  """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  with ops.op_scope(input_map.values(), name, 'import'):
    g = ops.get_default_graph()
    g.graph_def_versions.CopyFrom(graph_def.versions)

    with ops.name_scope('_inputs'):
      input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Set any default attr values that aren't present.
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)

      # Rewrite the colocation attributes in the graph, since the
      # names of new ops may have changed.
      for key, value in op.node_def.attr.items():
        if key == '_class':
          class_values = value.list
          new_class_values = []
          for class_value in class_values.s:
            if class_value.startswith(b'loc:@'):
              op_to_bind_to = class_value[5:].decode()
              # Find the op by its original name.
              if op_to_bind_to not in name_to_op:
                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' % (
                                     op_to_bind_to, node.name))
              original_op = name_to_op[op_to_bind_to]
              new_class_values.append(compat.as_bytes(
                  'loc:@' + original_op.name))
            else:
              new_class_values.append(class_value)
          value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
              s=new_class_values))

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def`.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(_InvalidNodeMessage(
                node, 'Input tensor %r %s' % (input_name, te)))

      # pylint: disable=protected_access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (", ".join(dtypes.as_dtype(x).name for x in input_types),
                   ", ".join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected_access

      # Execute shape inference for this op.
      # NOTE(mrry): If the graph contains a cycle, the full shape information
      # may not be available for this op's inputs.
      ops.set_shapes_for_outputs(op)

      # Apply device functions for this op.
      # NOTE(mrry): We do this after configuring the inputs, because
      # the result of the device functions may depend on the inputs.
      with _MaybeDevice(node.device):
        g._apply_device_functions(op)  # pylint: disable=protected-access

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        name = compat.as_str(name)
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
Esempio n. 10
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    # Type checks for inputs.
    if not isinstance(graph_def, graph_pb2.GraphDef):
        # `graph_def` could be a dynamically-created message, so try a duck-typed
        # approach
        try:
            old_graph_def = graph_def
            graph_def = graph_pb2.GraphDef()
            graph_def.MergeFrom(old_graph_def)
        except TypeError:
            raise TypeError('graph_def must be a GraphDef proto.')
    if input_map is None:
        input_map = {}
    else:
        if not (isinstance(input_map, dict) and all(
                isinstance(k, compat.bytes_or_text_types)
                for k in input_map.keys())):
            raise TypeError(
                'input_map must be a dictionary mapping strings to '
                'Tensor objects.')
    if return_elements is not None:
        return_elements = tuple(return_elements)
        if not all(
                isinstance(x, compat.bytes_or_text_types)
                for x in return_elements):
            raise TypeError('return_elements must be a list of strings.')

    # Use a canonical representation for all tensor names.
    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
    used_input_keys = set()

    name_to_op = {}

    if op_dict is None:
        op_dict = op_def_registry.get_registered_ops()

    if producer_op_list is None:
        producer_op_dict = None
    else:
        producer_op_dict = {op.name: op for op in producer_op_list.op}

    with ops.op_scope(input_map.values(), name, 'import'):
        g = ops.get_default_graph()
        g.graph_def_versions.CopyFrom(graph_def.versions)

        with ops.name_scope('_inputs'):
            input_map = {
                k: ops.convert_to_tensor(v)
                for k, v in input_map.items()
            }

        # NOTE(mrry): We do this in two passes, because there may be a cycle in
        # `graph_def`.

        # 1. Add operations without their inputs.
        for node in graph_def.node:
            # Set any default attr values that aren't present.
            op_def = op_dict[node.op]
            for attr_def in op_def.attr:
                key = attr_def.name
                if attr_def.HasField('default_value'):
                    value = node.attr[key]
                    if value is None or value.WhichOneof('value') is None:
                        node.attr[key].CopyFrom(attr_def.default_value)
            if producer_op_dict:
                # Remove any default attr values that aren't in op_def.
                if node.op in producer_op_dict:
                    producer_op_def = producer_op_dict[node.op]
                    # We make a copy of node.attr to iterate through since we
                    # may modify node.attr inside the loop.
                    for key in list(node.attr):
                        if _FindAttrInOpDef(key, op_def) is None:
                            # No attr_def in consumer, look in producer.
                            attr_def = _FindAttrInOpDef(key, producer_op_def)
                            if (attr_def and attr_def.HasField('default_value')
                                    and node.attr[key]
                                    == attr_def.default_value):
                                # Unknown attr had default value in producer, delete it
                                # so it can be understood by consumer.
                                del node.attr[key]

            output_types = _OutputTypes(node, op_dict)
            name_to_op[node.name] = g.create_op(node.op, [],
                                                output_types,
                                                name=node.name,
                                                attrs=node.attr,
                                                compute_shapes=False,
                                                compute_device=False,
                                                op_def=op_def)

        # 2. Add inputs to the operations.
        for node in graph_def.node:
            op = name_to_op[node.name]
            input_types = _InputTypes(node, op_dict)

            # Rewrite the colocation attributes in the graph, since the
            # names of new ops may have changed.
            for key, value in op.node_def.attr.items():
                if key == '_class':
                    class_values = value.list
                    new_class_values = []
                    for class_value in class_values.s:
                        if class_value.startswith(b'loc:@'):
                            op_to_bind_to = class_value[5:].decode()
                            # Find the op by its original name.
                            if op_to_bind_to not in name_to_op:
                                raise ValueError(
                                    'Specified colocation to an op that '
                                    'does not exist during import: %s in %s' %
                                    (op_to_bind_to, node.name))
                            original_op = name_to_op[op_to_bind_to]
                            new_class_values.append(
                                compat.as_bytes('loc:@' + original_op.name))
                        else:
                            new_class_values.append(class_value)
                    value.list.CopyFrom(
                        attr_value_pb2.AttrValue.ListValue(s=new_class_values))

            # NOTE(mrry): We cannot use zip here because control inputs do not appear
            # in the list of input_types.
            for i, input_name in enumerate(
                [_CanonicalInputName(x) for x in node.input]):

                if _IsControlInput(input_name):
                    # (a) Input is a control input that should be taken from an op
                    #     in "graph_def".
                    try:
                        source_op = name_to_op[input_name[1:]]
                    except KeyError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'Control input %r not found in graph_def.' %
                                (input_name, )))
                    # pylint: disable=protected-access
                    op._add_control_input(source_op)
                    # pylint: enable=protected-access

                else:
                    try:
                        input_type = input_types[i]
                    except IndexError:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node,
                                'More inputs specified (%r) than the op expects.'
                                % (input_name, )))

                    if input_name in input_map:
                        # (b) Input should be replaced by a tensor from the caller.
                        source_tensor = input_map[input_name]
                        used_input_keys.add(input_name)

                    else:
                        # (c) Input should be taken from an op in `graph_def`.
                        operation_name, output_index = _ParseTensorName(
                            input_name)
                        try:
                            source_op = name_to_op[operation_name]
                            source_tensor = list(
                                source_op.values())[output_index]
                        except (KeyError, IndexError):
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r not found in graph_def.' %
                                    (input_name, )))

                    try:
                        # pylint: disable=protected-access
                        op._add_input(source_tensor, dtype=input_type)
                        # pylint: enable=protected-access
                    except TypeError as te:
                        raise ValueError(
                            _InvalidNodeMessage(
                                node, 'Input tensor %r %s' % (input_name, te)))

            # pylint: disable=protected_access
            if op._input_dtypes != input_types:
                raise ValueError(
                    _InvalidNodeMessage(
                        node, 'Input types mismatch (expected %r but got %r)' %
                        (', '.join(
                            dtypes.as_dtype(x).name
                            for x in input_types), ', '.join(
                                x.name for x in op._input_dtypes))))
            # pylint: enable=protected_access

            # Execute shape inference for this op.
            # NOTE(mrry): If the graph contains a cycle, the full shape information
            # may not be available for this op's inputs.
            ops.set_shapes_for_outputs(op)
            # For nodes with _output_shapes set, set the output shapes.
            if '_output_shapes' in op.node_def.attr:
                for i, output in enumerate(op.outputs):
                    dims = op.node_def.attr['_output_shapes'].list.shape[i]
                    output_shape = tensor_shape.TensorShape(
                        None if dims.unknown_rank else [
                            dim.size if dim.size >= 0 else None
                            for dim in dims.dim
                        ])
                    output.set_shape(output_shape)
                del op.node_def.attr['_output_shapes']

            # Apply device functions for this op.
            # NOTE(mrry): We do this after configuring the inputs, because
            # the result of the device functions may depend on the inputs.
            with _MaybeDevice(node.device):
                g._apply_device_functions(op)  # pylint: disable=protected-access

        # Treat unused input mappings as an error, because they are likely to be
        # due to a typo.
        unused_input_keys = frozenset(
            input_map.keys()).difference(used_input_keys)
        if unused_input_keys:
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(unused_input_keys))

        if return_elements is None:
            return None
        else:
            ret = []
            for name in return_elements:
                name = compat.as_str(name)
                if ':' in name:
                    try:
                        operation_name, output_index = _ParseTensorName(name)
                        ret.append(
                            name_to_op[operation_name].outputs[output_index])
                    except (ValueError, KeyError, IndexError):
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
                else:
                    try:
                        ret.append(name_to_op[name])
                    except KeyError:
                        raise ValueError(
                            'Requested return_element %r not found in graph_def.'
                            % name)
            return ret
Esempio n. 11
0
def copy(org_instance,
         dict_swap=None,
         scope="copied",
         replace_itself=False,
         copy_q=False,
         copy_parent_rvs=True):
    """Build a new node in the TensorFlow graph from `org_instance`,
  where any of its ancestors existing in `dict_swap` are
  replaced with `dict_swap`'s corresponding value.

  Copying is done recursively. Any `Operation` whose output is
  required to copy `org_instance` is also copied (if it isn't already
  copied within the new scope).

  `tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue` are
  always reused and not copied. In addition, `tf.Operation`s with
  operation-level seeds are copied with a new operation-level seed.

  Args:
    org_instance: RandomVariable, tf.Operation, tf.Tensor, or tf.Variable.
      Node to add in graph with replaced ancestors.
    dict_swap: dict, optional.
      Random variables, variables, tensors, or operations to swap with.
      Its keys are what `org_instance` may depend on, and its values are
      the corresponding object (not necessarily of the same class
      instance, but must have the same type, e.g., float32) that is used
      in exchange.
    scope: str, optional.
      A scope for the new node(s). This is used to avoid name
      conflicts with the original node(s).
    replace_itself: bool, optional
      Whether to replace `org_instance` itself if it exists in
      `dict_swap`. (This is used for the recursion.)
    copy_q: bool, optional.
      Whether to copy the replaced tensors too (if not already
      copied within the new scope). Otherwise will reuse them.
    copy_parent_rvs:
      Whether to copy parent random variables `org_instance` depends
      on. Otherwise will copy only the sample tensors and not the
      random variable class itself.

  Returns:
    RandomVariable, tf.Variable, tf.Tensor, or tf.Operation.
    The copied node.

  Raises:
    TypeError.
    If `org_instance` is not one of the above types.

  #### Examples

  ```python
  x = tf.constant(2.0)
  y = tf.constant(3.0)
  z = x * y

  qx = tf.constant(4.0)
  # The TensorFlow graph is currently
  # `x` -> `z` <- y`, `qx`

  # This adds a subgraph with newly copied nodes,
  # `qx` -> `copied/z` <- `copied/y`
  z_new = ed.copy(z, {x: qx})

  sess = tf.Session()
  sess.run(z)
  6.0
  sess.run(z_new)
  12.0
  ```
  """
    if not isinstance(org_instance,
                      (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)):
        raise TypeError("Could not copy instance: " + str(org_instance))

    if dict_swap is None:
        dict_swap = {}
    if scope[-1] != '/':
        scope += '/'

    # Swap instance if in dictionary.
    if org_instance in dict_swap and replace_itself:
        org_instance = dict_swap[org_instance]
        if not copy_q:
            return org_instance
    elif isinstance(org_instance, tf.Tensor) and replace_itself:
        # Deal with case when `org_instance` is the associated tensor
        # from the RandomVariable, e.g., `z.value()`. If
        # `dict_swap={z: qz}`, we aim to swap it with `qz.value()`.
        for key, value in six.iteritems(dict_swap):
            if isinstance(key, RandomVariable):
                if org_instance == key.value():
                    if isinstance(value, RandomVariable):
                        org_instance = value.value()
                    else:
                        org_instance = value
                    if not copy_q:
                        return org_instance
                    break

    # If instance is a tf.Variable, return it; do not copy any. Note we
    # check variables via their name. If we get variables through an
    # op's inputs, it has type tf.Tensor and not tf.Variable.
    if isinstance(org_instance, (tf.Tensor, tf.Variable)):
        for variable in tf.global_variables():
            if org_instance.name == variable.name:
                if variable in dict_swap and replace_itself:
                    # Deal with case when `org_instance` is the associated _ref
                    # tensor for a tf.Variable.
                    org_instance = dict_swap[variable]
                    if not copy_q or isinstance(org_instance, tf.Variable):
                        return org_instance
                    for variable in tf.global_variables():
                        if org_instance.name == variable.name:
                            return variable
                    break
                else:
                    return variable

    graph = tf.get_default_graph()
    new_name = scope + org_instance.name

    # If an instance of the same name exists, return it.
    if isinstance(org_instance, RandomVariable):
        for rv in random_variables():
            if new_name == rv.name:
                return rv
    elif isinstance(org_instance, (tf.Tensor, tf.Operation)):
        try:
            return graph.as_graph_element(new_name,
                                          allow_tensor=True,
                                          allow_operation=True)
        except:
            pass

    # Preserve ordering of random variables. Random variables are always
    # copied first (from parent -> child) before any deterministic
    # operations that depend on them.
    if copy_parent_rvs and \
            isinstance(org_instance, (RandomVariable, tf.Tensor, tf.Variable)):
        for v in get_parents(org_instance):
            copy(v, dict_swap, scope, True, copy_q, True)

    if isinstance(org_instance, RandomVariable):
        rv = org_instance

        # If it has copiable arguments, copy them.
        args = [
            _copy_default(arg, dict_swap, scope, True, copy_q, False)
            for arg in rv._args
        ]

        kwargs = {}
        for key, value in six.iteritems(rv._kwargs):
            if isinstance(value, list):
                kwargs[key] = [
                    _copy_default(v, dict_swap, scope, True, copy_q, False)
                    for v in value
                ]
            else:
                kwargs[key] = _copy_default(value, dict_swap, scope, True,
                                            copy_q, False)

        kwargs['name'] = new_name
        # Create new random variable with copied arguments.
        try:
            new_rv = type(rv)(*args, **kwargs)
        except ValueError:
            # Handle case where parameters are copied under absolute name
            # scope. This can cause an error when creating a new random
            # variable as tf.identity name ops are called on parameters ("op
            # with name already exists"). To avoid remove absolute name scope.
            kwargs['name'] = new_name[:-1]
            new_rv = type(rv)(*args, **kwargs)
        return new_rv
    elif isinstance(org_instance, tf.Tensor):
        tensor = org_instance

        # Do not copy tf.placeholders.
        if 'Placeholder' in tensor.op.type:
            return tensor

        # A tensor is one of the outputs of its underlying
        # op. Therefore copy the op itself.
        op = tensor.op
        new_op = copy(op, dict_swap, scope, True, copy_q, False)

        output_index = op.outputs.index(tensor)
        new_tensor = new_op.outputs[output_index]

        # Add copied tensor to collections that the original one is in.
        for name, collection in six.iteritems(tensor.graph._collections):
            if tensor in collection:
                graph.add_to_collection(name, new_tensor)

        return new_tensor
    elif isinstance(org_instance, tf.Operation):
        op = org_instance

        # Do not copy queue operations.
        if 'Queue' in op.type:
            return op

        # Copy the node def.
        # It is unique to every Operation instance. Replace the name and
        # its operation-level seed if it has one.
        node_def = deepcopy(op.node_def)
        node_def.name = new_name
        if 'seed2' in node_def.attr and tf.get_seed(None)[1] is not None:
            node_def.attr['seed2'].i = tf.get_seed(None)[1]

        # Copy other arguments needed for initialization.
        output_types = op._output_types[:]

        # If it has an original op, copy it.
        if op._original_op is not None:
            original_op = copy(op._original_op, dict_swap, scope, True, copy_q,
                               False)
        else:
            original_op = None

        # Copy the op def.
        # It is unique to every Operation type.
        op_def = deepcopy(op.op_def)

        new_op = tf.Operation(
            node_def,
            graph,
            [],  # inputs; will add them afterwards
            output_types,
            [],  # control inputs; will add them afterwards
            [],  # input types; will add them afterwards
            original_op,
            op_def)

        # advertise op early to break recursions
        graph._add_op(new_op)

        # If it has control inputs, copy them.
        control_inputs = []
        for x in op.control_inputs:
            elem = copy(x, dict_swap, scope, True, copy_q, False)
            if not isinstance(elem, tf.Operation):
                elem = tf.convert_to_tensor(elem)

            control_inputs.append(elem)

        new_op._add_control_inputs(control_inputs)

        # If it has inputs, copy them.
        for x in op.inputs:
            elem = copy(x, dict_swap, scope, True, copy_q, False)
            if not isinstance(elem, tf.Operation):
                elem = tf.convert_to_tensor(elem)

            new_op._add_input(elem)

        # Use Graph's private methods to add the op, following
        # implementation of `tf.Graph().create_op()`.
        compute_shapes = True
        compute_device = True
        op_type = new_name

        if compute_shapes:
            set_shapes_for_outputs(new_op)
        graph._record_op_seen_by_control_dependencies(new_op)

        if compute_device:
            graph._apply_device_functions(new_op)

        if graph._colocation_stack:
            all_colocation_groups = []
            for colocation_op in graph._colocation_stack:
                all_colocation_groups.extend(colocation_op.colocation_groups())
                if colocation_op.device:
                    # Make this device match the device of the colocated op, to
                    # provide consistency between the device and the colocation
                    # property.
                    if new_op.device and new_op.device != colocation_op.device:
                        logging.warning(
                            "Tried to colocate %s with an op %s that had "
                            "a different device: %s vs %s. "
                            "Ignoring colocation property.", name,
                            colocation_op.name, new_op.device,
                            colocation_op.device)

            all_colocation_groups = sorted(set(all_colocation_groups))
            new_op.node_def.attr["_class"].CopyFrom(
                attr_value_pb2.AttrValue(
                    list=attr_value_pb2.AttrValue.ListValue(
                        s=all_colocation_groups)))

        # Sets "container" attribute if
        # (1) graph._container is not None
        # (2) "is_stateful" is set in OpDef
        # (3) "container" attribute is in OpDef
        # (4) "container" attribute is None
        if (graph._container and op_type in graph._registered_ops
                and graph._registered_ops[op_type].is_stateful
                and "container" in new_op.node_def.attr
                and not new_op.node_def.attr["container"].s):
            new_op.node_def.attr["container"].CopyFrom(
                attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container)))

        return new_op
    else:
        raise TypeError("Could not copy instance: " + str(org_instance))
Esempio n. 12
0
def calculate_graph_metrics(graph_def, statistic_types, input_layer,
                            input_shape_override, batch_size):
    """Looks at the performance statistics of all nodes in the graph.

    Parameters
    ----------
    graph_def : TYPE
        Description
    statistic_types : TYPE
        Description
    input_layer : TYPE
        Description
    input_shape_override : TYPE
        Description
    batch_size : TYPE
        Description

    Returns
    -------
    TYPE
        Description

    Raises
    ------
    ValueError
        Description
    """
    tf.import_graph_def(graph_def, name="")
    total_stats = {}
    node_stats = {}
    for statistic_type in statistic_types:
        total_stats[statistic_type] = ops.OpStats(statistic_type)
        node_stats[statistic_type] = {}
    # Make sure we get pretty-printed numbers with separators.
    locale.setlocale(locale.LC_ALL, "")
    with tf.Session() as sess:
        input_tensor = sess.graph.get_tensor_by_name(input_layer)
        input_shape_tensor = input_tensor.get_shape()
        if input_shape_tensor:
            input_shape = input_shape_tensor.as_list()
        else:
            input_shape = None
        if input_shape_override:
            input_shape = input_shape_override
        if input_shape is None:
            raise ValueError(
                """No input shape was provided on the command line,"""
                """ and the input op itself had no default shape, so"""
                """ shape inference couldn't be performed. This is"""
                """ required for metrics calculations.""")
        input_shape[0] = batch_size
        input_tensor.set_shape(input_shape)
        for node in graph_def.node:
            # Ensure that the updated input shape has been fully-propagated before we
            # ask for the statistics, since they may depend on the output size.
            op = sess.graph.get_operation_by_name(node.name)
            ops.set_shapes_for_outputs(op)
            for statistic_type in statistic_types:
                current_stats = ops.get_stats_for_node_def(
                    sess.graph, node, statistic_type)
                node_stats[statistic_type][node.name] = current_stats
                total_stats[statistic_type] += current_stats
    return total_stats, node_stats
Esempio n. 13
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None):
  """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements'.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map' is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, six.string_types) for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if (return_elements is not None
      and not (isinstance(return_elements, (list, tuple))
               and all(isinstance(x, six.string_types)
                       for x in return_elements))):
    raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  with ops.op_scope(input_map.values(), name, 'import'):
    g = ops.get_default_graph()

    with ops.name_scope('_inputs'):
      input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def'.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      output_types = _OutputTypes(node, op_dict)
      with _MaybeDevice(node.device):
        name_to_op[node.name] = g.create_op(
            node.op, [], output_types, name=node.name, attrs=node.attr,
            compute_shapes=False)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def'.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(
                _InvalidNodeMessage(node, 'Input tensor %r %s'
                                    % (input_name, te.message)))

      # pylint: disable=protected_access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (", ".join(types_lib.as_dtype(x).name for x in input_types),
                   ", ".join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected_access

      # Execute shape inference for this op.
      # NOTE(mrry): If the graph contains a cycle, the full shape information
      # may not be available for this op's inputs.
      ops.set_shapes_for_outputs(op)

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
Esempio n. 14
0
def copy(org_instance, dict_swap=None, scope="copied",
         replace_itself=False, copy_q=False, copy_parent_rvs=True):
  """Build a new node in the TensorFlow graph from `org_instance`,
  where any of its ancestors existing in `dict_swap` are
  replaced with `dict_swap`'s corresponding value.

  Copying is done recursively. Any `Operation` whose output is
  required to copy `org_instance` is also copied (if it isn't already
  copied within the new scope).

  `tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue` are
  always reused and not copied. In addition, `tf.Operation`s with
  operation-level seeds are copied with a new operation-level seed.

  Args:
    org_instance: RandomVariable, tf.Operation, tf.Tensor, or tf.Variable.
      Node to add in graph with replaced ancestors.
    dict_swap: dict.
      Random variables, variables, tensors, or operations to swap with.
      Its keys are what `org_instance` may depend on, and its values are
      the corresponding object (not necessarily of the same class
      instance, but must have the same type, e.g., float32) that is used
      in exchange.
    scope: str.
      A scope for the new node(s). This is used to avoid name
      conflicts with the original node(s).
    replace_itself: bool.
      Whether to replace `org_instance` itself if it exists in
      `dict_swap`. (This is used for the recursion.)
    copy_q: bool.
      Whether to copy the replaced tensors too (if not already
      copied within the new scope). Otherwise will reuse them.
    copy_parent_rvs:
      Whether to copy parent random variables `org_instance` depends
      on. Otherwise will copy only the sample tensors and not the
      random variable class itself.

  Returns:
    RandomVariable, tf.Variable, tf.Tensor, or tf.Operation.
    The copied node.

  Raises:
    TypeError.
    If `org_instance` is not one of the above types.

  #### Examples

  ```python
  x = tf.constant(2.0)
  y = tf.constant(3.0)
  z = x * y

  qx = tf.constant(4.0)
  # The TensorFlow graph is currently
  # `x` -> `z` <- y`, `qx`

  # This adds a subgraph with newly copied nodes,
  # `qx` -> `copied/z` <- `copied/y`
  z_new = ed.copy(z, {x: qx})

  sess = tf.Session()
  sess.run(z)
  6.0
  sess.run(z_new)
  12.0
  ```
  """
  if not isinstance(org_instance,
                    (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)):
    raise TypeError("Could not copy instance: " + str(org_instance))

  if dict_swap is None:
    dict_swap = {}
  if scope[-1] != '/':
    scope += '/'

  # Swap instance if in dictionary.
  if org_instance in dict_swap and replace_itself:
    org_instance = dict_swap[org_instance]
    if not copy_q:
      return org_instance
  elif isinstance(org_instance, tf.Tensor) and replace_itself:
    # Deal with case when `org_instance` is the associated tensor
    # from the RandomVariable, e.g., `z.value()`. If
    # `dict_swap={z: qz}`, we aim to swap it with `qz.value()`.
    for key, value in six.iteritems(dict_swap):
      if isinstance(key, RandomVariable):
        if org_instance == key.value():
          if isinstance(value, RandomVariable):
            org_instance = value.value()
          else:
            org_instance = value
          if not copy_q:
            return org_instance
          break

  # If instance is a tf.Variable, return it; do not copy any. Note we
  # check variables via their name. If we get variables through an
  # op's inputs, it has type tf.Tensor and not tf.Variable.
  if isinstance(org_instance, (tf.Tensor, tf.Variable)):
    for variable in tf.global_variables():
      if org_instance.name == variable.name:
        if variable in dict_swap and replace_itself:
          # Deal with case when `org_instance` is the associated _ref
          # tensor for a tf.Variable.
          org_instance = dict_swap[variable]
          if not copy_q or isinstance(org_instance, tf.Variable):
            return org_instance
          for variable in tf.global_variables():
            if org_instance.name == variable.name:
              return variable
          break
        else:
          return variable

  graph = tf.get_default_graph()
  new_name = scope + org_instance.name

  # If an instance of the same name exists, return it.
  if isinstance(org_instance, RandomVariable):
    for rv in random_variables():
      if new_name == rv.name:
        return rv
  elif isinstance(org_instance, (tf.Tensor, tf.Operation)):
    try:
      return graph.as_graph_element(new_name,
                                    allow_tensor=True,
                                    allow_operation=True)
    except:
      pass

  # Preserve ordering of random variables. Random variables are always
  # copied first (from parent -> child) before any deterministic
  # operations that depend on them.
  if copy_parent_rvs and \
          isinstance(org_instance, (RandomVariable, tf.Tensor, tf.Variable)):
    for v in get_parents(org_instance):
      copy(v, dict_swap, scope, True, copy_q, True)

  if isinstance(org_instance, RandomVariable):
    rv = org_instance

    # If it has copiable arguments, copy them.
    args = [_copy_default(arg, dict_swap, scope, True, copy_q, False)
            for arg in rv._args]

    kwargs = {}
    for key, value in six.iteritems(rv._kwargs):
      if isinstance(value, list):
        kwargs[key] = [_copy_default(v, dict_swap, scope, True, copy_q, False)
                       for v in value]
      else:
        kwargs[key] = _copy_default(
            value, dict_swap, scope, True, copy_q, False)

    kwargs['name'] = new_name
    # Create new random variable with copied arguments.
    try:
      new_rv = type(rv)(*args, **kwargs)
    except ValueError:
      # Handle case where parameters are copied under absolute name
      # scope. This can cause an error when creating a new random
      # variable as tf.identity name ops are called on parameters ("op
      # with name already exists"). To avoid remove absolute name scope.
      kwargs['name'] = new_name[:-1]
      new_rv = type(rv)(*args, **kwargs)
    return new_rv
  elif isinstance(org_instance, tf.Tensor):
    tensor = org_instance

    # Do not copy tf.placeholders.
    if 'Placeholder' in tensor.op.type:
      return tensor

    # A tensor is one of the outputs of its underlying
    # op. Therefore copy the op itself.
    op = tensor.op
    new_op = copy(op, dict_swap, scope, True, copy_q, False)

    output_index = op.outputs.index(tensor)
    new_tensor = new_op.outputs[output_index]

    # Add copied tensor to collections that the original one is in.
    for name, collection in six.iteritems(tensor.graph._collections):
      if tensor in collection:
        graph.add_to_collection(name, new_tensor)

    return new_tensor
  elif isinstance(org_instance, tf.Operation):
    op = org_instance

    # Do not copy queue operations.
    if 'Queue' in op.type:
      return op

    # Copy the node def.
    # It is unique to every Operation instance. Replace the name and
    # its operation-level seed if it has one.
    node_def = deepcopy(op.node_def)
    node_def.name = new_name

    # when copying control flow contexts,
    # we need to make sure frame definitions are copied
    if 'frame_name' in node_def.attr and node_def.attr['frame_name'].s != b'':
      node_def.attr['frame_name'].s = (scope.encode('utf-8') +
                                       node_def.attr['frame_name'].s)

    if 'seed2' in node_def.attr and tf.get_seed(None)[1] is not None:
      node_def.attr['seed2'].i = tf.get_seed(None)[1]

    # Copy other arguments needed for initialization.
    output_types = op._output_types[:]

    # If it has an original op, copy it.
    if op._original_op is not None:
      original_op = copy(op._original_op, dict_swap, scope, True, copy_q, False)
    else:
      original_op = None

    # Copy the op def.
    # It is unique to every Operation type.
    op_def = deepcopy(op.op_def)

    new_op = tf.Operation(node_def,
                          graph,
                          [],  # inputs; will add them afterwards
                          output_types,
                          [],  # control inputs; will add them afterwards
                          [],  # input types; will add them afterwards
                          original_op,
                          op_def)

    # advertise op early to break recursions
    graph._add_op(new_op)

    # If it has control inputs, copy them.
    control_inputs = []
    for x in op.control_inputs:
      elem = copy(x, dict_swap, scope, True, copy_q, False)
      if not isinstance(elem, tf.Operation):
        elem = tf.convert_to_tensor(elem)

      control_inputs.append(elem)

    new_op._add_control_inputs(control_inputs)

    # If it has inputs, copy them.
    for x in op.inputs:
      elem = copy(x, dict_swap, scope, True, copy_q, False)
      if not isinstance(elem, tf.Operation):
        elem = tf.convert_to_tensor(elem)

      new_op._add_input(elem)

    # Copy the control flow context.
    control_flow_context = _copy_context(op._get_control_flow_context(), {},
                                         dict_swap, scope, copy_q)
    new_op._set_control_flow_context(control_flow_context)

    # Use Graph's private methods to add the op, following
    # implementation of `tf.Graph().create_op()`.
    compute_shapes = True
    compute_device = True
    op_type = new_name

    if compute_shapes:
      set_shapes_for_outputs(new_op)
    graph._record_op_seen_by_control_dependencies(new_op)

    if compute_device:
      graph._apply_device_functions(new_op)

    if graph._colocation_stack:
      all_colocation_groups = []
      for colocation_op in graph._colocation_stack:
        all_colocation_groups.extend(colocation_op.colocation_groups())
        if colocation_op.device:
          # Make this device match the device of the colocated op, to
          # provide consistency between the device and the colocation
          # property.
          if new_op.device and new_op.device != colocation_op.device:
            logging.warning("Tried to colocate %s with an op %s that had "
                            "a different device: %s vs %s. "
                            "Ignoring colocation property.",
                            name, colocation_op.name, new_op.device,
                            colocation_op.device)

      all_colocation_groups = sorted(set(all_colocation_groups))
      new_op.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))

    # Sets "container" attribute if
    # (1) graph._container is not None
    # (2) "is_stateful" is set in OpDef
    # (3) "container" attribute is in OpDef
    # (4) "container" attribute is None
    if (graph._container and
        op_type in graph._registered_ops and
        graph._registered_ops[op_type].is_stateful and
        "container" in new_op.node_def.attr and
            not new_op.node_def.attr["container"].s):
      new_op.node_def.attr["container"].CopyFrom(
          attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container)))

    return new_op
  else:
    raise TypeError("Could not copy instance: " + str(org_instance))
Esempio n. 15
0
def copy(org_instance, dict_swap=None, scope="copied",
         replace_itself=False, copy_q=False):
  """Build a new node in the TensorFlow graph from `org_instance`,
  where any of its ancestors existing in `dict_swap` are
  replaced with `dict_swap`'s corresponding value.

  The copying is done recursively, so any `Operation` whose output
  is required to evaluate `org_instance` is also copied (if it isn't
  already copied within the new scope). This is with the exception of
  `tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue`, which
  are reused and not newly copied.

  Parameters
  ----------
  org_instance : RandomVariable, tf.Variable, tf.Tensor, or tf.Operation
    Node to add in graph with replaced ancestors.
  dict_swap : dict, optional
    Random variables, variables, tensors, or operations to swap with.
    Its keys are what `org_instance` may depend on, and its values are
    the corresponding object (not necessarily of the same class
    instance, but must have the same type, e.g., float32) that is used
    in exchange.
  scope : str, optional
    A scope for the new node(s). This is used to avoid name
    conflicts with the original node(s).
  replace_itself : bool, optional
    Whether to replace `org_instance` itself if it exists in
    `dict_swap`. (This is used for the recursion.)
  copy_q : bool, optional
    Whether to copy the replaced tensors too (if not already
    copied within the new scope). Otherwise will reuse them.

  Returns
  -------
  RandomVariable, tf.Variable, tf.Tensor, or tf.Operation
    The copied node.

  Raises
  ------
  TypeError
    If `org_instance` is not one of the above types.

  Examples
  --------
  >>> x = tf.constant(2.0)
  >>> y = tf.constant(3.0)
  >>> z = x * y
  >>>
  >>> qx = tf.constant(4.0)
  >>> # The TensorFlow graph is currently
  >>> # `x` -> `z` <- y`, `qx`
  >>>
  >>> # This adds a subgraph with newly copied nodes,
  >>> # `copied/qx` -> `copied/z` <- `copied/y`
  >>> z_new = copy(z, {x: qx})
  >>>
  >>> sess = tf.Session()
  >>> sess.run(z)
  6.0
  >>> sess.run(z_new)
  12.0
  """
  if not isinstance(org_instance, RandomVariable) and \
     not isinstance(org_instance, tf.Variable) and \
     not isinstance(org_instance, tf.Tensor) and \
     not isinstance(org_instance, tf.Operation):
    raise TypeError("Could not copy instance: " + str(org_instance))

  if dict_swap is None:
    dict_swap = {}

  # Swap instance if in dictionary.
  if org_instance in dict_swap and replace_itself:
    org_instance = dict_swap[org_instance]
    if not copy_q:
      return org_instance
  elif isinstance(org_instance, tf.Tensor) and replace_itself:
    # Deal with case when `org_instance` is the associated tensor
    # from the RandomVariable, e.g., `z.value()`. If
    # `dict_swap={z: qz}`, we aim to swap it with `qz.value()`.
    for key, value in six.iteritems(dict_swap):
      if isinstance(key, RandomVariable):
        if org_instance == key.value():
          if isinstance(value, RandomVariable):
            org_instance = value.value()
          else:
            org_instance = value

          if not copy_q:
            return org_instance
          break

  graph = tf.get_default_graph()
  new_name = scope + '/' + org_instance.name

  # If an instance of the same name exists, return appropriately.
  # Do this for random variables.
  random_variables = {x.name: x for x in
                      graph.get_collection('_random_variable_collection_')}
  if new_name in random_variables:
    return random_variables[new_name]

  # Do this for tensors and operations.
  try:
    already_present = graph.as_graph_element(new_name,
                                             allow_tensor=True,
                                             allow_operation=True)
    return already_present
  except:
    pass

  # If instance is a variable, return it; do not re-copy any.
  # Note we check variables via their name and not their type. This
  # is because if we get variables through an op's inputs, it has
  # type tf.Tensor: we can only tell it is a variable via its name.
  variables = {x.name: x for x in graph.get_collection(tf.GraphKeys.VARIABLES)}
  if org_instance.name in variables:
    return graph.get_tensor_by_name(variables[org_instance.name].name)

  # Do the same for placeholders. Determine via its op's type.
  if isinstance(org_instance, tf.Tensor):
    if "Placeholder" in org_instance.op.type:
      return org_instance

  if isinstance(org_instance, RandomVariable):
    rv = org_instance

    # If it has copiable arguments, copy them.
    args = []
    for arg in rv._args:
      if isinstance(arg, RandomVariable) or \
         isinstance(arg, tf.Variable) or \
         isinstance(arg, tf.Tensor) or \
         isinstance(arg, tf.Operation):
         arg = copy(arg, dict_swap, scope, True, copy_q)

      args.append(arg)

    kwargs = {}
    for key, value in six.iteritems(rv._kwargs):
      if isinstance(value, RandomVariable) or \
         isinstance(value, tf.Variable) or \
         isinstance(value, tf.Tensor) or \
         isinstance(value, tf.Operation):
         value = copy(value, dict_swap, scope, True, copy_q)

      kwargs[key] = value

    kwargs['name'] = new_name
    # Create new random variable with copied arguments.
    new_rv = rv.__class__(*args, **kwargs)
    return new_rv
  elif isinstance(org_instance, tf.Tensor):
    tensor = org_instance

    # A tensor is one of the outputs of its underlying
    # op. Therefore copy the op itself.
    op = tensor.op
    new_op = copy(op, dict_swap, scope, True, copy_q)

    output_index = op.outputs.index(tensor)
    new_tensor = new_op.outputs[output_index]

    # Add copied tensor to collections that the original one is in.
    for name, collection in tensor.graph._collections.items():
      if tensor in collection:
        graph.add_to_collection(name, new_tensor)

    return new_tensor
  else:  # tf.Operation
    op = org_instance

    # Do not copy queue operations
    if 'Queue' in op.type:
      return op

    # If it has an original op, copy it.
    if op._original_op is not None:
      new_original_op = copy(op._original_op, dict_swap, scope, True, copy_q)
    else:
      new_original_op = None

    # If it has control inputs, copy them.
    new_control_inputs = []
    for x in op.control_inputs:
      elem = copy(x, dict_swap, scope, True, copy_q)
      if not isinstance(elem, tf.Operation):
        elem = tf.convert_to_tensor(elem)

      new_control_inputs += [elem]

    # If it has inputs, copy them.
    new_inputs = []
    for x in op.inputs:
      elem = copy(x, dict_swap, scope, True, copy_q)
      if not isinstance(elem, tf.Operation):
        elem = tf.convert_to_tensor(elem)

      new_inputs += [elem]

    # Make a copy of the node def.
    # As an instance of tensorflow.core.framework.graph_pb2.NodeDef, it
    # stores string-based info such as name, device, and type of the op.
    # It is unique to every Operation instance.
    new_node_def = deepcopy(op.node_def)
    new_node_def.name = new_name

    # Copy the other inputs needed for initialization.
    output_types = op._output_types[:]
    input_types = op._input_types[:]

    # Make a copy of the op def.
    # It is unique to every Operation type.
    op_def = deepcopy(op.op_def)

    ret = tf.Operation(new_node_def,
                       graph,
                       new_inputs,
                       output_types,
                       new_control_inputs,
                       input_types,
                       new_original_op,
                       op_def)

    # Use Graph's private methods to add the op, following
    # implementation of `tf.Graph().create_op()`.
    compute_shapes = True
    compute_device = True
    op_type = new_name

    if compute_shapes:
      set_shapes_for_outputs(ret)
    graph._add_op(ret)
    graph._record_op_seen_by_control_dependencies(ret)

    if compute_device:
      graph._apply_device_functions(ret)

    if graph._colocation_stack:
      all_colocation_groups = []
      for colocation_op in graph._colocation_stack:
        all_colocation_groups.extend(colocation_op.colocation_groups())
        if colocation_op.device:
          # Make this device match the device of the colocated op, to
          # provide consistency between the device and the colocation
          # property.
          if ret.device and ret.device != colocation_op.device:
            logging.warning("Tried to colocate %s with an op %s that had "
                            "a different device: %s vs %s. "
                            "Ignoring colocation property.",
                            name, colocation_op.name, ret.device,
                            colocation_op.device)
          else:
            ret._set_device(colocation_op.device)

      all_colocation_groups = sorted(set(all_colocation_groups))
      ret.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))

    # Sets "container" attribute if
    # (1) graph._container is not None
    # (2) "is_stateful" is set in OpDef
    # (3) "container" attribute is in OpDef
    # (4) "container" attribute is None
    if (graph._container and
        op_type in graph._registered_ops and
        graph._registered_ops[op_type].is_stateful and
        "container" in ret.node_def.attr and
            not ret.node_def.attr["container"].s):
      ret.node_def.attr["container"].CopyFrom(
          attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container)))

    return ret
Esempio n. 16
0
def calculate_graph_metrics(graph_def, statistic_types, input_layer,
                            input_shape_override, batch_size):
    """Looks at the performance statistics of all nodes in the graph.

    Parameters
    ----------
    graph_def : TYPE
        Description
    statistic_types : TYPE
        Description
    input_layer : TYPE
        Description
    input_shape_override : TYPE
        Description
    batch_size : TYPE
        Description

    Returns
    -------
    TYPE
        Description

    Raises
    ------
    ValueError
        Description
    """
    tf.import_graph_def(graph_def, name="")
    total_stats = {}
    node_stats = {}
    for statistic_type in statistic_types:
        total_stats[statistic_type] = ops.OpStats(statistic_type)
        node_stats[statistic_type] = {}
    # Make sure we get pretty-printed numbers with separators.
    locale.setlocale(locale.LC_ALL, "")
    with tf.Session() as sess:
        input_tensor = sess.graph.get_tensor_by_name(input_layer)
        input_shape_tensor = input_tensor.get_shape()
        if input_shape_tensor:
            input_shape = input_shape_tensor.as_list()
        else:
            input_shape = None
        if input_shape_override:
            input_shape = input_shape_override
        if input_shape is None:
            raise ValueError("""No input shape was provided on the command line,"""
                             """ and the input op itself had no default shape, so"""
                             """ shape inference couldn't be performed. This is"""
                             """ required for metrics calculations.""")
        input_shape[0] = batch_size
        input_tensor.set_shape(input_shape)
        for node in graph_def.node:
            # Ensure that the updated input shape has been fully-propagated before we
            # ask for the statistics, since they may depend on the output size.
            op = sess.graph.get_operation_by_name(node.name)
            ops.set_shapes_for_outputs(op)
            for statistic_type in statistic_types:
                current_stats = ops.get_stats_for_node_def(sess.graph, node,
                                                           statistic_type)
                node_stats[statistic_type][node.name] = current_stats
                total_stats[statistic_type] += current_stats
    return total_stats, node_stats