示例#1
0
 def testReadSavedModelInvalid(self):
     saved_model_dir = os.path.join(test.get_temp_dir(),
                                    "invalid_saved_model")
     with self.assertRaisesRegexp(
             IOError,
             "SavedModel file does not exist at: %s" % saved_model_dir):
         reader.read_saved_model(saved_model_dir)
示例#2
0
def get_serving_meta_graph_def(savedmodel_dir):
  """Extract the SERVING MetaGraphDef from a SavedModel directory.

  Args:
    savedmodel_dir: the string path to the directory containing the .pb
      and variables for a SavedModel. This is equivalent to the subdirectory
      that is created under the directory specified by --export_dir when
      running an Official Model.

  Returns:
    MetaGraphDef that should be used for tag_constants.SERVING mode.

  Raises:
    ValueError: if a MetaGraphDef matching tag_constants.SERVING is not found.
  """
  # We only care about the serving graph def
  tag_set = set([tf.saved_model.tag_constants.SERVING])
  serving_graph_def = None
  saved_model = reader.read_saved_model(savedmodel_dir)
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == tag_set:
      serving_graph_def = meta_graph_def
  if not serving_graph_def:
    raise ValueError("No MetaGraphDef found for tag_constants.SERVING. "
                     "Please make sure the SavedModel includes a SERVING def.")

  return serving_graph_def
示例#3
0
def get_meta_graph_def(saved_model_dir, tag_set):
    """Gets MetaGraphDef from SavedModel.

  Returns the MetaGraphDef for the given tag-set and SavedModel directory.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect or execute.
    tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
        separated by ','. For tag-set contains multiple tags, all tags must be
        passed in.

  Raises:
    RuntimeError: An error when the given tag-set does not exist in the
        SavedModel.

  Returns:
    A MetaGraphDef corresponding to the tag-set.
  """
    saved_model = reader.read_saved_model(saved_model_dir)
    set_of_tags = set(tag_set.split(','))
    for meta_graph_def in saved_model.meta_graphs:
        if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
            return meta_graph_def

    raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +
                       ' could not be found in SavedModel')
示例#4
0
def get_meta_graph_def(saved_model_dir, tag_set):
  """Gets MetaGraphDef from SavedModel.

  Returns the MetaGraphDef for the given tag-set and SavedModel directory.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect or execute.
    tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
        separated by ','. For tag-set contains multiple tags, all tags must be
        passed in.

  Raises:
    RuntimeError: An error when the given tag-set does not exist in the
        SavedModel.

  Returns:
    A MetaGraphDef corresponding to the tag-set.
  """
  saved_model = reader.read_saved_model(saved_model_dir)
  set_of_tags = set(tag_set.split(','))
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
      return meta_graph_def

  raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +
                     ' could not be found in SavedModel')
示例#5
0
def _get_meta_graph_def(saved_model_dir, tag_set):
    """Validate saved_model and extract MetaGraphDef.

  Args:
    saved_model_dir: saved_model path to convert.
    tag_set: Set of tag(s) of the MetaGraphDef to load.

  Returns:
    The meta_graph_def used for tflite conversion.

  Raises:
    ValueError: No valid MetaGraphDef for given tag_set.
  """
    saved_model = reader.read_saved_model(saved_model_dir)
    tag_sets = []
    result_meta_graph_def = None
    for meta_graph_def in saved_model.meta_graphs:
        meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags)
        tag_sets.append(meta_graph_tag_set)
        if meta_graph_tag_set == tag_set:
            result_meta_graph_def = meta_graph_def
    logging.info("The given saved_model contains the following tags: %s",
                 tag_sets)
    if result_meta_graph_def is not None:
        return result_meta_graph_def
    else:
        raise ValueError(
            "No valid MetaGraphDef for this tag_set '{}'. Possible "
            "values are '{}'. ".format(tag_set, tag_sets))
示例#6
0
    def _read_sessions(self, model_dirs, tag):
        """Read graph and parameters.
        Args:
            model_dirs: Model dirs saved by Estimator
            tag: Serving tag e.g. serve
        """
        for model_dir in model_dirs:
            saved_model = reader.read_saved_model(model_dir)
            meta_graph = None
            for meta_graph_def in saved_model.meta_graphs:
                if tag in meta_graph_def.meta_info_def.tags:
                    meta_graph = meta_graph_def
                    break
            if meta_graph is None:
                raise ValueError("Cannot find saved_model with tag: " + tag)
            self.signature_def_list.append(
                get_signature_def_by_key(meta_graph, "probs"))

            gpu_option = tf.GPUOptions(
                allow_growth=True,
                visible_device_list=self.config.train.visible_device_list)
            session_config = tf.ConfigProto(gpu_options=gpu_option)
            graph = tf.Graph()
            self.graph_list.append(graph)
            session = tf.Session(graph=graph, config=session_config)
            self.sess_model_list.append(session)
            with session.as_default():
                with graph.as_default():
                    loader.load(session, [tag], model_dir)
                    graph.finalize()
示例#7
0
def load_meta_graph(model_path, tags, graph, signature_def_key=None):
    saved_model = reader.read_saved_model(model_path)
    the_meta_graph = None
    for meta_graph_def in saved_model.meta_graphs:
        if sorted(meta_graph_def.meta_info_def.tags) == sorted(tags):
            the_meta_graph = meta_graph_def
    signature_def_key = signature_def_key or tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    try:
        signature_def = signature_def_utils.get_signature_def_by_key(
            the_meta_graph,
            signature_def_key)
    except ValueError as ex:
        try:
            formatted_key = 'default_input_alternative:{}'.format(
                signature_def_key)
            signature_def = signature_def_utils.get_signature_def_by_key(
                the_meta_graph, formatted_key)
        except ValueError:
            raise ValueError(
                'Got signature_def_key "{}". Available signatures are {}. '
                'Original error:\n{}'.format(
                    signature_def_key, list(the_meta_graph.signature_def), ex)
            )
    input_names = {k: v.name for k, v in signature_def.inputs.items()}
    output_names = {k: v.name for k, v in signature_def.outputs.items()}
    feed_tensors = {k: graph.get_tensor_by_name(v)
                    for k, v in input_names.items()}
    fetch_tensors = {k: graph.get_tensor_by_name(v)
                     for k, v in output_names.items()}
    return feed_tensors, fetch_tensors
示例#8
0
def get_serving_meta_graph_def(savedmodel_dir):
    """Extract the SERVING MetaGraphDef from a SavedModel directory.
    Args:
      savedmodel_dir: the string path to the directory containing the .pb
        and variables for a SavedModel. This is equivalent to the subdirectory
        that is created under the directory specified by --export_dir when
        running an Official Model.
    Returns:
      MetaGraphDef that should be used for tag_constants.SERVING mode.
    Raises:
      ValueError: if a MetaGraphDef matching tag_constants.SERVING is not found.
    """
    # We only care about the serving graph def
    tag_set = set([tf.saved_model.tag_constants.SERVING])
    serving_graph_def = None
    saved_model = reader.read_saved_model(savedmodel_dir)
    for meta_graph_def in saved_model.meta_graphs:
        if set(meta_graph_def.meta_info_def.tags) == tag_set:
            serving_graph_def = meta_graph_def
    if not serving_graph_def:
        raise ValueError(
            "No MetaGraphDef found for tag_constants.SERVING. "
            "Please make sure the SavedModel includes a SERVING def.")

    return serving_graph_def
def _get_meta_graph_def(saved_model_dir, tag_set):
  """Validate saved_model and extract MetaGraphDef.

  Args:
    saved_model_dir: saved_model path to convert.
    tag_set: Set of tag(s) of the MetaGraphDef to load.

  Returns:
    The meta_graph_def used for tflite conversion.

  Raises:
    ValueError: No valid MetaGraphDef for given tag_set.
  """
  saved_model = reader.read_saved_model(saved_model_dir)
  tag_sets = []
  result_meta_graph_def = None
  for meta_graph_def in saved_model.meta_graphs:
    meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags)
    tag_sets.append(meta_graph_tag_set)
    if meta_graph_tag_set == tag_set:
      result_meta_graph_def = meta_graph_def
  logging.info("The given saved_model contains the following tags: %s",
               tag_sets)
  if result_meta_graph_def is not None:
    return result_meta_graph_def
  else:
    raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible "
                     "values are '{}'. ".format(tag_set, tag_sets))
示例#10
0
def RunModel(saved_model_dir, signature_def_key, tag, text, ngrams_list=None):
    saved_model = reader.read_saved_model(saved_model_dir)
    meta_graph = None
    for meta_graph_def in saved_model.meta_graphs:
        if tag in meta_graph_def.meta_info_def.tags:
            meta_graph = meta_graph_def
            break
    if meta_graph_def is None:
        raise ValueError("Cannot find saved_model with tag" + tag)
    signature_def = signature_def_utils.get_signature_def_by_key(
        meta_graph, signature_def_key)
    text = text_utils.TokenizeText(text)
    ngrams = None
    if ngrams_list is not None:
        ngrams_list = text_utils.ParseNgramsOpts(ngrams_list)
        ngrams = text_utils.GenerateNgrams(text, ngrams_list)
    example = inputs.BuildTextExample(text, ngrams=ngrams)
    example = example.SerializeToString()
    inputs_feed_dict = {
        signature_def.inputs["inputs"].name: [example],
    }
    if signature_def_key == "proba":
        output_key = "scores"
    elif signature_def_key == "embedding":
        output_key = "outputs"
    else:
        raise ValueError("Unrecognised signature_def %s" % (signature_def_key))
    output_tensor = signature_def.outputs[output_key].name
    with tf.Session() as sess:
        loader.load(sess, [tag], saved_model_dir)
        outputs = sess.run(output_tensor, feed_dict=inputs_feed_dict)
        return outputs
示例#11
0
def _get_meta_graph_def(saved_model_dir, tag_set):
    saved_model = reader.read_saved_model(saved_model_dir)
    set_of_tags = set(tag_set.split(','))
    for meta_graph_def in saved_model.meta_graphs:
        if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
            return meta_graph_def

    raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +
                       ' could not be found in SavedModel')
示例#12
0
def main(args):

    bf = BaseFilter()
    tokenizer = Tokenizer()

    data_test = get_data("data/test-10.csv")
    data_test = preprocessing(bf, tokenizer, data_test)

    label = [v.strip() for v in tf.gfile.Open(FLAGS.label_file).readlines()]

    tag = FLAGS.tag
    output_key = "scores"

    saved_model = reader.read_saved_model(FLAGS.saved_model)

    meta_graph = None
    for meta_graph_def in saved_model.meta_graphs:
        if FLAGS.tag in meta_graph_def.meta_info_def.tags:
            meta_graph = meta_graph_def
            break

    if meta_graph is None:
        raise ValueError("Cannot find saved_model with tag" + FLAGS.tag)

    # print(meta_graph)

    signature_def = meta_graph.signature_def[FLAGS.signature_def]

    print(signature_def.inputs["inputs"].name)

    output_tensor = signature_def.outputs[output_key].name

    with tf.Session() as sess:
        loader.load(sess, [tag], FLAGS.saved_model)

        example = []
        for data in data_test:
            text = [tf.compat.as_bytes(x) for x in data.get("tokens")]

            record = tf.train.Example()
            record.features.feature["text"].bytes_list.value.extend(text)

            example.append(record.SerializeToString())

            inputs_feed_dict = {
                signature_def.inputs["inputs"].name:
                [record.SerializeToString()],
            }

            outputs = sess.run(output_tensor, feed_dict=inputs_feed_dict)

            index = np.argmax(outputs)
            print(" ".join(data.get("tokens")))
            print("benar" if label[index] == data.get("label") else "salah",
                  "predict: ", label[index], "harusnya:", data.get("label"))
            print()
示例#13
0
def get_meta_graph_def(saved_model_dir, tag_set):
  """
  Utility function to read a meta_graph_def from disk.
  From https://github.com/tensorflow/tensorflow/blob/8e0e8d41a3a8f2d4a6100c2ea1dc9d6c6c4ad382/tensorflow/python/tools/saved_model_cli.py#L186
  """
  saved_model = reader.read_saved_model(saved_model_dir)
  set_of_tags = set(tag_set.split(','))
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
      return meta_graph_def
  raise RuntimeError("MetaGraphDef associated with tag-set {0} could not be found in SavedModel".format(tag_set))
示例#14
0
def scan(args):
    """Function triggered by scan command.
  Args:
    args: A namespace parsed from command line.
  """
    if args.tag_set:
        scan_meta_graph_def(
            saved_model_utils.get_meta_graph_def(args.dir, args.tag_set))
    else:
        saved_model = reader.read_saved_model(args.dir)
        for meta_graph_def in saved_model.meta_graphs:
            scan_meta_graph_def(meta_graph_def)
示例#15
0
def scan(args):
  """Function triggered by scan command.

  Args:
    args: A namespace parsed from command line.
  """
  if args.tag_set:
    scan_meta_graph_def(
        saved_model_utils.get_meta_graph_def(args.dir, args.tag_set))
  else:
    saved_model = reader.read_saved_model(args.dir)
    for meta_graph_def in saved_model.meta_graphs:
      scan_meta_graph_def(meta_graph_def)
  def testReadSavedModelValid(self):
    saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model")
    builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
    builder.save()

    actual_saved_model_pb = reader.read_saved_model(saved_model_dir)
    self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1)
    self.assertEqual(
        len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1)
    self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0],
                     tag_constants.TRAINING)
示例#17
0
  def testReadSavedModelValid(self):
    saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model")
    builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
    with self.session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
    builder.save()

    actual_saved_model_pb = reader.read_saved_model(saved_model_dir)
    self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1)
    self.assertEqual(
        len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1)
    self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0],
                     tag_constants.TRAINING)
示例#18
0
def get_meta_graph_def(saved_model_dir, tags):
    """Gets `MetaGraphDef` from a directory containing a `SavedModel`.
  Returns the `MetaGraphDef` for the given tag-set and SavedModel directory.
  Args:
    saved_model_dir: Directory containing the SavedModel.
    tags: Comma separated list of tags used to identify the correct
      `MetaGraphDef`.
  Raises:
    ValueError: An error when the given tags cannot be found.
  Returns:
    A `MetaGraphDef` corresponding to the given tags.
  """
    saved_model = reader.read_saved_model(saved_model_dir)
    set_of_tags = set([tag.strip() for tag in tags.split(',')])
    for meta_graph_def in saved_model.meta_graphs:
        if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
            return meta_graph_def
    raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags))
示例#19
0
def get_meta_graph_def(saved_model_dir, tag_set):
  """Utility function to read a meta_graph_def from disk.

  From `saved_model_cli.py <https://github.com/tensorflow/tensorflow/blob/8e0e8d41a3a8f2d4a6100c2ea1dc9d6c6c4ad382/tensorflow/python/tools/saved_model_cli.py#L186>`_

  Args:
    :saved_model_dir: path to saved_model.
    :tag_set: list of string tags identifying the TensorFlow graph within the saved_model.

  Returns:
    A TensorFlow meta_graph_def, or raises an Exception otherwise.
  """
  saved_model = reader.read_saved_model(saved_model_dir)
  set_of_tags = set(tag_set.split(','))
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
      return meta_graph_def
  raise RuntimeError("MetaGraphDef associated with tag-set {0} could not be found in SavedModel".format(tag_set))
示例#20
0
def RunModel(saved_model_dir, signature_def_key, feature_text, feature_map):
    saved_model = reader.read_saved_model(saved_model_dir)
    meta_graph = None
    for meta_graph_def in saved_model.meta_graphs:
        if meta_graph_def.meta_info_def.tags == _TAG:
            meta_graph = meta_graph_def
    signature_def = signature_def_utils.get_signature_def_by_key(
        meta_graph, signature_def_key)
    features = pi.get_feature_list(feature_map, feature_text.split(" "))
    inputs_feed_dict = {signature_def.inputs["inputs"].name: features}
    if signature_def_key == "proba":
        output_key = "scores"
    elif signature_def_key == "embedding":
        output_key = "outputs"
    else:
        raise ValueError("Unrecognised signature_def %s" % (signature_def_key))
    output_tensor = signature_def.outputs[output_key].name
    with tf.Session() as sess:
        loader.load(sess, [_TAG], saved_model_dir)
        outputs = sess.run(output_tensor, feed_dict=inputs_feed_dict)
        return outputs
def get_meta_graph_def(saved_model_dir, tags):
  """Gets `MetaGraphDef` from a directory containing a `SavedModel`.

  Returns the `MetaGraphDef` for the given tag-set and SavedModel directory.

  Args:
    saved_model_dir: Directory containing the SavedModel.
    tags: Comma separated list of tags used to identify the correct
      `MetaGraphDef`.

  Raises:
    ValueError: An error when the given tags cannot be found.

  Returns:
    A `MetaGraphDef` corresponding to the given tags.
  """
  saved_model = reader.read_saved_model(saved_model_dir)
  set_of_tags = set([tag.strip() for tag in tags.split(',')])
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
      return meta_graph_def
  raise ValueError('Could not find MetaGraphDef with tags {}'.format(tags))
示例#22
0
 def testReadSavedModelInvalid(self):
   saved_model_dir = os.path.join(test.get_temp_dir(), "invalid_saved_model")
   with self.assertRaisesRegexp(
       IOError, "SavedModel file does not exist at: %s" % saved_model_dir):
     reader.read_saved_model(saved_model_dir)