Exemplo n.º 1
0
    def testExportMonitorRegressionSignature(self):
        def _regression_signature(examples, unused_features, predictions):
            signatures = {}
            signatures['regression'] = (exporter.regression_signature(
                examples, predictions))
            return signatures['regression'], signatures

        random.seed(42)
        x = np.random.rand(1000)
        y = 2 * x + 3
        cont_features = [feature_column.real_valued_column('', dimension=1)]
        regressor = learn.LinearRegressor(feature_columns=cont_features)
        export_dir = os.path.join(tempfile.mkdtemp(), 'export')
        export_monitor = learn.monitors.ExportMonitor(
            every_n_steps=1,
            export_dir=export_dir,
            exports_to_keep=1,
            signature_fn=_regression_signature)
        regressor.fit(x, y, steps=10, monitors=[export_monitor])

        self.assertTrue(gfile.Exists(export_dir))
        with self.assertRaises(errors.NotFoundError):
            saver.checkpoint_exists(
                os.path.join(export_dir, '00000000', 'export'))
        self.assertTrue(
            saver.checkpoint_exists(
                os.path.join(export_dir, '00000010', 'export')))
        # Validate the signature
        signature = self._get_default_signature(
            os.path.join(export_dir, '00000010', 'export.meta'))
        self.assertTrue(signature.HasField('regression_signature'))
Exemplo n.º 2
0
  def testExportMonitorRegressionSignature(self):

    def _regression_signature(examples, unused_features, predictions):
      signatures = {}
      signatures['regression'] = (
          exporter.regression_signature(examples, predictions))
      return signatures['regression'], signatures

    random.seed(42)
    x = np.random.rand(1000)
    y = 2 * x + 3
    cont_features = [feature_column.real_valued_column('', dimension=1)]
    regressor = learn.LinearRegressor(feature_columns=cont_features)
    export_dir = os.path.join(tempfile.mkdtemp(), 'export')
    export_monitor = learn.monitors.ExportMonitor(
        every_n_steps=1,
        export_dir=export_dir,
        exports_to_keep=1,
        signature_fn=_regression_signature)
    regressor.fit(x, y, steps=10, monitors=[export_monitor])

    self.assertTrue(gfile.Exists(export_dir))
    with self.assertRaises(errors.NotFoundError):
      saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export'))
    self.assertTrue(
        saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
    # Validate the signature
    signature = self._get_default_signature(
        os.path.join(export_dir, '00000010', 'export.meta'))
    self.assertTrue(signature.HasField('regression_signature'))
Exemplo n.º 3
0
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
    variable_names_blacklist=''):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
    raise ValueError(
        'Input checkpoint "' + input_checkpoint + '" does not exist!')

  if not output_node_names:
    raise ValueError(
        'You must supply the name of a node to --output_node_names.')

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ''

  _ = importer.import_graph_def(input_graph_def, name='')

  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(saver_def=input_saver_def)
      saver.restore(sess, input_checkpoint)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ':0')
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor
      saver = saver_lib.Saver(var_list=var_list)
      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes)

    variable_names_blacklist = (variable_names_blacklist.split(',') if
                                variable_names_blacklist else None)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names.split(','),
        variable_names_blacklist=variable_names_blacklist)

  return output_graph_def
Exemplo n.º 4
0
def create_model(session, state):
    global _buckets
    # if FLAGS.qualvec or FLAGS.translate:
    #     _buckets = [(5, 10), (10, 15), (20, 25), (40, 50), (60, 70), (70, 80), (90, 100)]
    # else:
    _buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
    """Create translation model and initialize or load parameters in session."""
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = paradet_model.QualVecModel(
        FLAGS.source_vocab_size,
        FLAGS.target_vocab_size,
        FLAGS.embedding_size,
        _buckets,
        FLAGS.size,
        FLAGS.maxout_size,
        FLAGS.num_layers,
        FLAGS.max_gradient_norm,
        FLAGS.batch_size,
        FLAGS.learning_rate,
        FLAGS.learning_rate_decay_factor,
        state=state,
        dtype=dtype)
    ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    if ckpt and save_mod.checkpoint_exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.global_variables_initializer())
    return model
Exemplo n.º 5
0
def create_model(session, state):
  """Create autoencoder model and initialize or load parameters in session."""
  # dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  dtype = tf.float32
  model = autoencode_model.AutoencodeModel(
      FLAGS.source_vocab_size,
      FLAGS.target_vocab_size,
      FLAGS.embedding_size,
      _buckets,
      FLAGS.size,
      FLAGS.num_layers,
      FLAGS.max_gradient_norm,
      FLAGS.batch_size,
      FLAGS.learning_rate,
      FLAGS.learning_rate_decay_factor,
      state=state,
      dtype=dtype)
  ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
  if ckpt and save_mod.checkpoint_exists(ckpt.model_checkpoint_path):
      print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
      model.saver.restore(session, ckpt.model_checkpoint_path)
  else:
      print("Created model with fresh parameters.")
      session.run(tf.global_variables_initializer())
  return model
    def _read_latest_config_files(self, run_path_pairs):
        """Reads and returns the projector config files in every run directory."""
        configs = {}
        config_fpaths = {}
        for run_name, logdir in run_path_pairs:
            config = ProjectorConfig()
            config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
            if file_io.file_exists(config_fpath):
                file_content = file_io.read_file_to_string(config_fpath)
                text_format.Merge(file_content, config)

            has_tensor_files = False
            for embedding in config.embeddings:
                if embedding.tensor_path:
                    has_tensor_files = True
                    break

            if not config.model_checkpoint_path:
                # See if you can find a checkpoint file in the logdir.
                ckpt_path = _find_latest_checkpoint(logdir)
                if not ckpt_path and not has_tensor_files:
                    continue
                if ckpt_path:
                    config.model_checkpoint_path = ckpt_path

            # Sanity check for the checkpoint file.
            if (config.model_checkpoint_path
                    and not checkpoint_exists(config.model_checkpoint_path)):
                logging.warning('Checkpoint file %s not found',
                                config.model_checkpoint_path)
                continue
            configs[run_name] = config
            config_fpaths[run_name] = config_fpath
        return configs, config_fpaths
Exemplo n.º 7
0
def freeze_graph(input_graph: str, input_checkpoint: str,
                 output_node_names: Iterable[str], output_graph: str):
    """
    Convert all variables in a graph and checkpoint into constants and save the new graph to the specified file.

    Additionally, the graph is pruned of all nodes that are not needed for the specified outputs.

    -------------------------------------------------------
    NOTE: this function creates new nodes in the default graph, you may want to wrap the call with the following code
    -------------------------------------------------------
    with tf.Graph().as_default():
        freeze_graph(...)
    -------------------------------------------------------

    :param input_graph: path to the input graph file
    :param input_checkpoint: path to the input checkpoint
    :param output_node_names: iterable collection of output node names
    :param output_graph: path to the output frozen graph file

    Raises:
        ValueError: if any of the specified files does not exist
    """

    if not gfile.Exists(input_graph):
        raise ValueError(
            'Input graph file `{}` does not exist!'.format(input_graph))

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError(
            'Input checkpoint `{}` does not exist!'.format(input_checkpoint))

    # read the graph definition
    input_graph_def = graph_pb2.GraphDef()
    with gfile.FastGFile(input_graph, 'rb') as file:
        input_graph_def.ParseFromString(file.read())

    # remove all the explicit device specifications for this node. This helps to make the graph more portable.
    for node in input_graph_def.node:
        node.device = ''

    # restore the input graph and checkpoint
    _ = importer.import_graph_def(input_graph_def, name='')
    with session.Session() as sess:
        var_list = {}
        reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
            tensor = sess.graph.get_tensor_by_name(key + ':0')
            var_list[key] = tensor
        saver = saver_lib.Saver(var_list=var_list)
        saver.restore(sess, input_checkpoint)

        # convert all the variables to constants
        with DisabledLogger('tensorflow'), DisabledPrint():
            output_graph_def = graph_util.convert_variables_to_constants(
                sess, input_graph_def, output_node_names)

    with gfile.GFile(output_graph, 'wb') as file:
        file.write(output_graph_def.SerializeToString())
Exemplo n.º 8
0
  def _read_latest_config_files(self, run_path_pairs):
    """Reads and returns the projector config files in every run directory."""
    configs = {}
    config_fpaths = {}
    for run_name, assets_dir in run_path_pairs:
      config = projector_config_pb2.ProjectorConfig()
      config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME)
      if file_io.file_exists(config_fpath):
        file_content = file_io.read_file_to_string(config_fpath)
        text_format.Merge(file_content, config)
      has_tensor_files = False
      for embedding in config.embeddings:
        if embedding.tensor_path:
          has_tensor_files = True
          break

      if not config.model_checkpoint_path:
        # See if you can find a checkpoint file in the logdir.
        logdir = _assets_dir_to_logdir(assets_dir)
        ckpt_path = _find_latest_checkpoint(logdir)
        if not ckpt_path and not has_tensor_files:
          continue
        if ckpt_path:
          config.model_checkpoint_path = ckpt_path

      # Sanity check for the checkpoint file.
      if (config.model_checkpoint_path and
          not checkpoint_exists(config.model_checkpoint_path)):
        logging.warning('Checkpoint file "%s" not found',
                        config.model_checkpoint_path)
        continue
      configs[run_name] = config
      config_fpaths[run_name] = config_fpath
    return configs, config_fpaths
 def _assert_export(self, export_monitor, export_dir, expected_signature):
     self.assertTrue(gfile.Exists(export_dir))
     # Only the written checkpoints are exported.
     self.assertTrue(
         saver.checkpoint_exists(export_dir + '00000001/export'),
         'Exported checkpoint expected but not found: %s' %
         (export_dir + '00000001/export'))
     self.assertTrue(
         saver.checkpoint_exists(export_dir + '00000010/export'),
         'Exported checkpoint expected but not found: %s' %
         (export_dir + '00000010/export'))
     self.assertEquals(six.b(os.path.join(export_dir, '00000010')),
                       export_monitor.last_export_dir)
     # Validate the signature
     signature = self._get_default_signature(export_dir +
                                             '00000010/export.meta')
     self.assertTrue(signature.HasField(expected_signature))
Exemplo n.º 10
0
 def _assert_export(self, export_monitor, export_dir, expected_signature):
   self.assertTrue(gfile.Exists(export_dir))
   # Only the written checkpoints are exported.
   self.assertTrue(
       saver.checkpoint_exists(export_dir + '00000001/export'),
       'Exported checkpoint expected but not found: %s' %
       (export_dir + '00000001/export'))
   self.assertTrue(
       saver.checkpoint_exists(export_dir + '00000010/export'),
       'Exported checkpoint expected but not found: %s' %
       (export_dir + '00000010/export'))
   self.assertEquals(
       six.b(os.path.join(export_dir, '00000010')),
       export_monitor.last_export_dir)
   # Validate the signature
   signature = self._get_default_signature(export_dir + '00000010/export.meta')
   self.assertTrue(signature.HasField(expected_signature))
Exemplo n.º 11
0
def freeze_graph(layer, unit, input_names, output_names, accuracy):
    frozen_model_path = "data/{}layer{}unit.pb".format(layer, unit)
    checkpoint_file = "data/{}layer{}unit.ckpt".format(layer, unit)
    if not saver_lib.checkpoint_exists(checkpoint_file):
        print("Checkpoint file '" + checkpoint_file + "' doesn't exist!")
        exit(-1)

    print("begin loading model")
    saver = tf.train.import_meta_graph(checkpoint_file + '.meta',
                                       clear_devices=True)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        saver.restore(sess, checkpoint_file)
        print("model loaded")
        # export network weights and biases to text files
        if not __debug__:
            output_nodes = "w_in,b_in,w_out,b_out"
            rnn_nodes = [",rnn/multi_rnn_cell/cell_{}/basic_lstm_cell/weights," \
                         "rnn/multi_rnn_cell/cell_{}/basic_lstm_cell/biases".format(i, i) for i in range(args.layer)]
            weights = output_nodes + "".join(rnn_nodes)
            for name in weights.split(","):
                v = sess.run("{}:0".format(name))
                var_file_name = "data/{}.csv".format(name.replace("/", "_"))
                print("save {} to file: {}".format(name, var_file_name))
                np.savetxt(var_file_name, v, delimiter=",")

        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            input_graph_def,  # The graph_def is used to retrieve the nodes
            output_node_names=output_names.split(
                ","
            )  # The output node names are used to select the useful nodes
        )

        # optimize graph
        output_graph_def = opt_inference(output_graph_def,
                                         input_names.split(","),
                                         output_names.split(","),
                                         dtypes.float32.as_datatype_enum)

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(frozen_model_path, "wb") as f:
            f.write(output_graph_def.SerializeToString())
            print("frozen graph binary saved to: {}".format(frozen_model_path))

        frozen_model_text_path = "data/{}layer{}unit{}.pb.txt".format(
            layer, unit, accuracy)
        with tf.gfile.FastGFile(frozen_model_text_path, "wb") as f:
            f.write(str(output_graph_def))
            print("frozen graph text saved to: {}".format(
                frozen_model_text_path))

        print("%d ops in the final graph." % len(output_graph_def.node))
Exemplo n.º 12
0
def freeze_graph(config):
    input_names = config.input_names
    output_names = config.output_names
    if output_names is None:
        output_names = [
            "eval_concat/yp", "eval_concat/yp2", "eval_concat/wy",
            "eval_concat/loss"
        ]

    if not os.path.exists(config.output_path):
        os.makedirs(config.output_path)

    frozen_model_path = os.path.join(config.output_path, "frozen_model.pb")
    checkpoint_file = config.input_path
    if not saver_lib.checkpoint_exists(checkpoint_file):
        print("Checkpoint file '" + checkpoint_file + "' doesn't exist!")
        exit(-1)

    print("begin loading model")
    saver = tf.train.import_meta_graph(checkpoint_file + '.meta',
                                       clear_devices=config.clear_device)

    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        saver.restore(sess, checkpoint_file)
        print("model loaded")
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def, output_node_names=output_names)

        if input_names is not None:
            output_graph_def = opt_inference(output_graph_def, input_names,
                                             output_names,
                                             dtypes.float32.as_datatype_enum)

        with tf.gfile.GFile(frozen_model_path, "wb") as f:
            f.write(output_graph_def.SerializeToString())
            print("frozen graph binary saved to: {}".format(frozen_model_path))

        frozen_model_text_path = "{}.txt".format(frozen_model_path)
        with tf.gfile.FastGFile(frozen_model_text_path, "wb") as f:
            f.write(str(output_graph_def))
            print("frozen graph text saved to: {}".format(
                frozen_model_text_path))

        print("%d ops in the final graph." % len(output_graph_def.node))
Exemplo n.º 13
0
    def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
        """Wait for a checkpoint file to appear.

    Args:
      pattern: A string.
      timeout_secs: How long to wait for in seconds.
      for_checkpoint: whether we're globbing for checkpoints.
    """
        end_time = time.time() + timeout_secs
        while time.time() < end_time:
            if for_checkpoint:
                if saver_lib.checkpoint_exists(pattern):
                    return
            else:
                if len(gfile.Glob(pattern)) >= 1:
                    return
            time.sleep(0.05)
        self.assertFalse(True, "Glob never matched any file: %s" % pattern)
Exemplo n.º 14
0
  def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
    """Wait for a checkpoint file to appear.

    Args:
      pattern: A string.
      timeout_secs: How long to wait for in seconds.
      for_checkpoint: whether we're globbing for checkpoints.
    """
    end_time = time.time() + timeout_secs
    while time.time() < end_time:
      if for_checkpoint:
        if saver_lib.checkpoint_exists(pattern):
          return
      else:
        if len(gfile.Glob(pattern)) >= 1:
          return
      time.sleep(0.05)
    self.assertFalse(True, "Glob never matched any file: %s" % pattern)
Exemplo n.º 15
0
def scan_graph(input_checkpoint=None,
               input_saved_model_dir=None,
               saved_model_tags=tag_constants.SERVING):
    """extract the graph to scan from a model file."""

    if (not input_saved_model_dir and not input_checkpoint):
        print("Please specify a checkpoint or \'SavedModel\' file!")
        return -1
    if (input_saved_model_dir and input_checkpoint):
        print("Please specify only *One* model file type: \
checkpoint or \'SavedModel\'!")
        return -1

    input_graph_def = None
    if input_checkpoint:
        # now we doesn't use the variables file, but still check it for completeness
        if not saver_lib.checkpoint_exists(input_checkpoint):
            print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
            return -1
        # Build meta file path for a checkpoint
        meta_file = input_checkpoint + ".meta"
        if not gfile.Exists(meta_file):
            print("Input checkpoint meta file '" + meta_file +
                  "' doesn't exist!")
            return -1
        try:
            input_graph_def = _parse_input_meta_graph_proto(meta_file,
                                                            True).graph_def
        except:
            exctype, value = sys.exc_info()[:2]
            print("Parse checkpoint meta-graph file '%s' failed: %s(%s)" %\
                  (meta_file, exctype, value))
            return -1
    if input_saved_model_dir:
        try:
            input_graph_def = saved_model_utils.get_meta_graph_def(
                input_saved_model_dir, saved_model_tags).graph_def
        except:
            exctype, value = sys.exc_info()[:2]
            print("Parse SaveModel '%s' meta-graph file failed: %s(%s)" %\
                  (input_saved_model_dir, exctype, value))
            return -1

    return detect_ops(input_graph_def)
Exemplo n.º 16
0
    def _read_config_files(self, run_paths, logdir):
        # If there are no summary event files, the projector can still work,
        # thus treating the `logdir` as the model checkpoint directory.
        if not run_paths:
            run_paths['.'] = logdir

        configs = {}
        config_fpaths = {}
        for run_name, logdir in run_paths.items():
            config = ProjectorConfig()
            config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
            if file_io.file_exists(config_fpath):
                file_content = file_io.read_file_to_string(
                    config_fpath).decode('utf-8')
                text_format.Merge(file_content, config)

            has_tensor_files = False
            for embedding in config.embeddings:
                if embedding.tensor_path:
                    has_tensor_files = True
                    break

            if not config.model_checkpoint_path:
                # See if you can find a checkpoint file in the logdir.
                ckpt_path = latest_checkpoint(logdir)
                if not ckpt_path:
                    # Or in the parent of logdir.
                    ckpt_path = latest_checkpoint(os.path.join('../', logdir))
                    if not ckpt_path and not has_tensor_files:
                        logging.warning('Cannot find model checkpoint in %s',
                                        logdir)
                        continue
                if ckpt_path:
                    config.model_checkpoint_path = ckpt_path

            # Sanity check for the checkpoint file.
            if (config.model_checkpoint_path
                    and not checkpoint_exists(config.model_checkpoint_path)):
                logging.warning('Checkpoint file %s not found',
                                config.model_checkpoint_path)
                continue
            configs[run_name] = config
            config_fpaths[run_name] = config_fpath
        return configs, config_fpaths
Exemplo n.º 17
0
  def _read_config_files(self, run_paths, logdir):
    # If there are no summary event files, the projector can still work,
    # thus treating the `logdir` as the model checkpoint directory.
    if not run_paths:
      run_paths['.'] = logdir

    configs = {}
    config_fpaths = {}
    for run_name, logdir in run_paths.items():
      config = ProjectorConfig()
      config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
      if file_io.file_exists(config_fpath):
        file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
        text_format.Merge(file_content, config)

      has_tensor_files = False
      for embedding in config.embeddings:
        if embedding.tensor_path:
          has_tensor_files = True
          break

      if not config.model_checkpoint_path:
        # See if you can find a checkpoint file in the logdir.
        ckpt_path = latest_checkpoint(logdir)
        if not ckpt_path:
          # Or in the parent of logdir.
          ckpt_path = latest_checkpoint(os.path.join('../', logdir))
          if not ckpt_path and not has_tensor_files:
            logging.warning('Cannot find model checkpoint in %s', logdir)
            continue
        if ckpt_path:
          config.model_checkpoint_path = ckpt_path

      # Sanity check for the checkpoint file.
      if (config.model_checkpoint_path and
          not checkpoint_exists(config.model_checkpoint_path)):
        logging.warning('Checkpoint file %s not found',
                        config.model_checkpoint_path)
        continue
      configs[run_name] = config
      config_fpaths[run_name] = config_fpath
    return configs, config_fpaths
Exemplo n.º 18
0
def freeza_graph_with_def_protos(input_graph_def,input_saver_def,input_checkpoint,output_node_names,restore_op_name,filename_tensor_name,output_graph,clear_devices,initializer_nodes,variable_names_blacklist=""):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name,filename_tensor_name

    if not saver_lib.checkpoint_exists(input_checkpoint): #????
        print("Input checkpoint'" + input_checkpoint + " 'doesn't exist!")
        return -1

    if not output_node_names:
        print("you need to supply the name of a node to --output_node_names")
        return -1

    if clear_devices: #???
        for node in input_graph_def.node: #input_graph_def存的是图结构???
            node.device = ""

    _ = importer.import_graph_def(input_graph_def,name="") #???

#options ops:图上的结点
    with session.Session() as sess:
        if input_saver_def: #input_saver_def存的是variable???
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not saver_lib.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor
      saver = saver_lib.Saver(
          var_list=var_list, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
Exemplo n.º 20
0
def check_input_checkpoint(input_checkpoint):
    """Check if input_checkpoint is a valid path or path prefix."""
    if not saver_lib.checkpoint_exists(input_checkpoint):
        print("Input checkpoint '{}' doesn't exist!".format(input_checkpoint))
        exit(-1)
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    output_graph,
    clear_devices,
    initializer_nodes,
    variable_names_blacklist=""):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""

  _ = importer.import_graph_def(input_graph_def, name="")

  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(saver_def=input_saver_def)
      saver.restore(sess, input_checkpoint)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor
      saver = saver_lib.Saver(var_list=var_list)
      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes)

    variable_names_blacklist = (variable_names_blacklist.split(",") if
                                variable_names_blacklist else None)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names.split(","),
        variable_names_blacklist=variable_names_blacklist)

  with gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
Exemplo n.º 22
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 clear_devices,
                                 initializer_nodes,
                                 optimize_graph=True,
                                 variable_names_blacklist=''):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError('Input checkpoint "' + input_checkpoint +
                         '" does not exist!')

    if not output_node_names:
        raise ValueError(
            'You must supply the name of a node to --output_node_names.')

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ''

    with tf.Graph().as_default():
        tf.import_graph_def(input_graph_def, name='')

        if optimize_graph:
            logging.info('Graph Rewriter optimizations enabled')
            rewrite_options = rewriter_config_pb2.RewriterConfig(
                layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
            rewrite_options.optimizers.append('pruning')
            rewrite_options.optimizers.append('constfold')
            rewrite_options.optimizers.append('layout')
            graph_options = tf.GraphOptions(rewrite_options=rewrite_options,
                                            infer_shapes=True)
        else:
            logging.info('Graph Rewriter optimizations disabled')
            graph_options = tf.GraphOptions()
        config = tf.ConfigProto(graph_options=graph_options)
        with session.Session(config=config) as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def)
                saver.restore(sess, input_checkpoint)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(
                    input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ':0')
                    except KeyError:
                        # This tensor doesn't exist in the graph (for example it's
                        # 'global_step' or a similar housekeeping element) so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes)

            variable_names_blacklist = (variable_names_blacklist.split(',')
                                        if variable_names_blacklist else None)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(','),
                variable_names_blacklist=variable_names_blacklist)

    return output_graph_def
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not saver_lib.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_parition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_parition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          print("Models containing partition variables cannot be converted "
                "from checkpoint files. Please pass in a SavedModel using "
                "the flag --input_saved_model_dir.")
          return -1
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
Exemplo n.º 24
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not saver_lib.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(saver_def=input_saver_def,
                              write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor
      saver = saver_lib.Saver(var_list=var_list,
                              write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.split(","))

    variable_names_whitelist = (variable_names_whitelist.split(",")
                                if variable_names_whitelist else None)
    variable_names_blacklist = (variable_names_blacklist.split(",")
                                if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
Exemplo n.º 25
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
  """Converts all variables in a graph and checkpoint into constants."""

  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  if not gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  mode = "rb" if input_binary else "r"
  with gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""

  _ = importer.import_graph_def(input_graph_def, name="")

  with session.Session() as sess:
    if input_saver:
      with gfile.FastGFile(input_saver, mode) as f:
        saver_def = saver_pb2.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = saver_lib.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor
      saver = saver_lib.Saver(var_list=var_list)
      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes)

    variable_names_blacklist = (variable_names_blacklist.split(",") if
                                variable_names_blacklist else None)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names.split(","),
        variable_names_blacklist=variable_names_blacklist)

  with gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
Exemplo n.º 26
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
    """Converts all variables in a graph and checkpoint into constants."""

    # Unused by updated loading code.
    del restore_op_name, filename_tensor_name

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    mode = "rb" if input_binary else "r"
    with gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""

    _ = importer.import_graph_def(input_graph_def, name="")

    with session.Session() as sess:
        if input_saver:
            with gfile.FastGFile(input_saver, mode) as f:
                saver_def = saver_pb2.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = saver_lib.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    print(key)
                    continue
                var_list[key] = tensor

            # """ Print ops name
            def _node_name(n):
                if n.startswith("^"):
                    return n[1:]
                else:
                    return n.split(":")[0]

            name_to_node_map = {}  # Keyed by node name.
            for node in input_graph_def.node:
                n = _node_name(node.name)
                name_to_node_map[n] = node

            # print(name_to_node_map.keys())
            # """
            saver = saver_lib.Saver(var_list=var_list)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes)

        variable_names_blacklist = (variable_names_blacklist.split(",")
                                    if variable_names_blacklist else None)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def, output_node_names.split(","))
        # variable_names_blacklist=variable_names_blacklist)

    with gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
Exemplo n.º 27
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
  """Converts all variables in a graph and checkpoint into constants."""

  if not gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  mode = "rb" if input_binary else "r"
  with gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read().decode("utf-8"), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = importer.import_graph_def(input_graph_def, name="")

  with session.Session() as sess:
    if input_saver:
      with gfile.FastGFile(input_saver, mode) as f:
        saver_def = saver_pb2.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = saver_lib.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
      if initializer_nodes:
        sess.run(initializer_nodes)

    variable_names_blacklist = (variable_names_blacklist.split(",") if
                                variable_names_blacklist else None)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names.split(","),
        variable_names_blacklist=variable_names_blacklist)

  with gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
Exemplo n.º 28
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
    """Converts all variables in a graph and checkpoint into constants."""

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    mode = "rb" if input_binary else "r"
    with gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""
    _ = importer.import_graph_def(input_graph_def, name="")

    with session.Session() as sess:
        if input_saver:
            with gfile.FastGFile(input_saver, mode) as f:
                saver_def = saver_pb2.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = saver_lib.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            sess.run([restore_op_name],
                     {filename_tensor_name: input_checkpoint})
            if initializer_nodes:
                sess.run(initializer_nodes)

        variable_names_blacklist = (variable_names_blacklist.split(",")
                                    if variable_names_blacklist else None)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names.split(","),
            variable_names_blacklist=variable_names_blacklist)

    with gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
Exemplo n.º 29
0
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
    optimize_graph=False,
    variable_names_blacklist=''):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
    raise ValueError(
        'Input checkpoint "' + input_checkpoint + '" does not exist!')

  if not output_node_names:
    raise ValueError(
        'You must supply the name of a node to --output_node_names.')

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ''

  with tf.Graph().as_default():
    tf.import_graph_def(input_graph_def, name='')

    if optimize_graph:
      logging.info('Graph Rewriter optimizations enabled')
      rewrite_options = rewriter_config_pb2.RewriterConfig(
          optimize_tensor_layout=True)
      rewrite_options.optimizers.append('pruning')
      rewrite_options.optimizers.append('constfold')
      rewrite_options.optimizers.append('layout')
      graph_options = tf.GraphOptions(
          rewrite_options=rewrite_options, infer_shapes=True)
    else:
      logging.info('Graph Rewriter optimizations disabled')
      graph_options = tf.GraphOptions()
    config = tf.ConfigProto(graph_options=graph_options)
    with session.Session(config=config) as sess:
      if input_saver_def:
        saver = saver_lib.Saver(saver_def=input_saver_def)
        saver.restore(sess, input_checkpoint)
      else:
        var_list = {}
        reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
          try:
            tensor = sess.graph.get_tensor_by_name(key + ':0')
          except KeyError:
            # This tensor doesn't exist in the graph (for example it's
            # 'global_step' or a similar housekeeping element) so skip it.
            continue
          var_list[key] = tensor
        saver = saver_lib.Saver(var_list=var_list)
        saver.restore(sess, input_checkpoint)
        if initializer_nodes:
          sess.run(initializer_nodes)

      variable_names_blacklist = (variable_names_blacklist.split(',') if
                                  variable_names_blacklist else None)
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.split(','),
          variable_names_blacklist=variable_names_blacklist)

  return output_graph_def