def pcoll_from_file_cache( query_pipeline: beam.Pipeline, pcoll: beam.PCollection, cache_manager: FileBasedCacheManager, key: str) -> beam.PCollection: """Reads PCollection cache from files. Args: query_pipeline: The beam.Pipeline object built by the magic to execute the SQL query. pcoll: The PCollection to read cache for. cache_manager: The file based cache manager that holds the PCollection cache. key: The key of the PCollection cache. Returns: A PCollection read from the cache. """ schema = pcoll.element_type class Unreify(beam.DoFn): def process(self, e): if isinstance(e, beam.Row) and hasattr(e, 'windowed_value'): yield e.windowed_value return ( query_pipeline | '{}{}'.format('QuerySource', key) >> cache.ReadCache(cache_manager, key) | '{}{}'.format('Unreify', key) >> beam.ParDo( Unreify()).with_output_types(schema))
def _read_cache(self, pipeline, pcoll): """Reads a cached pvalue. A noop will cause the pipeline to execute the transform as it is and cache nothing from this transform for next run. Modifies: pipeline """ # Makes sure the pcoll belongs to the pipeline being instrumented. if pcoll.pipeline is not pipeline: return # The keyed cache is always valid within this instrumentation. key = self.cache_key(pcoll) # Can only read from cache when the cache with expected key exists and its # computation has been completed. if (self._cache_manager.exists('full', key) and (self._runner_pcoll_to_user_pcoll[pcoll] in ie.current_env().computed_pcollections)): if key not in self._cached_pcoll_read: # Mutates the pipeline with cache read transform attached # to root of the pipeline. pcoll_from_cache = ( pipeline | '{}{}'.format(READ_CACHE, key) >> cache.ReadCache( self._cache_manager, key)) self._cached_pcoll_read[key] = pcoll_from_cache
def unreify_from_cache( pipeline: beam.Pipeline, cache_key: str, cache_manager: cache.CacheManager, element_type: Optional[type] = None, source_label: Optional[str] = None, unreify_label: Optional[str] = None) -> beam.pvalue.PCollection: """Reads from cache and unreifies elements from windowed values. pipeline: The pipeline that's reading from the cache. cache_key: The key of the cache. cache_manager: The cache manager to manage the cache. element_type: (optional) The element type of the PCollection's elements. source_label: (optional) A transform label for the cache-reading transform. unreify_label: (optional) A transform label for the Unreify transform. """ if not source_label: source_label = '{}{}'.format(READ_CACHE, cache_key) if not unreify_label: unreify_label = '{}{}{}'.format('UnreifyAfter_', READ_CACHE, cache_key) read_cache = pipeline | source_label >> cache.ReadCache( cache_manager, cache_key) if element_type: # If the PCollection is schema-aware, explicitly sets the output types. return read_cache | unreify_label >> beam.ParDo( Unreify()).with_output_types(element_type) return read_cache | unreify_label >> beam.ParDo(Unreify())
def expand(self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PCollection: class Unreify(beam.DoFn): def process(self, e): yield e.windowed_value return (pcoll.pipeline | 'read' + self._label >> cache.ReadCache( self._cache_manager, self._key) | 'unreify' + self._label >> beam.ParDo(Unreify()))
def test_read_cache_expansion(self): p = beam.Pipeline(runner=self.runner) # The cold run. pcoll = (p | 'Create' >> beam.Create([1, 2, 3]) | 'Double' >> beam.Map(lambda x: x * 2) | 'Square' >> beam.Map(lambda x: x**2)) pipeline_proto = to_stable_runner_api(p) pipeline_info = pipeline_analyzer.PipelineInfo( pipeline_proto.components) pcoll_id = 'ref_PCollection_PCollection_3' # Output PCollection of Square cache_label1 = pipeline_info.cache_label(pcoll_id) analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager, pipeline_proto, self.runner) pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api( analyzer.pipeline_proto_to_execute(), p.runner, p._options) pipeline_to_execute.run().wait_until_finish() # The second run. _ = (pcoll | 'Triple' >> beam.Map(lambda x: x * 3) | 'Cube' >> beam.Map(lambda x: x**3)) analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager, to_stable_runner_api(p), self.runner) expected_pipeline = beam.Pipeline(runner=self.runner) pcoll1 = (expected_pipeline | 'Load%s' % cache_label1 >> cache.ReadCache( self.cache_manager, cache_label1)) pcoll2 = pcoll1 | 'Triple' >> beam.Map(lambda x: x * 3) pcoll3 = pcoll2 | 'Cube' >> beam.Map(lambda x: x**3) cache_label2 = 'PColl-7654321' cache_label3 = 'PColl-3141593' # pylint: disable=expression-not-assigned pcoll2 | 'CacheSample%s' % cache_label2 >> cache.WriteCache( self.cache_manager, cache_label2, sample=True, sample_size=10) pcoll3 | 'CacheSample%s' % cache_label3 >> cache.WriteCache( self.cache_manager, cache_label3, sample=True, sample_size=10) pcoll3 | 'CacheFull%s' % cache_label3 >> cache.WriteCache( self.cache_manager, cache_label3) # Since ReadCache & WriteCache expansion leads to more than 50 PTransform # protos in the pipeline, a simple check of proto map size is enough. self.assertPipelineEqual(analyzer.pipeline_proto_to_execute(), to_stable_runner_api(expected_pipeline))
def _producing_transforms(pcoll_id, leaf=False): """Returns PTransforms (and their names) that produces the given PColl.""" if pcoll_id in _producing_transforms.analyzed_pcoll_ids: return else: _producing_transforms.analyzed_pcoll_ids.add(pcoll_id) derivation = pipeline_info.derivation(pcoll_id) if self._cache_manager.exists('full', derivation.cache_label()): # If the PCollection is cached, yield ReadCache PTransform that reads # the PCollection and all its sub PTransforms. if not leaf: caches_used.add(pcoll_id) cache_label = pipeline_info.derivation( pcoll_id).cache_label() dummy_pcoll = pipeline | 'Load%s' % cache_label >> cache.ReadCache( self._cache_manager, cache_label) # Find the top level ReadCache composite PTransform. read_cache = dummy_pcoll.producer while read_cache.parent.parent: read_cache = read_cache.parent def _include_subtransforms(transform): """Depth-first yield the PTransform itself and its sub PTransforms. """ yield transform for subtransform in transform.parts: for yielded in _include_subtransforms( subtransform): yield yielded for transform in _include_subtransforms(read_cache): transform_proto = transform.to_runner_api(context) if dummy_pcoll in transform.outputs.values(): transform_proto.outputs['None'] = pcoll_id yield context.transforms.get_id( transform), transform_proto else: transform_id, _ = pipeline_info.producer(pcoll_id) transform_proto = pipeline_proto.components.transforms[ transform_id] for input_id in transform_proto.inputs.values(): for transform in _producing_transforms(input_id): yield transform yield transform_id, transform_proto
def test_instrument_example_pipeline_to_read_cache(self): p_origin, init_pcoll, second_pcoll = self._example_pipeline() p_copy, _, _ = self._example_pipeline(False) # Mock as if cacheable PCollections are cached. init_pcoll_cache_key = self.cache_key_of('init_pcoll', init_pcoll) self._mock_write_cache(p_origin, [b'1', b'2', b'3'], init_pcoll_cache_key) second_pcoll_cache_key = self.cache_key_of('second_pcoll', second_pcoll) self._mock_write_cache(p_origin, [b'1', b'4', b'9'], second_pcoll_cache_key) # Mark the completeness of PCollections from the original(user) pipeline. ie.current_env().mark_pcollection_computed((init_pcoll, second_pcoll)) ie.current_env().add_derived_pipeline(p_origin, p_copy) instr.build_pipeline_instrument(p_copy) cached_init_pcoll = ( p_origin | '_ReadCache_' + init_pcoll_cache_key >> cache.ReadCache( ie.current_env().get_cache_manager(p_origin), init_pcoll_cache_key) | 'unreify' >> beam.Map(lambda _: _)) # second_pcoll is never used as input and there is no need to read cache. class TestReadCacheWireVisitor(PipelineVisitor): """Replace init_pcoll with cached_init_pcoll for all occuring inputs.""" def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): if transform_node.inputs: main_inputs = dict(transform_node.main_inputs) for tag, main_input in main_inputs.items(): if main_input == init_pcoll: main_inputs[tag] = cached_init_pcoll transform_node.main_inputs = main_inputs v = TestReadCacheWireVisitor() p_origin.visit(v) assert_pipeline_equal(self, p_origin, p_copy)
def test_instrument_example_pipeline_to_read_cache(self): p_origin, init_pcoll, second_pcoll = self._example_pipeline() p_copy, _, _ = self._example_pipeline(False) # Mock as if cacheable PCollections are cached. init_pcoll_cache_key = 'init_pcoll_' + str(id(init_pcoll)) + '_' + str( id(init_pcoll.producer)) self._mock_write_cache([b'1', b'2', b'3'], init_pcoll_cache_key) second_pcoll_cache_key = 'second_pcoll_' + str( id(second_pcoll)) + '_' + str(id(second_pcoll.producer)) self._mock_write_cache([b'1', b'4', b'9'], second_pcoll_cache_key) # Mark the completeness of PCollections from the original(user) pipeline. ie.current_env().mark_pcollection_computed( (p_origin, init_pcoll, second_pcoll)) instr.build_pipeline_instrument(p_copy) cached_init_pcoll = ( p_origin | '_ReadCache_' + init_pcoll_cache_key >> cache.ReadCache( ie.current_env().cache_manager(), init_pcoll_cache_key) | 'unreify' >> beam.Map(lambda _: _)) # second_pcoll is never used as input and there is no need to read cache. class TestReadCacheWireVisitor(PipelineVisitor): """Replace init_pcoll with cached_init_pcoll for all occuring inputs.""" def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): if transform_node.inputs: input_list = list(transform_node.inputs) for i in range(len(input_list)): if input_list[i] == init_pcoll: input_list[i] = cached_init_pcoll transform_node.inputs = tuple(input_list) v = TestReadCacheWireVisitor() p_origin.visit(v) assert_pipeline_equal(self, p_origin, p_copy)
def _read_cache(self, pipeline, pcoll, is_unbounded_source_output): """Reads a cached pvalue. A noop will cause the pipeline to execute the transform as it is and cache nothing from this transform for next run. Modifies: pipeline """ # Makes sure the pcoll belongs to the pipeline being instrumented. if pcoll.pipeline is not pipeline: return # The keyed cache is always valid within this instrumentation. key = self.cache_key(pcoll) # Can only read from cache when the cache with expected key exists and its # computation has been completed. is_cached = self._cache_manager.exists('full', key) is_computed = ( pcoll in self._runner_pcoll_to_user_pcoll and self._runner_pcoll_to_user_pcoll[pcoll] in ie.current_env().computed_pcollections) if ((is_cached and is_computed) or is_unbounded_source_output): if key not in self._cached_pcoll_read: # Mutates the pipeline with cache read transform attached # to root of the pipeline. # To put the cached value into the correct window, simply return a # WindowedValue constructed from the element. class Unreify(beam.DoFn): def process(self, e): yield e.windowed_value pcoll_from_cache = ( pipeline | '{}{}'.format(READ_CACHE, key) >> cache.ReadCache( self._cache_manager, key) | '{}{}unreify'.format(READ_CACHE, key) >> beam.ParDo(Unreify())) self._cached_pcoll_read[key] = pcoll_from_cache
def test_instrument_example_pipeline_to_read_cache(self): p_origin, init_pcoll, second_pcoll = self._example_pipeline() p_copy, _, _ = self._example_pipeline(False) # Mock as if cacheable PCollections are cached. init_pcoll_cache_key = 'init_pcoll_' + str( id(init_pcoll)) + '_ref_PCollection_PCollection_10_' + str( id(init_pcoll.producer)) self._mock_write_cache(init_pcoll, init_pcoll_cache_key) second_pcoll_cache_key = 'second_pcoll_' + str( id(second_pcoll)) + '_ref_PCollection_PCollection_11_' + str( id(second_pcoll.producer)) self._mock_write_cache(second_pcoll, second_pcoll_cache_key) ie.current_env().cache_manager().exists = MagicMock(return_value=True) instr.pin(p_copy) cached_init_pcoll = p_origin | ( '_ReadCache_' + init_pcoll_cache_key) >> cache.ReadCache( ie.current_env().cache_manager(), init_pcoll_cache_key) # second_pcoll is never used as input and there is no need to read cache. class TestReadCacheWireVisitor(PipelineVisitor): """Replace init_pcoll with cached_init_pcoll for all occuring inputs.""" def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): if transform_node.inputs: input_list = list(transform_node.inputs) for i in range(len(input_list)): if input_list[i] == init_pcoll: input_list[i] = cached_init_pcoll transform_node.inputs = tuple(input_list) v = TestReadCacheWireVisitor() p_origin.visit(v) self.assertPipelineEqual(p_origin, p_copy)
def _insert_producing_transforms(self, pcoll_id, required_transforms, top_level_required_transforms, leaf=False): """Inserts PTransforms producing the given PCollection into the dicts. Args: pcoll_id: (str) required_transforms: (Dict[str, PTransform proto]) top_level_required_transforms: (Dict[str, PTransform proto]) leaf: (bool) whether the PCollection should be read from cache if the cache exists. Modifies: required_transforms top_level_required_transforms self._read_cache_ids """ if pcoll_id in self._analyzed_pcoll_ids: return else: self._analyzed_pcoll_ids.add(pcoll_id) cache_label = self._pipeline_info.cache_label(pcoll_id) if self._cache_manager.exists('full', cache_label) and not leaf: self._caches_used.add(pcoll_id) cache_label = self._pipeline_info.cache_label(pcoll_id) dummy_pcoll = (self._pipeline | 'Load%s' % cache_label >> cache.ReadCache( self._cache_manager, cache_label)) read_cache = self._top_level_producer(dummy_pcoll) read_cache_id = self._context.transforms.get_id(read_cache) read_cache_proto = read_cache.to_runner_api(self._context) read_cache_proto.outputs['None'] = pcoll_id top_level_required_transforms[read_cache_id] = read_cache_proto self._read_cache_ids.add(read_cache_id) for transform in self._include_subtransforms(read_cache): transform_id = self._context.transforms.get_id(transform) transform_proto = transform.to_runner_api(self._context) if dummy_pcoll in transform.outputs.values(): transform_proto.outputs['None'] = pcoll_id required_transforms[transform_id] = transform_proto else: pcoll = self._context.pcollections.get_by_id(pcoll_id) top_level_transform = self._top_level_producer(pcoll) for transform in self._include_subtransforms(top_level_transform): transform_id = self._context.transforms.get_id(transform) transform_proto = self._context.transforms.get_proto(transform) # Inserting ancestor PTransforms. for input_id in transform_proto.inputs.values(): self._insert_producing_transforms( input_id, required_transforms, top_level_required_transforms) required_transforms[transform_id] = transform_proto # Must be inserted after inserting ancestor PTransforms. top_level_id = self._context.transforms.get_id(top_level_transform) top_level_proto = self._context.transforms.get_proto( top_level_transform) top_level_required_transforms[top_level_id] = top_level_proto