Exemplo n.º 1
0
def create_inference_graph(input_graph_def,
                           outputs,
                           max_batch_size=1,
                           max_workspace_size_bytes=2 << 20,
                           precision_mode="FP32",
                           minimum_segment_size=3,
                           is_dynamic_op=False,
                           maximum_cached_engines=1,
                           cached_engine_batches=[]):
    """Python wrapper for the TRT transformation.

  Args:
    input_graph_def: GraphDef object containing a model to be transformed.
    outputs: list of tensors or node names for the model outputs.
    max_batch_size: max size for the input batch
    max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
    precision_mode: one of 'FP32', 'FP16' and 'INT8'
    minimum_segment_size: the minimum number of nodes required for a subgraph to
      be replaced by TRTEngineOp.
    is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
      network and engine at run time.
    maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
    cached_engine_batches: batch sizes used to pre-create cached engines.

  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing subgraphs.

  Raises:
    ValueError: if the provided precision mode is invalid.
    RuntimeError: if the returned status message is malformed.
  """
    supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
    if precision_mode.upper() not in supported_precision_modes:
        raise ValueError(
            ("precision mode '{}' is not supported."
             "It should be one of {}").format(precision_mode,
                                              "{'FP32', 'FP16', 'INT8'}"))
    mode = supported_precision_modes[precision_mode.upper()]
    compiled_version = get_linked_tensorrt_version()
    loaded_version = get_loaded_tensorrt_version()
    version_mismatch = False
    if loaded_version[0] < compiled_version[0]:
        tf_logging.error(
            "TensorRT version mismatch. Tensorflow was compiled against " +
            "TensorRT %s but library loaded from environment is TensorRT %s" %
            (".".join([str(x) for x in compiled_version]),
             ".".join([str(x) for x in loaded_version])) +
            ". Please make sure that correct version of TensorRT " +
            "is available in the system and added to ldconfig or LD_LIBRARY_PATH"
        )
        raise RuntimeError("Incompatible TensorRT library version")
    for i in zip(loaded_version, compiled_version):
        if i[0] != i[1]:
            tf_logging.warn("TensorRT mismatch. Compiled against version " +
                            "%s, but loaded %s. Things may not work" %
                            (".".join([str(x) for x in compiled_version]),
                             ".".join([str(x) for x in loaded_version])))
            version_mismatch = True
            break
    if not version_mismatch:
        tf_logging.info("Running against TensorRT version %s" %
                        ".".join([str(x) for x in loaded_version]))

    def py2bytes(inp):
        return inp

    def py3bytes(inp):
        return inp.encode("utf-8", errors="surrogateescape")

    def py2string(inp):
        return inp

    def py3string(inp):
        return inp.decode("utf-8")

    if _six.PY2:
        to_bytes = py2bytes
        to_string = py2string
    else:
        to_bytes = py3bytes
        to_string = py3string

    out_names = []
    for i in outputs:
        if isinstance(i, ops.Tensor):
            out_names.append(to_bytes(i.name))
        else:
            out_names.append(to_bytes(i))

    input_graph_def_str = input_graph_def.SerializeToString()

    # TODO(sami): Fix this when we can return status from C++ library
    # There is a problem with the TF internal library setup that doesn't
    # allow us to return a status object from C++.  Thus we return a
    # pair or strings where first one is encoded status and the second
    # one is the transformed graphs protobuf string.
    out = trt_convert(input_graph_def_str, out_names, max_batch_size,
                      max_workspace_size_bytes, mode, minimum_segment_size,
                      is_dynamic_op, maximum_cached_engines,
                      cached_engine_batches)
    status = to_string(out[0])
    output_graph_def_string = out[1]
    del input_graph_def_str  # Save some memory
    if len(status) < 2:
        raise _impl.UnknownError(None, None, status)
    if status[:2] != "OK":
        msg = status.split(";")
        if len(msg) == 1:
            raise RuntimeError("Status message is malformed {}".format(status))
        # pylint: disable=protected-access
        raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                             int(msg[0]))
        # pylint: enable=protected-access
    output_graph_def = graph_pb2.GraphDef()
    output_graph_def.ParseFromString(output_graph_def_string)
    del output_graph_def_string  # Save some memory
    return output_graph_def
def _get_dummy_graphdef():
    dummy_graphdef = graph_pb2.GraphDef()
    text_format.Merge(graphdef_string, dummy_graphdef)
    return dummy_graphdef
Exemplo n.º 3
0
def _convert_variables_to_constants_v2_impl(func,
                                            lower_control_flow=True,
                                            aggressive_inlining=False):
  """Replaces all the variables in a graph with constants of the same values.

  TensorFlow 2.0 function for converting all Variable ops into Const ops holding
  the same values. This makes it possible to describe the network fully with a
  single GraphDef file, and allows the removal of a lot of ops related to
  loading and saving the variables. This function runs Grappler's function
  inlining optimization in order to return a single subgraph.

  The current implementation only works for graphs that do not contain any
  control flow or embedding related ops.

  Note that the NodeDefs in the returned GraphDef contains the original node
  names if they are created by the graph optimization. Converting the GraphDef
  to concrete function will lose these debug information.

  Args:
    func: ConcreteFunction.
    lower_control_flow: Boolean indicating whether or not to lower control flow
      ops such as If and While. (default True)
    aggressive_inlining: Inlining functions with stateful ops might lead to
      undefined execution if function call doesn't have an outgoing control
      edge and control outputs (they should be added automatically in TFv2).
      Aggressive mode disables safety checks in Grappler function optimizer.

  Returns:
    GraphDef containing a simplified version of the original and converted
    input indices that were converted to constants.
  """
  # Inline the graph in order to remove functions when possible.
  graph_def = _run_inline_graph_optimization(func, lower_control_flow,
                                             aggressive_inlining)

  # Gets list of all node defs include those in the library.
  node_defs = _get_node_defs_list(graph_def)

  # Get mapping from node name to node.
  name_to_node = {_get_tensor_name(node.name): node for node in node_defs}

  # Get mapping from node name to variable value.
  tensor_data = _get_tensor_data(func)

  # Get mapping from function name to argument types.
  function_data = _get_control_flow_function_data(
      node_defs, tensor_data, name_to_node)

  # Get variable data for all nodes in `node_defs`.
  reference_variables = {}
  resource_identities = {}
  placeholders = {}
  converted_input_indices = set()

  def _save_placeholder(node_name, dtype):
    placeholders[node_name] = {
        "dtype": dtype,
        "data": tensor_data[node_name]["data"],
    }
    converted_input_indices.add(tensor_data[node_name]["index"])

  for node in node_defs:
    if node.op in _CONDITIONAL_OPS:
      # Get dtype and data for resource Placeholders.
      then_func = node.attr["then_branch"].func.name
      arg_types = function_data[then_func]["types"]
      for idx, input_tensor in enumerate(node.input[1:]):
        input_name = _get_tensor_name(input_tensor)
        if input_name in tensor_data:
          dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
          _save_placeholder(_get_tensor_name(input_tensor), dtype)
    elif node.op in _LOOP_OPS:
      # Get dtype and data for resource Placeholders.
      cond_func = node.attr["cond"].func.name
      arg_types = function_data[cond_func]["types"]
      for idx, input_tensor in enumerate(node.input):
        input_name = _get_tensor_name(input_tensor)
        if input_name in tensor_data:
          dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
          _save_placeholder(_get_tensor_name(input_tensor), dtype)
    elif (node.op == "Identity" and node.attr["T"].type == dtypes.resource and
          name_to_node[_get_tensor_name(node.input[0])].op in _LOOP_OPS):
      # Store the dtype for Identity resource ops that are outputs of While ops.
      while_node = name_to_node[_get_tensor_name(node.input[0])]
      body_func = while_node.attr["body"].func.name
      input_data = node.input[0].split(":")
      idx = 0 if len(input_data) == 1 else int(input_data[1])

      dtype = attr_value_pb2.AttrValue(
          type=function_data[body_func]["types"][idx])
      resource_identities[node.name] = dtype
    elif node.op == "VariableV2":
      # Get data for VariableV2 ops (reference variables) that cannot be lifted.
      with func.graph.as_default():
        identity_node = array_ops.identity(
            func.graph.as_graph_element(node.name + ":0"))
      reference_variables[node.name] = (
          func.prune([], [identity_node.name])()[0])
    elif node.name in tensor_data and not tensor_data[node.name]["is_variable"]:
      # Get dtype and data for non-variable Placeholders (ex. values for 1.X
      # Const ops that are loaded as Placeholders in 2.0)
      _save_placeholder(node.name, node.attr["dtype"])
    elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd"]:
      # Get dtype and data for Placeholder ops associated with ReadVariableOp
      # and ResourceGather ops. There can be an Identity in between the
      # resource op and Placeholder. Store the dtype for the Identity ops.
      input_name = _get_tensor_name(node.input[0])
      while name_to_node[input_name].op == "Identity":
        resource_identities[input_name] = node.attr["dtype"]
        input_name = _get_tensor_name(name_to_node[input_name].input[0])
      if name_to_node[input_name].op != "Placeholder":
        raise ValueError("Cannot find the Placeholder op that is an input "
                         "to the ReadVariableOp.")
      _save_placeholder(input_name, node.attr["dtype"])

  # Reconstruct the graph with constants in place of variables.
  output_graph_def = graph_pb2.GraphDef()

  for input_node in graph_def.node:
    output_node = output_graph_def.node.add()
    # Convert VariableV2 ops to Const ops.
    if input_node.name in reference_variables:
      data = reference_variables[input_node.name]
      dtype = attr_value_pb2.AttrValue(type=data.dtype.as_datatype_enum)
      _populate_const_op(output_node, input_node.name, dtype, data.numpy(),
                         data.shape)
    # Convert Placeholder ops to Const ops.
    elif input_node.name in placeholders:
      data = placeholders[input_node.name]["data"]
      dtype = placeholders[input_node.name]["dtype"]
      _populate_const_op(output_node, input_node.name, dtype, data, data.shape)
    # Update the dtype for Identity ops that are inputs to ReadVariableOps.
    elif input_node.name in resource_identities:
      output_node.CopyFrom(input_node)
      output_node.attr["T"].CopyFrom(resource_identities[input_node.name])
    # Convert ReadVariableOps to Identity ops.
    elif input_node.op == "ReadVariableOp":
      _populate_identity_op(output_node, input_node)
    # Convert ResourceGather to Gather ops with a Const axis feeding into it.
    elif input_node.op == "ResourceGather":
      if input_node.attr["batch_dims"].i != 0:
        raise ValueError("batch_dims != 0 is not supported by freeze_graph.")
      output_axis_node = output_graph_def.node.add()
      axis_node_name = input_node.name + "/axis"
      axis_dtype = input_node.attr["Tindices"]
      axis_data = np.array(input_node.attr["batch_dims"].i)
      _populate_const_op(output_axis_node, axis_node_name, axis_dtype,
                         axis_data, axis_data.shape)

      output_node.op = "GatherV2"
      output_node.name = input_node.name
      output_node.input.extend(
          [input_node.input[0], input_node.input[1], axis_node_name])
      output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
      output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
      output_node.attr["Taxis"].CopyFrom(axis_dtype)
      if "_class" in input_node.attr:
        output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
    elif input_node.op == "ResourceGatherNd":
      output_node.op = "GatherNd"
      output_node.name = input_node.name
      output_node.input.extend(
          [input_node.input[0], input_node.input[1]])
      output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
      output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
      if "_class" in input_node.attr:
        output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
    # Update the function names and argument types for the conditional ops.
    elif input_node.op in _CONDITIONAL_OPS:
      _populate_if_op(output_node, input_node, function_data)
    elif input_node.op in _LOOP_OPS:
      _populate_while_op(output_node, input_node, function_data)
    else:
      output_node.CopyFrom(input_node)

  # Add functions to reconstructed graph.
  if graph_def.library:
    library = output_graph_def.library

    for input_library_func in graph_def.library.function:
      orig_func_name = input_library_func.signature.name
      new_func_name = _get_new_function_name(orig_func_name)

      # Do not copy any functions that aren't being used in the graph. Any
      # functions that are not used by control flow should have been inlined.
      if orig_func_name not in function_data:
        continue

      output_library_func = library.function.add()
      for key, value in input_library_func.ret.items():
        output_library_func.ret[key] = value
      for key, value in input_library_func.control_ret.items():
        output_library_func.control_ret[key] = value

      # Update the input types in the function signature. Update the output
      # types for functions that are while loop bodies.
      output_library_func.signature.CopyFrom(input_library_func.signature)
      output_library_func.signature.name = new_func_name
      for dtype, arg in zip(function_data[orig_func_name]["types"],
                            output_library_func.signature.input_arg):
        arg.type = dtype
      if function_data[orig_func_name]["is_also_output_type"]:
        for dtype, arg in zip(function_data[orig_func_name]["types"],
                              output_library_func.signature.output_arg):
          arg.type = dtype

      # Update the NodeDefs.
      func_variables = {
          node.name: node.input[0]
          for node in input_library_func.node_def
          if node.op == "ReadVariableOp"
      }

      for input_node in input_library_func.node_def:
        output_node = output_library_func.node_def.add()
        # Convert ReadVariableOps to Identity ops.
        if input_node.op == "ReadVariableOp":
          _populate_identity_op(output_node, input_node)
        # Update the function names and argument types for the conditional ops.
        elif input_node.op in _CONDITIONAL_OPS:
          _populate_if_op(output_node, input_node, function_data)
        elif input_node.op in _LOOP_OPS:
          _populate_while_op(output_node, input_node, function_data)
        else:
          output_node.CopyFrom(input_node)
          # Convert :value to :output for ops that use the ReadVariableOp.
          for idx, full_name in enumerate(input_node.input):
            input_name = _get_tensor_name(full_name)
            if input_name in func_variables:
              full_name_parts = full_name.split(":")
              full_name_parts[1] = "output"
              input_name = ":".join(full_name_parts)
              output_node.input[idx] = input_name

  output_graph_def.versions.CopyFrom(graph_def.versions)
  return (output_graph_def, converted_input_indices)
Exemplo n.º 4
0
    def create_test_graph(self):
        input_node = node_def_pb2.NodeDef()
        input_node.name = "input"
        input_node.op = "Placeholder"
        input_node.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(
            type=dtypes.float32.as_datatype_enum))
        
        conv1_weight_node = node_def_pb2.NodeDef()
        conv1_weight_node.name = "conv1_weights"
        conv1_weight_node.op = "Const"
        conv1_weight_value = np.float32(np.abs(np.random.randn(3,3,3,32)))
        conv1_weight_node.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
        conv1_weight_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
            tensor=tensor_util.make_tensor_proto(
        conv1_weight_value, conv1_weight_value.dtype.type, conv1_weight_value.shape)))
        
        conv1_node = node_def_pb2.NodeDef()
        conv1_node.name = "conv1"
        conv1_node.op = "Conv2D"
        conv1_node.attr['T'].CopyFrom(attr_value_pb2.AttrValue(
            type=dtypes.float32.as_datatype_enum))
        conv1_node.input.extend([input_node.name, conv1_weight_node.name])
        conv1_node.attr['strides'].CopyFrom(attr_value_pb2.AttrValue(
            list=attr_value_pb2.AttrValue.ListValue(i=[1,2,2,1])))
        conv1_node.attr['dilations'].CopyFrom(attr_value_pb2.AttrValue(
            list=attr_value_pb2.AttrValue.ListValue(i=[1,1,1,1])))
        conv1_node.attr['padding'].CopyFrom(attr_value_pb2.AttrValue(s=b'SAME'))
        conv1_node.attr['data_format'].CopyFrom(attr_value_pb2.AttrValue(s=b'NHWC'))
        
        bias_node = node_def_pb2.NodeDef()
        bias_node.name = "conv1_bias"
        bias_node.op = "Const"
        bias_value = np.float32(np.abs(np.random.randn(32)))
        bias_node.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
        bias_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            bias_value, bias_value.dtype.type, bias_value.shape)))
        
        bias_add_node = node_def_pb2.NodeDef()
        bias_add_node.name = "conv1_bias_add"
        bias_add_node.op = "BiasAdd"
        bias_add_node.attr['T'].CopyFrom(attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
        bias_add_node.input.extend([conv1_node.name, bias_node.name])
        bias_add_node.attr['data_format'].CopyFrom(attr_value_pb2.AttrValue(s=b'NHWC'))
        
        relu_node = node_def_pb2.NodeDef()
        relu_node.op = "Relu"
        relu_node.name = "relu"
        relu_node.input.extend([bias_add_node.name])
        
        conv2_weight_node = node_def_pb2.NodeDef()
        conv2_weight_node.name = "conv2_weights"
        conv2_weight_node.op = "Const"
        conv2_weight_value = np.float32(np.abs(np.random.randn(3,3,3,32)))
        conv2_weight_node.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
        conv2_weight_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
            tensor=tensor_util.make_tensor_proto(
        conv2_weight_value, conv2_weight_value.dtype.type, conv2_weight_value.shape)))
        
        conv2_node = node_def_pb2.NodeDef()
        conv2_node.name = "conv2"
        conv2_node.op = "Conv2D"
        conv2_node.attr['T'].CopyFrom(attr_value_pb2.AttrValue(
            type=dtypes.float32.as_datatype_enum))
        conv2_node.input.extend([relu_node.name, conv2_weight_node.name])
        conv2_node.attr['strides'].CopyFrom(attr_value_pb2.AttrValue(
            list=attr_value_pb2.AttrValue.ListValue(i=[1,2,2,1])))
        conv2_node.attr['dilations'].CopyFrom(attr_value_pb2.AttrValue(
            list=attr_value_pb2.AttrValue.ListValue(i=[1,1,1,1])))
        conv2_node.attr['padding'].CopyFrom(attr_value_pb2.AttrValue(s=b'SAME'))
        conv2_node.attr['data_format'].CopyFrom(attr_value_pb2.AttrValue(s=b'NHWC'))

        bias_node2 = node_def_pb2.NodeDef()
        bias_node2.name = "conv2_bias"
        bias_node2.op = "Const"
        bias_value2 = np.float32(np.abs(np.random.randn(32)))
        bias_node2.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
        bias_node2.attr['value'].CopyFrom(attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            bias_value2, bias_value2.dtype.type, bias_value2.shape)))
        
        bias_add_node2 = node_def_pb2.NodeDef()
        bias_add_node2.name = "conv2_bias_add"
        bias_add_node2.op = "BiasAdd"
        bias_add_node2.attr['T'].CopyFrom(attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
        bias_add_node2.input.extend([conv2_node.name, bias_node2.name])
        bias_add_node2.attr['data_format'].CopyFrom(attr_value_pb2.AttrValue(s=b'NHWC'))
        
        relu_node2 = node_def_pb2.NodeDef()
        relu_node2.op = "Relu"
        relu_node2.name = "relu2"
        relu_node2.input.extend([bias_add_node2.name])
        
        conv3_weight_node = node_def_pb2.NodeDef()
        conv3_weight_node.name = "conv3_weights"
        conv3_weight_node.op = "Const"
        conv3_weight_value = np.float32(np.abs(np.random.randn(3,3,3,32)))
        conv3_weight_node.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
        conv3_weight_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
            tensor=tensor_util.make_tensor_proto(
        conv3_weight_value, conv3_weight_value.dtype.type, conv3_weight_value.shape)))
        
        conv3_node = node_def_pb2.NodeDef()
        conv3_node.name = "conv3"
        conv3_node.op = "Conv2D"
        conv3_node.attr['T'].CopyFrom(attr_value_pb2.AttrValue(
            type=dtypes.float32.as_datatype_enum))
        conv3_node.input.extend([relu_node2.name, conv3_weight_node.name])
        conv3_node.attr['strides'].CopyFrom(attr_value_pb2.AttrValue(
            list=attr_value_pb2.AttrValue.ListValue(i=[1,2,2,1])))
        conv3_node.attr['dilations'].CopyFrom(attr_value_pb2.AttrValue(
            list=attr_value_pb2.AttrValue.ListValue(i=[1,1,1,1])))
        conv3_node.attr['padding'].CopyFrom(attr_value_pb2.AttrValue(s=b'SAME'))
        conv3_node.attr['data_format'].CopyFrom(attr_value_pb2.AttrValue(s=b'NHWC'))

        self.test_graph = graph_pb2.GraphDef()

        self.test_graph.node.extend([input_node, 
                                     conv1_weight_node, 
                                     conv1_node, 
                                     bias_node, 
                                     bias_add_node, 
                                     relu_node,
                                     conv2_weight_node, 
                                     conv2_node, 
                                     bias_node2, 
                                     bias_add_node2, 
                                     relu_node2,
                                     conv3_weight_node, 
                                     conv3_node, 
                                    ])
Exemplo n.º 5
0
    def load_frozenmodel(self):
        """
        loads graph from frozen model file
        """
        print('> Loading frozen model into memory')
        if (self.config.MODEL_TYPE == 'od' and self.config.SPLIT_MODEL):
            # load a frozen Model and split it into GPU and CPU graphs
            # Hardcoded split points for ssd_mobilenet
            input_graph = tf.Graph()
            with tf.Session(graph=input_graph, config=self._tf_config):
                if self.config.SSD_SHAPE == 600:
                    shape = 7326
                else:
                    shape = 1917
                self.score = tf.placeholder(tf.float32,
                                            shape=(None, shape,
                                                   self.config.NUM_CLASSES),
                                            name=self.config.SPLIT_NODES[0])
                self.expand = tf.placeholder(tf.float32,
                                             shape=(None, shape, 1, 4),
                                             name=self.config.SPLIT_NODES[1])
                for node in input_graph.as_graph_def().node:
                    if node.name == self.config.SPLIT_NODES[0]:
                        score_def = node
                    if node.name == self.config.SPLIT_NODES[1]:
                        expand_def = node

            with self.detection_graph.as_default():
                od_graph_def = tf.GraphDef()
                with tf.gfile.GFile(self.config.MODEL_PATH, 'rb') as fid:
                    serialized_graph = fid.read()
                    od_graph_def.ParseFromString(serialized_graph)

                    edges = {}
                    name_to_node_map = {}
                    node_seq = {}
                    seq = 0
                    for node in od_graph_def.node:
                        n = self._node_name(node.name)
                        name_to_node_map[n] = node
                        edges[n] = [self._node_name(x) for x in node.input]
                        node_seq[n] = seq
                        seq += 1
                    for d in self.config.SPLIT_NODES:
                        assert d in name_to_node_map, "%s is not in graph" % d

                    nodes_to_keep = set()
                    next_to_visit = self.config.SPLIT_NODES[:]

                    while next_to_visit:
                        n = next_to_visit[0]
                        del next_to_visit[0]
                        if n in nodes_to_keep: continue
                        nodes_to_keep.add(n)
                        next_to_visit += edges[n]

                    nodes_to_keep_list = sorted(list(nodes_to_keep),
                                                key=lambda n: node_seq[n])
                    nodes_to_remove = set()

                    for n in node_seq:
                        if n in nodes_to_keep_list: continue
                        nodes_to_remove.add(n)
                    nodes_to_remove_list = sorted(list(nodes_to_remove),
                                                  key=lambda n: node_seq[n])

                    keep = graph_pb2.GraphDef()
                    for n in nodes_to_keep_list:
                        keep.node.extend([copy.deepcopy(name_to_node_map[n])])

                    remove = graph_pb2.GraphDef()
                    remove.node.extend([score_def])
                    remove.node.extend([expand_def])
                    for n in nodes_to_remove_list:
                        remove.node.extend(
                            [copy.deepcopy(name_to_node_map[n])])

                    with tf.device('/gpu:0'):
                        tf.import_graph_def(keep, name='')
                    with tf.device('/cpu:0'):
                        tf.import_graph_def(remove, name='')
        else:
            # default model loading procedure
            with self.detection_graph.as_default():
                od_graph_def = tf.GraphDef()
                with tf.gfile.GFile(self.config.MODEL_PATH, 'rb') as fid:
                    serialized_graph = fid.read()
                    od_graph_def.ParseFromString(serialized_graph)
                    tf.import_graph_def(od_graph_def, name='')
Exemplo n.º 6
0
# read frozen graph and display nodes
graph = tf.get_default_graph().as_graph_def(add_shapes=True)
with tf.gfile.Open('model/original/hep_frozen_bs_32.pb', 'rb') as f:
    data = f.read()
    graph.ParseFromString(data)

    #full graph
    #display_nodes(graph.node)

    #prune graph

    #first step
    graph.node[25].input[0] = 'Relu'
    nodes = graph.node[:12] + graph.node[23:]
    graph = graph_pb2.GraphDef()
    graph.node.extend(nodes)

    #next step
    graph.node[33].input[0] = 'Relu_1'
    nodes = graph.node[:20] + graph.node[31:]
    graph = graph_pb2.GraphDef()
    graph.node.extend(nodes)

    #next step
    graph.node[41].input[0] = 'Relu_2'
    nodes = graph.node[:28] + graph.node[39:]
    graph = graph_pb2.GraphDef()
    graph.node.extend(nodes)

    #next step
Exemplo n.º 7
0
import time

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Tensorflow Openpose Inference')
    parser.add_argument('--imgpath', type=str, default='./images/wywh.jpg')
    parser.add_argument('--input-width', type=int, default=656)
    parser.add_argument('--input-height', type=int, default=368)
    args = parser.parse_args()

    t0 = time.time()

    tf.reset_default_graph()
    
    from tensorflow.core.framework import graph_pb2
    graph_def = graph_pb2.GraphDef()
    # Download model from https://www.dropbox.com/s/2dw1oz9l9hi9avg/optimized_openpose.pb
    with open('models/optimized_openpose.pb', 'rb') as f:
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

    t1 = time.time()
    print(t1 - t0)

    inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
    heatmaps_tensor = tf.get_default_graph().get_tensor_by_name('Mconv7_stage6_L2/BiasAdd:0')
    pafs_tensor = tf.get_default_graph().get_tensor_by_name('Mconv7_stage6_L1/BiasAdd:0')

    t2 = time.time()
    print(t2 - t1)
Exemplo n.º 8
0
def fold_batch_norms(input_graph_def):
  """Removes batch normalization ops by folding them into convolutions.

  Batch normalization during training has multiple dynamic parameters that are
  updated, but once the graph is finalized these become constants. That means
  there's an opportunity to reduce the computations down to a scale and
  addition, rather than the more expensive multiple ops, and even bake the
  scaling into the convolution weights. This function identifies the typical
  pattern of batch normalization subgraphs, and performs the transformation to
  fold the computations down into a simpler form. It currently only spots batch
  normalization that's performed by the BatchNormWithGlobalNormalization and
  FusedBatchNorm ops, and will need to be extended in the future to handle the
  newer style.

  Args:
    input_graph_def: A GraphDef containing a model.

  Returns:
    Modified graph with BN ops removed, and modified weights.

  Raises:
    ValueError: If the graph is badly formed with duplicate node names.
  """
  input_node_map = {}
  for node in input_graph_def.node:
    if node.name not in input_node_map:
      input_node_map[node.name] = node
    else:
      raise ValueError("Duplicate node names detected for ", node.name)

  nodes_to_skip = {}
  new_ops = []
  for node in input_graph_def.node:
    if node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm"):
      continue

    conv_op = node_from_map(input_node_map,
                            node.input[INPUT_ORDER[node.op].index("conv_op")])
    if conv_op.op != "Conv2D":
      tf_logging.warning(
          "Didn't find expected Conv2D input to '%s'" % node.name)
      continue

    weights_op = node_from_map(input_node_map, conv_op.input[1])
    if weights_op.op != "Const":
      tf_logging.warning("Didn't find expected conv Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (conv_op.name, weights_op))
      continue
    weights = values_from_const(weights_op)
    channel_count = weights.shape[3]

    mean_op = node_from_map(input_node_map,
                            node.input[INPUT_ORDER[node.op].index("mean_op")])
    if mean_op.op != "Const":
      tf_logging.warning("Didn't find expected mean Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (node.name, mean_op))
      continue
    mean_value = values_from_const(mean_op)
    if mean_value.shape != (channel_count,):
      tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
                         " for node %s" % (str(mean_value.shape), str(
                             (channel_count,)), node.name))
      continue

    var_op = node_from_map(input_node_map,
                           node.input[INPUT_ORDER[node.op].index("var_op")])
    if var_op.op != "Const":
      tf_logging.warning("Didn't find expected var Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (node.name, var_op))
      continue
    var_value = values_from_const(var_op)
    if var_value.shape != (channel_count,):
      tf_logging.warning("Incorrect shape for var, found %s, expected %s,"
                         " for node %s" % (str(var_value.shape), str(
                             (channel_count,)), node.name))
      continue

    beta_op = node_from_map(input_node_map,
                            node.input[INPUT_ORDER[node.op].index("beta_op")])
    if beta_op.op != "Const":
      tf_logging.warning("Didn't find expected beta Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (node.name, beta_op))
      continue
    beta_value = values_from_const(beta_op)
    if beta_value.shape != (channel_count,):
      tf_logging.warning("Incorrect shape for beta, found %s, expected %s,"
                         " for node %s" % (str(beta_value.shape), str(
                             (channel_count,)), node.name))
      continue

    gamma_op = node_from_map(input_node_map,
                             node.input[INPUT_ORDER[node.op].index("gamma_op")])
    if gamma_op.op != "Const":
      tf_logging.warning("Didn't find expected gamma Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (node.name, gamma_op))
      continue
    gamma_value = values_from_const(gamma_op)
    if gamma_value.shape != (channel_count,):
      tf_logging.warning("Incorrect shape for gamma, found %s, expected %s,"
                         " for node %s" % (str(gamma_value.shape), str(
                             (channel_count,)), node.name))
      continue

    variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f
    nodes_to_skip[node.name] = True
    nodes_to_skip[weights_op.name] = True
    nodes_to_skip[mean_op.name] = True
    nodes_to_skip[var_op.name] = True
    nodes_to_skip[beta_op.name] = True
    nodes_to_skip[gamma_op.name] = True
    nodes_to_skip[conv_op.name] = True

    if scale_after_normalization(node):
      scale_value = (
          (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) *
          gamma_value)
    else:
      scale_value = (
          1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value))
    offset_value = (-mean_value * scale_value) + beta_value
    scaled_weights = np.copy(weights)
    it = np.nditer(
        scaled_weights, flags=["multi_index"], op_flags=["readwrite"])
    while not it.finished:
      current_scale = scale_value[it.multi_index[3]]
      it[0] *= current_scale
      it.iternext()
    scaled_weights_op = node_def_pb2.NodeDef()
    scaled_weights_op.op = "Const"
    scaled_weights_op.name = weights_op.name
    scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
    scaled_weights_op.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            scaled_weights, weights.dtype.type, weights.shape)))
    new_conv_op = node_def_pb2.NodeDef()
    new_conv_op.CopyFrom(conv_op)
    offset_op = node_def_pb2.NodeDef()
    offset_op.op = "Const"
    offset_op.name = conv_op.name + "_bn_offset"
    offset_op.attr["dtype"].CopyFrom(mean_op.attr["dtype"])
    offset_op.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            offset_value, mean_value.dtype.type, offset_value.shape)))
    bias_add_op = node_def_pb2.NodeDef()
    bias_add_op.op = "BiasAdd"
    bias_add_op.name = node.name
    bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"])
    bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"])
    bias_add_op.input.extend([new_conv_op.name, offset_op.name])
    new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op])

  result_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in nodes_to_skip:
      continue
    new_node = node_def_pb2.NodeDef()
    new_node.CopyFrom(node)
    result_graph_def.node.extend([new_node])

  result_graph_def.node.extend(new_ops)
  return result_graph_def
Exemplo n.º 9
0
def fuse_resize_and_conv(input_graph_def, output_node_names):
  """Merges preceding resize and mirror pad ops into a specialized convolution.

  There's a common pattern of enlarging the input to a convolution using a
  resize operation, and also using MirrorPad to extend the boundaries to that
  zero edge pixels don't bleed inwards when convolving. This routine looks for
  that pattern of operations, and fuses them together into a Conv2DWithResizeOp.

  Args:
    input_graph_def: A GraphDef containing a model.
    output_node_names: A list of names of the nodes that produce the final
      results.

  Returns:
    Modified graph with resize and pad ops merged.

  Raises:
    ValueError: If the graph is badly formed with duplicate node names.
  """

  input_node_map = {}
  for node in input_graph_def.node:
    if node.name not in input_node_map:
      input_node_map[node.name] = node
    else:
      raise ValueError("Duplicate node names detected for ", node.name)

  node_reference_count = collections.defaultdict(int)
  for node in input_graph_def.node:
    for input_name in node.input:
      stripped_name = node_name_from_input(input_name)
      node_reference_count[stripped_name] += 1
  for output_name in output_node_names:
    node_reference_count[output_name] += 1

  new_ops = []
  for node in input_graph_def.node:

    if node.op != "Conv2D":
      continue
    conv_op = node

    input_op = node_from_map(input_node_map, conv_op.input[0])
    if input_op.op == "MirrorPad":
      mirror_pad_op = input_op
      resize_op = node_from_map(input_node_map, mirror_pad_op.input[0])
      if resize_op.op != "ResizeBilinear":
        resize_op = None
    else:
      mirror_pad_op = None
      if input_op.op == "ResizeBilinear":
        resize_op = input_op
      else:
        resize_op = None

    # There are no ops to be fused into the conv, so skip replacing this one.
    if not mirror_pad_op and not resize_op:
      continue

    # We're replacing this node, so make sure the old one is removed.
    node_reference_count[conv_op.name] = 0
    if mirror_pad_op:
      node_reference_count[mirror_pad_op.name] -= 1
    if resize_op:
      node_reference_count[resize_op.name] -= 1

    fused_conv_op = node_def_pb2.NodeDef()
    if resize_op:
      fused_conv_op.op = "FusedResizeAndPadConv2D"
    else:
      fused_conv_op.op = "FusedPadConv2D"
    fused_conv_op.name = conv_op.name
    if mirror_pad_op:
      mirror_paddings_name = mirror_pad_op.input[1]
      mirror_paddings_mode = mirror_pad_op.attr["mode"]
    else:
      # If there was no MirrorPad op, then create settings that make the padding
      # stage of the fused operation a no-op.
      paddings_op = node_def_pb2.NodeDef()
      paddings_op.op = "Const"
      paddings_op.name = conv_op.name + "_dummy_paddings"
      paddings_op.attr["dtype"].CopyFrom(
          attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum))
      paddings_op.attr["value"].CopyFrom(
          attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
              [0, 0, 0, 0, 0, 0, 0, 0], dtypes.int32, [4, 2])))
      new_ops.extend([paddings_op])
      mirror_paddings_name = paddings_op.name
      mirror_paddings_mode = attr_value_pb2.AttrValue(s=b"REFLECT")
    if resize_op:
      fused_conv_op.input.extend([
          resize_op.input[0], resize_op.input[1], mirror_paddings_name,
          conv_op.input[1]
      ])
      fused_conv_op.attr["resize_align_corners"].CopyFrom(
          resize_op.attr["align_corners"])
    else:
      fused_conv_op.input.extend(
          [mirror_pad_op.input[0], mirror_paddings_name, conv_op.input[1]])
    fused_conv_op.attr["T"].CopyFrom(conv_op.attr["T"])
    fused_conv_op.attr["mode"].CopyFrom(mirror_paddings_mode)
    fused_conv_op.attr["strides"].CopyFrom(conv_op.attr["strides"])
    fused_conv_op.attr["padding"].CopyFrom(conv_op.attr["padding"])
    new_ops.extend([fused_conv_op])

  result_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node_reference_count[node.name] < 1:
      continue
    new_node = node_def_pb2.NodeDef()
    new_node.CopyFrom(node)
    result_graph_def.node.extend([new_node])

  result_graph_def.node.extend(new_ops)
  return result_graph_def
    def test_remove_redundant_quantization(self):
        a_constant_name = "a_constant"
        a_constant_min_name = "a_constant_min"
        a_constant_max_name = "a_constant_max"
        a_dequantize_name = "a_dequantize"
        a_quantize_name = "a_quantize"
        b_constant_name = "b_constant"
        b_constant_min_name = "b_constant_min"
        b_constant_max_name = "b_constant_max"
        b_dequantize_name = "b_dequantize"
        b_quantize_name = "b_quantize"
        mat_mul_name = "mat_mul"
        graph_def = graph_pb2.GraphDef()
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=(0, ),
                                                         dtype=dtypes.quint8,
                                                         shape=[])
        graph_def.node.extend([a_constant])
        a_constant_min = quantize_graph.create_constant_node(
            a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
        graph_def.node.extend([a_constant_min])
        a_constant_max = quantize_graph.create_constant_node(
            a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
        graph_def.node.extend([a_constant_max])
        a_dequantize_node = quantize_graph.create_node(
            "Dequantize", a_dequantize_name,
            [a_constant_name, a_constant_min_name, a_constant_max_name])
        quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8)
        graph_def.node.extend([a_dequantize_node])
        a_quantize_node = quantize_graph.create_node(
            "QuantizeV2", a_quantize_name, [
                a_dequantize_name, a_dequantize_name + ":1",
                a_dequantize_name + ":2"
            ])
        quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8)
        graph_def.node.extend([a_quantize_node])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=(0, ),
                                                         dtype=dtypes.quint8,
                                                         shape=[])
        graph_def.node.extend([b_constant])
        b_constant_min = quantize_graph.create_constant_node(
            b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
        graph_def.node.extend([b_constant_min])
        b_constant_max = quantize_graph.create_constant_node(
            b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
        graph_def.node.extend([b_constant_max])
        b_dequantize_node = quantize_graph.create_node(
            "Dequantize", b_dequantize_name,
            [b_constant_name, b_constant_min_name, b_constant_max_name])
        quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8)
        graph_def.node.extend([b_dequantize_node])
        b_quantize_node = quantize_graph.create_node(
            "QuantizeV2", b_quantize_name, [
                b_dequantize_name, b_dequantize_name + ":1",
                b_dequantize_name + ":2"
            ])
        quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8)
        graph_def.node.extend([b_quantize_node])
        mat_mul_node = quantize_graph.create_node(
            "QuantizedMatMul", mat_mul_name, [
                a_quantize_name, b_quantize_name, a_quantize_name + ":1",
                a_quantize_name + ":2", b_quantize_name + ":1",
                b_quantize_name + ":2"
            ])
        quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
        quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
        graph_def.node.extend([mat_mul_node])

        expected_output = graph_pb2.GraphDef()
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=(0, ),
                                                         dtype=dtypes.quint8,
                                                         shape=[])
        expected_output.node.extend([a_constant])
        a_constant_min = quantize_graph.create_constant_node(
            a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
        expected_output.node.extend([a_constant_min])
        a_constant_max = quantize_graph.create_constant_node(
            a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
        expected_output.node.extend([a_constant_max])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=(0, ),
                                                         dtype=dtypes.quint8,
                                                         shape=[])
        expected_output.node.extend([b_constant])
        b_constant_min = quantize_graph.create_constant_node(
            b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
        expected_output.node.extend([b_constant_min])
        b_constant_max = quantize_graph.create_constant_node(
            b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
        expected_output.node.extend([b_constant_max])
        mat_mul_node = quantize_graph.create_node(
            "QuantizedMatMul", mat_mul_name, [
                a_constant_name, b_constant_name, a_constant_min_name,
                a_constant_max_name, b_constant_min_name, b_constant_max_name
            ])
        quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
        quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
        expected_output.node.extend([mat_mul_node])
        expected_output.versions.CopyFrom(graph_def.versions)
        expected_output.library.CopyFrom(graph_def.library)

        rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name],
                                                quantized_input_range=None)
        output = rewriter.remove_redundant_quantization(graph_def)
        stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
        self.assertProtoEquals(expected_output, stripped_output)
Exemplo n.º 11
0
    def do_transformation(self):
        GraphAnalyzer().graph = self.model

        graph_info = GraphAnalyzer().parse_graph()
        node_hash_info = {}
        loc_attr_node = []

        for _, v in graph_info.items():
            if '_class' in v.node.attr:
                loc_attr_node.append(
                    v.node.attr['_class'].list.s[0].decode().split(':@')[-1])

        for node_name, i in graph_info.items():
            if node_name in loc_attr_node or i.node.op not in ('QuantizeV2',
                                                               "Const"):
                continue

            hash_value = self._gen_node_hash(graph_info, i.node)

            if hash_value not in node_hash_info:
                node_hash_info[hash_value] = [node_name]

            if node_name not in node_hash_info[hash_value]:
                node_hash_info[hash_value].append(node_name)

        for _, v in node_hash_info.items():
            if len(v) == 1 or v[0] not in graph_info:
                continue
            node_type = graph_info[v[0]].node.op
            for j in v[1:]:
                if node_type == 'Const' and j in graph_info:
                    output_op_types = [
                        graph_info[out_name].node.op in self.control_op_types
                        for out_name in graph_info[j].outputs
                    ]
                    if any(output_op_types):
                        continue

                    next_node = graph_info[j].outputs[0]
                    matched_index = 0
                    for index, origin_input in enumerate(
                            graph_info[next_node].node.input):
                        if origin_input == j:
                            matched_index = index
                            break

                    graph_info[next_node].node.input[matched_index] = v[0]
                    graph_info[v[0]].outputs.append(j)
                    graph_info.pop(j)

                elif node_type == 'QuantizeV2':
                    next_node = graph_info[j].outputs[0]
                    quantize_v2_output_names = (j, j + ':1', j + ':2')

                    replace_index = [
                        list(graph_info[next_node].node.input).index(i)
                        for i in quantize_v2_output_names
                    ]

                    graph_info[next_node].node.input[replace_index[0]] = v[0]

                    graph_info[next_node].node.input[
                        replace_index[1]] = v[0] + ':1'
                    graph_info[next_node].node.input[
                        replace_index[2]] = v[0] + ':2'

                    graph_info[v[0]].outputs.append(j)
                    graph_info.pop(graph_info[j].node.input[1])
                    graph_info.pop(graph_info[j].node.input[2])

                    graph_info.pop(j)
                else:
                    self.logger.debug('Unknown Op type {}'.format(node_type))

        output_graph_def = graph_pb2.GraphDef()

        for _, node_info in graph_info.items():
            output_graph_def.node.extend([node_info.node])

        return output_graph_def
    def test_keep_control_edges(self):
        no_op_name = "no_op"
        a_constant_name = "a_constant"
        b_constant_name = "b_constant"
        a_check_name = "a_check"
        b_check_name = "b_check"
        a_identity_name = "a_identity"
        b_identity_name = "b_identity"
        add_name = "add"
        graph_def = graph_pb2.GraphDef()
        no_op = quantize_graph.create_node("NoOp", no_op_name, [])
        graph_def.node.extend([no_op])
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=1,
                                                         dtype=dtypes.float32,
                                                         shape=[])
        graph_def.node.extend([a_constant])
        a_check_node = quantize_graph.create_node("CheckNumerics",
                                                  a_check_name,
                                                  [a_constant_name])
        graph_def.node.extend([a_check_node])
        a_identity_node = quantize_graph.create_node(
            "Identity", a_identity_name,
            [a_constant_name, "^" + a_check_name, "^" + no_op_name])
        graph_def.node.extend([a_identity_node])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=1,
                                                         dtype=dtypes.float32,
                                                         shape=[])
        graph_def.node.extend([b_constant])
        b_check_node = quantize_graph.create_node("CheckNumerics",
                                                  b_check_name,
                                                  [b_constant_name])
        graph_def.node.extend([b_check_node])
        b_identity_node = quantize_graph.create_node(
            "Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
        graph_def.node.extend([b_identity_node])
        add_node = quantize_graph.create_node(
            "Add", add_name, [a_identity_name, b_identity_name])
        quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
        graph_def.node.extend([add_node])

        expected_output = graph_pb2.GraphDef()
        no_op = quantize_graph.create_node("NoOp", no_op_name, [])
        expected_output.node.extend([no_op])
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=1,
                                                         dtype=dtypes.float32,
                                                         shape=[])
        expected_output.node.extend([a_constant])
        a_identity_node = quantize_graph.create_node(
            "Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
        expected_output.node.extend([a_identity_node])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=1,
                                                         dtype=dtypes.float32,
                                                         shape=[])
        expected_output.node.extend([b_constant])
        add_node = quantize_graph.create_node(
            "Add", add_name, [a_identity_name, b_constant_name])
        quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
        expected_output.node.extend([add_node])
        expected_output.versions.CopyFrom(graph_def.versions)
        expected_output.library.CopyFrom(graph_def.library)

        output = graph_util.remove_training_nodes(graph_def)
        stripped_output = graph_util.extract_sub_graph(output, [add_name])
        self.assertProtoEquals(expected_output, stripped_output)
Exemplo n.º 13
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj)
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train():
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=flags_obj.epochs_between_evals,
            num_gpus=flags_core.get_num_gpus(flags_obj))

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1)

    total_training_cycle = (flags_obj.train_epochs //
                            flags_obj.epochs_between_evals)
    profiler_hook = tf.train.ProfilerHook(save_steps=100,
                                          save_secs=None,
                                          output_dir="profs",
                                          show_memory=True,
                                          show_dataflow=True)

    #DOGA DEBUG GRAPH
    gdef = gpb.GraphDef()

    with open('/tmp/cifar10_model/graph.pbtxt', 'r') as fh:
        graph_str = fh.read()

    pbtf.Parse(graph_str, gdef)
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(gdef)

        operations_tensors = {}
        operations_names = graph.get_operations()
        count1 = 0
        count2 = 0
        #print(operations_names)
        for operation in operations_names:
            operation_name = operation.name
            operations_info = graph.get_operation_by_name(
                operation_name).values()
            if len(operations_info) > 0:
                if not (operations_info[0].shape.ndims is None):
                    operation_shape = operations_info[0].shape.as_list()
                    operation_dtype_size = operations_info[0].dtype.size
                    if not (operation_dtype_size is None):
                        operation_no_of_elements = 1
                        for dim in operation_shape:
                            if not (dim is None):
                                operation_no_of_elements = operation_no_of_elements * dim
                        total_size = operation_no_of_elements * operation_dtype_size
                        operations_tensors[operation_name] = total_size
                    else:
                        count1 = count1 + 1
                else:
                    count1 = count1 + 1
                    operations_tensors[operation_name] = -1
            else:
                count2 = count2 + 1
                operations_tensors[operation_name] = -1

        print(count1)
        print(count2)

        with open('tensors_sz.json', 'w') as f:
            json.dump(operations_tensors, f)

    for cycle_index in range(total_training_cycle):
        tf.logging.info('Starting a training cycle: %d/%d', cycle_index,
                        total_training_cycle)

        classifier.train(input_fn=input_fn_train,
                         hooks=[profiler_hook],
                         max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Exemplo n.º 14
0
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
            output_quantized, op_name, op_type):
    """Fuse subgraph between input_nodes and output_nodes into a single custom op.

  Args:
    graph_def: A graph_pb2.GraphDef proto.
    input_nodes: input nodes to the subgraph to be fused.
    output_nodes: output nodes to the subgraph to be fused.
    output_dtypes: A list of output datatypes for the custom op
    output_quantized: A boolean flag that indicates if output is quantized
    op_name: fused op name.
    op_type: fused op type.
  Returns:
    The GraphDef of the new graph.

  Raises:
    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
  """

    if not isinstance(graph_def, graph_pb2.GraphDef):
        raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")

    if isinstance(input_nodes, six.string_types):
        raise TypeError("input_nodes must be a list.")

    if isinstance(output_nodes, six.string_types):
        raise TypeError("output_nodes must be a list.")

    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
        graph_def)
    _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)

    # Nodes upto and including input_nodes
    reachable_by_input = _bfs_for_reachable_nodes(input_nodes,
                                                  name_to_input_name)
    # Nodes upto and including output_nodes
    reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
                                                   name_to_input_name)

    # Set of nodes in the list input_nodes
    input_nodes_set = set(input_nodes)

    # Set of nodes in the list output_nodes
    output_nodes_set = set(output_nodes)

    nodes_post_output = []
    for node in graph_def.node:
        n = _node_name(node.name)
        if n in reachable_by_output:
            if n not in reachable_by_input and n not in output_nodes_set:
                # n is between input and output, i.e., part of the fused op
                next_to_visit = [n]
                while next_to_visit:
                    cur_node = next_to_visit[0]
                    del next_to_visit[0]
                    if cur_node in reachable_by_input and cur_node not in input_nodes_set:
                        raise TypeError(
                            "Node %s uses input %s not in input_nodes." %
                            (n, cur_node))
                    if cur_node not in input_nodes_set:
                        next_to_visit += name_to_input_name[cur_node]
        elif n not in reachable_by_input:
            nodes_post_output.append(n)

    # Add all nodes upto the input nodes
    out = graph_pb2.GraphDef()
    reachable_by_input_sorted = sorted(list(reachable_by_input),
                                       key=lambda n: name_to_seq_num[n])
    for node in reachable_by_input_sorted:
        out.node.extend([copy.deepcopy(name_to_node[node])])

    # Add the custom op
    new_node = node_def_pb2.NodeDef()
    for node in input_nodes:
        new_node.input.append(node)
    new_node.attr["_output_types"].list.type[:] = output_dtypes
    new_node.attr["_output_quantized"].b = output_quantized
    new_node.op = op_type
    new_node.name = op_name
    out.node.extend([new_node])

    # Add the nodes in the output of the custom op
    for index, n in enumerate(output_nodes):
        assert len(name_to_node[n].input) == 1
        new_node = copy.deepcopy(name_to_node[n])
        del new_node.input[:]
        new_node.input.append(op_name +
                              (":" + str(index) if index != 0 else ""))
        out.node.extend([new_node])

    # Add the nodes post output_nodes
    for n in nodes_post_output:
        out.node.extend([copy.deepcopy(name_to_node[n])])

    out.library.CopyFrom(graph_def.library)
    out.versions.CopyFrom(graph_def.versions)
    return out
Exemplo n.º 15
0
 def _StripGraph(self, gd):
   """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
   return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
Exemplo n.º 16
0
    def _testFreezeGraph(self, saver_write_version):

        checkpoint_prefix = os.path.join(self.get_temp_dir(),
                                         "saved_checkpoint")
        checkpoint_meta_graph_file = os.path.join(self.get_temp_dir(),
                                                  "saved_checkpoint.meta")
        checkpoint_state_name = "checkpoint_state"
        input_graph_name = "input_graph.pb"
        output_graph_name = "output_graph.pb"

        # We'll create an input graph that has a single variable containing 1.0,
        # and that then multiplies it by 2.
        with ops.Graph().as_default():
            variable_node = variables.VariableV1(1.0, name="variable_node")
            output_node = math_ops.multiply(variable_node,
                                            2.0,
                                            name="output_node")
            sess = session.Session()
            init = variables.global_variables_initializer()
            sess.run(init)
            output = sess.run(output_node)
            self.assertNear(2.0, output, 0.00001)
            saver = saver_lib.Saver(write_version=saver_write_version)
            checkpoint_path = saver.save(sess,
                                         checkpoint_prefix,
                                         global_step=0,
                                         latest_filename=checkpoint_state_name)
            graph_io.write_graph(sess.graph, self.get_temp_dir(),
                                 input_graph_name)

        # We save out the graph to disk, and then call the const conversion
        # routine.
        input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
        input_saver_def_path = ""
        input_binary = False
        output_node_names = "output_node"
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        output_graph_path = os.path.join(self.get_temp_dir(),
                                         output_graph_name)
        clear_devices = False
        input_meta_graph = checkpoint_meta_graph_file

        freeze_graph.freeze_graph(input_graph_path,
                                  input_saver_def_path,
                                  input_binary,
                                  checkpoint_path,
                                  output_node_names,
                                  restore_op_name,
                                  filename_tensor_name,
                                  output_graph_path,
                                  clear_devices,
                                  "",
                                  "",
                                  input_meta_graph,
                                  checkpoint_version=saver_write_version)

        # Now we make sure the variable is now a constant, and that the graph still
        # produces the expected result.
        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                _ = importer.import_graph_def(output_graph_def, name="")

            self.assertEqual(4, len(output_graph_def.node))
            for node in output_graph_def.node:
                self.assertNotEqual("VariableV2", node.op)
                self.assertNotEqual("Variable", node.op)

            with session.Session() as sess:
                output_node = sess.graph.get_tensor_by_name("output_node:0")
                output = sess.run(output_node)
                self.assertNear(2.0, output, 0.00001)
Exemplo n.º 17
0
    def generate_output_graph(self, input_graph_def, input_node_map,
                              fuse_op_list):
        output_graph_def = graph_pb2.GraphDef()
        skip_list = []
        skip_node_name = []
        float32_type = dtypes.float32.as_datatype_enum
        for index, node in enumerate(input_graph_def.node):
            if index in fuse_op_list:
                input_node = input_node_map[node.input[0]]
                if input_node.op == 'QuantizeV2':
                    new_node = node_def_pb2.NodeDef()

                    new_node.op = node.op + "AndDequantize"
                    for _, value in enumerate(node.input):
                        new_node.input.append(value)

                    dequantize_node = input_graph_def.node[index + 4]
                    frozen_max_node = input_graph_def.node[index + 2]
                    frozen_min_node = input_graph_def.node[index + 1]

                    new_node.name = dequantize_node.name

                    new_node.input.append(frozen_min_node.name)
                    new_node.input.append(frozen_max_node.name)

                    new_node.attr["T1"].CopyFrom(node.attr['T1'])
                    new_node.attr["T2"].CopyFrom(node.attr['T2'])

                    new_node.attr["Tbias"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))
                    new_node.attr["Toutput"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))

                    skip_list.append(index + 1)
                    skip_list.append(index + 2)
                    skip_list.append(index + 3)
                    skip_list.append(index + 4)
                    output_graph_def.node.extend(
                        [new_node, frozen_max_node, frozen_min_node])
                elif input_node.op == "Requantize":
                    new_node = node_def_pb2.NodeDef()
                    new_node.op = node.op + "AndDequantize"
                    for _, value in enumerate(node.input):
                        new_node.input.append(value)

                    dequantize_node = input_graph_def.node[index + 4]
                    frozen_max_node = input_graph_def.node[index + 2]
                    frozen_min_node = input_graph_def.node[index + 1]
                    new_node.name = dequantize_node.name
                    skip_list.append(index + 1)
                    skip_list.append(index + 2)
                    skip_list.append(index + 3)
                    skip_list.append(index + 4)
                    new_node.input.append(frozen_min_node.name)
                    new_node.input.append(frozen_max_node.name)

                    new_node.attr["T1"].CopyFrom(node.attr['T1'])
                    new_node.attr["T2"].CopyFrom(node.attr['T2'])

                    new_node.attr["Tbias"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))
                    new_node.attr["Toutput"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))

                    output_graph_def.node.extend(
                        [new_node, frozen_max_node, frozen_min_node])
                else:
                    new_node = node_def_pb2.NodeDef()
                    new_node.CopyFrom(node)
                    output_graph_def.node.extend([new_node])

            elif index in skip_list or node.name in skip_node_name:
                continue
            else:
                new_node = node_def_pb2.NodeDef()
                new_node.CopyFrom(node)
                output_graph_def.node.extend([new_node])
        return output_graph_def
Exemplo n.º 18
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
    """Converts all variables in a graph and checkpoint into constants."""

    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    mode = "rb" if input_binary else "r"
    with gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""

    _ = importer.import_graph_def(input_graph_def, name="")

    with session.Session() as sess:
        if input_saver:
            with gfile.FastGFile(input_saver, mode) as f:
                saver_def = saver_pb2.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = saver_lib.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
            saver = saver_lib.Saver(var_list=var_list)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes)

        variable_names_blacklist = (variable_names_blacklist.split(",")
                                    if variable_names_blacklist else None)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names.split(","),
            variable_names_blacklist=variable_names_blacklist)

    with gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
Exemplo n.º 19
0
def export_scoped_meta_graph(filename=None,
                             graph_def=None,
                             graph=None,
                             export_scope=None,
                             as_text=False,
                             unbound_inputs_col_name="unbound_inputs",
                             clear_devices=False,
                             saver_def=None,
                             clear_extraneous_savers=False,
                             strip_default_attrs=False,
                             save_debug_info=False,
                             **kwargs):
    """Returns `MetaGraphDef` proto. Optionally writes it to filename.

  This function exports the graph, saver, and collection objects into
  `MetaGraphDef` protocol buffer with the intention of it being imported
  at a later time or location to restart training, run inference, or be
  a subgraph.

  Args:
    filename: Optional filename including the path for writing the
      generated `MetaGraphDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    graph: The `Graph` to export. If `None`, use the default graph.
    export_scope: Optional `string`. Name scope under which to extract
      the subgraph. The scope name will be stripped from the node definitions
      for easy import later into new name scopes. If `None`, the whole graph
      is exported.
    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
    unbound_inputs_col_name: Optional `string`. If provided, a string collection
      with the given name will be added to the returned `MetaGraphDef`,
      containing the names of tensors that must be remapped when importing the
      `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      before exporting the graph.
    saver_def: `SaverDef` protocol buffer.
    clear_extraneous_savers: Remove any Saver-related information from the
        graph (both Save/Restore ops and SaverDefs) that are not associated
        with the provided SaverDef.
    strip_default_attrs: Set to true if default valued attributes must be
      removed while exporting the GraphDef.
    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
      which in the same directory of filename and with `_debug` added before the
      file extension.
    **kwargs: Optional keyed arguments, including meta_info_def and
        collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
    ValueError: When executing in Eager mode and either `graph_def` or `graph`
      is undefined.
  """
    if context.executing_eagerly() and not (graph_def is not None
                                            and graph is not None):
        raise ValueError(
            "Exporting/importing meta graphs is not supported when "
            "Eager Execution is enabled.")
    graph = graph or ops.get_default_graph()

    exclude_nodes = None
    unbound_inputs = []
    if export_scope or clear_extraneous_savers or clear_devices:
        if graph_def:
            new_graph_def = graph_pb2.GraphDef()
            new_graph_def.versions.CopyFrom(graph_def.versions)
            new_graph_def.library.CopyFrom(graph_def.library)

            if clear_extraneous_savers:
                exclude_nodes = _find_extraneous_saver_nodes(
                    graph_def, saver_def)

            for node_def in graph_def.node:
                if _should_include_node(node_def.name, export_scope,
                                        exclude_nodes):
                    new_node_def = _node_def(node_def,
                                             export_scope,
                                             unbound_inputs,
                                             clear_devices=clear_devices)
                    new_graph_def.node.extend([new_node_def])
            graph_def = new_graph_def
        else:
            # Only do this complicated work if we want to remove a name scope.
            graph_def = graph_pb2.GraphDef()
            # pylint: disable=protected-access
            graph_def.versions.CopyFrom(graph.graph_def_versions)
            bytesize = 0

            if clear_extraneous_savers:
                exclude_nodes = _find_extraneous_saver_nodes(
                    graph.as_graph_def(), saver_def)

            for key in sorted(graph._nodes_by_id):
                if _should_include_node(graph._nodes_by_id[key].name,
                                        export_scope, exclude_nodes):
                    value = graph._nodes_by_id[key]
                    # pylint: enable=protected-access
                    node_def = _node_def(value.node_def,
                                         export_scope,
                                         unbound_inputs,
                                         clear_devices=clear_devices)
                    graph_def.node.extend([node_def])
                    if value.outputs:
                        assert "_output_shapes" not in graph_def.node[-1].attr
                        graph_def.node[-1].attr[
                            "_output_shapes"].list.shape.extend([
                                output.get_shape().as_proto()
                                for output in value.outputs
                            ])
                    bytesize += value.node_def.ByteSize()
                    if bytesize >= (1 << 31) or bytesize < 0:
                        raise ValueError("GraphDef cannot be larger than 2GB.")

            graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access

        # It's possible that not all the inputs are in the export_scope.
        # If we would like such information included in the exported meta_graph,
        # add them to a special unbound_inputs collection.
        if unbound_inputs_col_name:
            # Clears the unbound_inputs collections.
            graph.clear_collection(unbound_inputs_col_name)
            for k in unbound_inputs:
                graph.add_to_collection(unbound_inputs_col_name, k)

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=export_scope)
    for v in variables:
        if _should_include_node(v, export_scope, exclude_nodes):
            var_list[ops.strip_name_scope(v.name, export_scope)] = v

    scoped_meta_graph_def = create_meta_graph_def(
        graph_def=graph_def,
        graph=graph,
        export_scope=export_scope,
        exclude_nodes=exclude_nodes,
        clear_extraneous_savers=clear_extraneous_savers,
        saver_def=saver_def,
        strip_default_attrs=strip_default_attrs,
        **kwargs)

    if filename:
        graph_io.write_graph(scoped_meta_graph_def,
                             os.path.dirname(filename),
                             os.path.basename(filename),
                             as_text=as_text)
        if save_debug_info:
            name, _ = os.path.splitext(filename)
            debug_filename = "{name}{ext}".format(name=name, ext=".debug")

            # Gets the operation from the graph by the name. Exludes variable nodes,
            # so only the nodes in the frozen models are included.
            # TODO(liufengdb): fix this for functions.
            ops_to_export = []
            for node in scoped_meta_graph_def.graph_def.node:
                scoped_op_name = ops.prepend_name_scope(
                    node.name, export_scope)
                ops_to_export.append(
                    ("", graph.get_operation_by_name(scoped_op_name)))

            graph_debug_info = error_interpolation.create_graph_debug_info_def(
                ops_to_export)

            graph_io.write_graph(graph_debug_info,
                                 os.path.dirname(debug_filename),
                                 os.path.basename(debug_filename),
                                 as_text=as_text)

    return scoped_meta_graph_def, var_list
Exemplo n.º 20
0
import uff
import copy

graph_filename = 'mrt_graph_1.pb'
graph_filename_converted = 'mrt_graph_2.pb'

f = gfile.FastGFile(graph_filename, 'rb')

# define graph def object
graph_def = tf.GraphDef()

# store frozen graph from pb file
graph_def.ParseFromString(f.read())

# define new empty graph
modified_graph_def = graph_pb2.GraphDef()

# pre-define empty image placeholder node
image_placeholder_node = node_def_pb2.NodeDef()

# iterate through all nodes in graph
for node in graph_def.node:

    # set dtype attibute of imagePlaceholder node to int32
    if node.name == 'vars/Cast':
        print node

# iterate through all nodes in graph
for node in graph_def.node:

    # set dtype attibute of imagePlaceholder node to int32
Exemplo n.º 21
0
    def testDeployCheckpoint(self):
        input_meta_name = "original_meta.meta"
        input_meta_path = os.path.join(self.get_temp_dir(), input_meta_name)
        q_config, _ = self._compose_config()
        with ops.Graph().as_default():
            self._build_graph(is_freezed=False)
            graph_def = ops.get_default_graph().as_graph_def()
            saver_lib.export_meta_graph(filename=input_meta_path)

        original_meta_graph_def = MetaGraphDef()
        original_meta_graph_def = self._parse_def_from_file(
            original_meta_graph_def, input_meta_path)
        decent_q.quantize_train(original_meta_graph_def, q_config)

        quant_train_meta_graph_def = MetaGraphDef()
        quant_train_meta_graph_path = os.path.join(
            self.get_temp_dir(), "quantize_train/quantize_train.ckpt.meta")
        quant_train_meta_graph_def = self._parse_def_from_file(
            quant_train_meta_graph_def, quant_train_meta_graph_path)
        with ops.Graph().as_default():
            new_saver = saver_lib.import_meta_graph(quant_train_meta_graph_def)
            with session.Session() as sess:
                w_t = sess.graph.get_tensor_by_name("w/read/wquant:0")
                b_t = sess.graph.get_tensor_by_name("b/read/wquant:0")
                relu_t = sess.graph.get_tensor_by_name("relu/aquant:0")
                input_fn = self._mock_input_fn("input:0", [1, 4, 4, 3])
                init = variables.global_variables_initializer()
                sess.run(init)
                eval_relu, eval_w, eval_b = sess.run([relu_t, w_t, b_t],
                                                     feed_dict=input_fn(1))

                checkpoint_prefix = os.path.join(self.get_temp_dir(),
                                                 "ckpt/saved_checkpoint")
                checkpoint_state_name = "checkpoint_state"
                checkpoint_path = new_saver.save(
                    sess,
                    checkpoint_prefix,
                    global_step=0,
                    latest_filename=checkpoint_state_name)
        q_config.output_nodes = ["relu/aquant"]
        decent_q.quantize_evaluate(quant_train_meta_graph_def, q_config)
        quant_eval_meta_graph_def = MetaGraphDef()
        quant_eval_meta_graph_path = os.path.join(
            self.get_temp_dir(), "quantize_eval/quantize_eval.ckpt.meta")
        quant_eval_meta_graph_def = self._parse_def_from_file(
            quant_eval_meta_graph_def, quant_eval_meta_graph_path)
        sess.close()
        decent_q.deploy_checkpoint(quant_eval_meta_graph_def, checkpoint_path,
                                   q_config)
        deploy_graph_def = graph_pb2.GraphDef()
        deploy_graph_path = os.path.join(self.get_temp_dir(),
                                         "deploy/deploy_model.pb")
        deploy_graph_def = self._parse_def_from_file(deploy_graph_def,
                                                     deploy_graph_path)
        for node in deploy_graph_def.node:
            if node.name == "conv2d":
                # need to equal with quantize pos in quantize_eval_model.pb
                self.assertAllEqual(node.attr['ipos'].list.i, [8, 6])
                self.assertAllEqual(node.attr['wpos'].list.i, [8, 7])
                self.assertAllEqual(node.attr['bpos'].list.i, [8, 8])
                self.assertAllEqual(node.attr['opos'].list.i, [8, 4])
                deploy_w = tensor_util.MakeNdarray(node.attr['weights'].tensor)
                deploy_b = tensor_util.MakeNdarray(node.attr['bias'].tensor)
                self.assertNDArrayNear(deploy_w, eval_w, 1e-6)
                self.assertNDArrayNear(deploy_b, eval_b, 1e-6)
Exemplo n.º 22
0
def convert_variables_to_constants_v2(func, lower_control_flow=True):
  """Replaces all the variables in a graph with constants of the same values.

  TensorFlow 2.0 function for converting all Variable ops into Const ops holding
  the same values. This makes it possible to describe the network fully with a
  single GraphDef file, and allows the removal of a lot of ops related to
  loading and saving the variables. This function runs Grappler's function
  inlining optimization in order to return a single subgraph.

  The current implementation only works for graphs that do not contain any
  control flow or embedding related ops.

  Args:
    func: ConcreteFunction.
    lower_control_flow: Boolean indicating whether or not to lower control flow
      ops such as If and While. (default True)

  Returns:
    ConcreteFunction containing a simplified version of the original.
  """
  # TODO(nupurgarg): Replace ResourceGather with Gather.
  # Inline the graph in order to remove functions when possible.
  graph_def = _run_inline_graph_optimization(func, lower_control_flow)

  # Gets list of all node defs include those in the library.
  node_defs = _get_node_defs_list(graph_def)

  # Get mapping from node name to node.
  name_to_node = {_get_tensor_name(node.name): node for node in node_defs}

  # Get mapping from node name to variable value.
  tensor_data = _get_tensor_data(func)

  # Get mapping from function name to argument types.
  function_types = _get_control_flow_function_types(node_defs, tensor_data)

  # Get variable data for all nodes in `node_defs`.
  reference_variables = {}
  resource_identities = {}
  placeholders = {}
  converted_input_indices = set()

  def _save_placeholder(node_name, dtype):
    placeholders[node_name] = {
        "dtype": dtype,
        "data": tensor_data[node_name]["data"],
    }
    converted_input_indices.add(tensor_data[node_name]["index"])

  for node in node_defs:
    if node.op == "If":
      # Get dtype and data for resource Placeholders.
      then_func = node.attr["then_branch"].func.name
      arg_types = function_types[then_func]
      for idx, input_tensor in enumerate(node.input[1:]):
        input_name = _get_tensor_name(input_tensor)
        if input_name in tensor_data:
          dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
          _save_placeholder(_get_tensor_name(input_tensor), dtype)
    if node.op == "VariableV2":
      # Get data for VariableV2 ops (reference variables) that cannot be lifted.
      with func.graph.as_default():
        identity_node = array_ops.identity(
            func.graph.as_graph_element(node.name + ":0"))
      reference_variables[node.name] = (
          func.prune([], [identity_node.name])()[0])
    elif node.name in tensor_data and not tensor_data[node.name]["is_variable"]:
      # Get dtype and data for non-variable Placeholders (ex. values for 1.X
      # Const ops that are loaded as Placeholders in 2.0)
      _save_placeholder(node.name, node.attr["dtype"])
    elif node.op == "ReadVariableOp":
      # Get dtype and data for Placeholder ops associated with ReadVariableOp.
      # There can be an Identity in between the ReadVariableOp and Placeholder.
      # Store the dtype for the Identity ops.
      input_name = _get_tensor_name(node.input[0])
      while name_to_node[input_name].op == "Identity":
        resource_identities[input_name] = node.attr["dtype"]
        input_name = _get_tensor_name(name_to_node[input_name].input[0])
      if name_to_node[input_name].op != "Placeholder":
        raise ValueError("Cannot find the Placeholder op that is an input "
                         "to the ReadVariableOp.")
      _save_placeholder(input_name, node.attr["dtype"])

  # Reconstruct the graph with constants in place of variables.
  output_graph_def = graph_pb2.GraphDef()

  for input_node in graph_def.node:
    output_node = output_graph_def.node.add()
    # Convert VariableV2 ops to Const ops.
    if input_node.name in reference_variables:
      data = reference_variables[input_node.name]
      dtype = attr_value_pb2.AttrValue(type=data.dtype.as_datatype_enum)
      _populate_const_op(output_node, input_node.name, dtype, data.numpy(),
                         data.shape)
    # Convert Placeholder ops to Const ops.
    elif input_node.name in placeholders:
      data = placeholders[input_node.name]["data"]
      dtype = placeholders[input_node.name]["dtype"]
      _populate_const_op(output_node, input_node.name, dtype, data, data.shape)
    # Update the dtype for Identity ops that are inputs to ReadVariableOps.
    elif input_node.name in resource_identities:
      output_node.CopyFrom(input_node)
      output_node.attr["T"].CopyFrom(resource_identities[input_node.name])
    # Convert ReadVariableOps to Identity ops.
    elif input_node.op == "ReadVariableOp":
      _populate_identity_op(output_node, input_node)
    # Update the function names and function's arguments types for the If ops.
    elif input_node.op == "If":
      _populate_if_op(output_node, input_node, function_types)
    else:
      output_node.CopyFrom(input_node)

  # Add functions to reconstructed graph.
  if graph_def.library:
    library = output_graph_def.library

    for input_library_func in graph_def.library.function:
      orig_func_name = input_library_func.signature.name
      new_func_name = _get_new_function_name(orig_func_name)

      # Do not copy any functions that aren't being used in the graph. Any
      # functions that are not used by control flow should have been inlined.
      if orig_func_name not in function_types:
        continue

      output_library_func = library.function.add()
      for key, value in input_library_func.ret.items():
        output_library_func.ret[key] = value
      for key, value in input_library_func.control_ret.items():
        output_library_func.control_ret[key] = value

      # Update the input types in the function signature.
      output_library_func.signature.CopyFrom(input_library_func.signature)
      output_library_func.signature.name = new_func_name
      for dtype, arg in zip(function_types[orig_func_name],
                            output_library_func.signature.input_arg):
        arg.type = dtype

      # Update the NodeDefs.
      func_variables = {
          node.name: node.input[0]
          for node in input_library_func.node_def
          if node.op == "ReadVariableOp"
      }

      for input_node in input_library_func.node_def:
        output_node = output_library_func.node_def.add()
        # Convert ReadVariableOps to Identity ops.
        if input_node.op == "ReadVariableOp":
          _populate_identity_op(output_node, input_node)
        elif input_node.op == "If":
          _populate_if_op(output_node, input_node, function_types)
        else:
          output_node.CopyFrom(input_node)
          # Convert :value to :output for ops that use the ReadVariableOp.
          for idx, full_name in enumerate(input_node.input):
            input_name = _get_tensor_name(full_name)
            if input_name in func_variables:
              full_name_parts = full_name.split(":")
              full_name_parts[1] = "output"
              input_name = ":".join(full_name_parts)
              output_node.input[idx] = input_name

  output_graph_def.versions.CopyFrom(graph_def.versions)
  return _construct_concrete_function(func, output_graph_def,
                                      converted_input_indices)
Exemplo n.º 23
0
def strip_unused(input_graph_def, input_tensor_names, output_tensor_names,
                 placeholder_type_enum):
    """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_tensor_names: A list of the nodes we use as inputs.
    output_tensor_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed. and a map containing the old input
    names to the new input names

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
    for name in input_tensor_names:
        if ":" not in name:
            raise ValueError("Input '%s' appears to refer to a Operation, "
                             "not a Tensor." % name)

    old2new = {}

    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    not_found = {name for name in input_tensor_names}
    input_node_names = {name.split(":")[0] for name in input_tensor_names}
    output_node_names = list(
        {name.split(":")[0]
         for name in output_tensor_names})
    inputs_replaced_graph_def = graph_pb2.GraphDef()
    for node in input_graph_def.node:
        if node.name not in input_node_names:
            for i in range(len(node.input)):
                if _append_port(node.input[i]) in input_tensor_names:
                    old_name = _append_port(node.input[i])
                    not_found.remove(old_name)
                    new_input_name = node.input[i].replace(":", "_")
                    placeholder_node = node_def_pb2.NodeDef()
                    placeholder_node.op = "Placeholder"
                    placeholder_node.name = new_input_name
                    if isinstance(placeholder_type_enum, list):
                        input_node_index = input_tensor_names.index(old_name)
                        placeholder_node.attr["dtype"].CopyFrom(
                            attr_value_pb2.AttrValue(
                                type=placeholder_type_enum[input_node_index]))
                    else:
                        placeholder_node.attr["dtype"].CopyFrom(
                            attr_value_pb2.AttrValue(
                                type=placeholder_type_enum))
                    if "_output_shapes" in node.attr:
                        placeholder_node.attr["_output_shapes"].CopyFrom(
                            node.attr["_output_shapes"])
                    node.input[i] = new_input_name
                    old2new[old_name] = new_input_name + ":0"
                    inputs_replaced_graph_def.node.extend([placeholder_node])
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    if not_found:
        raise KeyError("The following input nodes were not found: %s\n" %
                       not_found)

    output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                    output_node_names)
    return output_graph_def, old2new
Exemplo n.º 24
0
def load_frozenmodel(model_path):
    print('> Loading frozen model into memory')
    num_classes = 90
    # load a frozen Model and split it into GPU and CPU graphs
    # Hardcoded for ssd_mobilenet
    input_graph = tf.Graph()
    with tf.Session(graph=input_graph):
        shape = 1917
        score = tf.placeholder(tf.float32,
                               shape=(None, shape, num_classes),
                               name="Postprocessor/convert_scores")
        expand = tf.placeholder(tf.float32,
                                shape=(None, shape, 1, 4),
                                name="Postprocessor/ExpandDims_1")
        for node in input_graph.as_graph_def().node:
            if node.name == "Postprocessor/convert_scores":
                score_def = node
            if node.name == "Postprocessor/ExpandDims_1":
                expand_def = node

    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        # model_path = 'frozen_inference_graph.pb'
        with tf.gfile.GFile(model_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            dest_nodes = [
                'Postprocessor/convert_scores', 'Postprocessor/ExpandDims_1'
            ]

            edges = {}
            name_to_node_map = {}
            node_seq = {}
            seq = 0
            for node in od_graph_def.node:
                n = _node_name(node.name)
                name_to_node_map[n] = node
                edges[n] = [_node_name(x) for x in node.input]
                node_seq[n] = seq
                seq += 1
            for d in dest_nodes:
                assert d in name_to_node_map, "%s is not in graph" % d

            nodes_to_keep = set()
            next_to_visit = dest_nodes[:]

            while next_to_visit:
                n = next_to_visit[0]
                del next_to_visit[0]
                if n in nodes_to_keep: continue
                nodes_to_keep.add(n)
                next_to_visit += edges[n]

            nodes_to_keep_list = sorted(list(nodes_to_keep),
                                        key=lambda n: node_seq[n])
            nodes_to_remove = set()

            for n in node_seq:
                if n in nodes_to_keep_list: continue
                nodes_to_remove.add(n)
            nodes_to_remove_list = sorted(list(nodes_to_remove),
                                          key=lambda n: node_seq[n])

            keep = graph_pb2.GraphDef()
            for n in nodes_to_keep_list:
                keep.node.extend([copy.deepcopy(name_to_node_map[n])])

            remove = graph_pb2.GraphDef()
            remove.node.extend([score_def])
            remove.node.extend([expand_def])
            for n in nodes_to_remove_list:
                remove.node.extend([copy.deepcopy(name_to_node_map[n])])

            with tf.device('/gpu:0'):
                tf.import_graph_def(keep, name='')
            with tf.device('/cpu:0'):
                tf.import_graph_def(remove, name='')

    return detection_graph, score, expand
Exemplo n.º 25
0
 def converted_self(self):
     if self._converted_self is None:
         copied_graph = graph_pb2.GraphDef()
         copied_graph.CopyFrom(self._graph_def)
         self._converted_self = _GraphDef(copied_graph)
     return self._converted_self
Exemplo n.º 26
0
def create_subgraph(tf_graph, node_list, sess, dst_scope=None):
    """
    Create a tf subgraph from the node list.
    :param tf_graph:
    :param node_list:
    :param sess:
    :param dst_scope:
    :return:
    """
    variable_dict_names = []
    variable_names = []
    tensor_op_names = []
    for n_ in node_list:  # type: tf.Operation
        tensor_op_names.extend([ts_.op.name for ts_ in n_.inputs])
        if n_.type in ["Variable", "VariableV2", "VarHandleOp"]:
            variable_name = n_.name
            variable_dict_names.append(variable_name)

            if n_.type == "VarHandleOp":
                variable_names.append(variable_name + "/Read/ReadVariableOp:0")
            else:
                variable_names.append(variable_name + ":0")
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    found_variables = dict(zip(variable_dict_names, returned_variables))
    all_op_names = set([n_.name for n_ in node_list])
    missing_ops = set(tensor_op_names) - all_op_names

    replacement = {}
    tf_graph_def = tf_graph.as_graph_def()
    subgraph_def = _extract_sub_graph(tf_graph_def, [n_.name for n_ in node_list], missing_ops)

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in subgraph_def.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            output_node.attr["dtype"].CopyFrom(dtype)
            output_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(
                    tensor=tensor_util.make_tensor_proto(
                        data, dtype=dtype.type, shape=data.shape)))
            how_many_converted += 1
        elif input_node.op == "ReadVariableOp" and (
                input_node.input[0] in found_variables):
            # The preceding branch converts all VarHandleOps of ResourceVariables to
            # constants, so we need to convert the associated ReadVariableOps to
            # Identity ops.
            output_node.op = "Identity"
            output_node.name = input_node.name
            output_node.input.extend([input_node.input[0]])
            output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
            if "_class" in input_node.attr:
                output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
        elif input_node.name not in missing_ops:
            output_node.CopyFrom(input_node)
        else:
            output_node = None
        if output_node is not None:
            output_graph_def.node.extend([output_node])

    for input_node in tf_graph_def.node:
        if input_node.name in missing_ops:
            output_node = node_def_pb2.NodeDef()
            output_node.op = "Placeholder"
            output_node.name = input_node.name
            replacement[input_node.name] = input_node.name
            if str(input_node.attr["dtype"]):
                output_node.attr["dtype"].CopyFrom(input_node.attr["dtype"])
            elif str(input_node.attr["T"]):
                output_node.attr["dtype"].CopyFrom(input_node.attr["T"])
            else:
                if input_node.op == 'All':
                    output_node.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(type="DT_BOOL"))
                elif input_node.op == 'Cast':
                    output_node.attr["dtype"].CopyFrom(input_node.attr["DstT"])
                else:
                    raise RuntimeError("Can't get the node data type for %s" % input_node.name)
            ts_shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, input_node.name)
            output_node.attr["shape"].CopyFrom(
                attr_value_pb2.AttrValue(shape=ts_shape.as_proto()))
            output_graph_def.node.extend([output_node])

    output_graph_def.library.CopyFrom(subgraph_def.library)
    with tf.Graph().as_default() as sub_graph:
        im_scope = "" if dst_scope is None else dst_scope
        tf.import_graph_def(output_graph_def, name=im_scope)
        if im_scope:
            replacement = {k_: im_scope + '/' + k_ for k_ in replacement}

    return sub_graph, replacement
Exemplo n.º 27
0
 def __init__(self, model_dir, outputs=None):
     from tensorflow.core.framework import graph_pb2
     self._tmp_dir = util.tempdir()
     self._model_dir = model_dir
     self._graph = graph_pb2.GraphDef()
     self._outputs = outputs or []
Exemplo n.º 28
0
  def from_frozen_graph(cls,
                        graph_def_file,
                        input_arrays,
                        output_arrays,
                        input_shapes=None):
    """Creates a TocoConverter class from a file containing a frozen GraphDef.

    Args:
      graph_def_file: Full filepath of file containing frozen GraphDef.
      input_arrays: List of input tensors to freeze graph with.
      output_arrays: List of output tensors to freeze graph with.
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
        None}). (default None)

    Returns:
      TocoConverter class.

    Raises:
      ValueError:
        Unable to parse input file.
        The graph is not frozen.
        input_arrays or output_arrays contains an invalid tensor name.
    """
    with _session.Session() as sess:
      sess.run(_global_variables_initializer())

      # Read GraphDef from file.
      graph_def = _graph_pb2.GraphDef()
      with open(graph_def_file, "rb") as f:
        file_content = f.read()
      try:
        graph_def.ParseFromString(file_content)
      except (_text_format.ParseError, DecodeError):
        try:
          print("Ignore 'tcmalloc: large alloc' warnings.")

          if not isinstance(file_content, str):
            if PY3:
              file_content = file_content.decode('utf-8')
            else:
              file_content = file_content.encode('utf-8')
          _text_format.Merge(file_content, graph_def)
        except (_text_format.ParseError, DecodeError):
          raise ValueError(
              "Unable to parse input file '{}'.".format(graph_def_file))
      sess.graph.as_default()
      _import_graph_def(graph_def, name="")

      # Get input and output tensors.
      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
      _set_tensor_shapes(input_tensors, input_shapes)

      # Check if graph is frozen.
      if not _is_frozen_graph(sess):
        raise ValueError("Please freeze the graph using freeze_graph.py.")

      # Create TocoConverter class.
      return cls(sess.graph_def, input_tensors, output_tensors)
Exemplo n.º 29
0
  def from_frozen_graph(cls,
                        graph_def_file,
                        input_arrays,
                        output_arrays,
                        input_shapes=None):
    """Creates a TocoConverter class from a file containing a frozen GraphDef.

    Args:
      graph_def_file: Full filepath of file containing frozen GraphDef.
      input_arrays: List of input tensors to freeze graph with.
      output_arrays: List of output tensors to freeze graph with.
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
        None}). (default None)

    Returns:
      TocoConverter class.

    Raises:
      ValueError:
        Unable to parse input file.
        The graph is not frozen.
        input_arrays or output_arrays contains an invalid tensor name.
        input_shapes is not correctly defined when required
    """
    with _ops.Graph().as_default():
      with _session.Session() as sess:
        # Read GraphDef from file.
        graph_def = _graph_pb2.GraphDef()
        with open(graph_def_file, "rb") as f:
          file_content = f.read()
        try:
          graph_def.ParseFromString(file_content)
        except (_text_format.ParseError, DecodeError):
          try:
            print("Ignore 'tcmalloc: large alloc' warnings.")

            if not isinstance(file_content, str):
              if PY3:
                file_content = file_content.decode("utf-8")
              else:
                file_content = file_content.encode("utf-8")
            _text_format.Merge(file_content, graph_def)
          except (_text_format.ParseError, DecodeError):
            raise ValueError(
                "Unable to parse input file '{}'.".format(graph_def_file))

        # Handles models with custom TFLite ops that cannot be resolved in
        # TensorFlow.
        load_model_in_session = True
        try:
          _import_graph_def(graph_def, name="")
        except _NotFoundError:
          load_model_in_session = False

        if load_model_in_session:
          # Check if graph is frozen.
          if not _is_frozen_graph(sess):
            raise ValueError("Please freeze the graph using freeze_graph.py.")

          # Get input and output tensors.
          input_tensors = _get_tensors_from_tensor_names(
              sess.graph, input_arrays)
          output_tensors = _get_tensors_from_tensor_names(
              sess.graph, output_arrays)
          _set_tensor_shapes(input_tensors, input_shapes)

          return cls(sess.graph_def, input_tensors, output_tensors)
        else:
          if not input_shapes:
            raise ValueError("input_shapes must be defined for this model.")
          if set(input_arrays) != set(input_shapes.keys()):
            raise ValueError("input_shapes must contain a value for each item "
                             "in input_array.")

          input_arrays_with_shape = [
              (name, input_shapes[name]) for name in input_arrays
          ]
          return cls(
              graph_def,
              input_tensors=None,
              output_tensors=None,
              input_arrays_with_shape=input_arrays_with_shape,
              output_arrays=output_arrays)
Exemplo n.º 30
0
def _static_range_quantize(saved_model_path: str,
                           signature_keys=None,
                           tags=None,
                           output_directory=None,
                           representative_dataset=None):
    """Quantizes the given SavedModel via static range quantization.

  Args:
    saved_model_path: Path to the saved model. When representative_dataset is
      not provided, this should be a model trained with QAT.
    signature_keys: List of keys identifying SignatureDef containing inputs and
      outputs.
    tags: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze.
    output_directory: The path to save the output SavedModel (must be an empty
      directory).
    representative_dataset: a generator that returns a dictionary in
      {input_name: input_tensor} format or a tuple with signature key and a
      dictionary in {input_name: input_tensor} format that feeds calibration
      data for quantizing model. This should be provided when the model is not a
      QAT model.

  Returns:
    A SavedModel object with TF quantization applied.

  Raises:
    ValueError: when representative_dataset is not provided for non-QAT model.
  """
    is_qat_saved_model = _is_qat_saved_model(saved_model_path)
    signatures = _get_signatures_from_saved_model(saved_model_path,
                                                  signature_keys, tags)

    # Checks if the model is from QAT
    if representative_dataset is None and not is_qat_saved_model:
        raise ValueError(
            'When `representative_dataset` is not provided, the model should be '
            'trained with quantization-aware training (QAT).')

    if is_qat_saved_model:
        # Handle QAT models are supported.
        graph_def_serialized = (quantize_model_wrapper.quantize_qat_model(
            saved_model_path, ','.join(signature_keys), ','.join(tags)))
    else:
        # Handle PTQ models are supported with mocking calibration.
        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_pre_calibration(
                saved_model_path, ','.join(signature_keys), ','.join(tags)))

        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(graph_def_serialized)

        float_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(float_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            for function_def in graph_def.library.function:
                for node_def in function_def.node_def:
                    if node_def.op == 'CustomAggregator':
                        node_def.attr['id'].s = uuid.uuid4().hex.encode(
                            'ascii')

            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            signatures = _fix_tensor_names(signatures, working_graph)
            if signatures is None:
                raise ValueError(
                    "The input SavedModel doesn't contain a valid signature")

            v1_builder.add_meta_graph_and_variables(
                sess, [tag_constants.SERVING], signature_def_map=signatures)

        v1_builder.save()

        float_model = saved_model_load(float_model_dir)

        for sample in representative_dataset():
            # TODO(b/214311251): Add a test case with multiple signatures.
            if isinstance(sample, tuple):
                if not isinstance(sample[1], dict):
                    raise ValueError(
                        'You need to provide a dictionary with input '
                        'names and values in the second argument in the '
                        'tuple')
                signature_key = sample[0]
                input_data_map = sample[1]
            elif isinstance(sample, dict):
                if len(signature_keys) > 1:
                    raise ValueError(
                        'When the model has multiple signatures, you need '
                        'to provide a tuple with signature key and a '
                        'dictionary with input names and values')
                signature_key = signature_keys[0]
                input_data_map = sample
            else:
                raise ValueError(
                    'You need to provide either a dictionary with input '
                    'names and values or a tuple with signature key and a '
                    'dictionary with input names and values')
            func = float_model.signatures[signature_key]
            func(**input_data_map)

        for function_def in graph_def.library.function:
            for node_def in function_def.node_def:
                if node_def.op == 'CustomAggregator':
                    node_id = node_def.attr['id'].s
                    try:
                        min_val = quantize_model_wrapper.get_min_from_calibrator(
                            node_id)
                        max_val = quantize_model_wrapper.get_max_from_calibrator(
                            node_id)
                        quantize_model_wrapper.clear_data_from_calibrator(
                            node_id)
                        node_def.attr['min'].f = float(min_val)
                        node_def.attr['max'].f = float(max_val)
                    except ValueError:
                        warnings.warn(
                            f'CustomAggregator id "{node_id.decode("utf-8")}" from '
                            f'FunctionDef "{function_def.signature.name}" does not have '
                            'min or max values. This function may not be quantized.'
                        )

        calibrated_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(calibrated_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            v1_builder.add_meta_graph_and_variables(
                sess, [tag_constants.SERVING], signature_def_map=signatures)

        v1_builder.save()
        signatures = _get_signatures_from_saved_model(calibrated_model_dir,
                                                      signature_keys, tags)

        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_post_calibration(
                calibrated_model_dir,
                ','.join(signature_keys),
                ','.join(tags),
            ))

    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(graph_def_serialized)

    if output_directory is None:
        output_directory = tempfile.mkdtemp()
    v1_builder = builder.SavedModelBuilder(output_directory)

    with session.Session(graph=ops.Graph()) as sess:
        importer.import_graph_def(graph_def, name='')
        working_graph = ops.get_default_graph()

        signatures = _fix_tensor_names(signatures, working_graph)
        if signatures is None:
            raise ValueError(
                "The input SavedModel doesn't contain a valid signature")

        v1_builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
                                                signature_def_map=signatures)

    v1_builder.save()

    return saved_model_load(output_directory)