Example #1
0
  def run_pipeline(self, pipeline):
    if not hasattr(self, '_desired_cache_labels'):
      self._desired_cache_labels = set()

    # Invoke a round trip through the runner API. This makes sure the Pipeline
    # proto is stable.
    pipeline = beam.pipeline.Pipeline.from_runner_api(
        pipeline.to_runner_api(),
        pipeline.runner,
        pipeline._options)

    # Snapshot the pipeline in a portable proto before mutating it.
    pipeline_proto, original_context = pipeline.to_runner_api(
        return_context=True)
    pcolls_to_pcoll_id = self._pcolls_to_pcoll_id(pipeline, original_context)

    analyzer = pipeline_analyzer.PipelineAnalyzer(self._cache_manager,
                                                  pipeline_proto,
                                                  self._underlying_runner,
                                                  pipeline._options,
                                                  self._desired_cache_labels)
    # Should be only accessed for debugging purpose.
    self._analyzer = analyzer

    pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api(
        analyzer.pipeline_proto_to_execute(),
        self._underlying_runner,
        pipeline._options)

    pipeline_info = pipeline_analyzer.PipelineInfo(pipeline_proto.components)

    display = display_manager.DisplayManager(
        pipeline_info=pipeline_info,
        pipeline_proto=pipeline_proto,
        caches_used=analyzer.caches_used(),
        cache_manager=self._cache_manager,
        referenced_pcollections=analyzer.top_level_referenced_pcollection_ids(),
        required_transforms=analyzer.top_level_required_transforms())
    display.start_periodic_update()
    result = pipeline_to_execute.run()
    result.wait_until_finish()
    display.stop_periodic_update()

    return PipelineResult(result, self, pipeline_info, self._cache_manager,
                          pcolls_to_pcoll_id)
Example #2
0
    def run_pipeline(self, pipeline):
        if not hasattr(self, '_desired_cache_labels'):
            self._desired_cache_labels = set()

        # Invoke a round trip through the runner API. This makes sure the Pipeline
        # proto is stable.
        pipeline = beam.pipeline.Pipeline.from_runner_api(
            pipeline.to_runner_api(), pipeline.runner, pipeline._options)

        # Snapshot the pipeline in a portable proto before mutating it.
        pipeline_proto, original_context = pipeline.to_runner_api(
            return_context=True)
        pcolls_to_pcoll_id = self._pcolls_to_pcoll_id(pipeline,
                                                      original_context)

        # TODO(qinyeli): Refactor the rest of this function into
        # def manipulate_pipeline(pipeline_proto) -> pipeline_proto_to_run:

        # Make a copy of the original pipeline to avoid accidental manipulation
        pipeline, context = beam.pipeline.Pipeline.from_runner_api(
            pipeline_proto,
            self._underlying_runner,
            pipeline._options,  # pylint: disable=protected-access
            return_context=True)
        pipeline_info = PipelineInfo(pipeline_proto.components)

        caches_used = set()

        def _producing_transforms(pcoll_id, leaf=False):
            """Returns PTransforms (and their names) that produces the given PColl."""
            if pcoll_id in _producing_transforms.analyzed_pcoll_ids:
                return
            else:
                _producing_transforms.analyzed_pcoll_ids.add(pcoll_id)

            derivation = pipeline_info.derivation(pcoll_id)
            if self._cache_manager.exists('full', derivation.cache_label()):
                # If the PCollection is cached, yield ReadCache PTransform that reads
                # the PCollection and all its sub PTransforms.
                if not leaf:
                    caches_used.add(pcoll_id)

                    cache_label = pipeline_info.derivation(
                        pcoll_id).cache_label()
                    dummy_pcoll = pipeline | 'Load%s' % cache_label >> cache.ReadCache(
                        self._cache_manager, cache_label)

                    # Find the top level ReadCache composite PTransform.
                    read_cache = dummy_pcoll.producer
                    while read_cache.parent.parent:
                        read_cache = read_cache.parent

                    def _include_subtransforms(transform):
                        """Depth-first yield the PTransform itself and its sub PTransforms.
            """
                        yield transform
                        for subtransform in transform.parts:
                            for yielded in _include_subtransforms(
                                    subtransform):
                                yield yielded

                    for transform in _include_subtransforms(read_cache):
                        transform_proto = transform.to_runner_api(context)
                        if dummy_pcoll in transform.outputs.values():
                            transform_proto.outputs['None'] = pcoll_id
                        yield context.transforms.get_id(
                            transform), transform_proto

            else:
                transform_id, _ = pipeline_info.producer(pcoll_id)
                transform_proto = pipeline_proto.components.transforms[
                    transform_id]
                for input_id in transform_proto.inputs.values():
                    for transform in _producing_transforms(input_id):
                        yield transform
                yield transform_id, transform_proto

        desired_pcollections = self._desired_pcollections(pipeline_info)

        # TODO(qinyeli): Preserve composite structure.
        required_transforms = collections.OrderedDict()
        _producing_transforms.analyzed_pcoll_ids = set()
        for pcoll_id in desired_pcollections:
            # TODO(qinyeli): Collections consumed by no-output transforms.
            required_transforms.update(_producing_transforms(pcoll_id, True))

        referenced_pcollections = self._referenced_pcollections(
            pipeline_proto, required_transforms)

        required_transforms['_root'] = beam_runner_api_pb2.PTransform(
            subtransforms=required_transforms.keys())

        pipeline_to_execute = copy.deepcopy(pipeline_proto)
        pipeline_to_execute.root_transform_ids[:] = ['_root']
        set_proto_map(pipeline_to_execute.components.transforms,
                      required_transforms)
        set_proto_map(pipeline_to_execute.components.pcollections,
                      referenced_pcollections)
        set_proto_map(pipeline_to_execute.components.coders,
                      context.to_runner_api().coders)

        pipeline_slice, context = beam.pipeline.Pipeline.from_runner_api(
            pipeline_to_execute,
            self._underlying_runner,
            pipeline._options,  # pylint: disable=protected-access
            return_context=True)

        # TODO(qinyeli): cache only top-level pcollections.
        for pcoll_id in pipeline_info.all_pcollections():
            if pcoll_id not in referenced_pcollections:
                continue
            cache_label = pipeline_info.derivation(pcoll_id).cache_label()
            pcoll = context.pcollections.get_by_id(pcoll_id)

            if pcoll_id in desired_pcollections:
                # pylint: disable=expression-not-assigned
                pcoll | 'CacheFull%s' % cache_label >> cache.WriteCache(
                    self._cache_manager, cache_label)

            if pcoll_id in referenced_pcollections:
                # pylint: disable=expression-not-assigned
                pcoll | 'CacheSample%s' % cache_label >> cache.WriteCache(
                    self._cache_manager,
                    cache_label,
                    sample=True,
                    sample_size=SAMPLE_SIZE)

        display = display_manager.DisplayManager(
            pipeline_info=pipeline_info,
            pipeline_proto=pipeline_proto,
            caches_used=caches_used,
            cache_manager=self._cache_manager,
            referenced_pcollections=referenced_pcollections,
            required_transforms=required_transforms)
        display.start_periodic_update()
        result = pipeline_slice.run()
        result.wait_until_finish()
        display.stop_periodic_update()

        return PipelineResult(result, self, pipeline_info, self._cache_manager,
                              pcolls_to_pcoll_id)
Example #3
0
  def run_pipeline(self, pipeline):
    if not hasattr(self, '_desired_cache_labels'):
      self._desired_cache_labels = set()
    print('Running...')

    # Snapshot the pipeline in a portable proto before mutating it.
    pipeline_proto, original_context = pipeline.to_runner_api(
        return_context=True)

    pcolls_to_pcoll_id = self._pcolls_to_pcoll_id(pipeline, original_context)

    pipeline_info = PipelineInfo(pipeline_proto.components)
    context = pipeline_context.PipelineContext(pipeline_proto.components)

    caches_used = set()

    def _producing_transforms(pcoll_id, leaf=False):
      """Returns PTransforms (and their names) that produces the given PColl."""
      derivation = pipeline_info.derivation(pcoll_id)
      if self._cache_manager.exists('full', derivation.cache_label()):
        if not leaf:
          caches_used.add(pcoll_id)
          yield ('Read' + derivation.cache_label(),
                 beam_runner_api_pb2.PTransform(
                     unique_name='Read' + derivation.cache_label(),
                     spec=beam.io.Read(
                         beam.io.ReadFromText(
                             self._cache_manager.glob_path(
                                 'full', derivation.cache_label()),
                             coder=SafeFastPrimitivesCoder())._source)
                     .to_runner_api(context),
                     outputs={'None': pcoll_id}))
      else:
        transform_id, _ = pipeline_info.producer(pcoll_id)
        transform_proto = pipeline_proto.components.transforms[transform_id]
        for input_id in transform_proto.inputs.values():
          for transform in _producing_transforms(input_id):
            yield transform
        yield transform_id, transform_proto

    desired_pcollections = self._desired_pcollections(pipeline_info)

    # TODO(qinyeli): Preserve composite structure.
    required_transforms = collections.OrderedDict()
    for pcoll_id in desired_pcollections:
      # TODO(qinyeli): Collections consumed by no-output transforms.
      required_transforms.update(_producing_transforms(pcoll_id, True))

    referenced_pcollections = self._referenced_pcollections(
        pipeline_proto, required_transforms)

    required_transforms['_root'] = beam_runner_api_pb2.PTransform(
        subtransforms=required_transforms.keys())

    pipeline_to_execute = copy.deepcopy(pipeline_proto)
    pipeline_to_execute.root_transform_ids[:] = ['_root']
    set_proto_map(pipeline_to_execute.components.transforms,
                  required_transforms)
    set_proto_map(pipeline_to_execute.components.pcollections,
                  referenced_pcollections)
    set_proto_map(pipeline_to_execute.components.coders,
                  context.to_runner_api().coders)

    pipeline_slice, context = beam.pipeline.Pipeline.from_runner_api(
        pipeline_to_execute,
        self._underlying_runner,
        pipeline._options,  # pylint: disable=protected-access
        return_context=True)
    pcolls_to_write = {}
    pcolls_to_sample = {}

    for pcoll_id in pipeline_info.all_pcollections():
      if pcoll_id not in referenced_pcollections:
        continue
      cache_label = pipeline_info.derivation(pcoll_id).cache_label()
      if pcoll_id in desired_pcollections:
        pcolls_to_write[cache_label] = context.pcollections.get_by_id(pcoll_id)
      if pcoll_id in referenced_pcollections:
        pcolls_to_sample[cache_label] = context.pcollections.get_by_id(pcoll_id)

    # pylint: disable=expression-not-assigned
    if pcolls_to_write:
      pcolls_to_write | WriteCache(self._cache_manager)
    if pcolls_to_sample:
      pcolls_to_sample | 'WriteSample' >> WriteCache(
          self._cache_manager, sample=True)

    display = display_manager.DisplayManager(
        pipeline_info=pipeline_info,
        pipeline_proto=pipeline_proto,
        caches_used=caches_used,
        cache_manager=self._cache_manager,
        referenced_pcollections=referenced_pcollections,
        required_transforms=required_transforms)
    display.start_periodic_update()
    result = pipeline_slice.run()
    result.wait_until_finish()
    display.stop_periodic_update()

    return PipelineResult(result, self, pipeline_info, self._cache_manager,
                          pcolls_to_pcoll_id)