def test_peek_after_end(self):
   it = peekable_iterator.PeekableIterator([1, 2, 3])
   self.assertEqual(list(it), [1, 2, 3])
   with self.assertRaises(StopIteration):
     it.peek()
   with self.assertRaises(StopIteration):
     it.peek()
 def test_simple_peek(self):
   it = peekable_iterator.PeekableIterator([1, 2, 3])
   self.assertEqual(it.peek(), 1)
   self.assertEqual(it.peek(), 1)
   self.assertEqual(next(it), 1)
   self.assertEqual(it.peek(), 2)
   self.assertEqual(next(it), 2)
   self.assertEqual(next(it), 3)
   self.assertEqual(list(it), [])
 def test_simple_has_next(self):
   it = peekable_iterator.PeekableIterator([1, 2])
   self.assertTrue(it.has_next())
   self.assertEqual(it.peek(), 1)
   self.assertTrue(it.has_next())
   self.assertEqual(next(it), 1)
   self.assertEqual(it.peek(), 2)
   self.assertTrue(it.has_next())
   self.assertEqual(next(it), 2)
   self.assertFalse(it.has_next())
   self.assertFalse(it.has_next())
Esempio n. 4
0
    def build_requests(self, run_to_events):
        """Converts a stream of TF events to a stream of outgoing requests.

        Each yielded request will be at most `_MAX_REQUEST_LENGTH_BYTES`
        bytes long.

        Args:
          run_to_events: Mapping from run name to generator of `tf.Event`
            values, as returned by `LogdirLoader.get_run_events`.

        Yields:
          A finite stream of `WriteScalarRequest` objects.

        Raises:
          RuntimeError: If no progress can be made because even a single
          point is too large (say, due to a gigabyte-long tag name).
        """

        self._new_request()
        runs = {}  # cache: map from run name to `Run` proto in request
        tags = (
            {}
        )  # cache: map from `(run, tag)` to `Tag` proto in run in request
        work_items = peekable_iterator.PeekableIterator(
            self._run_values(run_to_events))

        while work_items.has_next():
            (run_name, event, orig_value) = work_items.peek()
            value = data_compat.migrate_value(orig_value)
            time_series_key = (run_name, value.tag)

            metadata = self._tag_metadata.get(time_series_key)
            if metadata is None:
                plugin_name = value.metadata.plugin_data.plugin_name
                if plugin_name == scalar_metadata.PLUGIN_NAME:
                    metadata = value.metadata
                else:
                    metadata = _RequestBuilder._NON_SCALAR_TIME_SERIES
                self._tag_metadata[time_series_key] = metadata
            if metadata is _RequestBuilder._NON_SCALAR_TIME_SERIES:
                next(work_items)
                continue
            try:
                run_proto = runs.get(run_name)
                if run_proto is None:
                    run_proto = self._create_run(run_name)
                    runs[run_name] = run_proto
                tag_proto = tags.get((run_name, value.tag))
                if tag_proto is None:
                    tag_proto = self._create_tag(run_proto, value.tag,
                                                 metadata)
                    tags[(run_name, value.tag)] = tag_proto
                self._create_point(tag_proto, event, value)
                next(work_items)
            except _OutOfSpaceError:
                # Flush request and start a new one.
                request_to_emit = self._prune_request()
                if request_to_emit is None:
                    raise RuntimeError(
                        "Could not make progress uploading data")
                self._new_request()
                runs.clear()
                tags.clear()
                yield request_to_emit

        final_request = self._prune_request()
        if final_request is not None:
            yield final_request
 def test_normal_iteration(self):
   it = peekable_iterator.PeekableIterator([1, 2, 3])
   self.assertEqual(list(it), [1, 2, 3])
 def test_empty_iteration(self):
   it = peekable_iterator.PeekableIterator([])
   self.assertEqual(list(it), [])