def pcolls_from_streaming_cache( user_pipeline: beam.Pipeline, query_pipeline: beam.Pipeline, name_to_pcoll: Dict[str, beam.PCollection], instrumentation: inst.PipelineInstrument, cache_manager: StreamingCache) -> Dict[str, beam.PCollection]: """Reads PCollection cache through the TestStream. Args: user_pipeline: The beam.Pipeline object defined by the user in the notebook. query_pipeline: The beam.Pipeline object built by the magic to execute the SQL query. name_to_pcoll: PCollections with variable names used in the SQL query. instrumentation: A pipeline_instrument.PipelineInstrument that helps calculate the cache key of a given PCollection. cache_manager: The streaming cache manager that holds the PCollection cache. Returns: A Dict[str, beam.PCollection], where each PCollection is tagged with their PCollection variable names, read from the cache. When the user_pipeline has unbounded sources, we force all cache reads to go through the TestStream even if they are bounded sources. """ def exception_handler(e): _LOGGER.error(str(e)) return True test_stream_service = ie.current_env().get_test_stream_service_controller( user_pipeline) if not test_stream_service: test_stream_service = TestStreamServiceController( cache_manager, exception_handler=exception_handler) test_stream_service.start() ie.current_env().set_test_stream_service_controller( user_pipeline, test_stream_service) tag_to_name = {} for name, pcoll in name_to_pcoll.items(): key = instrumentation.cache_key(pcoll) tag_to_name[key] = name output_pcolls = query_pipeline | test_stream.TestStream( output_tags=set(tag_to_name.keys()), coder=cache_manager._default_pcoder, endpoint=test_stream_service.endpoint) sql_source = {} for tag, output in output_pcolls.items(): name = tag_to_name[tag] # Must mark the element_type to avoid introducing pickled Python coder # to the Java expansion service. output.element_type = name_to_pcoll[name].element_type sql_source[name] = output return sql_source
def _replace_with_cached_inputs(self, pipeline): """Replace PCollection inputs in the pipeline with cache if possible. For any input PCollection, find out whether there is valid cache. If so, replace the input of the AppliedPTransform with output of the AppliedPtransform that sources pvalue from the cache. If there is no valid cache, noop. """ # Find all cached unbounded PCollections. # If the pipeline has unbounded sources, then we want to force all cache # reads to go through the TestStream (even if they are bounded sources). if self.has_unbounded_sources: class CacheableUnboundedPCollectionVisitor(PipelineVisitor): def __init__(self, pin): self._pin = pin self.unbounded_pcolls = set() def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): if transform_node.outputs: for output_pcoll in transform_node.outputs.values(): key = self._pin.cache_key(output_pcoll) if key in self._pin._cached_pcoll_read: self.unbounded_pcolls.add(key) if transform_node.inputs: for input_pcoll in transform_node.inputs: key = self._pin.cache_key(input_pcoll) if key in self._pin._cached_pcoll_read: self.unbounded_pcolls.add(key) v = CacheableUnboundedPCollectionVisitor(self) pipeline.visit(v) # The set of keys from the cached unbounded PCollections will be used as # the output tags for the TestStream. This is to remember what cache-key # is associated with which PCollection. output_tags = v.unbounded_pcolls # Take the PCollections that will be read from the TestStream and insert # them back into the dictionary of cached PCollections. The next step will # replace the downstream consumer of the non-cached PCollections with # these PCollections. if output_tags: output_pcolls = pipeline | test_stream.TestStream( output_tags=output_tags, coder=self._cache_manager._default_pcoder) for tag, pcoll in output_pcolls.items(): self._cached_pcoll_read[tag] = pcoll class ReadCacheWireVisitor(PipelineVisitor): """Visitor wires cache read as inputs to replace corresponding original input PCollections in pipeline. """ def __init__(self, pin): """Initializes with a PipelineInstrument.""" self._pin = pin def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): if transform_node.inputs: main_inputs = dict(transform_node.main_inputs) for tag, input_pcoll in main_inputs.items(): key = self._pin.cache_key(input_pcoll) # Replace the input pcollection with the cached pcollection (if it # has been cached). if key in self._pin._cached_pcoll_read: # Ignore this pcoll in the final pruned instrumented pipeline. self._pin._ignored_targets.add(input_pcoll) main_inputs[tag] = self._pin._cached_pcoll_read[ key] # Update the transform with its new inputs. transform_node.main_inputs = main_inputs v = ReadCacheWireVisitor(self) pipeline.visit(v)