def test_side_inputs(self): class SplitNumbersFn(DoFn): def process(self, element): if element < 0: yield pvalue.TaggedOutput('tag_negative', element) else: yield element class ProcessNumbersFn(DoFn): def process(self, element, negatives): yield element root_read = beam.Impulse() result = (self.pipeline | 'read' >> root_read | ParDo(SplitNumbersFn()).with_outputs('tag_negative', main='positive')) positive, negative = result positive | ParDo(ProcessNumbersFn(), AsList(negative)) self.pipeline.visit(self.visitor) root_transforms = [t.transform for t in self.visitor.root_transforms] self.assertEqual(root_transforms, [root_read]) self.assertEqual(len(self.visitor.step_names), 3) self.assertEqual(len(self.visitor.views), 1) self.assertTrue(isinstance(self.visitor.views[0], pvalue.AsList))
def expand(self, pcoll): # This is a composite transform involves the following: # 1. Create a singleton of the user provided `query` and apply a ``ParDo`` # that splits the query into `num_splits` and assign each split query a # unique `int` as the key. The resulting output is of the type # ``PCollection[(int, Query)]``. # # If the value of `num_splits` is less than or equal to 0, then the # number of splits will be computed dynamically based on the size of the # data for the `query`. # # 2. The resulting ``PCollection`` is sharded using a ``GroupByKey`` # operation. The queries are extracted from the (int, Iterable[Query]) and # flattened to output a ``PCollection[Query]``. # # 3. In the third step, a ``ParDo`` reads entities for each query and # outputs a ``PCollection[Entity]``. queries = (pcoll.pipeline | 'UserQuery' >> Create([self._query]) | 'SplitQuery' >> ParDo( ReadFromDatastore.SplitQueryFn( self._project, self._query, self._datastore_namespace, self._num_splits))) sharded_queries = (queries | GroupByKey() | Values() | 'Flatten' >> FlatMap(lambda x: x)) entities = sharded_queries | 'Read' >> ParDo( ReadFromDatastore.ReadFn(self._project, self._datastore_namespace)) return entities
def test_side_inputs(self): class SplitNumbersFn(DoFn): def process(self, element): if element < 0: yield pvalue.OutputValue('tag_negative', element) else: yield element class ProcessNumbersFn(DoFn): def process(self, element, negatives): yield element class DummySource(iobase.BoundedSource): pass root_read = Read(DummySource()) result = (self.pipeline | 'read' >> root_read | ParDo(SplitNumbersFn()).with_outputs('tag_negative', main='positive')) positive, negative = result positive | ParDo(ProcessNumbersFn(), AsList(negative)) self.pipeline.visit(self.visitor) root_transforms = sorted( [t.transform for t in self.visitor.root_transforms]) self.assertEqual(root_transforms, sorted([root_read])) self.assertEqual(len(self.visitor.step_names), 3) self.assertEqual(len(self.visitor.views), 1) self.assertTrue(isinstance(self.visitor.views[0], pvalue.AsList))
def expand(self, pcoll): if self._throttle_rampup: throttling_fn = RampupThrottlingFn(self._hint_num_workers) pcoll = ( pcoll | 'Enforce throttling during ramp-up' >> ParDo(throttling_fn)) return pcoll | 'Write Batch to Datastore' >> ParDo(self._mutate_fn)
def test_side_inputs(self): class SplitNumbersFn(NewDoFn): def process(self, element): if element < 0: yield pvalue.SideOutputValue('tag_negative', element) else: yield element class ProcessNumbersFn(NewDoFn): def process(self, element, negatives): yield element root_create = Create('create', [[-1, 2, 3]]) result = (self.pipeline | root_create | ParDo(SplitNumbersFn()).with_outputs('tag_negative', main='positive')) positive, negative = result positive | ParDo(ProcessNumbersFn(), AsList(negative)) self.pipeline.visit(self.visitor) root_transforms = sorted( [t.transform for t in self.visitor.root_transforms]) self.assertEqual(root_transforms, sorted([root_create])) self.assertEqual(len(self.visitor.step_names), 4) self.assertEqual(len(self.visitor.views), 1) self.assertTrue( isinstance(self.visitor.views[0], pvalue.ListPCollectionView))
def _process_numbers(pcoll, negatives): first_output = (pcoll | 'process numbers step 1' >> ParDo( ProcessNumbersFn(), negatives)) second_output = (first_output | 'process numbers step 2' >> ParDo( ProcessNumbersFn(), negatives)) output_pc = ((first_output, second_output) | 'flatten results' >> beam.Flatten()) return output_pc
def test_pipeline_sdk_not_overridden(self): pipeline_options = PipelineOptions([ '--experiments=beam_fn_api', '--experiments=use_unified_worker', '--temp_location', 'gs://any-location/temp', '--sdk_container_image=dummy_prefix/dummy_name:dummy_tag' ]) pipeline = Pipeline(options=pipeline_options) pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned proto_pipeline, _ = pipeline.to_runner_api(return_context=True) dummy_env = DockerEnvironment( container_image='dummy_prefix/dummy_name:dummy_tag') proto_pipeline, _ = pipeline.to_runner_api( return_context=True, default_environment=dummy_env) # Accessing non-public method for testing. apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( proto_pipeline, {}, pipeline_options) self.assertIsNotNone(2, len(proto_pipeline.components.environments)) from apache_beam.utils import proto_utils found_override = False for env in proto_pipeline.components.environments.values(): docker_payload = proto_utils.parse_Bytes( env.payload, beam_runner_api_pb2.DockerPayload) if docker_payload.container_image.startswith( names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY): found_override = True self.assertFalse(found_override)
def expand(self, pcoll): mongodb_write_fn = MongoDBWriteFn( database_name=self.database_name, collection_name=self.collection_name, connection_string=self.connection_string, test_client=self.test_client) return pcoll | 'WriteToMongoDB' >> ParDo(mongodb_write_fn)
def expand(self, pcoll): p = pcoll.pipeline # TODO(pabloem): Use a different method to determine if streaming or batch. standard_options = p.options.view_as(StandardOptions) if (not callable(self.table_reference) and self.table_reference.projectId is None): self.table_reference.projectId = pcoll.pipeline.options.view_as( GoogleCloudOptions).project if standard_options.streaming: # TODO: Support load jobs for streaming pipelines. bigquery_write_fn = BigQueryWriteFn( table_id=self.table_reference.tableId, dataset_id=self.table_reference.datasetId, project_id=self.table_reference.projectId, batch_size=self.batch_size, schema=self.get_dict_table_schema(self.schema), create_disposition=self.create_disposition, write_disposition=self.write_disposition, kms_key=self.kms_key, test_client=self.test_client) return pcoll | 'WriteToBigQuery' >> ParDo(bigquery_write_fn) else: from apache_beam.io.gcp import bigquery_file_loads return pcoll | bigquery_file_loads.BigQueryBatchFileLoads( destination=self.table_reference, schema=self.get_dict_table_schema(self.schema), create_disposition=self.create_disposition, write_disposition=self.write_disposition, max_file_size=self.max_file_size, max_files_per_bundle=self.max_files_per_bundle, gs_location=self.gs_location, test_client=self.test_client)
def expand(self, pcoll): mongodb_write_fn = ElasticSearchWriteFn(es_client=self.es_client, index_name=self.index_name, doc_type=self.doc_type, test_client=self.test_client, mapping=self.mapping) return pcoll | 'WriteToElasticSearch' >> ParDo(mongodb_write_fn)
def test_timestamp_param(self): class TestDoFn(DoFn): def process(self, element, timestamp=DoFn.TimestampParam): yield timestamp with TestPipeline() as pipeline: pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn()) assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
def test_element(self): class TestDoFn(DoFn): def process(self, element): yield element + 10 with TestPipeline() as pipeline: pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn()) assert_that(pcoll, equal_to([11, 12]))
def run(argv=None): parser = argparse.ArgumentParser() parser.add_argument('input_topic', type=str, help="Input Pub/Sub topic name.") parser.add_argument( 'output_table', type=str, help="Output BigQuery table name. Example: project.db.name") parser.add_argument('--model_project', type=str, help="Google Project ID with model.") parser.add_argument('--model_name', type=str, help="Name of the Google AI Platform model name.") parser.add_argument('--model_region', type=str, help="AI Platform region name.") parser.add_argument('--model_version', type=str, help="AI Platform model version.") known_args, pipeline_args = parser.parse_known_args(argv) _topic_comp = known_args.input_topic.split('/') if len(_topic_comp) != 4 or _topic_comp[0] != 'projects' or _topic_comp[ 2] != 'topics': raise ValueError("Table topic name has inappropriate format.") if len(known_args.output_table.split('.')) != 2: raise ValueError("Table name has inappropriate format.") inf_args = [ known_args.model_project, known_args.model_name, known_args.model_region, known_args.model_version ] options = PipelineOptions(pipeline_args) options.view_as(SetupOptions).save_main_session = True options.view_as(StandardOptions).streaming = True p = Pipeline(options=options) _ = (p | 'read from pub/sub' >> ReadFromPubSub( known_args.input_topic).with_output_types(bytes) | 'windowing' >> WindowInto(window.FixedWindows(10, 0)) | 'convert to dict' >> Map(json.loads) | 'pre processing' >> PreProcessing() | 'make inference' >> ParDo(MakeRemoteInferenceDoFn(*inf_args)) | 'format message' >> Map(formatter) | 'write to BQ' >> WriteToBigQuery( table=known_args.output_table, schema=build_bq_schema(), create_disposition=BigQueryDisposition.CREATE_IF_NEEDED, write_disposition=BigQueryDisposition.WRITE_APPEND)) if os.environ.get('DEPLOY'): p.run( ) # I use p.run() instead of "opening context `with Pipeline() as p`" because it need to exit after running. else: p.run().wait_until_finish()
def test_context_param(self): class TestDoFn(DoFn): def process(self, element, context=DoFn.ContextParam): yield context.element + 10 pipeline = TestPipeline() pcoll = pipeline | 'Create' >> Create([1, 2])| 'Do' >> ParDo(TestDoFn()) assert_that(pcoll, equal_to([11, 12])) pipeline.run()
def expand(self, pcoll): if self.project is None: self.project = pcoll.pipeline.options.view_as( GoogleCloudOptions).project if self.project is None: raise ValueError( 'GCP project name needs to be specified in "project" pipeline option' ) return pcoll | ParDo( _InspectFn(self.config, self.timeout, self.project))
def expand(self, pcoll): # This is a composite transform involves the following: # 1. Create a singleton of the user provided `query` and apply a ``ParDo`` # that splits the query into `num_splits` queries if possible. # # If the value of `num_splits` is 0, the number of splits will be # computed dynamically based on the size of the data for the `query`. # # 2. The resulting ``PCollection`` is sharded across workers using a # ``Reshuffle`` operation. # # 3. In the third step, a ``ParDo`` reads entities for each query and # outputs a ``PCollection[Entity]``. return (pcoll.pipeline | 'UserQuery' >> Create([self._query]) | 'SplitQuery' >> ParDo( ReadFromDatastore._SplitQueryFn(self._num_splits)) | Reshuffle() | 'Read' >> ParDo(ReadFromDatastore._QueryFn()))
def expand(self, pvalue): return (pvalue | FlatMap(self._create_image_annotation_pairs) | util.BatchElements(min_batch_size=self.min_batch_size, max_batch_size=self.max_batch_size) | ParDo( _ImageAnnotateFn(features=self.features, retry=self.retry, timeout=self.timeout, client_options=self.client_options, metadata=self.metadata)))
def expand(self, pcoll): if self.project is None: self.project = pcoll.pipeline.options.view_as( GoogleCloudOptions).project if self.project is None: raise ValueError( """GCP project name needs to be specified in "project" pipeline option""") return pcoll | ParDo( _CreateCatalogItemFn(self.project, self.retry, self.timeout, self.metadata, self.catalog_name))
def test_window_param(self): class TestDoFn(DoFn): def process(self, element, window=DoFn.WindowParam): yield (element, (float(window.start), float(window.end))) pipeline = TestPipeline() pcoll = (pipeline | Create([1, 7]) | Map(lambda x: TimestampedValue(x, x)) | WindowInto(windowfn=SlidingWindows(10, 5)) | ParDo(TestDoFn())) assert_that( pcoll, equal_to([(1, (-5, 5)), (1, (0, 10)), (7, (0, 10)), (7, (5, 15))])) pcoll2 = pcoll | 'Again' >> ParDo(TestDoFn()) assert_that(pcoll2, equal_to([((1, (-5, 5)), (-5, 5)), ((1, (0, 10)), (0, 10)), ((7, (0, 10)), (0, 10)), ((7, (5, 15)), (5, 15))]), label='doubled windows') pipeline.run()
def expand(self, pcoll): if self.project is None: self.project = pcoll.pipeline.options.view_as( GoogleCloudOptions).project if self.project is None: raise ValueError( 'GCP project name needs to be specified in "project" pipeline option' ) return pcoll | ParDo( _PredictUserEventFn(self.project, self.retry, self.timeout, self.metadata, self.catalog_name, self.event_store, self.placement_id))
def expand(self, pcoll): if self.project is None: self.project = pcoll.pipeline.options.view_as( GoogleCloudOptions).project if self.project is None: raise ValueError( 'GCP project name needs to be specified in "project" pipeline option' ) return (pcoll | GroupIntoBatches.WithShardedKey( self.max_batch_size) | ParDo( _ImportCatalogItemsFn(self.project, self.retry, self.timeout, self.metadata, self.catalog_name)))
def test_side_input_tagged(self): class TestDoFn(DoFn): def process(self, element, prefix, suffix=DoFn.SideInputParam): return ['%s-%s-%s' % (prefix, element, suffix)] with TestPipeline() as pipeline: words_list = ['aa', 'bb', 'cc'] words = pipeline | 'SomeWords' >> Create(words_list) prefix = 'zyx' suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in result = words | 'DecorateWordsDoFnNoTag' >> ParDo( TestDoFn(), prefix, suffix=AsSingleton(suffix)) assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
def test_sdk_harness_container_images_get_set(self): pipeline_options = PipelineOptions([ '--experiments=beam_fn_api', '--experiments=use_unified_worker', '--temp_location', 'gs://any-location/temp' ]) pipeline = Pipeline(options=pipeline_options) pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned test_environment = DockerEnvironment( container_image='test_default_image') proto_pipeline, _ = pipeline.to_runner_api( return_context=True, default_environment=test_environment) # We have to manually add environments since Dataflow only sets # 'sdkHarnessContainerImages' when there are at least two environments. dummy_env = beam_runner_api_pb2.Environment( urn=common_urns.environments.DOCKER.urn, payload=(beam_runner_api_pb2.DockerPayload( container_image='dummy_image')).SerializeToString()) proto_pipeline.components.environments['dummy_env_id'].CopyFrom( dummy_env) dummy_transform = beam_runner_api_pb2.PTransform( environment_id='dummy_env_id') proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom( dummy_transform) env = apiclient.Environment( [], # packages pipeline_options, '2.0.0', # any environment version FAKE_PIPELINE_URL, proto_pipeline, _sdk_image_overrides={ '.*dummy.*': 'dummy_image', '.*test.*': 'test_default_image' }) worker_pool = env.proto.workerPools[0] # For the test, a third environment get added since actual default # container image for Dataflow is different from 'test_default_image' # we've provided above. self.assertEqual(3, len(worker_pool.sdkHarnessContainerImages)) # Container image should be overridden by a Dataflow specific URL. self.assertTrue( str.startswith( (worker_pool.sdkHarnessContainerImages[0]).containerImage, 'gcr.io/cloud-dataflow/v1beta3/python'))
def test_java_sdk_harness_dedup(self): pipeline_options = PipelineOptions([ '--experiments=beam_fn_api', '--experiments=use_unified_worker', '--temp_location', 'gs://any-location/temp' ]) pipeline = Pipeline(options=pipeline_options) pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned proto_pipeline, _ = pipeline.to_runner_api(return_context=True) dummy_env_1 = beam_runner_api_pb2.Environment( urn=common_urns.environments.DOCKER.urn, payload=(beam_runner_api_pb2.DockerPayload( container_image='apache/beam_java:dummy_tag') ).SerializeToString()) proto_pipeline.components.environments['dummy_env_id_1'].CopyFrom( dummy_env_1) dummy_transform_1 = beam_runner_api_pb2.PTransform( environment_id='dummy_env_id_1') proto_pipeline.components.transforms['dummy_transform_id_1'].CopyFrom( dummy_transform_1) dummy_env_2 = beam_runner_api_pb2.Environment( urn=common_urns.environments.DOCKER.urn, payload=(beam_runner_api_pb2.DockerPayload( container_image='apache/beam_java:dummy_tag') ).SerializeToString()) proto_pipeline.components.environments['dummy_env_id_2'].CopyFrom( dummy_env_2) dummy_transform_2 = beam_runner_api_pb2.PTransform( environment_id='dummy_env_id_2') proto_pipeline.components.transforms['dummy_transform_id_2'].CopyFrom( dummy_transform_2) # Accessing non-public method for testing. apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( proto_pipeline, dict(), pipeline_options) # Only one of 'dummy_env_id_1' or 'dummy_env_id_2' should be in the set of # environment IDs used by the proto after Java environment de-duping. env_ids_from_transforms = [ proto_pipeline.components.transforms[transform_id].environment_id for transform_id in proto_pipeline.components.transforms ] if 'dummy_env_id_1' in env_ids_from_transforms: self.assertTrue('dummy_env_id_2' not in env_ids_from_transforms) else: self.assertTrue('dummy_env_id_2' in env_ids_from_transforms)
def expand(self, pcoll): p = pcoll.pipeline # TODO(pabloem): Use a different method to determine if streaming or batch. standard_options = p.options.view_as(StandardOptions) if (not callable(self.table_reference) and self.table_reference.projectId is None): self.table_reference.projectId = pcoll.pipeline.options.view_as( GoogleCloudOptions).project if (standard_options.streaming or self.method == WriteToBigQuery.Method.STREAMING_INSERTS): # TODO: Support load jobs for streaming pipelines. bigquery_write_fn = BigQueryWriteFn( batch_size=self.batch_size, create_disposition=self.create_disposition, write_disposition=self.write_disposition, kms_key=self.kms_key, retry_strategy=self.insert_retry_strategy, test_client=self.test_client) # TODO: Use utility functions from BQTools table_fn = self._get_table_fn() outputs = ( pcoll | 'AppendDestination' >> beam.Map(lambda x: (table_fn(x), x)) | 'StreamInsertRows' >> ParDo(bigquery_write_fn).with_outputs( BigQueryWriteFn.FAILED_ROWS, main='main')) return { BigQueryWriteFn.FAILED_ROWS: outputs[BigQueryWriteFn.FAILED_ROWS] } else: if standard_options.streaming: raise NotImplementedError( 'File Loads to BigQuery are only supported on Batch pipelines.' ) from apache_beam.io.gcp import bigquery_file_loads return pcoll | bigquery_file_loads.BigQueryBatchFileLoads( destination=self.table_reference, schema=self.get_dict_table_schema(self.schema), create_disposition=self.create_disposition, write_disposition=self.write_disposition, max_file_size=self.max_file_size, max_files_per_bundle=self.max_files_per_bundle, gs_location=self.gs_location, test_client=self.test_client)
def expand(self, pcoll): p = pcoll.pipeline if (isinstance(self.table_reference, bigquery.TableReference) and self.table_reference.projectId is None): self.table_reference.projectId = pcoll.pipeline.options.view_as( GoogleCloudOptions).project method_to_use = self._compute_method(p, p.options) if method_to_use == WriteToBigQuery.Method.STREAMING_INSERTS: # TODO: Support load jobs for streaming pipelines. bigquery_write_fn = BigQueryWriteFn( schema=self.schema, batch_size=self.batch_size, create_disposition=self.create_disposition, write_disposition=self.write_disposition, kms_key=self.kms_key, retry_strategy=self.insert_retry_strategy, test_client=self.test_client) outputs = ( pcoll | 'AppendDestination' >> beam.ParDo( bigquery_tools.AppendDestinationsFn(self.table_reference)) | 'StreamInsertRows' >> ParDo(bigquery_write_fn).with_outputs( BigQueryWriteFn.FAILED_ROWS, main='main')) return { BigQueryWriteFn.FAILED_ROWS: outputs[BigQueryWriteFn.FAILED_ROWS] } else: if p.options.view_as(StandardOptions).streaming: raise NotImplementedError( 'File Loads to BigQuery are only supported on Batch pipelines.' ) from apache_beam.io.gcp import bigquery_file_loads return (pcoll | bigquery_file_loads.BigQueryBatchFileLoads( destination=self.table_reference, schema=self.schema, create_disposition=self.create_disposition, write_disposition=self.write_disposition, max_file_size=self.max_file_size, max_files_per_bundle=self.max_files_per_bundle, custom_gcs_temp_location=self.custom_gcs_temp_location, test_client=self.test_client, validate=self._validate))
def expand(self, pcoll): if self.table_reference.projectId is None: self.table_reference.projectId = pcoll.pipeline.options.view_as( GoogleCloudOptions).project bigquery_write_fn = BigQueryWriteFn( table_id=self.table_reference.tableId, dataset_id=self.table_reference.datasetId, project_id=self.table_reference.projectId, batch_size=self.batch_size, schema=self.get_dict_table_schema(self.schema), create_disposition=self.create_disposition, write_disposition=self.write_disposition, test_client=self.test_client) return pcoll | 'WriteToBigQuery' >> ParDo(bigquery_write_fn)
def test_default_environment_get_set(self): pipeline_options = PipelineOptions([ '--experiments=beam_fn_api', '--experiments=use_unified_worker', '--temp_location', 'gs://any-location/temp' ]) pipeline = Pipeline(options=pipeline_options) pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned test_environment = DockerEnvironment(container_image='test_default_image') proto_pipeline, _ = pipeline.to_runner_api( return_context=True, default_environment=test_environment) dummy_env = beam_runner_api_pb2.Environment( urn=common_urns.environments.DOCKER.urn, payload=( beam_runner_api_pb2.DockerPayload( container_image='dummy_image')).SerializeToString()) proto_pipeline.components.environments['dummy_env_id'].CopyFrom(dummy_env) dummy_transform = beam_runner_api_pb2.PTransform( environment_id='dummy_env_id') proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom( dummy_transform) env = apiclient.Environment( [], # packages pipeline_options, '2.0.0', # any environment version FAKE_PIPELINE_URL, proto_pipeline, _sdk_image_overrides={ '.*dummy.*': 'dummy_image', '.*test.*': 'test_default_image' }) worker_pool = env.proto.workerPools[0] self.assertEqual(2, len(worker_pool.sdkHarnessContainerImages)) images_from_proto = [ sdk_info.containerImage for sdk_info in worker_pool.sdkHarnessContainerImages ] self.assertIn('test_default_image', images_from_proto)
def test_non_apache_container_not_overridden(self): pipeline_options = PipelineOptions([ '--experiments=beam_fn_api', '--experiments=use_unified_worker', '--temp_location', 'gs://any-location/temp' ]) pipeline = Pipeline(options=pipeline_options) pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned proto_pipeline, _ = pipeline.to_runner_api(return_context=True) dummy_env = beam_runner_api_pb2.Environment( urn=common_urns.environments.DOCKER.urn, payload=( beam_runner_api_pb2.DockerPayload( container_image='other_org/dummy_name:dummy_tag') ).SerializeToString()) proto_pipeline.components.environments['dummy_env_id'].CopyFrom(dummy_env) dummy_transform = beam_runner_api_pb2.PTransform( environment_id='dummy_env_id') proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom( dummy_transform) # Accessing non-public method for testing. apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( proto_pipeline, dict(), pipeline_options) self.assertIsNotNone(2, len(proto_pipeline.components.environments)) from apache_beam.utils import proto_utils found_override = False for env in proto_pipeline.components.environments.values(): docker_payload = proto_utils.parse_Bytes( env.payload, beam_runner_api_pb2.DockerPayload) if docker_payload.container_image.startswith( names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY): found_override = True self.assertFalse(found_override)
def test_side_inputs(self): class SplitNumbersFn(DoFn): def process(self, element): if element < 0: yield pvalue.TaggedOutput('tag_negative', element) else: yield element class ProcessNumbersFn(DoFn): def process(self, element, negatives): yield element def _process_numbers(pcoll, negatives): first_output = (pcoll | 'process numbers step 1' >> ParDo( ProcessNumbersFn(), negatives)) second_output = (first_output | 'process numbers step 2' >> ParDo( ProcessNumbersFn(), negatives)) output_pc = ((first_output, second_output) | 'flatten results' >> beam.Flatten()) return output_pc root_read = beam.Impulse() result = (self.pipeline | 'read' >> root_read | ParDo(SplitNumbersFn()).with_outputs('tag_negative', main='positive')) positive, negative = result _process_numbers(positive, AsList(negative)) self.pipeline.visit(self.visitor) root_transforms = [t.transform for t in self.visitor.root_transforms] self.assertEqual(root_transforms, [root_read]) self.assertEqual(len(self.visitor.step_names), 5) self.assertEqual(len(self.visitor.views), 1) self.assertTrue(isinstance(self.visitor.views[0], pvalue.AsList))