def test_side_inputs(self):
        class SplitNumbersFn(DoFn):
            def process(self, element):
                if element < 0:
                    yield pvalue.TaggedOutput('tag_negative', element)
                else:
                    yield element

        class ProcessNumbersFn(DoFn):
            def process(self, element, negatives):
                yield element

        root_read = beam.Impulse()

        result = (self.pipeline
                  | 'read' >> root_read
                  | ParDo(SplitNumbersFn()).with_outputs('tag_negative',
                                                         main='positive'))
        positive, negative = result
        positive | ParDo(ProcessNumbersFn(), AsList(negative))

        self.pipeline.visit(self.visitor)

        root_transforms = [t.transform for t in self.visitor.root_transforms]
        self.assertEqual(root_transforms, [root_read])
        self.assertEqual(len(self.visitor.step_names), 3)
        self.assertEqual(len(self.visitor.views), 1)
        self.assertTrue(isinstance(self.visitor.views[0], pvalue.AsList))
Example #2
0
    def expand(self, pcoll):
        # This is a composite transform involves the following:
        #   1. Create a singleton of the user provided `query` and apply a ``ParDo``
        #   that splits the query into `num_splits` and assign each split query a
        #   unique `int` as the key. The resulting output is of the type
        #   ``PCollection[(int, Query)]``.
        #
        #   If the value of `num_splits` is less than or equal to 0, then the
        #   number of splits will be computed dynamically based on the size of the
        #   data for the `query`.
        #
        #   2. The resulting ``PCollection`` is sharded using a ``GroupByKey``
        #   operation. The queries are extracted from the (int, Iterable[Query]) and
        #   flattened to output a ``PCollection[Query]``.
        #
        #   3. In the third step, a ``ParDo`` reads entities for each query and
        #   outputs a ``PCollection[Entity]``.

        queries = (pcoll.pipeline
                   | 'UserQuery' >> Create([self._query])
                   | 'SplitQuery' >> ParDo(
                       ReadFromDatastore.SplitQueryFn(
                           self._project, self._query,
                           self._datastore_namespace, self._num_splits)))

        sharded_queries = (queries
                           | GroupByKey()
                           | Values()
                           | 'Flatten' >> FlatMap(lambda x: x))

        entities = sharded_queries | 'Read' >> ParDo(
            ReadFromDatastore.ReadFn(self._project, self._datastore_namespace))
        return entities
    def test_side_inputs(self):
        class SplitNumbersFn(DoFn):
            def process(self, element):
                if element < 0:
                    yield pvalue.OutputValue('tag_negative', element)
                else:
                    yield element

        class ProcessNumbersFn(DoFn):
            def process(self, element, negatives):
                yield element

        class DummySource(iobase.BoundedSource):
            pass

        root_read = Read(DummySource())

        result = (self.pipeline
                  | 'read' >> root_read
                  | ParDo(SplitNumbersFn()).with_outputs('tag_negative',
                                                         main='positive'))
        positive, negative = result
        positive | ParDo(ProcessNumbersFn(), AsList(negative))

        self.pipeline.visit(self.visitor)

        root_transforms = sorted(
            [t.transform for t in self.visitor.root_transforms])
        self.assertEqual(root_transforms, sorted([root_read]))
        self.assertEqual(len(self.visitor.step_names), 3)
        self.assertEqual(len(self.visitor.views), 1)
        self.assertTrue(isinstance(self.visitor.views[0], pvalue.AsList))
Example #4
0
 def expand(self, pcoll):
     if self._throttle_rampup:
         throttling_fn = RampupThrottlingFn(self._hint_num_workers)
         pcoll = (
             pcoll
             | 'Enforce throttling during ramp-up' >> ParDo(throttling_fn))
     return pcoll | 'Write Batch to Datastore' >> ParDo(self._mutate_fn)
Example #5
0
    def test_side_inputs(self):
        class SplitNumbersFn(NewDoFn):
            def process(self, element):
                if element < 0:
                    yield pvalue.SideOutputValue('tag_negative', element)
                else:
                    yield element

        class ProcessNumbersFn(NewDoFn):
            def process(self, element, negatives):
                yield element

        root_create = Create('create', [[-1, 2, 3]])

        result = (self.pipeline
                  | root_create
                  | ParDo(SplitNumbersFn()).with_outputs('tag_negative',
                                                         main='positive'))
        positive, negative = result
        positive | ParDo(ProcessNumbersFn(), AsList(negative))

        self.pipeline.visit(self.visitor)

        root_transforms = sorted(
            [t.transform for t in self.visitor.root_transforms])
        self.assertEqual(root_transforms, sorted([root_create]))
        self.assertEqual(len(self.visitor.step_names), 4)
        self.assertEqual(len(self.visitor.views), 1)
        self.assertTrue(
            isinstance(self.visitor.views[0], pvalue.ListPCollectionView))
Example #6
0
        def _process_numbers(pcoll, negatives):
            first_output = (pcoll
                            | 'process numbers step 1' >> ParDo(
                                ProcessNumbersFn(), negatives))

            second_output = (first_output
                             | 'process numbers step 2' >> ParDo(
                                 ProcessNumbersFn(), negatives))

            output_pc = ((first_output, second_output)
                         | 'flatten results' >> beam.Flatten())
            return output_pc
Example #7
0
  def test_pipeline_sdk_not_overridden(self):
    pipeline_options = PipelineOptions([
        '--experiments=beam_fn_api',
        '--experiments=use_unified_worker',
        '--temp_location',
        'gs://any-location/temp',
        '--sdk_container_image=dummy_prefix/dummy_name:dummy_tag'
    ])

    pipeline = Pipeline(options=pipeline_options)
    pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

    proto_pipeline, _ = pipeline.to_runner_api(return_context=True)

    dummy_env = DockerEnvironment(
        container_image='dummy_prefix/dummy_name:dummy_tag')
    proto_pipeline, _ = pipeline.to_runner_api(
        return_context=True, default_environment=dummy_env)

    # Accessing non-public method for testing.
    apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
        proto_pipeline, {}, pipeline_options)

    self.assertIsNotNone(2, len(proto_pipeline.components.environments))

    from apache_beam.utils import proto_utils
    found_override = False
    for env in proto_pipeline.components.environments.values():
      docker_payload = proto_utils.parse_Bytes(
          env.payload, beam_runner_api_pb2.DockerPayload)
      if docker_payload.container_image.startswith(
          names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY):
        found_override = True

    self.assertFalse(found_override)
 def expand(self, pcoll):
     mongodb_write_fn = MongoDBWriteFn(
         database_name=self.database_name,
         collection_name=self.collection_name,
         connection_string=self.connection_string,
         test_client=self.test_client)
     return pcoll | 'WriteToMongoDB' >> ParDo(mongodb_write_fn)
Example #9
0
    def expand(self, pcoll):
        p = pcoll.pipeline

        # TODO(pabloem): Use a different method to determine if streaming or batch.
        standard_options = p.options.view_as(StandardOptions)

        if (not callable(self.table_reference)
                and self.table_reference.projectId is None):
            self.table_reference.projectId = pcoll.pipeline.options.view_as(
                GoogleCloudOptions).project

        if standard_options.streaming:
            # TODO: Support load jobs for streaming pipelines.
            bigquery_write_fn = BigQueryWriteFn(
                table_id=self.table_reference.tableId,
                dataset_id=self.table_reference.datasetId,
                project_id=self.table_reference.projectId,
                batch_size=self.batch_size,
                schema=self.get_dict_table_schema(self.schema),
                create_disposition=self.create_disposition,
                write_disposition=self.write_disposition,
                kms_key=self.kms_key,
                test_client=self.test_client)
            return pcoll | 'WriteToBigQuery' >> ParDo(bigquery_write_fn)
        else:
            from apache_beam.io.gcp import bigquery_file_loads
            return pcoll | bigquery_file_loads.BigQueryBatchFileLoads(
                destination=self.table_reference,
                schema=self.get_dict_table_schema(self.schema),
                create_disposition=self.create_disposition,
                write_disposition=self.write_disposition,
                max_file_size=self.max_file_size,
                max_files_per_bundle=self.max_files_per_bundle,
                gs_location=self.gs_location,
                test_client=self.test_client)
 def expand(self, pcoll):
     mongodb_write_fn = ElasticSearchWriteFn(es_client=self.es_client,
                                             index_name=self.index_name,
                                             doc_type=self.doc_type,
                                             test_client=self.test_client,
                                             mapping=self.mapping)
     return pcoll | 'WriteToElasticSearch' >> ParDo(mongodb_write_fn)
Example #11
0
  def test_timestamp_param(self):
    class TestDoFn(DoFn):
      def process(self, element, timestamp=DoFn.TimestampParam):
        yield timestamp

    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
      assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
Example #12
0
  def test_element(self):
    class TestDoFn(DoFn):
      def process(self, element):
        yield element + 10

    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
      assert_that(pcoll, equal_to([11, 12]))
Example #13
0
def run(argv=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('input_topic',
                        type=str,
                        help="Input Pub/Sub topic name.")
    parser.add_argument(
        'output_table',
        type=str,
        help="Output BigQuery table name. Example: project.db.name")
    parser.add_argument('--model_project',
                        type=str,
                        help="Google Project ID with model.")
    parser.add_argument('--model_name',
                        type=str,
                        help="Name of the Google AI Platform model name.")
    parser.add_argument('--model_region',
                        type=str,
                        help="AI Platform region name.")
    parser.add_argument('--model_version',
                        type=str,
                        help="AI Platform model version.")

    known_args, pipeline_args = parser.parse_known_args(argv)

    _topic_comp = known_args.input_topic.split('/')
    if len(_topic_comp) != 4 or _topic_comp[0] != 'projects' or _topic_comp[
            2] != 'topics':
        raise ValueError("Table topic name has inappropriate format.")

    if len(known_args.output_table.split('.')) != 2:
        raise ValueError("Table name has inappropriate format.")

    inf_args = [
        known_args.model_project, known_args.model_name,
        known_args.model_region, known_args.model_version
    ]
    options = PipelineOptions(pipeline_args)
    options.view_as(SetupOptions).save_main_session = True
    options.view_as(StandardOptions).streaming = True

    p = Pipeline(options=options)
    _ = (p | 'read from pub/sub' >> ReadFromPubSub(
        known_args.input_topic).with_output_types(bytes)
         | 'windowing' >> WindowInto(window.FixedWindows(10, 0))
         | 'convert to dict' >> Map(json.loads)
         | 'pre processing' >> PreProcessing()
         | 'make inference' >> ParDo(MakeRemoteInferenceDoFn(*inf_args))
         | 'format message' >> Map(formatter)
         | 'write to BQ' >> WriteToBigQuery(
             table=known_args.output_table,
             schema=build_bq_schema(),
             create_disposition=BigQueryDisposition.CREATE_IF_NEEDED,
             write_disposition=BigQueryDisposition.WRITE_APPEND))
    if os.environ.get('DEPLOY'):
        p.run(
        )  # I use p.run() instead of "opening context `with Pipeline() as p`" because it need to exit after running.
    else:
        p.run().wait_until_finish()
Example #14
0
  def test_context_param(self):
    class TestDoFn(DoFn):
      def process(self, element, context=DoFn.ContextParam):
        yield context.element + 10

    pipeline = TestPipeline()
    pcoll = pipeline | 'Create' >> Create([1, 2])| 'Do' >> ParDo(TestDoFn())
    assert_that(pcoll, equal_to([11, 12]))
    pipeline.run()
Example #15
0
 def expand(self, pcoll):
     if self.project is None:
         self.project = pcoll.pipeline.options.view_as(
             GoogleCloudOptions).project
     if self.project is None:
         raise ValueError(
             'GCP project name needs to be specified in "project" pipeline option'
         )
     return pcoll | ParDo(
         _InspectFn(self.config, self.timeout, self.project))
Example #16
0
    def expand(self, pcoll):
        # This is a composite transform involves the following:
        #   1. Create a singleton of the user provided `query` and apply a ``ParDo``
        #   that splits the query into `num_splits` queries if possible.
        #
        #   If the value of `num_splits` is 0, the number of splits will be
        #   computed dynamically based on the size of the data for the `query`.
        #
        #   2. The resulting ``PCollection`` is sharded across workers using a
        #   ``Reshuffle`` operation.
        #
        #   3. In the third step, a ``ParDo`` reads entities for each query and
        #   outputs a ``PCollection[Entity]``.

        return (pcoll.pipeline
                | 'UserQuery' >> Create([self._query])
                | 'SplitQuery' >> ParDo(
                    ReadFromDatastore._SplitQueryFn(self._num_splits))
                | Reshuffle()
                | 'Read' >> ParDo(ReadFromDatastore._QueryFn()))
Example #17
0
 def expand(self, pvalue):
     return (pvalue
             | FlatMap(self._create_image_annotation_pairs)
             | util.BatchElements(min_batch_size=self.min_batch_size,
                                  max_batch_size=self.max_batch_size)
             | ParDo(
                 _ImageAnnotateFn(features=self.features,
                                  retry=self.retry,
                                  timeout=self.timeout,
                                  client_options=self.client_options,
                                  metadata=self.metadata)))
Example #18
0
 def expand(self, pcoll):
     if self.project is None:
         self.project = pcoll.pipeline.options.view_as(
             GoogleCloudOptions).project
     if self.project is None:
         raise ValueError(
             """GCP project name needs to be specified in "project" pipeline
         option""")
     return pcoll | ParDo(
         _CreateCatalogItemFn(self.project, self.retry, self.timeout,
                              self.metadata, self.catalog_name))
Example #19
0
    def test_window_param(self):
        class TestDoFn(DoFn):
            def process(self, element, window=DoFn.WindowParam):
                yield (element, (float(window.start), float(window.end)))

        pipeline = TestPipeline()
        pcoll = (pipeline
                 | Create([1, 7])
                 | Map(lambda x: TimestampedValue(x, x))
                 | WindowInto(windowfn=SlidingWindows(10, 5))
                 | ParDo(TestDoFn()))
        assert_that(
            pcoll,
            equal_to([(1, (-5, 5)), (1, (0, 10)), (7, (0, 10)), (7, (5, 15))]))
        pcoll2 = pcoll | 'Again' >> ParDo(TestDoFn())
        assert_that(pcoll2,
                    equal_to([((1, (-5, 5)), (-5, 5)), ((1, (0, 10)), (0, 10)),
                              ((7, (0, 10)), (0, 10)),
                              ((7, (5, 15)), (5, 15))]),
                    label='doubled windows')
        pipeline.run()
Example #20
0
 def expand(self, pcoll):
     if self.project is None:
         self.project = pcoll.pipeline.options.view_as(
             GoogleCloudOptions).project
     if self.project is None:
         raise ValueError(
             'GCP project name needs to be specified in "project" pipeline option'
         )
     return pcoll | ParDo(
         _PredictUserEventFn(self.project, self.retry, self.timeout,
                             self.metadata, self.catalog_name,
                             self.event_store, self.placement_id))
Example #21
0
 def expand(self, pcoll):
     if self.project is None:
         self.project = pcoll.pipeline.options.view_as(
             GoogleCloudOptions).project
     if self.project is None:
         raise ValueError(
             'GCP project name needs to be specified in "project" pipeline option'
         )
     return (pcoll | GroupIntoBatches.WithShardedKey(
         self.max_batch_size) | ParDo(
             _ImportCatalogItemsFn(self.project, self.retry, self.timeout,
                                   self.metadata, self.catalog_name)))
Example #22
0
  def test_side_input_tagged(self):
    class TestDoFn(DoFn):
      def process(self, element, prefix, suffix=DoFn.SideInputParam):
        return ['%s-%s-%s' % (prefix, element, suffix)]

    with TestPipeline() as pipeline:
      words_list = ['aa', 'bb', 'cc']
      words = pipeline | 'SomeWords' >> Create(words_list)
      prefix = 'zyx'
      suffix = pipeline | 'SomeString' >> Create(['xyz'])  # side in
      result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
          TestDoFn(), prefix, suffix=AsSingleton(suffix))
      assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
Example #23
0
    def test_sdk_harness_container_images_get_set(self):

        pipeline_options = PipelineOptions([
            '--experiments=beam_fn_api', '--experiments=use_unified_worker',
            '--temp_location', 'gs://any-location/temp'
        ])

        pipeline = Pipeline(options=pipeline_options)
        pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

        test_environment = DockerEnvironment(
            container_image='test_default_image')
        proto_pipeline, _ = pipeline.to_runner_api(
            return_context=True, default_environment=test_environment)

        # We have to manually add environments since Dataflow only sets
        # 'sdkHarnessContainerImages' when there are at least two environments.
        dummy_env = beam_runner_api_pb2.Environment(
            urn=common_urns.environments.DOCKER.urn,
            payload=(beam_runner_api_pb2.DockerPayload(
                container_image='dummy_image')).SerializeToString())
        proto_pipeline.components.environments['dummy_env_id'].CopyFrom(
            dummy_env)

        dummy_transform = beam_runner_api_pb2.PTransform(
            environment_id='dummy_env_id')
        proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom(
            dummy_transform)

        env = apiclient.Environment(
            [],  # packages
            pipeline_options,
            '2.0.0',  # any environment version
            FAKE_PIPELINE_URL,
            proto_pipeline,
            _sdk_image_overrides={
                '.*dummy.*': 'dummy_image',
                '.*test.*': 'test_default_image'
            })
        worker_pool = env.proto.workerPools[0]

        # For the test, a third environment get added since actual default
        # container image for Dataflow is different from 'test_default_image'
        # we've provided above.
        self.assertEqual(3, len(worker_pool.sdkHarnessContainerImages))

        # Container image should be overridden by a Dataflow specific URL.
        self.assertTrue(
            str.startswith(
                (worker_pool.sdkHarnessContainerImages[0]).containerImage,
                'gcr.io/cloud-dataflow/v1beta3/python'))
Example #24
0
    def test_java_sdk_harness_dedup(self):
        pipeline_options = PipelineOptions([
            '--experiments=beam_fn_api', '--experiments=use_unified_worker',
            '--temp_location', 'gs://any-location/temp'
        ])

        pipeline = Pipeline(options=pipeline_options)
        pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

        proto_pipeline, _ = pipeline.to_runner_api(return_context=True)

        dummy_env_1 = beam_runner_api_pb2.Environment(
            urn=common_urns.environments.DOCKER.urn,
            payload=(beam_runner_api_pb2.DockerPayload(
                container_image='apache/beam_java:dummy_tag')
                     ).SerializeToString())
        proto_pipeline.components.environments['dummy_env_id_1'].CopyFrom(
            dummy_env_1)

        dummy_transform_1 = beam_runner_api_pb2.PTransform(
            environment_id='dummy_env_id_1')
        proto_pipeline.components.transforms['dummy_transform_id_1'].CopyFrom(
            dummy_transform_1)

        dummy_env_2 = beam_runner_api_pb2.Environment(
            urn=common_urns.environments.DOCKER.urn,
            payload=(beam_runner_api_pb2.DockerPayload(
                container_image='apache/beam_java:dummy_tag')
                     ).SerializeToString())
        proto_pipeline.components.environments['dummy_env_id_2'].CopyFrom(
            dummy_env_2)

        dummy_transform_2 = beam_runner_api_pb2.PTransform(
            environment_id='dummy_env_id_2')
        proto_pipeline.components.transforms['dummy_transform_id_2'].CopyFrom(
            dummy_transform_2)

        # Accessing non-public method for testing.
        apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
            proto_pipeline, dict(), pipeline_options)

        # Only one of 'dummy_env_id_1' or 'dummy_env_id_2' should be in the set of
        # environment IDs used by the proto after Java environment de-duping.
        env_ids_from_transforms = [
            proto_pipeline.components.transforms[transform_id].environment_id
            for transform_id in proto_pipeline.components.transforms
        ]
        if 'dummy_env_id_1' in env_ids_from_transforms:
            self.assertTrue('dummy_env_id_2' not in env_ids_from_transforms)
        else:
            self.assertTrue('dummy_env_id_2' in env_ids_from_transforms)
Example #25
0
    def expand(self, pcoll):
        p = pcoll.pipeline

        # TODO(pabloem): Use a different method to determine if streaming or batch.
        standard_options = p.options.view_as(StandardOptions)

        if (not callable(self.table_reference)
                and self.table_reference.projectId is None):
            self.table_reference.projectId = pcoll.pipeline.options.view_as(
                GoogleCloudOptions).project

        if (standard_options.streaming
                or self.method == WriteToBigQuery.Method.STREAMING_INSERTS):
            # TODO: Support load jobs for streaming pipelines.
            bigquery_write_fn = BigQueryWriteFn(
                batch_size=self.batch_size,
                create_disposition=self.create_disposition,
                write_disposition=self.write_disposition,
                kms_key=self.kms_key,
                retry_strategy=self.insert_retry_strategy,
                test_client=self.test_client)

            # TODO: Use utility functions from BQTools
            table_fn = self._get_table_fn()

            outputs = (
                pcoll
                | 'AppendDestination' >> beam.Map(lambda x: (table_fn(x), x))
                | 'StreamInsertRows' >> ParDo(bigquery_write_fn).with_outputs(
                    BigQueryWriteFn.FAILED_ROWS, main='main'))

            return {
                BigQueryWriteFn.FAILED_ROWS:
                outputs[BigQueryWriteFn.FAILED_ROWS]
            }
        else:
            if standard_options.streaming:
                raise NotImplementedError(
                    'File Loads to BigQuery are only supported on Batch pipelines.'
                )

            from apache_beam.io.gcp import bigquery_file_loads
            return pcoll | bigquery_file_loads.BigQueryBatchFileLoads(
                destination=self.table_reference,
                schema=self.get_dict_table_schema(self.schema),
                create_disposition=self.create_disposition,
                write_disposition=self.write_disposition,
                max_file_size=self.max_file_size,
                max_files_per_bundle=self.max_files_per_bundle,
                gs_location=self.gs_location,
                test_client=self.test_client)
Example #26
0
    def expand(self, pcoll):
        p = pcoll.pipeline

        if (isinstance(self.table_reference, bigquery.TableReference)
                and self.table_reference.projectId is None):
            self.table_reference.projectId = pcoll.pipeline.options.view_as(
                GoogleCloudOptions).project

        method_to_use = self._compute_method(p, p.options)

        if method_to_use == WriteToBigQuery.Method.STREAMING_INSERTS:
            # TODO: Support load jobs for streaming pipelines.
            bigquery_write_fn = BigQueryWriteFn(
                schema=self.schema,
                batch_size=self.batch_size,
                create_disposition=self.create_disposition,
                write_disposition=self.write_disposition,
                kms_key=self.kms_key,
                retry_strategy=self.insert_retry_strategy,
                test_client=self.test_client)

            outputs = (
                pcoll
                | 'AppendDestination' >> beam.ParDo(
                    bigquery_tools.AppendDestinationsFn(self.table_reference))
                | 'StreamInsertRows' >> ParDo(bigquery_write_fn).with_outputs(
                    BigQueryWriteFn.FAILED_ROWS, main='main'))

            return {
                BigQueryWriteFn.FAILED_ROWS:
                outputs[BigQueryWriteFn.FAILED_ROWS]
            }
        else:
            if p.options.view_as(StandardOptions).streaming:
                raise NotImplementedError(
                    'File Loads to BigQuery are only supported on Batch pipelines.'
                )

            from apache_beam.io.gcp import bigquery_file_loads
            return (pcoll
                    | bigquery_file_loads.BigQueryBatchFileLoads(
                        destination=self.table_reference,
                        schema=self.schema,
                        create_disposition=self.create_disposition,
                        write_disposition=self.write_disposition,
                        max_file_size=self.max_file_size,
                        max_files_per_bundle=self.max_files_per_bundle,
                        custom_gcs_temp_location=self.custom_gcs_temp_location,
                        test_client=self.test_client,
                        validate=self._validate))
Example #27
0
 def expand(self, pcoll):
     if self.table_reference.projectId is None:
         self.table_reference.projectId = pcoll.pipeline.options.view_as(
             GoogleCloudOptions).project
     bigquery_write_fn = BigQueryWriteFn(
         table_id=self.table_reference.tableId,
         dataset_id=self.table_reference.datasetId,
         project_id=self.table_reference.projectId,
         batch_size=self.batch_size,
         schema=self.get_dict_table_schema(self.schema),
         create_disposition=self.create_disposition,
         write_disposition=self.write_disposition,
         test_client=self.test_client)
     return pcoll | 'WriteToBigQuery' >> ParDo(bigquery_write_fn)
Example #28
0
  def test_default_environment_get_set(self):

    pipeline_options = PipelineOptions([
        '--experiments=beam_fn_api',
        '--experiments=use_unified_worker',
        '--temp_location',
        'gs://any-location/temp'
    ])

    pipeline = Pipeline(options=pipeline_options)
    pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

    test_environment = DockerEnvironment(container_image='test_default_image')
    proto_pipeline, _ = pipeline.to_runner_api(
        return_context=True, default_environment=test_environment)

    dummy_env = beam_runner_api_pb2.Environment(
        urn=common_urns.environments.DOCKER.urn,
        payload=(
            beam_runner_api_pb2.DockerPayload(
                container_image='dummy_image')).SerializeToString())
    proto_pipeline.components.environments['dummy_env_id'].CopyFrom(dummy_env)

    dummy_transform = beam_runner_api_pb2.PTransform(
        environment_id='dummy_env_id')
    proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom(
        dummy_transform)

    env = apiclient.Environment(
        [],  # packages
        pipeline_options,
        '2.0.0',  # any environment version
        FAKE_PIPELINE_URL,
        proto_pipeline,
        _sdk_image_overrides={
            '.*dummy.*': 'dummy_image', '.*test.*': 'test_default_image'
        })
    worker_pool = env.proto.workerPools[0]

    self.assertEqual(2, len(worker_pool.sdkHarnessContainerImages))

    images_from_proto = [
        sdk_info.containerImage
        for sdk_info in worker_pool.sdkHarnessContainerImages
    ]
    self.assertIn('test_default_image', images_from_proto)
Example #29
0
  def test_non_apache_container_not_overridden(self):
    pipeline_options = PipelineOptions([
        '--experiments=beam_fn_api',
        '--experiments=use_unified_worker',
        '--temp_location',
        'gs://any-location/temp'
    ])

    pipeline = Pipeline(options=pipeline_options)
    pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

    proto_pipeline, _ = pipeline.to_runner_api(return_context=True)

    dummy_env = beam_runner_api_pb2.Environment(
        urn=common_urns.environments.DOCKER.urn,
        payload=(
            beam_runner_api_pb2.DockerPayload(
                container_image='other_org/dummy_name:dummy_tag')
        ).SerializeToString())
    proto_pipeline.components.environments['dummy_env_id'].CopyFrom(dummy_env)

    dummy_transform = beam_runner_api_pb2.PTransform(
        environment_id='dummy_env_id')
    proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom(
        dummy_transform)

    # Accessing non-public method for testing.
    apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
        proto_pipeline, dict(), pipeline_options)

    self.assertIsNotNone(2, len(proto_pipeline.components.environments))

    from apache_beam.utils import proto_utils
    found_override = False
    for env in proto_pipeline.components.environments.values():
      docker_payload = proto_utils.parse_Bytes(
          env.payload, beam_runner_api_pb2.DockerPayload)
      if docker_payload.container_image.startswith(
          names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY):
        found_override = True

    self.assertFalse(found_override)
Example #30
0
    def test_side_inputs(self):
        class SplitNumbersFn(DoFn):
            def process(self, element):
                if element < 0:
                    yield pvalue.TaggedOutput('tag_negative', element)
                else:
                    yield element

        class ProcessNumbersFn(DoFn):
            def process(self, element, negatives):
                yield element

        def _process_numbers(pcoll, negatives):
            first_output = (pcoll
                            | 'process numbers step 1' >> ParDo(
                                ProcessNumbersFn(), negatives))

            second_output = (first_output
                             | 'process numbers step 2' >> ParDo(
                                 ProcessNumbersFn(), negatives))

            output_pc = ((first_output, second_output)
                         | 'flatten results' >> beam.Flatten())
            return output_pc

        root_read = beam.Impulse()

        result = (self.pipeline
                  | 'read' >> root_read
                  | ParDo(SplitNumbersFn()).with_outputs('tag_negative',
                                                         main='positive'))
        positive, negative = result
        _process_numbers(positive, AsList(negative))

        self.pipeline.visit(self.visitor)

        root_transforms = [t.transform for t in self.visitor.root_transforms]
        self.assertEqual(root_transforms, [root_read])
        self.assertEqual(len(self.visitor.step_names), 5)
        self.assertEqual(len(self.visitor.views), 1)
        self.assertTrue(isinstance(self.visitor.views[0], pvalue.AsList))