def test_get_node_def_from_graph(self): graph_def = graph_pb2.GraphDef() node_foo = graph_def.node.add() node_foo.name = "foo" self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo) self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
def testDefaultAttrStripping(self): """Verifies that default attributes are stripped from a graph def.""" # Complex Op has 2 attributes with defaults: # o "T" : float32. # o "Tout" : complex64. # When inputs to the Complex Op are float32 instances, "T" maps to float32 # and "Tout" maps to complex64. Since these attr values map to their # defaults, they must be stripped unless stripping of default attrs is # disabled. with self.cached_session(): real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real") imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") # strip_default_attrs is enabled. meta_graph_def, _ = meta_graph.export_scoped_meta_graph( graph_def=ops.get_default_graph().as_graph_def(), strip_default_attrs=True) node_def = test_util.get_node_def_from_graph("complex", meta_graph_def.graph_def) self.assertNotIn("T", node_def.attr) self.assertNotIn("Tout", node_def.attr) self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) # strip_default_attrs is disabled. meta_graph_def, _ = meta_graph.export_scoped_meta_graph( graph_def=ops.get_default_graph().as_graph_def(), strip_default_attrs=False) node_def = test_util.get_node_def_from_graph("complex", meta_graph_def.graph_def) self.assertIn("T", node_def.attr) self.assertIn("Tout", node_def.attr) self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs) # When inputs to the Complex Op are float64 instances, "T" maps to float64 # and "Tout" maps to complex128. Since these attr values don't map to their # defaults, they must not be stripped. with self.session(graph=ops.Graph()): real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real") imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag") math_ops.complex(real_num, imag_num, name="complex") meta_graph_def, _ = meta_graph.export_scoped_meta_graph( graph_def=ops.get_default_graph().as_graph_def(), strip_default_attrs=True) node_def = test_util.get_node_def_from_graph("complex", meta_graph_def.graph_def) self.assertEqual(node_def.attr["T"].type, dtypes.float64) self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128) self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
def testDefaultAttrStripping(self): """Verifies that default attributes are stripped from a graph def.""" # Complex Op has 2 attributes with defaults: # o "T" : float32. # o "Tout" : complex64. # When inputs to the Complex Op are float32 instances, "T" maps to float32 # and "Tout" maps to complex64. Since these attr values map to their # defaults, they must be stripped unless stripping of default attrs is # disabled. with self.cached_session(): real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real") imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") # strip_default_attrs is enabled. meta_graph_def, _ = meta_graph.export_scoped_meta_graph( graph_def=ops.get_default_graph().as_graph_def(), strip_default_attrs=True) node_def = test_util.get_node_def_from_graph("complex", meta_graph_def.graph_def) self.assertNotIn("T", node_def.attr) self.assertNotIn("Tout", node_def.attr) self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) # strip_default_attrs is disabled. meta_graph_def, _ = meta_graph.export_scoped_meta_graph( graph_def=ops.get_default_graph().as_graph_def(), strip_default_attrs=False) node_def = test_util.get_node_def_from_graph("complex", meta_graph_def.graph_def) self.assertIn("T", node_def.attr) self.assertIn("Tout", node_def.attr) self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs) # When inputs to the Complex Op are float64 instances, "T" maps to float64 # and "Tout" maps to complex128. Since these attr values don't map to their # defaults, they must not be stripped. with self.session(graph=ops.Graph()): real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real") imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag") math_ops.complex(real_num, imag_num, name="complex") meta_graph_def, _ = meta_graph.export_scoped_meta_graph( graph_def=ops.get_default_graph().as_graph_def(), strip_default_attrs=True) node_def = test_util.get_node_def_from_graph("complex", meta_graph_def.graph_def) self.assertEqual(node_def.attr["T"].type, dtypes.float64) self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128) self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
def testDefaultAttrStrippingUnregisteredOps(self): """Verifies that nodes with un-registered ops are not stripped.""" graph_def = graph_pb2.GraphDef() node = graph_def.node.add() node.name = "node_with_unreg_op" node.op = "unreg_op" node.attr["attr_1"].i = 1 meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() meta_info_def.stripped_op_list.op.add() with self.cached_session(): meta_graph_def = meta_graph.create_meta_graph_def( meta_info_def=meta_info_def, graph_def=graph_def, strip_default_attrs=True) node_def = test_util.get_node_def_from_graph("node_with_unreg_op", meta_graph_def.graph_def) self.assertEqual(node_def.attr["attr_1"].i, 1) self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
def testDefaultAttrStrippingUnregisteredOps(self): """Verifies that nodes with un-registered ops are not stripped.""" graph_def = graph_pb2.GraphDef() node = graph_def.node.add() node.name = "node_with_unreg_op" node.op = "unreg_op" node.attr["attr_1"].i = 1 meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() meta_info_def.stripped_op_list.op.add() with self.cached_session(): meta_graph_def = meta_graph.create_meta_graph_def( meta_info_def=meta_info_def, graph_def=graph_def, strip_default_attrs=True) node_def = test_util.get_node_def_from_graph("node_with_unreg_op", meta_graph_def.graph_def) self.assertEqual(node_def.attr["attr_1"].i, 1) self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
def test_get_node_def_from_graph(self): graph_def = graph_pb2.GraphDef() node_foo = graph_def.node.add() node_foo.name = "foo" self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo) self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
def testStripDefaultAttrs(self): export_dir = self._get_export_dir("test_strip_default_attrs") builder = saved_model_builder.SavedModelBuilder(export_dir) # Add a graph with two float32 variables and a Complex Op composing them # with strip_default_attrs enabled. with session.Session(graph=ops.Graph()) as sess: real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") sess.run(variables.global_variables_initializer()) builder.add_meta_graph_and_variables( sess, ["foo"], strip_default_attrs=True) # Add a graph with the same float32 variables and a Complex Op composing # them with strip_default_attrs disabled. with session.Session(graph=ops.Graph()) as sess: real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") sess.run(variables.global_variables_initializer()) builder.add_meta_graph(["bar"], strip_default_attrs=False) # Save the SavedModel to disk in text format. builder.save(as_text=True) # Loading graph "foo" via the loader must restore the defaults for the # "Complex" node based on the "Complex" OpDef in the Op registry. sess = session.Session(graph=ops.Graph()) meta_graph_def = loader.load(sess, ["foo"], export_dir) complex_node = test_util.get_node_def_from_graph("complex", meta_graph_def.graph_def) self.assertIn("T", complex_node.attr) self.assertIn("Tout", complex_node.attr) # Load graph "foo" from disk as-is to verify default attrs are stripped. # pylint: disable=protected-access saved_model_pb = loader_impl._parse_saved_model(export_dir) self.assertIsNotNone(saved_model_pb) # pylint: enable=protected-access meta_graph_foo_def = None meta_graph_bar_def = None for meta_graph_def in saved_model_pb.meta_graphs: if set(meta_graph_def.meta_info_def.tags) == set(["foo"]): meta_graph_foo_def = meta_graph_def elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]): meta_graph_bar_def = meta_graph_def self.assertIsNotNone(meta_graph_foo_def) self.assertIsNotNone(meta_graph_bar_def) # "Complex" Op has 2 attributes with defaults: # o "T" : float32. (input type) # o "Tout" : complex64. (output type) # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout". # Graph "foo" was saved with strip_default_attrs set to True. node_def = test_util.get_node_def_from_graph("complex", meta_graph_foo_def.graph_def) self.assertNotIn("T", node_def.attr) self.assertNotIn("Tout", node_def.attr) # "Complex" Op in graph "bar" must have attributes "T" and "Tout". # Graph "bar" was saved with strip_default_attrs set to False. node_def = test_util.get_node_def_from_graph("complex", meta_graph_bar_def.graph_def) self.assertIn("T", node_def.attr) self.assertIn("Tout", node_def.attr)
def testStripDefaultAttrs(self): export_dir = self._get_export_dir("test_strip_default_attrs") builder = saved_model_builder.SavedModelBuilder(export_dir) # Add a graph with two float32 variables and a Complex Op composing them # with strip_default_attrs enabled. with session.Session(graph=ops.Graph()) as sess: real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") sess.run(variables.global_variables_initializer()) builder.add_meta_graph_and_variables( sess, ["foo"], strip_default_attrs=True) # Add a graph with the same float32 variables and a Complex Op composing # them with strip_default_attrs disabled. with session.Session(graph=ops.Graph()) as sess: real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") sess.run(variables.global_variables_initializer()) builder.add_meta_graph(["bar"], strip_default_attrs=False) # Save the SavedModel to disk in text format. builder.save(as_text=True) # Loading graph "foo" via the loader must restore the defaults for the # "Complex" node based on the "Complex" OpDef in the Op registry. sess = session.Session(graph=ops.Graph()) meta_graph_def = loader.load(sess, ["foo"], export_dir) complex_node = test_util.get_node_def_from_graph("complex", meta_graph_def.graph_def) self.assertIn("T", complex_node.attr) self.assertIn("Tout", complex_node.attr) # Load graph "foo" from disk as-is to verify default attrs are stripped. # pylint: disable=protected-access saved_model_pb = loader_impl._parse_saved_model(export_dir) self.assertIsNotNone(saved_model_pb) # pylint: enable=protected-access meta_graph_foo_def = None meta_graph_bar_def = None for meta_graph_def in saved_model_pb.meta_graphs: if set(meta_graph_def.meta_info_def.tags) == set(["foo"]): meta_graph_foo_def = meta_graph_def elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]): meta_graph_bar_def = meta_graph_def self.assertIsNotNone(meta_graph_foo_def) self.assertIsNotNone(meta_graph_bar_def) # "Complex" Op has 2 attributes with defaults: # o "T" : float32. (input type) # o "Tout" : complex64. (output type) # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout". # Graph "foo" was saved with strip_default_attrs set to True. node_def = test_util.get_node_def_from_graph("complex", meta_graph_foo_def.graph_def) self.assertNotIn("T", node_def.attr) self.assertNotIn("Tout", node_def.attr) # "Complex" Op in graph "bar" must have attributes "T" and "Tout". # Graph "bar" was saved with strip_default_attrs set to False. node_def = test_util.get_node_def_from_graph("complex", meta_graph_bar_def.graph_def) self.assertIn("T", node_def.attr) self.assertIn("Tout", node_def.attr)