コード例 #1
0
    def test_audio(self):
        with tf.compat.v1.Graph().as_default():
            audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2))
            old_op = tf.compat.v1.summary.audio("k488", audio, 44100)
            old_value = self._value_from_op(old_op)
        assert old_value.HasField("audio"), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual("k488/audio/0", new_value.tag)
        expected_metadata = audio_metadata.create_summary_metadata(
            display_name="k488/audio/0",
            description="",
            encoding=audio_metadata.Encoding.Value("WAV"),
            converted_to_tensor=True,
        )

        # Check serialized submessages...
        plugin_content = audio_metadata.parse_plugin_metadata(
            new_value.metadata.plugin_data.content)
        expected_content = audio_metadata.parse_plugin_metadata(
            expected_metadata.plugin_data.content)
        self.assertEqual(plugin_content, expected_content)
        # ...then check full metadata except plugin content, since
        # serialized forms need not be identical.
        new_value.metadata.plugin_data.content = (
            expected_metadata.plugin_data.content)
        self.assertEqual(expected_metadata, new_value.metadata)

        self.assertTrue(new_value.HasField("tensor"))
        data = tensor_util.make_ndarray(new_value.tensor)
        self.assertEqual((1, 2), data.shape)
        self.assertEqual(
            tf.compat.as_bytes(old_value.audio.encoded_audio_string),
            data[0][0])
        self.assertEqual(b"", data[0][1])  # empty label
コード例 #2
0
    def send_requests(self, run_to_events):
        """Accepts a stream of TF events and sends batched write RPCs.

        Each sent request will be at most `_MAX_REQUEST_LENGTH_BYTES`
        bytes long.

        Args:
          run_to_events: Mapping from run name to generator of `tf.Event`
            values, as returned by `LogdirLoader.get_run_events`.

        Raises:
          RuntimeError: If no progress can be made because even a single
          point is too large (say, due to a gigabyte-long tag name).
        """

        for (run_name, event, orig_value) in self._run_values(run_to_events):
            value = data_compat.migrate_value(orig_value)
            time_series_key = (run_name, value.tag)

            # The metadata for a time series is memorized on the first event.
            # If later events arrive with a mismatching plugin_name, they are
            # ignored with a warning.
            metadata = self._tag_metadata.get(time_series_key)
            first_in_time_series = False
            if metadata is None:
                first_in_time_series = True
                metadata = value.metadata
                self._tag_metadata[time_series_key] = metadata

            plugin_name = metadata.plugin_data.plugin_name
            if value.HasField("metadata") and (
                plugin_name != value.metadata.plugin_data.plugin_name
            ):
                logger.warning(
                    "Mismatching plugin names for %s.  Expected %s, found %s.",
                    time_series_key,
                    metadata.plugin_data.plugin_name,
                    value.metadata.plugin_data.plugin_name,
                )
                continue
            if plugin_name not in self._allowed_plugins:
                if first_in_time_series:
                    logger.info(
                        "Skipping time series %r with unsupported plugin name %r",
                        time_series_key,
                        plugin_name,
                    )
                continue
            if plugin_name == scalar_metadata.PLUGIN_NAME:
                self._scalar_request_sender.add_event(
                    run_name, event, value, metadata
                )
            # TODO(nielsene): add Tensor plugin cases here
            # TODO(soergel): add Graphs blob case here

        self._scalar_request_sender.flush()
コード例 #3
0
    def test_empty_histogram(self):
        with tf.compat.v1.Graph().as_default():
            old_op = tf.compat.v1.summary.histogram("empty_yet_important",
                                                    tf.constant([]))
            old_value = self._value_from_op(old_op)
        assert old_value.HasField("histo"), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual("empty_yet_important", new_value.tag)
        expected_metadata = histogram_metadata.create_summary_metadata(
            display_name="empty_yet_important", description="")
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField("tensor"))
        buckets = tensor_util.make_ndarray(new_value.tensor)
        self.assertEmpty(buckets)
コード例 #4
0
  def test_histogram(self):
    old_op = tf.summary.histogram('important_data',
                                  tf.random_normal(shape=[23, 45]))
    old_value = self._value_from_op(old_op)
    assert old_value.HasField('histo'), old_value
    new_value = data_compat.migrate_value(old_value)

    self.assertEqual('important_data', new_value.tag)
    expected_metadata = histogram_metadata.create_summary_metadata(
        display_name='important_data', description='')
    self.assertEqual(expected_metadata, new_value.metadata)
    self.assertTrue(new_value.HasField('tensor'))
    buckets = tf.make_ndarray(new_value.tensor)
    self.assertEqual(old_value.histo.min, buckets[0][0])
    self.assertEqual(old_value.histo.max, buckets[-1][1])
    self.assertEqual(23 * 45, buckets[:, 2].astype(int).sum())
コード例 #5
0
  def test_scalar(self):
    old_op = tf.summary.scalar('important_constants', tf.constant(0x5f3759df))
    old_value = self._value_from_op(old_op)
    assert old_value.HasField('simple_value'), old_value
    new_value = data_compat.migrate_value(old_value)

    self.assertEqual('important_constants', new_value.tag)
    expected_metadata = scalar_metadata.create_summary_metadata(
        display_name='important_constants',
        description='')
    self.assertEqual(expected_metadata, new_value.metadata)
    self.assertTrue(new_value.HasField('tensor'))
    data = tf.make_ndarray(new_value.tensor)
    self.assertEqual((), data.shape)
    low_precision_value = np.array(0x5f3759df).astype('float32').item()
    self.assertEqual(low_precision_value, data.item())
コード例 #6
0
    def test_histogram(self):
        old_op = tf.summary.histogram('important_data',
                                      tf.random_normal(shape=[23, 45]))
        old_value = self._value_from_op(old_op)
        assert old_value.HasField('histo'), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual('important_data', new_value.tag)
        expected_metadata = histogram_metadata.create_summary_metadata(
            display_name='important_data', description='')
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField('tensor'))
        buckets = tf.make_ndarray(new_value.tensor)
        self.assertEqual(old_value.histo.min, buckets[0][0])
        self.assertEqual(old_value.histo.max, buckets[-1][1])
        self.assertEqual(23 * 45, buckets[:, 2].astype(int).sum())
コード例 #7
0
    def test_scalar(self):
        old_op = tf.summary.scalar('important_constants',
                                   tf.constant(0x5f3759df))
        old_value = self._value_from_op(old_op)
        assert old_value.HasField('simple_value'), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual('important_constants', new_value.tag)
        expected_metadata = scalar_metadata.create_summary_metadata(
            display_name='important_constants', description='')
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField('tensor'))
        data = tf.make_ndarray(new_value.tensor)
        self.assertEqual((), data.shape)
        low_precision_value = np.array(0x5f3759df).astype('float32').item()
        self.assertEqual(low_precision_value, data.item())
コード例 #8
0
    def test_scalar(self):
        with tf.compat.v1.Graph().as_default():
            old_op = tf.compat.v1.summary.scalar("important_constants",
                                                 tf.constant(0x5F3759DF))
            old_value = self._value_from_op(old_op)
        assert old_value.HasField("simple_value"), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual("important_constants", new_value.tag)
        expected_metadata = scalar_metadata.create_summary_metadata(
            display_name="important_constants", description="")
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField("tensor"))
        data = tensor_util.make_ndarray(new_value.tensor)
        self.assertEqual((), data.shape)
        low_precision_value = np.array(0x5F3759DF).astype("float32").item()
        self.assertEqual(low_precision_value, data.item())
コード例 #9
0
    def test_histogram(self):
        with tf.compat.v1.Graph().as_default():
            old_op = tf.compat.v1.summary.histogram(
                "important_data", tf.random.normal(shape=[23, 45]))
            old_value = self._value_from_op(old_op)
        assert old_value.HasField("histo"), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual("important_data", new_value.tag)
        expected_metadata = histogram_metadata.create_summary_metadata(
            display_name="important_data", description="")
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField("tensor"))
        buckets = tensor_util.make_ndarray(new_value.tensor)
        self.assertEqual(old_value.histo.min, buckets[0][0])
        self.assertEqual(old_value.histo.max, buckets[-1][1])
        self.assertEqual(23 * 45, buckets[:, 2].astype(int).sum())
コード例 #10
0
ファイル: data_compat_test.py プロジェクト: jlewi/tensorboard
  def test_image(self):
    old_op = tf.summary.image('mona_lisa',
                              tf.cast(tf.random_normal(shape=[1, 400, 200, 3]),
                                      tf.uint8))
    old_value = self._value_from_op(old_op)
    assert old_value.HasField('image'), old_value
    new_value = data_compat.migrate_value(old_value)

    self.assertEqual('mona_lisa/image/0', new_value.tag)
    expected_metadata = image_metadata.create_summary_metadata(
        display_name='mona_lisa/image/0', description='')
    self.assertEqual(expected_metadata, new_value.metadata)
    self.assertTrue(new_value.HasField('tensor'))
    (width, height, data) = tf.make_ndarray(new_value.tensor)
    self.assertEqual(b'200', width)
    self.assertEqual(b'400', height)
    self.assertEqual(
        tf.compat.as_bytes(old_value.image.encoded_image_string), data)
コード例 #11
0
    def test_image(self):
        old_op = tf.summary.image(
            'mona_lisa',
            tf.cast(tf.random_normal(shape=[1, 400, 200, 3]), tf.uint8))
        old_value = self._value_from_op(old_op)
        assert old_value.HasField('image'), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual('mona_lisa/image/0', new_value.tag)
        expected_metadata = image_metadata.create_summary_metadata(
            display_name='mona_lisa/image/0', description='')
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField('tensor'))
        (width, height, data) = tf.make_ndarray(new_value.tensor)
        self.assertEqual(b'200', width)
        self.assertEqual(b'400', height)
        self.assertEqual(
            tf.compat.as_bytes(old_value.image.encoded_image_string), data)
コード例 #12
0
    def test_image(self):
        with tf.compat.v1.Graph().as_default():
            old_op = tf.compat.v1.summary.image(
                "mona_lisa",
                tf.image.convert_image_dtype(
                    tf.random.normal(shape=[1, 400, 200, 3]),
                    tf.uint8,
                    saturate=True,
                ),
            )
            old_value = self._value_from_op(old_op)
        assert old_value.HasField("image"), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual("mona_lisa/image/0", new_value.tag)
        expected_metadata = image_metadata.create_summary_metadata(
            display_name="mona_lisa/image/0",
            description="",
            converted_to_tensor=True,
        )

        # Check serialized submessages...
        plugin_content = image_metadata.parse_plugin_metadata(
            new_value.metadata.plugin_data.content
        )
        expected_content = image_metadata.parse_plugin_metadata(
            expected_metadata.plugin_data.content
        )
        self.assertEqual(plugin_content, expected_content)
        # ...then check full metadata except plugin content, since
        # serialized forms need not be identical.
        new_value.metadata.plugin_data.content = (
            expected_metadata.plugin_data.content
        )
        self.assertEqual(expected_metadata, new_value.metadata)

        self.assertTrue(new_value.HasField("tensor"))
        (width, height, data) = tensor_util.make_ndarray(new_value.tensor)
        self.assertEqual(b"200", width)
        self.assertEqual(b"400", height)
        self.assertEqual(
            tf.compat.as_bytes(old_value.image.encoded_image_string), data
        )
コード例 #13
0
  def test_audio(self):
    audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2))
    old_op = tf.summary.audio('k488', audio, 44100)
    old_value = self._value_from_op(old_op)
    assert old_value.HasField('audio'), old_value
    new_value = data_compat.migrate_value(old_value)

    self.assertEqual('k488/audio/0', new_value.tag)
    expected_metadata = audio_metadata.create_summary_metadata(
        display_name='k488/audio/0',
        description='',
        encoding=audio_metadata.Encoding.Value('WAV'))
    self.assertEqual(expected_metadata, new_value.metadata)
    self.assertTrue(new_value.HasField('tensor'))
    data = tf.make_ndarray(new_value.tensor)
    self.assertEqual((1, 2), data.shape)
    self.assertEqual(tf.compat.as_bytes(old_value.audio.encoded_audio_string),
                     data[0][0])
    self.assertEqual(b'', data[0][1])  # empty label
コード例 #14
0
  def test_audio(self):
    audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2))
    old_op = tf.summary.audio('k488', audio, 44100)
    old_value = self._value_from_op(old_op)
    assert old_value.HasField('audio'), old_value
    new_value = data_compat.migrate_value(old_value)

    self.assertEqual('k488/audio/0', new_value.tag)
    expected_metadata = audio_metadata.create_summary_metadata(
        display_name='k488/audio/0',
        description='',
        encoding=audio_metadata.Encoding.Value('WAV'))
    self.assertEqual(expected_metadata, new_value.metadata)
    self.assertTrue(new_value.HasField('tensor'))
    data = tf.make_ndarray(new_value.tensor)
    self.assertEqual((1, 2), data.shape)
    self.assertEqual(tf.compat.as_bytes(old_value.audio.encoded_audio_string),
                     data[0][0])
    self.assertEqual(b'', data[0][1])  # empty label
コード例 #15
0
    def test_single_value_histogram(self):
        with tf.compat.v1.Graph().as_default():
            old_op = tf.compat.v1.summary.histogram("single_value_data",
                                                    tf.constant([1] * 1024))
            old_value = self._value_from_op(old_op)
        assert old_value.HasField("histo"), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual("single_value_data", new_value.tag)
        expected_metadata = histogram_metadata.create_summary_metadata(
            display_name="single_value_data", description="")
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField("tensor"))
        buckets = tensor_util.make_ndarray(new_value.tensor)
        # Only one bucket is kept.
        self.assertEqual((1, 3), buckets.shape)
        self.assertEqual(1, buckets[0][0])
        self.assertEqual(1, buckets[-1][1])
        self.assertEqual(1024, buckets[0][2])
コード例 #16
0
    def test_histogram_with_extremal_values(self):
        with tf.compat.v1.Graph().as_default():
            old_op = tf.compat.v1.summary.histogram("extremal_values",
                                                    tf.constant([-1e20, 1e20]))
            old_value = self._value_from_op(old_op)
        assert old_value.HasField("histo"), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual("extremal_values", new_value.tag)
        expected_metadata = histogram_metadata.create_summary_metadata(
            display_name="extremal_values", description="")
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField("tensor"))
        buckets = tensor_util.make_ndarray(new_value.tensor)
        for bucket in buckets:
            # No `backwards` buckets.
            self.assertLessEqual(bucket[0], bucket[1])
        self.assertEqual(old_value.histo.min, buckets[0][0])
        self.assertEqual(old_value.histo.max, buckets[-1][1])
        self.assertEqual(2, buckets[:, 2].astype(int).sum())
コード例 #17
0
 def _process_event(self, event, tagged_data):
   """Processes a single tf.Event and records it in tagged_data."""
   event_type = event.WhichOneof('what')
   # Handle the most common case first.
   if event_type == 'summary':
     for value in event.summary.value:
       value = data_compat.migrate_value(value)
       tag, metadata, values = tagged_data.get(value.tag, (None, None, []))
       values.append((event.step, event.wall_time, value.tensor))
       if tag is None:
         # Store metadata only from the first event.
         tagged_data[value.tag] = sqlite_writer.TagData(
             value.tag, value.metadata, values)
   elif event_type == 'file_version':
     pass  # TODO: reject file version < 2 (at loader level)
   elif event_type == 'session_log':
     if event.session_log.status == event_pb2.SessionLog.START:
       pass  # TODO: implement purging via sqlite writer truncation method
   elif event_type in ('graph_def', 'meta_graph_def'):
     pass  # TODO: support graphs
   elif event_type == 'tagged_run_metadata':
     pass  # TODO: support run metadata
コード例 #18
0
 def _process_event(self, event, tagged_data):
   """Processes a single tf.Event and records it in tagged_data."""
   event_type = event.WhichOneof('what')
   # Handle the most common case first.
   if event_type == 'summary':
     for value in event.summary.value:
       value = data_compat.migrate_value(value)
       tag, metadata, values = tagged_data.get(value.tag, (None, None, []))
       values.append((event.step, event.wall_time, value.tensor))
       if tag is None:
         # Store metadata only from the first event.
         tagged_data[value.tag] = sqlite_writer.TagData(
             value.tag, value.metadata, values)
   elif event_type == 'file_version':
     pass  # TODO: reject file version < 2 (at loader level)
   elif event_type == 'session_log':
     if event.session_log.status == tf.SessionLog.START:
       pass  # TODO: implement purging via sqlite writer truncation method
   elif event_type in ('graph_def', 'meta_graph_def'):
     pass  # TODO: support graphs
   elif event_type == 'tagged_run_metadata':
     pass  # TODO: support run metadata
コード例 #19
0
    def test_audio(self):
        with tf.compat.v1.Graph().as_default():
            audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2))
            old_op = tf.compat.v1.summary.audio("k488", audio, 44100)
            old_value = self._value_from_op(old_op)
        assert old_value.HasField("audio"), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual("k488/audio/0", new_value.tag)
        expected_metadata = audio_metadata.create_summary_metadata(
            display_name="k488/audio/0",
            description="",
            encoding=audio_metadata.Encoding.Value("WAV"),
        )
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField("tensor"))
        data = tensor_util.make_ndarray(new_value.tensor)
        self.assertEqual((1, 2), data.shape)
        self.assertEqual(
            tf.compat.as_bytes(old_value.audio.encoded_audio_string),
            data[0][0])
        self.assertEqual(b"", data[0][1])  # empty label
コード例 #20
0
    def test_image(self):
        with tf.compat.v1.Graph().as_default():
            old_op = tf.compat.v1.summary.image(
                "mona_lisa",
                tf.image.convert_image_dtype(
                    tf.random.normal(shape=[1, 400, 200, 3]),
                    tf.uint8,
                    saturate=True,
                ),
            )
            old_value = self._value_from_op(old_op)
        assert old_value.HasField("image"), old_value
        new_value = data_compat.migrate_value(old_value)

        self.assertEqual("mona_lisa/image/0", new_value.tag)
        expected_metadata = image_metadata.create_summary_metadata(
            display_name="mona_lisa/image/0", description="")
        self.assertEqual(expected_metadata, new_value.metadata)
        self.assertTrue(new_value.HasField("tensor"))
        (width, height, data) = tensor_util.make_ndarray(new_value.tensor)
        self.assertEqual(b"200", width)
        self.assertEqual(b"400", height)
        self.assertEqual(
            tf.compat.as_bytes(old_value.image.encoded_image_string), data)
コード例 #21
0
    def _ProcessEvent(self, event):
        """Called whenever an event is loaded."""
        if self._first_event_timestamp is None:
            self._first_event_timestamp = event.wall_time

        if event.HasField('file_version'):
            new_file_version = _ParseFileVersion(event.file_version)
            if self.file_version and self.file_version != new_file_version:
                ## This should not happen.
                tf.logging.warn(
                    ('Found new file_version for event.proto. This will '
                     'affect purging logic for TensorFlow restarts. '
                     'Old: {0} New: {1}').format(self.file_version,
                                                 new_file_version))
            self.file_version = new_file_version

        self._MaybePurgeOrphanedData(event)

        ## Process the event.
        # GraphDef and MetaGraphDef are handled in a special way:
        # If no graph_def Event is available, but a meta_graph_def is, and it
        # contains a graph_def, then use the meta_graph_def.graph_def as our graph.
        # If a graph_def Event is available, always prefer it to the graph_def
        # inside the meta_graph_def.
        if event.HasField('graph_def'):
            if self._graph is not None:
                tf.logging.warn(
                    ('Found more than one graph event per run, or there was '
                     'a metagraph containing a graph_def, as well as one or '
                     'more graph events.  Overwriting the graph with the '
                     'newest event.'))
            self._graph = event.graph_def
            self._graph_from_metagraph = False
        elif event.HasField('meta_graph_def'):
            if self._meta_graph is not None:
                tf.logging.warn(
                    ('Found more than one metagraph event per run. '
                     'Overwriting the metagraph with the newest event.'))
            self._meta_graph = event.meta_graph_def
            if self._graph is None or self._graph_from_metagraph:
                # We may have a graph_def in the metagraph.  If so, and no
                # graph_def is directly available, use this one instead.
                meta_graph = tf.MetaGraphDef()
                meta_graph.ParseFromString(self._meta_graph)
                if meta_graph.graph_def:
                    if self._graph is not None:
                        tf.logging.warn((
                            'Found multiple metagraphs containing graph_defs,'
                            'but did not find any graph events.  Overwriting the '
                            'graph with the newest metagraph version.'))
                    self._graph_from_metagraph = True
                    self._graph = meta_graph.graph_def.SerializeToString()
        elif event.HasField('tagged_run_metadata'):
            tag = event.tagged_run_metadata.tag
            if tag in self._tagged_metadata:
                tf.logging.warn(
                    'Found more than one "run metadata" event with tag ' +
                    tag + '. Overwriting it with the newest event.')
            self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
        elif event.HasField('summary'):
            for value in event.summary.value:
                value = data_compat.migrate_value(value)

                if value.HasField('metadata'):
                    tag = value.tag
                    # We only store the first instance of the metadata. This check
                    # is important: the `FileWriter` does strip metadata from all
                    # values except the first one per each tag, but a new
                    # `FileWriter` is created every time a training job stops and
                    # restarts. Hence, we must also ignore non-initial metadata in
                    # this logic.
                    if tag not in self.summary_metadata:
                        self.summary_metadata[tag] = value.metadata
                        plugin_data = value.metadata.plugin_data
                        if plugin_data.plugin_name:
                            self._plugin_to_tag_to_content[
                                plugin_data.plugin_name][tag] = (
                                    plugin_data.content)
                        else:
                            tf.logging.warn((
                                'This summary with tag %r is oddly not associated with a '
                                'plugin.'), tag)

                for summary_type, summary_func in SUMMARY_TYPES.items():
                    if value.HasField(summary_type):
                        datum = getattr(value, summary_type)
                        tag = value.tag
                        if summary_type == 'tensor' and not tag:
                            # This tensor summary was created using the old method that used
                            # plugin assets. We must still continue to support it.
                            tag = value.node_name
                        getattr(self, summary_func)(tag, event.wall_time,
                                                    event.step, datum)
コード例 #22
0
 def _assert_noop(self, value):
     original_pbtxt = value.SerializeToString()
     result = data_compat.migrate_value(value)
     self.assertEqual(value, result)
     self.assertEqual(original_pbtxt, value.SerializeToString())
コード例 #23
0
ファイル: uploader.py プロジェクト: zimaxeg/tensorboard
    def build_requests(self, run_to_events):
        """Converts a stream of TF events to a stream of outgoing requests.

        Each yielded request will be at most `_MAX_REQUEST_LENGTH_BYTES`
        bytes long.

        Args:
          run_to_events: Mapping from run name to generator of `tf.Event`
            values, as returned by `LogdirLoader.get_run_events`.

        Yields:
          A finite stream of `WriteScalarRequest` objects.

        Raises:
          RuntimeError: If no progress can be made because even a single
          point is too large (say, due to a gigabyte-long tag name).
        """

        self._new_request()
        runs = {}  # cache: map from run name to `Run` proto in request
        tags = (
            {}
        )  # cache: map from `(run, tag)` to `Tag` proto in run in request
        work_items = peekable_iterator.PeekableIterator(
            self._run_values(run_to_events))

        while work_items.has_next():
            (run_name, event, orig_value) = work_items.peek()
            value = data_compat.migrate_value(orig_value)
            time_series_key = (run_name, value.tag)

            metadata = self._tag_metadata.get(time_series_key)
            if metadata is None:
                plugin_name = value.metadata.plugin_data.plugin_name
                if plugin_name == scalar_metadata.PLUGIN_NAME:
                    metadata = value.metadata
                else:
                    metadata = _RequestBuilder._NON_SCALAR_TIME_SERIES
                self._tag_metadata[time_series_key] = metadata
            if metadata is _RequestBuilder._NON_SCALAR_TIME_SERIES:
                next(work_items)
                continue
            try:
                run_proto = runs.get(run_name)
                if run_proto is None:
                    run_proto = self._create_run(run_name)
                    runs[run_name] = run_proto
                tag_proto = tags.get((run_name, value.tag))
                if tag_proto is None:
                    tag_proto = self._create_tag(run_proto, value.tag,
                                                 metadata)
                    tags[(run_name, value.tag)] = tag_proto
                self._create_point(tag_proto, event, value)
                next(work_items)
            except _OutOfSpaceError:
                # Flush request and start a new one.
                request_to_emit = self._prune_request()
                if request_to_emit is None:
                    raise RuntimeError(
                        "Could not make progress uploading data")
                self._new_request()
                runs.clear()
                tags.clear()
                yield request_to_emit

        final_request = self._prune_request()
        if final_request is not None:
            yield final_request
コード例 #24
0
  def _ProcessEvent(self, event):
    """Called whenever an event is loaded."""
    if self._first_event_timestamp is None:
      self._first_event_timestamp = event.wall_time

    if event.HasField('file_version'):
      new_file_version = _ParseFileVersion(event.file_version)
      if self.file_version and self.file_version != new_file_version:
        ## This should not happen.
        tf.logging.warn(('Found new file_version for event.proto. This will '
                         'affect purging logic for TensorFlow restarts. '
                         'Old: {0} New: {1}').format(self.file_version,
                                                     new_file_version))
      self.file_version = new_file_version

    self._MaybePurgeOrphanedData(event)

    ## Process the event.
    # GraphDef and MetaGraphDef are handled in a special way:
    # If no graph_def Event is available, but a meta_graph_def is, and it
    # contains a graph_def, then use the meta_graph_def.graph_def as our graph.
    # If a graph_def Event is available, always prefer it to the graph_def
    # inside the meta_graph_def.
    if event.HasField('graph_def'):
      if self._graph is not None:
        tf.logging.warn(
            ('Found more than one graph event per run, or there was '
             'a metagraph containing a graph_def, as well as one or '
             'more graph events.  Overwriting the graph with the '
             'newest event.'))
      self._graph = event.graph_def
      self._graph_from_metagraph = False
    elif event.HasField('meta_graph_def'):
      if self._meta_graph is not None:
        tf.logging.warn(('Found more than one metagraph event per run. '
                         'Overwriting the metagraph with the newest event.'))
      self._meta_graph = event.meta_graph_def
      if self._graph is None or self._graph_from_metagraph:
        # We may have a graph_def in the metagraph.  If so, and no
        # graph_def is directly available, use this one instead.
        meta_graph = tf.MetaGraphDef()
        meta_graph.ParseFromString(self._meta_graph)
        if meta_graph.graph_def:
          if self._graph is not None:
            tf.logging.warn(
                ('Found multiple metagraphs containing graph_defs,'
                 'but did not find any graph events.  Overwriting the '
                 'graph with the newest metagraph version.'))
          self._graph_from_metagraph = True
          self._graph = meta_graph.graph_def.SerializeToString()
    elif event.HasField('tagged_run_metadata'):
      tag = event.tagged_run_metadata.tag
      if tag in self._tagged_metadata:
        tf.logging.warn('Found more than one "run metadata" event with tag ' +
                        tag + '. Overwriting it with the newest event.')
      self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
    elif event.HasField('summary'):
      for value in event.summary.value:
        value = data_compat.migrate_value(value)

        if value.HasField('metadata'):
          tag = value.tag
          # We only store the first instance of the metadata. This check
          # is important: the `FileWriter` does strip metadata from all
          # values except the first one per each tag, but a new
          # `FileWriter` is created every time a training job stops and
          # restarts. Hence, we must also ignore non-initial metadata in
          # this logic.
          if tag not in self.summary_metadata:
            self.summary_metadata[tag] = value.metadata
            plugin_data = value.metadata.plugin_data
            if plugin_data.plugin_name:
              self._plugin_to_tag_to_content[plugin_data.plugin_name][tag] = (
                  plugin_data.content)
            else:
              tf.logging.warn(
                  ('This summary with tag %r is oddly not associated with a '
                   'plugin.'), tag)

        for summary_type, summary_func in SUMMARY_TYPES.items():
          if value.HasField(summary_type):
            datum = getattr(value, summary_type)
            tag = value.tag
            if summary_type == 'tensor' and not tag:
              # This tensor summary was created using the old method that used
              # plugin assets. We must still continue to support it.
              tag = value.node_name
            getattr(self, summary_func)(tag, event.wall_time, event.step, datum)
コード例 #25
0
 def _assert_noop(self, value):
   original_pbtxt = value.SerializeToString()
   result = data_compat.migrate_value(value)
   self.assertEqual(value, result)
   self.assertEqual(original_pbtxt, value.SerializeToString())