Exemple #1
0
def Save_pb(sess, net, pbName, clear_devices = True):
    output_node_names = net.name
    output_node_names =output_node_names.split(":")
    output_node_names = output_node_names[0]
    
    input_graph_def = net.graph.as_graph_def()
    
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""
    
    output_graph_def = graph_util.convert_variables_to_constants( 
        sess,
        input_graph_def,
        output_node_names.split(",")
    )
    
    output_graph_def = graph_util.remove_training_nodes(output_graph_def)
    tf.train.write_graph(output_graph_def, './',pbName+"_MobileDeploy.pbtxt")
    print("pbtxt file generated ^____^")
    
    with tf.gfile.GFile(pbName+"_frozen.pb", "wb") as f: 
        f.write(output_graph_def.SerializeToString())
    print("pb file generated ^____^")
    print("%d ops in the final graph." % len(output_graph_def.node))
    def test_conv_biasadd_relu_fusion(self):
        tf.compat.v1.disable_eager_execution()

        self._tmp_graph_def = graph_util.remove_training_nodes(
            self.input_graph, self.outputs)

        self._tmp_graph_def = StripUnusedNodesOptimizer(
            self._tmp_graph_def, self.inputs,
            self.outputs).do_transformation()

        self._tmp_graph_def = FoldBatchNormNodesOptimizer(
            self._tmp_graph_def).do_transformation()
        op_wise_sequences = TensorflowQuery(local_config_file=os.path.join(
            os.path.dirname(__file__),
            "../lpot/adaptor/tensorflow.yaml")).get_eightbit_patterns()

        output_graph = QuantizeGraphForIntel(self._tmp_graph_def, self.outputs,
                                             self.op_wise_config,
                                             op_wise_sequences,
                                             'cpu').do_transform()

        node_name_type_mapping = {}
        for i in output_graph.node:
            node_name_type_mapping[i.name] = i.op

        should_disable_sum_node_name = 'v0/resnet_v17/conv27/conv2d/Conv2D_eightbit_quantized_conv'
        should_enable_sum_node_name = 'v0/resnet_v13/conv11/conv2d/Conv2D_eightbit_quantized_conv'
        should_disable_sum_flag = should_disable_sum_node_name in node_name_type_mapping and node_name_type_mapping[
            should_disable_sum_node_name] == 'QuantizedConv2DWithBias'
        should_enable_sum_flag = should_enable_sum_node_name in node_name_type_mapping and node_name_type_mapping[
            should_enable_sum_node_name] == 'QuantizedConv2DWithBiasSumAndRelu'
        self.assertEqual(should_enable_sum_flag, True)
        self.assertEqual(should_disable_sum_flag, True)
    def __freeze_session(
        self, session, keep_var_names=None, output_names=None, clear_devices=True
    ):

        from tensorflow.python.framework.graph_util import (
            convert_variables_to_constants,
            remove_training_nodes,
        )

        graph = session.graph
        with graph.as_default():
            freeze_var_names = list(
                set(v.op.name for v in tf.global_variables()).difference(
                    keep_var_names or []
                )
            )
            output_names = output_names or []
            output_names += [v.op.name for v in tf.global_variables()]
            # Graph -> GraphDef ProtoBuf
            input_graph_def = graph.as_graph_def()
            Utils.print_nodes(input_graph_def)
            if clear_devices:
                for node in input_graph_def.node:
                    node.device = ""
            frozen_graph = convert_variables_to_constants(
                session, input_graph_def, output_names, freeze_var_names
            )
            frozen_graph = remove_training_nodes(frozen_graph)
            return frozen_graph
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.
    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = remove_training_nodes(graph.as_graph_def())
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)

        return frozen_graph
    def test_conv_biasadd_relu_fusion(self):
        tf.compat.v1.disable_eager_execution()

        self._tmp_graph_def = graph_util.remove_training_nodes(self.input_graph, self.outputs)

        self._tmp_graph_def = StripUnusedNodes(self._tmp_graph_def,
                                                self.inputs, self.outputs).do_transform()

        self._tmp_graph_def = FoldBatchNormNodes(self._tmp_graph_def).do_transform()

        output_graph = QuantizeGraphForIntel(self._tmp_graph_def, self.outputs,
                                             self.op_wise_config,
                                             'cpu').do_transform()

        node_name_type_mapping = {}
        for i in output_graph.node:
            node_name_type_mapping[i.name] = i.op

        should_disable_sum_node_name = 'v0/resnet_v17/conv27/conv2d/Conv2D_eightbit_quantized_conv'
        should_enable_sum_node_name = 'v0/resnet_v13/conv11/conv2d/Conv2D_eightbit_quantized_conv'
        should_disable_sum_flag = should_disable_sum_node_name in node_name_type_mapping and node_name_type_mapping[
            should_disable_sum_node_name] == 'QuantizedConv2DWithBias'
        should_enable_sum_flag = should_enable_sum_node_name in node_name_type_mapping and node_name_type_mapping[
            should_enable_sum_node_name] == 'QuantizedConv2DWithBiasSumAndRelu'
        self.assertEqual(should_enable_sum_flag, True)
        self.assertEqual(should_disable_sum_flag, True)
Exemple #6
0
def save_model(fname, sess, graph=None):
    def save(fname, graph_def):
        pass
        with tf.Graph().as_default() as g:
            tf.import_graph_def(graph_def, name='')
            graph_def = g.as_graph_def(add_shapes=True)
        tf.train.write_graph(graph_def, ".", fname, as_text=False)

    if graph == None:
        graph_def = sess.graph_def
    else:
        graph_def = graph.as_graph_def(add_shapes=True)

    input_nodes = [
        'IteratorGetNext:0', 'IteratorGetNext:1', 'IteratorGetNext:2'
    ]
    output_nodes = ['logits']

    graph_def = graph_util.convert_variables_to_constants(
        sess=sess, input_graph_def=graph_def, output_node_names=output_nodes)
    graph_def = graph_util.remove_training_nodes(graph_def,
                                                 protected_nodes=output_nodes)
    graph_def = optimize_for_inference_lib.optimize_for_inference(
        graph_def, [], output_nodes, dtypes.float32.as_datatype_enum)

    transforms = [
        'remove_nodes(op=Identity, op=StopGradient)',
        'fold_batch_norms',
        'fold_old_batch_norms',
    ]
    graph_def = TransformGraph(graph_def, input_nodes, output_nodes,
                               transforms)
    save("build/data/bert_tf_v1_1_large_fp32_384_v2/model.pb", graph_def)
Exemple #7
0
def ckpt2pb():
    input_size = opts["test"]["input_size"]
    use_moving_avg = opts["yolo"]["use_moving_avg"]
    moving_avg_decay = opts["yolo"]["moving_avg_decay"]
    precision = tf.float16 if opts["yolo"][
        "precision"] == "fp16" else tf.float32
    with tf.name_scope("input"):
        input_data = tf.placeholder(dtype=precision,
                                    shape=(1, input_size, input_size, 3),
                                    name="input_data")
    output = inference(input_data)
    output_names = []
    for tensor in output:
        output_names.append(tensor.name.split(":")[0])

    sess = tf.InteractiveSession()

    sess.run(tf.global_variables_initializer())
    if use_moving_avg:
        with tf.name_scope("ema"):
            ema_obj = tf.train.ExponentialMovingAverage(moving_avg_decay)
        saver = tf.train.Saver(ema_obj.variables_to_restore())
    else:
        saver = tf.train.Saver()
    saver.restore(sess, arguments.ckpt_path)

    constant_graph = graph_util.convert_variables_to_constants(
        sess, sess.graph_def, output_names)
    constant_graph = graph_util.remove_training_nodes(constant_graph)

    with tf.gfile.GFile(arguments.pb_path, mode='wb') as f:
        f.write(constant_graph.SerializeToString())
Exemple #8
0
def freeze_graph(model_folder, output_node_names):
    """Takes a model folder and creates a frozen model file."""
	
    try:
        checkpoint = tf.train.get_checkpoint_state(model_folder)
        input_checkpoint = checkpoint.model_checkpoint_path

        absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
        print "absolute_model_folder: %s" % absolute_model_folder

        output_graph = absolute_model_folder + "/frozen_model.pb"
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()

        with tf.Session() as sess:
            saver.restore(sess, input_checkpoint)
            output_graph_def = graph_util.convert_variables_to_constants(
              sess=sess,
              input_graph_def=input_graph_def,
              output_node_names=output_node_names.split(",")
            )

            output_graph_def = graph_util.remove_training_nodes(output_graph_def)

            with tf.gfile.GFile(output_graph, "wb") as f:
                f.write(output_graph_def.SerializeToString())
            print("%d ops in the final graph." % len(output_graph_def.node))

    except:
        e = sys.exc_info()[0]
        print e
Exemple #9
0
def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
                           placeholder_type_enum):
  """Applies a series of inference optimizations on the input graph.

  Args:
    input_graph_def: A GraphDef containing a training model.
    input_node_names: A list of names of the nodes that are fed inputs during
      inference.
    output_node_names: A list of names of the nodes that produce the final
      results.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    An optimized version of the input graph.
  """
  ensure_graph_is_valid(input_graph_def)
  optimized_graph_def = input_graph_def
  optimized_graph_def = strip_unused_lib.strip_unused(optimized_graph_def,
                                                      input_node_names,
                                                      output_node_names,
                                                      placeholder_type_enum)
  optimized_graph_def = graph_util.remove_training_nodes(
      optimized_graph_def, output_node_names)
  optimized_graph_def = fold_batch_norms(optimized_graph_def)
  optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
                                             output_node_names)
  ensure_graph_is_valid(optimized_graph_def)
  return optimized_graph_def
Exemple #10
0
    def rewrite(self, output_node_names):
        """Triggers rewriting of the float graph.

        Args:
          output_node_names: A list of names of the nodes that produce the final
            results.

        Returns:
          A quantized version of the float graph.
        """
        self.output_graph = graph_pb2.GraphDef()
        output_nodes = [
            self.nodes_map[output_node_name]
            for output_node_name in output_node_names
        ]
        # output_nodes = []
        # for output_node_name in output_node_names:
        #   output_nodes.append(self.nodes_map[output_node_name])

        # When function graph_util.remove_training_nodes remove
        # "Identity" ops in the graph, it does not replace the
        # control input properly, so the control input becomes
        # the regular input. Disable this function until the
        # the bug is fixed.

        self.set_input_graph(
            graph_util.remove_training_nodes(
                self.input_graph, protected_nodes=output_node_names))

        output_nodes = [
            self.nodes_map[output_node_name]
            for output_node_name in output_node_names
        ]

        # output_nodes=[]
        # for output_node_name in output_node_names:
        #   output_nodes.append(self.nodes_map[output_node_name])

        self.state = EightbitizeRecursionState(already_visited={},
                                               output_node_stack=[],
                                               merged_with_fake_quant={})

        # TODO(intel-tf): Enables fused quantized node for intel cpu.
        for output_node in output_nodes:
            # Intiailize output_node_stack with output node.
            # Each element in the stack is a mutable list containing
            # [parent_node, index_to_parent, quantization_flag, fusion_flag].
            # In case of root node, make self as parent.
            self.state.output_node_stack.append(
                [output_node, None, False, False])
            self.intel_cpu_eightbitize_nodes_recursively(output_node)
            self.state.output_node_stack.pop()

        self.state = None
        if strip_redundant_quantization:
            self.output_graph = self.remove_redundant_quantization(
                self.output_graph)
            self.remove_dead_nodes(output_node_names)
        self.apply_final_node_renames()
        return self.output_graph
Exemple #11
0
def export_keras_to_tf(input_model='model.h5',
                       output_model='model.pb',
                       num_output=1):
    print('Loading Keras model: ', input_model)

    keras_model = load_model(input_model,
                             custom_objects={'Normalization': Normalization()})

    print(keras_model.summary())

    predictions = [None] * num_output
    prediction_node_names = [None] * num_output

    for i in range(num_output):
        prediction_node_names[i] = 'output_node' + str(i)
        predictions[i] = tf.identity(keras_model.outputs[i],
                                     name=prediction_node_names[i])

    session = K.get_session()

    constant_graph = graph_util.convert_variables_to_constants(
        session, session.graph.as_graph_def(), prediction_node_names)
    infer_graph = graph_util.remove_training_nodes(constant_graph)

    graph_io.write_graph(infer_graph, '.', output_model, as_text=False)
  def test_keep_control_edges(self):
    no_op_name = "no_op"
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = graph_pb2.GraphDef()
    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
    graph_def.node.extend([no_op])
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=1, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node(
        "Identity", a_identity_name,
        [a_constant_name, "^" + a_check_name, "^" + no_op_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=1, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node(
        "Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name, b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
    graph_def.node.extend([add_node])

    expected_output = graph_pb2.GraphDef()
    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
    expected_output.node.extend([no_op])
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=1, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([a_constant])
    a_identity_node = quantize_graph.create_node(
        "Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
    expected_output.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=1, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name, b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
    expected_output.node.extend([add_node])
    expected_output.versions.CopyFrom(graph_def.versions)
    expected_output.library.CopyFrom(graph_def.library)

    output = graph_util.remove_training_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
  def testRemoveTrainingNodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = self.create_constant_node_def(a_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = self.create_node_def("CheckNumerics", a_check_name,
                                        [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = self.create_node_def("Identity", a_identity_name,
                                           [a_constant_name,
                                            "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = self.create_constant_node_def(b_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = self.create_node_def("CheckNumerics", b_check_name,
                                        [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = self.create_node_def("Identity", b_identity_name,
                                           [b_constant_name,
                                            "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = self.create_node_def("Add", add_name,
                                    [a_identity_name,
                                     b_identity_name])
    self.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = self.create_constant_node_def(a_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    expected_output.node.extend([a_constant])
    b_constant = self.create_constant_node_def(b_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    expected_output.node.extend([b_constant])
    add_node = self.create_node_def("Add", add_name,
                                    [a_constant_name,
                                     b_constant_name])
    self.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    output = graph_util.remove_training_nodes(graph_def)
    self.assertProtoEquals(expected_output, output)
Exemple #14
0
  def testRemoveTrainingNodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = self.create_constant_node_def(a_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = self.create_node_def("CheckNumerics", a_check_name,
                                        [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = self.create_node_def("Identity", a_identity_name,
                                           [a_constant_name,
                                            "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = self.create_constant_node_def(b_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = self.create_node_def("CheckNumerics", b_check_name,
                                        [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = self.create_node_def("Identity", b_identity_name,
                                           [b_constant_name,
                                            "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = self.create_node_def("Add", add_name,
                                    [a_identity_name,
                                     b_identity_name])
    self.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = self.create_constant_node_def(a_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    expected_output.node.extend([a_constant])
    b_constant = self.create_constant_node_def(b_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    expected_output.node.extend([b_constant])
    add_node = self.create_node_def("Add", add_name,
                                    [a_constant_name,
                                     b_constant_name])
    self.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    output = graph_util.remove_training_nodes(graph_def)
    self.assertProtoEquals(expected_output, output)
    def save_frozen_model(self, model_filename, input_shape):
        """
        Save frozen TensorFlow formatted model protobuf
        """
        model = self.load_model(model_filename)

        # Change filename to protobuf extension
        base = os.path.basename(model_filename)
        output_model = os.path.splitext(base)[0] + ".pb"

        # Set Keras to inference
        K.backend._LEARNING_PHASE = tf.constant(0)
        K.backend.set_learning_phase(False)
        K.backend.set_learning_phase(0)
        K.backend.set_image_data_format("channels_last")

        num_output = len(model.outputs)
        predictions = [None] * num_output
        prediction_node_names = [None] * num_output

        for i in range(num_output):
            prediction_node_names[i] = "output_node" + str(i)
            predictions[i] = tf.identity(model.outputs[i],
                                         name=prediction_node_names[i])

        sess = K.backend.get_session()

        constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), prediction_node_names)
        infer_graph = graph_util.remove_training_nodes(constant_graph)

        # Write protobuf of frozen model
        frozen_dir = "./frozen_model/"
        shutil.rmtree(frozen_dir,
                      ignore_errors=True)  # Remove existing directory
        graph_io.write_graph(infer_graph,
                             frozen_dir,
                             output_model,
                             as_text=False)

        pb_filename = os.path.join(frozen_dir, output_model)
        print("Frozen TensorFlow model written to: {}".format(pb_filename))
        print("Convert this to OpenVINO by running:\n")
        print("source /opt/intel/openvino/bin/setupvars.sh")
        print(
            "python $INTEL_OPENVINO_DIR/deployment_tools/model_optimizer/mo_tf.py \\"
        )
        print("       --input_model {} \\".format(pb_filename))

        shape_string = "[1"
        for idx in range(len(input_shape[1:])):
            shape_string += ",{}".format(input_shape[idx + 1])
        shape_string += "]"

        print("       --input_shape {} \\".format(shape_string))
        print("       --output_dir openvino_models/FP32/ \\")
        print("       --data_type FP32\n\n")
Exemple #16
0
def _convert_op_hints_if_present(sess, output_tensors):
  if is_frozen_graph(sess):
    raise ValueError("Try to convert op hints, needs unfrozen graph.")
  hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
  output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
  graph_def = tf_graph_util.convert_variables_to_constants(
      sess, sess.graph_def, output_arrays + hinted_outputs_nodes)
  graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
  graph_def = tf_graph_util.remove_training_nodes(graph_def)
  return graph_def
Exemple #17
0
def export_to_pb(sess, x, filename):
    pred_names = ['output']
    tf.identity(x, name=pred_names[0])

    graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), pred_names)

    graph = graph_util.remove_training_nodes(graph)
    path = graph_io.write_graph(graph, ".", filename, as_text=False)
    print('saved the frozen graph (ready for inference) at: ', path)
Exemple #18
0
def _convert_op_hints_if_present(sess, output_tensors):
    if is_frozen_graph(sess):
        raise ValueError("Try to convert op hints, needs unfrozen graph.")
    hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
    output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
    graph_def = tf_graph_util.convert_variables_to_constants(
        sess, sess.graph_def, output_arrays + hinted_outputs_nodes)
    graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
    graph_def = tf_graph_util.remove_training_nodes(graph_def)
    return graph_def
def linear_regression_model(x,y,test_x=None,test_y=None,training_epochs=1000,learning_rate=0.01):
	# getting placeholders (our input values)
	X = tf.placeholder(np.float32,[None,3],name="X")
	Y = tf.placeholder(np.float32,name="Y") 
	# declaring weights and bias
	W = tf.Variable(np.random.randn(3,1), name = "W",dtype=np.float32) 
	b = tf.Variable(np.random.randn(), name = "b",dtype=np.float32)
	# Hypothesis 
	y_pred = tf.add(tf.matmul(X, W), b,name="y_pred")
	# Mean Squared Error Cost Function
	cost = tf.reduce_sum(tf.pow(y_pred-Y, 2)) / (2 * len(x)) 
	# Gradient Descent Optimizer 
	optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) 
	# Global Variables Initializer 
	init = tf.global_variables_initializer()
	# Starting the Tensorflow Session 
	with tf.Session() as sess:
		# Initializing the Variables 
		sess.run(init)
		saver = tf.train.Saver()
		# Iterating through all the epochs 
		for epoch in range(training_epochs):
			# Feeding each data point into the optimizer using Feed Dictionary
			for i in range(len(x)):
				sess.run(optimizer, feed_dict = {X: [x[i]], Y:[y[i]]})
			# Displaying the result after every 50 epochs 
			if (epoch + 1) % 50 == 0: 
				# Calculating the cost a every epoch 
				c = sess.run(cost, feed_dict = {X: x, Y:y}) 
				print("Epoch", (epoch + 1), ": cost =", c, "W =", sess.run(W), "b =", sess.run(b)) 
		
		# Storing necessary values to be used outside the Session 
		training_cost =sess.run(cost, feed_dict = {X: x, Y:y}) 
		weight = sess.run(W) 
		bias = sess.run(b)
		#saving our model
		saver.save(sess, "./chkps/mnist_model")
		out_nodes = [y_pred.op.name]
		print("output_nodes: ",out_nodes)
		sub_graph_def = remove_training_nodes(sess.graph_def)
		sub_graph_def = gu.convert_variables_to_constants(sess, sub_graph_def, out_nodes)
		graph_path = tf.train.write_graph(sub_graph_def,
                                  "./predictor_model",
                                  "predictor.pb",
                                  as_text=False)

		print('written graph to: %s' % graph_path)
		sess.close()

	# testing the model
	if not (test_x.any()==None or test_y.any()==None): 
		for i in range(len(test_x)):
			# Calculating the predictions 
			prediction = np.matmul(x[i],weight).flatten() + bias 
			print("Training cost =", training_cost, "Weight =", weight, "bias =", bias,"prediction =",prediction,"value =",y[i], '\n')
  def testRemoveIdentityUsedAsControlInputInConst(self):
    """Check that Identity nodes used as control inputs are not removed."""
    graph_def = graph_pb2.GraphDef()
    graph_def.node.extend([
        self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
        self.create_node_def("Identity", "I", ["Base"]),
        self.create_node_def("BaseOp", "Base", [])
    ])

    self.assertProtoEquals(graph_def,
                           graph_util.remove_training_nodes(graph_def))
Exemple #21
0
  def testRemoveIdentityUsedAsControlInputInConst(self):
    """Check that Identity nodes used as control inputs are not removed."""
    graph_def = graph_pb2.GraphDef()
    graph_def.node.extend([
        self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
        self.create_node_def("Identity", "I", ["Base"]),
        self.create_node_def("BaseOp", "Base", [])
    ])

    self.assertProtoEquals(graph_def,
                           graph_util.remove_training_nodes(graph_def))
Exemple #22
0
    def rewrite(self, output_node_names):
        """Triggers rewriting of the float graph.

    Args:
      output_node_names: A list of names of the nodes that produce the final
        results.

    Returns:
      A quantized version of the float graph.
    """
        self.output_graph = tf.GraphDef()
        output_nodes = [
            self.nodes_map[output_node_name]
            for output_node_name in output_node_names
        ]
        if self.mode == "round":
            self.already_visited = {}
            for output_node in output_nodes:
                self.round_nodes_recursively(output_node)
        elif self.mode == "quantize":
            self.already_visited = {}
            self.already_quantized = {}
            for output_node in output_nodes:
                self.quantize_nodes_recursively(output_node)
        elif self.mode == "eightbit":
            self.set_input_graph(
                graph_util.remove_training_nodes(self.input_graph))
            output_nodes = [
                self.nodes_map[output_node_name]
                for output_node_name in output_node_names
            ]
            self.already_visited = {}
            self.layers_eightbitized = []
            for output_node in output_nodes:
                self.eightbitize_nodes_recursively(output_node)
            self.output_graph = self.quantize_weights(self.output_graph,
                                                      b"MIN_FIRST")
            if FLAGS.strip_redundant_quantization:
                self.output_graph = self.remove_redundant_quantization(
                    self.output_graph)
                self.remove_dead_nodes(output_node_names)
        elif self.mode == "weights":
            self.output_graph = self.quantize_weights(self.input_graph,
                                                      b"MIN_COMBINED")
            self.remove_dead_nodes(output_node_names)
        elif self.mode == "weights_rounded":
            self.output_graph = self.quantize_weights(self.input_graph,
                                                      self.mode)
            self.remove_dead_nodes(output_node_names)
        else:
            print("Bad mode - " + self.mode + ".")
        return self.output_graph
Exemple #23
0
  def rewrite(self, output_node_names):
    """Triggers rewriting of the float graph.

    Args:
      output_node_names: A list of names of the nodes that produce the final
        results.

    Returns:
      A quantized version of the float graph.
    """
    self.output_graph = tf.GraphDef()
    output_nodes = [self.nodes_map[output_node_name]
                    for output_node_name in output_node_names]
    if self.mode == "round":
      self.already_visited = {}
      for output_node in output_nodes:
        self.round_nodes_recursively(output_node)
    elif self.mode == "quantize":
      self.already_visited = {}
      self.already_quantized = {}
      for output_node in output_nodes:
        self.quantize_nodes_recursively(output_node)
    elif self.mode == "eightbit":
      self.set_input_graph(graph_util.remove_training_nodes(self.input_graph))
      output_nodes = [self.nodes_map[output_node_name]
                      for output_node_name in output_node_names]

      self.already_visited = {}
      self.layers_eightbitized = []
      for output_node in output_nodes:
        self.eightbitize_nodes_recursively(output_node)
      self.output_graph = self.quantize_weights(self.output_graph, b"MIN_FIRST")
      if self.input_range:
        self.add_output_graph_node(create_constant_node(
            "quantized_input_min_value", self.input_range[0], tf.float32, []))
        self.add_output_graph_node(create_constant_node(
            "quantized_input_max_value", self.input_range[1], tf.float32, []))
      if FLAGS.strip_redundant_quantization:
        self.output_graph = self.remove_redundant_quantization(
            self.output_graph)
        self.remove_dead_nodes(output_node_names)
      self.apply_final_node_renames()
    elif self.mode == "weights":
      self.output_graph = self.quantize_weights(self.input_graph,
                                                b"MIN_COMBINED")
      self.remove_dead_nodes(output_node_names)
    elif self.mode == "weights_rounded":
      self.output_graph = self.quantize_weights(self.input_graph, self.mode)
      self.remove_dead_nodes(output_node_names)
    else:
      print("Bad mode - " + self.mode + ".")
    return self.output_graph
Exemple #24
0
    def _load_saved_model(self):
        """Load the tensorflow saved model."""
        try:
            from tensorflow.python.tools import freeze_graph
            from tensorflow.python.framework import ops
            from tensorflow.python.framework import graph_util
            from tensorflow.core.framework import graph_pb2
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import tensorflow which is "
                "required to restore from saved model.")

        saved_model_dir = self._model_dir
        output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
        input_saved_model_dir = saved_model_dir
        output_node_names = self._get_output_names()

        input_binary = False
        input_saver_def_path = False
        restore_op_name = None
        filename_tensor_name = None
        clear_devices = True
        input_meta_graph = False
        checkpoint_path = None
        input_graph_filename = None
        saved_model_tags = ",".join(self._get_tag_set())

        freeze_graph.freeze_graph(
            input_graph_filename,
            input_saver_def_path,
            input_binary,
            checkpoint_path,
            output_node_names,
            restore_op_name,
            filename_tensor_name,
            output_graph_filename,
            clear_devices,
            "",
            "",
            "",
            input_meta_graph,
            input_saved_model_dir,
            saved_model_tags,
        )

        with ops.Graph().as_default():  # pylint: disable=not-context-manager
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_filename, "rb") as f:
                output_graph_def.ParseFromString(f.read())
            output_graph_def = graph_util.remove_training_nodes(
                output_graph_def, protected_nodes=self._outputs)
            return output_graph_def
Exemple #25
0
def export(x: tf.Tensor, filename: str):
    with tf.Session() as sess:
        pred_node_names = ["output"]
        tf.identity(x, name=pred_node_names[0])

        graph = graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), pred_node_names)

        graph = graph_util.remove_training_nodes(graph)

        path = graph_io.write_graph(graph, ".", filename, as_text=False)

    return path
Exemple #26
0
def freeze_graph(checkpoints_path, output_graph):
    """
    :param checkpoints_path: ckpt文件路径
    :param output_graph: pb模型保存路径
    :return:
    """
    with tf.Graph().as_default():
        image = tf.placeholder(shape=[None, 608, 608, 3],
                               dtype=tf.float32,
                               name='inputs')

        # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
        output_node_names = "reorg_layer/obj_probs,reorg_layer/class_probs,reorg_layer/bboxes_probs"

        # 从模型代码中获取结构
        Model = network.Network(is_train=False)
        logits = Model.build_network(image)
        output = Model.reorg_layer(logits, model_params['anchors'])

        # 从meta中获取结构
        #saver = tf.train.import_meta_graph(checkpoints_path + '.meta', clear_devices=True)

        # 获得默认的图
        graph = tf.get_default_graph()

        # 返回一个序列化的图代表当前的图
        input_graph_def = graph.as_graph_def()

        with tf.Session() as sess:
            saver = tf.train.Saver()
            # 恢复图并得到数据
            saver.restore(sess, checkpoints_path)

            # 模型持久化,将变量值固定
            output_graph_def = graph_util.convert_variables_to_constants(
                sess=sess,
                input_graph_def=input_graph_def,
                output_node_names=output_node_names.split(","))

            # 删除训练层,只保留主干
            output_graph_def = graph_util.remove_training_nodes(
                output_graph_def)

            # 保存模型
            with tf.gfile.GFile(output_graph, "wb") as f:

                # 序列化输出
                f.write(output_graph_def.SerializeToString())

            # 得到当前图有几个操作节点
            print("%d ops in the final graph." % len(output_graph_def.node))
    def __init__(self, model, session=None):
        """
		This constructor takes a reference to a TensorFlow Operation or Tensor or Keras model and then applies the two TensorFlow functions
		graph_util.convert_variables_to_constants and graph_util.remove_training_nodes to cleanse the graph of any nodes that are linked to training. This leaves us with 
		the nodes you need for inference. 
		In the resulting graph there should only be tf.Operations left that have one of the following types [Const, MatMul, Add, BiasAdd, Conv2D, Reshape, MaxPool, AveragePool, Placeholder, Relu, Sigmoid, Tanh]
		If the input should be a Keras model we will ignore operations with type Pack, Shape, StridedSlice, and Prod such that the Flatten layer can be used.
		
		Arguments
		---------
		model : tensorflow.Tensor or tensorflow.Operation or tensorflow.python.keras.engine.sequential.Sequential or keras.engine.sequential.Sequential
		    if tensorflow.Tensor: model.op will be treated as the output node of the TensorFlow model. Make sure that the graph only contains supported operations after applying
		                          graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [model.op.name] as output_node_names
		    if tensorflow.Operation: model will be treated as the output of the TensorFlow model. Make sure that the graph only contains supported operations after applying
		                          graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [model.op.name] as output_node_names
		    if tensorflow.python.keras.engine.sequential.Sequential: x = model.layers[-1].output.op.inputs[0].op will be treated as the output node of the Keras model. Make sure that the graph only
		                          contains supported operations after applying graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [x.name] as
		                          output_node_names
		    if keras.engine.sequential.Sequential: x = model.layers[-1].output.op.inputs[0].op will be treated as the output node of the Keras model. Make sure that the graph only
		                          contains supported operations after applying graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [x.name] as
		                          output_node_names
		session : tf.Session
		    session which contains the information about the trained variables. If None the code will take the Session from tf.get_default_session(). If you pass a keras model you don't have to
		    provide a session, this function will automatically get it.
		"""
        output_names = None
        if issubclass(model.__class__, tf.Tensor):
            output_names = [model.op.name]
        elif issubclass(model.__class__, tf.Operation):
            output_names = [model.name]
        elif issubclass(model.__class__, Sequential):
            session = tf.keras.backend.get_session()
            output_names = [model.layers[-1].output.op.inputs[0].op.name]
            model = model.layers[-1].output.op
        elif issubclass(model.__class__, onnx.ModelProto):
            assert 0, 'not tensorflow model'
        else:
            import keras
            if issubclass(model.__class__, keras.engine.sequential.Sequential):
                session = keras.backend.get_session()
                output_names = [model.layers[-1].output.op.inputs[0].op.name]
                model = model.layers[-1].output.op
            else:
                assert 0, "ERAN can't recognize this input"

        if session is None:
            session = tf.get_default_session()

        tmp = graph_util.convert_variables_to_constants(
            session, model.graph.as_graph_def(), output_names)
        self.graph_def = graph_util.remove_training_nodes(tmp)
Exemple #28
0
def export_keras_to_tf(input_model, output_model):
    print("Loading Keras model: ", input_model)

    keras_model = load_model(input_model, compile=False)

    print(keras_model.summary())

    num_output = len(keras_model.outputs)
    predictions = [None] * num_output
    prediction_node_names = [None] * num_output

    for i in range(num_output):
        prediction_node_names[i] = "output_node" + str(i)
        predictions[i] = tf.identity(keras_model.outputs[i],
                                     name=prediction_node_names[i])

    sess = K.get_session()

    constant_graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), prediction_node_names)
    infer_graph = graph_util.remove_training_nodes(constant_graph)

    # Write protobuf of frozen model
    frozen_dir = "./tf_protobuf/"
    shutil.rmtree(frozen_dir, ignore_errors=True)  # Remove existing directory
    graph_io.write_graph(infer_graph, frozen_dir, output_model, as_text=False)

    print("Input shape {}".format(keras_model.inputs[0].shape))

    pb_filename = os.path.join(frozen_dir, output_model)
    print("Frozen TensorFlow model written to: {}".format(pb_filename))
    print("Convert this to OpenVINO by running:\n")
    print("source /opt/intel/openvino/bin/setupvars.sh")
    print(
        "python $INTEL_OPENVINO_DIR/deployment_tools/model_optimizer/mo_tf.py \\"
    )
    print("       --input_model {} \\".format(pb_filename))

    shape_string = "[1"
    for idx in range(len(keras_model.inputs[0].shape[1:])):
        shape_string += ",{}".format(keras_model.inputs[0].shape[idx + 1])
    shape_string += "]"

    print("       --input_shape {} \\".format(shape_string))
    print("       --output_dir openvino_models/FP32/ \\")
    print("       --data_type FP32")

    return prediction_node_names
Exemple #29
0
 def _fuse_requantize_with_fused_quantized_conv(self):
     self._tmp_graph_def = fuse_quantized_conv_and_requantize(
         self._tmp_graph_def)
     # strip_unused_nodes with optimize_for_inference
     dtypes = self._get_dtypes(self._tmp_graph_def)
     # self._tmp_graph_def = optimize_for_inference(self._tmp_graph_def, self.inputs, self.outputs, dtypes, False)
     self._tmp_graph_def = StripUnusedNodes(self._tmp_graph_def,
                                            self.inputs, self.outputs,
                                            dtypes).do_transform()
     self._tmp_graph_def = graph_util.remove_training_nodes(
         self._tmp_graph_def, self.outputs)
     self._tmp_graph_def = FoldBatchNormNodes(
         self._tmp_graph_def).do_transform()
     RerangeQuantizedConcat(self._tmp_graph_def).do_transformation()
     write_graph(self._tmp_graph_def, self.output_graph)
     logging.info('Converted graph file is saved to: %s', self.output_graph)
Exemple #30
0
def ckpt2pb():
    with tf.Graph().as_default() as graph_old:
        isess = tf.compat.v1.InteractiveSession()
        # isess = tf.Session()
        # ckpt_filename = './model.ckpt'
        ckpt_filename = './tf_ckpts_singlenet/ckpt-30.data-00000-of-00002'
        isess.run(tf.compat.v1.global_variables_initializer())
        saver = tf.compat.v1.train.import_meta_graph(ckpt_filename,
                                                     clear_devices=True)
        saver.restore(isess, ckpt_filename)

        constant_graph = graph_util.convert_variables_to_constants(
            isess, isess.graph_def, ["Cls/fc/biases"])
        constant_graph = graph_util.remove_training_nodes(constant_graph)

        with tf.gfile.GFile('./pb_model/model.pb', mode='wb') as f:
            f.write(constant_graph.SerializeToString())
Exemple #31
0
    def _optimize_frozen_fp32_graph(self):
        """Optimize fp32 frozen graph."""

        self._tmp_graph_def = read_graph(self.input_graph,
                                         self.input_graph_binary_flag)
        dtypes = self._get_dtypes(self._tmp_graph_def)
        # self._tmp_graph_def = optimize_for_inference(self._tmp_graph_def, self.inputs, self.outputs, dtypes, False)
        self._tmp_graph_def = FuseColumnWiseMul(
            self._tmp_graph_def).do_transformation()
        self._tmp_graph_def = StripUnusedNodes(self._tmp_graph_def,
                                               self.inputs, self.outputs,
                                               dtypes).do_transform()
        self._tmp_graph_def = graph_util.remove_training_nodes(
            self._tmp_graph_def, self.outputs)
        self._tmp_graph_def = FoldBatchNormNodes(
            self._tmp_graph_def).do_transform()
        write_graph(self._tmp_graph_def, self._fp32_optimized_graph)
Exemple #32
0
def export_cnn() -> None:
    input = tf.placeholder(tf.float32, shape=(1, 1, 3, 3))
    filter = tf.constant(np.ones((3, 3, 1, 1)), dtype=tf.float32)
    x = tf.nn.conv2d(input, filter, (1, 1, 1, 1), "SAME", data_format='NCHW')
    x = tf.nn.sigmoid(x)
    x = tf.nn.relu(x)

    pred_node_names = ["output"]
    tf.identity(x, name=pred_node_names[0])

    with tf.Session() as sess:
        constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), pred_node_names)

    frozen = graph_util.remove_training_nodes(constant_graph)

    output = "cnn.pb"
    graph_io.write_graph(frozen, ".", output, as_text=False)
Exemple #33
0
def export(x: tf.Tensor, filename: str, sess=None):
    should_close = False
    if sess is None:
        should_close = True
        sess = tf.Session()

    pred_node_names = ["output"]
    tf.identity(x, name=pred_node_names[0])
    graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), pred_node_names)

    graph = graph_util.remove_training_nodes(graph)

    path = graph_io.write_graph(graph, ".", filename, as_text=False)

    if should_close:
        sess.close()

    return path
def optimize_for_inference(input_graph_def, input_node_names,
                           output_node_names, placeholder_type_enum):
  """Applies a series of inference optimizations on the input graph.

  Args:
    input_graph_def: A GraphDef containing a training model.
    input_node_names: A list of names of the nodes that are fed inputs during
      inference.
    output_node_names: A list of names of the nodes that produce the final
      results.
    placeholder_type_enum: Data type of the placeholders used for inputs.

  Returns:
    An optimized version of the input graph.
  """
  stripped_graph_def = strip_unused_lib.strip_unused(input_graph_def,
                                                     input_node_names,
                                                     output_node_names,
                                                     placeholder_type_enum)
  detrained_graph_def = graph_util.remove_training_nodes(stripped_graph_def)
  folded_graph_def = fold_batch_norms(detrained_graph_def)
  return folded_graph_def
  def testRemoveIdentityChains(self):
    """Check that chains of Identity nodes are correctly pruned.

    Create a chain of four nodes, A, B, C, and D where A inputs B, B inputs C,
    and C inputs D. Nodes B and C are "Identity" and should be pruned, resulting
    in the nodes A and D, where A inputs D.
    """
    graph_def = graph_pb2.GraphDef()
    graph_def.node.extend([
        self.create_node_def("Aop", "A", ["B"]), self.create_node_def(
            "Identity", "B", ["C"]), self.create_node_def(
                "Identity", "C", ["D"]), self.create_node_def("Dop", "D", [])
    ])

    expected_graph_def = graph_pb2.GraphDef()
    expected_graph_def.node.extend([
        self.create_node_def("Aop", "A", ["D"]), self.create_node_def(
            "Dop", "D", [])
    ])

    self.assertProtoEquals(expected_graph_def,
                           graph_util.remove_training_nodes(graph_def))
Exemple #36
0
    def _load_saved_model(self):
        """Load the tensorflow saved model."""
        try:
            from tensorflow.python.tools import freeze_graph
            from tensorflow.python.framework import ops
            from tensorflow.python.framework import graph_util
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import tensorflow which is "
                "required to restore from saved model.")

        saved_model_dir = self._model_dir
        output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
        input_saved_model_dir = saved_model_dir
        output_node_names = self._get_output_names()

        input_binary = False
        input_saver_def_path = False
        restore_op_name = None
        filename_tensor_name = None
        clear_devices = True
        input_meta_graph = False
        checkpoint_path = None
        input_graph_filename = None
        saved_model_tags = ",".join(self._get_tag_set())

        freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
                                  input_binary, checkpoint_path, output_node_names,
                                  restore_op_name, filename_tensor_name,
                                  output_graph_filename, clear_devices, "", "", "",
                                  input_meta_graph, input_saved_model_dir,
                                  saved_model_tags)

        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_filename, "rb") as f:
                output_graph_def.ParseFromString(f.read())
            output_graph_def = graph_util.remove_training_nodes(output_graph_def)
            return output_graph_def
def main(_):
  # Import data
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

  # Specify inputs, outputs, and a cost function
  # placeholders
  x = tf.placeholder(tf.float32, [None, 784], name="x")
  y_ = tf.placeholder(tf.float32, [None, 10], name="y")

  # Build the graph for the deep net
  y_pred, logits = deepnn(x)

  with tf.name_scope("Loss"):
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_,
                                                               logits=logits)
    loss = tf.reduce_mean(cross_entropy, name="cross_entropy_loss")
  train_step = tf.train.AdamOptimizer(1e-4).minimize(loss, name="train_step")

  with tf.name_scope("Prediction"):
    correct_prediction = tf.equal(y_pred,
                                  tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")

  # Start training session
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    # SGD
    for i in range(1, FLAGS.num_iter + 1):
      batch_images, batch_labels = mnist.train.next_batch(FLAGS.batch_size)
      feed_dict = {x: batch_images, y_: batch_labels}
      train_step.run(feed_dict=feed_dict)
      if i % FLAGS.log_iter == 0:
        train_accuracy = accuracy.eval(feed_dict=feed_dict)
        print('step %d, training accuracy %g' % (i, train_accuracy))

    print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images,
                                                        y_: mnist.test.labels}))
    # Saving checkpoint and serialize the graph
    ckpt_path = saver.save(sess, FLAGS.chkp)
    print('saving checkpoint: %s' % ckpt_path)
    out_nodes = [y_pred.op.name]
    # Freeze graph and remove training nodes
    sub_graph_def = gu.remove_training_nodes(sess.graph_def)
    sub_graph_def = gu.convert_variables_to_constants(sess, sub_graph_def, out_nodes)
    if FLAGS.no_quant:
      graph_path = tf.train.write_graph(sub_graph_def,
                                        FLAGS.output_dir,
                                        FLAGS.pb_fname,
                                        as_text=False)
    else:
      # # quantize the graph
      # quant_graph_def = TransformGraph(sub_graph_def,
      #                                  [],
      #                                  out_nodes,
      #                                  ["quantize_weights", "quantize_nodes"])
      graph_path = tf.train.write_graph(sub_graph_def,
                                        FLAGS.output_dir,
                                        FLAGS.pb_fname,
                                        as_text=False)
    print('written graph to: %s' % graph_path)
    print('the output nodes: {!s}'.format(out_nodes))
def freeze_graph(model_folder,output_nodes):
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    print(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path
    absolute_model_folder= "/".join(input_checkpoint.split('/')[:-1])
    output_graph= absolute_model_folder+ "/frozen_model.pb"
    # in order to freeze the graph, we need to tell TF which nodes will be used at utput
    # We can get the node names from our previous keras model definition
    output_node_name= output_nodes

    # if we trained the model on GPU, we want to be sure there are no explicit GPU directives on the graph nides
    clear_devices= True
    new_saver= tf.train.import_meta_graph(input_checkpoint+ '.meta' , clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess2:
        print(input_checkpoint)
        new_saver.restore(sess2, input_checkpoint)

        graph_util.remove_training_nodes(
            input_graph_def,
            protected_nodes=None
        )
        # Since we are freeing the models, we want to turn all the rainable variables to constants
        output_graph_def= graph_util.convert_variables_to_constants(sess2,  # This is used to retrieve the weights
                  input_graph_def,  # The graph def is used to retrieve the nodes
                  output_node_name)  # The output node names are used to select the useful nodes
        # We can now freze the graph and save it
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph" % len(output_graph_def.node))