def cache_output(output_name: str, output: PValue) -> None: user_pipeline = ie.current_env().user_pipeline(output.pipeline) if user_pipeline: cache_manager = ie.current_env().get_cache_manager( user_pipeline, create_if_absent=True) else: _LOGGER.warning( 'Something is wrong with %s. Cannot introspect its data.', output) return key = CacheKey.from_pcoll(output_name, output).to_str() _ = reify_to_cache(pcoll=output, cache_key=key, cache_manager=cache_manager) try: output.pipeline.run().wait_until_finish() except (KeyboardInterrupt, SystemExit): raise except Exception as e: _LOGGER.warning(_NOT_SUPPORTED_MSG, e, output.pipeline.runner) return ie.current_env().mark_pcollection_computed([output]) visualize_computed_pcoll(output_name, output, max_n=float('inf'), max_duration_secs=float('inf'))
def __init__( self, user_pipeline, # type: beam.Pipeline pcolls, # type: List[beam.pvalue.PCollection] result, # type: beam.runner.PipelineResult max_n, # type: int max_duration_secs, # type: float ): self._user_pipeline = user_pipeline self._result = result self._result_lock = threading.Lock() self._pcolls = pcolls pcoll_var = lambda pcoll: { v: k for k, v in utils.pcoll_by_name().items() }.get(pcoll, None) self._streams = { pcoll: ElementStream( pcoll, pcoll_var(pcoll), CacheKey.from_pcoll(pcoll_var(pcoll), pcoll).to_str(), max_n, max_duration_secs) for pcoll in pcolls } self._start = time.time() self._duration_secs = max_duration_secs self._set_computed = bcj.is_cache_complete(str(id(user_pipeline))) # Run a separate thread for marking the PCollections done. This is because # the pipeline run may be asynchronous. self._mark_computed = threading.Thread(target=self._mark_all_computed) self._mark_computed.daemon = True self._mark_computed.start()
def _build_query_components( query: str, found: Dict[str, beam.PCollection], output_name: str, run: bool = True ) -> Tuple[str, Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline], SqlChain]: """Builds necessary components needed to apply the SqlTransform. Args: query: The SQL query to be executed by the magic. found: The PCollections with variable names found to be used by the query. output_name: The output variable name in __main__ module. run: Whether to prepare components for a local run or not. Returns: The processed query to be executed by the magic; a source to apply the SqlTransform to: a dictionary of tagged PCollections, or a single PCollection, or the pipeline to execute the query; the chain of applied beam_sql magics this one belongs to. """ if found: user_pipeline = ie.current_env().user_pipeline( next(iter(found.values())).pipeline) sql_pipeline = beam.Pipeline(options=user_pipeline._options) ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline) sql_source = {} if run: if has_source_to_cache(user_pipeline): sql_source = pcolls_from_streaming_cache( user_pipeline, sql_pipeline, found) else: cache_manager = ie.current_env().get_cache_manager( user_pipeline, create_if_absent=True) for pcoll_name, pcoll in found.items(): cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str() sql_source[pcoll_name] = unreify_from_cache( pipeline=sql_pipeline, cache_key=cache_key, cache_manager=cache_manager, element_type=pcoll.element_type) else: sql_source = found if len(sql_source) == 1: query = replace_single_pcoll_token(query, next(iter(sql_source.keys()))) sql_source = next(iter(sql_source.values())) node = SqlNode( output_name=output_name, source=set(found.keys()), query=query) chain = ie.current_env().get_sql_chain( user_pipeline, set_user_pipeline=True).append(node) else: # does not query any existing PCollection sql_source = beam.Pipeline() ie.current_env().add_user_pipeline(sql_source) # The node should be the root node of the chain created below. node = SqlNode(output_name=output_name, source=sql_source, query=query) chain = ie.current_env().get_sql_chain(sql_source).append(node) return query, sql_source, chain
def cache_key(self, pcoll): """Gets the identifier of a cacheable PCollection in cache. If the pcoll is not a cacheable, return ''. The key is what the pcoll would use as identifier if it's materialized in cache. It doesn't mean that there would definitely be such cache already. Also, the pcoll can come from the original user defined pipeline object or an equivalent pcoll from a transformed copy of the original pipeline. 'pcoll_id' of cacheable is not stable for cache_key, thus not included in cache key. A combination of 'var', 'version' and 'producer_version' is sufficient to identify a cached PCollection. """ cacheable = self.cacheables.get(self._cacheable_key(pcoll), None) if cacheable: if cacheable.pcoll in self.runner_pcoll_to_user_pcoll: user_pcoll = self.runner_pcoll_to_user_pcoll[cacheable.pcoll] else: user_pcoll = cacheable.pcoll return repr( CacheKey(cacheable.var, cacheable.version, cacheable.producer_version, str(id(user_pcoll.pipeline)))) return ''
def test_empty(self): CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', '')) cache = StreamingCache(cache_dir=None) self.assertFalse(cache.exists(CACHED_PCOLLECTION_KEY)) cache.write([], CACHED_PCOLLECTION_KEY) reader, _ = cache.read(CACHED_PCOLLECTION_KEY) # Assert that an empty reader returns an empty list. self.assertFalse([e for e in reader])
def setUp(self): self.cache = InMemoryCache() self.p = beam.Pipeline() self.pcoll = self.p | beam.Create([]) self.cache_key = str(CacheKey('pcoll', '', '', '')) # Create a MockPipelineResult to control the state of a fake run of the # pipeline. self.mock_result = MockPipelineResult() ie.current_env().add_user_pipeline(self.p) ie.current_env().set_pipeline_result(self.p, self.mock_result) ie.current_env().set_cache_manager(self.cache, self.p)
def __init__(self, cache_dir, labels, is_cache_complete=None, coder=None): if not coder: coder = SafeFastPrimitivesCoder() if not is_cache_complete: is_cache_complete = lambda _: True self._cache_dir = cache_dir self._coder = coder self._labels = labels self._path = os.path.join(self._cache_dir, *self._labels) self._is_cache_complete = is_cache_complete self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
def test_single_reader(self): """Tests that we expect to see all the correctly emitted TestStreamPayloads. """ CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', '')) values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY) .add_element(element=0, event_time_secs=0) .advance_processing_time(1) .add_element(element=1, event_time_secs=1) .advance_processing_time(1) .add_element(element=2, event_time_secs=2) .build()) # yapf: disable cache = StreamingCache(cache_dir=None) cache.write(values, CACHED_PCOLLECTION_KEY) reader, _ = cache.read(CACHED_PCOLLECTION_KEY) coder = coders.FastPrimitivesCoder() events = list(reader) # Units here are in microseconds. expected = [ TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode(0), timestamp=0) ], tag=CACHED_PCOLLECTION_KEY)), TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event(element_event=TestStreamPayload.Event. AddElements(elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode(1), timestamp=1 * 10**6) ], tag=CACHED_PCOLLECTION_KEY)), TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event(element_event=TestStreamPayload.Event. AddElements(elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode(2), timestamp=2 * 10**6) ], tag=CACHED_PCOLLECTION_KEY)), ] self.assertSequenceEqual(events, expected)
def pcolls_from_streaming_cache( user_pipeline: beam.Pipeline, query_pipeline: beam.Pipeline, name_to_pcoll: Dict[str, beam.PCollection]) -> Dict[str, beam.PCollection]: """Reads PCollection cache through the TestStream. Args: user_pipeline: The beam.Pipeline object defined by the user in the notebook. query_pipeline: The beam.Pipeline object built by the magic to execute the SQL query. name_to_pcoll: PCollections with variable names used in the SQL query. Returns: A Dict[str, beam.PCollection], where each PCollection is tagged with their PCollection variable names, read from the cache. When the user_pipeline has unbounded sources, we force all cache reads to go through the TestStream even if they are bounded sources. """ def exception_handler(e): _LOGGER.error(str(e)) return True cache_manager = ie.current_env().get_cache_manager( user_pipeline, create_if_absent=True) test_stream_service = ie.current_env().get_test_stream_service_controller( user_pipeline) if not test_stream_service: test_stream_service = TestStreamServiceController( cache_manager, exception_handler=exception_handler) test_stream_service.start() ie.current_env().set_test_stream_service_controller( user_pipeline, test_stream_service) tag_to_name = {} for name, pcoll in name_to_pcoll.items(): key = CacheKey.from_pcoll(name, pcoll).to_str() tag_to_name[key] = name output_pcolls = query_pipeline | test_stream.TestStream( output_tags=set(tag_to_name.keys()), coder=cache_manager._default_pcoder, endpoint=test_stream_service.endpoint) sql_source = {} for tag, output in output_pcolls.items(): name = tag_to_name[tag] # Must mark the element_type to avoid introducing pickled Python coder # to the Java expansion service. output.element_type = name_to_pcoll[name].element_type sql_source[name] = output return sql_source
def test_always_default_coder_for_test_stream_records(self): CACHED_NUMBERS = repr(CacheKey('numbers', '', '', '')) numbers = (FileRecordsBuilder(CACHED_NUMBERS) .advance_processing_time(2) .add_element(element=1, event_time_secs=0) .advance_processing_time(1) .add_element(element=2, event_time_secs=0) .advance_processing_time(1) .add_element(element=2, event_time_secs=0) .build()) # yapf: disable cache = StreamingCache(cache_dir=None) cache.write(numbers, CACHED_NUMBERS) self.assertIs(type(cache.load_pcoder(CACHED_NUMBERS)), type(cache._default_pcoder))
def _wait_until_file_exists(self, timeout_secs=30): """Blocks until the file exists for a maximum of timeout_secs. """ # Wait for up to `timeout_secs` for the file to be available. start = time.time() while not os.path.exists(self._path): time.sleep(1) if time.time() - start > timeout_secs: pcollection_var = CacheKey.from_str(self._labels[-1]).var raise RuntimeError( 'Timed out waiting for cache file for PCollection `{}` to be ' 'available with path {}.'.format(pcollection_var, self._path)) return open(self._path, mode='rb')
def test_cache_output(self): p_cache_output = beam.Pipeline() pcoll_co = p_cache_output | 'Create Source' >> beam.Create([1, 2, 3]) cache_manager = FileBasedCacheManager() ie.current_env().set_cache_manager(cache_manager, p_cache_output) ib.watch(locals()) with patch( 'apache_beam.runners.interactive.display.pcoll_visualization.' 'visualize_computed_pcoll', lambda a, b: None): cache_output('pcoll_co', pcoll_co) self.assertIn(pcoll_co, ie.current_env().computed_pcollections) self.assertTrue( cache_manager.exists( 'full', CacheKey.from_pcoll('pcoll_co', pcoll_co).to_str()))
def _build_query_components( query: str, found: Dict[str, beam.PCollection] ) -> Tuple[str, Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline]]: """Builds necessary components needed to apply the SqlTransform. Args: query: The SQL query to be executed by the magic. found: The PCollections with variable names found to be used by the query. Returns: The processed query to be executed by the magic and a source to apply the SqlTransform to: a dictionary of tagged PCollections, or a single PCollection, or the pipeline to execute the query. """ if found: user_pipeline = ie.current_env().user_pipeline( next(iter(found.values())).pipeline) sql_pipeline = beam.Pipeline(options=user_pipeline._options) ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline) sql_source = {} if has_source_to_cache(user_pipeline): sql_source = pcolls_from_streaming_cache(user_pipeline, sql_pipeline, found) else: cache_manager = ie.current_env().get_cache_manager( user_pipeline, create_if_absent=True) for pcoll_name, pcoll in found.items(): cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str() sql_source[pcoll_name] = unreify_from_cache( pipeline=sql_pipeline, cache_key=cache_key, cache_manager=cache_manager, element_type=pcoll.element_type) if len(sql_source) == 1: query = replace_single_pcoll_token(query, next(iter(sql_source.keys()))) sql_source = next(iter(sql_source.values())) else: sql_source = beam.Pipeline() ie.current_env().add_user_pipeline(sql_source) return query, sql_source
def read(self, pcoll_name, pcoll, max_n, max_duration_secs): # type: (str, beam.pvalue.PValue, int, float) -> Union[None, ElementStream] """Reads an ElementStream of a computed PCollection. Returns None if an error occurs. The caller is responsible of validating if the given pcoll_name and pcoll can identify a watched and computed PCollection without ambiguity in the notebook. """ try: cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str() return ElementStream(pcoll, pcoll_name, cache_key, max_n, max_duration_secs) except (KeyboardInterrupt, SystemExit): raise except Exception as e: # Caller should handle all validations. Here to avoid redundant # validations, simply log errors if caller fails to do so. _LOGGER.error(str(e)) return None
def cache_key(self, pcoll): """Gets the identifier of a cacheable PCollection in cache. If the pcoll is not a cacheable, return ''. This is only needed in pipeline instrument when the origin of given pcoll is unknown (whether it's from the user pipeline or a runner pipeline). If a pcoll is from the user pipeline, always use CacheKey.from_pcoll to build the key. The key is what the pcoll would use as identifier if it's materialized in cache. It doesn't mean that there would definitely be such cache already. Also, the pcoll can come from the original user defined pipeline object or an equivalent pcoll from a transformed copy of the original pipeline. """ cacheable = self._cacheables.get(self.pcoll_id(pcoll), None) if cacheable: if cacheable.pcoll in self.runner_pcoll_to_user_pcoll: user_pcoll = self.runner_pcoll_to_user_pcoll[cacheable.pcoll] else: user_pcoll = cacheable.pcoll return CacheKey.from_pcoll(cacheable.var, user_pcoll).to_str() return ''
def test_recording_manager_clears_cache(self): """Tests that the RecordingManager clears the cache before recording. A job may have incomplete PCollections when the job terminates. Clearing the cache ensures that correct results are computed every run. """ # Add the TestStream so that it can be cached. ib.options.recordable_sources.add(TestStream) p = beam.Pipeline(InteractiveRunner(), options=PipelineOptions(streaming=True)) elems = (p | TestStream().advance_watermark_to( 0).advance_processing_time(1).add_elements(list( range(10))).advance_processing_time(1)) squares = elems | beam.Map(lambda x: x**2) # Watch the local scope for Interactive Beam so that referenced PCollections # will be cached. ib.watch(locals()) # This is normally done in the interactive_utils when a transform is # applied but needs an IPython environment. So we manually run this here. ie.current_env().track_user_pipelines() # Do the first recording to get the timestamp of the first time the fragment # was run. rm = RecordingManager(p) # Set up a mock for the Cache's clear function which will be used to clear # uncomputed PCollections. rm._clear_pcolls = MagicMock() rm.record([squares], max_n=1, max_duration=500) rm.cancel() # Assert that the cache cleared the PCollection. rm._clear_pcolls.assert_any_call( unittest.mock.ANY, # elems is unbounded source populated by the background job, thus not # cleared. {CacheKey.from_pcoll('squares', squares).to_str()})
def test_streaming_cache_uses_local_ib_cache_root(self): """ Checks that StreamingCache._cache_dir is set to the cache_root set under Interactive Beam for a local directory and that the cached values are the same as the values of a cache using default settings. """ CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', '')) values = (FileRecordsBuilder(CACHED_PCOLLECTION_KEY) .advance_processing_time(1) .advance_watermark(watermark_secs=0) .add_element(element=1, event_time_secs=0) .build()) # yapf: disable local_cache = StreamingCache(cache_dir=None) local_cache.write(values, CACHED_PCOLLECTION_KEY) reader_one, _ = local_cache.read(CACHED_PCOLLECTION_KEY) pcoll_list_one = list(reader_one) # Set Interactive Beam specified cache dir to cloud storage ib.options.cache_root = '/tmp/it-test/' cache_manager_with_ib_option = StreamingCache( cache_dir=ib.options.cache_root) self.assertEqual(ib.options.cache_root, cache_manager_with_ib_option._cache_dir) cache_manager_with_ib_option.write(values, CACHED_PCOLLECTION_KEY) reader_two, _ = cache_manager_with_ib_option.read( CACHED_PCOLLECTION_KEY) pcoll_list_two = list(reader_two) self.assertEqual(pcoll_list_one, pcoll_list_two) # Reset Interactive Beam setting ib.options.cache_root = None
def test_multiple_readers(self): """Tests that the service advances the clock with multiple outputs. """ CACHED_LETTERS = repr(CacheKey('letters', '', '', '')) CACHED_NUMBERS = repr(CacheKey('numbers', '', '', '')) CACHED_LATE = repr(CacheKey('late', '', '', '')) letters = (FileRecordsBuilder(CACHED_LETTERS) .advance_processing_time(1) .advance_watermark(watermark_secs=0) .add_element(element='a', event_time_secs=0) .advance_processing_time(10) .advance_watermark(watermark_secs=10) .add_element(element='b', event_time_secs=10) .build()) # yapf: disable numbers = (FileRecordsBuilder(CACHED_NUMBERS) .advance_processing_time(2) .add_element(element=1, event_time_secs=0) .advance_processing_time(1) .add_element(element=2, event_time_secs=0) .advance_processing_time(1) .add_element(element=2, event_time_secs=0) .build()) # yapf: disable late = (FileRecordsBuilder(CACHED_LATE) .advance_processing_time(101) .add_element(element='late', event_time_secs=0) .build()) # yapf: disable cache = StreamingCache(cache_dir=None) cache.write(letters, CACHED_LETTERS) cache.write(numbers, CACHED_NUMBERS) cache.write(late, CACHED_LATE) reader = cache.read_multiple([[CACHED_LETTERS], [CACHED_NUMBERS], [CACHED_LATE]]) coder = coders.FastPrimitivesCoder() events = list(reader) # Units here are in microseconds. expected = [ # Advances clock from 0 to 1 TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=0, tag=CACHED_LETTERS)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('a'), timestamp=0) ], tag=CACHED_LETTERS)), # Advances clock from 1 to 2 TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode(1), timestamp=0) ], tag=CACHED_NUMBERS)), # Advances clock from 2 to 3 TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode(2), timestamp=0) ], tag=CACHED_NUMBERS)), # Advances clock from 3 to 4 TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode(2), timestamp=0) ], tag=CACHED_NUMBERS)), # Advances clock from 4 to 11 TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=7 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=10 * 10**6, tag=CACHED_LETTERS)), TestStreamPayload.Event(element_event=TestStreamPayload.Event. AddElements(elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('b'), timestamp=10 * 10**6) ], tag=CACHED_LETTERS)), # Advances clock from 11 to 101 TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=90 * 10**6)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('late'), timestamp=0) ], tag=CACHED_LATE)), ] self.assertSequenceEqual(events, expected)
def test_read_and_write_multiple_outputs(self): """An integration test between the Sink and Source with multiple outputs. This tests the funcionatlity that the StreamingCache reads from multiple files and combines them into a single sorted output. """ LETTERS_TAG = repr(CacheKey('letters', '', '', '')) NUMBERS_TAG = repr(CacheKey('numbers', '', '', '')) # Units here are in seconds. test_stream = (TestStream() .advance_watermark_to(0, tag=LETTERS_TAG) .advance_processing_time(5) .add_elements(['a', 'b', 'c'], tag=LETTERS_TAG) .advance_watermark_to(10, tag=NUMBERS_TAG) .advance_processing_time(1) .add_elements( [ TimestampedValue('1', 15), TimestampedValue('2', 15), TimestampedValue('3', 15) ], tag=NUMBERS_TAG)) # yapf: disable cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0) coder = SafeFastPrimitivesCoder() options = StandardOptions(streaming=True) with TestPipeline(options=options) as p: # pylint: disable=expression-not-assigned events = p | test_stream events[LETTERS_TAG] | 'Letters sink' >> cache.sink([LETTERS_TAG]) events[NUMBERS_TAG] | 'Numbers sink' >> cache.sink([NUMBERS_TAG]) reader = cache.read_multiple([[LETTERS_TAG], [NUMBERS_TAG]]) actual_events = list(reader) # Units here are in microseconds. expected_events = [ TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=5 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=0, tag=LETTERS_TAG)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('a'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('b'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('c'), timestamp=0), ], tag=LETTERS_TAG)), TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=10 * 10**6, tag=NUMBERS_TAG)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=0, tag=LETTERS_TAG)), TestStreamPayload.Event(element_event=TestStreamPayload.Event. AddElements(elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('1'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('2'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('3'), timestamp=15 * 10**6), ], tag=NUMBERS_TAG)), ] self.assertListEqual(actual_events, expected_events)
def cache_key_of(self, name, pcoll): return CacheKey.from_pcoll(name, pcoll).to_str()
def test_read_and_write(self): """An integration test between the Sink and Source. This ensures that the sink and source speak the same language in terms of coders, protos, order, and units. """ CACHED_RECORDS = repr(CacheKey('records', '', '', '')) # Units here are in seconds. test_stream = ( TestStream(output_tags=(CACHED_RECORDS)) .advance_watermark_to(0, tag=CACHED_RECORDS) .advance_processing_time(5) .add_elements(['a', 'b', 'c'], tag=CACHED_RECORDS) .advance_watermark_to(10, tag=CACHED_RECORDS) .advance_processing_time(1) .add_elements( [ TimestampedValue('1', 15), TimestampedValue('2', 15), TimestampedValue('3', 15) ], tag=CACHED_RECORDS)) # yapf: disable coder = SafeFastPrimitivesCoder() cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0) # Assert that there are no capture keys at first. self.assertEqual(cache.capture_keys, set()) options = StandardOptions(streaming=True) with TestPipeline(options=options) as p: records = (p | test_stream)[CACHED_RECORDS] # pylint: disable=expression-not-assigned records | cache.sink([CACHED_RECORDS], is_capture=True) reader, _ = cache.read(CACHED_RECORDS) actual_events = list(reader) # Assert that the capture keys are forwarded correctly. self.assertEqual(cache.capture_keys, set([CACHED_RECORDS])) # Units here are in microseconds. expected_events = [ TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=5 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=0, tag=CACHED_RECORDS)), TestStreamPayload.Event( element_event=TestStreamPayload.Event.AddElements( elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('a'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('b'), timestamp=0), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('c'), timestamp=0), ], tag=CACHED_RECORDS)), TestStreamPayload.Event(processing_time_event=TestStreamPayload. Event.AdvanceProcessingTime( advance_duration=1 * 10**6)), TestStreamPayload.Event( watermark_event=TestStreamPayload.Event.AdvanceWatermark( new_watermark=10 * 10**6, tag=CACHED_RECORDS)), TestStreamPayload.Event(element_event=TestStreamPayload.Event. AddElements(elements=[ TestStreamPayload.TimestampedElement( encoded_element=coder.encode('1'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('2'), timestamp=15 * 10**6), TestStreamPayload.TimestampedElement( encoded_element=coder.encode('3'), timestamp=15 * 10**6), ], tag=CACHED_RECORDS)), ] self.assertEqual(actual_events, expected_events)