Ejemplo n.º 1
0
    def test_load_graph(self):
        import google.protobuf as pb

        concrete_func = self.create_tf_function(concrete=True)

        with tmp_file(suffix=".pb") as path_pb, tmp_file(
                suffix=".pb.txt") as path_txt:
            cmsml.tensorflow.save_graph(path_txt,
                                        concrete_func,
                                        variables_to_constants=True)
            cmsml.tensorflow.save_graph(path_pb,
                                        concrete_func,
                                        variables_to_constants=False)

            self.assertTrue(os.path.exists(path_pb))
            self.assertTrue(os.path.exists(path_txt))

            graph = cmsml.tensorflow.load_graph(path_txt)
            self.assertIsInstance(graph, self.tf.Graph)

            graph = cmsml.tensorflow.load_graph(path_pb)
            self.assertIsInstance(graph, self.tf.Graph)

            with self.assertRaises(pb.text_format.ParseError):
                cmsml.tensorflow.load_graph(path_pb, as_text=True)
            with self.assertRaises(pb.message.DecodeError):
                cmsml.tensorflow.load_graph(path_txt, as_text=False)
Ejemplo n.º 2
0
    def test_save_keras_model_v1(self):
        model = self.create_keras_model(self.tf1)

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path,
                                        model,
                                        variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb.txt") as path:
            cmsml.tensorflow.save_graph(path,
                                        model,
                                        variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path,
                                        model,
                                        variables_to_constants=True)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path,
                                        self.tf1.keras.backend.get_session(),
                                        variables_to_constants=False)
            self.assertTrue(os.path.exists(path))
Ejemplo n.º 3
0
    def test_save_keras_model_v2(self):
        model = self.create_keras_model(self.tf)

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, model, variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, model, variables_to_constants=True)
            self.assertTrue(os.path.exists(path))
Ejemplo n.º 4
0
    def test_save_empty_polymorphic_function(self):
        empty_poly_func = self.create_tf_function(no_input=True)

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, empty_poly_func, variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, empty_poly_func, variables_to_constants=True)
            self.assertTrue(os.path.exists(path))
Ejemplo n.º 5
0
    def test_save_polymorphic_function_error(self):
        poly_func = self.create_tf_function()

        with self.assertRaises(ValueError):
            with tmp_file(suffix=".pb") as path:
                cmsml.tensorflow.save_graph(path, poly_func, variables_to_constants=False)

        with self.assertRaises(ValueError):
            with tmp_file(suffix=".pb") as path:
                cmsml.tensorflow.save_graph(path, poly_func, variables_to_constants=True)
Ejemplo n.º 6
0
    def test_save_concrete_function(self):
        concrete_func = self.create_tf_function(concrete=True)

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, concrete_func, variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb.txt") as path:
            cmsml.tensorflow.save_graph(path, concrete_func, variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, concrete_func, variables_to_constants=True)
            self.assertTrue(os.path.exists(path))
Ejemplo n.º 7
0
    def test_save_graph(self):
        graph, session = self.create_tf1_graph()
        if graph is None or session is None:
            return

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, graph, variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb.txt") as path:
            cmsml.tensorflow.save_graph(path, graph, variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, graph.as_graph_def(), variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, session, variables_to_constants=False)
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb") as path:
            cmsml.tensorflow.save_graph(path, session, variables_to_constants=True,
                output_names=["output"])
            self.assertTrue(os.path.exists(path))

        with tmp_file(suffix=".pb") as path:
            with self.assertRaises(ValueError):
                cmsml.tensorflow.save_graph(path, session, variables_to_constants=True)
            self.assertFalse(os.path.exists(path))
Ejemplo n.º 8
0
    def test_write_summary(self):
        concrete_func = self.create_tf_function(concrete=True)

        with tmp_dir(create=False) as path:
            cmsml.tensorflow.write_graph_summary(concrete_func.graph, path)
            self.assertTrue(os.path.exists(path))
            self.assertGreater(len(os.listdir(path)), 0)

        with tmp_file(suffix=".pb") as graph_path:
            cmsml.tensorflow.save_graph(graph_path, concrete_func)
            with tmp_dir(create=False) as path:
                cmsml.tensorflow.write_graph_summary(graph_path, path)
                self.assertTrue(os.path.exists(path))
                self.assertGreater(len(os.listdir(path)), 0)
                self.assertTrue(os.path.exists(path))
Ejemplo n.º 9
0
    def test_load_graph_and_run(self):
        import numpy as np

        tf = self.tf1
        if tf is None:
            return

        _, session = self.create_tf1_graph()
        with tmp_file(suffix=".pb.txt") as path:
            cmsml.tensorflow.save_graph(path, session, variables_to_constants=True,
                output_names=["output"])
            graph = cmsml.tensorflow.load_graph(path)

        session = self.create_tf1_session(graph)
        with graph.as_default():
            x = graph.get_tensor_by_name("input:0")
            y = graph.get_tensor_by_name("output:0")
            out = session.run(y, {x: np.ones((2, 10))})

        self.assertEqual(out.shape, (2, 1))
        self.assertEqual(tuple(out[..., 0]), (1., 1.))