def test_read_cache_expansion(self): p = beam.Pipeline(runner=self.runner) # The cold run. pcoll = (p | 'Create' >> beam.Create([1, 2, 3]) | 'Double' >> beam.Map(lambda x: x * 2) | 'Square' >> beam.Map(lambda x: x**2)) pipeline_proto = to_stable_runner_api(p) pipeline_info = pipeline_analyzer.PipelineInfo( pipeline_proto.components) pcoll_id = 'ref_PCollection_PCollection_3' # Output PCollection of Square cache_label1 = pipeline_info.cache_label(pcoll_id) analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager, pipeline_proto, self.runner) pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api( analyzer.pipeline_proto_to_execute(), p.runner, p._options) pipeline_to_execute.run().wait_until_finish() # The second run. _ = (pcoll | 'Triple' >> beam.Map(lambda x: x * 3) | 'Cube' >> beam.Map(lambda x: x**3)) analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager, to_stable_runner_api(p), self.runner) expected_pipeline = beam.Pipeline(runner=self.runner) pcoll1 = (expected_pipeline | 'Load%s' % cache_label1 >> cache.ReadCache( self.cache_manager, cache_label1)) pcoll2 = pcoll1 | 'Triple' >> beam.Map(lambda x: x * 3) pcoll3 = pcoll2 | 'Cube' >> beam.Map(lambda x: x**3) cache_label2 = 'PColl-7654321' cache_label3 = 'PColl-3141593' # pylint: disable=expression-not-assigned pcoll2 | 'CacheSample%s' % cache_label2 >> cache.WriteCache( self.cache_manager, cache_label2, sample=True, sample_size=10) pcoll3 | 'CacheSample%s' % cache_label3 >> cache.WriteCache( self.cache_manager, cache_label3, sample=True, sample_size=10) pcoll3 | 'CacheFull%s' % cache_label3 >> cache.WriteCache( self.cache_manager, cache_label3) # Since ReadCache & WriteCache expansion leads to more than 50 PTransform # protos in the pipeline, a simple check of proto map size is enough. self.assertPipelineEqual(analyzer.pipeline_proto_to_execute(), to_stable_runner_api(expected_pipeline))
def run_pipeline(self, pipeline): if not hasattr(self, '_desired_cache_labels'): self._desired_cache_labels = set() # Invoke a round trip through the runner API. This makes sure the Pipeline # proto is stable. pipeline = beam.pipeline.Pipeline.from_runner_api( pipeline.to_runner_api(), pipeline.runner, pipeline._options) # Snapshot the pipeline in a portable proto before mutating it. pipeline_proto, original_context = pipeline.to_runner_api( return_context=True) pcolls_to_pcoll_id = self._pcolls_to_pcoll_id(pipeline, original_context) analyzer = pipeline_analyzer.PipelineAnalyzer(self._cache_manager, pipeline_proto, self._underlying_runner, pipeline._options, self._desired_cache_labels) # Should be only accessed for debugging purpose. self._analyzer = analyzer pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api( analyzer.pipeline_proto_to_execute(), self._underlying_runner, pipeline._options) pipeline_info = pipeline_analyzer.PipelineInfo(pipeline_proto.components) display = display_manager.DisplayManager( pipeline_info=pipeline_info, pipeline_proto=pipeline_proto, caches_used=analyzer.caches_used(), cache_manager=self._cache_manager, referenced_pcollections=analyzer.top_level_referenced_pcollection_ids(), required_transforms=analyzer.top_level_required_transforms()) display.start_periodic_update() result = pipeline_to_execute.run() result.wait_until_finish() display.stop_periodic_update() return PipelineResult(result, self, pipeline_info, self._cache_manager, pcolls_to_pcoll_id)
def test_passthrough(self): """ Test that PTransforms which pass through their input PCollection can be used with PipelineInfo. """ class Passthrough(beam.PTransform): def expand(self, pcoll): return pcoll p = beam.Pipeline(runner=self.runner) p | beam.Impulse() | Passthrough() # pylint: disable=expression-not-assigned proto = to_stable_runner_api(p).components info = pipeline_analyzer.PipelineInfo(proto) for pcoll_id in info.all_pcollections(): # FIXME: If PipelineInfo does not support passthrough PTransforms, this # will only fail some of the time, depending on the ordering of # transforms in the Pipeline proto. # Should not throw exception info.cache_label(pcoll_id)