Exemple #1
0
def pcoll_from_file_cache(
    query_pipeline: beam.Pipeline,
    pcoll: beam.PCollection,
    cache_manager: FileBasedCacheManager,
    key: str) -> beam.PCollection:
  """Reads PCollection cache from files.

  Args:
    query_pipeline: The beam.Pipeline object built by the magic to execute the
        SQL query.
    pcoll: The PCollection to read cache for.
    cache_manager: The file based cache manager that holds the PCollection
        cache.
    key: The key of the PCollection cache.

  Returns:
    A PCollection read from the cache.
  """
  schema = pcoll.element_type

  class Unreify(beam.DoFn):
    def process(self, e):
      if isinstance(e, beam.Row) and hasattr(e, 'windowed_value'):
        yield e.windowed_value

  return (
      query_pipeline
      |
      '{}{}'.format('QuerySource', key) >> cache.ReadCache(cache_manager, key)
      | '{}{}'.format('Unreify', key) >> beam.ParDo(
          Unreify()).with_output_types(schema))
Exemple #2
0
  def _read_cache(self, pipeline, pcoll):
    """Reads a cached pvalue.

    A noop will cause the pipeline to execute the transform as
    it is and cache nothing from this transform for next run.

    Modifies:
      pipeline
    """
    # Makes sure the pcoll belongs to the pipeline being instrumented.
    if pcoll.pipeline is not pipeline:
      return
    # The keyed cache is always valid within this instrumentation.
    key = self.cache_key(pcoll)
    # Can only read from cache when the cache with expected key exists and its
    # computation has been completed.
    if (self._cache_manager.exists('full', key) and
        (self._runner_pcoll_to_user_pcoll[pcoll] in
         ie.current_env().computed_pcollections)):
      if key not in self._cached_pcoll_read:
        # Mutates the pipeline with cache read transform attached
        # to root of the pipeline.
        pcoll_from_cache = (
            pipeline
            | '{}{}'.format(READ_CACHE, key) >> cache.ReadCache(
                self._cache_manager, key))
        self._cached_pcoll_read[key] = pcoll_from_cache
Exemple #3
0
def unreify_from_cache(
        pipeline: beam.Pipeline,
        cache_key: str,
        cache_manager: cache.CacheManager,
        element_type: Optional[type] = None,
        source_label: Optional[str] = None,
        unreify_label: Optional[str] = None) -> beam.pvalue.PCollection:
    """Reads from cache and unreifies elements from windowed values.

  pipeline: The pipeline that's reading from the cache.
  cache_key: The key of the cache.
  cache_manager: The cache manager to manage the cache.
  element_type: (optional) The element type of the PCollection's elements.
  source_label: (optional) A transform label for the cache-reading transform.
  unreify_label: (optional) A transform label for the Unreify transform.
  """
    if not source_label:
        source_label = '{}{}'.format(READ_CACHE, cache_key)
    if not unreify_label:
        unreify_label = '{}{}{}'.format('UnreifyAfter_', READ_CACHE, cache_key)
    read_cache = pipeline | source_label >> cache.ReadCache(
        cache_manager, cache_key)
    if element_type:
        # If the PCollection is schema-aware, explicitly sets the output types.
        return read_cache | unreify_label >> beam.ParDo(
            Unreify()).with_output_types(element_type)
    return read_cache | unreify_label >> beam.ParDo(Unreify())
Exemple #4
0
    def expand(self,
               pcoll: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
        class Unreify(beam.DoFn):
            def process(self, e):
                yield e.windowed_value

        return (pcoll.pipeline
                | 'read' + self._label >> cache.ReadCache(
                    self._cache_manager, self._key)
                | 'unreify' + self._label >> beam.ParDo(Unreify()))
Exemple #5
0
    def test_read_cache_expansion(self):
        p = beam.Pipeline(runner=self.runner)

        # The cold run.
        pcoll = (p
                 | 'Create' >> beam.Create([1, 2, 3])
                 | 'Double' >> beam.Map(lambda x: x * 2)
                 | 'Square' >> beam.Map(lambda x: x**2))
        pipeline_proto = to_stable_runner_api(p)

        pipeline_info = pipeline_analyzer.PipelineInfo(
            pipeline_proto.components)
        pcoll_id = 'ref_PCollection_PCollection_3'  # Output PCollection of Square
        cache_label1 = pipeline_info.cache_label(pcoll_id)

        analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager,
                                                      pipeline_proto,
                                                      self.runner)
        pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api(
            analyzer.pipeline_proto_to_execute(), p.runner, p._options)
        pipeline_to_execute.run().wait_until_finish()

        # The second run.
        _ = (pcoll
             | 'Triple' >> beam.Map(lambda x: x * 3)
             | 'Cube' >> beam.Map(lambda x: x**3))
        analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager,
                                                      to_stable_runner_api(p),
                                                      self.runner)

        expected_pipeline = beam.Pipeline(runner=self.runner)
        pcoll1 = (expected_pipeline
                  | 'Load%s' % cache_label1 >> cache.ReadCache(
                      self.cache_manager, cache_label1))
        pcoll2 = pcoll1 | 'Triple' >> beam.Map(lambda x: x * 3)
        pcoll3 = pcoll2 | 'Cube' >> beam.Map(lambda x: x**3)

        cache_label2 = 'PColl-7654321'
        cache_label3 = 'PColl-3141593'

        # pylint: disable=expression-not-assigned
        pcoll2 | 'CacheSample%s' % cache_label2 >> cache.WriteCache(
            self.cache_manager, cache_label2, sample=True, sample_size=10)
        pcoll3 | 'CacheSample%s' % cache_label3 >> cache.WriteCache(
            self.cache_manager, cache_label3, sample=True, sample_size=10)
        pcoll3 | 'CacheFull%s' % cache_label3 >> cache.WriteCache(
            self.cache_manager, cache_label3)

        # Since ReadCache & WriteCache expansion leads to more than 50 PTransform
        # protos in the pipeline, a simple check of proto map size is enough.
        self.assertPipelineEqual(analyzer.pipeline_proto_to_execute(),
                                 to_stable_runner_api(expected_pipeline))
Exemple #6
0
        def _producing_transforms(pcoll_id, leaf=False):
            """Returns PTransforms (and their names) that produces the given PColl."""
            if pcoll_id in _producing_transforms.analyzed_pcoll_ids:
                return
            else:
                _producing_transforms.analyzed_pcoll_ids.add(pcoll_id)

            derivation = pipeline_info.derivation(pcoll_id)
            if self._cache_manager.exists('full', derivation.cache_label()):
                # If the PCollection is cached, yield ReadCache PTransform that reads
                # the PCollection and all its sub PTransforms.
                if not leaf:
                    caches_used.add(pcoll_id)

                    cache_label = pipeline_info.derivation(
                        pcoll_id).cache_label()
                    dummy_pcoll = pipeline | 'Load%s' % cache_label >> cache.ReadCache(
                        self._cache_manager, cache_label)

                    # Find the top level ReadCache composite PTransform.
                    read_cache = dummy_pcoll.producer
                    while read_cache.parent.parent:
                        read_cache = read_cache.parent

                    def _include_subtransforms(transform):
                        """Depth-first yield the PTransform itself and its sub PTransforms.
            """
                        yield transform
                        for subtransform in transform.parts:
                            for yielded in _include_subtransforms(
                                    subtransform):
                                yield yielded

                    for transform in _include_subtransforms(read_cache):
                        transform_proto = transform.to_runner_api(context)
                        if dummy_pcoll in transform.outputs.values():
                            transform_proto.outputs['None'] = pcoll_id
                        yield context.transforms.get_id(
                            transform), transform_proto

            else:
                transform_id, _ = pipeline_info.producer(pcoll_id)
                transform_proto = pipeline_proto.components.transforms[
                    transform_id]
                for input_id in transform_proto.inputs.values():
                    for transform in _producing_transforms(input_id):
                        yield transform
                yield transform_id, transform_proto
    def test_instrument_example_pipeline_to_read_cache(self):
        p_origin, init_pcoll, second_pcoll = self._example_pipeline()
        p_copy, _, _ = self._example_pipeline(False)

        # Mock as if cacheable PCollections are cached.
        init_pcoll_cache_key = self.cache_key_of('init_pcoll', init_pcoll)
        self._mock_write_cache(p_origin, [b'1', b'2', b'3'],
                               init_pcoll_cache_key)
        second_pcoll_cache_key = self.cache_key_of('second_pcoll',
                                                   second_pcoll)
        self._mock_write_cache(p_origin, [b'1', b'4', b'9'],
                               second_pcoll_cache_key)
        # Mark the completeness of PCollections from the original(user) pipeline.
        ie.current_env().mark_pcollection_computed((init_pcoll, second_pcoll))
        ie.current_env().add_derived_pipeline(p_origin, p_copy)
        instr.build_pipeline_instrument(p_copy)

        cached_init_pcoll = (
            p_origin
            | '_ReadCache_' + init_pcoll_cache_key >> cache.ReadCache(
                ie.current_env().get_cache_manager(p_origin),
                init_pcoll_cache_key)
            | 'unreify' >> beam.Map(lambda _: _))

        # second_pcoll is never used as input and there is no need to read cache.

        class TestReadCacheWireVisitor(PipelineVisitor):
            """Replace init_pcoll with cached_init_pcoll for all occuring inputs."""
            def enter_composite_transform(self, transform_node):
                self.visit_transform(transform_node)

            def visit_transform(self, transform_node):
                if transform_node.inputs:
                    main_inputs = dict(transform_node.main_inputs)
                    for tag, main_input in main_inputs.items():
                        if main_input == init_pcoll:
                            main_inputs[tag] = cached_init_pcoll
                    transform_node.main_inputs = main_inputs

        v = TestReadCacheWireVisitor()
        p_origin.visit(v)
        assert_pipeline_equal(self, p_origin, p_copy)
Exemple #8
0
    def test_instrument_example_pipeline_to_read_cache(self):
        p_origin, init_pcoll, second_pcoll = self._example_pipeline()
        p_copy, _, _ = self._example_pipeline(False)

        # Mock as if cacheable PCollections are cached.
        init_pcoll_cache_key = 'init_pcoll_' + str(id(init_pcoll)) + '_' + str(
            id(init_pcoll.producer))
        self._mock_write_cache([b'1', b'2', b'3'], init_pcoll_cache_key)
        second_pcoll_cache_key = 'second_pcoll_' + str(
            id(second_pcoll)) + '_' + str(id(second_pcoll.producer))
        self._mock_write_cache([b'1', b'4', b'9'], second_pcoll_cache_key)
        # Mark the completeness of PCollections from the original(user) pipeline.
        ie.current_env().mark_pcollection_computed(
            (p_origin, init_pcoll, second_pcoll))
        instr.build_pipeline_instrument(p_copy)

        cached_init_pcoll = (
            p_origin
            | '_ReadCache_' + init_pcoll_cache_key >> cache.ReadCache(
                ie.current_env().cache_manager(), init_pcoll_cache_key)
            | 'unreify' >> beam.Map(lambda _: _))

        # second_pcoll is never used as input and there is no need to read cache.

        class TestReadCacheWireVisitor(PipelineVisitor):
            """Replace init_pcoll with cached_init_pcoll for all occuring inputs."""
            def enter_composite_transform(self, transform_node):
                self.visit_transform(transform_node)

            def visit_transform(self, transform_node):
                if transform_node.inputs:
                    input_list = list(transform_node.inputs)
                    for i in range(len(input_list)):
                        if input_list[i] == init_pcoll:
                            input_list[i] = cached_init_pcoll
                    transform_node.inputs = tuple(input_list)

        v = TestReadCacheWireVisitor()
        p_origin.visit(v)
        assert_pipeline_equal(self, p_origin, p_copy)
Exemple #9
0
  def _read_cache(self, pipeline, pcoll, is_unbounded_source_output):
    """Reads a cached pvalue.

    A noop will cause the pipeline to execute the transform as
    it is and cache nothing from this transform for next run.

    Modifies:
      pipeline
    """
    # Makes sure the pcoll belongs to the pipeline being instrumented.
    if pcoll.pipeline is not pipeline:
      return
    # The keyed cache is always valid within this instrumentation.
    key = self.cache_key(pcoll)
    # Can only read from cache when the cache with expected key exists and its
    # computation has been completed.

    is_cached = self._cache_manager.exists('full', key)
    is_computed = (
        pcoll in self._runner_pcoll_to_user_pcoll and
        self._runner_pcoll_to_user_pcoll[pcoll] in
        ie.current_env().computed_pcollections)
    if ((is_cached and is_computed) or is_unbounded_source_output):
      if key not in self._cached_pcoll_read:
        # Mutates the pipeline with cache read transform attached
        # to root of the pipeline.

        # To put the cached value into the correct window, simply return a
        # WindowedValue constructed from the element.
        class Unreify(beam.DoFn):
          def process(self, e):
            yield e.windowed_value

        pcoll_from_cache = (
            pipeline
            | '{}{}'.format(READ_CACHE, key) >> cache.ReadCache(
                self._cache_manager, key)
            | '{}{}unreify'.format(READ_CACHE, key) >> beam.ParDo(Unreify()))
        self._cached_pcoll_read[key] = pcoll_from_cache
Exemple #10
0
    def test_instrument_example_pipeline_to_read_cache(self):
        p_origin, init_pcoll, second_pcoll = self._example_pipeline()
        p_copy, _, _ = self._example_pipeline(False)

        # Mock as if cacheable PCollections are cached.
        init_pcoll_cache_key = 'init_pcoll_' + str(
            id(init_pcoll)) + '_ref_PCollection_PCollection_10_' + str(
                id(init_pcoll.producer))
        self._mock_write_cache(init_pcoll, init_pcoll_cache_key)
        second_pcoll_cache_key = 'second_pcoll_' + str(
            id(second_pcoll)) + '_ref_PCollection_PCollection_11_' + str(
                id(second_pcoll.producer))
        self._mock_write_cache(second_pcoll, second_pcoll_cache_key)
        ie.current_env().cache_manager().exists = MagicMock(return_value=True)
        instr.pin(p_copy)

        cached_init_pcoll = p_origin | (
            '_ReadCache_' + init_pcoll_cache_key) >> cache.ReadCache(
                ie.current_env().cache_manager(), init_pcoll_cache_key)

        # second_pcoll is never used as input and there is no need to read cache.

        class TestReadCacheWireVisitor(PipelineVisitor):
            """Replace init_pcoll with cached_init_pcoll for all occuring inputs."""
            def enter_composite_transform(self, transform_node):
                self.visit_transform(transform_node)

            def visit_transform(self, transform_node):
                if transform_node.inputs:
                    input_list = list(transform_node.inputs)
                    for i in range(len(input_list)):
                        if input_list[i] == init_pcoll:
                            input_list[i] = cached_init_pcoll
                    transform_node.inputs = tuple(input_list)

        v = TestReadCacheWireVisitor()
        p_origin.visit(v)
        self.assertPipelineEqual(p_origin, p_copy)
Exemple #11
0
    def _insert_producing_transforms(self,
                                     pcoll_id,
                                     required_transforms,
                                     top_level_required_transforms,
                                     leaf=False):
        """Inserts PTransforms producing the given PCollection into the dicts.

    Args:
      pcoll_id: (str)
      required_transforms: (Dict[str, PTransform proto])
      top_level_required_transforms: (Dict[str, PTransform proto])
      leaf: (bool) whether the PCollection should be read from cache if the
        cache exists.

    Modifies:
      required_transforms
      top_level_required_transforms
      self._read_cache_ids
    """
        if pcoll_id in self._analyzed_pcoll_ids:
            return
        else:
            self._analyzed_pcoll_ids.add(pcoll_id)

        cache_label = self._pipeline_info.cache_label(pcoll_id)
        if self._cache_manager.exists('full', cache_label) and not leaf:
            self._caches_used.add(pcoll_id)

            cache_label = self._pipeline_info.cache_label(pcoll_id)
            dummy_pcoll = (self._pipeline
                           | 'Load%s' % cache_label >> cache.ReadCache(
                               self._cache_manager, cache_label))

            read_cache = self._top_level_producer(dummy_pcoll)
            read_cache_id = self._context.transforms.get_id(read_cache)
            read_cache_proto = read_cache.to_runner_api(self._context)
            read_cache_proto.outputs['None'] = pcoll_id
            top_level_required_transforms[read_cache_id] = read_cache_proto
            self._read_cache_ids.add(read_cache_id)

            for transform in self._include_subtransforms(read_cache):
                transform_id = self._context.transforms.get_id(transform)
                transform_proto = transform.to_runner_api(self._context)
                if dummy_pcoll in transform.outputs.values():
                    transform_proto.outputs['None'] = pcoll_id
                required_transforms[transform_id] = transform_proto

        else:
            pcoll = self._context.pcollections.get_by_id(pcoll_id)

            top_level_transform = self._top_level_producer(pcoll)
            for transform in self._include_subtransforms(top_level_transform):
                transform_id = self._context.transforms.get_id(transform)
                transform_proto = self._context.transforms.get_proto(transform)

                # Inserting ancestor PTransforms.
                for input_id in transform_proto.inputs.values():
                    self._insert_producing_transforms(
                        input_id, required_transforms,
                        top_level_required_transforms)
                required_transforms[transform_id] = transform_proto

            # Must be inserted after inserting ancestor PTransforms.
            top_level_id = self._context.transforms.get_id(top_level_transform)
            top_level_proto = self._context.transforms.get_proto(
                top_level_transform)
            top_level_required_transforms[top_level_id] = top_level_proto