Esempio n. 1
0
 def test_get_node_def_from_graph(self):
     graph_def = graph_pb2.GraphDef()
     node_foo = graph_def.node.add()
     node_foo.name = "foo"
     self.assertIs(test_util.get_node_def_from_graph("foo", graph_def),
                   node_foo)
     self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
Esempio n. 2
0
  def testDefaultAttrStripping(self):
    """Verifies that default attributes are stripped from a graph def."""

    # Complex Op has 2 attributes with defaults:
    #   o "T"    : float32.
    #   o "Tout" : complex64.

    # When inputs to the Complex Op are float32 instances, "T" maps to float32
    # and "Tout" maps to complex64. Since these attr values map to their
    # defaults, they must be stripped unless stripping of default attrs is
    # disabled.
    with self.cached_session():
      real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
      imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")

      # strip_default_attrs is enabled.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertNotIn("T", node_def.attr)
      self.assertNotIn("Tout", node_def.attr)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)

      # strip_default_attrs is disabled.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=False)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertIn("T", node_def.attr)
      self.assertIn("Tout", node_def.attr)
      self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs)

    # When inputs to the Complex Op are float64 instances, "T" maps to float64
    # and "Tout" maps to complex128. Since these attr values don't map to their
    # defaults, they must not be stripped.
    with self.session(graph=ops.Graph()):
      real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
      imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertEqual(node_def.attr["T"].type, dtypes.float64)
      self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
  def testDefaultAttrStripping(self):
    """Verifies that default attributes are stripped from a graph def."""

    # Complex Op has 2 attributes with defaults:
    #   o "T"    : float32.
    #   o "Tout" : complex64.

    # When inputs to the Complex Op are float32 instances, "T" maps to float32
    # and "Tout" maps to complex64. Since these attr values map to their
    # defaults, they must be stripped unless stripping of default attrs is
    # disabled.
    with self.cached_session():
      real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
      imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")

      # strip_default_attrs is enabled.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertNotIn("T", node_def.attr)
      self.assertNotIn("Tout", node_def.attr)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)

      # strip_default_attrs is disabled.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=False)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertIn("T", node_def.attr)
      self.assertIn("Tout", node_def.attr)
      self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs)

    # When inputs to the Complex Op are float64 instances, "T" maps to float64
    # and "Tout" maps to complex128. Since these attr values don't map to their
    # defaults, they must not be stripped.
    with self.session(graph=ops.Graph()):
      real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
      imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertEqual(node_def.attr["T"].type, dtypes.float64)
      self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
Esempio n. 4
0
  def testDefaultAttrStrippingUnregisteredOps(self):
    """Verifies that nodes with un-registered ops are not stripped."""
    graph_def = graph_pb2.GraphDef()
    node = graph_def.node.add()
    node.name = "node_with_unreg_op"
    node.op = "unreg_op"
    node.attr["attr_1"].i = 1

    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
    meta_info_def.stripped_op_list.op.add()

    with self.cached_session():
      meta_graph_def = meta_graph.create_meta_graph_def(
          meta_info_def=meta_info_def, graph_def=graph_def,
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("node_with_unreg_op",
                                                   meta_graph_def.graph_def)
      self.assertEqual(node_def.attr["attr_1"].i, 1)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
  def testDefaultAttrStrippingUnregisteredOps(self):
    """Verifies that nodes with un-registered ops are not stripped."""
    graph_def = graph_pb2.GraphDef()
    node = graph_def.node.add()
    node.name = "node_with_unreg_op"
    node.op = "unreg_op"
    node.attr["attr_1"].i = 1

    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
    meta_info_def.stripped_op_list.op.add()

    with self.cached_session():
      meta_graph_def = meta_graph.create_meta_graph_def(
          meta_info_def=meta_info_def, graph_def=graph_def,
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("node_with_unreg_op",
                                                   meta_graph_def.graph_def)
      self.assertEqual(node_def.attr["attr_1"].i, 1)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
Esempio n. 6
0
 def test_get_node_def_from_graph(self):
   graph_def = graph_pb2.GraphDef()
   node_foo = graph_def.node.add()
   node_foo.name = "foo"
   self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo)
   self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
  def testStripDefaultAttrs(self):
    export_dir = self._get_export_dir("test_strip_default_attrs")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Add a graph with two float32 variables and a Complex Op composing them
    # with strip_default_attrs enabled.
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], strip_default_attrs=True)

    # Add a graph with the same float32 variables and a Complex Op composing
    # them with strip_default_attrs disabled.
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph(["bar"], strip_default_attrs=False)

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Loading graph "foo" via the loader must restore the defaults for the
    # "Complex" node based on the "Complex" OpDef in the Op registry.
    sess = session.Session(graph=ops.Graph())
    meta_graph_def = loader.load(sess, ["foo"], export_dir)
    complex_node = test_util.get_node_def_from_graph("complex",
                                                     meta_graph_def.graph_def)
    self.assertIn("T", complex_node.attr)
    self.assertIn("Tout", complex_node.attr)

    # Load graph "foo" from disk as-is to verify default attrs are stripped.
    # pylint: disable=protected-access
    saved_model_pb = loader_impl._parse_saved_model(export_dir)
    self.assertIsNotNone(saved_model_pb)
    # pylint: enable=protected-access

    meta_graph_foo_def = None
    meta_graph_bar_def = None
    for meta_graph_def in saved_model_pb.meta_graphs:
      if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
        meta_graph_foo_def = meta_graph_def
      elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
        meta_graph_bar_def = meta_graph_def

    self.assertIsNotNone(meta_graph_foo_def)
    self.assertIsNotNone(meta_graph_bar_def)

    # "Complex" Op has 2 attributes with defaults:
    #   o "T"    : float32.   (input type)
    #   o "Tout" : complex64. (output type)

    # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
    # Graph "foo" was saved with strip_default_attrs set to True.
    node_def = test_util.get_node_def_from_graph("complex",
                                                 meta_graph_foo_def.graph_def)
    self.assertNotIn("T", node_def.attr)
    self.assertNotIn("Tout", node_def.attr)

    # "Complex" Op in graph "bar" must have attributes "T" and "Tout".
    # Graph "bar" was saved with strip_default_attrs set to False.
    node_def = test_util.get_node_def_from_graph("complex",
                                                 meta_graph_bar_def.graph_def)
    self.assertIn("T", node_def.attr)
    self.assertIn("Tout", node_def.attr)
  def testStripDefaultAttrs(self):
    export_dir = self._get_export_dir("test_strip_default_attrs")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Add a graph with two float32 variables and a Complex Op composing them
    # with strip_default_attrs enabled.
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], strip_default_attrs=True)

    # Add a graph with the same float32 variables and a Complex Op composing
    # them with strip_default_attrs disabled.
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph(["bar"], strip_default_attrs=False)

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Loading graph "foo" via the loader must restore the defaults for the
    # "Complex" node based on the "Complex" OpDef in the Op registry.
    sess = session.Session(graph=ops.Graph())
    meta_graph_def = loader.load(sess, ["foo"], export_dir)
    complex_node = test_util.get_node_def_from_graph("complex",
                                                     meta_graph_def.graph_def)
    self.assertIn("T", complex_node.attr)
    self.assertIn("Tout", complex_node.attr)

    # Load graph "foo" from disk as-is to verify default attrs are stripped.
    # pylint: disable=protected-access
    saved_model_pb = loader_impl._parse_saved_model(export_dir)
    self.assertIsNotNone(saved_model_pb)
    # pylint: enable=protected-access

    meta_graph_foo_def = None
    meta_graph_bar_def = None
    for meta_graph_def in saved_model_pb.meta_graphs:
      if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
        meta_graph_foo_def = meta_graph_def
      elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
        meta_graph_bar_def = meta_graph_def

    self.assertIsNotNone(meta_graph_foo_def)
    self.assertIsNotNone(meta_graph_bar_def)

    # "Complex" Op has 2 attributes with defaults:
    #   o "T"    : float32.   (input type)
    #   o "Tout" : complex64. (output type)

    # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
    # Graph "foo" was saved with strip_default_attrs set to True.
    node_def = test_util.get_node_def_from_graph("complex",
                                                 meta_graph_foo_def.graph_def)
    self.assertNotIn("T", node_def.attr)
    self.assertNotIn("Tout", node_def.attr)

    # "Complex" Op in graph "bar" must have attributes "T" and "Tout".
    # Graph "bar" was saved with strip_default_attrs set to False.
    node_def = test_util.get_node_def_from_graph("complex",
                                                 meta_graph_bar_def.graph_def)
    self.assertIn("T", node_def.attr)
    self.assertIn("Tout", node_def.attr)