예제 #1
0
    def testQuantizeTrain(self):
        input_meta_name = "input_meta.meta"
        with ops.Graph().as_default():
            self._build_graph(is_freezed=False)
            graph_def = ops.get_default_graph().as_graph_def()
            input_meta_path = os.path.join(self.get_temp_dir(),
                                           input_meta_name)
            saver_lib.export_meta_graph(filename=input_meta_path)
            original_graph_node = [node.name for node in graph_def.node]

        meta_graph_def = MetaGraphDef()
        meta_graph_def = self._parse_def_from_file(meta_graph_def,
                                                   input_meta_path)
        q_config, _ = self._compose_config()
        decent_q.quantize_train(meta_graph_def, q_config)

        output_meta_graph_def = MetaGraphDef()
        output_meta_graph_path = os.path.join(
            self.get_temp_dir(), "quantize_train/quantize_train.ckpt.meta")
        output_meta_graph_def = self._parse_def_from_file(
            output_meta_graph_def, output_meta_graph_path)
        quantize_train_graph_def = output_meta_graph_def.graph_def
        for node in quantize_train_graph_def.node:
            if node.name not in original_graph_node:
                self.assertEqual(node.op, "FixNeuron")
예제 #2
0
def _parse_input_meta_graph_proto(input_graph, input_binary):
    """Parser input tensorflow graph into MetaGraphDef proto."""

    if not gfile.Exists(input_graph):

        print("Input meta graph file '" + input_graph + "' does not exist!")

        return -1

    input_meta_graph_def = MetaGraphDef()

    mode = "rb" if input_binary else "r"

    with gfile.FastGFile(input_graph, mode) as f:

        if input_binary:

            input_meta_graph_def.ParseFromString(f.read())

        else:

            text_format.Merge(f.read(), input_meta_graph_def)

    print("Loaded meta graph file '" + input_graph)

    return input_meta_graph_def
예제 #3
0
def _parse_input_meta_graph(input_meta_graph):
    """Parse input_meta_graph configurations"""
    if not gfile.Exists(input_meta_graph):
        raise ValueError("Input meta graph file '" + input_meta_graph +
                         "' does not exist.")
    meta_graph_def = MetaGraphDef()
    with gfile.GFile(input_meta_graph, "rb") as f:
        meta_graph_def.ParseFromString(f.read())
    return meta_graph_def
예제 #4
0
    def testQuantizeEval(self):
        input_meta_name = "original_meta.meta"
        input_meta_path = os.path.join(self.get_temp_dir(), input_meta_name)
        with ops.Graph().as_default():
            self._build_graph(is_freezed=False)
            graph_def = ops.get_default_graph().as_graph_def()
            saver_lib.export_meta_graph(filename=input_meta_path)

        original_meta_graph_def = MetaGraphDef()
        original_meta_graph_def = self._parse_def_from_file(
            original_meta_graph_def, input_meta_path)
        q_config, _ = self._compose_config()
        decent_q.quantize_train(original_meta_graph_def, q_config)

        quant_train_meta_graph_def = MetaGraphDef()
        quant_train_meta_graph_path = os.path.join(
            self.get_temp_dir(), "quantize_train/quantize_train.ckpt.meta")
        quant_train_meta_graph_def = self._parse_def_from_file(
            quant_train_meta_graph_def, quant_train_meta_graph_path)
        with session.Session() as sess:
            new_saver = saver_lib.import_meta_graph(quant_train_meta_graph_def)

            relu = sess.graph.get_tensor_by_name("relu/aquant:0")
            input_fn = self._mock_input_fn("input:0", [1, 4, 4, 3])
            init = variables.global_variables_initializer()
            sess.run(init)
            relu_val = sess.run([relu], feed_dict=input_fn(1))
        decent_q.quantize_evaluate(quant_train_meta_graph_def, q_config)
        quant_eval_meta_graph_def = MetaGraphDef()
        quant_eval_meta_graph_path = os.path.join(
            self.get_temp_dir(), "quantize_eval/quantize_eval.ckpt.meta")
        quant_eval_meta_graph_def = self._parse_def_from_file(
            quant_eval_meta_graph_def, quant_eval_meta_graph_path)
        eval_quant_pos = [
            node.attr["quantize_pos"].i
            for node in quant_eval_meta_graph_def.graph_def.node
            if node.op == "FixNeuron"
        ]
        self.assertAllEqual([8, 7, 6, 4], eval_quant_pos)
예제 #5
0
    def testDeployCheckpoint(self):
        input_meta_name = "original_meta.meta"
        input_meta_path = os.path.join(self.get_temp_dir(), input_meta_name)
        q_config, _ = self._compose_config()
        with ops.Graph().as_default():
            self._build_graph(is_freezed=False)
            graph_def = ops.get_default_graph().as_graph_def()
            saver_lib.export_meta_graph(filename=input_meta_path)

        original_meta_graph_def = MetaGraphDef()
        original_meta_graph_def = self._parse_def_from_file(
            original_meta_graph_def, input_meta_path)
        decent_q.quantize_train(original_meta_graph_def, q_config)

        quant_train_meta_graph_def = MetaGraphDef()
        quant_train_meta_graph_path = os.path.join(
            self.get_temp_dir(), "quantize_train/quantize_train.ckpt.meta")
        quant_train_meta_graph_def = self._parse_def_from_file(
            quant_train_meta_graph_def, quant_train_meta_graph_path)
        with ops.Graph().as_default():
            new_saver = saver_lib.import_meta_graph(quant_train_meta_graph_def)
            with session.Session() as sess:
                w_t = sess.graph.get_tensor_by_name("w/read/wquant:0")
                b_t = sess.graph.get_tensor_by_name("b/read/wquant:0")
                relu_t = sess.graph.get_tensor_by_name("relu/aquant:0")
                input_fn = self._mock_input_fn("input:0", [1, 4, 4, 3])
                init = variables.global_variables_initializer()
                sess.run(init)
                eval_relu, eval_w, eval_b = sess.run([relu_t, w_t, b_t],
                                                     feed_dict=input_fn(1))

                checkpoint_prefix = os.path.join(self.get_temp_dir(),
                                                 "ckpt/saved_checkpoint")
                checkpoint_state_name = "checkpoint_state"
                checkpoint_path = new_saver.save(
                    sess,
                    checkpoint_prefix,
                    global_step=0,
                    latest_filename=checkpoint_state_name)
        q_config.output_nodes = ["relu/aquant"]
        decent_q.quantize_evaluate(quant_train_meta_graph_def, q_config)
        quant_eval_meta_graph_def = MetaGraphDef()
        quant_eval_meta_graph_path = os.path.join(
            self.get_temp_dir(), "quantize_eval/quantize_eval.ckpt.meta")
        quant_eval_meta_graph_def = self._parse_def_from_file(
            quant_eval_meta_graph_def, quant_eval_meta_graph_path)
        sess.close()
        decent_q.deploy_checkpoint(quant_eval_meta_graph_def, checkpoint_path,
                                   q_config)
        deploy_graph_def = graph_pb2.GraphDef()
        deploy_graph_path = os.path.join(self.get_temp_dir(),
                                         "deploy/deploy_model.pb")
        deploy_graph_def = self._parse_def_from_file(deploy_graph_def,
                                                     deploy_graph_path)
        for node in deploy_graph_def.node:
            if node.name == "conv2d":
                # need to equal with quantize pos in quantize_eval_model.pb
                self.assertAllEqual(node.attr['ipos'].list.i, [8, 6])
                self.assertAllEqual(node.attr['wpos'].list.i, [8, 7])
                self.assertAllEqual(node.attr['bpos'].list.i, [8, 8])
                self.assertAllEqual(node.attr['opos'].list.i, [8, 4])
                deploy_w = tensor_util.MakeNdarray(node.attr['weights'].tensor)
                deploy_b = tensor_util.MakeNdarray(node.attr['bias'].tensor)
                self.assertNDArrayNear(deploy_w, eval_w, 1e-6)
                self.assertNDArrayNear(deploy_b, eval_b, 1e-6)