def testFromLibraryMissingFuncDef(self):
        @function.Defun(dtypes.float32, dtypes.float32)
        def G1(x, dy):
            return x * dy

        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        gradient = function_pb2.GradientDef()
        gradient.function_name = F1.name
        gradient.gradient_func = G1.name

        # Create invalid function def that is missing G1 function def
        library = function_pb2.FunctionDefLibrary()
        library.gradient.extend([gradient])
        library.function.extend([F1.definition])

        with self.assertRaisesRegexp(
                ValueError,
                "FunctionDefLibrary missing 'G1_........' FunctionDef"):
            function._from_library(library)

        # Create invalid function def that is missing F1 function def
        library = function_pb2.FunctionDefLibrary()
        library.gradient.extend([gradient])
        library.function.extend([G1.definition])

        with self.assertRaisesRegexp(
                ValueError,
                "FunctionDefLibrary missing 'F1_........' FunctionDef"):
            function._from_library(library)
Example #2
0
  def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self):

    @function.Defun(dtypes.float32)
    def F1(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    f1_def = F1.definition

    library = function_pb2.FunctionDefLibrary()
    library.function.extend([f1_def])

    graph_def1 = graph_pb2.GraphDef()
    graph_def1.library.CopyFrom(library)

    reversed_function = function_pb2.FunctionDef()
    reversed_function.CopyFrom(f1_def)
    # Clear the node_def attribute.
    del reversed_function.node_def[:]
    reversed_function.node_def.extend(reversed(f1_def.node_def))
    reversed_library = function_pb2.FunctionDefLibrary()
    reversed_library.function.extend([reversed_function])
    graph_def2 = graph_pb2.GraphDef()
    graph_def2.library.CopyFrom(reversed_library)

    self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
    def testFromLibraryCyclicGradFuncs(self):
        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        @function.Defun(dtypes.float32)
        def F2(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        # Create invalid function def library where F1 has gradient function F2 and
        # F2 has gradient function F1
        library = function_pb2.FunctionDefLibrary()
        library.function.extend([F1.definition, F2.definition])

        gradient1 = function_pb2.GradientDef()
        gradient1.function_name = F1.name
        gradient1.gradient_func = F2.name

        gradient2 = function_pb2.GradientDef()
        gradient2.function_name = F2.name
        gradient2.gradient_func = F1.name

        library.gradient.extend([gradient1, gradient2])

        with self.assertRaisesRegexp(
                ValueError,
                "FunctionDefLibrary contains cyclic gradient functions!"):
            function._from_library(library)
Example #4
0
    def testGraphDefsWithPermutedFunctionsCompareEqual(self):
        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        @function.Defun(dtypes.float32)
        def F2(x):
            return math_ops.exp(x)

        definition_1 = F1.definition
        definition_2 = F2.definition
        library = function_pb2.FunctionDefLibrary()
        library.function.extend([definition_1, definition_2])

        graph_def1 = graph_pb2.GraphDef()
        graph_def1.library.CopyFrom(library)

        reversed_library = function_pb2.FunctionDefLibrary()
        reversed_library.function.extend([definition_2, definition_1])
        graph_def2 = graph_pb2.GraphDef()
        graph_def2.library.CopyFrom(reversed_library)

        self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
Example #5
0
    def testGraphDefsWithFunctionLibsCompareEqual(self):
        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        library = function_pb2.FunctionDefLibrary()
        library.function.extend([F1.definition])

        graph_def1 = graph_pb2.GraphDef()
        graph_def1.library.CopyFrom(library)

        graph_def2 = graph_pb2.GraphDef()
        graph_def2.library.CopyFrom(library)

        self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
Example #6
0
    def test_meta_graph_transform(self):

        with ops.Graph().as_default():
            with tf_session.Session(''):
                a = array_ops.placeholder(dtypes.int64, [1], name='a')
                b = array_ops.placeholder(dtypes.int64, [1], name='b')
                c = array_ops.placeholder(dtypes.int64, [1], name='c')
                _ = a * b
                _ = b * c
                base_meta_graph_def = saver.export_meta_graph()

        with ops.Graph().as_default():
            with tf_session.Session(''):
                a = array_ops.placeholder(dtypes.int64, [1], name='a')
                b = array_ops.placeholder(dtypes.int64, [1], name='b')
                _ = a * b
                meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
                meta_info_def.tags.append('tag_ab')

                expected_meta_graph_def = saver.export_meta_graph(
                    meta_info_def=meta_info_def)
                # Graph rewriter clears versions field, so we expect that.
                expected_meta_graph_def.graph_def.ClearField('versions')
                # Graph rewriter adds an empty library field, so we expect that.
                expected_meta_graph_def.graph_def.library.CopyFrom(
                    function_pb2.FunctionDefLibrary())

        input_names = ['a', 'b']
        output_names = ['mul:0']
        transforms = ['strip_unused_nodes']
        tags = ['tag_ab']
        print('AAAAAA: {}'.format(base_meta_graph_def))
        transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
            base_meta_graph_def, input_names, output_names, transforms, tags)

        self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
 def testFromLibraryEmptyLib(self):
     library = function_pb2.FunctionDefLibrary()
     self.assertEqual(len(function._from_library(library)), 0)