Esempio n. 1
0
def pcolls_from_streaming_cache(
    user_pipeline: beam.Pipeline,
    query_pipeline: beam.Pipeline,
    name_to_pcoll: Dict[str, beam.PCollection],
    instrumentation: inst.PipelineInstrument,
    cache_manager: StreamingCache) -> Dict[str, beam.PCollection]:
  """Reads PCollection cache through the TestStream.

  Args:
    user_pipeline: The beam.Pipeline object defined by the user in the
        notebook.
    query_pipeline: The beam.Pipeline object built by the magic to execute the
        SQL query.
    name_to_pcoll: PCollections with variable names used in the SQL query.
    instrumentation: A pipeline_instrument.PipelineInstrument that helps
        calculate the cache key of a given PCollection.
    cache_manager: The streaming cache manager that holds the PCollection cache.

  Returns:
    A Dict[str, beam.PCollection], where each PCollection is tagged with
    their PCollection variable names, read from the cache.

  When the user_pipeline has unbounded sources, we force all cache reads to go
  through the TestStream even if they are bounded sources.
  """
  def exception_handler(e):
    _LOGGER.error(str(e))
    return True

  test_stream_service = ie.current_env().get_test_stream_service_controller(
      user_pipeline)
  if not test_stream_service:
    test_stream_service = TestStreamServiceController(
        cache_manager, exception_handler=exception_handler)
    test_stream_service.start()
    ie.current_env().set_test_stream_service_controller(
        user_pipeline, test_stream_service)

  tag_to_name = {}
  for name, pcoll in name_to_pcoll.items():
    key = instrumentation.cache_key(pcoll)
    tag_to_name[key] = name
  output_pcolls = query_pipeline | test_stream.TestStream(
      output_tags=set(tag_to_name.keys()),
      coder=cache_manager._default_pcoder,
      endpoint=test_stream_service.endpoint)
  sql_source = {}
  for tag, output in output_pcolls.items():
    name = tag_to_name[tag]
    # Must mark the element_type to avoid introducing pickled Python coder
    # to the Java expansion service.
    output.element_type = name_to_pcoll[name].element_type
    sql_source[name] = output
  return sql_source
Esempio n. 2
0
    def _replace_with_cached_inputs(self, pipeline):
        """Replace PCollection inputs in the pipeline with cache if possible.

    For any input PCollection, find out whether there is valid cache. If so,
    replace the input of the AppliedPTransform with output of the
    AppliedPtransform that sources pvalue from the cache. If there is no valid
    cache, noop.
    """

        # Find all cached unbounded PCollections.

        # If the pipeline has unbounded sources, then we want to force all cache
        # reads to go through the TestStream (even if they are bounded sources).
        if self.has_unbounded_sources:

            class CacheableUnboundedPCollectionVisitor(PipelineVisitor):
                def __init__(self, pin):
                    self._pin = pin
                    self.unbounded_pcolls = set()

                def enter_composite_transform(self, transform_node):
                    self.visit_transform(transform_node)

                def visit_transform(self, transform_node):
                    if transform_node.outputs:
                        for output_pcoll in transform_node.outputs.values():
                            key = self._pin.cache_key(output_pcoll)
                            if key in self._pin._cached_pcoll_read:
                                self.unbounded_pcolls.add(key)

                    if transform_node.inputs:
                        for input_pcoll in transform_node.inputs:
                            key = self._pin.cache_key(input_pcoll)
                            if key in self._pin._cached_pcoll_read:
                                self.unbounded_pcolls.add(key)

            v = CacheableUnboundedPCollectionVisitor(self)
            pipeline.visit(v)

            # The set of keys from the cached unbounded PCollections will be used as
            # the output tags for the TestStream. This is to remember what cache-key
            # is associated with which PCollection.
            output_tags = v.unbounded_pcolls

            # Take the PCollections that will be read from the TestStream and insert
            # them back into the dictionary of cached PCollections. The next step will
            # replace the downstream consumer of the non-cached PCollections with
            # these PCollections.
            if output_tags:
                output_pcolls = pipeline | test_stream.TestStream(
                    output_tags=output_tags,
                    coder=self._cache_manager._default_pcoder)
                for tag, pcoll in output_pcolls.items():
                    self._cached_pcoll_read[tag] = pcoll

        class ReadCacheWireVisitor(PipelineVisitor):
            """Visitor wires cache read as inputs to replace corresponding original
      input PCollections in pipeline.
      """
            def __init__(self, pin):
                """Initializes with a PipelineInstrument."""
                self._pin = pin

            def enter_composite_transform(self, transform_node):
                self.visit_transform(transform_node)

            def visit_transform(self, transform_node):
                if transform_node.inputs:
                    main_inputs = dict(transform_node.main_inputs)
                    for tag, input_pcoll in main_inputs.items():
                        key = self._pin.cache_key(input_pcoll)

                        # Replace the input pcollection with the cached pcollection (if it
                        # has been cached).
                        if key in self._pin._cached_pcoll_read:
                            # Ignore this pcoll in the final pruned instrumented pipeline.
                            self._pin._ignored_targets.add(input_pcoll)
                            main_inputs[tag] = self._pin._cached_pcoll_read[
                                key]
                    # Update the transform with its new inputs.
                    transform_node.main_inputs = main_inputs

        v = ReadCacheWireVisitor(self)
        pipeline.visit(v)