Exemple #1
0
    def test_read_cache_expansion(self):
        p = beam.Pipeline(runner=self.runner)

        # The cold run.
        pcoll = (p
                 | 'Create' >> beam.Create([1, 2, 3])
                 | 'Double' >> beam.Map(lambda x: x * 2)
                 | 'Square' >> beam.Map(lambda x: x**2))
        pipeline_proto = to_stable_runner_api(p)

        pipeline_info = pipeline_analyzer.PipelineInfo(
            pipeline_proto.components)
        pcoll_id = 'ref_PCollection_PCollection_3'  # Output PCollection of Square
        cache_label1 = pipeline_info.cache_label(pcoll_id)

        analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager,
                                                      pipeline_proto,
                                                      self.runner)
        pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api(
            analyzer.pipeline_proto_to_execute(), p.runner, p._options)
        pipeline_to_execute.run().wait_until_finish()

        # The second run.
        _ = (pcoll
             | 'Triple' >> beam.Map(lambda x: x * 3)
             | 'Cube' >> beam.Map(lambda x: x**3))
        analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager,
                                                      to_stable_runner_api(p),
                                                      self.runner)

        expected_pipeline = beam.Pipeline(runner=self.runner)
        pcoll1 = (expected_pipeline
                  | 'Load%s' % cache_label1 >> cache.ReadCache(
                      self.cache_manager, cache_label1))
        pcoll2 = pcoll1 | 'Triple' >> beam.Map(lambda x: x * 3)
        pcoll3 = pcoll2 | 'Cube' >> beam.Map(lambda x: x**3)

        cache_label2 = 'PColl-7654321'
        cache_label3 = 'PColl-3141593'

        # pylint: disable=expression-not-assigned
        pcoll2 | 'CacheSample%s' % cache_label2 >> cache.WriteCache(
            self.cache_manager, cache_label2, sample=True, sample_size=10)
        pcoll3 | 'CacheSample%s' % cache_label3 >> cache.WriteCache(
            self.cache_manager, cache_label3, sample=True, sample_size=10)
        pcoll3 | 'CacheFull%s' % cache_label3 >> cache.WriteCache(
            self.cache_manager, cache_label3)

        # Since ReadCache & WriteCache expansion leads to more than 50 PTransform
        # protos in the pipeline, a simple check of proto map size is enough.
        self.assertPipelineEqual(analyzer.pipeline_proto_to_execute(),
                                 to_stable_runner_api(expected_pipeline))
  def run_pipeline(self, pipeline):
    if not hasattr(self, '_desired_cache_labels'):
      self._desired_cache_labels = set()

    # Invoke a round trip through the runner API. This makes sure the Pipeline
    # proto is stable.
    pipeline = beam.pipeline.Pipeline.from_runner_api(
        pipeline.to_runner_api(),
        pipeline.runner,
        pipeline._options)

    # Snapshot the pipeline in a portable proto before mutating it.
    pipeline_proto, original_context = pipeline.to_runner_api(
        return_context=True)
    pcolls_to_pcoll_id = self._pcolls_to_pcoll_id(pipeline, original_context)

    analyzer = pipeline_analyzer.PipelineAnalyzer(self._cache_manager,
                                                  pipeline_proto,
                                                  self._underlying_runner,
                                                  pipeline._options,
                                                  self._desired_cache_labels)
    # Should be only accessed for debugging purpose.
    self._analyzer = analyzer

    pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api(
        analyzer.pipeline_proto_to_execute(),
        self._underlying_runner,
        pipeline._options)

    pipeline_info = pipeline_analyzer.PipelineInfo(pipeline_proto.components)

    display = display_manager.DisplayManager(
        pipeline_info=pipeline_info,
        pipeline_proto=pipeline_proto,
        caches_used=analyzer.caches_used(),
        cache_manager=self._cache_manager,
        referenced_pcollections=analyzer.top_level_referenced_pcollection_ids(),
        required_transforms=analyzer.top_level_required_transforms())
    display.start_periodic_update()
    result = pipeline_to_execute.run()
    result.wait_until_finish()
    display.stop_periodic_update()

    return PipelineResult(result, self, pipeline_info, self._cache_manager,
                          pcolls_to_pcoll_id)
Exemple #3
0
    def test_passthrough(self):
        """
    Test that PTransforms which pass through their input PCollection can be
    used with PipelineInfo.
    """
        class Passthrough(beam.PTransform):
            def expand(self, pcoll):
                return pcoll

        p = beam.Pipeline(runner=self.runner)
        p | beam.Impulse() | Passthrough()  # pylint: disable=expression-not-assigned
        proto = to_stable_runner_api(p).components
        info = pipeline_analyzer.PipelineInfo(proto)
        for pcoll_id in info.all_pcollections():
            # FIXME: If PipelineInfo does not support passthrough PTransforms, this
            #        will only fail some of the time, depending on the ordering of
            #        transforms in the Pipeline proto.

            # Should not throw exception
            info.cache_label(pcoll_id)