def testFromLibraryMissingFuncDef(self): @function.Defun(dtypes.float32, dtypes.float32) def G1(x, dy): return x * dy @function.Defun(dtypes.float32) def F1(x): return math_ops.exp(x) - math_ops.exp(-x) gradient = function_pb2.GradientDef() gradient.function_name = F1.name gradient.gradient_func = G1.name # Create invalid function def that is missing G1 function def library = function_pb2.FunctionDefLibrary() library.gradient.extend([gradient]) library.function.extend([F1.definition]) with self.assertRaisesRegexp( ValueError, "FunctionDefLibrary missing 'G1_........' FunctionDef"): function._from_library(library) # Create invalid function def that is missing F1 function def library = function_pb2.FunctionDefLibrary() library.gradient.extend([gradient]) library.function.extend([G1.definition]) with self.assertRaisesRegexp( ValueError, "FunctionDefLibrary missing 'F1_........' FunctionDef"): function._from_library(library)
def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self): @function.Defun(dtypes.float32) def F1(x): return math_ops.exp(x) - math_ops.exp(-x) f1_def = F1.definition library = function_pb2.FunctionDefLibrary() library.function.extend([f1_def]) graph_def1 = graph_pb2.GraphDef() graph_def1.library.CopyFrom(library) reversed_function = function_pb2.FunctionDef() reversed_function.CopyFrom(f1_def) # Clear the node_def attribute. del reversed_function.node_def[:] reversed_function.node_def.extend(reversed(f1_def.node_def)) reversed_library = function_pb2.FunctionDefLibrary() reversed_library.function.extend([reversed_function]) graph_def2 = graph_pb2.GraphDef() graph_def2.library.CopyFrom(reversed_library) self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
def testFromLibraryCyclicGradFuncs(self): @function.Defun(dtypes.float32) def F1(x): return math_ops.exp(x) - math_ops.exp(-x) @function.Defun(dtypes.float32) def F2(x): return math_ops.exp(x) - math_ops.exp(-x) # Create invalid function def library where F1 has gradient function F2 and # F2 has gradient function F1 library = function_pb2.FunctionDefLibrary() library.function.extend([F1.definition, F2.definition]) gradient1 = function_pb2.GradientDef() gradient1.function_name = F1.name gradient1.gradient_func = F2.name gradient2 = function_pb2.GradientDef() gradient2.function_name = F2.name gradient2.gradient_func = F1.name library.gradient.extend([gradient1, gradient2]) with self.assertRaisesRegexp( ValueError, "FunctionDefLibrary contains cyclic gradient functions!"): function._from_library(library)
def testGraphDefsWithPermutedFunctionsCompareEqual(self): @function.Defun(dtypes.float32) def F1(x): return math_ops.exp(x) - math_ops.exp(-x) @function.Defun(dtypes.float32) def F2(x): return math_ops.exp(x) definition_1 = F1.definition definition_2 = F2.definition library = function_pb2.FunctionDefLibrary() library.function.extend([definition_1, definition_2]) graph_def1 = graph_pb2.GraphDef() graph_def1.library.CopyFrom(library) reversed_library = function_pb2.FunctionDefLibrary() reversed_library.function.extend([definition_2, definition_1]) graph_def2 = graph_pb2.GraphDef() graph_def2.library.CopyFrom(reversed_library) self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
def testGraphDefsWithFunctionLibsCompareEqual(self): @function.Defun(dtypes.float32) def F1(x): return math_ops.exp(x) - math_ops.exp(-x) library = function_pb2.FunctionDefLibrary() library.function.extend([F1.definition]) graph_def1 = graph_pb2.GraphDef() graph_def1.library.CopyFrom(library) graph_def2 = graph_pb2.GraphDef() graph_def2.library.CopyFrom(library) self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
def test_meta_graph_transform(self): with ops.Graph().as_default(): with tf_session.Session(''): a = array_ops.placeholder(dtypes.int64, [1], name='a') b = array_ops.placeholder(dtypes.int64, [1], name='b') c = array_ops.placeholder(dtypes.int64, [1], name='c') _ = a * b _ = b * c base_meta_graph_def = saver.export_meta_graph() with ops.Graph().as_default(): with tf_session.Session(''): a = array_ops.placeholder(dtypes.int64, [1], name='a') b = array_ops.placeholder(dtypes.int64, [1], name='b') _ = a * b meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() meta_info_def.tags.append('tag_ab') expected_meta_graph_def = saver.export_meta_graph( meta_info_def=meta_info_def) # Graph rewriter clears versions field, so we expect that. expected_meta_graph_def.graph_def.ClearField('versions') # Graph rewriter adds an empty library field, so we expect that. expected_meta_graph_def.graph_def.library.CopyFrom( function_pb2.FunctionDefLibrary()) input_names = ['a', 'b'] output_names = ['mul:0'] transforms = ['strip_unused_nodes'] tags = ['tag_ab'] print('AAAAAA: {}'.format(base_meta_graph_def)) transformed_meta_graph_def = meta_graph_transform.meta_graph_transform( base_meta_graph_def, input_names, output_names, transforms, tags) self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
def testFromLibraryEmptyLib(self): library = function_pb2.FunctionDefLibrary() self.assertEqual(len(function._from_library(library)), 0)