예제 #1
0
class StreamingCache(CacheManager):
    """Abstraction that holds the logic for reading and writing to cache.
  """
    def __init__(self,
                 cache_dir,
                 is_cache_complete=None,
                 sample_resolution_sec=0.1):
        self._sample_resolution_sec = sample_resolution_sec
        self._is_cache_complete = is_cache_complete

        if cache_dir:
            self._cache_dir = cache_dir
        else:
            self._cache_dir = tempfile.mkdtemp(prefix='interactive-temp-',
                                               dir=os.environ.get(
                                                   'TEST_TMPDIR', None))

        # List of saved pcoders keyed by PCollection path. It is OK to keep this
        # list in memory because once FileBasedCacheManager object is
        # destroyed/re-created it loses the access to previously written cache
        # objects anyways even if cache_dir already exists. In other words,
        # it is not possible to resume execution of Beam pipeline from the
        # saved cache if FileBasedCacheManager has been reset.
        #
        # However, if we are to implement better cache persistence, one needs
        # to take care of keeping consistency between the cached PCollection
        # and its PCoder type.
        self._saved_pcoders = {}
        self._default_pcoder = SafeFastPrimitivesCoder()

        # The sinks to capture data from capturable sources.
        # Dict([str, StreamingCacheSink])
        self._capture_sinks = {}
        self._capture_keys = set()

    def size(self, *labels):
        if self.exists(*labels):
            return os.path.getsize(os.path.join(self._cache_dir, *labels))
        return 0

    @property
    def capture_size(self):
        return sum(
            [sink.size_in_bytes for _, sink in self._capture_sinks.items()])

    @property
    def capture_paths(self):
        return list(self._capture_sinks.keys())

    @property
    def capture_keys(self):
        return self._capture_keys

    def exists(self, *labels):
        path = os.path.join(self._cache_dir, *labels)
        return os.path.exists(path)

    # TODO(srohde): Modify this to return the correct version.
    def read(self, *labels, **args):
        """Returns a generator to read all records from file."""
        tail = args.pop('tail', False)

        # Only immediately return when the file doesn't exist when the user wants a
        # snapshot of the cache (when tail is false).
        if not self.exists(*labels) and not tail:
            return iter([]), -1

        reader = StreamingCacheSource(self._cache_dir, labels,
                                      self._is_cache_complete).read(tail=tail)

        # Return an empty iterator if there is nothing in the file yet. This can
        # only happen when tail is False.
        try:
            header = next(reader)
        except StopIteration:
            return iter([]), -1
        return StreamingCache.Reader([header], [reader]).read(), 1

    def read_multiple(self, labels, tail=True):
        """Returns a generator to read all records from file.

    Does tail until the cache is complete. This is because it is used in the
    TestStreamServiceController to read from file which is only used during
    pipeline runtime which needs to block.
    """
        readers = [
            StreamingCacheSource(self._cache_dir, l,
                                 self._is_cache_complete).read(tail=tail)
            for l in labels
        ]
        headers = [next(r) for r in readers]
        return StreamingCache.Reader(headers, readers).read()

    def write(self, values, *labels):
        """Writes the given values to cache.
    """
        directory = os.path.join(self._cache_dir, *labels[:-1])
        filepath = os.path.join(directory, labels[-1])
        if not os.path.exists(directory):
            os.makedirs(directory)
        with open(filepath, 'ab') as f:
            for v in values:
                if isinstance(v, (TestStreamFileHeader, TestStreamFileRecord)):
                    val = v.SerializeToString()
                else:
                    val = v
                f.write(self._default_pcoder.encode(val) + b'\n')

    def clear(self, *labels):
        directory = os.path.join(self._cache_dir, *labels[:-1])
        filepath = os.path.join(directory, labels[-1])
        self._capture_keys.discard(labels[-1])
        if os.path.exists(filepath):
            os.remove(filepath)
            return True
        return False

    def source(self, *labels):
        """Returns the StreamingCacheManager source.

    This is beam.Impulse() because unbounded sources will be marked with this
    and then the PipelineInstrument will replace these with a TestStream.
    """
        return beam.Impulse()

    def sink(self, labels, is_capture=False):
        """Returns a StreamingCacheSink to write elements to file.

    Note that this is assumed to only work in the DirectRunner as the underlying
    StreamingCacheSink assumes a single machine to have correct element
    ordering.
    """
        filename = labels[-1]
        cache_dir = os.path.join(self._cache_dir, *labels[:-1])
        sink = StreamingCacheSink(cache_dir, filename,
                                  self._sample_resolution_sec)
        if is_capture:
            self._capture_sinks[sink.path] = sink
            self._capture_keys.add(filename)
        return sink

    def save_pcoder(self, pcoder, *labels):
        self._saved_pcoders[os.path.join(*labels)] = pcoder

    def load_pcoder(self, *labels):
        return (self._default_pcoder if self._default_pcoder is not None else
                self._saved_pcoders[os.path.join(*labels)])

    def cleanup(self):
        if os.path.exists(self._cache_dir):
            shutil.rmtree(self._cache_dir)
        self._saved_pcoders = {}
        self._capture_sinks = {}
        self._capture_keys = set()

    class Reader(object):
        """Abstraction that reads from PCollection readers.

    This class is an Abstraction layer over multiple PCollection readers to be
    used for supplying a TestStream service with events.

    This class is also responsible for holding the state of the clock, injecting
    clock advancement events, and watermark advancement events.
    """
        def __init__(self, headers, readers):
            # This timestamp is used as the monotonic clock to order events in the
            # replay.
            self._monotonic_clock = timestamp.Timestamp.of(0)

            # The PCollection cache readers.
            self._readers = {}

            # The file headers that are metadata for that particular PCollection.
            # The header allows for metadata about an entire stream, so that the data
            # isn't copied per record.
            self._headers = {header.tag: header for header in headers}
            self._readers = OrderedDict(
                ((h.tag, r) for (h, r) in zip(headers, readers)))

            # The most recently read timestamp per tag.
            self._stream_times = {
                tag: timestamp.Timestamp(seconds=0)
                for tag in self._headers
            }

        def _test_stream_events_before_target(self, target_timestamp):
            """Reads the next iteration of elements from each stream.

      Retrieves an element from each stream iff the most recently read timestamp
      from that stream is less than the target_timestamp. Since the amount of
      events may not fit into memory, this StreamingCache reads at most one
      element from each stream at a time.
      """
            records = []
            for tag, r in self._readers.items():
                # The target_timestamp is the maximum timestamp that was read from the
                # stream. Some readers may have elements that are less than this. Thus,
                # we skip all readers that already have elements that are at this
                # timestamp so that we don't read everything into memory.
                if self._stream_times[tag] >= target_timestamp:
                    continue
                try:
                    record = next(r).recorded_event
                    if record.HasField('processing_time_event'):
                        self._stream_times[tag] += timestamp.Duration(
                            micros=record.processing_time_event.
                            advance_duration)
                    records.append((tag, record, self._stream_times[tag]))
                except StopIteration:
                    pass
            return records

        def _merge_sort(self, previous_events, new_events):
            return sorted(previous_events + new_events,
                          key=lambda x: x[2],
                          reverse=True)

        def _min_timestamp_of(self, events):
            return events[-1][2] if events else timestamp.MAX_TIMESTAMP

        def _event_stream_caught_up_to_target(self, events, target_timestamp):
            empty_events = not events
            stream_is_past_target = self._min_timestamp_of(
                events) > target_timestamp
            return empty_events or stream_is_past_target

        def read(self):
            """Reads records from PCollection readers.
      """

            # The largest timestamp read from the different streams.
            target_timestamp = timestamp.MAX_TIMESTAMP

            # The events from last iteration that are past the target timestamp.
            unsent_events = []

            # Emit events until all events have been read.
            while True:
                # Read the next set of events. The read events will most likely be
                # out of order if there are multiple readers. Here we sort them into
                # a more manageable state.
                new_events = self._test_stream_events_before_target(
                    target_timestamp)
                events_to_send = self._merge_sort(unsent_events, new_events)
                if not events_to_send:
                    break

                # Get the next largest timestamp in the stream. This is used as the
                # timestamp for readers to "catch-up" to. This will only read from
                # readers with a timestamp less than this.
                target_timestamp = self._min_timestamp_of(events_to_send)

                # Loop through the elements with the correct timestamp.
                while not self._event_stream_caught_up_to_target(
                        events_to_send, target_timestamp):

                    # First advance the clock to match the time of the stream. This has
                    # a side-effect of also advancing this cache's clock.
                    tag, r, curr_timestamp = events_to_send.pop()
                    if curr_timestamp > self._monotonic_clock:
                        yield self._advance_processing_time(curr_timestamp)

                    # Then, send either a new element or watermark.
                    if r.HasField('element_event'):
                        r.element_event.tag = tag
                        yield r
                    elif r.HasField('watermark_event'):
                        r.watermark_event.tag = tag
                        yield r
                unsent_events = events_to_send
                target_timestamp = self._min_timestamp_of(unsent_events)

        def _advance_processing_time(self, new_timestamp):
            """Advances the internal clock and returns an AdvanceProcessingTime event.
      """
            advancy_by = new_timestamp.micros - self._monotonic_clock.micros
            e = TestStreamPayload.Event(
                processing_time_event=TestStreamPayload.Event.
                AdvanceProcessingTime(advance_duration=advancy_by))
            self._monotonic_clock = new_timestamp
            return e
  def test_read_and_write(self):
    """An integration test between the Sink and Source.

    This ensures that the sink and source speak the same language in terms of
    coders, protos, order, and units.
    """
    CACHED_RECORDS = repr(CacheKey('records', '', '', ''))

    # Units here are in seconds.
    test_stream = (
        TestStream(output_tags=(CACHED_RECORDS))
                   .advance_watermark_to(0, tag=CACHED_RECORDS)
                   .advance_processing_time(5)
                   .add_elements(['a', 'b', 'c'], tag=CACHED_RECORDS)
                   .advance_watermark_to(10, tag=CACHED_RECORDS)
                   .advance_processing_time(1)
                   .add_elements(
                       [
                           TimestampedValue('1', 15),
                           TimestampedValue('2', 15),
                           TimestampedValue('3', 15)
                       ],
                       tag=CACHED_RECORDS)) # yapf: disable

    coder = SafeFastPrimitivesCoder()
    cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)

    # Assert that there are no capture keys at first.
    self.assertEqual(cache.capture_keys, set())

    options = StandardOptions(streaming=True)
    with TestPipeline(options=options) as p:
      records = (p | test_stream)[CACHED_RECORDS]

      # pylint: disable=expression-not-assigned
      records | cache.sink([CACHED_RECORDS], is_capture=True)

    reader, _ = cache.read(CACHED_RECORDS)
    actual_events = list(reader)

    # Assert that the capture keys are forwarded correctly.
    self.assertEqual(cache.capture_keys, set([CACHED_RECORDS]))

    # Units here are in microseconds.
    expected_events = [
        TestStreamPayload.Event(
            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                advance_duration=5 * 10**6)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=0, tag=CACHED_RECORDS)),
        TestStreamPayload.Event(
            element_event=TestStreamPayload.Event.AddElements(
                elements=[
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('a'), timestamp=0),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('b'), timestamp=0),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('c'), timestamp=0),
                ],
                tag=CACHED_RECORDS)),
        TestStreamPayload.Event(
            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                advance_duration=1 * 10**6)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=10 * 10**6, tag=CACHED_RECORDS)),
        TestStreamPayload.Event(
            element_event=TestStreamPayload.Event.AddElements(
                elements=[
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('1'), timestamp=15 *
                        10**6),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('2'), timestamp=15 *
                        10**6),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('3'), timestamp=15 *
                        10**6),
                ],
                tag=CACHED_RECORDS)),
    ]
    self.assertEqual(actual_events, expected_events)
  def test_read_and_write_multiple_outputs(self):
    """An integration test between the Sink and Source with multiple outputs.

    This tests the funcionatlity that the StreamingCache reads from multiple
    files and combines them into a single sorted output.
    """
    LETTERS_TAG = repr(CacheKey('letters', '', '', ''))
    NUMBERS_TAG = repr(CacheKey('numbers', '', '', ''))

    # Units here are in seconds.
    test_stream = (TestStream()
                   .advance_watermark_to(0, tag=LETTERS_TAG)
                   .advance_processing_time(5)
                   .add_elements(['a', 'b', 'c'], tag=LETTERS_TAG)
                   .advance_watermark_to(10, tag=NUMBERS_TAG)
                   .advance_processing_time(1)
                   .add_elements(
                       [
                           TimestampedValue('1', 15),
                           TimestampedValue('2', 15),
                           TimestampedValue('3', 15)
                       ],
                       tag=NUMBERS_TAG)) # yapf: disable

    cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)

    coder = SafeFastPrimitivesCoder()

    options = StandardOptions(streaming=True)
    with TestPipeline(options=options) as p:
      # pylint: disable=expression-not-assigned
      events = p | test_stream
      events[LETTERS_TAG] | 'Letters sink' >> cache.sink([LETTERS_TAG])
      events[NUMBERS_TAG] | 'Numbers sink' >> cache.sink([NUMBERS_TAG])

    reader = cache.read_multiple([[LETTERS_TAG], [NUMBERS_TAG]])
    actual_events = list(reader)

    # Units here are in microseconds.
    expected_events = [
        TestStreamPayload.Event(
            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                advance_duration=5 * 10**6)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=0, tag=LETTERS_TAG)),
        TestStreamPayload.Event(
            element_event=TestStreamPayload.Event.AddElements(
                elements=[
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('a'), timestamp=0),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('b'), timestamp=0),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('c'), timestamp=0),
                ],
                tag=LETTERS_TAG)),
        TestStreamPayload.Event(
            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                advance_duration=1 * 10**6)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=10 * 10**6, tag=NUMBERS_TAG)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=0, tag=LETTERS_TAG)),
        TestStreamPayload.Event(
            element_event=TestStreamPayload.Event.AddElements(
                elements=[
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('1'), timestamp=15 *
                        10**6),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('2'), timestamp=15 *
                        10**6),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('3'), timestamp=15 *
                        10**6),
                ],
                tag=NUMBERS_TAG)),
    ]

    self.assertListEqual(actual_events, expected_events)
예제 #4
0
    def test_read_and_write(self):
        """An integration test between the Sink and Source.

    This ensures that the sink and source speak the same language in terms of
    coders, protos, order, and units.
    """

        # Units here are in seconds.
        test_stream = (TestStream()
                       .advance_watermark_to(0, tag='records')
                       .advance_processing_time(5)
                       .add_elements(['a', 'b', 'c'], tag='records')
                       .advance_watermark_to(10, tag='records')
                       .advance_processing_time(1)
                       .add_elements(
                           [
                               TimestampedValue('1', 15),
                               TimestampedValue('2', 15),
                               TimestampedValue('3', 15)
                           ],
                           tag='records')) # yapf: disable

        coder = SafeFastPrimitivesCoder()
        cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)

        options = StandardOptions(streaming=True)
        options.view_as(DebugOptions).add_experiment(
            'passthrough_pcollection_output_ids')
        with TestPipeline(options=options) as p:
            # pylint: disable=expression-not-assigned
            p | test_stream | cache.sink(['records'])

        reader, _ = cache.read('records')
        actual_events = list(reader)

        # Units here are in microseconds.
        expected_events = [
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=5 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=0, tag='records')),
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('a'), timestamp=0),
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('b'), timestamp=0),
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('c'), timestamp=0),
                    ],
                    tag='records')),
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=10 * 10**6, tag='records')),
            TestStreamPayload.Event(element_event=TestStreamPayload.Event.
                                    AddElements(elements=[
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('1'),
                                            timestamp=15 * 10**6),
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('2'),
                                            timestamp=15 * 10**6),
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('3'),
                                            timestamp=15 * 10**6),
                                    ],
                                                tag='records')),
        ]
        self.assertEqual(actual_events, expected_events)