def test_write_cache_expansion(self): p = beam.Pipeline(runner=self.runner) pcoll1 = p | 'Create' >> beam.Create([1, 2, 3]) pcoll2 = pcoll1 | 'Double' >> beam.Map(lambda x: x * 2) pcoll3 = pcoll2 | 'Square' >> beam.Map(lambda x: x**2) analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager, to_stable_runner_api(p), self.runner) cache_label1 = 'PColl-1234567' cache_label2 = 'PColl-7654321' cache_label3 = 'PColl-3141593' # pylint: disable=expression-not-assigned pcoll1 | 'CacheSample%s' % cache_label1 >> cache.WriteCache( self.cache_manager, cache_label1, sample=True, sample_size=10) pcoll2 | 'CacheSample%s' % cache_label2 >> cache.WriteCache( self.cache_manager, cache_label2, sample=True, sample_size=10) pcoll3 | 'CacheSample%s' % cache_label3 >> cache.WriteCache( self.cache_manager, cache_label3, sample=True, sample_size=10) pcoll3 | 'CacheFull%s' % cache_label3 >> cache.WriteCache( self.cache_manager, cache_label3) expected_pipeline_proto = to_stable_runner_api(p) # Since WriteCache expansion leads to more than 50 PTransform protos in the # pipeline, a simple check of proto map size is enough. self.assertPipelineEqual(analyzer.pipeline_proto_to_execute(), expected_pipeline_proto)
def test_background_caching_pipeline_proto(self): p = beam.Pipeline(interactive_runner.InteractiveRunner()) # Test that the two ReadFromPubSub are correctly cut out. a = p | 'ReadUnboundedSourceA' >> beam.io.ReadFromPubSub( subscription='projects/fake-project/subscriptions/fake_sub') b = p | 'ReadUnboundedSourceB' >> beam.io.ReadFromPubSub( subscription='projects/fake-project/subscriptions/fake_sub') # Add some extra PTransform afterwards to make sure that only the unbounded # sources remain. c = (a, b) | beam.CoGroupByKey() _ = c | beam.Map(lambda x: x) ib.watch(locals()) instrumenter = instr.pin(p) actual_pipeline = instrumenter.background_caching_pipeline_proto() # Now recreate the expected pipeline, which should only have the unbounded # sources. p = beam.Pipeline(interactive_runner.InteractiveRunner()) a = p | 'ReadUnboundedSourceA' >> beam.io.ReadFromPubSub( subscription='projects/fake-project/subscriptions/fake_sub') _ = a | 'a' >> cache.WriteCache(ie.current_env().cache_manager(), '') b = p | 'ReadUnboundedSourceB' >> beam.io.ReadFromPubSub( subscription='projects/fake-project/subscriptions/fake_sub') _ = b | 'b' >> cache.WriteCache(ie.current_env().cache_manager(), '') expected_pipeline = p.to_runner_api(return_context=False, use_fake_coders=True) assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline)
def test_read_cache_expansion(self): p = beam.Pipeline(runner=self.runner) # The cold run. pcoll = (p | 'Create' >> beam.Create([1, 2, 3]) | 'Double' >> beam.Map(lambda x: x * 2) | 'Square' >> beam.Map(lambda x: x**2)) pipeline_proto = to_stable_runner_api(p) pipeline_info = pipeline_analyzer.PipelineInfo( pipeline_proto.components) pcoll_id = 'ref_PCollection_PCollection_3' # Output PCollection of Square cache_label1 = pipeline_info.cache_label(pcoll_id) analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager, pipeline_proto, self.runner) pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api( analyzer.pipeline_proto_to_execute(), p.runner, p._options) pipeline_to_execute.run().wait_until_finish() # The second run. _ = (pcoll | 'Triple' >> beam.Map(lambda x: x * 3) | 'Cube' >> beam.Map(lambda x: x**3)) analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager, to_stable_runner_api(p), self.runner) expected_pipeline = beam.Pipeline(runner=self.runner) pcoll1 = (expected_pipeline | 'Load%s' % cache_label1 >> cache.ReadCache( self.cache_manager, cache_label1)) pcoll2 = pcoll1 | 'Triple' >> beam.Map(lambda x: x * 3) pcoll3 = pcoll2 | 'Cube' >> beam.Map(lambda x: x**3) cache_label2 = 'PColl-7654321' cache_label3 = 'PColl-3141593' # pylint: disable=expression-not-assigned pcoll2 | 'CacheSample%s' % cache_label2 >> cache.WriteCache( self.cache_manager, cache_label2, sample=True, sample_size=10) pcoll3 | 'CacheSample%s' % cache_label3 >> cache.WriteCache( self.cache_manager, cache_label3, sample=True, sample_size=10) pcoll3 | 'CacheFull%s' % cache_label3 >> cache.WriteCache( self.cache_manager, cache_label3) # Since ReadCache & WriteCache expansion leads to more than 50 PTransform # protos in the pipeline, a simple check of proto map size is enough. self.assertPipelineEqual(analyzer.pipeline_proto_to_execute(), to_stable_runner_api(expected_pipeline))
def test_instrument_example_pipeline_to_write_cache(self): # Original instance defined by user code has all variables handlers. p_origin, init_pcoll, second_pcoll = self._example_pipeline() # Copied instance when execution has no user defined variables. p_copy, _, _ = self._example_pipeline(False) # Instrument the copied pipeline. pin = instr.pin(p_copy) # Manually instrument original pipeline with expected pipeline transforms. init_pcoll_cache_key = pin.cache_key(init_pcoll) _ = init_pcoll | ( ('_WriteCache_' + init_pcoll_cache_key) >> cache.WriteCache( ie.current_env().cache_manager(), init_pcoll_cache_key)) second_pcoll_cache_key = pin.cache_key(second_pcoll) _ = second_pcoll | ( ('_WriteCache_' + second_pcoll_cache_key) >> cache.WriteCache( ie.current_env().cache_manager(), second_pcoll_cache_key)) # The 2 pipelines should be the same now. self.assertPipelineEqual(p_copy, p_origin)
def test_instrument_example_pipeline_to_write_cache(self): # Original instance defined by user code has all variables handlers. p_origin, init_pcoll, second_pcoll = self._example_pipeline() # Copied instance when execution has no user defined variables. p_copy, _, _ = self._example_pipeline(False) # Instrument the copied pipeline. pipeline_instrument = instr.build_pipeline_instrument(p_copy) # Manually instrument original pipeline with expected pipeline transforms. init_pcoll_cache_key = pipeline_instrument.cache_key(init_pcoll) _ = (init_pcoll | 'reify init' >> beam.Map(lambda _: _) | '_WriteCache_' + init_pcoll_cache_key >> cache.WriteCache( ie.current_env().cache_manager(), init_pcoll_cache_key)) second_pcoll_cache_key = pipeline_instrument.cache_key(second_pcoll) _ = (second_pcoll | 'reify second' >> beam.Map(lambda _: _) | '_WriteCache_' + second_pcoll_cache_key >> cache.WriteCache( ie.current_env().cache_manager(), second_pcoll_cache_key)) # The 2 pipelines should be the same now. assert_pipeline_equal(self, p_copy, p_origin)
def _insert_caching_transforms(self, pcoll_id, required_transforms, top_level_required_transforms, sample=False): """Inserts PTransforms caching the given PCollection into the dicts. Args: pcoll_id: (str) required_transforms: (Dict[str, PTransform proto]) top_level_required_transforms: (Dict[str, PTransform proto]) sample: (bool) whether to cache sample or cache full. Modifies: required_transforms top_level_required_transforms self._write_cache_ids """ cache_label = self._pipeline_info.cache_label(pcoll_id) pcoll = self._context.pcollections.get_by_id(pcoll_id) if not sample: pdone = pcoll | 'CacheFull%s' % cache_label >> cache.WriteCache( self._cache_manager, cache_label) else: pdone = pcoll | 'CacheSample%s' % cache_label >> cache.WriteCache( self._cache_manager, cache_label, sample=True, sample_size=10) write_cache = self._top_level_producer(pdone) write_cache_id = self._context.transforms.get_id(write_cache) write_cache_proto = write_cache.to_runner_api(self._context) top_level_required_transforms[write_cache_id] = write_cache_proto self._write_cache_ids.add(write_cache_id) for transform in self._include_subtransforms(write_cache): transform_id = self._context.transforms.get_id(transform) transform_proto = transform.to_runner_api(self._context) required_transforms[transform_id] = transform_proto
def expand(self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PCollection: class Reify(beam.DoFn): def process(self, e, w=beam.DoFn.WindowParam, p=beam.DoFn.PaneInfoParam, t=beam.DoFn.TimestampParam): yield test_stream.WindowedValueHolder( WindowedValue(e, t, [w], p)) return (pcoll | 'reify' + self._label >> beam.ParDo(Reify()) | 'write' + self._label >> cache.WriteCache( self._cache_manager, self._key, is_capture=False))
def test_word_count(self): p = beam.Pipeline(runner=self.runner) class WordExtractingDoFn(beam.DoFn): def process(self, element): text_line = element.strip() words = text_line.split() return words # Count the occurrences of each word. pcoll1 = p | beam.Create(['to be or not to be that is the question']) pcoll2 = pcoll1 | 'Split' >> beam.ParDo(WordExtractingDoFn()) pcoll3 = pcoll2 | 'Pair with One' >> beam.Map(lambda x: (x, 1)) pcoll4 = pcoll3 | 'Group' >> beam.GroupByKey() pcoll5 = pcoll4 | 'Count' >> beam.Map(lambda item: (item[0], sum(item[1]))) analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager, to_stable_runner_api(p), self.runner) cache_label1 = 'PColl-1111111' cache_label2 = 'PColl-2222222' cache_label3 = 'PColl-3333333' cache_label4 = 'PColl-4444444' cache_label5 = 'PColl-5555555' # pylint: disable=expression-not-assigned pcoll1 | 'CacheSample%s' % cache_label1 >> cache.WriteCache( self.cache_manager, cache_label1, sample=True, sample_size=10) pcoll2 | 'CacheSample%s' % cache_label2 >> cache.WriteCache( self.cache_manager, cache_label2, sample=True, sample_size=10) pcoll3 | 'CacheSample%s' % cache_label3 >> cache.WriteCache( self.cache_manager, cache_label3, sample=True, sample_size=10) pcoll4 | 'CacheSample%s' % cache_label4 >> cache.WriteCache( self.cache_manager, cache_label3, sample=True, sample_size=10) pcoll5 | 'CacheSample%s' % cache_label5 >> cache.WriteCache( self.cache_manager, cache_label3, sample=True, sample_size=10) pcoll5 | 'CacheFull%s' % cache_label5 >> cache.WriteCache( self.cache_manager, cache_label3) expected_pipeline_proto = to_stable_runner_api(p) self.assertPipelineEqual(analyzer.pipeline_proto_to_execute(), expected_pipeline_proto) pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api( analyzer.pipeline_proto_to_execute(), p.runner, p._options) pipeline_to_execute.run().wait_until_finish()
def _write_cache(self, pipeline, pcoll): """Caches a cacheable PCollection. For the given PCollection, by appending sub transform part that materialize the PCollection through sink into cache implementation. The cache write is not immediate. It happens when the runner runs the transformed pipeline and thus not usable for this run as intended. This function always writes the cache for the given PCollection as long as the PCollection belongs to the pipeline being instrumented and the keyed cache is absent. Modifies: pipeline """ # Makes sure the pcoll belongs to the pipeline being instrumented. if pcoll.pipeline is not pipeline: return # The keyed cache is always valid within this instrumentation. key = self.cache_key(pcoll) # Only need to write when the cache with expected key doesn't exist. if not self._cache_manager.exists('full', key): label = '{}{}'.format(WRITE_CACHE, key) _ = pcoll | label >> cache.WriteCache(self._cache_manager, key)
def reify_to_cache(pcoll: beam.pvalue.PCollection, cache_key: str, cache_manager: cache.CacheManager, reify_label: Optional[str] = None, write_cache_label: Optional[str] = None, is_capture: bool = False) -> beam.pvalue.PValue: """Reifies elements into windowed values and write to cache. Args: pcoll: The PCollection to be cached. cache_key: The key of the cache. cache_manager: The cache manager to manage the cache. reify_label: (optional) A transform label for the Reify transform. write_cache_label: (optional) A transform label for the cache-writing transform. is_capture: Whether the cache is capturing a record of recordable sources. """ if not reify_label: reify_label = '{}{}{}'.format('ReifyBefore_', WRITE_CACHE, cache_key) if not write_cache_label: write_cache_label = '{}{}'.format(WRITE_CACHE, cache_key) return (pcoll | reify_label >> beam.ParDo(Reify()) | write_cache_label >> cache.WriteCache( cache_manager, cache_key, is_capture=is_capture))
def test_instrument_example_unbounded_pipeline_to_read_cache_not_cached( self): """Tests that the instrumenter works when the PCollection is not cached. """ # Create the pipeline that will be instrumented. from apache_beam.options.pipeline_options import StandardOptions options = StandardOptions(streaming=True) p_original_read_cache = beam.Pipeline( interactive_runner.InteractiveRunner(), options) ie.current_env().set_cache_manager(StreamingCache(cache_dir=None), p_original_read_cache) source_1 = p_original_read_cache | 'source1' >> beam.io.ReadFromPubSub( subscription='projects/fake-project/subscriptions/fake_sub') # pylint: disable=possibly-unused-variable pcoll_1 = source_1 | 'square1' >> beam.Map(lambda x: x * x) # Watch but do not cache the PCollections. ib.watch(locals()) # This should be noop. utils.watch_sources(p_original_read_cache) # Instrument the original pipeline to create the pipeline the user will see. p_copy = beam.Pipeline.from_runner_api( p_original_read_cache.to_runner_api(), runner=interactive_runner.InteractiveRunner(), options=options) ie.current_env().add_derived_pipeline(p_original_read_cache, p_copy) instrumenter = instr.build_pipeline_instrument(p_copy) actual_pipeline = beam.Pipeline.from_runner_api( proto=instrumenter.instrumented_pipeline_proto(), runner=interactive_runner.InteractiveRunner(), options=options) # Now, build the expected pipeline which replaces the unbounded source with # a TestStream. source_1_cache_key = self.cache_key_of('source_1', source_1) p_expected = beam.Pipeline() ie.current_env().set_cache_manager(StreamingCache(cache_dir=None), p_expected) test_stream = (p_expected | TestStream(output_tags=[source_1_cache_key])) # pylint: disable=expression-not-assigned (test_stream[source_1_cache_key] | 'square1' >> beam.Map(lambda x: x * x) | 'reify' >> beam.Map(lambda _: _) | cache.WriteCache(ie.current_env().get_cache_manager(p_expected), 'unused')) # Test that the TestStream is outputting to the correct PCollection. class TestStreamVisitor(PipelineVisitor): def __init__(self): self.output_tags = set() def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): transform = transform_node.transform if isinstance(transform, TestStream): self.output_tags = transform.output_tags v = TestStreamVisitor() actual_pipeline.visit(v) expected_output_tags = set([source_1_cache_key]) actual_output_tags = v.output_tags self.assertSetEqual(expected_output_tags, actual_output_tags) # Test that the pipeline is as expected. assert_pipeline_proto_equal(self, p_expected.to_runner_api(), instrumenter.instrumented_pipeline_proto())
def test_instrument_mixed_streaming_batch(self): """Tests caching for both batch and streaming sources in the same pipeline. This ensures that cached bounded and unbounded sources are read from the TestStream. """ # Create the pipeline that will be instrumented. from apache_beam.options.pipeline_options import StandardOptions options = StandardOptions(streaming=True) p_original = beam.Pipeline(interactive_runner.InteractiveRunner(), options) streaming_cache_manager = StreamingCache(cache_dir=None) ie.current_env().set_cache_manager(streaming_cache_manager, p_original) source_1 = p_original | 'source1' >> beam.io.ReadFromPubSub( subscription='projects/fake-project/subscriptions/fake_sub') source_2 = p_original | 'source2' >> beam.Create([1, 2, 3, 4, 5]) # pylint: disable=possibly-unused-variable pcoll_1 = ((source_1, source_2) | beam.Flatten() | 'square1' >> beam.Map(lambda x: x * x)) # Watch but do not cache the PCollections. ib.watch(locals()) # This should be noop. utils.watch_sources(p_original) self._mock_write_cache(p_original, [], self.cache_key_of('source_2', source_2)) ie.current_env().mark_pcollection_computed([source_2]) # Instrument the original pipeline to create the pipeline the user will see. p_copy = beam.Pipeline.from_runner_api( p_original.to_runner_api(), runner=interactive_runner.InteractiveRunner(), options=options) ie.current_env().add_derived_pipeline(p_original, p_copy) instrumenter = instr.build_pipeline_instrument(p_copy) actual_pipeline = beam.Pipeline.from_runner_api( proto=instrumenter.instrumented_pipeline_proto(), runner=interactive_runner.InteractiveRunner(), options=options) # Now, build the expected pipeline which replaces the unbounded source with # a TestStream. source_1_cache_key = self.cache_key_of('source_1', source_1) source_2_cache_key = self.cache_key_of('source_2', source_2) p_expected = beam.Pipeline() ie.current_env().set_cache_manager(streaming_cache_manager, p_expected) test_stream = ( p_expected | TestStream(output_tags=[source_1_cache_key, source_2_cache_key])) # pylint: disable=expression-not-assigned ((test_stream[self.cache_key_of('source_1', source_1)], test_stream[self.cache_key_of('source_2', source_2)]) | beam.Flatten() | 'square1' >> beam.Map(lambda x: x * x) | 'reify' >> beam.Map(lambda _: _) | cache.WriteCache(ie.current_env().get_cache_manager(p_expected), 'unused')) # Test that the TestStream is outputting to the correct PCollection. class TestStreamVisitor(PipelineVisitor): def __init__(self): self.output_tags = set() def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): transform = transform_node.transform if isinstance(transform, TestStream): self.output_tags = transform.output_tags v = TestStreamVisitor() actual_pipeline.visit(v) expected_output_tags = set([source_1_cache_key, source_2_cache_key]) actual_output_tags = v.output_tags self.assertSetEqual(expected_output_tags, actual_output_tags) # Test that the pipeline is as expected. assert_pipeline_proto_equal(self, p_expected.to_runner_api(), instrumenter.instrumented_pipeline_proto())
def test_able_to_cache_intermediate_unbounded_source_pcollection(self): """Tests being able to cache an intermediate source PCollection. In the following pipeline, the source doesn't have a reference and so is not automatically cached in the watch() command. This tests that this case is taken care of. """ # Create the pipeline that will be instrumented. from apache_beam.options.pipeline_options import StandardOptions options = StandardOptions(streaming=True) streaming_cache_manager = StreamingCache(cache_dir=None) p_original_cache_source = beam.Pipeline( interactive_runner.InteractiveRunner(), options) ie.current_env().set_cache_manager(streaming_cache_manager, p_original_cache_source) # pylint: disable=possibly-unused-variable source_1 = ( p_original_cache_source | 'source1' >> beam.io.ReadFromPubSub( subscription='projects/fake-project/subscriptions/fake_sub') | beam.Map(lambda e: e)) # Watch but do not cache the PCollections. ib.watch(locals()) # Make sure that sources without a user reference are still cached. utils.watch_sources(p_original_cache_source) intermediate_source_pcoll = None for watching in ie.current_env().watching(): watching = list(watching) for var, watchable in watching: if 'synthetic' in var: intermediate_source_pcoll = watchable break # Instrument the original pipeline to create the pipeline the user will see. p_copy = beam.Pipeline.from_runner_api( p_original_cache_source.to_runner_api(), runner=interactive_runner.InteractiveRunner(), options=options) ie.current_env().add_derived_pipeline(p_original_cache_source, p_copy) instrumenter = instr.build_pipeline_instrument(p_copy) actual_pipeline = beam.Pipeline.from_runner_api( proto=instrumenter.instrumented_pipeline_proto(), runner=interactive_runner.InteractiveRunner(), options=options) ie.current_env().add_derived_pipeline(p_original_cache_source, actual_pipeline) # Now, build the expected pipeline which replaces the unbounded source with # a TestStream. intermediate_source_pcoll_cache_key = \ self.cache_key_of('synthetic_var_' + str(id(intermediate_source_pcoll)), intermediate_source_pcoll) p_expected = beam.Pipeline() ie.current_env().set_cache_manager(streaming_cache_manager, p_expected) test_stream = ( p_expected | TestStream(output_tags=[intermediate_source_pcoll_cache_key])) # pylint: disable=expression-not-assigned (test_stream[intermediate_source_pcoll_cache_key] | 'square1' >> beam.Map(lambda e: e) | 'reify' >> beam.Map(lambda _: _) | cache.WriteCache(ie.current_env().get_cache_manager(p_expected), 'unused')) # Test that the TestStream is outputting to the correct PCollection. class TestStreamVisitor(PipelineVisitor): def __init__(self): self.output_tags = set() def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): transform = transform_node.transform if isinstance(transform, TestStream): self.output_tags = transform.output_tags v = TestStreamVisitor() actual_pipeline.visit(v) expected_output_tags = set([intermediate_source_pcoll_cache_key]) actual_output_tags = v.output_tags self.assertSetEqual(expected_output_tags, actual_output_tags) # Test that the pipeline is as expected. assert_pipeline_proto_equal(self, p_expected.to_runner_api(), instrumenter.instrumented_pipeline_proto())
def test_load_saved_pcoder(self): pipeline = beam.Pipeline() pcoll = pipeline | beam.Create([1, 2, 3]) _ = pcoll | cache.WriteCache(self.cache_manager, 'a key') self.assertIs(type(self.cache_manager.load_pcoder('full', 'a key')), type(coders.registry.get_coder(int)))
def _write_cache(self, pipeline, pcoll, output_as_extended_target=True, ignore_unbounded_reads=False, is_capture=False): """Caches a cacheable PCollection. For the given PCollection, by appending sub transform part that materialize the PCollection through sink into cache implementation. The cache write is not immediate. It happens when the runner runs the transformed pipeline and thus not usable for this run as intended. This function always writes the cache for the given PCollection as long as the PCollection belongs to the pipeline being instrumented and the keyed cache is absent. Modifies: pipeline """ # Makes sure the pcoll belongs to the pipeline being instrumented. if pcoll.pipeline is not pipeline: return # Ignore the unbounded reads from capturable sources as these will be pruned # out using the PipelineFragment later on. if ignore_unbounded_reads: ignore = False producer = pcoll.producer while producer: if isinstance( producer.transform, tuple(ie.current_env().options.capturable_sources)): ignore = True break producer = producer.parent if ignore: self._ignored_targets.add(pcoll) return # The keyed cache is always valid within this instrumentation. key = self.cache_key(pcoll) # Only need to write when the cache with expected key doesn't exist. if not self._cache_manager.exists('full', key): label = '{}{}'.format(WRITE_CACHE, key) # Read the windowing information and cache it along with the element. This # caches the arguments to a WindowedValue object because Python has logic # that detects if a DoFn returns a WindowedValue. When it detecs one, it # puts the element into the correct window then emits the value to # downstream transforms. class Reify(beam.DoFn): def process(self, e, w=beam.DoFn.WindowParam, p=beam.DoFn.PaneInfoParam, t=beam.DoFn.TimestampParam): yield test_stream.WindowedValueHolder( WindowedValue(e, t, [w], p)) extended_target = ( pcoll | label + 'reify' >> beam.ParDo(Reify()) | label >> cache.WriteCache( self._cache_manager, key, is_capture=is_capture)) if output_as_extended_target: self._extended_targets.add(extended_target)
def run_pipeline(self, pipeline): if not hasattr(self, '_desired_cache_labels'): self._desired_cache_labels = set() # Invoke a round trip through the runner API. This makes sure the Pipeline # proto is stable. pipeline = beam.pipeline.Pipeline.from_runner_api( pipeline.to_runner_api(), pipeline.runner, pipeline._options) # Snapshot the pipeline in a portable proto before mutating it. pipeline_proto, original_context = pipeline.to_runner_api( return_context=True) pcolls_to_pcoll_id = self._pcolls_to_pcoll_id(pipeline, original_context) # TODO(qinyeli): Refactor the rest of this function into # def manipulate_pipeline(pipeline_proto) -> pipeline_proto_to_run: # Make a copy of the original pipeline to avoid accidental manipulation pipeline, context = beam.pipeline.Pipeline.from_runner_api( pipeline_proto, self._underlying_runner, pipeline._options, # pylint: disable=protected-access return_context=True) pipeline_info = PipelineInfo(pipeline_proto.components) caches_used = set() def _producing_transforms(pcoll_id, leaf=False): """Returns PTransforms (and their names) that produces the given PColl.""" if pcoll_id in _producing_transforms.analyzed_pcoll_ids: return else: _producing_transforms.analyzed_pcoll_ids.add(pcoll_id) derivation = pipeline_info.derivation(pcoll_id) if self._cache_manager.exists('full', derivation.cache_label()): # If the PCollection is cached, yield ReadCache PTransform that reads # the PCollection and all its sub PTransforms. if not leaf: caches_used.add(pcoll_id) cache_label = pipeline_info.derivation( pcoll_id).cache_label() dummy_pcoll = pipeline | 'Load%s' % cache_label >> cache.ReadCache( self._cache_manager, cache_label) # Find the top level ReadCache composite PTransform. read_cache = dummy_pcoll.producer while read_cache.parent.parent: read_cache = read_cache.parent def _include_subtransforms(transform): """Depth-first yield the PTransform itself and its sub PTransforms. """ yield transform for subtransform in transform.parts: for yielded in _include_subtransforms( subtransform): yield yielded for transform in _include_subtransforms(read_cache): transform_proto = transform.to_runner_api(context) if dummy_pcoll in transform.outputs.values(): transform_proto.outputs['None'] = pcoll_id yield context.transforms.get_id( transform), transform_proto else: transform_id, _ = pipeline_info.producer(pcoll_id) transform_proto = pipeline_proto.components.transforms[ transform_id] for input_id in transform_proto.inputs.values(): for transform in _producing_transforms(input_id): yield transform yield transform_id, transform_proto desired_pcollections = self._desired_pcollections(pipeline_info) # TODO(qinyeli): Preserve composite structure. required_transforms = collections.OrderedDict() _producing_transforms.analyzed_pcoll_ids = set() for pcoll_id in desired_pcollections: # TODO(qinyeli): Collections consumed by no-output transforms. required_transforms.update(_producing_transforms(pcoll_id, True)) referenced_pcollections = self._referenced_pcollections( pipeline_proto, required_transforms) required_transforms['_root'] = beam_runner_api_pb2.PTransform( subtransforms=required_transforms.keys()) pipeline_to_execute = copy.deepcopy(pipeline_proto) pipeline_to_execute.root_transform_ids[:] = ['_root'] set_proto_map(pipeline_to_execute.components.transforms, required_transforms) set_proto_map(pipeline_to_execute.components.pcollections, referenced_pcollections) set_proto_map(pipeline_to_execute.components.coders, context.to_runner_api().coders) pipeline_slice, context = beam.pipeline.Pipeline.from_runner_api( pipeline_to_execute, self._underlying_runner, pipeline._options, # pylint: disable=protected-access return_context=True) # TODO(qinyeli): cache only top-level pcollections. for pcoll_id in pipeline_info.all_pcollections(): if pcoll_id not in referenced_pcollections: continue cache_label = pipeline_info.derivation(pcoll_id).cache_label() pcoll = context.pcollections.get_by_id(pcoll_id) if pcoll_id in desired_pcollections: # pylint: disable=expression-not-assigned pcoll | 'CacheFull%s' % cache_label >> cache.WriteCache( self._cache_manager, cache_label) if pcoll_id in referenced_pcollections: # pylint: disable=expression-not-assigned pcoll | 'CacheSample%s' % cache_label >> cache.WriteCache( self._cache_manager, cache_label, sample=True, sample_size=SAMPLE_SIZE) display = display_manager.DisplayManager( pipeline_info=pipeline_info, pipeline_proto=pipeline_proto, caches_used=caches_used, cache_manager=self._cache_manager, referenced_pcollections=referenced_pcollections, required_transforms=required_transforms) display.start_periodic_update() result = pipeline_slice.run() result.wait_until_finish() display.stop_periodic_update() return PipelineResult(result, self, pipeline_info, self._cache_manager, pcolls_to_pcoll_id)