def _read_vars(self, model_dir):
   """Returns (global_step, latest_feature)."""
   with ops.Graph().as_default() as g:
     ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
     meta_filename = ckpt_path + '.meta'
     saver_lib.import_meta_graph(meta_filename)
     saver = saver_lib.Saver()
     with self.test_session(graph=g) as sess:
       saver.restore(sess, ckpt_path)
       return sess.run(ops.get_collection('my_vars'))
  def testMetaGraphSaveLoad(self):
    save_prefix = os.path.join(self.get_temp_dir(), "ckpt")
    save_graph = ops.Graph()
    with save_graph.as_default(), self.test_session(
        graph=save_graph) as session:
      partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
      with variable_scope.variable_scope("root", partitioner=partitioner):
        v0 = variable_scope.get_variable(
            "v0", dtype=dtypes.float32, shape=(10, 10))
        v0_list = v0._get_variable_list()
        v0_part = v0._get_partitions()
        self.assertEqual(len(v0_list), 5)
        self.assertAllEqual(v0_part, (5, 1))
        variables.global_variables_initializer().run()

        save_graph.get_collection_ref("partvar").append(v0)
        saver = saver_lib.Saver()
        save_graph.finalize()
        save_path = saver.save(sess=session, save_path=save_prefix)
        previous_value = session.run(
            save_graph.get_tensor_by_name(v0.name + ":0"))

    restore_graph = ops.Graph()
    with restore_graph.as_default(), self.test_session(
        graph=restore_graph) as session:
      saver = saver_lib.import_meta_graph(save_path + ".meta")
      saver.restore(sess=session, save_path=save_path)
      v0, = save_graph.get_collection_ref("partvar")
      self.assertIsInstance(v0, variables.PartitionedVariable)
      self.assertAllEqual(
          previous_value,
          session.run(restore_graph.get_tensor_by_name(v0.name + ":0")))
def graph_def_from_checkpoint(checkpoint_dir, output_node_names):
  """Converts checkpoint data to GraphDef.

  Reads the latest checkpoint data and produces a GraphDef in which the
  variables have been converted to constants.

  Args:
    checkpoint_dir: Path to the checkpoints.
    output_node_names: List of name strings for the result nodes of the graph.

  Returns:
    A GraphDef from the latest checkpoint

  Raises:
    ValueError: if no checkpoint is found
  """
  checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir)
  if checkpoint_path is None:
    raise ValueError('Could not find a checkpoint at: {0}.'
                     .format(checkpoint_dir))

  saver_for_restore = saver_lib.import_meta_graph(
      checkpoint_path + '.meta', clear_devices=True)
  with session.Session() as sess:
    saver_for_restore.restore(sess, checkpoint_path)
    graph_def = ops.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, graph_def, output_node_names)

  return output_graph_def
 def _ExportAndImportGraph(self, graph):
   """Export and import graph into a new graph."""
   meta_graph = saver_lib.export_meta_graph(
       graph=graph, collection_list=graph.get_all_collection_keys())
   graph_copy = ops.Graph()
   with graph_copy.as_default():
     _ = saver_lib.import_meta_graph(meta_graph)
   return graph_copy
 def _CopyGraph(self, graph):
   """Return a copy of graph."""
   meta_graph = saver_lib.export_meta_graph(
       graph=graph, collection_list=graph.get_all_collection_keys())
   graph_copy = ops.Graph()
   with graph_copy.as_default():
     _ = saver_lib.import_meta_graph(meta_graph)
   return graph_copy
Example #6
0
def load(sess, tags, export_dir):
  """Loads the model from a SavedModel as specified by tags.

  Args:
    sess: The TensorFlow session to restore the variables.
    tags: Set of string tags to identify the required MetaGraphDef. These should
        correspond to the tags used when saving the variables using the
        SavedModel `save()` API.
    export_dir: Directory in which the SavedModel protocol buffer and variables
        to be loaded are located.

  Returns:
    The `MetaGraphDef` protocol buffer loaded in the provided session. This
    can be used to further extract signature-defs, collection-defs, etc.

  Raises:
    RuntimeError: MetaGraphDef associated with the tags cannot be found.
  """
  # Build the SavedModel protocol buffer and find the requested meta graph def.
  saved_model = _parse_saved_model(export_dir)
  found_match = False
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == set(tags):
      meta_graph_def_to_load = meta_graph_def
      found_match = True
      break

  if not found_match:
    raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
        "[]") + " could not be found in SavedModel")

  # Build a saver by importing the meta graph def to load.
  saver = tf_saver.import_meta_graph(meta_graph_def_to_load)

  # Build the checkpoint path where the variables are located.
  variables_path = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.VARIABLES_DIRECTORY),
      compat.as_bytes(constants.VARIABLES_FILENAME))

  # Restore the variables using the built saver in the provided session.
  saver.restore(sess, variables_path)

  # Get asset tensors, if any.
  asset_tensors_dictionary = _get_asset_tensors(export_dir,
                                                meta_graph_def_to_load)

  main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
  if main_op_tensor is not None:
    sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
  else:
    legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
    if legacy_init_op_tensor is not None:
      sess.run(fetches=[legacy_init_op_tensor],
               feed_dict=asset_tensors_dictionary)

  return meta_graph_def_to_load
  def testMetagraph(self):
    with ops.Graph().as_default():
      with variable_scope.variable_scope("foo", use_resource=True):
        a = variable_scope.get_variable("a", initializer=10.0)

      momentum.MomentumOptimizer(
          learning_rate=0.001, momentum=0.1).minimize(
              a,
              colocate_gradients_with_ops=True,
              global_step=training_util.get_or_create_global_step())

      graph = ops.get_default_graph()
      meta_graph_def = saver.export_meta_graph(graph=graph)

    with ops.Graph().as_default():
      saver.import_meta_graph(meta_graph_def, import_scope="")
      meta_graph_two = saver.export_meta_graph(graph=graph)
    self.assertEqual(meta_graph_def, meta_graph_two)
Example #8
0
  def testGradientOfDeserializedCond(self):
    with ops.Graph().as_default():
      pred = array_ops.placeholder(dtypes.bool, name="pred")
      x = constant_op.constant(3.0, name="x")
      ops.add_to_collection("x", x)

      def true_fn():
        return math_ops.pow(x, 3)

      def false_fn():
        return x

      ops.add_to_collection("pred", pred)
      cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
      for c in cond:
        ops.add_to_collection("cond", c)
      meta_graph = saver.export_meta_graph()

    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        saver.import_meta_graph(meta_graph)
        x = ops.get_collection("x")[0]
        pred = ops.get_collection("pred")[0]
        cond = ops.get_collection("cond")
        cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")
        cond_grad_grad = gradients_impl.gradients(
            cond_grad, [x], name="cond_grad_grad")
        # d[x^3]/dx = 3x^2
        true_val = sess.run(cond_grad, {pred: True})
        self.assertEqual(true_val, [27.0])
        # d[x]/dx = 1
        false_val = sess.run(cond_grad, {pred: False})
        self.assertEqual(false_val, [1.0])

        true_val = sess.run(cond_grad_grad, {pred: True})
        # d2[x^3]/dx2 = 6x
        self.assertEqual(true_val, [18.0])
        false_val = sess.run(cond_grad_grad, {pred: False})
        # d2[x]/dx2 = 0
        self.assertEqual(false_val, [0.0])
Example #9
0
  def _get_default_signature(self, export_meta_filename):
    """ Gets the default signature from the export.meta file. """
    with session.Session():
      save = saver.import_meta_graph(export_meta_filename)
      meta_graph_def = save.export_meta_graph()
      collection_def = meta_graph_def.collection_def

      signatures_any = collection_def['serving_signatures'].any_list.value
      self.assertEquals(len(signatures_any), 1)
      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      default_signature = signatures.default_signature
      return default_signature
Example #10
0
  def testNoVariables(self):
    test_dir = _TestDir("no_variables")
    filename = os.path.join(test_dir, "metafile")

    input_feed_value = -10  # Arbitrary input value for feed_dict.

    orig_graph = tf.Graph()
    with self.test_session(graph=orig_graph) as sess:
      # Create a minimal graph with zero variables.
      input_tensor = tf.placeholder(tf.float32, shape=[], name="input")
      offset = tf.constant(42, dtype=tf.float32, name="offset")
      output_tensor = tf.add(input_tensor, offset, name="add_offset")

      # Add input and output tensors to graph collections.
      tf.add_to_collection("input_tensor", input_tensor)
      tf.add_to_collection("output_tensor", output_tensor)

      output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
      self.assertEqual(output_value, 32)

      # Generates MetaGraphDef.
      #
      # Note that this is calling the saver *module-level* export_meta_graph and
      # not the Saver.export_meta_graph instance-level method.
      meta_graph_def = saver_module.export_meta_graph(
          filename=filename,
          graph_def=tf.get_default_graph().as_graph_def(),
          collection_list=["input_tensor", "output_tensor"],
          saver_def=None,
      )

    # Create a clean graph and import the MetaGraphDef nodes.
    new_graph = tf.Graph()
    with self.test_session(graph=new_graph) as sess:
      # Import the previously export meta graph.
      saver_instance = saver_module.import_meta_graph(filename)
      # The saver instance should be None since there are no graph variables
      # to be restored in this case.
      self.assertIsNone(saver_instance)

      # Re-exports the current graph state for comparison to the original.
      new_meta_graph_def = saver_module.export_meta_graph(filename + "_new")
      self.assertProtoEquals(meta_graph_def, new_meta_graph_def)

      # Ensures that we can still get a reference to our graph collections.
      new_input_tensor = tf.get_collection("input_tensor")[0]
      new_output_tensor = tf.get_collection("output_tensor")[0]
      # Verifies that the new graph computes the same result as the original.
      new_output_value = sess.run(
          new_output_tensor, {new_input_tensor: input_feed_value})
      self.assertEqual(new_output_value, output_value)
 def testRestoreFromMetaGraph(self):
   logdir = self._test_dir("restore_from_meta_graph")
   with ops.Graph().as_default():
     variables.VariableV1(1, name="v0")
     sv = supervisor.Supervisor(logdir=logdir)
     sess = sv.prepare_or_wait_for_session("")
     filename = sv.saver.save(sess, sv.save_path)
     sv.stop()
   # Create a new Graph and Supervisor and recover.
   with ops.Graph().as_default():
     new_saver = saver_lib.import_meta_graph(".".join([filename, "meta"]))
     self.assertIsNotNone(new_saver)
     sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver)
     sess = sv2.prepare_or_wait_for_session("")
     self.assertEquals(1, sess.run("v0:0"))
     sv2.saver.save(sess, sv2.save_path)
     sv2.stop()
  def _testSaveRestoreUtility(self, start, break_range, stop):
    path = self._iterator_checkpoint_prefix()
    step = 0
    meta_filename = path + "-%d.meta" % step

    input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
        np.array([[12], [13], [14], [15]]), 4))
    to_concatenate_components = (np.tile(
        np.array([[5], [6], [7], [8], [9]]), 20), np.tile(
            np.array([[16], [17], [18], [19], [20]]), 15))

    with ops.Graph().as_default() as g:
      init_op, get_next = self._build_graph(input_components,
                                            to_concatenate_components)
      saver = saver_lib.Saver()
      with self.test_session(graph=g) as sess:
        sess.run(init_op)
        for i in range(start, break_range):
          result = sess.run(get_next)
          if i < 4:
            for component, result_component in zip(input_components, result):
              self.assertAllEqual(component[i], result_component)
          else:
            for component, result_component in zip(to_concatenate_components,
                                                   result):
              self.assertAllEqual(component[i - 4], result_component)
        saver.save(sess, path, step)

    with ops.Graph().as_default() as g:
      saver = saver_lib.import_meta_graph(meta_filename)
      with self.test_session(graph=g) as sess:
        get_next = nest.pack_sequence_as(("a", "b"),
                                         ops.get_collection("get_next"))
        saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
        for i in range(break_range, stop):
          result = sess.run(get_next)
          if i < 4:
            for component, result_component in zip(input_components, result):
              self.assertAllEqual(component[i], result_component)
          else:
            for component, result_component in zip(to_concatenate_components,
                                                   result):
              self.assertAllEqual(component[i - 4], result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)
  def testSaveRestoreUsingSaverFromMetaGraph(self):

    def _build_graph(start, stop):
      iterator = dataset_ops.Dataset.range(start,
                                           stop).make_initializable_iterator()
      init_op = iterator.initializer
      get_next = iterator.get_next()
      ops.add_to_collection("iterator_ops", init_op)
      ops.add_to_collection("iterator_ops", get_next)
      saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator)
      # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection
      # so that it can be automatically picked up by the Saver.
      ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
      saver = saver_lib.Saver()
      return init_op, get_next, saver

    start = 2
    stop = 10
    break_point = 5
    path = self._iterator_checkpoint_prefix()
    meta_filename = path + ".meta"

    # Execute input pipeline for a few steps and save iterator state.
    with ops.Graph().as_default() as g:
      init_op, get_next, saver = _build_graph(start, stop)
      with self.test_session(graph=g) as sess:
        sess.run(variables.global_variables_initializer())
        sess.run(init_op)
        for i in range(start, break_point):
          self.assertEqual(i, sess.run(get_next))
        saver.save(sess, path)

    # Build the saver from the MetaGraph using import_meta_graph and
    # check that the iterator state is restored.
    with ops.Graph().as_default() as g:
      saver = saver_lib.import_meta_graph(meta_filename)
      init_op, get_next = ops.get_collection("iterator_ops")
      with self.test_session(graph=g) as sess:
        saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
        for i in range(break_point, stop):
          self.assertEqual(i, sess.run(get_next))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)
Example #14
0
  def doBasicsOneExportPath(self,
                            export_path,
                            clear_devices=False,
                            global_step=GLOBAL_STEP,
                            sharded=True,
                            export_count=1):
    # Build a graph with 2 parameter nodes on different devices.
    ops.reset_default_graph()
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      # v2 is an unsaved variable derived from v0 and v1.  It is used to
      # exercise the ability to run an init op when restoring a graph.
      with sess.graph.device("/cpu:0"):
        v0 = variables.VariableV1(10, name="v0")
      with sess.graph.device("/cpu:1"):
        v1 = variables.VariableV1(20, name="v1")
      v2 = variables.VariableV1(1, name="v2", trainable=False, collections=[])
      assign_v2 = state_ops.assign(v2, math_ops.add(v0, v1))
      init_op = control_flow_ops.group(assign_v2, name="init_op")

      ops.add_to_collection("v", v0)
      ops.add_to_collection("v", v1)
      ops.add_to_collection("v", v2)

      named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1}
      signatures = {
          "foo":
              exporter.regression_signature(
                  input_tensor=v0, output_tensor=v1),
          "generic":
              exporter.generic_signature(named_tensor_bindings)
      }

      asset_filepath_orig = os.path.join(test.get_temp_dir(), "hello42.txt")
      asset_file = constant_op.constant(asset_filepath_orig, name="filename42")
      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file)

      with gfile.FastGFile(asset_filepath_orig, "w") as f:
        f.write("your data here")
      assets_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)

      ignored_asset = os.path.join(test.get_temp_dir(), "ignored.txt")
      with gfile.FastGFile(ignored_asset, "w") as f:
        f.write("additional data here")

      variables.global_variables_initializer().run()

      # Run an export.
      save = saver.Saver(
          {
              "v0": v0,
              "v1": v1
          },
          restore_sequentially=True,
          sharded=sharded,
          write_version=saver_pb2.SaverDef.V1)
      export = exporter.Exporter(save)
      compare_def = ops.get_default_graph().as_graph_def()
      export.init(
          compare_def,
          init_op=init_op,
          clear_devices=clear_devices,
          default_graph_signature=exporter.classification_signature(
              input_tensor=v0),
          named_graph_signatures=signatures,
          assets_collection=assets_collection)

      for x in range(export_count):
        export.export(
            export_path,
            constant_op.constant(global_step + x),
            sess,
            exports_to_keep=gc.largest_export_versions(2))
      # Set global_step to the last exported version, as the rest of the test
      # uses it to construct model export path, loads model from it, and does
      # verifications. We want to make sure to always use the last exported
      # version, as old ones may have be garbage-collected.
      global_step += export_count - 1

    # Restore graph.
    ops.reset_default_graph()
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      save = saver.import_meta_graph(
          os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER %
                       global_step, constants.META_GRAPH_DEF_FILENAME))
      self.assertIsNotNone(save)
      meta_graph_def = save.export_meta_graph()
      collection_def = meta_graph_def.collection_def

      # Validate custom graph_def.
      graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
      self.assertEquals(len(graph_def_any), 1)
      graph_def = graph_pb2.GraphDef()
      graph_def_any[0].Unpack(graph_def)
      if clear_devices:
        for node in compare_def.node:
          node.device = ""
      self.assertProtoEquals(compare_def, graph_def)

      # Validate init_op.
      init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
      self.assertEquals(len(init_ops), 1)
      self.assertEquals(init_ops[0], "init_op")

      # Validate signatures.
      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)
      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      default_signature = signatures.default_signature
      self.assertEqual(
          default_signature.classification_signature.input.tensor_name, "v0:0")
      bindings = signatures.named_signatures["generic"].generic_signature.map
      self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
      self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
      read_foo_signature = (
          signatures.named_signatures["foo"].regression_signature)
      self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
      self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")

      # Validate the assets.
      assets_any = collection_def[constants.ASSETS_KEY].any_list.value
      self.assertEquals(len(assets_any), 1)
      asset = manifest_pb2.AssetFile()
      assets_any[0].Unpack(asset)
      assets_path = os.path.join(export_path,
                                 constants.VERSION_FORMAT_SPECIFIER %
                                 global_step, constants.ASSETS_DIRECTORY,
                                 "hello42.txt")
      asset_contents = gfile.GFile(assets_path).read()
      self.assertEqual(asset_contents, "your data here")
      self.assertEquals("hello42.txt", asset.filename)
      self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
      ignored_asset_path = os.path.join(export_path,
                                        constants.VERSION_FORMAT_SPECIFIER %
                                        global_step, constants.ASSETS_DIRECTORY,
                                        "ignored.txt")
      self.assertFalse(gfile.Exists(ignored_asset_path))

      # Validate graph restoration.
      if sharded:
        save.restore(sess,
                     os.path.join(export_path,
                                  constants.VERSION_FORMAT_SPECIFIER %
                                  global_step,
                                  constants.VARIABLES_FILENAME_PATTERN))
      else:
        save.restore(sess,
                     os.path.join(export_path,
                                  constants.VERSION_FORMAT_SPECIFIER %
                                  global_step, constants.VARIABLES_FILENAME))
      self.assertEqual(10, ops.get_collection("v")[0].eval())
      self.assertEqual(20, ops.get_collection("v")[1].eval())
      ops.get_collection(constants.INIT_OP_KEY)[0].run()
      self.assertEqual(30, ops.get_collection("v")[2].eval())
Example #15
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)

  Returns:
    Location of the output_graph_def.
  """
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not checkpoint_management.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_parition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_parition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          print("Models containing partition variables cannot be converted "
                "from checkpoint files. Please pass in a SavedModel using "
                "the flag --input_saved_model_dir.")
          return -1
        # Models that have been frozen previously do not contain Variables.
        elif _has_no_variables(sess):
          print("No variables were found in this model. It is likely the model "
                "was frozen previously. You cannot freeze a graph twice.")
          return 0
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
Example #16
0
    def testDeployCheckpoint(self):
        input_meta_name = "original_meta.meta"
        input_meta_path = os.path.join(self.get_temp_dir(), input_meta_name)
        q_config, _ = self._compose_config()
        with ops.Graph().as_default():
            self._build_graph(is_freezed=False)
            graph_def = ops.get_default_graph().as_graph_def()
            saver_lib.export_meta_graph(filename=input_meta_path)

        original_meta_graph_def = MetaGraphDef()
        original_meta_graph_def = self._parse_def_from_file(
            original_meta_graph_def, input_meta_path)
        decent_q.quantize_train(original_meta_graph_def, q_config)

        quant_train_meta_graph_def = MetaGraphDef()
        quant_train_meta_graph_path = os.path.join(
            self.get_temp_dir(), "quantize_train/quantize_train.ckpt.meta")
        quant_train_meta_graph_def = self._parse_def_from_file(
            quant_train_meta_graph_def, quant_train_meta_graph_path)
        with ops.Graph().as_default():
            new_saver = saver_lib.import_meta_graph(quant_train_meta_graph_def)
            with session.Session() as sess:
                w_t = sess.graph.get_tensor_by_name("w/read/wquant:0")
                b_t = sess.graph.get_tensor_by_name("b/read/wquant:0")
                relu_t = sess.graph.get_tensor_by_name("relu/aquant:0")
                input_fn = self._mock_input_fn("input:0", [1, 4, 4, 3])
                init = variables.global_variables_initializer()
                sess.run(init)
                eval_relu, eval_w, eval_b = sess.run([relu_t, w_t, b_t],
                                                     feed_dict=input_fn(1))

                checkpoint_prefix = os.path.join(self.get_temp_dir(),
                                                 "ckpt/saved_checkpoint")
                checkpoint_state_name = "checkpoint_state"
                checkpoint_path = new_saver.save(
                    sess,
                    checkpoint_prefix,
                    global_step=0,
                    latest_filename=checkpoint_state_name)
        q_config.output_nodes = ["relu/aquant"]
        decent_q.quantize_evaluate(quant_train_meta_graph_def, q_config)
        quant_eval_meta_graph_def = MetaGraphDef()
        quant_eval_meta_graph_path = os.path.join(
            self.get_temp_dir(), "quantize_eval/quantize_eval.ckpt.meta")
        quant_eval_meta_graph_def = self._parse_def_from_file(
            quant_eval_meta_graph_def, quant_eval_meta_graph_path)
        sess.close()
        decent_q.deploy_checkpoint(quant_eval_meta_graph_def, checkpoint_path,
                                   q_config)
        deploy_graph_def = graph_pb2.GraphDef()
        deploy_graph_path = os.path.join(self.get_temp_dir(),
                                         "deploy/deploy_model.pb")
        deploy_graph_def = self._parse_def_from_file(deploy_graph_def,
                                                     deploy_graph_path)
        for node in deploy_graph_def.node:
            if node.name == "conv2d":
                # need to equal with quantize pos in quantize_eval_model.pb
                self.assertAllEqual(node.attr['ipos'].list.i, [8, 6])
                self.assertAllEqual(node.attr['wpos'].list.i, [8, 7])
                self.assertAllEqual(node.attr['bpos'].list.i, [8, 8])
                self.assertAllEqual(node.attr['opos'].list.i, [8, 4])
                deploy_w = tensor_util.MakeNdarray(node.attr['weights'].tensor)
                deploy_b = tensor_util.MakeNdarray(node.attr['bias'].tensor)
                self.assertNDArrayNear(deploy_w, eval_w, 1e-6)
                self.assertNDArrayNear(deploy_b, eval_b, 1e-6)
 def _import_meta_graph(self):
   meta_file_path = self._ckpt_path() + ".meta"
   return saver_lib.import_meta_graph(meta_file_path)
Example #18
0
def load_session_bundle_from_path(export_dir,
                                  target="",
                                  config=None,
                                  meta_graph_def=None):
    """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    target: The execution engine to connect to. See target in
      tf.compat.v1.Session()
    config: A ConfigProto proto with configuration options. See config in
      tf.compat.v1.Session()
    meta_graph_def: optional object of type MetaGraphDef. If this object is
      present, then it is used instead of parsing MetaGraphDef from export_dir.

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
    if not meta_graph_def:
        meta_graph_filename = os.path.join(export_dir,
                                           constants.META_GRAPH_DEF_FILENAME)
        if not file_io.file_exists(meta_graph_filename):
            raise RuntimeError("Expected meta graph file missing %s" %
                               meta_graph_filename)
        # Reads meta graph file.
        meta_graph_def = meta_graph_pb2.MetaGraphDef()
        meta_graph_def.ParseFromString(
            file_io.read_file_to_string(meta_graph_filename, binary_mode=True))

    variables_filename = ""
    variables_filename_list = []
    checkpoint_sharded = False

    variables_index_filename = os.path.join(
        export_dir, constants.VARIABLES_INDEX_FILENAME_V2)
    checkpoint_v2 = file_io.file_exists(variables_index_filename)

    # Find matching checkpoint files.
    if checkpoint_v2:
        # The checkpoint is in v2 format.
        variables_filename_pattern = os.path.join(
            export_dir, constants.VARIABLES_FILENAME_PATTERN_V2)
        variables_filename_list = file_io.get_matching_files(
            variables_filename_pattern)
        checkpoint_sharded = True
    else:
        variables_filename = os.path.join(export_dir,
                                          constants.VARIABLES_FILENAME)
        if file_io.file_exists(variables_filename):
            variables_filename_list = [variables_filename]
        else:
            variables_filename = os.path.join(
                export_dir, constants.VARIABLES_FILENAME_PATTERN)
            variables_filename_list = file_io.get_matching_files(
                variables_filename)
            checkpoint_sharded = True

    # Prepare the files to restore a session.
    if not variables_filename_list:
        restore_files = ""
    elif checkpoint_v2 or not checkpoint_sharded:
        # For checkpoint v2 or v1 with non-sharded files, use "export" to restore
        # the session.
        restore_files = constants.VARIABLES_FILENAME
    else:
        restore_files = constants.VARIABLES_FILENAME_PATTERN

    assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)

    collection_def = meta_graph_def.collection_def
    graph_def = graph_pb2.GraphDef()
    if constants.GRAPH_KEY in collection_def:
        # Use serving graph_def in MetaGraphDef collection_def if exists
        graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
        if len(graph_def_any) != 1:
            raise RuntimeError(
                "Expected exactly one serving GraphDef in : %s" %
                meta_graph_def)
        else:
            graph_def_any[0].Unpack(graph_def)
            # Replace the graph def in meta graph proto.
            meta_graph_def.graph_def.CopyFrom(graph_def)

    ops.reset_default_graph()
    sess = session.Session(target, graph=None, config=config)
    # Import the graph.
    saver = saver_lib.import_meta_graph(meta_graph_def)
    # Restore the session.
    if restore_files:
        saver.restore(sess, os.path.join(export_dir, restore_files))

    init_op_tensor = None
    if constants.INIT_OP_KEY in collection_def:
        init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
        if len(init_ops) != 1:
            raise RuntimeError("Expected exactly one serving init op in : %s" %
                               meta_graph_def)
        init_op_tensor = ops.get_collection(constants.INIT_OP_KEY)[0]

    # Create asset input tensor list.
    asset_tensor_dict = {}
    if constants.ASSETS_KEY in collection_def:
        assets_any = collection_def[constants.ASSETS_KEY].any_list.value
        for asset in assets_any:
            asset_pb = manifest_pb2.AssetFile()
            asset.Unpack(asset_pb)
            asset_tensor_dict[
                asset_pb.tensor_binding.tensor_name] = os.path.join(
                    assets_dir, asset_pb.filename)

    if init_op_tensor:
        # Run the init op.
        sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict)

    return sess, meta_graph_def
Example #19
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not saver_lib.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(saver_def=input_saver_def,
                              write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor
      saver = saver_lib.Saver(var_list=var_list,
                              write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.split(","))

    variable_names_whitelist = (variable_names_whitelist.split(",")
                                if variable_names_whitelist else None)
    variable_names_blacklist = (variable_names_blacklist.split(",")
                                if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
Example #20
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not saver_lib.checkpoint_exists(input_checkpoint)):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
            saver = saver_lib.Saver(var_list=var_list,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
Example #21
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants.
  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)
  Returns:
    Location of the output_graph_def.
  """
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not checkpoint_management.checkpoint_exists(input_checkpoint)):
        raise ValueError("Input checkpoint '" + input_checkpoint +
                         "' doesn't exist!")

    if not output_node_names:
        raise ValueError(
            "You need to supply the name of a node to --output_node_names.")

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_parition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_parition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    raise ValueError(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                # Models that have been frozen previously do not contain Variables.
                elif _has_no_variables(sess):
                    raise ValueError(
                        "No variables were found in this model. It is likely the model "
                        "was frozen previously. You cannot freeze a graph twice."
                    )
                    return 0
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
Example #22
0
 def importer():
     saver_lib.import_meta_graph(save_prefix + '.meta')
     return ops.get_default_graph().as_graph_element('output:0')
Example #23
0
ops.reset_default_graph()
sess = tf.compat.v1.InteractiveSession()
# Read the dictionary
dict_file = open(voc_file, 'r')
dict_list = dict_file.read().splitlines()
int2word = dict()
for word in dict_list:
    word_idx = len(int2word)
    int2word[word_idx] = word
dict_file.close()

######### THIS SECTION RESTORES THE IMAGE RECOGNITION MODEL FROM Semantic-Model/semantic_model.meta #########
# Restore weights
tf.compat.v1.disable_eager_execution()
saver = saver_lib.import_meta_graph(model)
saver.restore(sess, model[:-5])
#######################################

graph = tf.compat.v1.get_default_graph()

input = graph.get_tensor_by_name("model_input:0")
seq_len = graph.get_tensor_by_name("seq_lengths:0")
rnn_keep_prob = graph.get_tensor_by_name("keep_prob:0")
height_tensor = graph.get_tensor_by_name("input_height:0")
width_reduction_tensor = graph.get_tensor_by_name("width_reduction:0")
logits = tf.compat.v1.get_collection("logits")[0]
# Constants that are saved inside the model itself
WIDTH_REDUCTION, HEIGHT = sess.run([width_reduction_tensor, height_tensor])

decoded, _ = tf.nn.ctc_greedy_decoder(logits, seq_len)
 def _import_meta_graph(self):
     meta_file_path = self._ckpt_path() + ".meta"
     return saver_lib.import_meta_graph(meta_file_path)
def load(sess, tags, export_dir, **saver_kwargs):
  """Loads the model from a SavedModel as specified by tags.

  Args:
    sess: The TensorFlow session to restore the variables.
    tags: Set of string tags to identify the required MetaGraphDef. These should
        correspond to the tags used when saving the variables using the
        SavedModel `save()` API.
    export_dir: Directory in which the SavedModel protocol buffer and variables
        to be loaded are located.
    **saver_kwargs: Optional keyword arguments passed through to Saver.

  Returns:
    The `MetaGraphDef` protocol buffer loaded in the provided session. This
    can be used to further extract signature-defs, collection-defs, etc.

  Raises:
    RuntimeError: MetaGraphDef associated with the tags cannot be found.
  """
  with sess.graph.as_default():
    # Build the SavedModel protocol buffer and find requested meta graph def.
    saved_model = _parse_saved_model(export_dir)
    found_match = False
    for meta_graph_def in saved_model.meta_graphs:
      if set(meta_graph_def.meta_info_def.tags) == set(tags):
        meta_graph_def_to_load = meta_graph_def
        found_match = True
        break

    if not found_match:
      raise RuntimeError(
          "MetaGraphDef associated with tags " + str(tags).strip("[]") +
          " could not be found in SavedModel. To inspect available tag-sets in"
          " the SavedModel, please use the SavedModel CLI: `saved_model_cli`"
      )

    # Build a saver by importing the meta graph def to load.
    saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)

    if saver:
      # Build the checkpoint path where the variables are located.
      variables_path = os.path.join(
          compat.as_bytes(export_dir),
          compat.as_bytes(constants.VARIABLES_DIRECTORY),
          compat.as_bytes(constants.VARIABLES_FILENAME))

      # Restore the variables using the built saver in the provided session.
      saver.restore(sess, variables_path)
    else:
      tf_logging.info("The specified SavedModel has no variables; no "
                      "checkpoints were restored.")

    # Get asset tensors, if any.
    asset_tensors_dictionary = _get_asset_tensors(export_dir,
                                                  meta_graph_def_to_load)

    main_op_tensor = (
        _get_main_op_tensor(meta_graph_def_to_load) or
        (_get_legacy_init_op_tensor(meta_graph_def_to_load)))
    if main_op_tensor is not None:
      sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)

    return meta_graph_def_to_load
def partially_apply_saved_transform(saved_model_dir,
                                    logical_input_map,
                                    tensor_value_map=None):
    """Apply a transform graph, represented as a SavedModel, to existing Tensors.

  This adds nodes to a graph that already contains Tensors representing the
  inputs.  These input Tensors may be placeholders that will be fed when the
  graph is executed, or may be the outputs of some Ops.  Most typically, the
  input Tensors are reading and/or parsing Ops, but they could be anything--
  including the outputs of a prior application of this function using another
  transform graph.

  This function operates on the default Graph in the default Session, and so
  must be called within a context where these are provided.

  Args:
    saved_model_dir: A SavedModel directory providing a transform
      graph.  The MetaGraphDef and signature are selected from the SavedModel
      using keys defined in `../constants.py` ('transform' and
      'transform_signature', respectively).
    logical_input_map: a dict of logical name to Tensor.  The logical names must
      be a subset of those in the input signature of the transform graph, and
      the corresponding Tensors must have the expected types and shapes.
    tensor_value_map: a dict of tensor names to values.

  Returns:
    A pair of (unbound_inputs, outputs) where unbound_inputs is a dict of
    logical name to Tensors that are yet to be mapped or fed, and outputs is
    a dict of logical name to Tensor, as provided by the output signature
    of the transform graph

  Raises:
    ValueError: if the provided input_tensors dict has keys that are not part
      of the input signature, or any of the provided inputs have the wrong
      type or shape.
    RuntimeError: if there is no default graph available to which to apply the
      transform.
  """
    decomposed_input_tensors = _decompose_sparse_tensors(logical_input_map)

    meta_graph_def, input_signature, output_signature = (
        _load_transform_saved_model(saved_model_dir))

    # Check for inputs that were not part of the input signature.
    unexpected_inputs = (set(six.iterkeys(decomposed_input_tensors)) -
                         set(six.iterkeys(input_signature)))
    if unexpected_inputs:
        raise ValueError('Unexpected inputs '
                         'to transform: {}'.format(unexpected_inputs))

    # Create a map from tensor names in the graph to be imported, to the tensors
    # specified in `input_tensors`.
    input_map = {
        input_signature[decomposed_logical_name]:
        decomposed_input_tensors[decomposed_logical_name]
        for decomposed_logical_name in decomposed_input_tensors
    }
    if tensor_value_map:
        input_map.update({
            name: tf.constant(value)
            for name, value in six.iteritems(tensor_value_map)
        })

    graph = tf.get_default_graph()
    if graph is None:
        raise RuntimeError('apply_saved_transform() requires a default graph.')

    # unique_name may produce e.g. transform_5.  The result has no trailing slash.
    scope = graph.unique_name('transform', mark_as_used=False)

    # Load the transform graph, applying it to existing Tensors via input_map.
    # Throws ValueError if the input_map gives mismatched types or shapes.
    saver = tf_saver.import_meta_graph(meta_graph_def,
                                       import_scope=scope,
                                       input_map=input_map)
    if saver:
        tf.logging.warn(
            'Transform graphs should not have saved Variables, but this '
            'one does.  Variable values will *not* be restored.')

    # Add computed output tensors to the output.  There are two cases.  When the
    # output is not in the input_map, then we look up the tensor in the imported
    # graph by preprending the import scope and looking up the tensor by name.
    # This will fail if the expected output tensor is not now in the graph
    # under the expected name scope.  When the output is in the input map, then
    # that tensor will have been re-mapped so we use the tensor given in the
    # input_map.
    def lookup_remapped_tensor(tensor_name):
        if tensor_name in input_map:
            return input_map[tensor_name]
        else:
            return graph.get_tensor_by_name(
                ops.prepend_name_scope(tensor_name, scope))

    decomposed_output_tensors = {
        decomposed_logical_name: lookup_remapped_tensor(tensor_name)
        for decomposed_logical_name, tensor_name in six.iteritems(
            output_signature)
    }
    # Do the same for input tensors, where we assume such tensors are not in the
    # input_map since identical tensors in an input_map would be an error.
    decomposed_unbound_input_tensors = {
        decomposed_logical_name:
        graph.get_tensor_by_name(ops.prepend_name_scope(tensor_name, scope))
        for decomposed_logical_name, tensor_name in six.iteritems(
            input_signature)
        if decomposed_logical_name not in decomposed_input_tensors
    }

    outputs = _recompose_sparse_tensors(decomposed_output_tensors)
    unbound_inputs = _recompose_sparse_tensors(
        decomposed_unbound_input_tensors)
    return unbound_inputs, outputs
Example #27
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not checkpoint_management.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_parition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_parition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          print("Models containing partition variables cannot be converted "
                "from checkpoint files. Please pass in a SavedModel using "
                "the flag --input_saved_model_dir.")
          return -1
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
Example #28
0
  def doBasicsOneExportPath(self,
                            export_path,
                            clear_devices=False,
                            global_step=GLOBAL_STEP,
                            sharded=True,
                            export_count=1):
    # Build a graph with 2 parameter nodes on different devices.
    ops.reset_default_graph()
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      # v2 is an unsaved variable derived from v0 and v1.  It is used to
      # exercise the ability to run an init op when restoring a graph.
      with sess.graph.device("/cpu:0"):
        v0 = variables.VariableV1(10, name="v0")
      with sess.graph.device("/cpu:1"):
        v1 = variables.VariableV1(20, name="v1")
      v2 = variables.VariableV1(1, name="v2", trainable=False, collections=[])
      assign_v2 = state_ops.assign(v2, math_ops.add(v0, v1))
      init_op = control_flow_ops.group(assign_v2, name="init_op")

      ops.add_to_collection("v", v0)
      ops.add_to_collection("v", v1)
      ops.add_to_collection("v", v2)

      named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1}
      signatures = {
          "foo":
              exporter.regression_signature(
                  input_tensor=v0, output_tensor=v1),
          "generic":
              exporter.generic_signature(named_tensor_bindings)
      }

      asset_filepath_orig = os.path.join(test.get_temp_dir(), "hello42.txt")
      asset_file = constant_op.constant(asset_filepath_orig, name="filename42")
      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file)

      with gfile.GFile(asset_filepath_orig, "w") as f:
        f.write("your data here")
      assets_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)

      ignored_asset = os.path.join(test.get_temp_dir(), "ignored.txt")
      with gfile.GFile(ignored_asset, "w") as f:
        f.write("additional data here")

      variables.global_variables_initializer().run()

      # Run an export.
      save = saver.Saver(
          {
              "v0": v0,
              "v1": v1
          },
          restore_sequentially=True,
          sharded=sharded,
          write_version=saver_pb2.SaverDef.V1)
      export = exporter.Exporter(save)
      compare_def = ops.get_default_graph().as_graph_def()
      export.init(
          compare_def,
          init_op=init_op,
          clear_devices=clear_devices,
          default_graph_signature=exporter.classification_signature(
              input_tensor=v0),
          named_graph_signatures=signatures,
          assets_collection=assets_collection)

      for x in range(export_count):
        export.export(
            export_path,
            constant_op.constant(global_step + x),
            sess,
            exports_to_keep=gc.largest_export_versions(2))
      # Set global_step to the last exported version, as the rest of the test
      # uses it to construct model export path, loads model from it, and does
      # verifications. We want to make sure to always use the last exported
      # version, as old ones may have be garbage-collected.
      global_step += export_count - 1

    # Restore graph.
    ops.reset_default_graph()
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      save = saver.import_meta_graph(
          os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER %
                       global_step, constants.META_GRAPH_DEF_FILENAME))
      self.assertIsNotNone(save)
      meta_graph_def = save.export_meta_graph()
      collection_def = meta_graph_def.collection_def

      # Validate custom graph_def.
      graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
      self.assertEquals(len(graph_def_any), 1)
      graph_def = graph_pb2.GraphDef()
      graph_def_any[0].Unpack(graph_def)
      if clear_devices:
        for node in compare_def.node:
          node.device = ""
      self.assertProtoEquals(compare_def, graph_def)

      # Validate init_op.
      init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
      self.assertEquals(len(init_ops), 1)
      self.assertEquals(init_ops[0], "init_op")

      # Validate signatures.
      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)
      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      default_signature = signatures.default_signature
      self.assertEqual(
          default_signature.classification_signature.input.tensor_name, "v0:0")
      bindings = signatures.named_signatures["generic"].generic_signature.map
      self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
      self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
      read_foo_signature = (
          signatures.named_signatures["foo"].regression_signature)
      self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
      self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")

      # Validate the assets.
      assets_any = collection_def[constants.ASSETS_KEY].any_list.value
      self.assertEquals(len(assets_any), 1)
      asset = manifest_pb2.AssetFile()
      assets_any[0].Unpack(asset)
      assets_path = os.path.join(export_path,
                                 constants.VERSION_FORMAT_SPECIFIER %
                                 global_step, constants.ASSETS_DIRECTORY,
                                 "hello42.txt")
      asset_contents = gfile.GFile(assets_path).read()
      self.assertEqual(asset_contents, "your data here")
      self.assertEquals("hello42.txt", asset.filename)
      self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
      ignored_asset_path = os.path.join(export_path,
                                        constants.VERSION_FORMAT_SPECIFIER %
                                        global_step, constants.ASSETS_DIRECTORY,
                                        "ignored.txt")
      self.assertFalse(gfile.Exists(ignored_asset_path))

      # Validate graph restoration.
      if sharded:
        save.restore(sess,
                     os.path.join(export_path,
                                  constants.VERSION_FORMAT_SPECIFIER %
                                  global_step,
                                  constants.VARIABLES_FILENAME_PATTERN))
      else:
        save.restore(sess,
                     os.path.join(export_path,
                                  constants.VERSION_FORMAT_SPECIFIER %
                                  global_step, constants.VARIABLES_FILENAME))
      self.assertEqual(10, ops.get_collection("v")[0].eval())
      self.assertEqual(20, ops.get_collection("v")[1].eval())
      ops.get_collection(constants.INIT_OP_KEY)[0].run()
      self.assertEqual(30, ops.get_collection("v")[2].eval())
Example #29
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_parition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_parition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    print(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                    return -1
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
Example #30
0
def load_session_bundle_from_path(export_dir,
                                  target="",
                                  config=None,
                                  meta_graph_def=None):
  """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    target: The execution engine to connect to. See target in
      tf.compat.v1.Session()
    config: A ConfigProto proto with configuration options. See config in
      tf.compat.v1.Session()
    meta_graph_def: optional object of type MetaGraphDef. If this object is
      present, then it is used instead of parsing MetaGraphDef from export_dir.

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
  if not meta_graph_def:
    meta_graph_filename = os.path.join(export_dir,
                                       constants.META_GRAPH_DEF_FILENAME)
    if not file_io.file_exists(meta_graph_filename):
      raise RuntimeError("Expected meta graph file missing %s" %
                         meta_graph_filename)
    # Reads meta graph file.
    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    meta_graph_def.ParseFromString(
        file_io.read_file_to_string(meta_graph_filename, binary_mode=True))

  variables_filename = ""
  variables_filename_list = []
  checkpoint_sharded = False

  variables_index_filename = os.path.join(export_dir,
                                          constants.VARIABLES_INDEX_FILENAME_V2)
  checkpoint_v2 = file_io.file_exists(variables_index_filename)

  # Find matching checkpoint files.
  if checkpoint_v2:
    # The checkpoint is in v2 format.
    variables_filename_pattern = os.path.join(
        export_dir, constants.VARIABLES_FILENAME_PATTERN_V2)
    variables_filename_list = file_io.get_matching_files(
        variables_filename_pattern)
    checkpoint_sharded = True
  else:
    variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME)
    if file_io.file_exists(variables_filename):
      variables_filename_list = [variables_filename]
    else:
      variables_filename = os.path.join(export_dir,
                                        constants.VARIABLES_FILENAME_PATTERN)
      variables_filename_list = file_io.get_matching_files(variables_filename)
      checkpoint_sharded = True

  # Prepare the files to restore a session.
  if not variables_filename_list:
    restore_files = ""
  elif checkpoint_v2 or not checkpoint_sharded:
    # For checkpoint v2 or v1 with non-sharded files, use "export" to restore
    # the session.
    restore_files = constants.VARIABLES_FILENAME
  else:
    restore_files = constants.VARIABLES_FILENAME_PATTERN

  assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)

  collection_def = meta_graph_def.collection_def
  graph_def = graph_pb2.GraphDef()
  if constants.GRAPH_KEY in collection_def:
    # Use serving graph_def in MetaGraphDef collection_def if exists
    graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
    if len(graph_def_any) != 1:
      raise RuntimeError("Expected exactly one serving GraphDef in : %s" %
                         meta_graph_def)
    else:
      graph_def_any[0].Unpack(graph_def)
      # Replace the graph def in meta graph proto.
      meta_graph_def.graph_def.CopyFrom(graph_def)

  ops.reset_default_graph()
  sess = session.Session(target, graph=None, config=config)
  # Import the graph.
  saver = saver_lib.import_meta_graph(meta_graph_def)
  # Restore the session.
  if restore_files:
    saver.restore(sess, os.path.join(export_dir, restore_files))

  init_op_tensor = None
  if constants.INIT_OP_KEY in collection_def:
    init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
    if len(init_ops) != 1:
      raise RuntimeError("Expected exactly one serving init op in : %s" %
                         meta_graph_def)
    init_op_tensor = ops.get_collection(constants.INIT_OP_KEY)[0]

  # Create asset input tensor list.
  asset_tensor_dict = {}
  if constants.ASSETS_KEY in collection_def:
    assets_any = collection_def[constants.ASSETS_KEY].any_list.value
    for asset in assets_any:
      asset_pb = manifest_pb2.AssetFile()
      asset.Unpack(asset_pb)
      asset_tensor_dict[asset_pb.tensor_binding.tensor_name] = os.path.join(
          assets_dir, asset_pb.filename)

  if init_op_tensor:
    # Run the init op.
    sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict)

  return sess, meta_graph_def
 def importer():
   saver_lib.import_meta_graph(save_prefix + '.meta')
   return ops.get_default_graph().as_graph_element('output:0')
Example #32
0
def _partially_apply_saved_transform_impl(saved_model_dir,
                                          logical_input_map,
                                          tensor_replacement_map=None,
                                          fetch_tensor_names=None):
    """Shared code for partially_apply_saved_transform and fetch_tensor_values.

  This adds nodes to a graph that already contains Tensors representing the
  inputs.  These input Tensors may be placeholders that will be fed when the
  graph is executed, or may be the outputs of some Ops.  Most typically, the
  input Tensors are reading and/or parsing Ops, but they could be anything--
  including the outputs of a prior application of this function using another
  transform graph.

  This function operates on the default Graph in the default Session, and so
  must be called within a context where these are provided.

  Args:
    saved_model_dir: A SavedModel directory providing a transform
      graph.  The MetaGraphDef and signature are selected from the SavedModel
      using keys defined in `../constants.py` ('transform' and
      'transform_signature', respectively).
    logical_input_map: a dict of logical name to Tensor.  The logical names must
      be a subset of those in the input signature of the transform graph, and
      the corresponding Tensors must have the expected types and shapes.
    tensor_replacement_map: a dict of tensor names to `Tensors`.
    fetch_tensor_names: a list of tensor names.

  Returns:
    A tuple of (unbound_inputs, outputs, fetched_tensors) where unbound_inputs
    is a dict of logical name to Tensors that are yet to be mapped or fed,
    outputs is a dict of logical name to Tensor, as provided by the output
    signature of the transform graph, and fetched_tensors is a dict of tensor
    names to `Tensor`s where the tensor names are the names given by
    `fetched_tensor_names`.

  Raises:
    ValueError: if the provided input_tensors dict has keys that are not part
      of the input signature, or any of the provided inputs have the wrong
      type or shape.
    RuntimeError: if there is no default graph available to which to apply the
      transform.
  """
    graph = tf.get_default_graph()
    if graph is None:
        raise RuntimeError('apply_saved_transform() requires a default graph.')

    decomposed_input_tensors = _decompose_sparse_tensors(logical_input_map)

    meta_graph_def, input_signature, output_signature, asset_path_dict = (
        _load_transform_saved_model(saved_model_dir))
    asset_tensor_dict = {
        k: ops.convert_to_tensor(v)
        for k, v in asset_path_dict.items()
    }

    # Check for inputs that were not part of the input signature.
    unexpected_inputs = (set(six.iterkeys(decomposed_input_tensors)) -
                         set(six.iterkeys(input_signature)))
    if unexpected_inputs:
        raise ValueError('Unexpected inputs '
                         'to transform: {}'.format(unexpected_inputs))

    # Create a map from tensor names in the graph to be imported, to the tensors
    # specified in `input_tensors`.
    input_map = {
        input_signature[decomposed_logical_name]:
        decomposed_input_tensors[decomposed_logical_name]
        for decomposed_logical_name in decomposed_input_tensors
    }
    input_map.update(asset_tensor_dict)
    if tensor_replacement_map:
        input_map.update(tensor_replacement_map)

    # unique_name may produce e.g. transform_5.  The result has no trailing slash.
    scope = graph.unique_name('transform', mark_as_used=False)

    # unique_name returns an "absolute" name while we want a name relative to the
    # current scope.  Therefore, we check if the current name stack is non-empty,
    # and if so, strip out the existing name scope.
    if graph.get_name_scope():
        current_name_scope = graph.get_name_scope() + '/'
        assert scope.startswith(current_name_scope)
        import_scope = scope[len(current_name_scope):]
    else:
        import_scope = scope

    # Save the ASSET_FILEPATHS before importing the MetaGraphDef
    current_assets = graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)

    # Load the transform graph, applying it to existing Tensors via input_map.
    # Throws ValueError if the input_map gives mismatched types or shapes.
    saver = tf_saver.import_meta_graph(meta_graph_def,
                                       import_scope=import_scope,
                                       input_map=input_map)

    # Wipe out AssetFileDef collection; it is obsolete after loading
    graph.clear_collection(tf.saved_model.constants.ASSETS_KEY)

    # The import may have added Tensors to the ASSET_FILEPATHS collection that
    # were substituted via input_map.  To account for this, wipe out the
    # collection, restore the preexisting collection values, and then write in
    # the new substituted Tensors.
    graph.clear_collection(ops.GraphKeys.ASSET_FILEPATHS)
    for asset_path_tensor in current_assets:
        graph.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS,
                                asset_path_tensor)
    for asset_path_tensor in asset_tensor_dict.values():
        graph.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS,
                                asset_path_tensor)

    if saver:
        tf.logging.warn(
            'Transform graphs should not have saved Variables, but this '
            'one does.  Variable values will *not* be restored.')

    # Add computed output tensors to the output.  There are two cases.  When the
    # output is not in the input_map, then we look up the tensor in the imported
    # graph by prepending the import scope and looking up the tensor by name.
    # This will fail if the expected output tensor is not now in the graph
    # under the expected name scope.  When the output is in the input map, then
    # that tensor will have been re-mapped so we use the tensor given in the
    # input_map.
    def lookup_remapped_tensor(tensor_name):
        if tensor_name in input_map:
            return input_map[tensor_name]
        else:
            return graph.get_tensor_by_name(
                ops.prepend_name_scope(tensor_name, scope))

    decomposed_output_tensors = {
        decomposed_logical_name: lookup_remapped_tensor(tensor_name)
        for decomposed_logical_name, tensor_name in six.iteritems(
            output_signature)
    }
    # Do the same for input tensors, where we assume such tensors are not in the
    # input_map since identical tensors in an input_map would be an error.
    decomposed_unbound_input_tensors = {
        decomposed_logical_name:
        graph.get_tensor_by_name(ops.prepend_name_scope(tensor_name, scope))
        for decomposed_logical_name, tensor_name in six.iteritems(
            input_signature)
        if decomposed_logical_name not in decomposed_input_tensors
    }
    if fetch_tensor_names is None:
        fetch_tensor_names = []
    fetched_tensors = {
        name: lookup_remapped_tensor(name)
        for name in fetch_tensor_names
    }

    outputs = _recompose_sparse_tensors(decomposed_output_tensors)
    unbound_inputs = _recompose_sparse_tensors(
        decomposed_unbound_input_tensors)
    return unbound_inputs, outputs, fetched_tensors
            params_ori += params_layer
            logging.info(
                'Flops: {:}, Params: {:}, Input Shape: {:}, Output Shape: {:}, Kernel Shape: {:} of layer: {:}'
                .format(flops_layer, params_layer, op.inputs[0].shape,
                        op.outputs[0].shape, op.inputs[1].shape, op.name))

        elif op.type in ['MaxPool']:
            pool_size = 3
            flops_layer= op.outputs[0].shape[1] * op.outputs[0].shape[2] * \
                         op.outputs[0].shape[3] * pool_size * pool_size
            flops_ori += flops_layer
            logging.info('Flops: {:} of layer: {:}'.format(
                flops_layer, op.name))
    logging.info('Total flops: {:}, and total params: {:}'.format(
        flops_ori, params_ori))


if __name__ == '__main__':
    logging.info('Graph stat starts!')
    meta_graph = saver.import_meta_graph(args.model + '.meta',
                                         clear_devices=True)
    graph = ops.get_default_graph()

    for op in graph.get_operations():
        if (op.type == 'FusedBatchNorm') and (not 'tower' in op.name):
            logging.info('Node: {}, shape: {}'.format(
                op.name, op.outputs[0].shape.as_list()))

    flops_stat(graph, logging)
    logging.info('Graph stat ends!')
def aot_compile_cpu_meta_graph_def(checkpoint_path,
                                   meta_graph_def,
                                   output_prefix,
                                   signature_def_key,
                                   cpp_class,
                                   target_triple,
                                   target_cpu,
                                   variables_to_feed=(),
                                   multithreading=False):
    """Compile a `MetaGraphDef` to header+object files in `output_prefix`.

  Use XLA AOT (`tfcompile`) to convert the given meta graph and
  signature into a header + object files.  Also create an include makefile
  that helps identify the appropriate necessary include and library paths
  to incorporate these files into your C++ program.

  The graph is always optimized with grappler, and optionally (by default)
  variables are frozen as constants, before compilation happens.

  If the `freeze_graph` is `True`, all variables are embedded as constants
  into the graph and binary objects.  If it is `False`, then the variable
  values become inputs and outputs of the compiled class and the C++
  caller must set these values manually.

  Args:
    checkpoint_path: Python string.  Path to checkpoints/variables.
    meta_graph_def: Instance of `MetaGraphDef`.
    output_prefix: Python string.  Path prefix for outputs.
    signature_def_key: String, the signature_def to use in the SavedModel.
    cpp_class: String, Name of output C++ class.
    target_triple: String, LLVM target triple.
    target_cpu: String, LLVM target cpu name.
    variables_to_feed: A list of strings, the variables that will be fed by the
      user; these won't be frozen.  If `None`, then we will extract all the
      variables in the graph and mark them as to-feed.  The default behavior is
      an empty tuple: all variables must be frozen.
    multithreading: Whether to enable multithreading in the compiled
      computation.  Note that if using this option, the resulting object files
      may have external dependencies on multithreading libraries like nsync.

  Raises:
    RuntimeError: If tensorflow was not built with XLA.
    ImportError: If tensorflow was built with XLA but there was another
      issue importing the tfcompile python wrapper.
    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
      missing or has empty outputs.
  """
    if _pywrap_tfcompile_import_error:
        raise _pywrap_tfcompile_import_error  # pylint: disable=raising-bad-type

    else:
        # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap
        # so that we can set these directly instead of relying on env vars.
        xla_flags = os.environ.get('XLA_FLAGS')
        if not xla_flags:
            xla_flags = '--xla_cpu_multi_thread_eigen={}'.format(
                'true' if multithreading else 'false')
        else:
            xla_flags += ' --xla_cpu_multi_thread_eigen={}'.format(
                'true' if multithreading else 'false')
        os.environ['XLA_FLAGS'] = xla_flags

    signature_def_map = meta_graph_def.signature_def
    if signature_def_key not in signature_def_map:
        raise ValueError(
            'Unable to find signature_def key \'{}\' in signature def map.  '
            'Available keys: {}'.format(signature_def_key,
                                        list(signature_def_map.keys())))
    signature_def = signature_def_map[signature_def_key]
    if not signature_def.outputs:
        raise ValueError(
            'Signature key {} must have outputs, but saw none:\n{}'.format(
                signature_def_key, str(signature_def)))

    temp_dir = test.get_temp_dir()
    file_io.recursive_create_dir(temp_dir)
    if logging.get_verbosity() >= logging.INFO:
        original_graph_def_location = os.path.join(temp_dir,
                                                   'original_graph.pb')
        with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
            graph_writer.write(meta_graph_def.graph_def.SerializeToString())

    # This updates graph_def in place.
    _replace_input_placeholders_with_default_values(meta_graph_def.graph_def,
                                                    signature_def)

    graph_def = _optimize_graph(meta_graph_def, signature_def)

    all_variables = _get_variable_nodes_from_graph_def(graph_def)
    if variables_to_feed is None:
        variable_nodes_to_feed = list(all_variables.values())
    else:
        not_in_graph = set(variables_to_feed).difference(list(all_variables))
        if not_in_graph:
            raise ValueError(
                'Asked to feed variables that were not found in graph: {}.  '
                'Variables contained in the graph: {}'.format(
                    not_in_graph, list(all_variables)))
        variable_nodes_to_feed = [
            all_variables[name] for name in variables_to_feed
        ]

    if logging.get_verbosity() >= logging.INFO:
        prefrozen_graph_def_location = os.path.join(temp_dir,
                                                    'prefrozen_graph.pb')
        with file_io.FileIO(prefrozen_graph_def_location,
                            'wb') as graph_writer:
            graph_writer.write(graph_def.SerializeToString())

    # Load the Variables so that we can freeze the graph.
    with session.Session(graph=ops_lib.Graph()) as sess:
        restorer = saver_lib.import_meta_graph(meta_graph_def,
                                               clear_devices=True)
        if restorer is not None:
            restorer.restore(sess, checkpoint_path)
        graph_def.CopyFrom(
            graph_util.convert_variables_to_constants(
                sess,
                graph_def,
                output_node_names=[
                    _parse_tensor_name(n.name)[0]
                    for n in signature_def.outputs.values()
                ],
                variable_names_blacklist=[
                    n.name for n, _ in variable_nodes_to_feed
                ],
            ))

    signature_def = _prune_removed_feed_nodes(signature_def, graph_def)

    frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
    config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
    logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
    with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
        graph_writer.write(graph_def.SerializeToString())
    config = _signature_to_tf2xla_config(
        signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
    logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
    with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
        config_writer.write(str(config))

    output_dir = os.path.dirname(output_prefix)
    file_io.recursive_create_dir(output_dir)

    entry_point = re.sub('[^0-9a-zA-Z]+', '_',
                         '__xla_' + output_prefix + '__' + cpp_class)

    logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))

    makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
    with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
        makefile_writer.write(_xla_makefile_string(output_prefix))

    output_prefix = _shlex_quote(output_prefix)

    _pywrap_tfcompile.Compile(
        graph=frozen_graph_def_location,
        config=config_pbtxt_location,
        cpp_class=cpp_class,
        target_triple=target_triple,
        target_cpu=target_cpu,
        entry_point=entry_point,
        out_function_object='{}.o'.format(output_prefix),
        out_header='{}.h'.format(output_prefix),
        out_metadata_object='{}_metadata.o'.format(output_prefix),
        gen_name_to_index=True,
        # ProgramShape isn't uniquefied by entry_point.
        gen_program_shape=False)
def freeze_model(checkpoint_path: str,
                 meta_graph_def: meta_graph_pb2.MetaGraphDef,
                 output_prefix: str, signature_def_key: str,
                 variables_to_feed: List[str]) -> Tuple[str, str]:
    """Freeze a `MetaGraphDef` in preparation for tfcompile`.

  The graph is always optimized with grappler, and optionally (by default)
  variables are frozen as constants, before compilation happens.

  Args:
    checkpoint_path: Python string.  Path to checkpoints/variables.
    meta_graph_def: Instance of `MetaGraphDef`.
    output_prefix: Python string.  Path prefix for outputs.
    signature_def_key: String, the signature_def to use in the SavedModel.
    variables_to_feed: A list of strings, the variables that will be fed by the
      user; these won't be frozen.  If `None`, then we will extract all the
      variables in the graph and mark them as to-feed.  The default behavior is
      an empty tuple: all variables must be frozen.
  Returns:
    a pair containing the path to the frozen model and the path to the config.
  Raises:
    RuntimeError: If tensorflow was not built with XLA.
    ImportError: If tensorflow was built with XLA but there was another
      issue importing the tfcompile python wrapper.
    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
      missing or has empty outputs.
  """
    if _pywrap_tfcompile_import_error:
        raise _pywrap_tfcompile_import_error  # pylint: disable=raising-bad-type

    signature_def_map = meta_graph_def.signature_def
    if signature_def_key not in signature_def_map:
        raise ValueError(
            f"Unable to find signature_def_key '{signature_def_key}' in signature "
            'def map of `meta_graph_def`. Available keys: '
            f'{list(signature_def_map.keys())}')
    signature_def = signature_def_map[signature_def_key]
    if not signature_def.outputs:
        raise ValueError(
            f'Signature key {signature_def_key} must have outputs, but saw none:\n'
            f'{str(signature_def)}')

    file_io.recursive_create_dir(output_prefix)
    if logging.get_verbosity() >= logging.INFO:
        original_graph_def_location = os.path.join(output_prefix,
                                                   'original_graph.pb')
        with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
            graph_writer.write(meta_graph_def.graph_def.SerializeToString())

    # This updates graph_def in place.
    _replace_input_placeholders_with_default_values(meta_graph_def.graph_def,
                                                    signature_def)

    graph_def = _optimize_graph(meta_graph_def, signature_def)

    all_variables = _get_variable_nodes_from_graph_def(graph_def)
    if variables_to_feed is None:
        variable_nodes_to_feed = list(all_variables.values())
    else:
        not_in_graph = set(variables_to_feed).difference(list(all_variables))
        if not_in_graph:
            raise ValueError(
                'Asked to feed variables that were not found in graph: '
                f'{not_in_graph}. Variables contained in the graph: '
                f'{list(all_variables)}')
        variable_nodes_to_feed = [
            all_variables[name] for name in variables_to_feed
        ]

    if logging.get_verbosity() >= logging.INFO:
        prefrozen_graph_def_location = os.path.join(output_prefix,
                                                    'prefrozen_graph.pb')
        with file_io.FileIO(prefrozen_graph_def_location,
                            'wb') as graph_writer:
            graph_writer.write(graph_def.SerializeToString())

    # Load the Variables so that we can freeze the graph.
    with session.Session(graph=ops_lib.Graph()) as sess:
        restorer = saver_lib.import_meta_graph(meta_graph_def,
                                               clear_devices=True)
        if restorer is not None:
            restorer.restore(sess, checkpoint_path)
        graph_def.CopyFrom(
            graph_util.convert_variables_to_constants(
                sess,
                graph_def,
                output_node_names=[
                    _parse_tensor_name(n.name)[0]
                    for n in signature_def.outputs.values()
                ],
                variable_names_blacklist=[
                    n.name for n, _ in variable_nodes_to_feed
                ],
            ))

    signature_def = _prune_removed_feed_nodes(signature_def, graph_def)

    frozen_graph_def_location = os.path.join(output_prefix, 'frozen_graph.pb')
    config_pbtxt_location = os.path.join(output_prefix, 'config.pbtxt')
    logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
    with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
        graph_writer.write(graph_def.SerializeToString())
    config = _signature_to_tf2xla_config(
        signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
    logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
    with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
        config_writer.write(str(config))
    return frozen_graph_def_location, config_pbtxt_location
def _partially_apply_saved_transform_impl(saved_model_dir,
                                          logical_input_map,
                                          tensor_replacement_map=None):
    """Shared code for partially_apply_saved_transform and fetch_tensor_values.

  This adds nodes to a graph that already contains Tensors representing the
  inputs.  These input Tensors may be placeholders that will be fed when the
  graph is executed, or may be the outputs of some Ops.  Most typically, the
  input Tensors are reading and/or parsing Ops, but they could be anything--
  including the outputs of a prior application of this function using another
  transform graph.

  This function operates on the default Graph in the default Session, and so
  must be called within a context where these are provided.

  Args:
    saved_model_dir: A SavedModel directory providing a transform
      graph.  The MetaGraphDef and signature are selected from the SavedModel
      using keys defined in `../constants.py` ('transform' and
      'transform_signature', respectively).
    logical_input_map: a dict of logical name to Tensor.  The logical names must
      be a subset of those in the input signature of the transform graph, and
      the corresponding Tensors must have the expected types and shapes.
    tensor_replacement_map: a dict of tensor names to `Tensors`.

  Returns:
    A tuple of (unbound_inputs, outputs, assets_dict) where
      * unbound_inputs is a dict of logical name to Tensors that are yet to be
        mapped or fed
      * outputs is a dict of logical name to Tensor, as provided by the output
        signature of the transform graph

  Raises:
    ValueError: if the provided input_tensors dict has keys that are not part
      of the input signature, or any of the provided inputs have the wrong
      type or shape.
    RuntimeError: if there is no default graph available to which to apply the
      transform.
  """
    _maybe_register_addon_ops()
    graph = tf.compat.v1.get_default_graph()
    if graph is None:
        raise RuntimeError('apply_saved_transform() requires a default graph.')

    meta_graph_def, input_signature, output_signature, asset_path_dict = (
        _load_transform_saved_model(saved_model_dir))
    asset_tensor_dict = {
        k: tf.convert_to_tensor(v)
        for k, v in asset_path_dict.items()
    }

    # Check for inputs that were not part of the input signature.
    unexpected_inputs = (set(six.iterkeys(logical_input_map)) -
                         set(six.iterkeys(input_signature)))
    if unexpected_inputs:
        raise ValueError('Unexpected inputs '
                         'to transform: {}'.format(unexpected_inputs))

    # Create a map from tensor names in the graph to be imported, to the tensors
    # specified in `input_tensors`.
    input_map = _expand_input_map(logical_input_map, input_signature)

    input_map.update(asset_tensor_dict)
    if tensor_replacement_map:
        input_map.update(tensor_replacement_map)

    # unique_name may produce e.g. transform_5.  The result has no trailing slash.
    scope = graph.unique_name('transform', mark_as_used=False)

    # unique_name returns an "absolute" name while we want a name relative to the
    # current scope.  Therefore, we check if the current name stack is non-empty,
    # and if so, strip out the existing name scope.
    if graph.get_name_scope():
        current_name_scope = graph.get_name_scope() + '/'
        assert scope.startswith(current_name_scope)
        import_scope = scope[len(current_name_scope):]
    else:
        import_scope = scope

    # If the saved_model contained py_funcs, will reinsert them in the graph
    # here and update their associated token in the model.
    pyfunc_helper.register_pyfuncs_from_saved_transform(graph, meta_graph_def)

    # Save the ASSET_FILEPATHS before importing the MetaGraphDef
    current_assets = graph.get_collection(
        tf.compat.v1.GraphKeys.ASSET_FILEPATHS)

    # Warn user if meta_graph_def has saved variables
    if tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
        trainable_vars = meta_graph_def.collection_def[
            tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES].bytes_list.value
        if trainable_vars:
            raise ValueError(
                'The SavedModel contained trainable variables {}.  Because this '
                'function is typically called in the input_fn, trainable variables '
                'are disallowed'.format(trainable_vars))

    # Load the transform graph, applying it to existing Tensors via input_map.
    # Throws ValueError if the input_map gives mismatched types or shapes.
    saver = tf_saver.import_meta_graph(meta_graph_def,
                                       import_scope=import_scope,
                                       input_map=input_map)

    # Wipe out AssetFileDef collection; it is obsolete after loading
    graph.clear_collection(tf.saved_model.ASSETS_KEY)

    # The import may have added Tensors to the ASSET_FILEPATHS collection that
    # were substituted via input_map.  To account for this, wipe out the
    # collection, restore the preexisting collection values, and then write in
    # the new substituted Tensors.
    graph.clear_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
    for asset_path_tensor in current_assets:
        graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS,
                                asset_path_tensor)
    for asset_path_tensor in asset_tensor_dict.values():
        graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS,
                                asset_path_tensor)

    if saver:
        checkpoint_path = os.path.join(
            tf.compat.as_bytes(saved_model_dir),
            tf.compat.as_bytes(tf.saved_model.VARIABLES_DIRECTORY),
            tf.compat.as_bytes(tf.saved_model.VARIABLES_FILENAME))

        # We can't use the scope rename from init_from_checkpoint because it relies
        # on var scopes not rebuilt by import_meta_graph. So we need to construct it
        # explicitly by iterating over the variables.
        # TODO(b/78624684): remove this workaround.
        var_map = {}
        for var in tf.compat.v1.global_variables():
            var_name = var.op.name
            if not var_name.startswith(scope + '/'):
                continue

            # Generate original name before importing into scope.
            original_var_name = var_name[len(scope) + 1:]

            match = _PARTITIONED_VARIABLE_NAME_RE.match(original_var_name)
            if match:
                # If the variable is partitioned, extract the base variable name and
                # the index in the partition, then update var_map[base_name] to have
                # var_map[base_name][partition_index] = var.
                base_name = match.group(1)
                partition_index = int(match.group(2))
                if base_name not in var_map:
                    var_map[base_name] = []
                while not partition_index < len(var_map[base_name]):
                    var_map[base_name].append(None)
                assert var_map[base_name][partition_index] is None
                var_map[base_name][partition_index] = var
            else:
                var_map[original_var_name] = var

        if var_map:
            tf.compat.v1.train.init_from_checkpoint(checkpoint_path, var_map)

    # Add computed output tensors to the output.  There are two cases.  When the
    # output is not in the input_map, then we look up the tensor in the imported
    # graph by prepending the import scope and looking up the tensor by name.
    # This will fail if the expected output tensor is not now in the graph
    # under the expected name scope.  When the output is in the input map, then
    # that tensor will have been re-mapped so we use the tensor given in the
    # input_map.
    def lookup_remapped_tensor(tensor_name):
        if tensor_name in input_map:
            return input_map[tensor_name]
        else:
            return graph.get_tensor_by_name(
                ops.prepend_name_scope(tensor_name, scope))

    def lookup_tensor_or_sparse_or_composite_tensor(tensor_info):
        """Returns the remapped tensor corresponding to TensorInfo."""
        encoding = tensor_info.WhichOneof('encoding')
        if encoding == 'coo_sparse':
            return tf.SparseTensor(
                lookup_remapped_tensor(
                    tensor_info.coo_sparse.indices_tensor_name),
                lookup_remapped_tensor(
                    tensor_info.coo_sparse.values_tensor_name),
                lookup_remapped_tensor(
                    tensor_info.coo_sparse.dense_shape_tensor_name))
        elif encoding == 'composite_tensor':
            components = [
                lookup_remapped_tensor(info.name)
                for info in tensor_info.composite_tensor.components
            ]
            struct_coder = nested_structure_coder.StructureCoder()
            spec_proto = struct_pb2.StructuredValue(
                type_spec_value=tensor_info.composite_tensor.type_spec)
            spec = struct_coder.decode_proto(spec_proto)
            return spec._from_components(components)  # pylint: disable=protected-access
        elif encoding == 'name':
            return lookup_remapped_tensor(tensor_info.name)
        else:
            raise ValueError('Unsupported TensorInfo encoding %s' % encoding)

    outputs = {
        logical_name: lookup_tensor_or_sparse_or_composite_tensor(tensor_info)
        for logical_name, tensor_info in six.iteritems(output_signature)
    }
    # Do the same for input tensors, although such tensors should never be in the
    # input_map since identical tensors in an input_map would be an error.
    unbound_inputs = {
        logical_name: lookup_tensor_or_sparse_or_composite_tensor(tensor_info)
        for logical_name, tensor_info in six.iteritems(input_signature)
        if logical_name not in logical_input_map
    }

    return unbound_inputs, outputs
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
        #var_list.pop('global_step')
            saver = saver_lib.Saver(var_list=var_list)
            saver.restore(sess, input_checkpoint)

        #param = sess.run(var_list['LSTMlstmlayer3/multi_rnn_cell/cell_0/lstmlayer3/w_i_diag'])
        param = sess.run(var_list)
        save_variables_1(param, 'model.txt')


if __name__ == '__main__':

    input_meta_graph = sys.argv[1]
    input_checkpoint = sys.argv[2]
    output = sys.argv[3]
    input_meta_graph_def = _parse_input_meta_graph_proto(
        input_meta_graph, True)
    for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    with session.Session() as sess:
        restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                               clear_devices=True)
        restorer.restore(sess, input_checkpoint)
        variables = trainable_variables()
        param = sess.run(variables)
        save_variables(variables, param, output)