def testNoVariables(self):
        test_dir = _TestDir("no_variables")
        filename = os.path.join(test_dir, "metafile")

        input_feed_value = -10  # Arbitrary input value for feed_dict.

        orig_graph = ops.Graph()
        with self.session(graph=orig_graph) as sess:
            # Create a minimal graph with zero variables.
            input_tensor = array_ops.placeholder(dtypes.float32,
                                                 shape=[],
                                                 name="input")
            offset = constant_op.constant(42,
                                          dtype=dtypes.float32,
                                          name="offset")
            output_tensor = math_ops.add(input_tensor,
                                         offset,
                                         name="add_offset")

            # Add input and output tensors to graph collections.
            ops.add_to_collection("input_tensor", input_tensor)
            ops.add_to_collection("output_tensor", output_tensor)

            output_value = sess.run(output_tensor,
                                    {input_tensor: input_feed_value})
            self.assertEqual(output_value, 32)

            # Generates MetaGraphDef.
            meta_graph_def, var_list = meta_graph.export_scoped_meta_graph(
                filename=filename,
                graph_def=ops.get_default_graph().as_graph_def(
                    add_shapes=True),
                collection_list=["input_tensor", "output_tensor"],
                saver_def=None)
            self.assertTrue(meta_graph_def.HasField("meta_info_def"))
            self.assertNotEqual(
                meta_graph_def.meta_info_def.tensorflow_version, "")
            self.assertNotEqual(
                meta_graph_def.meta_info_def.tensorflow_git_version, "")
            self.assertEqual({}, var_list)

        # Create a clean graph and import the MetaGraphDef nodes.
        new_graph = ops.Graph()
        with self.session(graph=new_graph) as sess:
            # Import the previously export meta graph.
            meta_graph.import_scoped_meta_graph(filename)

            # Re-exports the current graph state for comparison to the original.
            new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
                filename + "_new")
            test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
                                                     new_meta_graph_def)

            # Ensures that we can still get a reference to our graph collections.
            new_input_tensor = ops.get_collection("input_tensor")[0]
            new_output_tensor = ops.get_collection("output_tensor")[0]
            # Verifies that the new graph computes the same result as the original.
            new_output_value = sess.run(new_output_tensor,
                                        {new_input_tensor: input_feed_value})
            self.assertEqual(new_output_value, output_value)
 def testScopedWithQueue(self):
     test_dir = _TestDir("scoped_with_queue")
     orig_meta_graph = self._testScopedExportWithQueue(
         test_dir, "exported_queue1.pbtxt")
     new_meta_graph = self._testScopedImportWithQueue(
         test_dir, "exported_queue1.pbtxt", "exported_new_queue1.pbtxt")
     test_util.assert_meta_graph_protos_equal(self, orig_meta_graph,
                                              new_meta_graph)
示例#3
0
 def testScopedWithQueue(self):
   test_dir = _TestDir("scoped_with_queue")
   orig_meta_graph = self._testScopedExportWithQueue(test_dir,
                                                     "exported_queue1.pbtxt")
   new_meta_graph = self._testScopedImportWithQueue(
       test_dir, "exported_queue1.pbtxt", "exported_new_queue1.pbtxt")
   test_util.assert_meta_graph_protos_equal(self, orig_meta_graph,
                                            new_meta_graph)
    def testWhileLoopGradients(self):
        # Create a simple while loop.
        with ops.Graph().as_default():
            with ops.name_scope("export"):
                var = variables.Variable(0.)
                var_name = var.name
                _, output = control_flow_ops.while_loop(
                    lambda i, x: i < 5, lambda i, x:
                    (i + 1, x + math_ops.cast(i, dtypes.float32)), [0, var])
                output_name = output.name

            # Generate a MetaGraphDef containing the while loop with an export scope.
            meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
                export_scope="export")

            # Build and run the gradients of the while loop. We use this below to
            # verify that the gradients are correct with the imported MetaGraphDef.
            init_op = variables.global_variables_initializer()
            grad = gradients_impl.gradients([output], [var])
            with session.Session() as sess:
                sess.run(init_op)
                expected_grad_value = sess.run(grad)

        # Restore the MetaGraphDef into a new Graph with an import scope.
        with ops.Graph().as_default():
            meta_graph.import_scoped_meta_graph(meta_graph_def,
                                                import_scope="import")

            # Re-export and make sure we get the same MetaGraphDef.
            new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
                export_scope="import")
            test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
                                                     new_meta_graph_def)

            # Make sure we can still build gradients and get the same result.

            def new_name(tensor_name):
                base_tensor_name = tensor_name.replace("export/", "")
                return "import/" + base_tensor_name

            var = ops.get_default_graph().get_tensor_by_name(
                new_name(var_name))
            output = ops.get_default_graph().get_tensor_by_name(
                new_name(output_name))
            grad = gradients_impl.gradients([output], [var])

            init_op = variables.global_variables_initializer()

            with session.Session() as sess:
                sess.run(init_op)
                actual_grad_value = sess.run(grad)
                self.assertEqual(expected_grad_value, actual_grad_value)
示例#5
0
  def testNoVariables(self):
    test_dir = _TestDir("no_variables")
    filename = os.path.join(test_dir, "metafile")

    input_feed_value = -10  # Arbitrary input value for feed_dict.

    orig_graph = ops.Graph()
    with self.session(graph=orig_graph) as sess:
      # Create a minimal graph with zero variables.
      input_tensor = array_ops.placeholder(
          dtypes.float32, shape=[], name="input")
      offset = constant_op.constant(42, dtype=dtypes.float32, name="offset")
      output_tensor = math_ops.add(input_tensor, offset, name="add_offset")

      # Add input and output tensors to graph collections.
      ops.add_to_collection("input_tensor", input_tensor)
      ops.add_to_collection("output_tensor", output_tensor)

      output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
      self.assertEqual(output_value, 32)

      # Generates MetaGraphDef.
      meta_graph_def, var_list = meta_graph.export_scoped_meta_graph(
          filename=filename,
          graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
          collection_list=["input_tensor", "output_tensor"],
          saver_def=None)
      self.assertTrue(meta_graph_def.HasField("meta_info_def"))
      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
                          "")
      self.assertEqual({}, var_list)

    # Create a clean graph and import the MetaGraphDef nodes.
    new_graph = ops.Graph()
    with self.session(graph=new_graph) as sess:
      # Import the previously export meta graph.
      meta_graph.import_scoped_meta_graph(filename)

      # Re-exports the current graph state for comparison to the original.
      new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename +
                                                                  "_new")
      test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
                                               new_meta_graph_def)

      # Ensures that we can still get a reference to our graph collections.
      new_input_tensor = ops.get_collection("input_tensor")[0]
      new_output_tensor = ops.get_collection("output_tensor")[0]
      # Verifies that the new graph computes the same result as the original.
      new_output_value = sess.run(new_output_tensor,
                                  {new_input_tensor: input_feed_value})
      self.assertEqual(new_output_value, output_value)
 def testScopedExportAndImport(self):
   test_dir = _TestDir("scoped_export_import")
   filenames = [
       "exported_hidden1.pbtxt", "exported_hidden2.pbtxt",
       "exported_softmax_linear.pbtxt"
   ]
   orig_meta_graphs = self._testScopedExport(test_dir, filenames)
   new_meta_graphs = self._testScopedImport(test_dir, filenames)
   # Delete the unbound_inputs to allow directly calling ProtoEqual.
   del orig_meta_graphs[0].collection_def["unbound_inputs"]
   del new_meta_graphs[0].collection_def["unbound_inputs"]
   for a, b in zip(orig_meta_graphs, new_meta_graphs):
     test_util.assert_meta_graph_protos_equal(self, a, b)
示例#7
0
 def testScopedExportAndImport(self):
     test_dir = _TestDir("scoped_export_import")
     filenames = [
         "exported_hidden1.pbtxt", "exported_hidden2.pbtxt",
         "exported_softmax_linear.pbtxt"
     ]
     orig_meta_graphs = self._testScopedExport(test_dir, filenames)
     new_meta_graphs = self._testScopedImport(test_dir, filenames)
     # Delete the unbound_inputs to allow directly calling ProtoEqual.
     del orig_meta_graphs[0].collection_def["unbound_inputs"]
     del new_meta_graphs[0].collection_def["unbound_inputs"]
     for a, b in zip(orig_meta_graphs, new_meta_graphs):
         test_util.assert_meta_graph_protos_equal(self, a, b)
示例#8
0
  def testWhileLoopGradients(self):
    # Create a simple while loop.
    with ops.Graph().as_default():
      with ops.name_scope("export"):
        var = variables.Variable(0.)
        var_name = var.name
        _, output = control_flow_ops.while_loop(
            lambda i, x: i < 5,
            lambda i, x: (i + 1, x + math_ops.cast(i, dtypes.float32)),
            [0, var])
        output_name = output.name

      # Generate a MetaGraphDef containing the while loop with an export scope.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          export_scope="export")

      # Build and run the gradients of the while loop. We use this below to
      # verify that the gradients are correct with the imported MetaGraphDef.
      init_op = variables.global_variables_initializer()
      grad = gradients_impl.gradients([output], [var])
      with session.Session() as sess:
        self.evaluate(init_op)
        expected_grad_value = self.evaluate(grad)

    # Restore the MetaGraphDef into a new Graph with an import scope.
    with ops.Graph().as_default():
      meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import")

      # Re-export and make sure we get the same MetaGraphDef.
      new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          export_scope="import")
      test_util.assert_meta_graph_protos_equal(
          self, meta_graph_def, new_meta_graph_def)

      # Make sure we can still build gradients and get the same result.

      def new_name(tensor_name):
        base_tensor_name = tensor_name.replace("export/", "")
        return "import/" + base_tensor_name

      var = ops.get_default_graph().get_tensor_by_name(new_name(var_name))
      output = ops.get_default_graph().get_tensor_by_name(new_name(output_name))
      grad = gradients_impl.gradients([output], [var])

      init_op = variables.global_variables_initializer()

      with session.Session() as sess:
        self.evaluate(init_op)
        actual_grad_value = self.evaluate(grad)
        self.assertEqual(expected_grad_value, actual_grad_value)
    def _testExportImportAcrossScopes(self, graph_fn, use_resource):
        """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
        with ops.Graph().as_default() as original_graph:
            with variable_scope.variable_scope("dropA/dropB/keepA"):
                graph_fn(use_resource=use_resource)
        exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
            graph=original_graph, export_scope="dropA/dropB")[0]

        with ops.Graph().as_default() as imported_graph:
            meta_graph.import_scoped_meta_graph(exported_meta_graph_def,
                                                import_scope="importA")

        with ops.Graph().as_default() as expected_graph:
            with variable_scope.variable_scope("importA/keepA"):
                graph_fn(use_resource=use_resource)

            if use_resource:
                # Bringing in collections that contain ResourceVariables will adds ops
                # to the graph the first time a variable is encountered, so mimic the
                # same behavior.
                seen_variables = set()
                for collection_key in sorted([
                        ops.GraphKeys.GLOBAL_VARIABLES,
                        ops.GraphKeys.TRAINABLE_VARIABLES,
                ]):
                    for var in expected_graph.get_collection(collection_key):
                        if var not in seen_variables:
                            var._read_variable_op()
                            seen_variables.add(var)

        result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
        expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

        if use_resource:
            # Clear all shared_name attributes before comparing, since they are
            # orthogonal to scopes and are not updated on export/import.
            for meta_graph_def in [result, expected]:
                for node in meta_graph_def.graph_def.node:
                    shared_name_attr = "shared_name"
                    shared_name_value = node.attr.get(shared_name_attr, None)
                    if shared_name_value and shared_name_value.HasField("s"):
                        if shared_name_value.s:
                            node.attr[shared_name_attr].s = b""

        test_util.assert_meta_graph_protos_equal(self, expected, result)
示例#10
0
  def _testExportImportAcrossScopes(self, graph_fn, use_resource):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn(use_resource=use_resource)
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn(use_resource=use_resource)

      if use_resource:
        # Bringing in a collection that contains ResourceVariables adds ops
        # to the graph, so mimic the same behavior.
        for collection_key in sorted([
            ops.GraphKeys.GLOBAL_VARIABLES,
            ops.GraphKeys.TRAINABLE_VARIABLES,
        ]):
          for var in expected_graph.get_collection(collection_key):
            var._read_variable_op()

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

    if use_resource:
      # Clear all shared_name attributes before comparing, since they are
      # supposed to be orthogonal to scopes.
      for meta_graph_def in [result, expected]:
        for node in meta_graph_def.graph_def.node:
          shared_name_attr = "shared_name"
          shared_name_value = node.attr.get(shared_name_attr, None)
          if shared_name_value and shared_name_value.HasField("s"):
            if shared_name_value.s:
              node.attr[shared_name_attr].s = b""

    test_util.assert_meta_graph_protos_equal(self, expected, result)
示例#11
0
 def testScopedExportAndImport(self):
     test_dir = _TestDir("scoped_export_import")
     filenames = [
         "exported_hidden1.pbtxt", "exported_hidden2.pbtxt",
         "exported_softmax_linear.pbtxt"
     ]
     orig_meta_graphs = self._testScopedExport(test_dir, filenames)
     new_meta_graphs = self._testScopedImport(test_dir, filenames)
     for a, b in zip(orig_meta_graphs, new_meta_graphs):
         # The unbound input strings are slightly different with the C API enabled
         # ("images" vs "images:0") due to the original import_graph_def code
         # vs. ImportGraphDef in C++.
         # TODO(skyewm): update the pbtxts once _USE_C_API is removed.
         del a.collection_def["unbound_inputs"]
         del b.collection_def["unbound_inputs"]
         test_util.assert_meta_graph_protos_equal(self, a, b)
示例#12
0
 def testScopedExportAndImport(self):
   test_dir = _TestDir("scoped_export_import")
   filenames = [
       "exported_hidden1.pbtxt", "exported_hidden2.pbtxt",
       "exported_softmax_linear.pbtxt"
   ]
   orig_meta_graphs = self._testScopedExport(test_dir, filenames)
   new_meta_graphs = self._testScopedImport(test_dir, filenames)
   for a, b in zip(orig_meta_graphs, new_meta_graphs):
     # The unbound input strings are slightly different with the C API enabled
     # ("images" vs "images:0") due to the original import_graph_def code
     # vs. ImportGraphDef in C++.
     # TODO(skyewm): update the pbtxts once _USE_C_API is removed.
     del a.collection_def["unbound_inputs"]
     del b.collection_def["unbound_inputs"]
     test_util.assert_meta_graph_protos_equal(self, a, b)
示例#13
0
 def assert_summaries(self,
                      test_case,
                      expected_logdir=None,
                      expected_graph=None,
                      expected_summaries=None,
                      expected_added_graphs=None,
                      expected_added_meta_graphs=None,
                      expected_session_logs=None):
     """Assert expected items have been added to summary writer."""
     if expected_logdir is not None:
         test_case.assertEqual(expected_logdir, self._logdir)
     if expected_graph is not None:
         test_case.assertTrue(expected_graph is self._graph)
     expected_summaries = expected_summaries or {}
     for step in expected_summaries:
         test_case.assertTrue(step in self._summaries,
                              msg='Missing step %s from %s.' %
                              (step, self._summaries.keys()))
         actual_simple_values = {}
         for step_summary in self._summaries[step]:
             for v in step_summary.value:
                 # Ignore global_step/sec since it's written by Supervisor in a
                 # separate thread, so it's non-deterministic how many get written.
                 if 'global_step/sec' != v.tag:
                     actual_simple_values[v.tag] = v.simple_value
         test_case.assertEqual(expected_summaries[step],
                               actual_simple_values)
     if expected_added_graphs is not None:
         test_case.assertEqual(expected_added_graphs, self._added_graphs)
     if expected_added_meta_graphs is not None:
         test_case.assertEqual(len(expected_added_meta_graphs),
                               len(self._added_meta_graphs))
         for expected, actual in zip(expected_added_meta_graphs,
                                     self._added_meta_graphs):
             test_util.assert_meta_graph_protos_equal(
                 test_case, expected, actual)
     if expected_session_logs is not None:
         test_case.assertEqual(expected_session_logs,
                               self._added_session_logs)
示例#14
0
  def _testExportImportAcrossScopes(self, graph_fn, use_resource):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn(use_resource=use_resource)
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn(use_resource=use_resource)

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

    if use_resource:
      # Clear all shared_name attributes before comparing, since they are
      # orthogonal to scopes and are not updated on export/import.
      for meta_graph_def in [result, expected]:
        for node in meta_graph_def.graph_def.node:
          shared_name_attr = "shared_name"
          shared_name_value = node.attr.get(shared_name_attr, None)
          if shared_name_value and shared_name_value.HasField("s"):
            if shared_name_value.s:
              node.attr[shared_name_attr].s = b""

    test_util.assert_meta_graph_protos_equal(self, expected, result)
  def _testExportImportAcrossScopes(self, graph_fn, use_resource):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn(use_resource=use_resource)
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn(use_resource=use_resource)

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

    if use_resource:
      # Clear all shared_name attributes before comparing, since they are
      # orthogonal to scopes and are not updated on export/import.
      for meta_graph_def in [result, expected]:
        for node in meta_graph_def.graph_def.node:
          shared_name_attr = "shared_name"
          shared_name_value = node.attr.get(shared_name_attr, None)
          if shared_name_value and shared_name_value.HasField("s"):
            if shared_name_value.s:
              node.attr[shared_name_attr].s = b""

    test_util.assert_meta_graph_protos_equal(self, expected, result)
示例#16
0
 def assert_summaries(self,
                      test_case,
                      expected_logdir=None,
                      expected_graph=None,
                      expected_summaries=None,
                      expected_added_graphs=None,
                      expected_added_meta_graphs=None,
                      expected_session_logs=None):
   """Assert expected items have been added to summary writer."""
   if expected_logdir is not None:
     test_case.assertEqual(expected_logdir, self._logdir)
   if expected_graph is not None:
     test_case.assertTrue(expected_graph is self._graph)
   expected_summaries = expected_summaries or {}
   for step in expected_summaries:
     test_case.assertTrue(
         step in self._summaries,
         msg='Missing step %s from %s.' % (step, self._summaries.keys()))
     actual_simple_values = {}
     for step_summary in self._summaries[step]:
       for v in step_summary.value:
         # Ignore global_step/sec since it's written by Supervisor in a
         # separate thread, so it's non-deterministic how many get written.
         if 'global_step/sec' != v.tag:
           actual_simple_values[v.tag] = v.simple_value
     test_case.assertEqual(expected_summaries[step], actual_simple_values)
   if expected_added_graphs is not None:
     test_case.assertEqual(expected_added_graphs, self._added_graphs)
   if expected_added_meta_graphs is not None:
     test_case.assertEqual(len(expected_added_meta_graphs),
                           len(self._added_meta_graphs))
     for expected, actual in zip(expected_added_meta_graphs,
                                 self._added_meta_graphs):
       test_util.assert_meta_graph_protos_equal(test_case, expected, actual)
   if expected_session_logs is not None:
     test_case.assertEqual(expected_session_logs, self._added_session_logs)