Exemple #1
0
  def testSimpleGraphdefsCompareEqual(self):
    graph_def1 = graph_pb2.GraphDef()
    graph_def1.node.extend([
        self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
        self.create_node_def("Identity", "I", ["Base"]),
        self.create_node_def("BaseOp", "Base", [])
    ])

    graph_def2 = graph_pb2.GraphDef()
    graph_def2.node.extend([
        self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
        self.create_node_def("Identity", "I", ["Base"]),
        self.create_node_def("BaseOp", "Base", [])
    ])

    self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
Exemple #2
0
  def testGraphdefsWithNanCompareNonEqual(self):
    graph_def1 = graph_pb2.GraphDef()
    graph_def1.node.extend([
        self.create_constant_node_def(
            "C", float("nan"), dtypes.float32, inputs=["^I"]),
        self.create_node_def("Identity", "I", ["Base"]),
        self.create_node_def("BaseOp", "Base", [])
    ])

    graph_def2 = graph_pb2.GraphDef()
    graph_def2.node.extend([
        self.create_constant_node_def(
            "C", float("nan"), dtypes.float32, inputs=["^I"]),
        self.create_node_def("Identity", "I", ["Base"]),
        self.create_node_def("BaseOp", "Base", [])
    ])
    self.assertFalse(graph_util.graph_defs_equal(graph_def1, graph_def2))
Exemple #3
0
  def testSimpleGraphdefEqualityWithNansEqual(self):
    graph_def1 = graph_pb2.GraphDef()
    graph_def1.node.extend([
        self.create_constant_node_def(
            "C", float("nan"), dtypes.float32, inputs=["^I"]),
        self.create_node_def("Identity", "I", ["Base"]),
        self.create_node_def("BaseOp", "Base", [])
    ])

    graph_def2 = graph_pb2.GraphDef()
    graph_def2.node.extend([
        self.create_constant_node_def(
            "C", float("nan"), dtypes.float32, inputs=["^I"]),
        self.create_node_def("Identity", "I", ["Base"]),
        self.create_node_def("BaseOp", "Base", [])
    ])
    self.assertTrue(
        graph_util.graph_defs_equal(
            graph_def1, graph_def2, treat_nan_as_equal=True))
Exemple #4
0
    def testGraphDefsWithPermutedFunctionsCompareEqual(self):
        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        @function.Defun(dtypes.float32)
        def F2(x):
            return math_ops.exp(x)

        definition_1 = F1.definition
        definition_2 = F2.definition
        library = function_pb2.FunctionDefLibrary()
        library.function.extend([definition_1, definition_2])

        graph_def1 = graph_pb2.GraphDef()
        graph_def1.library.CopyFrom(library)

        reversed_library = function_pb2.FunctionDefLibrary()
        reversed_library.function.extend([definition_2, definition_1])
        graph_def2 = graph_pb2.GraphDef()
        graph_def2.library.CopyFrom(reversed_library)

        self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
Exemple #5
0
    def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self):
        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        f1_def = F1.definition

        library = function_pb2.FunctionDefLibrary()
        library.function.extend([f1_def])

        graph_def1 = graph_pb2.GraphDef()
        graph_def1.library.CopyFrom(library)

        reversed_function = function_pb2.FunctionDef()
        reversed_function.CopyFrom(f1_def)
        # Clear the node_def attribute.
        del reversed_function.node_def[:]
        reversed_function.node_def.extend(reversed(f1_def.node_def))
        reversed_library = function_pb2.FunctionDefLibrary()
        reversed_library.function.extend([reversed_function])
        graph_def2 = graph_pb2.GraphDef()
        graph_def2.library.CopyFrom(reversed_library)

        self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))