def test_read_cache(self, mocked_get_cache_manager): p = beam.Pipeline() pcoll = p | beam.Create([1, 2, 3]) consumer_transform = beam.Map(lambda x: x * x) _ = pcoll | consumer_transform ib.watch(locals()) # Create the cache in memory. cache_manager = InMemoryCache() mocked_get_cache_manager.return_value = cache_manager aug_p = ap.AugmentedPipeline(p) key = repr(aug_p._cacheables[pcoll].to_key()) cache_manager.write('test', 'full', key) # Capture the applied transform of the consumer_transform. pcoll_id = aug_p._context.pcollections.get_id(pcoll) consumer_transform_id = None pipeline_proto = p.to_runner_api() for (transform_id, transform) in pipeline_proto.components.transforms.items(): if pcoll_id in transform.inputs.values(): consumer_transform_id = transform_id break self.assertIsNotNone(consumer_transform_id) # Read cache on the pipeline proto. _, cache_id = read_cache.ReadCache( pipeline_proto, aug_p._context, aug_p._cache_manager, aug_p._cacheables[pcoll]).read_cache() actual_pipeline = pipeline_proto # Read cache directly on the pipeline instance. label = '{}{}'.format('_cache_', key) transform = read_cache._ReadCacheTransform(aug_p._cache_manager, key, label) p | 'source' + label >> transform expected_pipeline = p.to_runner_api() # This rougly checks the equivalence between two protos, not detailed # wiring in sub transforms under top level transforms. assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline) # Check if the actual_pipeline uses cache as input of the # consumer_transform instead of the original pcoll from source. inputs = actual_pipeline.components.transforms[ consumer_transform_id].inputs self.assertIn(cache_id, inputs.values()) self.assertNotIn(pcoll_id, inputs.values())
def test_describe(self): p = beam.Pipeline(InteractiveRunner()) numbers = p | 'numbers' >> beam.Create([0, 1, 2]) letters = p | 'letters' >> beam.Create(['a', 'b', 'c']) ib.watch(locals()) # Create a MockPipelineResult to control the state of a fake run of the # pipeline. mock_result = MockPipelineResult() ie.current_env().track_user_pipelines() ie.current_env().set_pipeline_result(p, mock_result) cache_manager = InMemoryCache() ie.current_env().set_cache_manager(cache_manager, p) # Create a recording with an arbitrary start time. start_time = 100 recording = Recording(p, [numbers, letters], mock_result, pi.PipelineInstrument(p), max_n=10, max_duration_secs=60, start_time_for_test=start_time) # Get the cache key of the stream and write something to cache. This is # so that a pipeline doesn't have to run in the test. numbers_stream = recording.stream(numbers) cache_manager.write([0, 1, 2], 'full', numbers_stream.cache_key) cache_manager.save_pcoder(None, 'full', numbers_stream.cache_key) letters_stream = recording.stream(letters) cache_manager.write(['a', 'b', 'c'], 'full', letters_stream.cache_key) cache_manager.save_pcoder(None, 'full', letters_stream.cache_key) # Get the description. description = recording.describe() size = description['size'] start = description['start'] self.assertEqual( size, cache_manager.size('full', numbers_stream.cache_key) + cache_manager.size('full', letters_stream.cache_key)) self.assertEqual(start, start_time)
class ElementStreamTest(unittest.TestCase): def setUp(self): self.cache = InMemoryCache() self.p = beam.Pipeline() self.pcoll = self.p | beam.Create([]) self.cache_key = str(CacheKey('pcoll', '', '', '')) # Create a MockPipelineResult to control the state of a fake run of the # pipeline. self.mock_result = MockPipelineResult() ie.current_env().add_user_pipeline(self.p) ie.current_env().set_pipeline_result(self.p, self.mock_result) ie.current_env().set_cache_manager(self.cache, self.p) def test_read(self): """Test reading and if a stream is done no more elements are returned.""" self.mock_result.set_state(PipelineState.DONE) self.cache.write(['expected'], 'full', self.cache_key) self.cache.save_pcoder(None, 'full', self.cache_key) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=1, max_duration_secs=1) self.assertFalse(stream.is_done()) self.assertEqual(list(stream.read())[0], 'expected') self.assertTrue(stream.is_done()) def test_done_if_terminated(self): """Test that terminating the job sets the stream as done.""" self.cache.write(['expected'], 'full', self.cache_key) self.cache.save_pcoder(None, 'full', self.cache_key) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=100, max_duration_secs=10) self.assertFalse(stream.is_done()) self.assertEqual(list(stream.read(tail=False))[0], 'expected') # The limiters were not reached, so the stream is not done yet. self.assertFalse(stream.is_done()) self.mock_result.set_state(PipelineState.DONE) self.assertEqual(list(stream.read(tail=False))[0], 'expected') # The underlying pipeline is terminated, so the stream won't yield new # elements. self.assertTrue(stream.is_done()) def test_read_n(self): """Test that the stream only reads 'n' elements.""" self.mock_result.set_state(PipelineState.DONE) self.cache.write(list(range(5)), 'full', self.cache_key) self.cache.save_pcoder(None, 'full', self.cache_key) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=1, max_duration_secs=1) self.assertEqual(list(stream.read()), [0]) self.assertTrue(stream.is_done()) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=2, max_duration_secs=1) self.assertEqual(list(stream.read()), [0, 1]) self.assertTrue(stream.is_done()) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=5, max_duration_secs=1) self.assertEqual(list(stream.read()), list(range(5))) self.assertTrue(stream.is_done()) # Test that if the user asks for more than in the cache it still returns. stream = ElementStream(self.pcoll, '', self.cache_key, max_n=10, max_duration_secs=1) self.assertEqual(list(stream.read()), list(range(5))) self.assertTrue(stream.is_done()) def test_read_duration(self): """Test that the stream only reads a 'duration' of elements.""" def as_windowed_value(element): return WindowedValueHolder(WindowedValue(element, 0, [])) values = (FileRecordsBuilder(tag=self.cache_key) .advance_processing_time(1) .add_element(element=as_windowed_value(0), event_time_secs=0) .advance_processing_time(1) .add_element(element=as_windowed_value(1), event_time_secs=1) .advance_processing_time(1) .add_element(element=as_windowed_value(2), event_time_secs=3) .advance_processing_time(1) .add_element(element=as_windowed_value(3), event_time_secs=4) .advance_processing_time(1) .add_element(element=as_windowed_value(4), event_time_secs=5) .build()) # yapf: disable values = [ v.recorded_event for v in values if isinstance(v, beam_interactive_api_pb2.TestStreamFileRecord) ] self.mock_result.set_state(PipelineState.DONE) self.cache.write(values, 'full', self.cache_key) self.cache.save_pcoder(coders.FastPrimitivesCoder(), 'full', self.cache_key) # The following tests a progression of reading different durations from the # cache. stream = ElementStream(self.pcoll, '', self.cache_key, max_n=100, max_duration_secs=1) self.assertSequenceEqual([e.value for e in stream.read()], [0]) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=100, max_duration_secs=2) self.assertSequenceEqual([e.value for e in stream.read()], [0, 1]) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=100, max_duration_secs=10) self.assertSequenceEqual([e.value for e in stream.read()], [0, 1, 2, 3, 4])
class ElementStreamTest(unittest.TestCase): def setUp(self): ie.new_env() self.cache = InMemoryCache() self.p = beam.Pipeline() self.pcoll = self.p | beam.Create([]) self.cache_key = str(pi.CacheKey('pcoll', '', '', '')) # Create a MockPipelineResult to control the state of a fake run of the # pipeline. self.mock_result = MockPipelineResult() ie.current_env().track_user_pipelines() ie.current_env().set_pipeline_result(self.p, self.mock_result) ie.current_env().set_cache_manager(self.cache, self.p) def test_read(self): """Test reading and if a stream is done no more elements are returned.""" self.mock_result.set_state(PipelineState.DONE) self.cache.write(['expected'], 'full', self.cache_key) self.cache.save_pcoder(None, 'full', self.cache_key) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=1, max_duration_secs=1) self.assertFalse(stream.is_done()) self.assertEqual(list(stream.read())[0], 'expected') self.assertTrue(stream.is_done()) def test_done_if_terminated(self): """Test that terminating the job sets the stream as done.""" self.cache.write(['expected'], 'full', self.cache_key) self.cache.save_pcoder(None, 'full', self.cache_key) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=100, max_duration_secs=10) self.assertFalse(stream.is_done()) self.assertEqual(list(stream.read(tail=False))[0], 'expected') # The limiters were not reached, so the stream is not done yet. self.assertFalse(stream.is_done()) self.mock_result.set_state(PipelineState.DONE) self.assertEqual(list(stream.read(tail=False))[0], 'expected') # The underlying pipeline is terminated, so the stream won't yield new # elements. self.assertTrue(stream.is_done()) def test_read_n(self): """Test that the stream only reads 'n' elements.""" self.mock_result.set_state(PipelineState.DONE) self.cache.write(list(range(5)), 'full', self.cache_key) self.cache.save_pcoder(None, 'full', self.cache_key) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=1, max_duration_secs=1) self.assertEqual(list(stream.read()), [0]) self.assertTrue(stream.is_done()) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=2, max_duration_secs=1) self.assertEqual(list(stream.read()), [0, 1]) self.assertTrue(stream.is_done()) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=5, max_duration_secs=1) self.assertEqual(list(stream.read()), list(range(5))) self.assertTrue(stream.is_done()) # Test that if the user asks for more than in the cache it still returns. stream = ElementStream(self.pcoll, '', self.cache_key, max_n=10, max_duration_secs=1) self.assertEqual(list(stream.read()), list(range(5))) self.assertTrue(stream.is_done()) def test_read_duration(self): """Test that the stream only reads a 'duration' of elements.""" values = (FileRecordsBuilder(tag=self.cache_key) .advance_processing_time(1) .add_element(element=0, event_time_secs=0) .advance_processing_time(1) .add_element(element=1, event_time_secs=1) .advance_processing_time(1) .add_element(element=2, event_time_secs=3) .advance_processing_time(1) .add_element(element=3, event_time_secs=4) .advance_processing_time(1) .add_element(element=4, event_time_secs=5) .build()) # yapf: disable self.mock_result.set_state(PipelineState.DONE) self.cache.write(values, 'full', self.cache_key) self.cache.save_pcoder(None, 'full', self.cache_key) # The elements read from the cache are TestStreamFileRecord instances and # have the underlying elements encoded. This method decodes the elements # from the TestStreamFileRecord. def get_elements(events): coder = coders.FastPrimitivesCoder() elements = [] for e in events: if not isinstance(e, TestStreamFileRecord): continue if e.recorded_event.element_event: elements += ([ coder.decode(el.encoded_element) for el in e.recorded_event.element_event.elements ]) return elements # The following tests a progression of reading different durations from the # cache. stream = ElementStream(self.pcoll, '', self.cache_key, max_n=100, max_duration_secs=1) self.assertSequenceEqual(get_elements(stream.read()), [0]) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=100, max_duration_secs=2) self.assertSequenceEqual(get_elements(stream.read()), [0, 1]) stream = ElementStream(self.pcoll, '', self.cache_key, max_n=100, max_duration_secs=10) self.assertSequenceEqual(get_elements(stream.read()), [0, 1, 2, 3, 4])