示例#1
0
文件: saver.py 项目: ange3/deepcode
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(
      checkpoint_dir, latest_filename)
  f = None
  try:
    # Check that the file exists before opeining it to avoid
    # many lines of errors from colossus in the logs.
    if gfile.Exists(coord_checkpoint_filename):
      f = gfile.FastGFile(coord_checkpoint_filename, mode="r")
      ckpt = CheckpointState()
      text_format.Merge(f.read(), ckpt)
  except gfile.FileError:
    # It's ok if the file cannot be read
    return None
  except text_format.ParseError, e:
    logging.warning(str(e))
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
示例#2
0
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(
      checkpoint_dir, latest_filename)
  f = None
  try:
    # Check that the file exists before opeining it to avoid
    # many lines of errors from colossus in the logs.
    if gfile.Exists(coord_checkpoint_filename):
      f = gfile.FastGFile(coord_checkpoint_filename, mode="r")
      ckpt = CheckpointState()
      text_format.Merge(f.read(), ckpt)
  except gfile.FileError:
    # It's ok if the file cannot be read
    return None
  except text_format.ParseError, e:
    logging.warning(str(e))
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
示例#3
0
    def _MaybeDeleteOldCheckpoints(self, latest_save_path):
        """Deletes old checkpoints if necessary.

    Always keep the last max_to_keep checkpoints.  If
    keep_checkpoint_every_n_hours was specified, keep an additional checkpoint
    every N hours. For example, if N is 0.5, an additional checkpoint is kept
    for every 0.5 hours of training; if N is 10, an additional checkpoint is
    kept for every 10 hours of training.

    Args:
      latest_save_path: Name including path of checkpoint file to save.
    """
        if not self._max_to_keep:
            return
        # Remove first from list if the same name was used before.
        for p in self._last_checkpoints:
            if latest_save_path == self._CheckpointFilename(p):
                self._last_checkpoints.remove(p)
        # Append new path to list
        self._last_checkpoints.append((latest_save_path, time.time()))
        # If more than max_to_keep, remove oldest.
        if len(self._last_checkpoints) > self._max_to_keep:
            p = self._last_checkpoints.pop(0)
            # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
            # have reached N hours of training.
            should_keep = p[1] > self._next_checkpoint_time
            if should_keep:
                self._next_checkpoint_time += self._keep_checkpoint_every_n_hours * 3600
                return
            # Otherwise delete the files.
            for f in gfile.Glob(self._CheckpointFilename(p)):
                try:
                    gfile.Remove(f)
                except gfile.GOSError as e:
                    logging.warning("Ignoring: %s", str(e))
示例#4
0
def update_checkpoint_state(save_dir,
                            model_checkpoint_path,
                            all_model_checkpoint_paths=None,
                            latest_filename=None):
    """Updates the content of the 'checkpoint' file.

  This updates the checkpoint file containing a CheckpointState
  proto.

  Args:
    save_dir: Directory where the model was saved.
    model_checkpoint_path: The checkpoint file.
    all_model_checkpoint_paths: list of strings.  Paths to all not-yet-deleted
      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
      the last element must be equal to model_checkpoint_path.  These paths
      are also saved in the CheckpointState proto.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Raises:
    RuntimeError: If the save paths conflict.
  """
    if all_model_checkpoint_paths is None:
        all_model_checkpoint_paths = []

    if all_model_checkpoint_paths and all_model_checkpoint_paths[
            -1] != model_checkpoint_path:
        logging.warning(
            "%s is not in all_model_checkpoint_paths! Manually adding it.",
            model_checkpoint_path)
        all_model_checkpoint_paths.append(model_checkpoint_path)
    # Writes the "checkpoint" file for the coordinator for later restoration.
    coord_checkpoint_filename = _GetCheckpointFilename(save_dir,
                                                       latest_filename)

    # Relative paths need to be rewritten to be relative to the "save_dir".
    if not os.path.isabs(model_checkpoint_path):
        model_checkpoint_path = os.path.relpath(model_checkpoint_path,
                                                save_dir)

    all_model_checkpoint_paths = [
        os.path.relpath(p, save_dir) for p in all_model_checkpoint_paths
        if not os.path.isabs(p)
    ]

    if coord_checkpoint_filename == model_checkpoint_path:
        raise RuntimeError(
            "Save path '%s' conflicts with path used for "
            "checkpoint state.  Please use a different save path." %
            model_checkpoint_path)
    coord_checkpoint_proto = CheckpointState(
        model_checkpoint_path=model_checkpoint_path,
        all_model_checkpoint_paths=all_model_checkpoint_paths)
    f = gfile.FastGFile(coord_checkpoint_filename, mode="w")
    f.write(text_format.MessageToString(coord_checkpoint_proto))
    f.close()
示例#5
0
 def _default_global_step_tensor(self):
   try:
     gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
     if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
       return gs
     else:
       logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
       return None
   except KeyError:
     return None
示例#6
0
 def _default_global_step_tensor(self):
   try:
     gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
     if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
       return gs
     else:
       logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
       return None
   except KeyError:
     return None
示例#7
0
    def add_graph(self, graph, global_step=None, graph_def=None):
        """Adds a `Graph` to the event file.

    The graph described by the protocol buffer will be displayed by
    TensorBoard. Most users pass a graph in the constructor instead.

    Args:
      graph: A `Graph` object, such as `sess.graph`.
      global_step: Number. Optional global step counter to record with the
        graph.
      graph_def: DEPRECATED. Use the `graph` parameter instead.

    Raises:
      ValueError: If both graph and graph_def are passed to the method.
    """

        if graph is not None and graph_def is not None:
            raise ValueError(
                "Please pass only graph, or graph_def (deprecated), "
                "but not both.")

        if isinstance(graph, ops.Graph) or isinstance(graph_def, ops.Graph):
            # The user passed a `Graph`.

            # Check if the user passed it via the graph or the graph_def argument and
            # correct for that.
            if not isinstance(graph, ops.Graph):
                logging.warning(
                    "When passing a `Graph` object, please use the `graph`"
                    " named argument instead of `graph_def`.")
                graph = graph_def

            # Serialize the graph with additional info.
            true_graph_def = graph.as_graph_def(add_shapes=True)
        elif (isinstance(graph, graph_pb2.GraphDef)
              or isinstance(graph_def, graph_pb2.GraphDef)):
            # The user passed a `GraphDef`.
            logging.warning(
                "Passing a `GraphDef` to the SummaryWriter is deprecated."
                " Pass a `Graph` object instead, such as `sess.graph`.")

            # Check if the user passed it via the graph or the graph_def argument and
            # correct for that.
            if isinstance(graph, graph_pb2.GraphDef):
                true_graph_def = graph
            else:
                true_graph_def = graph_def

        else:
            # The user passed neither `Graph`, nor `GraphDef`.
            raise TypeError("The passed graph must be an instance of `Graph` "
                            "or the deprecated `GraphDef`")
        # Finally, add the graph_def to the summary writer.
        self._add_graph_def(true_graph_def, global_step)
示例#8
0
  def start_standard_services(self, sess):
    """Start the standard services for 'sess'.

    This starts services in the background.  The services started depend
    on the parameters to the constructor and may include:

      - A Summary thread computing summaries every save_summaries_secs.
      - A Checkpoint thread saving the model every every save_model_secs.
      - A StepCounter thread measure step time.

    Args:
      sess: A Session.

    Returns:
      A list of threads that are running the standard services.  You can use
      the Supervisor's Coordinator to join these threads with:
        sv.coord.Join(<list of threads>)

    Raises:
      RuntimeError: If called with a non-chief Supervisor.
      ValueError: If not `logdir` was passed to the constructor as the
        services need a log directory.
    """
    if not self._is_chief:
      raise RuntimeError("Only chief supervisor can start standard services. "
                         "Because only chief supervisors can write events.")

    if not self._logdir:
      logging.warning("Standard services need a 'logdir' "
                      "passed to the SessionManager")
      return

    if self._global_step is not None and self._summary_writer:
      # Only add the session log if we keep track of global step.
      # TensorBoard cannot use START message for purging expired events
      # if there is no step value.
      current_step = training_util.global_step(sess, self._global_step)
      self._summary_writer.add_session_log(
          SessionLog(status=SessionLog.START),
          current_step)

    threads = []
    if self._save_summaries_secs and self._summary_writer:
      if self._summary_op is not None:
        threads.append(SVSummaryThread(self, sess))
      if self._global_step is not None:
        threads.append(SVStepCounterThread(self, sess))
    if self.saver and self._save_model_secs:
      threads.append(SVTimerCheckpointThread(self, sess))
    for t in threads:
      t.start()
    self._started_threads.extend(threads)

    return threads
示例#9
0
def main(unused_argv=None):
    if FLAGS.debug:
        logging.set_verbosity(logging.DEBUG)
        logging.info('TensorBoard is in debug mode.')

    if not FLAGS.logdir:
        logging.error(
            'A logdir must be specified. Run `tensorboard --help` for '
            'details and examples.')
        return -1

    logging.info('Starting TensorBoard in directory %s', os.getcwd())

    path_to_run = ParseEventFilesFlag(FLAGS.logdir)
    logging.info('TensorBoard path_to_run is: %s', path_to_run)
    multiplexer = event_multiplexer.EventMultiplexer(
        size_guidance=TENSORBOARD_SIZE_GUIDANCE)

    def _Load():
        start = time.time()
        for (path, name) in six.iteritems(path_to_run):
            multiplexer.AddRunsFromDirectory(path, name)
        multiplexer.Reload()
        duration = time.time() - start
        logging.info('Multiplexer done loading. Load took %0.1f secs',
                     duration)
        t = threading.Timer(LOAD_INTERVAL, _Load)
        t.daemon = True
        t.start()

    t = threading.Timer(0, _Load)
    t.daemon = True
    t.start()

    factory = functools.partial(tensorboard_handler.TensorboardHandler,
                                multiplexer)
    try:
        server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory)
    except socket.error:
        logging.error(
            'Tried to connect to port %d, but that address is in use.',
            FLAGS.port)
        return -2
    try:
        tag = resource_loader.load_resource('tensorboard/TAG').strip()
        logging.info('TensorBoard is tag: %s', tag)
    except IOError:
        logging.warning('Unable to read TensorBoard tag')
        tag = ''

    status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port)
    print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port))
    print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port))
    server.serve_forever()
示例#10
0
def main(unused_argv=None):
  if FLAGS.debug:
    logging.set_verbosity(logging.DEBUG)
    logging.info('TensorBoard is in debug mode.')

  if not FLAGS.logdir:
    logging.error('A logdir must be specified. Run `tensorboard --help` for '
                  'details and examples.')
    return -1

  logging.info('Starting TensorBoard in directory %s', os.getcwd())

  path_to_run = ParseEventFilesFlag(FLAGS.logdir)
  logging.info('TensorBoard path_to_run is: %s', path_to_run)
  multiplexer = event_multiplexer.EventMultiplexer(
      size_guidance=TENSORBOARD_SIZE_GUIDANCE)
  # Ensure the Multiplexer initializes in a loaded state before it adds runs
  # So it can handle HTTP requests while runs are loading

  multiplexer.Reload()
  def _Load():
    start = time.time()
    for (path, name) in six.iteritems(path_to_run):
      multiplexer.AddRunsFromDirectory(path, name)
    multiplexer.Reload()
    duration = time.time() - start
    logging.info('Multiplexer done loading. Load took %0.1f secs', duration)
    t = threading.Timer(LOAD_INTERVAL, _Load)
    t.daemon = True
    t.start()
  t = threading.Timer(0, _Load)
  t.daemon = True
  t.start()

  factory = functools.partial(tensorboard_handler.TensorboardHandler,
                              multiplexer)
  try:
    server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory)
  except socket.error:
    logging.error('Tried to connect to port %d, but that address is in use.',
                  FLAGS.port)
    return -2
  try:
    tag = resource_loader.load_resource('tensorboard/TAG').strip()
    logging.info('TensorBoard is tag: %s', tag)
  except IOError:
    logging.warning('Unable to read TensorBoard tag')
    tag = ''

  status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port)
  print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port))
  print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port))
  server.serve_forever()
示例#11
0
def update_checkpoint_state(save_dir,
                            model_checkpoint_path,
                            all_model_checkpoint_paths=None,
                            latest_filename=None):
  """Updates the content of the 'checkpoint' file.

  This updates the checkpoint file containing a CheckpointState
  proto.

  Args:
    save_dir: Directory where the model was saved.
    model_checkpoint_path: The checkpoint file.
    all_model_checkpoint_paths: list of strings.  Paths to all not-yet-deleted
      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
      the last element must be equal to model_checkpoint_path.  These paths
      are also saved in the CheckpointState proto.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Raises:
    RuntimeError: If the save paths conflict.
  """
  if all_model_checkpoint_paths is None:
    all_model_checkpoint_paths = []

  if all_model_checkpoint_paths and all_model_checkpoint_paths[-1] != model_checkpoint_path:
    logging.warning(
        "%s is not in all_model_checkpoint_paths! Manually adding it.",
        model_checkpoint_path)
    all_model_checkpoint_paths.append(model_checkpoint_path)
  # Writes the "checkpoint" file for the coordinator for later restoration.
  coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)

  # Relative paths need to be rewritten to be relative to the "save_dir".
  if not os.path.isabs(model_checkpoint_path):
    model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)

  all_model_checkpoint_paths = [
      os.path.relpath(p, save_dir) for p in all_model_checkpoint_paths
      if not os.path.isabs(p)
  ]

  if coord_checkpoint_filename == model_checkpoint_path:
    raise RuntimeError("Save path '%s' conflicts with path used for "
                       "checkpoint state.  Please use a different save path." %
                       model_checkpoint_path)
  coord_checkpoint_proto = CheckpointState(
      model_checkpoint_path=model_checkpoint_path,
      all_model_checkpoint_paths=all_model_checkpoint_paths)
  f = gfile.FastGFile(coord_checkpoint_filename, mode="w")
  f.write(text_format.MessageToString(coord_checkpoint_proto))
  f.close()
示例#12
0
  def add_graph(self, graph, global_step=None, graph_def=None):
    """Adds a `Graph` to the event file.

    The graph described by the protocol buffer will be displayed by
    TensorBoard. Most users pass a graph in the constructor instead.

    Args:
      graph: A `Graph` object, such as `sess.graph`.
      global_step: Number. Optional global step counter to record with the
        graph.
      graph_def: DEPRECATED. Use the `graph` parameter instead.

    Raises:
      ValueError: If both graph and graph_def are passed to the method.
    """

    if graph is not None and graph_def is not None:
      raise ValueError("Please pass only graph, or graph_def (deprecated), "
                       "but not both.")

    if isinstance(graph, ops.Graph) or isinstance(graph_def, ops.Graph):
      # The user passed a `Graph`.

      # Check if the user passed it via the graph or the graph_def argument and
      # correct for that.
      if not isinstance(graph, ops.Graph):
        logging.warning("When passing a `Graph` object, please use the `graph`"
                        " named argument instead of `graph_def`.")
        graph = graph_def

      # Serialize the graph with additional info.
      true_graph_def = graph.as_graph_def(add_shapes=True)
    elif (isinstance(graph, graph_pb2.GraphDef)
          or isinstance(graph_def, graph_pb2.GraphDef)):
      # The user passed a `GraphDef`.
      logging.warning("Passing a `GraphDef` to the SummaryWriter is deprecated."
                      " Pass a `Graph` object instead, such as `sess.graph`.")

      # Check if the user passed it via the graph or the graph_def argument and
      # correct for that.
      if isinstance(graph, graph_pb2.GraphDef):
        true_graph_def = graph
      else:
        true_graph_def = graph_def

    else:
      # The user passed neither `Graph`, nor `GraphDef`.
      raise TypeError("The passed graph must be an instance of `Graph` "
                      "or the deprecated `GraphDef`")
    # Finally, add the graph_def to the summary writer.
    self._add_graph_def(true_graph_def, global_step)
示例#13
0
  def _default_global_step_tensor(self):
    """Returns the global_step from the default graph.

    Returns:
      The global step `Tensor` or `None`.
    """
    try:
      gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
      if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
        return gs
      else:
        logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
        return None
    except KeyError:
      return None
示例#14
0
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
    """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.
  """
    ckpt = None
    coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                       latest_filename)
    f = None
    try:
        # Check that the file exists before opening it to avoid
        # many lines of errors from colossus in the logs.
        if gfile.Exists(coord_checkpoint_filename):
            f = gfile.FastGFile(coord_checkpoint_filename, mode="r")
            ckpt = CheckpointState()
            text_format.Merge(f.read(), ckpt)
            # For relative model_checkpoint_path and all_model_checkpoint_paths,
            # prepend checkpoint_dir.
            if not os.path.isabs(checkpoint_dir):
                if not os.path.isabs(ckpt.model_checkpoint_path):
                    ckpt.model_checkpoint_path = os.path.join(
                        checkpoint_dir, ckpt.model_checkpoint_path)
                for i in range(len(ckpt.all_model_checkpoint_paths)):
                    p = ckpt.all_model_checkpoint_paths[i]
                    if not os.path.isabs(p):
                        ckpt.all_model_checkpoint_paths[i] = os.path.join(
                            checkpoint_dir, p)
    except IOError:
        # It's ok if the file cannot be read
        return None
    except text_format.ParseError as e:
        logging.warning(str(e))
        logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
        return None
    finally:
        if f:
            f.close()
    return ckpt
示例#15
0
def _MakeShape(v, arg_name):
  """Convert v into a TensorShapeProto."""
  # Args:
  #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
  #   arg_name: String, for error messages.

  # Returns:
  #   A TensorShapeProto.
  if isinstance(v, tensor_shape_pb2.TensorShapeProto):
    for d in v.dim:
      if d.name:
        logging.warning("Warning: TensorShapeProto with a named dimension: %s",
                        str(v))
        break
    return v
  return tensor_shape.as_shape(v).as_proto()
示例#16
0
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(
      checkpoint_dir, latest_filename)
  f = None
  try:
    # Check that the file exists before opening it to avoid
    # many lines of errors from colossus in the logs.
    if gfile.Exists(coord_checkpoint_filename):
      f = gfile.FastGFile(coord_checkpoint_filename, mode="r")
      ckpt = CheckpointState()
      text_format.Merge(f.read(), ckpt)
      # For relative model_checkpoint_path and all_model_checkpoint_paths,
      # prepend checkpoint_dir.
      if not os.path.isabs(checkpoint_dir):
        if not os.path.isabs(ckpt.model_checkpoint_path):
          ckpt.model_checkpoint_path = os.path.join(
              checkpoint_dir, ckpt.model_checkpoint_path)
        for i in range(len(ckpt.all_model_checkpoint_paths)):
          p = ckpt.all_model_checkpoint_paths[i]
          if not os.path.isabs(p):
            ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
  except IOError:
    # It's ok if the file cannot be read
    return None
  except text_format.ParseError as e:
    logging.warning(str(e))
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  finally:
    if f:
      f.close()
  return ckpt
示例#17
0
def main(unused_argv=None):
  if FLAGS.debug:
    logging.set_verbosity(logging.DEBUG)
    logging.info('TensorBoard is in debug mode.')

  if not FLAGS.logdir:
    msg = ('A logdir must be specified. Run `tensorboard --help` for '
           'details and examples.')
    logging.error(msg)
    print(msg)
    return -1

  logging.info('Starting TensorBoard in directory %s', os.getcwd())
  path_to_run = server.ParseEventFilesSpec(FLAGS.logdir)
  logging.info('TensorBoard path_to_run is: %s', path_to_run)

  multiplexer = event_multiplexer.EventMultiplexer(
      size_guidance=server.TENSORBOARD_SIZE_GUIDANCE,
      purge_orphaned_data=FLAGS.purge_orphaned_data)
  server.StartMultiplexerReloadingThread(multiplexer, path_to_run,
                                         FLAGS.reload_interval)
  try:
    tb_server = server.BuildServer(multiplexer, FLAGS.host, FLAGS.port)
  except socket.error:
    if FLAGS.port == 0:
      msg = 'Unable to find any open ports.'
      logging.error(msg)
      print(msg)
      return -2
    else:
      msg = 'Tried to connect to port %d, but address is in use.' % FLAGS.port
      logging.error(msg)
      print(msg)
      return -3

  try:
    tag = resource_loader.load_resource('tensorboard/TAG').strip()
    logging.info('TensorBoard is tag: %s', tag)
  except IOError:
    logging.warning('Unable to read TensorBoard tag')
    tag = ''

  status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port)
  print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port))
  print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port))
  tb_server.serve_forever()
示例#18
0
def main(unused_argv=None):
    if FLAGS.debug:
        logging.set_verbosity(logging.DEBUG)
        logging.info('TensorBoard is in debug mode.')

    if not FLAGS.logdir:
        msg = ('A logdir must be specified. Run `tensorboard --help` for '
               'details and examples.')
        logging.error(msg)
        print(msg)
        return -1

    logging.info('Starting TensorBoard in directory %s', os.getcwd())
    path_to_run = server.ParseEventFilesSpec(FLAGS.logdir)
    logging.info('TensorBoard path_to_run is: %s', path_to_run)

    multiplexer = event_multiplexer.EventMultiplexer(
        size_guidance=server.TENSORBOARD_SIZE_GUIDANCE,
        purge_orphaned_data=FLAGS.purge_orphaned_data)
    server.StartMultiplexerReloadingThread(multiplexer, path_to_run,
                                           FLAGS.reload_interval)
    try:
        tb_server = server.BuildServer(multiplexer, FLAGS.host, FLAGS.port)
    except socket.error:
        if FLAGS.port == 0:
            msg = 'Unable to find any open ports.'
            logging.error(msg)
            print(msg)
            return -2
        else:
            msg = 'Tried to connect to port %d, but address is in use.' % FLAGS.port
            logging.error(msg)
            print(msg)
            return -3

    try:
        tag = resource_loader.load_resource('tensorboard/TAG').strip()
        logging.info('TensorBoard is tag: %s', tag)
    except IOError:
        logging.warning('Unable to read TensorBoard tag')
        tag = ''

    status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port)
    print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port))
    print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port))
    tb_server.serve_forever()
示例#19
0
def _MakeShape(v, arg_name):
    """Convert v into a TensorShapeProto."""
    # Args:
    #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
    #   arg_name: String, for error messages.

    # Returns:
    #   A TensorShapeProto.
    if isinstance(v, tensor_shape_pb2.TensorShapeProto):
        for d in v.dim:
            if d.name:
                logging.warning(
                    "Warning: TensorShapeProto with a named dimension: %s",
                    str(v))
                break
        return v
    return tensor_shape.as_shape(v).as_proto()
示例#20
0
    def _show_compute(self, show_dataflow):
        """Visualize the computation activity."""
        for dev_stats in self._step_stats.dev_stats:
            device_pid = self._device_pids[dev_stats.device]

            for node_stats in dev_stats.node_stats:
                tid = node_stats.thread_id
                start_time = node_stats.all_start_micros
                end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
                _, _, inputs = self._parse_op_label(node_stats.timeline_label)

                self._emit_op(node_stats, device_pid)

                for input_name in inputs:
                    if input_name not in self._tensors:
                        # This can happen when partitioning has inserted a Send/Recv.
                        # We remove the numeric suffix so that the dataflow appears to
                        # come from the original node.  Ideally, the StepStats would
                        # contain logging for the Send and Recv nodes.
                        index = input_name.rfind('/_')
                        if index > 0:
                            input_name = input_name[:index]

                    if input_name in self._tensors:
                        tensor = self._tensors[input_name]
                        tensor.add_ref(start_time)
                        tensor.add_unref(end_time - 1)

                        if show_dataflow:
                            # We use a different flow ID for every graph edge.
                            create_time, create_pid, create_tid = self._flow_starts[
                                input_name]
                            # Don't add flows when producer and consumer ops are on the same
                            # pid/tid since the horizontal arrows clutter the visualization.
                            if create_pid != device_pid or create_tid != tid:
                                flow_id = self._alloc_flow_id()
                                self._chrome_trace.emit_flow_start(
                                    input_name, create_time, create_pid,
                                    create_tid, flow_id)
                                self._chrome_trace.emit_flow_end(
                                    input_name, start_time, device_pid, tid,
                                    flow_id)
                    else:
                        logging.warning('Can\'t find tensor %s', input_name)
示例#21
0
def _add_collection_def(meta_graph_def, key):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s" % type(key))
    return
  collection_list = ops.get_collection(key)
  if not collection_list:
    return
  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x)
        assert isinstance(proto, proto_type)
        getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        getattr(col_def, kind).value.extend([x.name for x in collection_list])
      elif kind == "bytes_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python3 distinguishes between bytes and strings.
        getattr(col_def, kind).value.extend(
            [compat.as_bytes(x) for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception as e:  # pylint: disable=broad-except
    logging.warning("Error encountered when serializing %s.\n"
                    "Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s" % (key, str(e)))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
示例#22
0
  def _show_compute(self, show_dataflow):
    """Visualize the computation activity."""
    for dev_stats in self._step_stats.dev_stats:
      device_pid = self._device_pids[dev_stats.device]

      for node_stats in dev_stats.node_stats:
        tid = node_stats.thread_id
        start_time = node_stats.all_start_micros
        end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
        _, _, inputs = self._parse_op_label(node_stats.timeline_label)

        self._emit_op(node_stats, device_pid)

        for input_name in inputs:
          if input_name not in self._tensors:
            # This can happen when partitioning has inserted a Send/Recv.
            # We remove the numeric suffix so that the dataflow appears to
            # come from the original node.  Ideally, the StepStats would
            # contain logging for the Send and Recv nodes.
            index = input_name.rfind('/_')
            if index > 0:
              input_name = input_name[:index]

          if input_name in self._tensors:
            tensor = self._tensors[input_name]
            tensor.add_ref(start_time)
            tensor.add_unref(end_time - 1)

            if show_dataflow:
              # We use a different flow ID for every graph edge.
              create_time, create_pid, create_tid = self._flow_starts[
                  input_name]
              # Don't add flows when producer and consumer ops are on the same
              # pid/tid since the horizontal arrows clutter the visualization.
              if create_pid != device_pid or create_tid != tid:
                flow_id = self._alloc_flow_id()
                self._chrome_trace.emit_flow_start(input_name, create_time,
                                                   create_pid, create_tid,
                                                   flow_id)
                self._chrome_trace.emit_flow_end(input_name, start_time,
                                                 device_pid, tid, flow_id)
          else:
            logging.warning('Can\'t find tensor %s', input_name)
示例#23
0
def load_resource(path):
  """Load the resource at given path, where path is relative to tensorflow/.

  Args:
    path: a string resource path relative to tensorflow/.

  Returns:
    The contents of that resource.

  Raises:
    IOError: If the path is not found, or the resource can't be opened.
  """
  path = os.path.join('tensorflow', path)
  path = os.path.abspath(path)
  try:
    with open(path, 'rb') as f:
      return f.read()
  except IOError as e:
    logging.warning('IOError %s on path %s' % (e, path))
示例#24
0
def load_resource(path):
  """Load the resource at given path, where path is relative to tensorflow/.

  Args:
    path: a string resource path relative to tensorflow/.

  Returns:
    The contents of that resource.

  Raises:
    IOError: If the path is not found, or the resource can't be opened.
  """
  path = os.path.join('tensorflow', path)
  path = os.path.abspath(path)
  try:
    with open(path, 'rb') as f:
      return f.read()
  except IOError as e:
    logging.warning('IOError %s on path %s', e, path)
示例#25
0
    def AddRun(self, path, name=None):
        """Add a run to the multiplexer.

    If the name is not specified, it is the same as the path.

    If a run by that name exists, and we are already watching the right path,
      do nothing. If we are watching a different path, replace the event
      accumulator.

    If `AutoUpdate` or `Reload` have been called, it will `AutoUpdate` or
    `Reload` the newly created accumulators. This maintains the invariant that
    once the Multiplexer was activated, all of its accumulators are active.

    Args:
      path: Path to the event files (or event directory) for given run.
      name: Name of the run to add. If not provided, is set to path.

    Returns:
      The `EventMultiplexer`.
    """
        if name is None or name is '':
            name = path
        accumulator = None
        with self._accumulators_mutex:
            if name not in self._accumulators or self._paths[name] != path:
                if name in self._paths and self._paths[name] != path:
                    # TODO(danmane) - Make it impossible to overwrite an old path with
                    # a new path (just give the new path a distinct name)
                    logging.warning(
                        'Conflict for name %s: old path %s, new path %s', name,
                        self._paths[name], path)
                logging.info('Constructing EventAccumulator for %s', path)
                accumulator = event_accumulator.EventAccumulator(
                    path, self._size_guidance)
                self._accumulators[name] = accumulator
                self._paths[name] = path
        if accumulator:
            if self._reload_called:
                accumulator.Reload()
            if self._autoupdate_called:
                accumulator.AutoUpdate(self._autoupdate_interval)
        return self
示例#26
0
def main(unused_argv=None):
    if FLAGS.debug:
        logging.set_verbosity(logging.DEBUG)
        logging.info('TensorBoard is in debug mode.')

    if not FLAGS.logdir:
        logging.error(
            'A logdir must be specified. Run `tensorboard --help` for '
            'details and examples.')
        return -1

    if FLAGS.debug:
        logging.info('Starting TensorBoard in directory %s', os.getcwd())

    path_to_run = ParseEventFilesFlag(FLAGS.logdir)
    multiplexer = event_multiplexer.AutoloadingMultiplexer(
        path_to_run=path_to_run,
        interval_secs=60,
        size_guidance=TENSORBOARD_SIZE_GUIDANCE)

    multiplexer.AutoUpdate(interval=30)

    factory = functools.partial(tensorboard_handler.TensorboardHandler,
                                multiplexer)
    try:
        server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory)
    except socket.error:
        logging.error(
            'Tried to connect to port %d, but that address is in use.',
            FLAGS.port)
        return -2
    try:
        tag = resource_loader.load_resource('tensorboard/TAG').strip()
        logging.info('TensorBoard is tag: %s', tag)
    except IOError:
        logging.warning('Unable to read TensorBoard tag')
        tag = ''

    status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port)
    print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port))
    print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port))
    server.serve_forever()
  def AddRun(self, path, name=None):
    """Add a run to the multiplexer.

    If the name is not specified, it is the same as the path.

    If a run by that name exists, and we are already watching the right path,
      do nothing. If we are watching a different path, replace the event
      accumulator.

    If `AutoUpdate` or `Reload` have been called, it will `AutoUpdate` or
    `Reload` the newly created accumulators. This maintains the invariant that
    once the Multiplexer was activated, all of its accumulators are active.

    Args:
      path: Path to the event files (or event directory) for given run.
      name: Name of the run to add. If not provided, is set to path.

    Returns:
      The `EventMultiplexer`.
    """
    if name is None or name is '':
      name = path
    accumulator = None
    with self._accumulators_mutex:
      if name not in self._accumulators or self._paths[name] != path:
        if name in self._paths and self._paths[name] != path:
          # TODO(danmane) - Make it impossible to overwrite an old path with
          # a new path (just give the new path a distinct name)
          logging.warning('Conflict for name %s: old path %s, new path %s',
                          name, self._paths[name], path)
        logging.info('Constructing EventAccumulator for %s', path)
        accumulator = event_accumulator.EventAccumulator(path,
                                                         self._size_guidance)
        self._accumulators[name] = accumulator
        self._paths[name] = path
    if accumulator:
      if self._reload_called:
        accumulator.Reload()
      if self._autoupdate_called:
        accumulator.AutoUpdate(self._autoupdate_interval)
    return self
示例#28
0
    def _model_not_ready(self, sess):
        """Checks if the model is ready or not.

    Args:
      sess: A `Session`.

    Returns:
      `None` if the model is ready, a `String` with the reason why it is not
      ready otherwise.
    """
        if self._ready_op is None:
            return None
        else:
            try:
                sess.run(self._ready_op)
                return None
            except errors.FailedPreconditionError as e:
                if "uninitialized" not in str(e):
                    logging.warning("Model not ready raised: %s", str(e))
                    raise e
                return str(e)
示例#29
0
  def _model_not_ready(self, sess):
    """Checks if the model is ready or not.

    Args:
      sess: A `Session`.

    Returns:
      `None` if the model is ready, a `String` with the reason why it is not
      ready otherwise.
    """
    if self._ready_op is None:
      return None
    else:
      try:
        sess.run(self._ready_op)
        return None
      except errors.FailedPreconditionError as e:
        if "uninitialized" not in str(e):
          logging.warning("Model not ready raised: %s", str(e))
          raise  e
        return str(e)
示例#30
0
文件: saver.py 项目: hdzz/tensorflow
  def _MaybeDeleteOldCheckpoints(self, latest_save_path,
                                 meta_graph_suffix="meta"):
    """Deletes old checkpoints if necessary.

    Always keep the last `max_to_keep` checkpoints.  If
    `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
    every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
    kept for every 0.5 hours of training; if `N` is 10, an additional
    checkpoint is kept for every 10 hours of training.

    Args:
      latest_save_path: Name including path of checkpoint file to save.
      meta_graph_suffix: Suffix for MetaGraphDef file. Defaults to 'meta'.
    """
    if not self.saver_def.max_to_keep:
      return
    # Remove first from list if the same name was used before.
    for p in self._last_checkpoints:
      if latest_save_path == self._CheckpointFilename(p):
        self._last_checkpoints.remove(p)
    # Append new path to list
    self._last_checkpoints.append((latest_save_path, time.time()))
    # If more than max_to_keep, remove oldest.
    if len(self._last_checkpoints) > self.saver_def.max_to_keep:
      p = self._last_checkpoints.pop(0)
      # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
      # have reached N hours of training.
      should_keep = p[1] > self._next_checkpoint_time
      if should_keep:
        self._next_checkpoint_time += (
            self.saver_def.keep_checkpoint_every_n_hours * 3600)
        return
      # Otherwise delete the files.
      for f in gfile.Glob(self._CheckpointFilename(p)):
        try:
          gfile.Remove(f)
          gfile.Remove(".".join([f, meta_graph_suffix]))
        except OSError as e:
          logging.warning("Ignoring: %s", str(e))
示例#31
0
def main(unused_argv=None):
  if FLAGS.debug:
    logging.set_verbosity(logging.DEBUG)
    logging.info('TensorBoard is in debug mode.')

  if not FLAGS.logdir:
    logging.error('A logdir must be specified. Run `tensorboard --help` for '
                  'details and examples.')
    return -1

  if FLAGS.debug:
    logging.info('Starting TensorBoard in directory %s', os.getcwd())

  path_to_run = ParseEventFilesFlag(FLAGS.logdir)
  multiplexer = event_multiplexer.AutoloadingMultiplexer(
      path_to_run=path_to_run, interval_secs=60,
      size_guidance=TENSORBOARD_SIZE_GUIDANCE)

  multiplexer.AutoUpdate(interval=30)

  factory = functools.partial(tensorboard_handler.TensorboardHandler,
                              multiplexer)
  try:
    server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory)
  except socket.error:
    logging.error('Tried to connect to port %d, but that address is in use.',
                  FLAGS.port)
    return -2
  try:
    tag = resource_loader.load_resource('tensorboard/TAG').strip()
    logging.info('TensorBoard is tag: %s', tag)
  except IOError:
    logging.warning('Unable to read TensorBoard tag')
    tag = ''

  status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port)
  print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port))
  print('(You can navigate to http://localhost:%d)' % FLAGS.port)
  server.serve_forever()
示例#32
0
def load_resource(path):
    """Load the resource at given path, where path is relative to tensorflow/.

  Args:
    path: a string resource path relative to tensorflow/.

  Returns:
    The contents of that resource.

  Raises:
    IOError: If the path is not found, or the resource can't be opened.
  """
    tensorflow_root = (os.path.join(os.path.dirname(__file__), os.pardir,
                                    os.pardir))
    path = os.path.join(tensorflow_root, path)
    path = os.path.abspath(path)
    try:
        with open(path, 'rb') as f:
            return f.read()
    except IOError as e:
        logging.warning('IOError %s on path %s', e, path)
        raise e
示例#33
0
文件: saver.py 项目: hdzz/tensorflow
def _add_collection_def(meta_graph_def, key):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
  if not isinstance(key, (str, bytes, unicode)):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s" % type(key))
    return
  collection_list = ops.get_collection(key)
  if not collection_list:
    return
  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x)
        assert isinstance(proto, proto_type)
        getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        getattr(col_def, kind).value.extend([x.name for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception, e:  # pylint: disable=broad-except
    logging.warning("Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s" % str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
示例#34
0
    def _MaybeDeleteOldCheckpoints(self, latest_save_path):
        """Deletes old checkpoints if necessary.

    Always keep the last `max_to_keep` checkpoints.  If
    `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
    every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
    kept for every 0.5 hours of training; if `N` is 10, an additional
    checkpoint is kept for every 10 hours of training.

    Args:
      latest_save_path: Name including path of checkpoint file to save.
    """
        if not self._max_to_keep:
            return
        # Remove first from list if the same name was used before.
        for p in self._last_checkpoints:
            if latest_save_path == self._CheckpointFilename(p):
                self._last_checkpoints.remove(p)
        # Append new path to list
        self._last_checkpoints.append((latest_save_path, time.time()))
        # If more than max_to_keep, remove oldest.
        if len(self._last_checkpoints) > self._max_to_keep:
            p = self._last_checkpoints.pop(0)
            # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
            # have reached N hours of training.
            should_keep = p[1] > self._next_checkpoint_time
            if should_keep:
                self._next_checkpoint_time += (
                    self._keep_checkpoint_every_n_hours * 3600)
                return
            # Otherwise delete the files.
            for f in gfile.Glob(self._CheckpointFilename(p)):
                try:
                    gfile.Remove(f)
                except OSError as e:
                    logging.warning("Ignoring: %s", str(e))
示例#35
0
def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
                          worker_device="/job:worker", merge_devices=True,
                          cluster=None, ps_ops=None):
  """Return a `device function` to use when building a Graph for replicas.

  Device Functions are used in `with tf.device(device_function):` statement to
  automatically assign devices to `Operation` objects as they are constructed,
  Device constraints are added from the inner-most context first, working
  outwards. The merging behavior adds constraints to fields that are yet unset
  by a more inner context. Currently the fields are (job, task, cpu/gpu).

  If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op.

  For example,

  ```python
  # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
  # jobs on hosts worker0, worker1 and worker2.
  cluster_spec = {
      "ps": ["ps0:2222", "ps1:2222"],
      "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
  with tf.device(tf.replica_device_setter(cluster=cluster_spec)):
    # Build your graph
    v1 = tf.Variable(...)  # assigned to /job:ps/task:0
    v2 = tf.Variable(...)  # assigned to /job:ps/task:1
    v3 = tf.Variable(...)  # assigned to /job:ps/task:0
  # Run compute
  ```

  Args:
    ps_tasks: Number of tasks in the `ps` job.
    ps_device: String.  Device of the `ps` job.  If empty no `ps` job is used.
      Defaults to `ps`.
    worker_device: String.  Device of the `worker` job.  If empty no `worker`
      job is used.
    merge_devices: `Boolean`. If `True`, merges or only sets a device if the
      device constraint is completely unset. merges device specification rather
      than overriding them.
    cluster: `ClusterDef` proto or `ClusterSpec`.
    ps_ops: List of `Operation` objects that need to be placed on `ps` devices.

  Returns:
    A function to pass to `tf.device()`.

  Raises:
    TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer.
  """
  if cluster is not None:
    if isinstance(cluster, server_lib.ClusterSpec):
      cluster_spec = cluster.as_cluster_spec()
    else:
      cluster_spec = server_lib.ClusterSpec(cluster).as_cluster_spec()
    # Get ps_job_name from ps_device by striping "/job:".
    ps_job_name = ps_device.lstrip("/job:")
    if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
      return None
    ps_tasks = len(cluster_spec[ps_job_name])

  if ps_tasks == 0:
    return None
  else:
    if not merge_devices:
      logging.warning(
          "DEPRECATION: It is recommended to set merge_devices=true in "
          "replica_device_setter")
    chooser = _ReplicaDeviceChooser(
        ps_tasks, ps_device, worker_device, merge_devices, ps_ops)
    return chooser.device_function
示例#36
0
def replica_device_setter(ps_tasks=0,
                          ps_device="/job:ps",
                          worker_device="/job:worker",
                          merge_devices=True,
                          cluster=None,
                          ps_ops=None):
    """Return a `device function` to use when building a Graph for replicas.

  Device Functions are used in `with tf.device(device_function):` statement to
  automatically assign devices to `Operation` objects as they are constructed,
  Device constraints are added from the inner-most context first, working
  outwards. The merging behavior adds constraints to fields that are yet unset
  by a more inner context. Currently the fields are (job, task, cpu/gpu).

  If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op.

  For example,

  ```python
  # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
  # jobs on hosts worker0, worker1 and worker2.
  cluster_spec = {
      "ps": ["ps0:2222", "ps1:2222"],
      "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
  with tf.device(tf.replica_device_setter(cluster=cluster_spec)):
    # Build your graph
    v1 = tf.Variable(...)  # assigned to /job:ps/task:0
    v2 = tf.Variable(...)  # assigned to /job:ps/task:1
    v3 = tf.Variable(...)  # assigned to /job:ps/task:0
  # Run compute
  ```

  Args:
    ps_tasks: Number of tasks in the `ps` job.
    ps_device: String.  Device of the `ps` job.  If empty no `ps` job is used.
      Defaults to `ps`.
    worker_device: String.  Device of the `worker` job.  If empty no `worker`
      job is used.
    merge_devices: `Boolean`. If `True`, merges or only sets a device if the
      device constraint is completely unset. merges device specification rather
      than overriding them.
    cluster: `ClusterDef` proto or `ClusterSpec`.
    ps_ops: List of `Operation` objects that need to be placed on `ps` devices.

  Returns:
    A function to pass to `tf.device()`.

  Raises:
    TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer.
  """
    if cluster is not None:
        if isinstance(cluster, server_lib.ClusterSpec):
            cluster_spec = cluster.as_cluster_spec()
        else:
            cluster_spec = server_lib.ClusterSpec(cluster).as_cluster_spec()
        # Get ps_job_name from ps_device by striping "/job:".
        ps_job_name = ps_device.lstrip("/job:")
        if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
            return None
        ps_tasks = len(cluster_spec[ps_job_name])

    if ps_tasks == 0:
        return None
    else:
        if not merge_devices:
            logging.warning(
                "DEPRECATION: It is recommended to set merge_devices=true in "
                "replica_device_setter")
        chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
                                        merge_devices, ps_ops)
        return chooser.device_function