示例#1
0
def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key):
  """Prints input and output TensorInfos.

  Prints the details of input and output TensorInfos for the SignatureDef mapped
  by the given signature_def_key.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect.
    tag_set: Group of tag(s) of the MetaGraphDef, in string format, separated by
        ','. For tag-set contains multiple tags, all tags must be passed in.
    signature_def_key: A SignatureDef key string.
  """
  meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir,
                                                        tag_set)
  inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
      meta_graph_def, signature_def_key)
  outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
      meta_graph_def, signature_def_key)

  print('The given SavedModel SignatureDef contains the following input(s):')
  for input_key, input_tensor in sorted(inputs_tensor_info.items()):
    print('inputs[\'%s\'] tensor_info:' % input_key)
    _print_tensor_info(input_tensor)

  print('The given SavedModel SignatureDef contains the following output(s):')
  for output_key, output_tensor in sorted(outputs_tensor_info.items()):
    print('outputs[\'%s\'] tensor_info:' % output_key)
    _print_tensor_info(output_tensor)

  print('Method name is: %s' %
        meta_graph_def.signature_def[signature_def_key].method_name)
示例#2
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_whitelist="",
                 variable_names_blacklist="",
                 input_meta_graph=None,
                 input_saved_model_dir=None,
                 saved_model_tags=tag_constants.SERVING):
  """Converts all variables in a graph and checkpoint into constants."""
  input_graph_def = None
  if input_saved_model_dir:
    input_graph_def = saved_model_utils.get_meta_graph_def(
        input_saved_model_dir, saved_model_tags).graph_def
  elif input_graph:
    input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
  input_meta_graph_def = None
  if input_meta_graph:
    input_meta_graph_def = _parse_input_meta_graph_proto(
        input_meta_graph, input_binary)
  input_saver_def = None
  if input_saver:
    input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
  freeze_graph_with_def_protos(
      input_graph_def, input_saver_def, input_checkpoint, output_node_names,
      restore_op_name, filename_tensor_name, output_graph, clear_devices,
      initializer_nodes, variable_names_whitelist, variable_names_blacklist,
      input_meta_graph_def, input_saved_model_dir, saved_model_tags.split(","))
示例#3
0
  def _TestCreateInferenceGraph(self,
                                input_saved_model_dir=None,
                                output_saved_model_dir=None):
    """General method to test trt_convert.create_inference_graph()."""
    input_graph_def = None if input_saved_model_dir else self._GetGraphDef()
    output_graph_def = trt_convert.create_inference_graph(
        input_graph_def, ["output"],
        max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
        input_saved_model_dir=input_saved_model_dir,
        output_saved_model_dir=output_saved_model_dir,
        session_config=self._GetConfigProto())
    graph_defs_to_verify = [output_graph_def]
    if output_saved_model_dir is not None:
      saved_model_graph_def = saved_model_utils.get_meta_graph_def(
          output_saved_model_dir, tag_constants.SERVING).graph_def
      self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
      graph_defs_to_verify.append(saved_model_graph_def)

    for graph_def in graph_defs_to_verify:
      node_name_to_op = {node.name: node.op for node in graph_def.node}
      self.assertEqual({
          "input": "Placeholder",
          "TRTEngineOp_0": "TRTEngineOp",
          "output": "Identity"
      }, node_name_to_op)
示例#4
0
  def _TestTrtGraphConverter(self,
                             input_saved_model_dir=None,
                             output_saved_model_dir=None,
                             need_calibration=False,
                             is_dynamic_op=False):
    """General method to test trt_convert.TrtGraphConverter()."""
    output_graph_def = self._ConvertGraph(
        input_saved_model_dir=input_saved_model_dir,
        output_saved_model_dir=output_saved_model_dir,
        need_calibration=need_calibration,
        is_dynamic_op=is_dynamic_op,
        use_function_backup=need_calibration)
    graph_defs_to_verify = [output_graph_def]

    if output_saved_model_dir:
      if context.executing_eagerly():
        root = load.load(output_saved_model_dir)
        saved_model_graph_def = root.signatures[
            signature_constants
            .DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def()
      else:
        saved_model_graph_def = saved_model_utils.get_meta_graph_def(
            output_saved_model_dir, tag_constants.SERVING).graph_def
      self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
      graph_defs_to_verify.append(saved_model_graph_def)

    for graph_def in graph_defs_to_verify:
      node_name_to_op = {node.name: node.op for node in graph_def.node}
      if context.executing_eagerly():
        # In V2 the actual graph could be inside a function.
        for func in graph_def.library.function:
          node_name_to_op.update({node.name: node.op for node in func.node_def})
        self.assertIn("TRTEngineOp_0", node_name_to_op)
        self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"])
      else:
        self.assertEqual({
            "input": "Placeholder",
            "TRTEngineOp_0": "TRTEngineOp",
            "output": "Identity"
        }, node_name_to_op)

      if need_calibration:
        trt_engine_nodes = [
            node for node in graph_def.node if node.op == "TRTEngineOp"
        ]
        self.assertNotEmpty(trt_engine_nodes)
        for node in trt_engine_nodes:
          self.assertTrue(len(node.attr["calibration_data"].s))
        # Run the calibrated graph.
        # TODO(laigd): consider having some input where the answer is different.
        with ops.Graph().as_default():
          importer.import_graph_def(graph_def, name="")
          with self.session(config=self._GetConfigProto()) as sess:
            for test_data in range(10):
              self.assertEqual((test_data + 1.0)**2,
                               sess.run(
                                   "output:0",
                                   feed_dict={"input:0": [[[test_data]]]}))
示例#5
0
def scan(args):
  """Function triggered by scan command.

  Args:
    args: A namespace parsed from command line.
  """
  if args.tag_set:
    scan_meta_graph_def(
        saved_model_utils.get_meta_graph_def(args.dir, args.tag_set))
  else:
    saved_model = reader.read_saved_model(args.dir)
    for meta_graph_def in saved_model.meta_graphs:
      scan_meta_graph_def(meta_graph_def)
示例#6
0
def get_signature_def_map(saved_model_dir, tag_set):
  """Gets SignatureDef map from a MetaGraphDef in a SavedModel.

  Returns the SignatureDef map for the given tag-set in the SavedModel
  directory.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect or execute.
    tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
        string format, separated by ','. For tag-set contains multiple tags, all
        tags must be passed in.

  Returns:
    A SignatureDef map that maps from string keys to SignatureDefs.
  """
  meta_graph = saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set)
  return meta_graph.signature_def
示例#7
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_whitelist="",
                 variable_names_blacklist="",
                 input_meta_graph=None,
                 input_saved_model_dir=None,
                 saved_model_tags=tag_constants.SERVING,
                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants."""
    input_graph_def = None
    if input_saved_model_dir:
        input_graph_def = saved_model_utils.get_meta_graph_def(
            input_saved_model_dir, saved_model_tags).graph_def
    elif input_graph:
        input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
    input_meta_graph_def = None
    if input_meta_graph:
        input_meta_graph_def = _parse_input_meta_graph_proto(
            input_meta_graph, input_binary)
    input_saver_def = None
    if input_saver:
        input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
    freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist,
                                 variable_names_blacklist,
                                 input_meta_graph_def,
                                 input_saved_model_dir,
                                 saved_model_tags.split(","),
                                 checkpoint_version=checkpoint_version)
示例#8
0
  def _TestTrtGraphConverter(self,
                             input_saved_model_dir=None,
                             output_saved_model_dir=None,
                             need_calibration=False,
                             is_dynamic_op=False):
    """General method to test trt_convert.TrtGraphConverter()."""
    output_graph_def = self._ConvertGraph(
        input_saved_model_dir=input_saved_model_dir,
        output_saved_model_dir=output_saved_model_dir,
        need_calibration=need_calibration,
        is_dynamic_op=is_dynamic_op,
        use_function_backup=need_calibration)
    graph_defs_to_verify = [output_graph_def]

    if output_saved_model_dir:
      saved_model_graph_def = saved_model_utils.get_meta_graph_def(
          output_saved_model_dir, tag_constants.SERVING).graph_def
      self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
      graph_defs_to_verify.append(saved_model_graph_def)

    for graph_def in graph_defs_to_verify:
      node_name_to_op = {node.name: node.op for node in graph_def.node}
      self.assertEqual(
          {
              "input": "Placeholder",
              "TRTEngineOp_0": "TRTEngineOp",
              "output": "Identity"
          }, node_name_to_op)

      if need_calibration:
        trt_engine_nodes = [
            node for node in graph_def.node if node.op == "TRTEngineOp"
        ]
        self.assertNotEmpty(trt_engine_nodes)
        for node in trt_engine_nodes:
          self.assertTrue(len(node.attr["calibration_data"].s))
        # Run the calibrated graph.
        # TODO(laigd): consider having some input where the answer is different.
        with ops.Graph().as_default():
          importer.import_graph_def(graph_def, name="")
          with self.session(config=self._GetConfigProto()) as sess:
            for test_data in range(10):
              self.assertEqual(
                  (test_data + 1.0)**2,
                  sess.run("output:0", feed_dict={"input:0": [[[test_data]]]}))
    def __init__(self, model_dir, output_columns=None):
        """
        :param model_dir: The directory where the model is saved.
        :param output_columns: List of column names for the output.
                The saved models typically return single output called "predictions"
                containing all the predictions as a 2-d numpy array where number of columns
                is expected number of outputs for each input tuple. 'output_columns' provides
                names for each of the columns. If it is None, default names are assigned.
        """
        self.model_dir = model_dir
        self.output_columns = output_columns

        self.meta_graph_def = saved_model_utils.get_meta_graph_def(
            self.model_dir, tag_set=DEFAULT_TAG)
        signature_def = self.meta_graph_def.signature_def[
            DEFAULT_SIGNATURE_DEF_KEY]
        self.input_tensors = signature_def.inputs

        output_tensors = signature_def.outputs

        # Output is expected to be single "predictions" ndarray. Enforce that.
        if output_tensors.keys() != {DEFAULT_OUTPUT_KEY}:
            raise RuntimeError(
                'Expected single output named "{}", but found [{}]'.format(
                    DEFAULT_OUTPUT_KEY, ','.join(output_tensors.keys())))

        self.output_tensor = output_tensors[DEFAULT_OUTPUT_KEY]
        output_shape = tuple(d.size
                             for d in self.output_tensor.tensor_shape.dim)

        # Ensure that the output_tensor shape matches (-1, len(output_columns))
        if self.output_columns:
            expected_shape = (-1, len(self.output_columns))
            if expected_shape != output_shape:
                raise RuntimeError(
                    'Shape of prediction does not match with output columns. '
                    'Expected shape is {}, but found {}.'.format(
                        expected_shape, output_shape))
        else:
            self.output_columns = [
                f'prediction_{i}' for i in range(output_shape[1])
            ]

        self.sess = session.Session(None, graph=ops_lib.Graph())
        loader.load(self.sess, [DEFAULT_TAG], self.model_dir)
示例#10
0
def verify_outputs(args, onnx_model):
    tag_sets = saved_model_utils.get_saved_model_tag_sets(args.saved_model)
    for tag_set in tag_sets:
        tag_set = ','.join(tag_set)
        meta_graph_def = saved_model_utils.get_meta_graph_def(
            args.saved_model, tag_set)
        signature_def_map = meta_graph_def.signature_def
        for signature_def_key in signature_def_map.keys():
            outputs_tensor_info = signature_def_map[signature_def_key].outputs
            for output_key, output_tensor in outputs_tensor_info.items():
                rename_output(onnx_model, output_key, output_tensor)

    print("Inputs in model: {}".format(", ".join([
        "'{}'".format(o.name) for o in onnx_model.graph.input
        if not has_initializer(onnx_model, o.name)
    ])))
    print("Outputs in model: {}".format(", ".join(
        ["'{}'".format(o.name) for o in onnx_model.graph.output])))
示例#11
0
def scan_graph(input_checkpoint=None,
               input_saved_model_dir=None,
               saved_model_tags=tag_constants.SERVING):
    """extract the graph to scan from a model file."""

    if (not input_saved_model_dir and not input_checkpoint):
        print("Please specify a checkpoint or \'SavedModel\' file!")
        return -1
    if (input_saved_model_dir and input_checkpoint):
        print("Please specify only *One* model file type: \
checkpoint or \'SavedModel\'!")
        return -1

    input_graph_def = None
    if input_checkpoint:
        # now we doesn't use the variables file, but still check it for completeness
        if not saver_lib.checkpoint_exists(input_checkpoint):
            print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
            return -1
        # Build meta file path for a checkpoint
        meta_file = input_checkpoint + ".meta"
        if not gfile.Exists(meta_file):
            print("Input checkpoint meta file '" + meta_file +
                  "' doesn't exist!")
            return -1
        try:
            input_graph_def = _parse_input_meta_graph_proto(meta_file,
                                                            True).graph_def
        except:
            exctype, value = sys.exc_info()[:2]
            print("Parse checkpoint meta-graph file '%s' failed: %s(%s)" %\
                  (meta_file, exctype, value))
            return -1
    if input_saved_model_dir:
        try:
            input_graph_def = saved_model_utils.get_meta_graph_def(
                input_saved_model_dir, saved_model_tags).graph_def
        except:
            exctype, value = sys.exc_info()[:2]
            print("Parse SaveModel '%s' meta-graph file failed: %s(%s)" %\
                  (input_saved_model_dir, exctype, value))
            return -1

    return detect_ops(input_graph_def)
示例#12
0
def get_input_and_output_names(saved_model_dir, tag_set, signature_def_key):

    meta_graph_def = saved_model_utils.get_meta_graph_def(
        saved_model_dir, tag_set)
    inputs_tensor_info = get_inputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)
    outputs_tensor_info = get_outputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)

    inputs = {
        input_key: input_tensor.name
        for input_key, input_tensor in inputs_tensor_info.items()
    }
    outputs = {
        output_key: output_tensor.name
        for output_key, output_tensor in outputs_tensor_info.items()
    }

    return inputs, outputs
 def _GetGraphDef(self, run_params, gdef_or_saved_model_dir):
     if isinstance(gdef_or_saved_model_dir, str):
         if run_params.is_v2:
             root = load.load(gdef_or_saved_model_dir)
             func = root.signatures[
                 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
             gdef = func.graph.as_graph_def()
             # Manually unref the loaded saved model and force GC to destroy the TRT
             # engine cache after load(). There is currently a reference cycle in 2.0
             # which prevents auto deletion of the resource.
             # TODO(laigd): fix this.
             del func
             del root
             gc.collect()
             return gdef
         return saved_model_utils.get_meta_graph_def(
             gdef_or_saved_model_dir, tag_constants.SERVING).graph_def
     assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef)
     return gdef_or_saved_model_dir
示例#14
0
def _show_inputs_outputs(saved_model_dir,
                         tag_set,
                         signature_def_key,
                         indent=0):
    """Prints input and output TensorInfos.

  Prints the details of input and output TensorInfos for the SignatureDef mapped
  by the given signature_def_key.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect.
    tag_set: Group of tag(s) of the MetaGraphDef, in string format, separated by
        ','. For tag-set contains multiple tags, all tags must be passed in.
    signature_def_key: A SignatureDef key string.
    indent: How far (in increments of 2 spaces) to indent each line of output.
  """
    meta_graph_def = saved_model_utils.get_meta_graph_def(
        saved_model_dir, tag_set)
    inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)
    outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)

    indent_str = "  " * indent

    def in_print(s):
        print(indent_str + s)

    in_print(
        'The given SavedModel SignatureDef contains the following input(s):')
    for input_key, input_tensor in sorted(inputs_tensor_info.items()):
        in_print('  inputs[\'%s\'] tensor_info:' % input_key)
        _print_tensor_info(input_tensor, indent + 1)

    in_print('The given SavedModel SignatureDef contains the following '
             'output(s):')
    for output_key, output_tensor in sorted(outputs_tensor_info.items()):
        in_print('  outputs[\'%s\'] tensor_info:' % output_key)
        _print_tensor_info(output_tensor, indent + 1)

    in_print('Method name is: %s' %
             meta_graph_def.signature_def[signature_def_key].method_name)
示例#15
0
def get_meta_graph_def(saved_model_dir, tag_set):
  """DEPRECATED: Use saved_model_utils.get_meta_graph_def instead.

  Gets MetaGraphDef from SavedModel. Returns the MetaGraphDef for the given
  tag-set and SavedModel directory.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect or execute.
    tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
        separated by ','. For tag-set contains multiple tags, all tags must be
        passed in.

  Raises:
    RuntimeError: An error when the given tag-set does not exist in the
        SavedModel.

  Returns:
    A MetaGraphDef corresponding to the tag-set.
  """
  return saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set)
示例#16
0
def get_meta_graph_def(saved_model_dir, tag_set):
    """DEPRECATED: Use saved_model_utils.get_meta_graph_def instead.

  Gets MetaGraphDef from SavedModel. Returns the MetaGraphDef for the given
  tag-set and SavedModel directory.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect or execute.
    tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
        separated by ','. For tag-set contains multiple tags, all tags must be
        passed in.

  Raises:
    RuntimeError: An error when the given tag-set does not exist in the
        SavedModel.

  Returns:
    A MetaGraphDef corresponding to the tag-set.
  """
    return saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set)
def import_to_tensorboard(model_dir, log_dir):
    """View an imported protobuf model (`.pb` file) as a graph in Tensorboard.

    Args:
      model_dir: The location of the protobuf (`pb`) model to visualize
      log_dir: The location for the Tensorboard log to begin visualization from.

    Usage:
      Call this function with your model location and desired log directory.
      Launch Tensorboard by pointing it to the log directory.
      View your imported savedModel as a graph.
    """
    with session.Session(graph=ops.Graph()) as sess:
        input_graph_def = saved_model_utils.get_meta_graph_def(
            model_dir, 'serve').graph_def
        importer.import_graph_def(input_graph_def)

        pb_visual_writer = summary.FileWriter(log_dir)
        pb_visual_writer.add_graph(sess.graph)
        print("Model Imported. Visualize by running: "
              "tensorboard --logdir={}".format(log_dir))
示例#18
0
def import_to_tensorboard(model_dir, log_dir, tag_set):
    """View an SavedModel as a graph in Tensorboard.
  Args:
    model_dir: The directory containing the SavedModel to import.
    log_dir: The location for the Tensorboard log to begin visualization from.
    tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
      separated by ','. For tag-set contains multiple tags, all tags must be
      passed in.
  Usage: Call this function with your SavedModel location and desired log
    directory. Launch Tensorboard by pointing it to the log directory. View your
    imported SavedModel as a graph.
  """
    with session.Session(graph=ops.Graph()) as sess:
        input_graph_def = saved_model_utils.get_meta_graph_def(
            model_dir, tag_set).graph_def
        importer.import_graph_def(input_graph_def)

        pb_visual_writer = summary.FileWriter(log_dir)
        pb_visual_writer.add_graph(sess.graph)
        print("Model Imported. Visualize by running: "
              "tensorboard --logdir={}".format(log_dir))
示例#19
0
    def __init__(self, model_dir):
        # model_dir = os.path.abspath(model_dir)
        assert os.path.exists(model_dir), 'model_dir {} does not exist'.format(
            model_dir)
        assert os.path.isdir(
            model_dir), 'model_dir {} is not a directory'.format(model_dir)

        meta_graph_def = saved_model_utils.get_meta_graph_def(
            model_dir, 'serve')
        self.inputs_tensor_info = get_signature_def_by_key(
            meta_graph_def, 'predict').inputs
        outputs_tensor_info = get_signature_def_by_key(meta_graph_def,
                                                       'predict').outputs
        # Sort to preserve order because we need to go from value to key later.
        self.output_tensor_keys = sorted(outputs_tensor_info.keys())
        self.output_tensor_names = [
            outputs_tensor_info[tensor_key].name
            for tensor_key in self.output_tensor_keys
        ]

        self.session = tf.Session(graph=tf.Graph())
        loader.load(self.session, ['serve'], model_dir)
示例#20
0
def freeze_model(args):
  """Function triggered by freeze_model command.

  Args:
    args: A namespace parsed from command line.
  """
  checkpoint_path = (
      args.checkpoint_path
      or os.path.join(args.dir, 'variables/variables'))
  if not args.variables_to_feed:
    variables_to_feed = []
  elif args.variables_to_feed.lower() == 'all':
    variables_to_feed = None  # We will identify them after.
  else:
    variables_to_feed = args.variables_to_feed.split(',')

  saved_model_aot_compile.freeze_model(
      checkpoint_path=checkpoint_path,
      meta_graph_def=saved_model_utils.get_meta_graph_def(
          args.dir, args.tag_set),
      signature_def_key=args.signature_def_key,
      variables_to_feed=variables_to_feed,
      output_prefix=args.output_prefix)
示例#21
0
  def _TestCreateInferenceGraph(self,
                                input_saved_model_dir=None,
                                output_saved_model_dir=None):
    """General method to test trt_convert.create_inference_graph()."""
    input_graph_def = None if input_saved_model_dir else self._GetGraphDef()
    output_graph_def = trt_convert.create_inference_graph(
        input_graph_def, ["output"],
        input_saved_model_dir=input_saved_model_dir,
        output_saved_model_dir=output_saved_model_dir,
        session_config=self._GetConfigProto())
    graph_defs_to_verify = [output_graph_def]
    if output_saved_model_dir is not None:
      saved_model_graph_def = saved_model_utils.get_meta_graph_def(
          output_saved_model_dir, tag_constants.SERVING).graph_def
      self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
      graph_defs_to_verify.append(saved_model_graph_def)

    for graph_def in graph_defs_to_verify:
      node_name_to_op = {node.name: node.op for node in graph_def.node}
      self.assertEqual({
          "input": "Placeholder",
          "TRTEngineOp_0": "TRTEngineOp",
          "output": "Identity"
      }, node_name_to_op)
示例#22
0
#先打印savemodel的节点信息,确定输入输出
# python /xxx/tensorflow/tensorflow/python/tools/saved_model_cli.py show --dir /xxx/model/1/ --all

ctpn_path = './ctpn_savemodel/1'
desenet_path = './densenet_savemodel/1/'
ocr_model_path = './ctpn_densenet/1'

if os.path.exists(ocr_model_path):
    shutil.rmtree(ocr_model_path)

if __name__ == "__main__":

    with tf.Graph().as_default() as g1:
        with tf.Session(graph=g1) as sess1:
            input_graph_def1 = saved_model_utils.get_meta_graph_def(
                ctpn_path, tf.saved_model.tag_constants.SERVING).graph_def
            tf.saved_model.loader.load(sess1, ["serve"], ctpn_path)
            g1def = convert_variables_to_constants\
                (sess1, input_graph_def1, output_node_names=['strided_slice_91'],
                variable_names_whitelist=None, variable_names_blacklist=None)

    with tf.Graph().as_default() as g2:
        with tf.Session(graph=g2) as sess2:
            input_graph_def2 = saved_model_utils.get_meta_graph_def(
                desenet_path, tf.saved_model.tag_constants.SERVING).graph_def
            tf.saved_model.loader.load(sess2, ["serve"], desenet_path)
            g2def = convert_variables_to_constants\
                (sess2, input_graph_def2, output_node_names=['ArgMax','Max'],
                variable_names_whitelist=None, variable_names_blacklist=None)

    with tf.Graph().as_default() as g_combined:
示例#23
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_whitelist="",
                 variable_names_denylist="",
                 input_meta_graph=None,
                 input_saved_model_dir=None,
                 saved_model_tags=tag_constants.SERVING,
                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants.
  Args:
    input_graph: A `GraphDef` file to load.
    input_saver: A TensorFlow Saver file.
    input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated list of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted),
    variable_names_denylist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph: A `MetaGraphDef` file to load (optional).
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
                           variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format.
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2).
  Returns:
    String that is the location of frozen GraphDef.
  """
    input_graph_def = None
    if input_saved_model_dir:
        input_graph_def = saved_model_utils.get_meta_graph_def(
            input_saved_model_dir, saved_model_tags).graph_def
    elif input_graph:
        input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
    input_meta_graph_def = None
    if input_meta_graph:
        input_meta_graph_def = _parse_input_meta_graph_proto(
            input_meta_graph, input_binary)
    input_saver_def = None
    if input_saver:
        input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
    return freeze_graph_with_def_protos(
        input_graph_def,
        input_saver_def,
        input_checkpoint,
        output_node_names,
        restore_op_name,
        filename_tensor_name,
        output_graph,
        clear_devices,
        initializer_nodes,
        variable_names_whitelist,
        variable_names_denylist,
        input_meta_graph_def,
        input_saved_model_dir,
        [tag for tag in saved_model_tags.replace(" ", "").split(",") if tag],
        checkpoint_version=checkpoint_version)
示例#24
0
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
                                   input_tensor_key_feed_dict, outdir,
                                   overwrite_flag, tf_debug=False):
  """Runs SavedModel and fetch all outputs.

  Runs the input dictionary through the MetaGraphDef within a SavedModel
  specified by the given tag_set and SignatureDef. Also save the outputs to file
  if outdir is not None.

  Args:
    saved_model_dir: Directory containing the SavedModel to execute.
    tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
        string format, separated by ','. For tag-set contains multiple tags, all
        tags must be passed in.
    signature_def_key: A SignatureDef key string.
    input_tensor_key_feed_dict: A dictionary maps input keys to numpy ndarrays.
    outdir: A directory to save the outputs to. If the directory doesn't exist,
        it will be created.
    overwrite_flag: A boolean flag to allow overwrite output file if file with
        the same name exists.
    tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
        intermediate Tensor values and runtime GraphDefs while running the
        SavedModel.

  Raises:
    ValueError: When any of the input tensor keys is not valid.
    RuntimeError: An error when output file already exists and overwrite is not
    enabled.
  """
  # Get a list of output tensor names.
  meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir,
                                                        tag_set)

  # Re-create feed_dict based on input tensor name instead of key as session.run
  # uses tensor name.
  inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
      meta_graph_def, signature_def_key)

  # Check if input tensor keys are valid.
  for input_key_name in input_tensor_key_feed_dict.keys():
    if input_key_name not in inputs_tensor_info.keys():
      raise ValueError(
          '"%s" is not a valid input key. Please choose from %s, or use '
          '--show option.' %
          (input_key_name, '"' + '", "'.join(inputs_tensor_info.keys()) + '"'))

  inputs_feed_dict = {
      inputs_tensor_info[key].name: tensor
      for key, tensor in input_tensor_key_feed_dict.items()
  }
  # Get outputs
  outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
      meta_graph_def, signature_def_key)
  # Sort to preserve order because we need to go from value to key later.
  output_tensor_keys_sorted = sorted(outputs_tensor_info.keys())
  output_tensor_names_sorted = [
      outputs_tensor_info[tensor_key].name
      for tensor_key in output_tensor_keys_sorted
  ]

  with session.Session(graph=ops_lib.Graph()) as sess:
    loader.load(sess, tag_set.split(','), saved_model_dir)

    if tf_debug:
      sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)

    outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict)

    for i, output in enumerate(outputs):
      output_tensor_key = output_tensor_keys_sorted[i]
      print('Result for output key %s:\n%s' % (output_tensor_key, output))

      # Only save if outdir is specified.
      if outdir:
        # Create directory if outdir does not exist
        if not os.path.isdir(outdir):
          os.makedirs(outdir)
        output_full_path = os.path.join(outdir, output_tensor_key + '.npy')

        # If overwrite not enabled and file already exist, error out
        if not overwrite_flag and os.path.exists(output_full_path):
          raise RuntimeError(
              'Output file %s already exists. Add \"--overwrite\" to overwrite'
              ' the existing output files.' % output_full_path)

        np.save(output_full_path, output)
        print('Output %s is saved to %s' % (output_tensor_key,
                                            output_full_path))
示例#25
0
# -*- coding: utf-8 -*-

import pandas as pd
import tensorflow as tf
from tensorflow.python.tools import saved_model_utils
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
import common

tag = tf.saved_model.tag_constants.SERVING
signature_def = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY

with tf.Session() as sess:
    export_dir = common.get_export_dir()
    tf.saved_model.loader.load(sess, [tag], export_dir)

    meta_graph_def = saved_model_utils.get_meta_graph_def(export_dir, tag)
    predict_signature_def = signature_def_utils.get_signature_def_by_key(
        meta_graph_def, signature_def)

    inputs = common.get_test_inputs()
    examples = common.create_examples(inputs)

    fetches = [
        predict_signature_def.outputs[key].name
        for key in ['classes', 'scores']
    ]
    feed_dict = {predict_signature_def.inputs['inputs'].name: examples}

    outputs = sess.run(fetches, feed_dict=feed_dict)
    predictions = {
        'classes': outputs[0],
 def _GetGraphDef(self, gdef_or_saved_model_dir):
     if isinstance(gdef_or_saved_model_dir, str):
         return saved_model_utils.get_meta_graph_def(
             gdef_or_saved_model_dir, tag_constants.SERVING).graph_def
     assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef)
     return gdef_or_saved_model_dir
示例#27
0
    def _TestTrtGraphConverter(self,
                               device,
                               output_saved_model_dir=None,
                               need_calibration=False,
                               is_dynamic_op=False):
        """General method to test trt_convert.TrtGraphConverter()."""
        output_graph_def = self._ConvertGraphV1(
            output_saved_model_dir=output_saved_model_dir,
            need_calibration=need_calibration,
            is_dynamic_op=is_dynamic_op,
            device=device)
        graph_defs_to_verify = [output_graph_def]

        if output_saved_model_dir:
            saved_model_graph_def = saved_model_utils.get_meta_graph_def(
                output_saved_model_dir, tag_constants.SERVING).graph_def
            self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
            graph_defs_to_verify.append(saved_model_graph_def)

        for graph_def in graph_defs_to_verify:
            node_name_to_op = {
                self._MayRemoveGraphSequenceNumber(node.name): node.op
                for node in graph_def.node
            }
            if device is not None and device.startswith("/CPU:"):
                self.assertEqual(
                    {
                        "add": "AddV2",
                        "add/ReadVariableOp": "Const",
                        "add_1": "AddV2",
                        "add_2": "AddV2",
                        "input1": "Placeholder",
                        "input2": "Placeholder",
                        "mul": "Mul",
                        "output": "Identity"
                    }, node_name_to_op)
            else:
                self.assertEqual(
                    {
                        "input1": "Placeholder",
                        "input2": "Placeholder",
                        "TRTEngineOp_0": "TRTEngineOp",
                        "output": "Identity"
                    }, node_name_to_op)

            if need_calibration:
                trt_engine_nodes = [
                    node for node in graph_def.node if node.op == "TRTEngineOp"
                ]
                if device is not None and device.startswith("/CPU:"):
                    self.assertEmpty(trt_engine_nodes)
                    return

                self.assertNotEmpty(trt_engine_nodes)
                for node in trt_engine_nodes:
                    self.assertTrue(len(node.attr["calibration_data"].s))
                # Run the calibrated graph.
                # TODO(laigd): consider having some input where the answer is different.
                with ops.Graph().as_default():
                    importer.import_graph_def(graph_def, name="")
                    with self.session(config=self._GetConfigProto()) as sess:
                        for test_data in range(10):
                            self.assertEqual(
                                (test_data + 1.0)**2 + test_data,
                                sess.run("output:0",
                                         feed_dict={
                                             "input1:0": [[[test_data]]],
                                             "input2:0": [[[test_data]]]
                                         }))
示例#28
0
    def _TestTrtGraphConverter(self,
                               input_saved_model_dir=None,
                               output_saved_model_dir=None,
                               need_calibration=False,
                               is_dynamic_op=False):
        """General method to test trt_convert.TrtGraphConverter()."""
        output_graph_def = self._ConvertGraph(
            input_saved_model_dir=input_saved_model_dir,
            output_saved_model_dir=output_saved_model_dir,
            need_calibration=need_calibration,
            is_dynamic_op=is_dynamic_op)
        graph_defs_to_verify = [output_graph_def]

        if output_saved_model_dir:
            if context.executing_eagerly():
                root = load.load(output_saved_model_dir)
                saved_model_graph_def = root.signatures[
                    signature_constants.
                    DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def()
            else:
                saved_model_graph_def = saved_model_utils.get_meta_graph_def(
                    output_saved_model_dir, tag_constants.SERVING).graph_def
            self.assertTrue(
                isinstance(saved_model_graph_def, graph_pb2.GraphDef))
            graph_defs_to_verify.append(saved_model_graph_def)

        for graph_def in graph_defs_to_verify:
            node_name_to_op = {node.name: node.op for node in graph_def.node}
            if context.executing_eagerly():
                # In V2 the actual graph could be inside a function.
                for func in graph_def.library.function:
                    node_name_to_op.update(
                        {node.name: node.op
                         for node in func.node_def})
                self.assertIn("TRTEngineOp_0", node_name_to_op)
                self.assertEqual("TRTEngineOp",
                                 node_name_to_op["TRTEngineOp_0"])
            else:
                self.assertEqual(
                    {
                        "input": "Placeholder",
                        "TRTEngineOp_0": "TRTEngineOp",
                        "output": "Identity"
                    }, node_name_to_op)

            if need_calibration:
                trt_engine_nodes = [
                    node for node in graph_def.node if node.op == "TRTEngineOp"
                ]
                self.assertNotEmpty(trt_engine_nodes)
                for node in trt_engine_nodes:
                    self.assertTrue(len(node.attr["calibration_data"].s))
                # Run the calibrated graph.
                # TODO(laigd): consider having some input where the answer is different.
                with ops.Graph().as_default():
                    importer.import_graph_def(graph_def, name="")
                    with self.session(config=self._GetConfigProto()) as sess:
                        for test_data in range(10):
                            self.assertEqual(
                                (test_data + 1.0)**2,
                                sess.run(
                                    "output:0",
                                    feed_dict={"input:0": [[[test_data]]]}))
from tensorflow.python.tools import saved_model_utils

meta_graph_def = saved_model_utils.get_meta_graph_def(config.MODEL_DIR,
                                                      'serve')
inputs = meta_graph_def.signature_def['serving_default'].inputs
outputs = meta_graph_def.signature_def['serving_default'].outputs

# Just get the first thing(s) from the serving signature def.  i.e. this
# model only has a single input and a single output.
input_name = None
for k, v in inputs.items():
    input_name = v.name
    break

output_name = None
for k, v in outputs.items():
    output_name = v.name
    break

# Make a dictionary that maps Earth Engine outputs and inputs to
# AI Platform inputs and outputs, respectively.
import json
input_dict = "'" + json.dumps({input_name: "array"}) + "'"
output_dict = "'" + json.dumps({output_name: "impervious"}) + "'"
示例#30
0
def run_main(unused_args):

  input_model_dir = FLAGS.input_saved_model_dir
  output_model_dir = FLAGS.output_saved_model_dir
  sig_key = FLAGS.signature_key
  inp_tags = FLAGS.saved_model_tags
  if FLAGS.saved_model_tags == "":
    tag_set = []
  else:
    tag_set = [tag for tag in inp_tags.split(",")]
    avail_tags = saved_model_utils.get_saved_model_tag_sets(input_model_dir)
    found = False
    for tag in tag_set:
      if [tag] in avail_tags:
        found = True
      else:
        found = False
        break
    if not found:
      print ("Supplied tags", tag_set, "is not in available tag set,\
                    please use one or more of these", avail_tags, "Using --saved_model_tags")
      exit(1)


  sig_def = saved_model_utils.get_meta_graph_def(input_model_dir, inp_tags)
  pretrained_model = load.load(input_model_dir, tag_set)
  if sig_key not in list(pretrained_model.signatures.keys()):
    print (sig_key, "is not in ", list(pretrained_model.signatures.keys()),
            "provide one of those using --signature_key")
    exit(1)

  infer = pretrained_model.signatures[sig_key]
  frozen_func = convert_to_constants.convert_variables_to_constants_v2(infer,lower_control_flow=True)

  frozen_func.graph.structured_outputs = nest.pack_sequence_as(
        infer.graph.structured_outputs,
        frozen_func.graph.structured_outputs)
  souts = frozen_func.graph.structured_outputs
  inputs = frozen_func.inputs
  input_nodes = [(tensor.name.split(":"))[0] for tensor in inputs]
  output_nodes = [(souts[name].name.split(":"))[0] for name in souts]

  gdef = frozen_func.graph.as_graph_def()
  opt_graph = optimize_for_inference_lib.optimize_for_inference(gdef, input_nodes, output_nodes,
           [tensor.dtype.as_datatype_enum for tensor in inputs] )

  with session.Session() as sess:
    graph = importer.import_graph_def(opt_graph,name="")

    signature_inputs = {(tensor.name.split(":"))[0]: model_utils.build_tensor_info(tensor)
                        for tensor in inputs}
    signature_outputs = {name: model_utils.build_tensor_info(souts[name])
                         for name in souts}
    signature_def = signature_def_utils.build_signature_def(
        signature_inputs, signature_outputs,
        signature_constants.PREDICT_METHOD_NAME)
    signature_def_map = {
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
        }
    builder = saved_model_builder.SavedModelBuilder(output_model_dir)
    builder.add_meta_graph_and_variables(sess, tags=[tag_constants.SERVING],
            signature_def_map=signature_def_map)
    builder.save()
示例#31
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_whitelist="",
                 variable_names_blacklist="",
                 input_meta_graph=None,
                 input_saved_model_dir=None,
                 saved_model_tags=tag_constants.SERVING,
                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph: A `GraphDef` file to load.
    input_saver: A TensorFlow Saver file.
    input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated list of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted),
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph: A `MetaGraphDef` file to load (optional).
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
                           variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format.
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2).
  Returns:
    String that is the location of frozen GraphDef.
  """
  input_graph_def = None
  if input_saved_model_dir:
    input_graph_def = saved_model_utils.get_meta_graph_def(
        input_saved_model_dir, saved_model_tags).graph_def
  elif input_graph:
    input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
  input_meta_graph_def = None
  if input_meta_graph:
    input_meta_graph_def = _parse_input_meta_graph_proto(
        input_meta_graph, input_binary)
  input_saver_def = None
  if input_saver:
    input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
  freeze_graph_with_def_protos(
      input_graph_def,
      input_saver_def,
      input_checkpoint,
      output_node_names,
      restore_op_name,
      filename_tensor_name,
      output_graph,
      clear_devices,
      initializer_nodes,
      variable_names_whitelist,
      variable_names_blacklist,
      input_meta_graph_def,
      input_saved_model_dir,
      saved_model_tags.replace(" ", "").split(","),
      checkpoint_version=checkpoint_version)
示例#32
0
def run_saved_model_with_feed_dict(saved_model_dir,
                                   tag_set,
                                   signature_def_key,
                                   input_tensor_key_feed_dict,
                                   outdir,
                                   overwrite_flag,
                                   tf_debug=False):
    """Runs SavedModel and fetch all outputs.

  Runs the input dictionary through the MetaGraphDef within a SavedModel
  specified by the given tag_set and SignatureDef. Also save the outputs to file
  if outdir is not None.

  Args:
    saved_model_dir: Directory containing the SavedModel to execute.
    tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
        string format, separated by ','. For tag-set contains multiple tags, all
        tags must be passed in.
    signature_def_key: A SignatureDef key string.
    input_tensor_key_feed_dict: A dictionary maps input keys to numpy ndarrays.
    outdir: A directory to save the outputs to. If the directory doesn't exist,
        it will be created.
    overwrite_flag: A boolean flag to allow overwrite output file if file with
        the same name exists.
    tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
        intermediate Tensor values and runtime GraphDefs while running the
        SavedModel.

  Raises:
    ValueError: When any of the input tensor keys is not valid.
    RuntimeError: An error when output file already exists and overwrite is not
    enabled.
  """
    # Get a list of output tensor names.
    meta_graph_def = saved_model_utils.get_meta_graph_def(
        saved_model_dir, tag_set)

    # Re-create feed_dict based on input tensor name instead of key as session.run
    # uses tensor name.
    inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)

    # Check if input tensor keys are valid.
    for input_key_name in input_tensor_key_feed_dict.keys():
        if input_key_name not in inputs_tensor_info.keys():
            raise ValueError(
                '"%s" is not a valid input key. Please choose from %s, or use '
                '--show option.' %
                (input_key_name,
                 '"' + '", "'.join(inputs_tensor_info.keys()) + '"'))

    inputs_feed_dict = {
        inputs_tensor_info[key].name: tensor
        for key, tensor in input_tensor_key_feed_dict.items()
    }
    # Get outputs
    outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)
    # Sort to preserve order because we need to go from value to key later.
    output_tensor_keys_sorted = sorted(outputs_tensor_info.keys())
    output_tensor_names_sorted = [
        outputs_tensor_info[tensor_key].name
        for tensor_key in output_tensor_keys_sorted
    ]

    with session.Session(graph=ops_lib.Graph()) as sess:
        loader.load(sess, tag_set.split(','), saved_model_dir)

        if tf_debug:
            sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)

        outputs = sess.run(output_tensor_names_sorted,
                           feed_dict=inputs_feed_dict)

        for i, output in enumerate(outputs):
            output_tensor_key = output_tensor_keys_sorted[i]
            print('Result for output key %s:\n%s' %
                  (output_tensor_key, output))

            # Only save if outdir is specified.
            if outdir:
                # Create directory if outdir does not exist
                if not os.path.isdir(outdir):
                    os.makedirs(outdir)
                output_full_path = os.path.join(outdir,
                                                output_tensor_key + '.npy')

                # If overwrite not enabled and file already exist, error out
                if not overwrite_flag and os.path.exists(output_full_path):
                    raise RuntimeError(
                        'Output file %s already exists. Add \"--overwrite\" to overwrite'
                        ' the existing output files.' % output_full_path)

                np.save(output_full_path, output)
                print('Output %s is saved to %s' %
                      (output_tensor_key, output_full_path))
示例#33
0
                                                        tf.float32)
  decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
  resize_shape = tf.stack([28, 28])
  resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
  resized_image = tf.image.resize_bilinear(decoded_image_4d,
                                           resize_shape_as_int)
  # 展开为1维数组
  resized_image_1d = tf.reshape(resized_image, (-1, 28 * 28))
  print(resized_image_1d.shape)
  tf.identity(resized_image_1d, name="DecodeJPGOutput")

g1def = g1.as_graph_def()

with tf.Graph().as_default() as g2:
  with tf.Session(graph=g2) as sess:
    input_graph_def = saved_model_utils.get_meta_graph_def(
        "./model", tag_constants.SERVING).graph_def

    tf.saved_model.loader.load(sess, ["serve"], "./model")

    g2def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        ["myOutput"],
        variable_names_whitelist=None,
        variable_names_blacklist=None)

with tf.Graph().as_default() as g_combined:
  with tf.Session(graph=g_combined) as sess:

    x = tf.placeholder(tf.string, name="base64_input")