Beispiel #1
0
def convert_variables_to_constants(sess,
                                   input_graph_def,
                                   output_node_names,
                                   variable_names_whitelist=None,
                                   variable_names_blacklist=None,
                                   use_fp16=False):
    from tensorflow.python.framework.graph_util_impl import extract_sub_graph
    from tensorflow.core.framework import graph_pb2
    from tensorflow.core.framework import node_def_pb2
    from tensorflow.core.framework import attr_value_pb2
    from tensorflow.core.framework import types_pb2
    from tensorflow.python.framework import tensor_util

    def patch_dtype(input_node, field_name, output_node):
        if use_fp16 and (field_name in input_node.attr) and (
                input_node.attr[field_name].type == types_pb2.DT_FLOAT):
            output_node.attr[field_name].CopyFrom(
                attr_value_pb2.AttrValue(type=types_pb2.DT_HALF))

    inference_graph = extract_sub_graph(input_graph_def, output_node_names)

    variable_names = []
    variable_dict_names = []
    for node in inference_graph.node:
        if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
            variable_name = node.name
            if ((variable_names_whitelist is not None
                 and variable_name not in variable_names_whitelist)
                    or (variable_names_blacklist is not None
                        and variable_name in variable_names_blacklist)):
                continue
            variable_dict_names.append(variable_name)
            if node.op == "VarHandleOp":
                variable_names.append(variable_name + "/Read/ReadVariableOp:0")
            else:
                variable_names.append(variable_name + ":0")
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    found_variables = dict(zip(variable_dict_names, returned_variables))

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]

            if use_fp16 and dtype.type == types_pb2.DT_FLOAT:
                output_node.attr["value"].CopyFrom(
                    attr_value_pb2.AttrValue(
                        tensor=tensor_util.make_tensor_proto(
                            data.astype('float16'),
                            dtype=types_pb2.DT_HALF,
                            shape=data.shape)))
            else:
                output_node.attr["dtype"].CopyFrom(dtype)
                output_node.attr["value"].CopyFrom(
                    attr_value_pb2.AttrValue(
                        tensor=tensor_util.make_tensor_proto(
                            data, dtype=dtype.type, shape=data.shape)))
            how_many_converted += 1
        elif input_node.op == "ReadVariableOp" and (input_node.input[0]
                                                    in found_variables):
            # placeholder nodes
            # print('- %s | %s ' % (input_node.name, input_node.attr["dtype"]))
            output_node.op = "Identity"
            output_node.name = input_node.name
            output_node.input.extend([input_node.input[0]])
            output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
            if "_class" in input_node.attr:
                output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
        else:
            # mostly op nodes
            output_node.CopyFrom(input_node)

        patch_dtype(input_node, 'dtype', output_node)
        patch_dtype(input_node, 'T', output_node)
        patch_dtype(input_node, 'DstT', output_node)
        patch_dtype(input_node, 'SrcT', output_node)
        patch_dtype(input_node, 'Tparams', output_node)

        if use_fp16 and ('value' in output_node.attr) and (
                output_node.attr['value'].tensor.dtype == types_pb2.DT_FLOAT):
            # hard-coded value need to be converted as well
            output_node.attr['value'].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    output_node.attr['value'].tensor.float_val[0],
                    dtype=types_pb2.DT_HALF)))

        output_graph_def.node.extend([output_node])

    output_graph_def.library.CopyFrom(inference_graph.library)
    return output_graph_def
Beispiel #2
0
def quantize_graph_def(graph_def,
                       skip=None,
                       output_nodes=None,
                       rel_tol=None,
                       only=None):
    """
  :type graph_def: GraphDef
  :type skip: set|list
  :type output_nodes: list
  :type rel_tol: float
  :type only: str
  :return: QuantizedGraph
  """
    if output_nodes is not None and len(output_nodes) > 0:
        graph_def = extract_sub_graph(graph_def, output_nodes)

    nodes = []
    items = []
    for node in graph_def.node:
        # check skip
        if should_skip(node, skip):
            nodes.append(node)
            continue

        # try convert to constant
        try:
            value = MakeNdarray(node.attr['value'].tensor)  # type: np.ndarray
        except TypeError:
            nodes.append(node)
            continue

        # check repeated field
        same_value = all_same_value(value, rel_tol)
        if same_value is not None:
            nodes.append(
                const_node(node.attr['dtype'].type,
                           np.array([same_value], dtype=value.dtype),
                           value.shape))
            continue

        # check data size
        elif value.size < 4096:
            nodes.append(node)
            continue

        # finally
        processed_node = NodeDef()
        processed_node.name = node.name
        processed_node.op = 'Placeholder'
        processed_node.attr['dtype'].type = node.attr['dtype'].type
        processed_node.attr['shape'].shape.CopyFrom(
            as_shape(value.shape).as_proto())
        nodes.append(processed_node)

        item = QuantizedItem()
        item.name = node.name
        item.dtype = node.attr['dtype'].type
        item.shape.extend(value.shape)
        print('quantize {}'.format(node.name))
        _fill(item, value, only=only)
        items.append(item)
    graph = QuantizedGraph()
    graph.graph.versions.CopyFrom(graph_def.versions)
    graph.graph.library.CopyFrom(graph_def.library)
    graph.graph.node.extend(nodes)
    graph.items.extend(items)
    return graph