def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename( checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opeining it to avoid # many lines of errors from colossus in the logs. if gfile.Exists(coord_checkpoint_filename): f = gfile.FastGFile(coord_checkpoint_filename, mode="r") ckpt = CheckpointState() text_format.Merge(f.read(), ckpt) except gfile.FileError: # It's ok if the file cannot be read return None except text_format.ParseError, e: logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None
def _MaybeDeleteOldCheckpoints(self, latest_save_path): """Deletes old checkpoints if necessary. Always keep the last max_to_keep checkpoints. If keep_checkpoint_every_n_hours was specified, keep an additional checkpoint every N hours. For example, if N is 0.5, an additional checkpoint is kept for every 0.5 hours of training; if N is 10, an additional checkpoint is kept for every 10 hours of training. Args: latest_save_path: Name including path of checkpoint file to save. """ if not self._max_to_keep: return # Remove first from list if the same name was used before. for p in self._last_checkpoints: if latest_save_path == self._CheckpointFilename(p): self._last_checkpoints.remove(p) # Append new path to list self._last_checkpoints.append((latest_save_path, time.time())) # If more than max_to_keep, remove oldest. if len(self._last_checkpoints) > self._max_to_keep: p = self._last_checkpoints.pop(0) # Do not delete the file if we keep_checkpoint_every_n_hours is set and we # have reached N hours of training. should_keep = p[1] > self._next_checkpoint_time if should_keep: self._next_checkpoint_time += self._keep_checkpoint_every_n_hours * 3600 return # Otherwise delete the files. for f in gfile.Glob(self._CheckpointFilename(p)): try: gfile.Remove(f) except gfile.GOSError as e: logging.warning("Ignoring: %s", str(e))
def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: list of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Raises: RuntimeError: If the save paths conflict. """ if all_model_checkpoint_paths is None: all_model_checkpoint_paths = [] if all_model_checkpoint_paths and all_model_checkpoint_paths[ -1] != model_checkpoint_path: logging.warning( "%s is not in all_model_checkpoint_paths! Manually adding it.", model_checkpoint_path) all_model_checkpoint_paths.append(model_checkpoint_path) # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) # Relative paths need to be rewritten to be relative to the "save_dir". if not os.path.isabs(model_checkpoint_path): model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) all_model_checkpoint_paths = [ os.path.relpath(p, save_dir) for p in all_model_checkpoint_paths if not os.path.isabs(p) ] if coord_checkpoint_filename == model_checkpoint_path: raise RuntimeError( "Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) coord_checkpoint_proto = CheckpointState( model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) f = gfile.FastGFile(coord_checkpoint_filename, mode="w") f.write(text_format.MessageToString(coord_checkpoint_proto)) f.close()
def _default_global_step_tensor(self): try: gs = ops.get_default_graph().get_tensor_by_name("global_step:0") if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]: return gs else: logging.warning("Found 'global_step' is not an int type: %s", gs.dtype) return None except KeyError: return None
def add_graph(self, graph, global_step=None, graph_def=None): """Adds a `Graph` to the event file. The graph described by the protocol buffer will be displayed by TensorBoard. Most users pass a graph in the constructor instead. Args: graph: A `Graph` object, such as `sess.graph`. global_step: Number. Optional global step counter to record with the graph. graph_def: DEPRECATED. Use the `graph` parameter instead. Raises: ValueError: If both graph and graph_def are passed to the method. """ if graph is not None and graph_def is not None: raise ValueError( "Please pass only graph, or graph_def (deprecated), " "but not both.") if isinstance(graph, ops.Graph) or isinstance(graph_def, ops.Graph): # The user passed a `Graph`. # Check if the user passed it via the graph or the graph_def argument and # correct for that. if not isinstance(graph, ops.Graph): logging.warning( "When passing a `Graph` object, please use the `graph`" " named argument instead of `graph_def`.") graph = graph_def # Serialize the graph with additional info. true_graph_def = graph.as_graph_def(add_shapes=True) elif (isinstance(graph, graph_pb2.GraphDef) or isinstance(graph_def, graph_pb2.GraphDef)): # The user passed a `GraphDef`. logging.warning( "Passing a `GraphDef` to the SummaryWriter is deprecated." " Pass a `Graph` object instead, such as `sess.graph`.") # Check if the user passed it via the graph or the graph_def argument and # correct for that. if isinstance(graph, graph_pb2.GraphDef): true_graph_def = graph else: true_graph_def = graph_def else: # The user passed neither `Graph`, nor `GraphDef`. raise TypeError("The passed graph must be an instance of `Graph` " "or the deprecated `GraphDef`") # Finally, add the graph_def to the summary writer. self._add_graph_def(true_graph_def, global_step)
def start_standard_services(self, sess): """Start the standard services for 'sess'. This starts services in the background. The services started depend on the parameters to the constructor and may include: - A Summary thread computing summaries every save_summaries_secs. - A Checkpoint thread saving the model every every save_model_secs. - A StepCounter thread measure step time. Args: sess: A Session. Returns: A list of threads that are running the standard services. You can use the Supervisor's Coordinator to join these threads with: sv.coord.Join(<list of threads>) Raises: RuntimeError: If called with a non-chief Supervisor. ValueError: If not `logdir` was passed to the constructor as the services need a log directory. """ if not self._is_chief: raise RuntimeError("Only chief supervisor can start standard services. " "Because only chief supervisors can write events.") if not self._logdir: logging.warning("Standard services need a 'logdir' " "passed to the SessionManager") return if self._global_step is not None and self._summary_writer: # Only add the session log if we keep track of global step. # TensorBoard cannot use START message for purging expired events # if there is no step value. current_step = training_util.global_step(sess, self._global_step) self._summary_writer.add_session_log( SessionLog(status=SessionLog.START), current_step) threads = [] if self._save_summaries_secs and self._summary_writer: if self._summary_op is not None: threads.append(SVSummaryThread(self, sess)) if self._global_step is not None: threads.append(SVStepCounterThread(self, sess)) if self.saver and self._save_model_secs: threads.append(SVTimerCheckpointThread(self, sess)) for t in threads: t.start() self._started_threads.extend(threads) return threads
def main(unused_argv=None): if FLAGS.debug: logging.set_verbosity(logging.DEBUG) logging.info('TensorBoard is in debug mode.') if not FLAGS.logdir: logging.error( 'A logdir must be specified. Run `tensorboard --help` for ' 'details and examples.') return -1 logging.info('Starting TensorBoard in directory %s', os.getcwd()) path_to_run = ParseEventFilesFlag(FLAGS.logdir) logging.info('TensorBoard path_to_run is: %s', path_to_run) multiplexer = event_multiplexer.EventMultiplexer( size_guidance=TENSORBOARD_SIZE_GUIDANCE) def _Load(): start = time.time() for (path, name) in six.iteritems(path_to_run): multiplexer.AddRunsFromDirectory(path, name) multiplexer.Reload() duration = time.time() - start logging.info('Multiplexer done loading. Load took %0.1f secs', duration) t = threading.Timer(LOAD_INTERVAL, _Load) t.daemon = True t.start() t = threading.Timer(0, _Load) t.daemon = True t.start() factory = functools.partial(tensorboard_handler.TensorboardHandler, multiplexer) try: server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory) except socket.error: logging.error( 'Tried to connect to port %d, but that address is in use.', FLAGS.port) return -2 try: tag = resource_loader.load_resource('tensorboard/TAG').strip() logging.info('TensorBoard is tag: %s', tag) except IOError: logging.warning('Unable to read TensorBoard tag') tag = '' status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port) print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port)) print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port)) server.serve_forever()
def main(unused_argv=None): if FLAGS.debug: logging.set_verbosity(logging.DEBUG) logging.info('TensorBoard is in debug mode.') if not FLAGS.logdir: logging.error('A logdir must be specified. Run `tensorboard --help` for ' 'details and examples.') return -1 logging.info('Starting TensorBoard in directory %s', os.getcwd()) path_to_run = ParseEventFilesFlag(FLAGS.logdir) logging.info('TensorBoard path_to_run is: %s', path_to_run) multiplexer = event_multiplexer.EventMultiplexer( size_guidance=TENSORBOARD_SIZE_GUIDANCE) # Ensure the Multiplexer initializes in a loaded state before it adds runs # So it can handle HTTP requests while runs are loading multiplexer.Reload() def _Load(): start = time.time() for (path, name) in six.iteritems(path_to_run): multiplexer.AddRunsFromDirectory(path, name) multiplexer.Reload() duration = time.time() - start logging.info('Multiplexer done loading. Load took %0.1f secs', duration) t = threading.Timer(LOAD_INTERVAL, _Load) t.daemon = True t.start() t = threading.Timer(0, _Load) t.daemon = True t.start() factory = functools.partial(tensorboard_handler.TensorboardHandler, multiplexer) try: server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory) except socket.error: logging.error('Tried to connect to port %d, but that address is in use.', FLAGS.port) return -2 try: tag = resource_loader.load_resource('tensorboard/TAG').strip() logging.info('TensorBoard is tag: %s', tag) except IOError: logging.warning('Unable to read TensorBoard tag') tag = '' status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port) print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port)) print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port)) server.serve_forever()
def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: list of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Raises: RuntimeError: If the save paths conflict. """ if all_model_checkpoint_paths is None: all_model_checkpoint_paths = [] if all_model_checkpoint_paths and all_model_checkpoint_paths[-1] != model_checkpoint_path: logging.warning( "%s is not in all_model_checkpoint_paths! Manually adding it.", model_checkpoint_path) all_model_checkpoint_paths.append(model_checkpoint_path) # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) # Relative paths need to be rewritten to be relative to the "save_dir". if not os.path.isabs(model_checkpoint_path): model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) all_model_checkpoint_paths = [ os.path.relpath(p, save_dir) for p in all_model_checkpoint_paths if not os.path.isabs(p) ] if coord_checkpoint_filename == model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) coord_checkpoint_proto = CheckpointState( model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) f = gfile.FastGFile(coord_checkpoint_filename, mode="w") f.write(text_format.MessageToString(coord_checkpoint_proto)) f.close()
def add_graph(self, graph, global_step=None, graph_def=None): """Adds a `Graph` to the event file. The graph described by the protocol buffer will be displayed by TensorBoard. Most users pass a graph in the constructor instead. Args: graph: A `Graph` object, such as `sess.graph`. global_step: Number. Optional global step counter to record with the graph. graph_def: DEPRECATED. Use the `graph` parameter instead. Raises: ValueError: If both graph and graph_def are passed to the method. """ if graph is not None and graph_def is not None: raise ValueError("Please pass only graph, or graph_def (deprecated), " "but not both.") if isinstance(graph, ops.Graph) or isinstance(graph_def, ops.Graph): # The user passed a `Graph`. # Check if the user passed it via the graph or the graph_def argument and # correct for that. if not isinstance(graph, ops.Graph): logging.warning("When passing a `Graph` object, please use the `graph`" " named argument instead of `graph_def`.") graph = graph_def # Serialize the graph with additional info. true_graph_def = graph.as_graph_def(add_shapes=True) elif (isinstance(graph, graph_pb2.GraphDef) or isinstance(graph_def, graph_pb2.GraphDef)): # The user passed a `GraphDef`. logging.warning("Passing a `GraphDef` to the SummaryWriter is deprecated." " Pass a `Graph` object instead, such as `sess.graph`.") # Check if the user passed it via the graph or the graph_def argument and # correct for that. if isinstance(graph, graph_pb2.GraphDef): true_graph_def = graph else: true_graph_def = graph_def else: # The user passed neither `Graph`, nor `GraphDef`. raise TypeError("The passed graph must be an instance of `Graph` " "or the deprecated `GraphDef`") # Finally, add the graph_def to the summary writer. self._add_graph_def(true_graph_def, global_step)
def _default_global_step_tensor(self): """Returns the global_step from the default graph. Returns: The global step `Tensor` or `None`. """ try: gs = ops.get_default_graph().get_tensor_by_name("global_step:0") if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]: return gs else: logging.warning("Found 'global_step' is not an int type: %s", gs.dtype) return None except KeyError: return None
def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if gfile.Exists(coord_checkpoint_filename): f = gfile.FastGFile(coord_checkpoint_filename, mode="r") ckpt = CheckpointState() text_format.Merge(f.read(), ckpt) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(checkpoint_dir): if not os.path.isabs(ckpt.model_checkpoint_path): ckpt.model_checkpoint_path = os.path.join( checkpoint_dir, ckpt.model_checkpoint_path) for i in range(len(ckpt.all_model_checkpoint_paths)): p = ckpt.all_model_checkpoint_paths[i] if not os.path.isabs(p): ckpt.all_model_checkpoint_paths[i] = os.path.join( checkpoint_dir, p) except IOError: # It's ok if the file cannot be read return None except text_format.ParseError as e: logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None finally: if f: f.close() return ckpt
def _MakeShape(v, arg_name): """Convert v into a TensorShapeProto.""" # Args: # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. # arg_name: String, for error messages. # Returns: # A TensorShapeProto. if isinstance(v, tensor_shape_pb2.TensorShapeProto): for d in v.dim: if d.name: logging.warning("Warning: TensorShapeProto with a named dimension: %s", str(v)) break return v return tensor_shape.as_shape(v).as_proto()
def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename( checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if gfile.Exists(coord_checkpoint_filename): f = gfile.FastGFile(coord_checkpoint_filename, mode="r") ckpt = CheckpointState() text_format.Merge(f.read(), ckpt) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(checkpoint_dir): if not os.path.isabs(ckpt.model_checkpoint_path): ckpt.model_checkpoint_path = os.path.join( checkpoint_dir, ckpt.model_checkpoint_path) for i in range(len(ckpt.all_model_checkpoint_paths)): p = ckpt.all_model_checkpoint_paths[i] if not os.path.isabs(p): ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) except IOError: # It's ok if the file cannot be read return None except text_format.ParseError as e: logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None finally: if f: f.close() return ckpt
def main(unused_argv=None): if FLAGS.debug: logging.set_verbosity(logging.DEBUG) logging.info('TensorBoard is in debug mode.') if not FLAGS.logdir: msg = ('A logdir must be specified. Run `tensorboard --help` for ' 'details and examples.') logging.error(msg) print(msg) return -1 logging.info('Starting TensorBoard in directory %s', os.getcwd()) path_to_run = server.ParseEventFilesSpec(FLAGS.logdir) logging.info('TensorBoard path_to_run is: %s', path_to_run) multiplexer = event_multiplexer.EventMultiplexer( size_guidance=server.TENSORBOARD_SIZE_GUIDANCE, purge_orphaned_data=FLAGS.purge_orphaned_data) server.StartMultiplexerReloadingThread(multiplexer, path_to_run, FLAGS.reload_interval) try: tb_server = server.BuildServer(multiplexer, FLAGS.host, FLAGS.port) except socket.error: if FLAGS.port == 0: msg = 'Unable to find any open ports.' logging.error(msg) print(msg) return -2 else: msg = 'Tried to connect to port %d, but address is in use.' % FLAGS.port logging.error(msg) print(msg) return -3 try: tag = resource_loader.load_resource('tensorboard/TAG').strip() logging.info('TensorBoard is tag: %s', tag) except IOError: logging.warning('Unable to read TensorBoard tag') tag = '' status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port) print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port)) print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port)) tb_server.serve_forever()
def _MakeShape(v, arg_name): """Convert v into a TensorShapeProto.""" # Args: # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. # arg_name: String, for error messages. # Returns: # A TensorShapeProto. if isinstance(v, tensor_shape_pb2.TensorShapeProto): for d in v.dim: if d.name: logging.warning( "Warning: TensorShapeProto with a named dimension: %s", str(v)) break return v return tensor_shape.as_shape(v).as_proto()
def _show_compute(self, show_dataflow): """Visualize the computation activity.""" for dev_stats in self._step_stats.dev_stats: device_pid = self._device_pids[dev_stats.device] for node_stats in dev_stats.node_stats: tid = node_stats.thread_id start_time = node_stats.all_start_micros end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros _, _, inputs = self._parse_op_label(node_stats.timeline_label) self._emit_op(node_stats, device_pid) for input_name in inputs: if input_name not in self._tensors: # This can happen when partitioning has inserted a Send/Recv. # We remove the numeric suffix so that the dataflow appears to # come from the original node. Ideally, the StepStats would # contain logging for the Send and Recv nodes. index = input_name.rfind('/_') if index > 0: input_name = input_name[:index] if input_name in self._tensors: tensor = self._tensors[input_name] tensor.add_ref(start_time) tensor.add_unref(end_time - 1) if show_dataflow: # We use a different flow ID for every graph edge. create_time, create_pid, create_tid = self._flow_starts[ input_name] # Don't add flows when producer and consumer ops are on the same # pid/tid since the horizontal arrows clutter the visualization. if create_pid != device_pid or create_tid != tid: flow_id = self._alloc_flow_id() self._chrome_trace.emit_flow_start( input_name, create_time, create_pid, create_tid, flow_id) self._chrome_trace.emit_flow_end( input_name, start_time, device_pid, tid, flow_id) else: logging.warning('Can\'t find tensor %s', input_name)
def _add_collection_def(meta_graph_def, key): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. """ if not isinstance(key, six.string_types) and not isinstance(key, bytes): logging.warning("Only collections with string type keys will be " "serialized. This key has %s" % type(key)) return collection_list = ops.get_collection(key) if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x) assert isinstance(proto, proto_type) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": getattr(col_def, kind).value.extend([x.name for x in collection_list]) elif kind == "bytes_list": # NOTE(opensource): This force conversion is to work around the fact # that Python3 distinguishes between bytes and strings. getattr(col_def, kind).value.extend( [compat.as_bytes(x) for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception as e: # pylint: disable=broad-except logging.warning("Error encountered when serializing %s.\n" "Type is unsupported, or the types of the items don't " "match field type in CollectionDef.\n%s" % (key, str(e))) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def _show_compute(self, show_dataflow): """Visualize the computation activity.""" for dev_stats in self._step_stats.dev_stats: device_pid = self._device_pids[dev_stats.device] for node_stats in dev_stats.node_stats: tid = node_stats.thread_id start_time = node_stats.all_start_micros end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros _, _, inputs = self._parse_op_label(node_stats.timeline_label) self._emit_op(node_stats, device_pid) for input_name in inputs: if input_name not in self._tensors: # This can happen when partitioning has inserted a Send/Recv. # We remove the numeric suffix so that the dataflow appears to # come from the original node. Ideally, the StepStats would # contain logging for the Send and Recv nodes. index = input_name.rfind('/_') if index > 0: input_name = input_name[:index] if input_name in self._tensors: tensor = self._tensors[input_name] tensor.add_ref(start_time) tensor.add_unref(end_time - 1) if show_dataflow: # We use a different flow ID for every graph edge. create_time, create_pid, create_tid = self._flow_starts[ input_name] # Don't add flows when producer and consumer ops are on the same # pid/tid since the horizontal arrows clutter the visualization. if create_pid != device_pid or create_tid != tid: flow_id = self._alloc_flow_id() self._chrome_trace.emit_flow_start(input_name, create_time, create_pid, create_tid, flow_id) self._chrome_trace.emit_flow_end(input_name, start_time, device_pid, tid, flow_id) else: logging.warning('Can\'t find tensor %s', input_name)
def load_resource(path): """Load the resource at given path, where path is relative to tensorflow/. Args: path: a string resource path relative to tensorflow/. Returns: The contents of that resource. Raises: IOError: If the path is not found, or the resource can't be opened. """ path = os.path.join('tensorflow', path) path = os.path.abspath(path) try: with open(path, 'rb') as f: return f.read() except IOError as e: logging.warning('IOError %s on path %s' % (e, path))
def load_resource(path): """Load the resource at given path, where path is relative to tensorflow/. Args: path: a string resource path relative to tensorflow/. Returns: The contents of that resource. Raises: IOError: If the path is not found, or the resource can't be opened. """ path = os.path.join('tensorflow', path) path = os.path.abspath(path) try: with open(path, 'rb') as f: return f.read() except IOError as e: logging.warning('IOError %s on path %s', e, path)
def AddRun(self, path, name=None): """Add a run to the multiplexer. If the name is not specified, it is the same as the path. If a run by that name exists, and we are already watching the right path, do nothing. If we are watching a different path, replace the event accumulator. If `AutoUpdate` or `Reload` have been called, it will `AutoUpdate` or `Reload` the newly created accumulators. This maintains the invariant that once the Multiplexer was activated, all of its accumulators are active. Args: path: Path to the event files (or event directory) for given run. name: Name of the run to add. If not provided, is set to path. Returns: The `EventMultiplexer`. """ if name is None or name is '': name = path accumulator = None with self._accumulators_mutex: if name not in self._accumulators or self._paths[name] != path: if name in self._paths and self._paths[name] != path: # TODO(danmane) - Make it impossible to overwrite an old path with # a new path (just give the new path a distinct name) logging.warning( 'Conflict for name %s: old path %s, new path %s', name, self._paths[name], path) logging.info('Constructing EventAccumulator for %s', path) accumulator = event_accumulator.EventAccumulator( path, self._size_guidance) self._accumulators[name] = accumulator self._paths[name] = path if accumulator: if self._reload_called: accumulator.Reload() if self._autoupdate_called: accumulator.AutoUpdate(self._autoupdate_interval) return self
def main(unused_argv=None): if FLAGS.debug: logging.set_verbosity(logging.DEBUG) logging.info('TensorBoard is in debug mode.') if not FLAGS.logdir: logging.error( 'A logdir must be specified. Run `tensorboard --help` for ' 'details and examples.') return -1 if FLAGS.debug: logging.info('Starting TensorBoard in directory %s', os.getcwd()) path_to_run = ParseEventFilesFlag(FLAGS.logdir) multiplexer = event_multiplexer.AutoloadingMultiplexer( path_to_run=path_to_run, interval_secs=60, size_guidance=TENSORBOARD_SIZE_GUIDANCE) multiplexer.AutoUpdate(interval=30) factory = functools.partial(tensorboard_handler.TensorboardHandler, multiplexer) try: server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory) except socket.error: logging.error( 'Tried to connect to port %d, but that address is in use.', FLAGS.port) return -2 try: tag = resource_loader.load_resource('tensorboard/TAG').strip() logging.info('TensorBoard is tag: %s', tag) except IOError: logging.warning('Unable to read TensorBoard tag') tag = '' status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port) print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port)) print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port)) server.serve_forever()
def AddRun(self, path, name=None): """Add a run to the multiplexer. If the name is not specified, it is the same as the path. If a run by that name exists, and we are already watching the right path, do nothing. If we are watching a different path, replace the event accumulator. If `AutoUpdate` or `Reload` have been called, it will `AutoUpdate` or `Reload` the newly created accumulators. This maintains the invariant that once the Multiplexer was activated, all of its accumulators are active. Args: path: Path to the event files (or event directory) for given run. name: Name of the run to add. If not provided, is set to path. Returns: The `EventMultiplexer`. """ if name is None or name is '': name = path accumulator = None with self._accumulators_mutex: if name not in self._accumulators or self._paths[name] != path: if name in self._paths and self._paths[name] != path: # TODO(danmane) - Make it impossible to overwrite an old path with # a new path (just give the new path a distinct name) logging.warning('Conflict for name %s: old path %s, new path %s', name, self._paths[name], path) logging.info('Constructing EventAccumulator for %s', path) accumulator = event_accumulator.EventAccumulator(path, self._size_guidance) self._accumulators[name] = accumulator self._paths[name] = path if accumulator: if self._reload_called: accumulator.Reload() if self._autoupdate_called: accumulator.AutoUpdate(self._autoupdate_interval) return self
def _model_not_ready(self, sess): """Checks if the model is ready or not. Args: sess: A `Session`. Returns: `None` if the model is ready, a `String` with the reason why it is not ready otherwise. """ if self._ready_op is None: return None else: try: sess.run(self._ready_op) return None except errors.FailedPreconditionError as e: if "uninitialized" not in str(e): logging.warning("Model not ready raised: %s", str(e)) raise e return str(e)
def _MaybeDeleteOldCheckpoints(self, latest_save_path, meta_graph_suffix="meta"): """Deletes old checkpoints if necessary. Always keep the last `max_to_keep` checkpoints. If `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint every `N` hours. For example, if `N` is 0.5, an additional checkpoint is kept for every 0.5 hours of training; if `N` is 10, an additional checkpoint is kept for every 10 hours of training. Args: latest_save_path: Name including path of checkpoint file to save. meta_graph_suffix: Suffix for MetaGraphDef file. Defaults to 'meta'. """ if not self.saver_def.max_to_keep: return # Remove first from list if the same name was used before. for p in self._last_checkpoints: if latest_save_path == self._CheckpointFilename(p): self._last_checkpoints.remove(p) # Append new path to list self._last_checkpoints.append((latest_save_path, time.time())) # If more than max_to_keep, remove oldest. if len(self._last_checkpoints) > self.saver_def.max_to_keep: p = self._last_checkpoints.pop(0) # Do not delete the file if we keep_checkpoint_every_n_hours is set and we # have reached N hours of training. should_keep = p[1] > self._next_checkpoint_time if should_keep: self._next_checkpoint_time += ( self.saver_def.keep_checkpoint_every_n_hours * 3600) return # Otherwise delete the files. for f in gfile.Glob(self._CheckpointFilename(p)): try: gfile.Remove(f) gfile.Remove(".".join([f, meta_graph_suffix])) except OSError as e: logging.warning("Ignoring: %s", str(e))
def main(unused_argv=None): if FLAGS.debug: logging.set_verbosity(logging.DEBUG) logging.info('TensorBoard is in debug mode.') if not FLAGS.logdir: logging.error('A logdir must be specified. Run `tensorboard --help` for ' 'details and examples.') return -1 if FLAGS.debug: logging.info('Starting TensorBoard in directory %s', os.getcwd()) path_to_run = ParseEventFilesFlag(FLAGS.logdir) multiplexer = event_multiplexer.AutoloadingMultiplexer( path_to_run=path_to_run, interval_secs=60, size_guidance=TENSORBOARD_SIZE_GUIDANCE) multiplexer.AutoUpdate(interval=30) factory = functools.partial(tensorboard_handler.TensorboardHandler, multiplexer) try: server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory) except socket.error: logging.error('Tried to connect to port %d, but that address is in use.', FLAGS.port) return -2 try: tag = resource_loader.load_resource('tensorboard/TAG').strip() logging.info('TensorBoard is tag: %s', tag) except IOError: logging.warning('Unable to read TensorBoard tag') tag = '' status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port) print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port)) print('(You can navigate to http://localhost:%d)' % FLAGS.port) server.serve_forever()
def load_resource(path): """Load the resource at given path, where path is relative to tensorflow/. Args: path: a string resource path relative to tensorflow/. Returns: The contents of that resource. Raises: IOError: If the path is not found, or the resource can't be opened. """ tensorflow_root = (os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) path = os.path.join(tensorflow_root, path) path = os.path.abspath(path) try: with open(path, 'rb') as f: return f.read() except IOError as e: logging.warning('IOError %s on path %s', e, path) raise e
def _add_collection_def(meta_graph_def, key): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. """ if not isinstance(key, (str, bytes, unicode)): logging.warning("Only collections with string type keys will be " "serialized. This key has %s" % type(key)) return collection_list = ops.get_collection(key) if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x) assert isinstance(proto, proto_type) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": getattr(col_def, kind).value.extend([x.name for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception, e: # pylint: disable=broad-except logging.warning("Type is unsupported, or the types of the items don't " "match field type in CollectionDef.\n%s" % str(e)) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def _MaybeDeleteOldCheckpoints(self, latest_save_path): """Deletes old checkpoints if necessary. Always keep the last `max_to_keep` checkpoints. If `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint every `N` hours. For example, if `N` is 0.5, an additional checkpoint is kept for every 0.5 hours of training; if `N` is 10, an additional checkpoint is kept for every 10 hours of training. Args: latest_save_path: Name including path of checkpoint file to save. """ if not self._max_to_keep: return # Remove first from list if the same name was used before. for p in self._last_checkpoints: if latest_save_path == self._CheckpointFilename(p): self._last_checkpoints.remove(p) # Append new path to list self._last_checkpoints.append((latest_save_path, time.time())) # If more than max_to_keep, remove oldest. if len(self._last_checkpoints) > self._max_to_keep: p = self._last_checkpoints.pop(0) # Do not delete the file if we keep_checkpoint_every_n_hours is set and we # have reached N hours of training. should_keep = p[1] > self._next_checkpoint_time if should_keep: self._next_checkpoint_time += ( self._keep_checkpoint_every_n_hours * 3600) return # Otherwise delete the files. for f in gfile.Glob(self._CheckpointFilename(p)): try: gfile.Remove(f) except OSError as e: logging.warning("Ignoring: %s", str(e))
def replica_device_setter(ps_tasks=0, ps_device="/job:ps", worker_device="/job:worker", merge_devices=True, cluster=None, ps_ops=None): """Return a `device function` to use when building a Graph for replicas. Device Functions are used in `with tf.device(device_function):` statement to automatically assign devices to `Operation` objects as they are constructed, Device constraints are added from the inner-most context first, working outwards. The merging behavior adds constraints to fields that are yet unset by a more inner context. Currently the fields are (job, task, cpu/gpu). If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op. For example, ```python # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker # jobs on hosts worker0, worker1 and worker2. cluster_spec = { "ps": ["ps0:2222", "ps1:2222"], "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]} with tf.device(tf.replica_device_setter(cluster=cluster_spec)): # Build your graph v1 = tf.Variable(...) # assigned to /job:ps/task:0 v2 = tf.Variable(...) # assigned to /job:ps/task:1 v3 = tf.Variable(...) # assigned to /job:ps/task:0 # Run compute ``` Args: ps_tasks: Number of tasks in the `ps` job. ps_device: String. Device of the `ps` job. If empty no `ps` job is used. Defaults to `ps`. worker_device: String. Device of the `worker` job. If empty no `worker` job is used. merge_devices: `Boolean`. If `True`, merges or only sets a device if the device constraint is completely unset. merges device specification rather than overriding them. cluster: `ClusterDef` proto or `ClusterSpec`. ps_ops: List of `Operation` objects that need to be placed on `ps` devices. Returns: A function to pass to `tf.device()`. Raises: TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer. """ if cluster is not None: if isinstance(cluster, server_lib.ClusterSpec): cluster_spec = cluster.as_cluster_spec() else: cluster_spec = server_lib.ClusterSpec(cluster).as_cluster_spec() # Get ps_job_name from ps_device by striping "/job:". ps_job_name = ps_device.lstrip("/job:") if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None: return None ps_tasks = len(cluster_spec[ps_job_name]) if ps_tasks == 0: return None else: if not merge_devices: logging.warning( "DEPRECATION: It is recommended to set merge_devices=true in " "replica_device_setter") chooser = _ReplicaDeviceChooser( ps_tasks, ps_device, worker_device, merge_devices, ps_ops) return chooser.device_function
def replica_device_setter(ps_tasks=0, ps_device="/job:ps", worker_device="/job:worker", merge_devices=True, cluster=None, ps_ops=None): """Return a `device function` to use when building a Graph for replicas. Device Functions are used in `with tf.device(device_function):` statement to automatically assign devices to `Operation` objects as they are constructed, Device constraints are added from the inner-most context first, working outwards. The merging behavior adds constraints to fields that are yet unset by a more inner context. Currently the fields are (job, task, cpu/gpu). If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op. For example, ```python # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker # jobs on hosts worker0, worker1 and worker2. cluster_spec = { "ps": ["ps0:2222", "ps1:2222"], "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]} with tf.device(tf.replica_device_setter(cluster=cluster_spec)): # Build your graph v1 = tf.Variable(...) # assigned to /job:ps/task:0 v2 = tf.Variable(...) # assigned to /job:ps/task:1 v3 = tf.Variable(...) # assigned to /job:ps/task:0 # Run compute ``` Args: ps_tasks: Number of tasks in the `ps` job. ps_device: String. Device of the `ps` job. If empty no `ps` job is used. Defaults to `ps`. worker_device: String. Device of the `worker` job. If empty no `worker` job is used. merge_devices: `Boolean`. If `True`, merges or only sets a device if the device constraint is completely unset. merges device specification rather than overriding them. cluster: `ClusterDef` proto or `ClusterSpec`. ps_ops: List of `Operation` objects that need to be placed on `ps` devices. Returns: A function to pass to `tf.device()`. Raises: TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer. """ if cluster is not None: if isinstance(cluster, server_lib.ClusterSpec): cluster_spec = cluster.as_cluster_spec() else: cluster_spec = server_lib.ClusterSpec(cluster).as_cluster_spec() # Get ps_job_name from ps_device by striping "/job:". ps_job_name = ps_device.lstrip("/job:") if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None: return None ps_tasks = len(cluster_spec[ps_job_name]) if ps_tasks == 0: return None else: if not merge_devices: logging.warning( "DEPRECATION: It is recommended to set merge_devices=true in " "replica_device_setter") chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device, merge_devices, ps_ops) return chooser.device_function