Пример #1
0
    def test_add_pruned_signature(self):
        base_meta_graph_def = meta_graph_pb2.MetaGraphDef()

        signature_name_keep = 'test_signature_keep'
        base_sig_keep = base_meta_graph_def.signature_def[signature_name_keep]
        base_sig_keep.inputs['input_1'].name = 'input_1'
        base_sig_keep.outputs['output_1'].name = 'output_1'

        signature_name_remove = 'test_signature_remove'
        base_sig_remove = base_meta_graph_def.signature_def[
            signature_name_remove]
        base_sig_remove.inputs['node2'].name = 'node2'
        base_sig_remove.outputs['output_1'].name = 'output_1'

        meta_graph_def = meta_graph_pb2.MetaGraphDef()
        removed_op_names = ['node2', 'node4', 'node5']
        meta_graph_transform._add_pruned_signature(base_meta_graph_def,
                                                   meta_graph_def,
                                                   signature_name_keep,
                                                   removed_op_names)
        meta_graph_transform._add_pruned_signature(base_meta_graph_def,
                                                   meta_graph_def,
                                                   signature_name_remove,
                                                   removed_op_names)

        self.assertTrue(signature_name_keep in meta_graph_def.signature_def)
        sig_keep = meta_graph_def.signature_def[signature_name_keep]
        self.assertEqual(base_sig_keep, sig_keep)

        self.assertFalse(signature_name_remove in meta_graph_def.signature_def)
Пример #2
0
    def testAssertProtoEqualsAny(self):
        # Test assertProtoEquals with a protobuf.Any field.
        meta_graph_def_str = """
    meta_info_def {
      meta_graph_version: "outer"
      any_info {
        [type.googleapis.com/tensorflow.MetaGraphDef] {
          meta_info_def {
            meta_graph_version: "inner"
          }
        }
      }
    }
    """
        meta_graph_def_outer = meta_graph_pb2.MetaGraphDef()
        meta_graph_def_outer.meta_info_def.meta_graph_version = "outer"
        meta_graph_def_inner = meta_graph_pb2.MetaGraphDef()
        meta_graph_def_inner.meta_info_def.meta_graph_version = "inner"
        meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner)
        self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer)
        self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer)

        # Check if the assertion failure message contains the content of
        # the inner proto.
        with self.assertRaisesRegexp(AssertionError,
                                     r'meta_graph_version: "inner"'):
            self.assertProtoEquals("", meta_graph_def_outer)
Пример #3
0
    def test_add_pruned_collection_proto_in_any_list(self):
        # Note: This also tests _is_removed_mentioned().
        collection_name = 'proto_collection'
        base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
        base_meta_graph_def.collection_def[
            collection_name].any_list.value.extend([
                _make_asset_file_def_any('node1'),
                _make_asset_file_def_any('node2'),
                _make_asset_file_def_any('node3'),
                _make_asset_file_def_any('node4'),
                _make_asset_file_def_any('/a/a_1'),
                _make_asset_file_def_any('/b/b_1')
            ])

        meta_graph_def = meta_graph_pb2.MetaGraphDef()
        removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1']
        meta_graph_transform._add_pruned_collection(base_meta_graph_def,
                                                    meta_graph_def,
                                                    collection_name,
                                                    removed_op_names)

        collection = meta_graph_def.collection_def[collection_name]

        expected_protos = [
            _make_asset_file_def_any('node1'),
            _make_asset_file_def_any('node3'),
            _make_asset_file_def_any('/a/a_1'),
        ]
        self.assertEqual(expected_protos, collection.any_list.value[:])
Пример #4
0
  def test_add_pruned_collection_node(self):
    collection_name = 'node_collection'
    base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
    base_meta_graph_def.collection_def[collection_name].node_list.value.extend(
        ['node1', 'node2', 'node3', 'node4'])

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    removed_op_names = ['node2', 'node4', 'node5']
    meta_graph_transform._add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)

    collection = meta_graph_def.collection_def[collection_name]

    expected_nodes = ['node1', 'node3']
    self.assertEqual(expected_nodes, collection.node_list.value)
Пример #5
0
  def test_add_pruned_saver(self):
    base_meta_graph_def = meta_graph_pb2.MetaGraphDef()

    base_meta_graph_def.saver_def.filename_tensor_name = 'node1'
    base_meta_graph_def.saver_def.save_tensor_name = 'node3'
    base_meta_graph_def.saver_def.restore_op_name = 'node6'

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    removed_op_names = ['node2', 'node4', 'node5']
    meta_graph_transform._add_pruned_saver(base_meta_graph_def,
                                           meta_graph_def,
                                           removed_op_names)

    # TODO(b/63447631): For now the saver is just copied unchanged
    self.assertEqual(base_meta_graph_def.saver_def, meta_graph_def.saver_def)
Пример #6
0
  def test_add_pruned_collection_int(self):
    collection_name = 'int_collection'
    base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
    base_meta_graph_def.collection_def[collection_name].int64_list.value[:] = (
        [10, 20, 30, 40])

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    removed_op_names = ['node2', 'node4', 'node5']
    meta_graph_transform._add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)

    collection = meta_graph_def.collection_def[collection_name]

    expected_ints = [10, 20, 30, 40]
    self.assertEqual(expected_ints, collection.int64_list.value)
Пример #7
0
 def testGetNodeMapMultiple(self):
     meta_graph_def = meta_graph_pb2.MetaGraphDef()
     meta_graph_def.collection_def[
         'my_collection/%s' %
         encoding.KEY_SUFFIX].bytes_list.value[:] = map(
             encoding.encode_key, ['alpha', 'bravo', 'charlie'])
     meta_graph_def.collection_def[
         'my_collection/fruits'].bytes_list.value[:] = [
             'apple', 'banana', 'cherry'
         ]
     meta_graph_def.collection_def[
         'my_collection/animals'].bytes_list.value[:] = [
             'aardvark', 'badger', 'camel'
         ]
     expected = {
         'alpha': {
             'fruits': 'apple',
             'animals': 'aardvark'
         },
         'bravo': {
             'fruits': 'banana',
             'animals': 'badger'
         },
         'charlie': {
             'fruits': 'cherry',
             'animals': 'camel'
         }
     }
     self.assertDictEqual(
         expected,
         graph_ref.get_node_map(meta_graph_def, 'my_collection',
                                ['fruits', 'animals']))
Пример #8
0
    def create_apply_graph(self, signature, input_tensors, name):
        """See `ModuleImpl.create_apply_graph`."""
        signature_def = self._meta_graph.signature_def.get(signature)

        # Build a input map to feed when importing the apply-graph by augmenting the
        # state_map with the input args. This allows an input to override a tensor
        # from the state-graph.
        feed_map = dict(self._state_map)
        feed_map.update(
            tensor_info.build_input_map(signature_def.inputs, input_tensors))

        # Make state tensors enter the current context. This way the Module can be
        # applied inside a control flow structure such as a while_loop.
        control_flow = self._graph._get_control_flow_context()  # pylint: disable=protected-access
        if control_flow:
            for key, value in sorted(feed_map.items()):
                feed_map[key] = control_flow.AddValue(value)

        # Don't mark the name as used at this point - import_scoped_meta_graph will
        # start using it.
        absolute_scope_name = self._graph.unique_name(name, mark_as_used=False)
        relative_scope_name = absolute_scope_name.split("/")[-1]

        import_collections = [
            # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS
            # ops, however one could create a graph that uses an asset at any other
            # time. As so everytime we bring the tensor with that has the asset
            # filename we must annotate it as so, so later re-exports have that
            # semantic information and can handle it.
            tf.GraphKeys.ASSET_FILEPATHS,
            tf.GraphKeys.COND_CONTEXT,
            tf.GraphKeys.WHILE_CONTEXT,
        ]
        if self._trainable:
            import_collections.extend([tf.GraphKeys.UPDATE_OPS])

        meta_graph = meta_graph_pb2.MetaGraphDef()
        meta_graph.CopyFrom(self._meta_graph)

        meta_graph_lib.filter_collections(meta_graph, import_collections)
        meta_graph_lib.prefix_shared_name_attributes(meta_graph,
                                                     absolute_scope_name)

        tf.train.import_meta_graph(meta_graph,
                                   input_map=feed_map,
                                   import_scope=relative_scope_name)
        fix_colocation_after_import(input_map=feed_map,
                                    absolute_import_scope=absolute_scope_name)

        def get_tensor(name):
            # When trying to output an input tensor there are no nodes created within
            # the apply scope. So one must look into the input map.
            try:
                return feed_map[name]
            except KeyError:
                return self._graph.get_tensor_by_name(
                    meta_graph_lib.prepend_name_scope(
                        name, import_scope=absolute_scope_name))

        return tensor_info.build_output_map(signature_def.outputs, get_tensor)
Пример #9
0
def _build_meta_graph(obj,
                      export_dir,
                      signatures,
                      options,
                      meta_graph_def=None):
    """Creates a MetaGraph containing the resources and functions of an object."""
    if ops.inside_function():
        raise AssertionError(
            "tf.saved_model.save is not supported inside a traced "
            "@tf.function. Move the call to the outer eagerly-executed "
            "context.")
    # pylint: enable=line-too-long
    if not isinstance(obj, base.Trackable):
        raise ValueError(
            "Expected a Trackable object for export, got {}.".format(obj))
    meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()

    checkpoint_graph_view = _AugmentedGraphView(obj)
    if signatures is None:
        signatures = signature_serialization.find_function_to_export(
            checkpoint_graph_view)

    signatures, wrapped_functions = (
        signature_serialization.canonicalize_signatures(signatures))
    signature_serialization.validate_saveable_view(checkpoint_graph_view)
    signature_map = signature_serialization.create_signature_map(signatures)
    checkpoint_graph_view.add_object(
        parent_node=checkpoint_graph_view.root,
        name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME,
        subgraph_root=signature_map)

    # Use _SaveableView to provide a frozen listing of properties and functions.
    # Note we run this twice since, while constructing the view the first time
    # there can be side effects of creating variables.
    _ = _SaveableView(checkpoint_graph_view)
    saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions)
    object_saver = util.TrackableSaver(checkpoint_graph_view)
    asset_info, exported_graph = _fill_meta_graph_def(
        meta_graph_def, saveable_view, signatures, options.namespace_whitelist)
    if options.function_aliases:
        function_aliases = meta_graph_def.meta_info_def.function_aliases
        for alias, func in options.function_aliases.items():
            for fdef in func._stateful_fn._function_cache.all_values():  # pylint: disable=protected-access
                function_aliases[fdef.name] = alias
            for fdef in func._stateless_fn._function_cache.all_values():  # pylint: disable=protected-access
                function_aliases[fdef.name] = alias

    object_graph_proto = _serialize_object_graph(saveable_view,
                                                 asset_info.asset_index)
    meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)

    # Save debug info, if requested.
    if options.save_debug_info:
        graph_debug_info = _export_debug_info(exported_graph)
        file_io.atomic_write_string_to_file(
            os.path.join(utils_impl.get_or_create_debug_dir(export_dir),
                         constants.DEBUG_INFO_FILENAME_PB),
            graph_debug_info.SerializeToString(deterministic=True))

    return meta_graph_def, exported_graph, object_saver, asset_info
Пример #10
0
def main(_):
    if FLAGS.metagraphdef:
        with gfile.GFile(FLAGS.metagraphdef) as meta_file:
            metagraph = meta_graph_pb2.MetaGraphDef()
            metagraph.ParseFromString(meta_file.read())
    else:
        with gfile.GFile(FLAGS.graphdef) as graph_file:
            graph_def = graph_pb2.GraphDef()
            if FLAGS.graphdef.endswith(".pbtxt"):
                text_format.Merge(graph_file.read(), graph_def)
            else:
                graph_def.ParseFromString(graph_file.read())
            importer.import_graph_def(graph_def, name="")
            graph = ops.get_default_graph()
            fetch = graph.get_operation_by_name(FLAGS.fetch)
            graph.add_to_collection("train_op", fetch)
            metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(),
                                                graph=graph)

    if FLAGS.rewriter_config is not None:
        rewriter_config = rewriter_config_pb2.RewriterConfig()
        text_format.Merge(FLAGS.rewriter_config, rewriter_config)
        optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config,
                                                     metagraph)
        metagraph.graph_def.CopyFrom(optimized_graph)

    report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
    print(report)
Пример #11
0
    def testGetSignatureDefByKeyRegression(self):
        input1 = constant_op.constant("a", name="input-1")
        output1 = constant_op.constant("b", name="output-1")

        meta_graph_def = meta_graph_pb2.MetaGraphDef()
        self._add_to_signature_def_map(
            meta_graph_def, {
                "my_regression":
                signature_def_utils.regression_signature_def(input1, output1)
            })

        # Look up the regression signature with the key used while saving.
        signature_def = signature_def_contrib_utils.get_signature_def_by_key(
            meta_graph_def, "my_regression")

        # Check the method name to match the constants regression method name.
        self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
                         signature_def.method_name)

        # Check inputs in signature def.
        self.assertEqual(1, len(signature_def.inputs))
        self._check_tensor_info(signature_def.inputs,
                                signature_constants.REGRESS_INPUTS,
                                "input-1:0")

        # Check outputs in signature def.
        self.assertEqual(1, len(signature_def.outputs))
        self._check_tensor_info(signature_def.outputs,
                                signature_constants.REGRESS_OUTPUTS,
                                "output-1:0")
Пример #12
0
    def _GenerateTestData(self):
        """Generates the test data directory.

    The test data has a single run named run1 which contains:
     - a graph definition and metagraph definition

    Returns:
      temp_dir: The directory the test data is generated under.
    """
        temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
        self.addCleanup(shutil.rmtree, temp_dir)
        run1_path = os.path.join(temp_dir, 'run1')
        os.makedirs(run1_path)
        writer = tf.summary.FileWriter(run1_path)

        # Add a simple graph event.
        graph_def = tf.GraphDef()
        node1 = graph_def.node.add()
        node1.name = 'a'
        node2 = graph_def.node.add()
        node2.name = 'b'
        node2.attr['very_large_attr'].s = b'a' * 2048  # 2 KB attribute

        meta_graph_def = meta_graph_pb2.MetaGraphDef(graph_def=graph_def)

        if self._only_use_meta_graph:
            writer.add_meta_graph(meta_graph_def)
        else:
            writer.add_graph(graph_def)

        writer.flush()
        writer.close()

        return temp_dir
Пример #13
0
def read_meta_graph_file(filename):
    """Reads a file containing `MetaGraphDef` and returns the protocol buffer.

  Args:
    filename: `meta_graph_def` filename including the path.

  Returns:
    A `MetaGraphDef` protocol buffer.

  Raises:
    IOError: If the file doesn't exist, or cannot be successfully parsed.
  """
    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    if not file_io.file_exists(filename):
        raise IOError(f"File does not exist. Received: {filename}.")
    # First try to read it as a binary file.
    with file_io.FileIO(filename, "rb") as f:
        file_content = f.read()
    try:
        meta_graph_def.ParseFromString(file_content)
        return meta_graph_def
    except Exception:  # pylint: disable=broad-except
        pass

    # Next try to read it as a text file.
    try:
        text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
    except text_format.ParseError as e:
        raise IOError(f"Cannot parse file {filename}: {str(e)}.")

    return meta_graph_def
Пример #14
0
 def testGetNodeMapBasic(self):
     meta_graph_def = meta_graph_pb2.MetaGraphDef()
     meta_graph_def.collection_def[
         'my_collection/%s' %
         encoding.KEY_SUFFIX].bytes_list.value[:] = map(
             encoding.encode_key, ['alpha', 'bravo', 'charlie'])
     meta_graph_def.collection_def[
         'my_collection/fruits'].bytes_list.value[:] = [
             b'apple', b'banana', b'cherry'
         ]
     expected = {
         'alpha': {
             'fruits': b'apple'
         },
         'bravo': {
             'fruits': b'banana'
         },
         'charlie': {
             'fruits': b'cherry'
         }
     }
     self.assertDictEqual(
         expected,
         graph_ref.get_node_map(meta_graph_def, 'my_collection',
                                ['fruits']))
Пример #15
0
    def testGetSignatureDefByKeyClassification(self):
        input1 = constant_op.constant("a", name="input-1")
        output1 = constant_op.constant("b", name="output-1")
        output2 = constant_op.constant("c", name="output-2")

        meta_graph_def = meta_graph_pb2.MetaGraphDef()
        self._add_to_signature_def_map(
            meta_graph_def, {
                "my_classification":
                signature_def_utils.classification_signature_def(
                    input1, output1, output2)
            })

        # Look up the classification signature def with the key used while saving.
        signature_def = signature_def_contrib_utils.get_signature_def_by_key(
            meta_graph_def, "my_classification")

        # Check the method name to match the constants classification method name.
        self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
                         signature_def.method_name)

        # Check inputs in signature def.
        self.assertEqual(1, len(signature_def.inputs))
        self._check_tensor_info(signature_def.inputs,
                                signature_constants.CLASSIFY_INPUTS,
                                "input-1:0")

        # Check outputs in signature def.
        self.assertEqual(2, len(signature_def.outputs))
        self._check_tensor_info(signature_def.outputs,
                                signature_constants.CLASSIFY_OUTPUT_CLASSES,
                                "output-1:0")
        self._check_tensor_info(signature_def.outputs,
                                signature_constants.CLASSIFY_OUTPUT_SCORES,
                                "output-2:0")
Пример #16
0
def get_metagraph():
    """Constructs and returns a MetaGraphDef from the input file."""
    if FLAGS.metagraphdef:
        with gfile.GFile(FLAGS.metagraphdef) as meta_file:
            metagraph = meta_graph_pb2.MetaGraphDef()
            if FLAGS.metagraphdef.endswith(".pbtxt"):
                text_format.Merge(meta_file.read(), metagraph)
            else:
                metagraph.ParseFromString(meta_file.read())
        if FLAGS.fetch is not None:
            fetch_collection = meta_graph_pb2.CollectionDef()
            for fetch in FLAGS.fetch.split(","):
                fetch_collection.node_list.value.append(fetch)
            metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
    else:
        with gfile.GFile(FLAGS.graphdef) as graph_file:
            graph_def = graph_pb2.GraphDef()
            if FLAGS.graphdef.endswith(".pbtxt"):
                text_format.Merge(graph_file.read(), graph_def)
            else:
                graph_def.ParseFromString(graph_file.read())
            importer.import_graph_def(graph_def, name="")
            graph = ops.get_default_graph()
            for fetch in FLAGS.fetch.split(","):
                fetch_op = graph.get_operation_by_name(fetch)
                graph.add_to_collection("train_op", fetch_op)
            metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(),
                                                graph=graph)
    return metagraph
Пример #17
0
    def _assertEventsWithGraph(self, test_dir, g, has_shapes):
        meta_graph_def = meta_graph.create_meta_graph_def(
            graph_def=g.as_graph_def(add_shapes=has_shapes))

        rr = self._EventsReader(test_dir)

        # The first event should list the file_version.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals("brain.Event:2", ev.file_version)

        # The next event should have the graph.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals(0, ev.step)
        ev_graph = graph_pb2.GraphDef()
        ev_graph.ParseFromString(ev.graph_def)
        self.assertProtoEquals(g.as_graph_def(add_shapes=has_shapes), ev_graph)

        # The next event should have the metagraph.
        ev = next(rr)
        self._assertRecent(ev.wall_time)
        self.assertEquals(0, ev.step)
        ev_meta_graph = meta_graph_pb2.MetaGraphDef()
        ev_meta_graph.ParseFromString(ev.meta_graph_def)
        self.assertProtoEquals(meta_graph_def, ev_meta_graph)

        # We should be done.
        self.assertRaises(StopIteration, lambda: next(rr))
Пример #18
0
def meta_graph_transform(base_meta_graph_def,
                         input_names,
                         output_names,
                         transforms,
                         tags,
                         checkpoint_path=None):
    """Apply the Graph Transform tool to a MetaGraphDef.

  Args:
    base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    transforms: A list of strings naming the graph transforms to be applied in
      order.  These transform names are exactly those supported by the Graph
      Transform Tool, with the addition of the 'freeze_graph' transform.
    tags: A list of tags with which to annotate the transformed MetaGraphDef.
    checkpoint_path: A path to a checkpoint to restore during freezing,
      if needed (default None).

  Returns:
    A new transformed MetaGraphDef protocol buffer.
  """
    meta_graph_def = _meta_graph_pb2.MetaGraphDef()

    initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)

    transformed_graph_def = _do_transforms(base_meta_graph_def.graph_def,
                                           input_names, output_names,
                                           initializer_names, transforms,
                                           base_meta_graph_def.saver_def,
                                           checkpoint_path)

    meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
    meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
    meta_graph_def.meta_info_def.ClearField('tags')
    for tag in tags:
        meta_graph_def.meta_info_def.tags.append(tag)

    base_op_names = [
        compat.as_str(node.name) for node in base_meta_graph_def.graph_def.node
    ]
    retained_op_names = [
        compat.as_str(node.name) for node in meta_graph_def.graph_def.node
    ]
    removed_op_names = set(base_op_names) - set(retained_op_names)

    # Copy saver, excluding any pruned nodes
    _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)

    # Copy collections, excluding any pruned nodes
    for collection_name in base_meta_graph_def.collection_def:
        _add_pruned_collection(base_meta_graph_def, meta_graph_def,
                               collection_name, removed_op_names)

    # Copy signature_defs, excluding any pruned nodes
    for signature_name in base_meta_graph_def.signature_def:
        _add_pruned_signature(base_meta_graph_def, meta_graph_def,
                              signature_name, removed_op_names)

    return meta_graph_def
Пример #19
0
    def testNoLogdirButExplicitSummaryWriter(self):
        logdir = self._test_dir("explicit_summary_writer")
        with ops.Graph().as_default():
            summary.scalar("c1", constant_op.constant(1))
            summary.scalar("c2", constant_op.constant(2))
            summary.scalar("c3", constant_op.constant(3))
            summ = summary.merge_all()
            sw = writer.FileWriter(logdir)
            sv = supervisor.Supervisor(logdir="",
                                       summary_op=None,
                                       summary_writer=sw)
            meta_graph_def = meta_graph.create_meta_graph_def()
            sess = sv.prepare_or_wait_for_session("")
            sv.summary_computed(sess, sess.run(summ))
            sess.close()
            # Wait to make sure everything is written to file before stopping.
            time.sleep(1)
            sv.stop()

        # Check the summary was written to 'logdir'
        rr = _summary_iterator(logdir)

        # The first event should list the file_version.
        ev = next(rr)
        self.assertEquals("brain.Event:2", ev.file_version)

        # The next one has the graph.
        ev = next(rr)
        ev_graph = graph_pb2.GraphDef()
        ev_graph.ParseFromString(ev.graph_def)
        self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True),
                               ev_graph)

        # Stored MetaGraphDef
        ev = next(rr)
        ev_meta_graph = meta_graph_pb2.MetaGraphDef()
        ev_meta_graph.ParseFromString(ev.meta_graph_def)
        self.assertProtoEquals(meta_graph_def, ev_meta_graph)
        self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True),
                               ev_meta_graph.graph_def)

        # The next one should have the values from the summary.
        ev = next(rr)
        self.assertProtoEquals(
            """
      value { tag: 'c1' simple_value: 1.0 }
      value { tag: 'c2' simple_value: 2.0 }
      value { tag: 'c3' simple_value: 3.0 }
      """, ev.summary)

        # The next one should be a stop message if we closed cleanly.
        ev = next(rr)
        self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)

        # We should be done.
        self.assertRaises(StopIteration, lambda: next(rr))
Пример #20
0
 def testNoModuleAttachments(self):
     meta_graph = meta_graph_pb2.MetaGraphDef()
     with tf.Graph().as_default():
         # No calls to attach_bytes.
         saved_model_lib._export_module_attachments(meta_graph)
     actual = saved_model_lib.get_attached_bytes_map(meta_graph)
     self.assertDictEqual({}, actual)
     # Check there were no unwarranted subscript operations.
     self.assertFalse(saved_model_lib.ATTACHMENT_COLLECTION_SAVED in
                      meta_graph.collection_def)
Пример #21
0
def main(unused_argv=None):
    event_file = os.path.expanduser(FLAGS.event_file)
    if not event_file:
        msg = ('The path to an event_file must be specified. '
               'Run `inspect_summary.py --help` for usage instructions.')
        logging.error(msg)
        return -1

    for event in tf.train.summary_iterator(event_file):
        # Yields a sequence of `tensorflow.core.util.event_pb2.Event`
        # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/event.proto

        #print '>>', event
        #print event.wall_time, event.step
        #import IPython; IPython.embed(); exit()

        what = event.WhichOneof('what')
        if what == 'file_version':
            assert event.file_version == "brain.Event:2"
        elif what == 'graph_def':
            # An encoded version of a GraphDef.
            logging.info('=' * 80)
            logging.info('GraphDef[wall_time=%s, step=%s]' %
                         (event.wall_time, event.step))
            graph = graph_pb2.GraphDef()
            graph.ParseFromString(event.graph_def)
            logging.info(str(graph))
            logging.info('=' * 80)
        elif what == 'summary':
            logging.warning('NYI: summary')
        elif what == 'log_message':
            logging.warning('NYI: log_message')  # user output a log message
        elif what == 'session_log':
            logging.warning('NYI: session_log')
        elif what == 'tagged_run_metadata':
            trn = event.tagged_run_metadata
            logging.info('=' * 80)
            logging.info('TaggedRunMetadata [wall_time=%s, step=%s, tag=%s]' %
                         (event.wall_time, event.step, trn.tag))
            # Lazy deserialization from a byte buffer
            run_meta = config_pb2.RunMetadata()
            run_meta.ParseFromString(trn.run_metadata)
            logging.info(str(run_meta))
            logging.info('=' * 80)
        elif what == 'meta_graph_def':
            # An encoded version of a MetaGraphDef.
            logging.info('=' * 80)
            logging.info('MetaGraphDef [wall_time=%s, step=%s]' %
                         (event.wall_time, event.step))
            meta_graph = meta_graph_pb2.MetaGraphDef()
            meta_graph.ParseFromString(event.meta_graph_def)
            logging.info(str(meta_graph))
            logging.info('=' * 80)
        else:
            raise NotImplemented()
Пример #22
0
    def _convert_saved_model(self):
        """Convert the input SavedModel."""
        graph = ops.Graph()
        with session.Session(graph=graph, config=self._session_config) as sess:
            input_meta_graph_def = loader.load(sess,
                                               self._input_saved_model_tags,
                                               self._input_saved_model_dir)
            input_signature_def = input_meta_graph_def.signature_def[
                self._input_saved_model_signature_key]

            def _gather_names(tensor_info):
                """Get the node names from a TensorInfo."""
                return set([
                    tensor_info[key].name.split(":")[0] for key in tensor_info
                ])

            # Get input and outputs from all SignatureDef.
            output_node_names = _gather_names(
                input_signature_def.inputs).union(
                    _gather_names(input_signature_def.outputs))

            # Preserve nodes in collection
            for collection_key in self._collections_to_keep(
                    input_meta_graph_def.collection_def):
                for op in sess.graph.get_collection(collection_key):
                    if isinstance(op, ops.Operation):
                        output_node_names.add(op.name.split(":")[0])

            # Freeze the variables in the SavedModel graph and copy the frozen
            # graph over.
            frozen_graph_def = graph_util.convert_variables_to_constants(
                sess, sess.graph.as_graph_def(add_shapes=True),
                list(output_node_names))
            self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
            self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)

            # Copy the collections that are not variables.
            for collection_key in self._collections_to_keep(
                    input_meta_graph_def.collection_def):
                self._grappler_meta_graph_def.collection_def[
                    collection_key].CopyFrom(
                        input_meta_graph_def.collection_def[collection_key])

            self._add_nodes_blacklist()

            # Copy other information.
            self._grappler_meta_graph_def.meta_info_def.CopyFrom(
                input_meta_graph_def.meta_info_def)
            self._grappler_meta_graph_def.signature_def[
                self._input_saved_model_signature_key].CopyFrom(
                    input_signature_def)
            # TODO(laigd): maybe add back AssetFileDef.

        self._run_conversion()
Пример #23
0
  def testGetNodeInGraph(self):
    g = tf.Graph()
    with g.as_default():
      apple = tf.constant(1.0)

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    meta_graph_def.collection_def['fruit_node'].any_list.value.extend(
        [encoding.encode_tensor_node(apple)])

    self.assertEqual(
        apple, graph_ref.get_node_in_graph(meta_graph_def, 'fruit_node', g))
Пример #24
0
 def testModuleAttachments(self):
   meta_graph = meta_graph_pb2.MetaGraphDef()
   with tf.Graph().as_default():
     saved_model_lib.attach_bytes("key1", tf.compat.as_bytes("oops"))
     saved_model_lib.attach_bytes("key2", tf.compat.as_bytes("value2"))
     saved_model_lib.attach_bytes("key1", tf.compat.as_bytes("value1"))
     saved_model_lib._export_module_attachments(meta_graph)
   actual = saved_model_lib.get_attached_bytes_map(meta_graph)
   expected = {"key1": tf.compat.as_bytes("value1"),
               "key2": tf.compat.as_bytes("value2")}
   self.assertDictEqual(expected, actual)
Пример #25
0
  def test_add_pruned_collection_proto_in_bytes_list(self):
    collection_name = 'proto_collection'
    base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
    base_meta_graph_def.collection_def[collection_name].bytes_list.value.extend(
        [compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node2'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4')))])

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    removed_op_names = ['node2', 'node4', 'node5']
    meta_graph_transform._add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)

    collection = meta_graph_def.collection_def[collection_name]

    expected_values = [
        compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
        compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3')))]
    self.assertEqual(expected_values, collection.bytes_list.value[:])
Пример #26
0
    def testGetSignatureDefByKey(self):
        x = array_ops.placeholder(dtypes.float32, 1, name="x")
        x_tensor_info = utils.build_tensor_info(x)

        y = array_ops.placeholder(dtypes.float32, name="y")
        y_tensor_info = utils.build_tensor_info(y)

        foo_signature_def = signature_def_utils.build_signature_def(
            {"foo-input": x_tensor_info}, {"foo-output": y_tensor_info},
            "foo-method-name")
        bar_signature_def = signature_def_utils.build_signature_def(
            {"bar-input": x_tensor_info}, {"bar-output": y_tensor_info},
            "bar-method-name")
        meta_graph_def = meta_graph_pb2.MetaGraphDef()
        self._add_to_signature_def_map(meta_graph_def, {
            "foo": foo_signature_def,
            "bar": bar_signature_def
        })

        # Look up a key that does not exist in the SignatureDefMap.
        missing_key = "missing-key"
        with self.assertRaisesRegexp(
                ValueError,
                "No SignatureDef with key '%s' found in MetaGraphDef" %
                missing_key):
            signature_def_contrib_utils.get_signature_def_by_key(
                meta_graph_def, missing_key)

        # Look up the key, `foo` which exists in the SignatureDefMap.
        foo_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
            meta_graph_def, "foo")
        self.assertTrue("foo-method-name", foo_signature_def.method_name)

        # Check inputs in signature def.
        self.assertEqual(1, len(foo_signature_def.inputs))
        self._check_tensor_info(foo_signature_def.inputs, "foo-input", "x:0")

        # Check outputs in signature def.
        self.assertEqual(1, len(foo_signature_def.outputs))
        self._check_tensor_info(foo_signature_def.outputs, "foo-output", "y:0")

        # Look up the key, `bar` which exists in the SignatureDefMap.
        bar_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
            meta_graph_def, "bar")
        self.assertTrue("bar-method-name", bar_signature_def.method_name)

        # Check inputs in signature def.
        self.assertEqual(1, len(bar_signature_def.inputs))
        self._check_tensor_info(bar_signature_def.inputs, "bar-input", "x:0")

        # Check outputs in signature def.
        self.assertEqual(1, len(bar_signature_def.outputs))
        self._check_tensor_info(bar_signature_def.outputs, "bar-output", "y:0")
Пример #27
0
    def _convert_saved_model(self):
        """Convert the input SavedModel."""
        graph = ops.Graph()
        with session.Session(graph=graph, config=self._session_config) as sess:
            input_meta_graph_def = loader.load(sess,
                                               self._input_saved_model_tags,
                                               self._input_saved_model_dir)
            input_signature_def = input_meta_graph_def.signature_def[
                self._input_saved_model_signature_key]

            def _gather_names(tensor_info):
                """Get the node names from a TensorInfo."""
                return set([
                    tensor_info[key].name.split(":")[0] for key in tensor_info
                ])

            # Get input and outputs from all SignatureDef.
            output_node_names = _gather_names(
                input_signature_def.inputs).union(
                    _gather_names(input_signature_def.outputs))

            # Freeze the variables in the SavedModel graph and copy the frozen
            # graph over.
            frozen_graph_def = graph_util.convert_variables_to_constants(
                sess, sess.graph.as_graph_def(add_shapes=True),
                list(output_node_names))
            self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
            self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)

            # Copy the collections that are not variables.
            for key in input_meta_graph_def.collection_def:
                # TODO(laigd): currently we use the collection key to filter out
                # collections that depend on variable ops, but this may miss some
                # other user-defined collections. A better way would be to use
                # CollectionDef::NodeList for the filtering.
                if key not in [
                        "variables", "local_variables", "model_variables",
                        "trainable_variables", "train_op", "table_initializer"
                ]:
                    self._grappler_meta_graph_def.collection_def[key].CopyFrom(
                        input_meta_graph_def.collection_def[key])

            self._add_nodes_blacklist()

            # Copy other information.
            self._grappler_meta_graph_def.meta_info_def.CopyFrom(
                input_meta_graph_def.meta_info_def)
            self._grappler_meta_graph_def.signature_def[
                self._input_saved_model_signature_key].CopyFrom(
                    input_signature_def)
            # TODO(laigd): maybe add back AssetFileDef.

        self._run_conversion()
Пример #28
0
    def testChiefCanWriteEvents(self):
        logdir = _test_dir("can_write")
        with tf.Graph().as_default():
            tf.summary.scalar("c1", tf.constant(1))
            tf.summary.scalar("c2", tf.constant(2))
            tf.summary.scalar("c3", tf.constant(3))
            summ = tf.summary.merge_all()
            sv = tf.train.Supervisor(is_chief=True,
                                     logdir=logdir,
                                     summary_op=None)
            meta_graph_def = meta_graph.create_meta_graph_def()
            sess = sv.prepare_or_wait_for_session("")
            sv.summary_computed(sess, sess.run(summ))
            sess.close()
            # Wait to make sure everything is written to file before stopping.
            time.sleep(1)
            sv.stop()

        rr = _summary_iterator(logdir)

        # The first event should list the file_version.
        ev = next(rr)
        self.assertEquals("brain.Event:2", ev.file_version)

        # The next one has the graph.
        ev = next(rr)
        ev_graph = tf.GraphDef()
        ev_graph.ParseFromString(ev.graph_def)
        self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True),
                               ev_graph)

        # Stored MetaGraphDef
        ev = next(rr)
        ev_meta_graph = meta_graph_pb2.MetaGraphDef()
        ev_meta_graph.ParseFromString(ev.meta_graph_def)
        self.assertProtoEquals(meta_graph_def, ev_meta_graph)
        self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True),
                               ev_meta_graph.graph_def)
        # The next one should have the values from the summary.
        ev = next(rr)
        self.assertProtoEquals(
            """
      value { tag: 'c1' simple_value: 1.0 }
      value { tag: 'c2' simple_value: 2.0 }
      value { tag: 'c3' simple_value: 3.0 }
      """, ev.summary)

        # The next one should be a stop message if we closed cleanly.
        ev = next(rr)
        self.assertEquals(tf.SessionLog.STOP, ev.session_log.status)

        # We should be done.
        self.assertRaises(StopIteration, lambda: next(rr))
Пример #29
0
def main(_):
  with gfile.GFile(FLAGS.input) as input_file:
    metagraph = meta_graph_pb2.MetaGraphDef()
    metagraph.ParseFromString(input_file.read())

  if FLAGS.rewriter_config is not None:
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

  report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
  print(report)
Пример #30
0
    def testStandardServicesWithoutGlobalStep(self):
        logdir = _test_dir("standard_services_without_global_step")
        # Create a checkpoint.
        with tf.Graph().as_default():
            v = tf.Variable([1.0], name="foo")
            tf.scalar_summary(["v"], v)
            sv = tf.train.Supervisor(logdir=logdir)
            meta_graph_def = meta_graph.create_meta_graph_def(
                saver_def=sv.saver.saver_def)
            sess = sv.prepare_or_wait_for_session("")
            save_path = sv.save_path
            self._wait_for_glob(save_path, 3.0)
            self._wait_for_glob(os.path.join(logdir, "*events*"),
                                3.0,
                                for_checkpoint=False)
            # Wait to make sure everything is written to file before stopping.
            time.sleep(1)
            sv.stop()
        # There should be an event file with a version number.
        rr = _summary_iterator(logdir)
        ev = next(rr)
        self.assertEquals("brain.Event:2", ev.file_version)
        ev = next(rr)
        ev_graph = tf.GraphDef()
        ev_graph.ParseFromString(ev.graph_def)
        self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True),
                               ev_graph)

        # Stored MetaGraphDef
        ev = next(rr)
        ev_meta_graph = meta_graph_pb2.MetaGraphDef()
        ev_meta_graph.ParseFromString(ev.meta_graph_def)
        self.assertProtoEquals(meta_graph_def, ev_meta_graph)
        self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True),
                               ev_meta_graph.graph_def)

        ev = next(rr)
        self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }",
                               ev.summary)

        ev = next(rr)
        self.assertEquals(tf.SessionLog.STOP, ev.session_log.status)

        self.assertRaises(StopIteration, lambda: next(rr))
        # There should be a checkpoint file with the variable "foo"
        with tf.Graph().as_default(), self.test_session() as sess:
            v = tf.Variable([10.10], name="foo")
            sav = tf.train.Saver([v])
            sav.restore(sess, save_path)
            self.assertEqual(1.0, v.eval()[0])