コード例 #1
0
ファイル: test_stream_test.py プロジェクト: yifanmai/beam
    def test_roundtrip_proto_multi(self):
        test_stream = (TestStream()
                       .advance_processing_time(1)
                       .advance_watermark_to(2, tag='a')
                       .advance_watermark_to(3, tag='b')
                       .add_elements([1, 2, 3], tag='a')
                       .add_elements([4, 5, 6], tag='b')) # yapf: disable

        options = StandardOptions(streaming=True)

        p = TestPipeline(options=options)
        p | test_stream

        pipeline_proto, context = p.to_runner_api(return_context=True)

        for t in pipeline_proto.components.transforms.values():
            if t.spec.urn == common_urns.primitives.TEST_STREAM.urn:
                test_stream_proto = t

        self.assertTrue(test_stream_proto)
        roundtrip_test_stream = TestStream().from_runner_api(
            test_stream_proto, context)

        self.assertListEqual(test_stream._events,
                             roundtrip_test_stream._events)
        self.assertSetEqual(test_stream.output_tags,
                            roundtrip_test_stream.output_tags)
        self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
コード例 #2
0
    def test_job_python_from_python_it(self):
        @ptransform.PTransform.register_urn('simple', None)
        class SimpleTransform(ptransform.PTransform):
            def expand(self, pcoll):
                return pcoll | beam.Map(lambda x: 'Simple(%s)' % x)

            def to_runner_api_parameter(self, unused_context):
                return 'simple', None

            @staticmethod
            def from_runner_api_parameter(_0, _1, _2):
                return SimpleTransform()

        pipeline = TestPipeline(is_integration_test=True)

        res = (pipeline
               | beam.Create(['a', 'b'])
               | beam.ExternalTransform(
                   'simple', None,
                   expansion_service.ExpansionServiceServicer()))
        assert_that(res, equal_to(['Simple(a)', 'Simple(b)']))

        proto_pipeline, _ = pipeline.to_runner_api(return_context=True)
        pipeline_from_proto = Pipeline.from_runner_api(proto_pipeline,
                                                       pipeline.runner,
                                                       pipeline._options)
        pipeline_from_proto.run().wait_until_finish()
コード例 #3
0
  def test_to_from_runner_api(self):
    """Tests that serialization of WriteToBigQuery is correct.

    This is not intended to be a change-detector test. As such, this only tests
    the more complicated serialization logic of parameters: ValueProviders,
    callables, and side inputs.
    """
    FULL_OUTPUT_TABLE = 'test_project:output_table'

    p = TestPipeline()

    # Used for testing side input parameters.
    table_record_pcv = beam.pvalue.AsDict(
        p | "MakeTable" >> beam.Create([('table', FULL_OUTPUT_TABLE)]))

    # Used for testing value provider parameters.
    schema = value_provider.StaticValueProvider(str, '"a:str"')

    original = WriteToBigQuery(
        table=lambda _,
        side_input: side_input['table'],
        table_side_inputs=(table_record_pcv, ),
        schema=schema)

    # pylint: disable=expression-not-assigned
    p | 'MyWriteToBigQuery' >> original

    # Run the pipeline through to generate a pipeline proto from an empty
    # context. This ensures that the serialization code ran.
    pipeline_proto, context = TestPipeline.from_runner_api(
        p.to_runner_api(), p.runner, p.get_pipeline_options()).to_runner_api(
            return_context=True)

    # Find the transform from the context.
    write_to_bq_id = [
        k for k,
        v in pipeline_proto.components.transforms.items()
        if v.unique_name == 'MyWriteToBigQuery'
    ][0]
    deserialized_node = context.transforms.get_by_id(write_to_bq_id)
    deserialized = deserialized_node.transform
    self.assertIsInstance(deserialized, WriteToBigQuery)

    # Test that the serialization of a value provider is correct.
    self.assertEqual(original.schema, deserialized.schema)

    # Test that the serialization of a callable is correct.
    self.assertEqual(
        deserialized._table(None, {'table': FULL_OUTPUT_TABLE}),
        FULL_OUTPUT_TABLE)

    # Test that the serialization of a side input is correct.
    self.assertEqual(
        len(original.table_side_inputs), len(deserialized.table_side_inputs))
    original_side_input_data = original.table_side_inputs[0]._side_input_data()
    deserialized_side_input_data = deserialized.table_side_inputs[
        0]._side_input_data()
    self.assertEqual(
        original_side_input_data.access_pattern,
        deserialized_side_input_data.access_pattern)
    self.assertEqual(
        original_side_input_data.window_mapping_fn,
        deserialized_side_input_data.window_mapping_fn)
    self.assertEqual(
        original_side_input_data.view_fn, deserialized_side_input_data.view_fn)