class StreamingCache(CacheManager): """Abstraction that holds the logic for reading and writing to cache. """ def __init__(self, cache_dir, is_cache_complete=None, sample_resolution_sec=0.1): self._sample_resolution_sec = sample_resolution_sec self._is_cache_complete = is_cache_complete if cache_dir: self._cache_dir = cache_dir else: self._cache_dir = tempfile.mkdtemp(prefix='interactive-temp-', dir=os.environ.get( 'TEST_TMPDIR', None)) # List of saved pcoders keyed by PCollection path. It is OK to keep this # list in memory because once FileBasedCacheManager object is # destroyed/re-created it loses the access to previously written cache # objects anyways even if cache_dir already exists. In other words, # it is not possible to resume execution of Beam pipeline from the # saved cache if FileBasedCacheManager has been reset. # # However, if we are to implement better cache persistence, one needs # to take care of keeping consistency between the cached PCollection # and its PCoder type. self._saved_pcoders = {} self._default_pcoder = SafeFastPrimitivesCoder() # The sinks to capture data from capturable sources. # Dict([str, StreamingCacheSink]) self._capture_sinks = {} self._capture_keys = set() def size(self, *labels): if self.exists(*labels): return os.path.getsize(os.path.join(self._cache_dir, *labels)) return 0 @property def capture_size(self): return sum( [sink.size_in_bytes for _, sink in self._capture_sinks.items()]) @property def capture_paths(self): return list(self._capture_sinks.keys()) @property def capture_keys(self): return self._capture_keys def exists(self, *labels): path = os.path.join(self._cache_dir, *labels) return os.path.exists(path) # TODO(srohde): Modify this to return the correct version. def read(self, *labels, **args): """Returns a generator to read all records from file.""" tail = args.pop('tail', False) # Only immediately return when the file doesn't exist when the user wants a # snapshot of the cache (when tail is false). if not self.exists(*labels) and not tail: return iter([]), -1 reader = StreamingCacheSource(self._cache_dir, labels, self._is_cache_complete).read(tail=tail) # Return an empty iterator if there is nothing in the file yet. This can # only happen when tail is False. try: header = next(reader) except StopIteration: return iter([]), -1 return StreamingCache.Reader([header], [reader]).read(), 1 def read_multiple(self, labels, tail=True): """Returns a generator to read all records from file. Does tail until the cache is complete. This is because it is used in the TestStreamServiceController to read from file which is only used during pipeline runtime which needs to block. """ readers = [ StreamingCacheSource(self._cache_dir, l, self._is_cache_complete).read(tail=tail) for l in labels ] headers = [next(r) for r in readers] return StreamingCache.Reader(headers, readers).read() def write(self, values, *labels): """Writes the given values to cache. """ directory = os.path.join(self._cache_dir, *labels[:-1]) filepath = os.path.join(directory, labels[-1]) if not os.path.exists(directory): os.makedirs(directory) with open(filepath, 'ab') as f: for v in values: if isinstance(v, (TestStreamFileHeader, TestStreamFileRecord)): val = v.SerializeToString() else: val = v f.write(self._default_pcoder.encode(val) + b'\n') def clear(self, *labels): directory = os.path.join(self._cache_dir, *labels[:-1]) filepath = os.path.join(directory, labels[-1]) self._capture_keys.discard(labels[-1]) if os.path.exists(filepath): os.remove(filepath) return True return False def source(self, *labels): """Returns the StreamingCacheManager source. This is beam.Impulse() because unbounded sources will be marked with this and then the PipelineInstrument will replace these with a TestStream. """ return beam.Impulse() def sink(self, labels, is_capture=False): """Returns a StreamingCacheSink to write elements to file. Note that this is assumed to only work in the DirectRunner as the underlying StreamingCacheSink assumes a single machine to have correct element ordering. """ filename = labels[-1] cache_dir = os.path.join(self._cache_dir, *labels[:-1]) sink = StreamingCacheSink(cache_dir, filename, self._sample_resolution_sec) if is_capture: self._capture_sinks[sink.path] = sink self._capture_keys.add(filename) return sink def save_pcoder(self, pcoder, *labels): self._saved_pcoders[os.path.join(*labels)] = pcoder def load_pcoder(self, *labels): return (self._default_pcoder if self._default_pcoder is not None else self._saved_pcoders[os.path.join(*labels)]) def cleanup(self): if os.path.exists(self._cache_dir): shutil.rmtree(self._cache_dir) self._saved_pcoders = {} self._capture_sinks = {} self._capture_keys = set() class Reader(object): """Abstraction that reads from PCollection readers. This class is an Abstraction layer over multiple PCollection readers to be used for supplying a TestStream service with events. This class is also responsible for holding the state of the clock, injecting clock advancement events, and watermark advancement events. """ def __init__(self, headers, readers): # This timestamp is used as the monotonic clock to order events in the # replay. self._monotonic_clock = timestamp.Timestamp.of(0) # The PCollection cache readers. self._readers = {} # The file headers that are metadata for that particular PCollection. # The header allows for metadata about an entire stream, so that the data # isn't copied per record. self._headers = {header.tag: header for header in headers} self._readers = OrderedDict( ((h.tag, r) for (h, r) in zip(headers, readers))) # The most recently read timestamp per tag. self._stream_times = { tag: timestamp.Timestamp(seconds=0) for tag in self._headers } def _test_stream_events_before_target(self, target_timestamp): """Reads the next iteration of elements from each stream. Retrieves an element from each stream iff the most recently read timestamp from that stream is less than the target_timestamp. Since the amount of events may not fit into memory, this StreamingCache reads at most one element from each stream at a time. """ records = [] for tag, r in self._readers.items(): # The target_timestamp is the maximum timestamp that was read from the # stream. Some readers may have elements that are less than this. Thus, # we skip all readers that already have elements that are at this # timestamp so that we don't read everything into memory. if self._stream_times[tag] >= target_timestamp: continue try: record = next(r).recorded_event if record.HasField('processing_time_event'): self._stream_times[tag] += timestamp.Duration( micros=record.processing_time_event. advance_duration) records.append((tag, record, self._stream_times[tag])) except StopIteration: pass return records def _merge_sort(self, previous_events, new_events): return sorted(previous_events + new_events, key=lambda x: x[2], reverse=True) def _min_timestamp_of(self, events): return events[-1][2] if events else timestamp.MAX_TIMESTAMP def _event_stream_caught_up_to_target(self, events, target_timestamp): empty_events = not events stream_is_past_target = self._min_timestamp_of( events) > target_timestamp return empty_events or stream_is_past_target def read(self): """Reads records from PCollection readers. """ # The largest timestamp read from the different streams. target_timestamp = timestamp.MAX_TIMESTAMP # The events from last iteration that are past the target timestamp. unsent_events = [] # Emit events until all events have been read. while True: # Read the next set of events. The read events will most likely be # out of order if there are multiple readers. Here we sort them into # a more manageable state. new_events = self._test_stream_events_before_target( target_timestamp) events_to_send = self._merge_sort(unsent_events, new_events) if not events_to_send: break # Get the next largest timestamp in the stream. This is used as the # timestamp for readers to "catch-up" to. This will only read from # readers with a timestamp less than this. target_timestamp = self._min_timestamp_of(events_to_send) # Loop through the elements with the correct timestamp. while not self._event_stream_caught_up_to_target( events_to_send, target_timestamp): # First advance the clock to match the time of the stream. This has # a side-effect of also advancing this cache's clock. tag, r, curr_timestamp = events_to_send.pop() if curr_timestamp > self._monotonic_clock: yield self._advance_processing_time(curr_timestamp) # Then, send either a new element or watermark. if r.HasField('element_event'): r.element_event.tag = tag yield r elif r.HasField('watermark_event'): r.watermark_event.tag = tag yield r unsent_events = events_to_send target_timestamp = self._min_timestamp_of(unsent_events) def _advance_processing_time(self, new_timestamp): """Advances the internal clock and returns an AdvanceProcessingTime event. """ advancy_by = new_timestamp.micros - self._monotonic_clock.micros e = TestStreamPayload.Event( processing_time_event=TestStreamPayload.Event. AdvanceProcessingTime(advance_duration=advancy_by)) self._monotonic_clock = new_timestamp return e
def test_read_and_write(self): """An integration test between the Sink and Source. This ensures that the sink and source speak the same language in terms of coders, protos, order, and units. """ CACHED_RECORDS = repr(CacheKey('records', '', '', '')) # Units here are in seconds. test_stream = ( TestStream(output_tags=(CACHED_RECORDS)) .advance_watermark_to(0, tag=CACHED_RECORDS) .advance_processing_time(5) .add_elements(['a', 'b', 'c'], tag=CACHED_RECORDS) .advance_watermark_to(10, tag=CACHED_RECORDS) .advance_processing_time(1) .add_elements( [ TimestampedValue('1', 15), TimestampedValue('2', 15), TimestampedValue('3', 15) ], tag=CACHED_RECORDS)) # yapf: disable coder = SafeFastPrimitivesCoder() cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0) # Assert that there are no capture keys at first. self.assertEqual(cache.capture_keys, set()) options = StandardOptions(streaming=True) with TestPipeline(options=options) as p: records = (p | test_stream)[CACHED_RECORDS] # pylint: disable=expression-not-assigned records | cache.sink([CACHED_RECORDS], is_capture=True) reader, _ = cache.read(CACHED_RECORDS) actual_events = list(reader) # Assert that the capture keys are forwarded correctly. self.assertEqual(cache.capture_keys, set([CACHED_RECORDS])) # Units here are in microseconds. expected_events = [ TestStreamPayload.Event( processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime( advance_duration=5 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=0, tag=CACHED_RECORDS)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('a'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('b'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('c'), timestamp=0), ], tag=CACHED_RECORDS)), TestStreamPayload.Event( processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=10 * 10**6, tag=CACHED_RECORDS)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('1'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('2'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('3'), timestamp=15 * 10**6), ], tag=CACHED_RECORDS)), ] self.assertEqual(actual_events, expected_events)
def test_read_and_write_multiple_outputs(self): """An integration test between the Sink and Source with multiple outputs. This tests the funcionatlity that the StreamingCache reads from multiple files and combines them into a single sorted output. """ LETTERS_TAG = repr(CacheKey('letters', '', '', '')) NUMBERS_TAG = repr(CacheKey('numbers', '', '', '')) # Units here are in seconds. test_stream = (TestStream() .advance_watermark_to(0, tag=LETTERS_TAG) .advance_processing_time(5) .add_elements(['a', 'b', 'c'], tag=LETTERS_TAG) .advance_watermark_to(10, tag=NUMBERS_TAG) .advance_processing_time(1) .add_elements( [ TimestampedValue('1', 15), TimestampedValue('2', 15), TimestampedValue('3', 15) ], tag=NUMBERS_TAG)) # yapf: disable cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0) coder = SafeFastPrimitivesCoder() options = StandardOptions(streaming=True) with TestPipeline(options=options) as p: # pylint: disable=expression-not-assigned events = p | test_stream events[LETTERS_TAG] | 'Letters sink' >> cache.sink([LETTERS_TAG]) events[NUMBERS_TAG] | 'Numbers sink' >> cache.sink([NUMBERS_TAG]) reader = cache.read_multiple([[LETTERS_TAG], [NUMBERS_TAG]]) actual_events = list(reader) # Units here are in microseconds. expected_events = [ TestStreamPayload.Event( processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime( advance_duration=5 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=0, tag=LETTERS_TAG)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('a'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('b'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('c'), timestamp=0), ], tag=LETTERS_TAG)), TestStreamPayload.Event( processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=10 * 10**6, tag=NUMBERS_TAG)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=0, tag=LETTERS_TAG)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('1'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('2'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('3'), timestamp=15 * 10**6), ], tag=NUMBERS_TAG)), ] self.assertListEqual(actual_events, expected_events)
def test_read_and_write(self): """An integration test between the Sink and Source. This ensures that the sink and source speak the same language in terms of coders, protos, order, and units. """ # Units here are in seconds. test_stream = (TestStream() .advance_watermark_to(0, tag='records') .advance_processing_time(5) .add_elements(['a', 'b', 'c'], tag='records') .advance_watermark_to(10, tag='records') .advance_processing_time(1) .add_elements( [ TimestampedValue('1', 15), TimestampedValue('2', 15), TimestampedValue('3', 15) ], tag='records')) # yapf: disable coder = SafeFastPrimitivesCoder() cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0) options = StandardOptions(streaming=True) options.view_as(DebugOptions).add_experiment( 'passthrough_pcollection_output_ids') with TestPipeline(options=options) as p: # pylint: disable=expression-not-assigned p | test_stream | cache.sink(['records']) reader, _ = cache.read('records') actual_events = list(reader) # Units here are in microseconds. expected_events = [ TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=5 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=0, tag='records')), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('a'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('b'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('c'), timestamp=0), ], tag='records')), TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=10 * 10**6, tag='records')), TestStreamPayload.Event(element_event=TestStreamPayload.Event. AddElements(elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('1'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('2'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('3'), timestamp=15 * 10**6), ], tag='records')), ] self.assertEqual(actual_events, expected_events)