コード例 #1
0
ファイル: writer.py プロジェクト: zghzdxs/pytorch
class FileWriter(object):
    """Writes protocol buffers to event files to be consumed by TensorBoard.

    The `FileWriter` class provides a mechanism to create an event file in a
    given directory and add summaries and events to it. The class updates the
    file contents asynchronously. This allows a training program to call methods
    to add data to the file directly from the training loop, without slowing down
    training.
    """
    def __init__(self,
                 log_dir,
                 max_queue=10,
                 flush_secs=120,
                 filename_suffix=''):
        """Creates a `FileWriter` and an event file.
        On construction the writer creates a new event file in `log_dir`.
        The other arguments to the constructor control the asynchronous writes to
        the event file.

        Args:
          log_dir: A string. Directory where event file will be written.
          max_queue: Integer. Size of the queue for pending events and
            summaries before one of the 'add' calls forces a flush to disk.
            Default is ten items.
          flush_secs: Number. How often, in seconds, to flush the
            pending events and summaries to disk. Default is every two minutes.
          filename_suffix: A string. Suffix added to all event filenames
            in the log_dir directory. More details on filename construction in
            tensorboard.summary.writer.event_file_writer.EventFileWriter.
        """
        # Sometimes PosixPath is passed in and we need to coerce it to
        # a string in all cases
        # TODO: See if we can remove this in the future if we are
        # actually the ones passing in a PosixPath
        log_dir = str(log_dir)
        self.event_writer = EventFileWriter(log_dir, max_queue, flush_secs,
                                            filename_suffix)

    def get_logdir(self):
        """Returns the directory where event file will be written."""
        return self.event_writer.get_logdir()

    def add_event(self, event, step=None, walltime=None):
        """Adds an event to the event file.
        Args:
          event: An `Event` protocol buffer.
          step: Number. Optional global step value for training process
            to record with the event.
          walltime: float. Optional walltime to override the default (current)
            walltime (from time.time()) seconds after epoch
        """
        event.wall_time = time.time() if walltime is None else walltime
        if step is not None:
            # Make sure step is converted from numpy or other formats
            # since protobuf might not convert depending on version
            event.step = int(step)
        self.event_writer.add_event(event)

    def add_summary(self, summary, global_step=None, walltime=None):
        """Adds a `Summary` protocol buffer to the event file.
        This method wraps the provided summary in an `Event` protocol buffer
        and adds it to the event file.

        Args:
          summary: A `Summary` protocol buffer.
          global_step: Number. Optional global step value for training process
            to record with the summary.
          walltime: float. Optional walltime to override the default (current)
            walltime (from time.time()) seconds after epoch
        """
        event = event_pb2.Event(summary=summary)
        self.add_event(event, global_step, walltime)

    def add_graph(self, graph_profile, walltime=None):
        """Adds a `Graph` and step stats protocol buffer to the event file.

        Args:
          graph_profile: A `Graph` and step stats protocol buffer.
          walltime: float. Optional walltime to override the default (current)
            walltime (from time.time()) seconds after epoch
        """
        graph = graph_profile[0]
        stepstats = graph_profile[1]
        event = event_pb2.Event(graph_def=graph.SerializeToString())
        self.add_event(event, None, walltime)

        trm = event_pb2.TaggedRunMetadata(
            tag='step1', run_metadata=stepstats.SerializeToString())
        event = event_pb2.Event(tagged_run_metadata=trm)
        self.add_event(event, None, walltime)

    def add_onnx_graph(self, graph, walltime=None):
        """Adds a `Graph` protocol buffer to the event file.

        Args:
          graph: A `Graph` protocol buffer.
          walltime: float. Optional walltime to override the default (current)
            _get_file_writerfrom time.time())
        """
        event = event_pb2.Event(graph_def=graph.SerializeToString())
        self.add_event(event, None, walltime)

    def flush(self):
        """Flushes the event file to disk.
        Call this method to make sure that all pending events have been written to
        disk.
        """
        self.event_writer.flush()

    def close(self):
        """Flushes the event file to disk and close the file.
        Call this method when you do not need the summary writer anymore.
        """
        self.event_writer.close()

    def reopen(self):
        """Reopens the EventFileWriter.
        Can be called after `close()` to add more events in the same directory.
        The events will go into a new events file.
        Does nothing if the EventFileWriter was not closed.
        """
        self.event_writer.reopen()
コード例 #2
0
class FileWriter(object):

    r"""Write protocol buffers to event files.

    Args:
        log_dir (str): Directory where event file will be written.
        max_queue (int, optional): Size of the queue for pending events and summaries before one of the 'add' calls
            forces a flush to disk. Defaults to 10.
        flush_secs (int, optional): How often, in seconds, to flush the pending events and summaries to disk. Defaults
            to every two minutes (120s).
        filename_suffix (str, optional): Suffix added to all event filenames in the log_dir directory.
    """

    def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix=''):
        log_dir = str(log_dir)
        self.event_writer = EventFileWriter(
            log_dir, max_queue, flush_secs, filename_suffix)

    def get_logdir(self):
        r"""Returns the directory where event file will be written."""
        return self.event_writer.get_logdir()

    def add_event(self, event, step=None, walltime=None):
        r"""Adds an event to the event file.

        Args:
            event: An `Event` protocol buffer.
            step (int, optional): Optional global step value for training process to record with the
                event.
            walltime: float. Optional walltime to override the default (current) walltime
                (from time.time()) seconds after epoch.
        """
        event.wall_time = time.time() if walltime is None else walltime
        if step is not None:
            event.step = int(step)
        self.event_writer.add_event(event)

    def add_summary(self, summary, global_step=None, walltime=None):
        r"""Adds a `Summary` protocol buffer to the event file.

        Args:
            summary: A `Summary` protocol buffer.
            global_step (int, optional): Optional global step value for training process to record
                with the summary.
            walltime (float, optional): Optional walltime to override the default (current) walltime
                (from time.time()) seconds after epoch.
        """
        event = event_pb2.Event(summary=summary)
        self.add_event(event, global_step, walltime)

    def add_graph(self, graph_profile, walltime=None):
        r"""Adds a `Graph` and step stats protocol buffer to the event file.

        Args:
            graph_profile: A `Graph` and step stats protocol buffer.
            walltime (float, optional): Optional walltime to override the default (current) walltime
                (from time.time()) seconds after epoch.
        """
        graph = graph_profile[0]
        stepstats = graph_profile[1]
        event = event_pb2.Event(graph_def=graph.SerializeToString())
        self.add_event(event, None, walltime)

        trm = event_pb2.TaggedRunMetadata(
            tag='step1', run_metadata=stepstats.SerializeToString())

        event = event_pb2.Event(tagged_run_metadata=trm)
        self.add_event(event, None, walltime)

    def flush(self):
        r"""Flushes the event file to disk."""
        self.event_writer.flush()

    def close(self):
        r"""Flushes the event file to disk and close the file."""
        self.event_writer.close()

    def reopen(self):
        r"""Reopens the EventFileWriter."""
        self.event_writer.reopen()