def model_composite_transform_example(contents, output_path): """Example of a composite transform. To declare a composite transform, define a subclass of PTransform. To override the apply method, define a method "apply" that takes a PCollection as its only parameter and returns a PCollection. """ import re import apache_beam as beam # [START composite_transform_example] # [START composite_ptransform_apply_method] # [START composite_ptransform_declare] class CountWords(beam.PTransform): # [END composite_ptransform_declare] def expand(self, pcoll): return (pcoll | beam.FlatMap(lambda x: re.findall(r'\w+', x)) | beam.combiners.Count.PerElement() | beam.Map(lambda (word, c): '%s: %s' % (word, c))) # [END composite_ptransform_apply_method] # [END composite_transform_example] p = TestPipeline() # Use TestPipeline for testing. (p | beam.Create(contents) | CountWords() | beam.io.WriteToText(output_path)) p.run()
def model_multiple_pcollections_flatten(contents, output_path): """Merging a PCollection with Flatten.""" some_hash_fn = lambda s: ord(s[0]) import apache_beam as beam p = TestPipeline() # Use TestPipeline for testing. partition_fn = lambda element, partitions: some_hash_fn(element) % partitions # Partition into deciles partitioned = p | beam.Create(contents) | beam.Partition(partition_fn, 3) pcoll1 = partitioned[0] pcoll2 = partitioned[1] pcoll3 = partitioned[2] # Flatten them back into 1 # A collection of PCollection objects can be represented simply # as a tuple (or list) of PCollections. # (The SDK for Python has no separate type to store multiple # PCollection objects, whether containing the same or different # types.) # [START model_multiple_pcollections_flatten] merged = ( (pcoll1, pcoll2, pcoll3) # A list of tuples can be "piped" directly into a Flatten transform. | beam.Flatten()) # [END model_multiple_pcollections_flatten] merged | beam.io.WriteToText(output_path) p.run()
def pipeline_logging(lines, output): """Logging Pipeline Messages.""" import re import apache_beam as beam # [START pipeline_logging] # import Python logging module. import logging class ExtractWordsFn(beam.DoFn): def process(self, element): words = re.findall(r'[A-Za-z\']+', element) for word in words: yield word if word.lower() == 'love': # Log using the root logger at info or higher levels logging.info('Found : %s', word.lower()) # Remaining WordCount example code ... # [END pipeline_logging] p = TestPipeline() # Use TestPipeline for testing. (p | beam.Create(lines) | beam.ParDo(ExtractWordsFn()) | beam.io.WriteToText(output)) p.run()
def test_read_from_text_file_pattern(self): pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4]) assert len(expected_data) == 40 pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText(pattern) assert_that(pcoll, equal_to(expected_data)) pipeline.run()
def model_multiple_pcollections_partition(contents, output_path): """Splitting a PCollection with Partition.""" some_hash_fn = lambda s: ord(s[0]) def get_percentile(i): """Assume i in [0,100).""" return i import apache_beam as beam p = TestPipeline() # Use TestPipeline for testing. students = p | beam.Create(contents) # [START model_multiple_pcollections_partition] def partition_fn(student, num_partitions): return int(get_percentile(student) * num_partitions / 100) by_decile = students | beam.Partition(partition_fn, 10) # [END model_multiple_pcollections_partition] # [START model_multiple_pcollections_partition_40th] fortieth_percentile = by_decile[4] # [END model_multiple_pcollections_partition_40th] ([by_decile[d] for d in xrange(10) if d != 4] + [fortieth_percentile] | beam.Flatten() | beam.io.WriteToText(output_path)) p.run()
def test_window_preserved(self): expected_timestamp = timestamp.Timestamp(5) expected_window = window.IntervalWindow(1.0, 2.0) class AddWindowDoFn(beam.DoFn): def process(self, element): yield WindowedValue( element, expected_timestamp, [expected_window]) pipeline = TestPipeline() data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] expected_windows = [ TestWindowedValue(kv, expected_timestamp, [expected_window]) for kv in data] before_identity = (pipeline | 'start' >> beam.Create(data) | 'add_windows' >> beam.ParDo(AddWindowDoFn())) assert_that(before_identity, equal_to(expected_windows), label='before_identity', reify_windows=True) after_identity = (before_identity | 'window' >> beam.WindowInto( beam.transforms.util._IdentityWindowFn( coders.IntervalWindowCoder()))) assert_that(after_identity, equal_to(expected_windows), label='after_identity', reify_windows=True) pipeline.run()
def test_no_window_context_fails(self): expected_timestamp = timestamp.Timestamp(5) # Assuming the default window function is window.GlobalWindows. expected_window = window.GlobalWindow() class AddTimestampDoFn(beam.DoFn): def process(self, element): yield window.TimestampedValue(element, expected_timestamp) pipeline = TestPipeline() data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] expected_windows = [ TestWindowedValue(kv, expected_timestamp, [expected_window]) for kv in data] before_identity = (pipeline | 'start' >> beam.Create(data) | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn())) assert_that(before_identity, equal_to(expected_windows), label='before_identity', reify_windows=True) after_identity = (before_identity | 'window' >> beam.WindowInto( beam.transforms.util._IdentityWindowFn( coders.GlobalWindowCoder())) # This DoFn will return TimestampedValues, making # WindowFn.AssignContext passed to IdentityWindowFn # contain a window of None. IdentityWindowFn should # raise an exception. | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn())) assert_that(after_identity, equal_to(expected_windows), label='after_identity', reify_windows=True) with self.assertRaisesRegexp(ValueError, r'window.*None.*add_timestamps2'): pipeline.run()
def test_timestamped_with_combiners(self): p = TestPipeline() result = (p # Create some initial test values. | 'start' >> Create([(k, k) for k in range(10)]) # The purpose of the WindowInto transform is to establish a # FixedWindows windowing function for the PCollection. # It does not bucket elements into windows since the timestamps # from Create are not spaced 5 ms apart and very likely they all # fall into the same window. | 'w' >> WindowInto(FixedWindows(5)) # Generate timestamped values using the values as timestamps. # Now there are values 5 ms apart and since Map propagates the # windowing function from input to output the output PCollection # will have elements falling into different 5ms windows. | Map(lambda (x, t): TimestampedValue(x, t)) # We add a 'key' to each value representing the index of the # window. This is important since there is no guarantee of # order for the elements of a PCollection. | Map(lambda v: (v / 5, v))) # Sum all elements associated with a key and window. Although it # is called CombinePerKey it is really CombinePerKeyAndWindow the # same way GroupByKey is really GroupByKeyAndWindow. sum_per_window = result | CombinePerKey(sum) # Compute mean per key and window. mean_per_window = result | combiners.Mean.PerKey() assert_that(sum_per_window, equal_to([(0, 10), (1, 35)]), label='assert:sum') assert_that(mean_per_window, equal_to([(0, 2.0), (1, 7.0)]), label='assert:mean') p.run()
def test_setting_timestamp(self): p = TestPipeline() unkeyed_items = p | beam.Create([12, 30, 60, 61, 66]) items = (unkeyed_items | 'key' >> beam.Map(lambda x: ('k', x))) def extract_timestamp_from_log_entry(entry): return entry[1] # [START setting_timestamp] class AddTimestampDoFn(beam.DoFn): def process(self, element): # Extract the numeric Unix seconds-since-epoch timestamp to be # associated with the current log entry. unix_timestamp = extract_timestamp_from_log_entry(element) # Wrap and emit the current entry and new timestamp in a # TimestampedValue. yield beam.window.TimestampedValue(element, unix_timestamp) timestamped_items = items | 'timestamp' >> beam.ParDo(AddTimestampDoFn()) # [END setting_timestamp] fixed_windowed_items = ( timestamped_items | 'window' >> beam.WindowInto( beam.window.FixedWindows(60))) summed = (fixed_windowed_items | 'group' >> beam.GroupByKey() | 'combine' >> beam.CombineValues(sum)) unkeyed = summed | 'unkey' >> beam.Map(lambda x: x[1]) assert_that(unkeyed, equal_to([42, 187])) p.run()
def test_gbk_execution(self): test_stream = (TestStream() .advance_watermark_to(10) .add_elements(['a', 'b', 'c']) .advance_watermark_to(20) .add_elements(['d']) .add_elements(['e']) .advance_processing_time(10) .advance_watermark_to(300) .add_elements([TimestampedValue('late', 12)]) .add_elements([TimestampedValue('last', 310)])) options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) records = (p | test_stream | beam.WindowInto(FixedWindows(15)) | beam.Map(lambda x: ('k', x)) | beam.GroupByKey()) # TODO(BEAM-2519): timestamp assignment for elements from a GBK should # respect the TimestampCombiner. The test below should also verify the # timestamps of the outputted elements once this is implemented. assert_that(records, equal_to([ ('k', ['a', 'b', 'c']), ('k', ['d', 'e']), ('k', ['late']), ('k', ['last'])])) p.run()
def test_to_list_and_to_dict(self): pipeline = TestPipeline() the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6] pcoll = pipeline | 'start' >> Create(the_list) result = pcoll | 'to list' >> combine.ToList() def matcher(expected): def match(actual): equal_to(expected[0])(actual[0]) return match assert_that(result, matcher([the_list])) pipeline.run() pipeline = TestPipeline() pairs = [(1, 2), (3, 4), (5, 6)] pcoll = pipeline | 'start-pairs' >> Create(pairs) result = pcoll | 'to dict' >> combine.ToDict() def matcher(): def match(actual): equal_to([1])([len(actual)]) equal_to(pairs)(actual[0].iteritems()) return match assert_that(result, matcher()) pipeline.run()
def test_run_direct(self): file_name = self._create_temp_file('aaaa\nbbbb\ncccc\ndddd') pipeline = TestPipeline() pcoll = pipeline | beam.io.Read(LineSource(file_name)) assert_that(pcoll, equal_to(['aaaa', 'bbbb', 'cccc', 'dddd'])) pipeline.run()
def test_reshuffle_window_fn_preserved(self): pipeline = TestPipeline() data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] expected_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [ ((1, 1), 1.0, IntervalWindow(1.0, 3.0)), ((2, 1), 1.0, IntervalWindow(1.0, 3.0)), ((3, 1), 1.0, IntervalWindow(1.0, 3.0)), ((1, 2), 2.0, IntervalWindow(2.0, 4.0)), ((2, 2), 2.0, IntervalWindow(2.0, 4.0)), ((1, 4), 4.0, IntervalWindow(4.0, 6.0))]] expected_merged_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [ ((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ((2, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)), ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]] before_reshuffle = (pipeline | 'start' >> beam.Create(data) | 'add_timestamp' >> beam.Map( lambda v: TimestampedValue(v, v[1])) | 'window' >> beam.WindowInto(Sessions(gap_size=2))) assert_that(before_reshuffle, equal_to(expected_windows), label='before_reshuffle', reify_windows=True) after_reshuffle = (before_reshuffle | 'reshuffle' >> beam.Reshuffle()) assert_that(after_reshuffle, equal_to(expected_windows), label='after_reshuffle', reify_windows=True) after_group = (after_reshuffle | 'group_by_key' >> beam.GroupByKey()) assert_that(after_group, equal_to(expected_merged_windows), label='after_group', reify_windows=True) pipeline.run()
def run_pipeline(self, count_implementation, factor=1): p = TestPipeline() words = p | beam.Create(['CAT', 'DOG', 'CAT', 'CAT', 'DOG']) result = words | count_implementation assert_that( result, equal_to([('CAT', (3 * factor)), ('DOG', (2 * factor))])) p.run()
def model_co_group_by_key_tuple(email_list, phone_list, output_path): """Applying a CoGroupByKey Transform to a tuple.""" import apache_beam as beam p = TestPipeline() # Use TestPipeline for testing. # [START model_group_by_key_cogroupbykey_tuple] # Each data set is represented by key-value pairs in separate PCollections. # Both data sets share a common key type (in this example str). # The email_list contains values such as: ('joe', '*****@*****.**') with # multiple possible values for each key. # The phone_list contains values such as: ('mary': '111-222-3333') with # multiple possible values for each key. emails = p | 'email' >> beam.Create(email_list) phones = p | 'phone' >> beam.Create(phone_list) # The result PCollection contains one key-value element for each key in the # input PCollections. The key of the pair will be the key from the input and # the value will be a dictionary with two entries: 'emails' - an iterable of # all values for the current key in the emails PCollection and 'phones': an # iterable of all values for the current key in the phones PCollection. # For instance, if 'emails' contained ('joe', '*****@*****.**') and # ('joe', '*****@*****.**'), then 'result' will contain the element # ('joe', {'emails': ['*****@*****.**', '*****@*****.**'], 'phones': ...}) result = {'emails': emails, 'phones': phones} | beam.CoGroupByKey() def join_info((name, info)): return '; '.join(['%s' % name, '%s' % ','.join(info['emails']), '%s' % ','.join(info['phones'])]) contact_lines = result | beam.Map(join_info) # [END model_group_by_key_cogroupbykey_tuple] contact_lines | beam.io.WriteToText(output_path) p.run()
def test_write_messages_unsupported_features(self, mock_pubsub): data = b'data' attributes = {'key': 'value'} payloads = [PubsubMessage(data, attributes)] options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) _ = (p | Create(payloads) | WriteToPubSub('projects/fakeprj/topics/a_topic', id_label='a_label')) with self.assertRaisesRegexp(NotImplementedError, r'id_label is not supported'): p.run() options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) _ = (p | Create(payloads) | WriteToPubSub('projects/fakeprj/topics/a_topic', timestamp_attribute='timestamp')) with self.assertRaisesRegexp(NotImplementedError, r'timestamp_attribute is not supported'): p.run()
def test_runtime_checks_on(self): # pylint: disable=expression-not-assigned p = TestPipeline(options=PipelineOptions(runtime_type_check=True)) with self.assertRaises(typehints.TypeCheckError): # [START type_hints_runtime_on] p | beam.Create(['a']) | beam.Map(lambda x: 3).with_output_types(str) p.run()
def test_compute_top_sessions(self): p = TestPipeline() edits = p | beam.Create(self.EDITS) result = edits | top_wikipedia_sessions.ComputeTopSessions(1.0) assert_that(result, equal_to(self.EXPECTED)) p.run()
def test_read_messages_timestamp_attribute_rfc3339_success(self, mock_pubsub): data = 'data' message_id = 'message_id' attributes = {'time': '2018-03-12T13:37:01.234567Z'} publish_time = '2018-03-12T13:37:01.234567Z' payloads = [ create_client_message(data, message_id, attributes, publish_time)] expected_elements = [ TestWindowedValue( PubsubMessage(data, attributes), timestamp.Timestamp.from_rfc3339(attributes['time']), [window.GlobalWindow()]), ] mock_pubsub.Client = functools.partial(FakePubsubClient, payloads) mock_pubsub.subscription.AutoAck = FakeAutoAck p = TestPipeline() p.options.view_as(StandardOptions).streaming = True pcoll = (p | ReadFromPubSub( 'projects/fakeprj/topics/a_topic', None, 'a_label', with_attributes=True, timestamp_attribute='time')) assert_that(pcoll, equal_to(expected_elements), reify_windows=True) p.run()
def test_basic_execution(self): test_stream = (TestStream() .advance_watermark_to(10) .add_elements(['a', 'b', 'c']) .advance_watermark_to(20) .add_elements(['d']) .add_elements(['e']) .advance_processing_time(10) .advance_watermark_to(300) .add_elements([TimestampedValue('late', 12)]) .add_elements([TimestampedValue('last', 310)])) class RecordFn(beam.DoFn): def process(self, element=beam.DoFn.ElementParam, timestamp=beam.DoFn.TimestampParam): yield (element, timestamp) options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) my_record_fn = RecordFn() records = p | test_stream | beam.ParDo(my_record_fn) assert_that(records, equal_to([ ('a', timestamp.Timestamp(10)), ('b', timestamp.Timestamp(10)), ('c', timestamp.Timestamp(10)), ('d', timestamp.Timestamp(20)), ('e', timestamp.Timestamp(20)), ('late', timestamp.Timestamp(12)), ('last', timestamp.Timestamp(310)),])) p.run()
def test_read_messages_timestamp_attribute_missing(self, mock_pubsub): data = 'data' attributes = {} publish_time_secs = 1520861821 publish_time_nanos = 234567000 publish_time = '2018-03-12T13:37:01.234567Z' ack_id = 'ack_id' pull_response = test_utils.create_pull_response([ test_utils.PullResponseMessage( data, attributes, publish_time_secs, publish_time_nanos, ack_id) ]) expected_elements = [ TestWindowedValue( PubsubMessage(data, attributes), timestamp.Timestamp.from_rfc3339(publish_time), [window.GlobalWindow()]), ] mock_pubsub.return_value.pull.return_value = pull_response p = TestPipeline() p.options.view_as(StandardOptions).streaming = True pcoll = (p | ReadFromPubSub( 'projects/fakeprj/topics/a_topic', None, None, with_attributes=True, timestamp_attribute='nonexistent')) assert_that(pcoll, equal_to(expected_elements), reify_windows=True) p.run() mock_pubsub.return_value.acknowledge.assert_has_calls([ mock.call(mock.ANY, [ack_id])])
def test_read_messages_timestamp_attribute_milli_success(self, mock_pubsub): data = b'data' attributes = {'time': '1337'} publish_time_secs = 1520861821 publish_time_nanos = 234567000 ack_id = 'ack_id' pull_response = test_utils.create_pull_response([ test_utils.PullResponseMessage( data, attributes, publish_time_secs, publish_time_nanos, ack_id) ]) expected_elements = [ TestWindowedValue( PubsubMessage(data, attributes), timestamp.Timestamp(micros=int(attributes['time']) * 1000), [window.GlobalWindow()]), ] mock_pubsub.return_value.pull.return_value = pull_response options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) pcoll = (p | ReadFromPubSub( 'projects/fakeprj/topics/a_topic', None, None, with_attributes=True, timestamp_attribute='time')) assert_that(pcoll, equal_to(expected_elements), reify_windows=True) p.run() mock_pubsub.return_value.acknowledge.assert_has_calls([ mock.call(mock.ANY, [ack_id])])
def test_basic_execution_sideinputs(self): options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) main_stream = (p | 'main TestStream' >> TestStream() .advance_watermark_to(10) .add_elements(['e'])) side_stream = (p | 'side TestStream' >> TestStream() .add_elements([window.TimestampedValue(2, 2)]) .add_elements([window.TimestampedValue(1, 1)]) .add_elements([window.TimestampedValue(7, 7)]) .add_elements([window.TimestampedValue(4, 4)]) ) class RecordFn(beam.DoFn): def process(self, elm=beam.DoFn.ElementParam, ts=beam.DoFn.TimestampParam, side=beam.DoFn.SideInputParam): yield (elm, ts, side) records = (main_stream # pylint: disable=unused-variable | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_stream))) assert_that(records, equal_to([('e', Timestamp(10), [2, 1, 7, 4])])) p.run()
def test_on_direct_runner(self): class FakeSink(NativeSink): """A fake sink outputing a number of elements.""" def __init__(self): self.written_values = [] self.writer_instance = FakeSinkWriter(self.written_values) def writer(self): return self.writer_instance class FakeSinkWriter(NativeSinkWriter): """A fake sink writer for testing.""" def __init__(self, written_values): self.written_values = written_values def __enter__(self): return self def __exit__(self, *unused_args): pass def Write(self, value): self.written_values.append(value) p = TestPipeline() sink = FakeSink() p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned p.run() self.assertEqual(['a', 'b', 'c'], sink.written_values)
def test_read_from_text_single_file(self): file_name, expected_data = write_data(5) assert len(expected_data) == 5 pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText(file_name) assert_that(pcoll, equal_to(expected_data)) pipeline.run()
def test_read_from_text_with_file_name_single_file(self): file_name, data = write_data(5) expected_data = [(file_name, el) for el in data] assert len(expected_data) == 5 pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(file_name) assert_that(pcoll, equal_to(expected_data)) pipeline.run()
def test_read_all_single_file(self): file_name, expected_data = write_data(5) assert len(expected_data) == 5 pipeline = TestPipeline() pcoll = pipeline | 'Create' >> Create( [file_name]) |'ReadAll' >> ReadAllFromText() assert_that(pcoll, equal_to(expected_data)) pipeline.run()
def test_reshuffle_contents_unchanged(self): pipeline = TestPipeline() data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)] result = (pipeline | 'start' >> beam.Create(data) | 'reshuffle' >> beam.Reshuffle()) assert_that(result, equal_to(data)) pipeline.run()
def test_pipeline_read_single_file(self): with TempDir() as tempdir: file_name = self._create_temp_vcf_file(_SAMPLE_HEADER_LINES + _SAMPLE_TEXT_LINES, tempdir) pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromVcf(file_name) assert_that(pcoll, _count_equals_to(len(_SAMPLE_TEXT_LINES))) pipeline.run()
def test_read_message_id_label_unsupported(self, unused_mock_pubsub): # id_label is unsupported in DirectRunner. p = TestPipeline() p.options.view_as(StandardOptions).streaming = True _ = (p | ReadFromPubSub('projects/fakeprj/topics/a_topic', None, 'a_label')) with self.assertRaisesRegexp(NotImplementedError, r'id_label is not supported'): p.run()
def test_top_py2(self): pipeline = TestPipeline() # A parameter we'll be sharing with a custom comparator. names = { 0: 'zo', 1: 'one', 2: 'twoo', 3: 'three', 5: 'fiiive', 6: 'sssssix', 9: 'nniiinne' } # First for global combines. pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6]) result_cmp = pcoll | 'cmp' >> combine.Top.Of( 6, lambda a, b, names: len(names[a]) < len(names[b]), names) # Note parameter passed to comparator. result_cmp_rev = pcoll | 'cmp_rev' >> combine.Top.Of( 3, lambda a, b, names: len(names[a]) < len(names[b]), names, # Note parameter passed to comparator. reverse=True) assert_that(result_cmp, equal_to([[9, 6, 6, 5, 3, 2]]), label='assert:cmp') assert_that(result_cmp_rev, equal_to([[0, 1, 1]]), label='assert:cmp_rev') # Again for per-key combines. pcoll = pipeline | 'start-perkye' >> Create( [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]]) result_key_cmp = pcoll | 'cmp-perkey' >> combine.Top.PerKey( 6, lambda a, b, names: len(names[a]) < len(names[b]), names) # Note parameter passed to comparator. assert_that(result_key_cmp, equal_to([('a', [9, 6, 6, 5, 3, 2])]), label='key:cmp') pipeline.run()
def test_multiple_outputs(self): """Tests that the TestStream supports emitting to multiple PCollections.""" letters_elements = [ TimestampedValue('a', 6), TimestampedValue('b', 7), TimestampedValue('c', 8), ] numbers_elements = [ TimestampedValue('1', 11), TimestampedValue('2', 12), TimestampedValue('3', 13), ] test_stream = (TestStream() .advance_watermark_to(5, tag='letters') .add_elements(letters_elements, tag='letters') .advance_watermark_to(10, tag='numbers') .add_elements(numbers_elements, tag='numbers')) # yapf: disable class RecordFn(beam.DoFn): def process(self, element=beam.DoFn.ElementParam, timestamp=beam.DoFn.TimestampParam): yield (element, timestamp) options = StandardOptions(streaming=True) p = TestPipeline(options=options) main = p | test_stream letters = main['letters'] | 'record letters' >> beam.ParDo(RecordFn()) numbers = main['numbers'] | 'record numbers' >> beam.ParDo(RecordFn()) assert_that(letters, equal_to([('a', Timestamp(6)), ('b', Timestamp(7)), ('c', Timestamp(8))]), label='assert letters') assert_that(numbers, equal_to([('1', Timestamp(11)), ('2', Timestamp(12)), ('3', Timestamp(13))]), label='assert numbers') p.run()
def test_read_with_query_batch(self, mock_batch_snapshot_class, mock_client_class): mock_snapshot = mock.MagicMock() mock_snapshot.generate_query_batches.return_value = [{ 'query': { 'sql': 'SELECT * FROM users' }, 'partition': 'test_partition' } for _ in range(3)] mock_snapshot.process_query_batch.side_effect = [ FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:] ] ro = [ReadOperation.query("Select * from users")] pipeline = TestPipeline() read = (pipeline | 'read' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name(), sql="SELECT * FROM users")) readall = (pipeline | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name(), read_operations=ro)) readpipeline = ( pipeline | 'create reads' >> beam.Create(ro) | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name())) pipeline.run() assert_that(read, equal_to(FAKE_ROWS), label='checkRead') assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll') assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
def test_gbk_execution_no_triggers(self): test_stream = (TestStream().advance_watermark_to(10).add_elements([ 'a', 'b', 'c' ]).advance_watermark_to(20).add_elements(['d']).add_elements([ 'e' ]).advance_processing_time(10).advance_watermark_to(300).add_elements([ TimestampedValue('late', 12) ]).add_elements([TimestampedValue('last', 310) ]).advance_watermark_to_infinity()) options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) records = (p | test_stream | beam.WindowInto(FixedWindows(15), allowed_lateness=300) | beam.Map(lambda x: ('k', x)) | beam.GroupByKey()) # TODO(https://github.com/apache/beam/issues/18441): timestamp assignment # for elements from a GBK should respect the TimestampCombiner. The test # below should also verify the timestamps of the outputted elements once # this is implemented. # assert per window expected_window_to_elements = { window.IntervalWindow(0, 15): [ ('k', ['a', 'b', 'c']), ('k', ['late']), ], window.IntervalWindow(15, 30): [ ('k', ['d', 'e']), ], window.IntervalWindow(300, 315): [ ('k', ['last']), ], } assert_that(records, equal_to_per_window(expected_window_to_elements), label='assert per window') p.run()
def test_gbk_execution_after_processing_trigger_fired(self): """Advance TestClock to (X + delta) and see the pipeline does finish.""" # TODO(mariagh): Add test_gbk_execution_after_processing_trigger_unfired # Advance TestClock to (X + delta) and see the pipeline does finish # Possibly to the framework trigger_transcripts.yaml test_stream = (TestStream() .advance_watermark_to(10) .add_elements(['a']) .advance_processing_time(5.1)) options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) records = (p | test_stream | beam.WindowInto( beam.window.FixedWindows(15), trigger=trigger.AfterProcessingTime(5), accumulation_mode=trigger.AccumulationMode.DISCARDING ) | beam.Map(lambda x: ('k', x)) | beam.GroupByKey()) # TODO(BEAM-2519): timestamp assignment for elements from a GBK should # respect the TimestampCombiner. The test below should also verify the # timestamps of the outputted elements once this is implemented. assert_that(records, equal_to([ ('k', ['a'])])) # assert per window expected_window_to_elements = { window.IntervalWindow(15, 30): [('k', ['a'])], } assert_that( records, equal_to_per_window(expected_window_to_elements), custom_windowing=window.FixedWindows(15), label='assert per window') p.run()
def testCalculateProgramMetricCombinations(self): """Tests the CalculateProgramMetricCombinations DoFn.""" fake_person = StatePerson.new_with_defaults( person_id=123, gender=Gender.MALE, birthdate=date(1970, 1, 1), residency_status=ResidencyStatus.PERMANENT) program_events = [ ProgramReferralEvent( state_code='US_TX', event_date=date(2011, 4, 3), program_id='program', participation_status=StateProgramAssignmentParticipationStatus. IN_PROGRESS), ProgramParticipationEvent(state_code='US_TX', event_date=date(2011, 6, 3), program_id='program') ] # Each event will be have an output for each methodology type expected_metric_count = 2 expected_combination_counts = \ {'referrals': expected_metric_count, 'participation': expected_metric_count} test_pipeline = TestPipeline() output = (test_pipeline | beam.Create([(fake_person, program_events)]) | 'Calculate Program Metrics' >> beam.ParDo( pipeline.CalculateProgramMetricCombinations(), None, -1, ALL_METRIC_INCLUSIONS_DICT)) assert_that( output, AssertMatchers.count_combinations(expected_combination_counts), 'Assert number of metrics is expected value') test_pipeline.run()
def test_top(self): pipeline = TestPipeline() # First for global combines. pcoll = pipeline | 'start' >> Create([6, 3, 1, 1, 9, 1, 5, 2, 0, 6]) result_top = pcoll | 'top' >> combine.Top.Largest(5) result_bot = pcoll | 'bot' >> combine.Top.Smallest(4) assert_that(result_top, equal_to([[9, 6, 6, 5, 3]]), label='assert:top') assert_that(result_bot, equal_to([[0, 1, 1, 1]]), label='assert:bot') # Again for per-key combines. pcoll = pipeline | 'start-perkey' >> Create( [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]]) result_key_top = pcoll | 'top-perkey' >> combine.Top.LargestPerKey(5) result_key_bot = pcoll | 'bot-perkey' >> combine.Top.SmallestPerKey(4) assert_that(result_key_top, equal_to([('a', [9, 6, 6, 5, 3])]), label='key:top') assert_that(result_key_bot, equal_to([('a', [0, 1, 1, 1])]), label='key:bot') pipeline.run()
def test_convert_variant_to_bigquery_row(self): variant_1, row_1 = self._get_sample_variant_1() variant_2, row_2 = self._get_sample_variant_2() variant_3, row_3 = self._get_sample_variant_3() header_fields = vcf_header_parser.HeaderFields({}, {}) proc_var_1 = processed_variant.ProcessedVariantFactory( header_fields).create_processed_variant(variant_1) proc_var_2 = processed_variant.ProcessedVariantFactory( header_fields).create_processed_variant(variant_2) proc_var_3 = processed_variant.ProcessedVariantFactory( header_fields).create_processed_variant(variant_3) pipeline = TestPipeline() bigquery_rows = ( pipeline | Create([proc_var_1, proc_var_2, proc_var_3]) | 'ConvertToRow' >> ParDo( ConvertToBigQueryTableRow( mock_bigquery_schema_descriptor.MockSchemaDescriptor()))) assert_that(bigquery_rows, equal_to([row_1, row_2, row_3])) pipeline.run()
def test_partition_variants(self): expected_partitions = self._get_standard_variant_partitions() expected_partitions.update(self._get_nonstandard_variant_partitions()) variants = [ variant for variant_list in expected_partitions.values() for variant in variant_list ] partitioner = variant_partition.VariantPartition() pipeline = TestPipeline() partitions = (pipeline | Create(variants) | 'PartitionVariants' >> Partition( partition_variants.PartitionVariants(partitioner), partitioner.get_num_partitions())) for i in xrange(partitioner.get_num_partitions()): assert_that(partitions[i], equal_to(expected_partitions.get(i, [])), label=str(i)) pipeline.run()
def test_approximate_unique_global_by_sample_size_with_duplicates(self): # test if estimation error with a given sample size is not greater than # expected max error with duplicated input. sample_size = 30 max_err = 2 / math.sqrt(sample_size) test_input = [10] * 50 + [20] * 50 actual_count = len(set(test_input)) pipeline = TestPipeline() result = (pipeline | 'create' >> beam.Create(test_input) | 'get_estimate' >> beam.ApproximateUnique.Globally(size=sample_size) | 'compare' >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0 / actual_count <= max_err])) assert_that(result, equal_to([True]), label='assert:global_by_size_with_duplicates') pipeline.run()
def test_metrics_in_fake_source(self): # FakeSource mock requires DirectRunner. pipeline = TestPipeline(runner='DirectRunner') pcoll = pipeline | Read(FakeSource([1, 2, 3, 4, 5, 6])) assert_that(pcoll, equal_to([1, 2, 3, 4, 5, 6])) res = pipeline.run() metric_results = res.metrics().query() outputs_counter = metric_results['counters'][0] self.assertEqual(outputs_counter.key.step, 'Read') self.assertEqual(outputs_counter.key.metric.name, 'outputs') self.assertEqual(outputs_counter.committed, 6)
def test_kafkaio_write(self): local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR') with self.local_kafka_service(local_kafka_jar) as kafka_port: p = TestPipeline() p.not_use_test_runner_api = True xlang_kafkaio = CrossLanguageKafkaIO( '%s:%s' % (self.get_platform_localhost(), kafka_port), 'xlang_kafkaio_test') xlang_kafkaio.build_write_pipeline(p) job = p.run() job.wait_until_finish()
def testClassifyIncarcerationEvents_NoSentenceGroups(self): """Tests the ClassifyIncarcerationEvents DoFn when the person has no sentence groups.""" fake_person = StatePerson.new_with_defaults( person_id=123, gender=Gender.MALE, birthdate=date(1970, 1, 1), residency_status=ResidencyStatus.PERMANENT) person_periods = {'person': [fake_person], 'sentence_groups': []} test_pipeline = TestPipeline() output = (test_pipeline | beam.Create([(fake_person.person_id, person_periods)]) | 'Identify Incarceration Events' >> beam.ParDo( pipeline.ClassifyIncarcerationEvents(), {})) assert_that(output, equal_to([])) test_pipeline.run()
def test_read_data_success(self, mock_pubsub): data_encoded = u'🤷 ¯\\_(ツ)_/¯'.encode('utf-8') ack_id = 'ack_id' pull_response = test_utils.create_pull_response([ test_utils.PullResponseMessage(data_encoded, ack_id=ack_id)]) expected_elements = [data_encoded] mock_pubsub.return_value.pull.return_value = pull_response options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) pcoll = (p | ReadFromPubSub('projects/fakeprj/topics/a_topic', None, None)) assert_that(pcoll, equal_to(expected_elements)) p.run() mock_pubsub.return_value.acknowledge.assert_has_calls([ mock.call(mock.ANY, [ack_id])]) mock_pubsub.return_value.api.transport.channel.close.assert_has_calls([ mock.call()])
def test_approximate_unique_global_by_sample_size_with_big_population(self): # test if estimation error is smaller than expected max error with a small # sample and a big population (sample size = 1% of population). sample_size = 100 max_err = 2 / math.sqrt(sample_size) test_input = [random.randint(0, 1000) for _ in range(10000)] actual_count = len(set(test_input)) pipeline = TestPipeline() result = (pipeline | 'create' >> beam.Create(test_input) | 'get_estimate' >> beam.ApproximateUnique.Globally(size=sample_size) | 'compare' >> beam.FlatMap(lambda x: [abs(x - actual_count) * 1.0 / actual_count <= max_err])) assert_that(result, equal_to([True]), label='assert:global_by_sample_size_with_big_population') pipeline.run()
def test_shard_variants(self): expected_shards = self._get_expected_variant_shards() variants = [variant for variant_list in expected_shards.values() for variant in variant_list] sharding = variant_sharding.VariantSharding( 'gcp_variant_transforms/data/sharding_configs/' 'homo_sapiens_default.yaml') pipeline = TestPipeline() shards = ( pipeline | Create(variants, reshuffle=False) | 'ShardVariants' >> beam.Partition( shard_variants.ShardVariants(sharding), sharding.get_num_shards())) for i in range(sharding.get_num_shards()): assert_that(shards[i], equal_to(expected_shards.get(i, [])), label=str(i)) pipeline.run()
def test_densify_variants_pipeline(self): call_names = ['sample1', 'sample2', 'sample3'] variant_calls = [ vcfio.VariantCall(name=call_names[0]), vcfio.VariantCall(name=call_names[1]), vcfio.VariantCall(name=call_names[2]), ] variants = [ vcfio.Variant(calls=[variant_calls[0], variant_calls[1]]), vcfio.Variant(calls=[variant_calls[1], variant_calls[2]]), ] pipeline = TestPipeline() densified_variants = ( pipeline | Create(variants) | 'DensifyVariants' >> densify_variants.DensifyVariants()) assert_that(densified_variants, asserts.has_calls(call_names)) pipeline.run()
def test_read_pattern_gzip(self): _, lines = write_data(200) splits = [0, 34, 100, 140, 164, 188, 200] chunks = [ lines[splits[i - 1]:splits[i]] for i in range(1, len(splits)) ] compressed_chunks = [] for c in chunks: out = io.BytesIO() with gzip.GzipFile(fileobj=out, mode="wb") as f: f.write(b'\n'.join(c)) compressed_chunks.append(out.getvalue()) file_pattern = write_prepared_pattern(compressed_chunks) pipeline = TestPipeline() pcoll = pipeline | 'Read' >> beam.io.Read( LineSource(file_pattern, splittable=False, compression_type=CompressionTypes.GZIP)) assert_that(pcoll, equal_to(lines)) pipeline.run()
def testProduceViolationMetricsNoInput(self) -> None: """Tests the ProduceViolationMetrics when there is no input to the function.""" test_pipeline = TestPipeline() output = (test_pipeline | beam.Create([]) | beam.ParDo(ExtractPersonEventsMetadata()) | "Produce ViolationMetrics" >> beam.ParDo( ProduceMetrics(), self.pipeline_config, ALL_METRIC_INCLUSIONS_DICT, test_pipeline_options(), None, -1, )) assert_that(output, equal_to([])) test_pipeline.run()
def test_call_names_combiner_pipeline_preserve_call_names_order_error( self): call_names = ['sample1', 'sample2', 'sample3'] variant_calls = [ vcfio.VariantCall(name=call_names[0]), vcfio.VariantCall(name=call_names[1]), vcfio.VariantCall(name=call_names[2]) ] variants = [ vcfio.Variant(calls=[variant_calls[0], variant_calls[1]]), vcfio.Variant(calls=[variant_calls[1], variant_calls[2]]) ] pipeline = TestPipeline() _ = (pipeline | transforms.Create(variants) | 'CombineCallNames' >> combine_call_names.CallNamesCombiner( preserve_call_names_order=True)) with self.assertRaises(ValueError): pipeline.run()
def test_pipeline_read_file_pattern(self): with temp_dir.TempDir() as tempdir: headers_1 = [self.lines[1], self.lines[-1]] headers_2 = [self.lines[2], self.lines[3], self.lines[-1]] headers_3 = [self.lines[4], self.lines[-1]] file_name_1 = tempdir.create_temp_file(suffix='.vcf', lines=headers_1) file_name_2 = tempdir.create_temp_file(suffix='.vcf', lines=headers_2) file_name_3 = tempdir.create_temp_file(suffix='.vcf', lines=headers_3) pipeline = TestPipeline() pcoll = pipeline | 'ReadHeaders' >> ReadVcfHeaders( os.path.join(tempdir.get_path(), '*.vcf')) expected = [_get_vcf_header_from_lines(h, file_name=file_name) for h, file_name in [(headers_1, file_name_1), (headers_2, file_name_2), (headers_3, file_name_3)]] assert_that(pcoll, asserts.header_vars_equal(expected)) pipeline.run()
def test_reshuffle_windows_unchanged(self): pipeline = TestPipeline() data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] expected_data = [TestWindowedValue(v, t, [w]) for (v, t, w) in [ ((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ((2, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)), ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]] before_reshuffle = (pipeline | 'start' >> beam.Create(data) | 'add_timestamp' >> beam.Map( lambda v: beam.window.TimestampedValue(v, v[1])) | 'window' >> beam.WindowInto(Sessions(gap_size=2)) | 'group_by_key' >> beam.GroupByKey()) assert_that(before_reshuffle, equal_to(expected_data), label='before_reshuffle', reify_windows=True) after_reshuffle = before_reshuffle | beam.Reshuffle() assert_that(after_reshuffle, equal_to(expected_data), label='after reshuffle', reify_windows=True) pipeline.run()
def test_reshuffle_sliding_window(self): pipeline = TestPipeline() data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] window_size = 2 expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])] * window_size before_reshuffle = (pipeline | 'start' >> beam.Create(data) | 'window' >> beam.WindowInto( SlidingWindows(size=window_size, period=1)) | 'group_by_key' >> beam.GroupByKey()) assert_that(before_reshuffle, equal_to(expected_data), label='before_reshuffle') after_reshuffle = (before_reshuffle | 'reshuffle' >> beam.Reshuffle()) # If Reshuffle applies the sliding window function a second time there # should be extra values for each key. assert_that(after_reshuffle, equal_to(expected_data), label='after reshuffle') pipeline.run()
def test_basic_execution_sideinputs(self): # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. global result # pylint: disable=global-variable-undefined result = [] def recorded_elements(elem): result.append(elem) return elem options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) main_stream = (p | 'main TestStream' >> TestStream() .advance_watermark_to(10) .add_elements(['e'])) side_stream = (p | 'side TestStream' >> TestStream() .add_elements([window.TimestampedValue(2, 2)]) .add_elements([window.TimestampedValue(1, 1)]) .add_elements([window.TimestampedValue(7, 7)]) .add_elements([window.TimestampedValue(4, 4)]) ) class RecordFn(beam.DoFn): def process(self, elm=beam.DoFn.ElementParam, ts=beam.DoFn.TimestampParam, side=beam.DoFn.SideInputParam): yield (elm, ts, side) records = (main_stream # pylint: disable=unused-variable | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_stream)) | beam.Map(recorded_elements)) p.run() # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. self.assertEqual([('e', Timestamp(10), [2, 1, 7, 4])], result)
def test_basic_execution_batch_sideinputs_fixed_windows(self): # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. global result # pylint: disable=global-variable-undefined result = [] def recorded_elements(elem): result.append(elem) return elem options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) main_stream = (p | 'main TestStream' >> TestStream() .advance_watermark_to(2) .add_elements(['a']) .advance_watermark_to(4) .add_elements(['b']) | 'main window' >> beam.WindowInto(window.FixedWindows(1))) side = (p | beam.Create([2, 1, 4]) | beam.Map(lambda t: window.TimestampedValue(t, t)) | beam.WindowInto(window.FixedWindows(2))) class RecordFn(beam.DoFn): def process(self, elm=beam.DoFn.ElementParam, ts=beam.DoFn.TimestampParam, side=beam.DoFn.SideInputParam): yield (elm, ts, side) records = (main_stream # pylint: disable=unused-variable | beam.ParDo(RecordFn(), beam.pvalue.AsList(side)) | beam.Map(recorded_elements)) p.run() # TODO(BEAM-3377): Remove after assert_that in streaming is fixed. self.assertEqual([('a', Timestamp(2), [2]), ('b', Timestamp(4), [4])], result)
def test_read_bzip2_concat(self): with TempDir() as tempdir: bzip2_file_name1 = tempdir.create_temp_file() lines = ['a', 'b', 'c'] with bz2.BZ2File(bzip2_file_name1, 'wb') as dst: data = '\n'.join(lines) + '\n' dst.write(data.encode('utf-8')) bzip2_file_name2 = tempdir.create_temp_file() lines = ['p', 'q', 'r'] with bz2.BZ2File(bzip2_file_name2, 'wb') as dst: data = '\n'.join(lines) + '\n' dst.write(data.encode('utf-8')) bzip2_file_name3 = tempdir.create_temp_file() lines = ['x', 'y', 'z'] with bz2.BZ2File(bzip2_file_name3, 'wb') as dst: data = '\n'.join(lines) + '\n' dst.write(data.encode('utf-8')) final_bzip2_file = tempdir.create_temp_file() with open(bzip2_file_name1, 'rb') as src, open( final_bzip2_file, 'wb') as dst: dst.writelines(src.readlines()) with open(bzip2_file_name2, 'rb') as src, open( final_bzip2_file, 'ab') as dst: dst.writelines(src.readlines()) with open(bzip2_file_name3, 'rb') as src, open( final_bzip2_file, 'ab') as dst: dst.writelines(src.readlines()) pipeline = TestPipeline() lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText( final_bzip2_file, compression_type=beam.io.filesystem.CompressionTypes.BZIP2) expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z'] assert_that(lines, equal_to(expected)) pipeline.run()
def test_basic_execution_sideinputs(self): options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) main_stream = (p | 'main TestStream' >> TestStream() .advance_watermark_to(10) .add_elements(['e'])) side_stream = (p | 'side TestStream' >> TestStream() .add_elements([window.TimestampedValue(2, 2)]) .add_elements([window.TimestampedValue(1, 1)]) .add_elements([window.TimestampedValue(7, 7)]) .add_elements([window.TimestampedValue(4, 4)]) ) class RecordFn(beam.DoFn): def process(self, elm=beam.DoFn.ElementParam, ts=beam.DoFn.TimestampParam, side=beam.DoFn.SideInputParam): yield (elm, ts, side) records = (main_stream # pylint: disable=unused-variable | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_stream))) # assert per window expected_window_to_elements = { window.IntervalWindow(0, 15): [ ('e', Timestamp(10), [2, 1, 7, 4]), ], } assert_that( records, equal_to_per_window(expected_window_to_elements), custom_windowing=window.FixedWindows(15), label='assert per window') assert_that(records, equal_to([('e', Timestamp(10), [2, 1, 7, 4])])) p.run()
def test_gbk_execution_after_watermark_trigger(self): test_stream = (TestStream() .advance_watermark_to(10) .add_elements([TimestampedValue('a', 11)]) .advance_watermark_to(20) .add_elements([TimestampedValue('b', 21)]) .advance_watermark_to_infinity()) options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) records = (p # pylint: disable=unused-variable | test_stream | beam.WindowInto( FixedWindows(15), trigger=trigger.AfterWatermark(early=trigger.AfterCount(1)), accumulation_mode=trigger.AccumulationMode.DISCARDING) | beam.Map(lambda x: ('k', x)) | beam.GroupByKey()) # TODO(BEAM-2519): timestamp assignment for elements from a GBK should # respect the TimestampCombiner. The test below should also verify the # timestamps of the outputted elements once this is implemented. # assert per window expected_window_to_elements = { window.IntervalWindow(0, 15): [ ('k', ['a']), ('k', []) ], window.IntervalWindow(15, 30): [ ('k', ['b']), ('k', []) ], } assert_that( records, equal_to_per_window(expected_window_to_elements), label='assert per window') p.run()
def test_deterministic_key(self): p = TestPipeline() lines = (p | beam.Create([ 'banana,fruit,3', 'kiwi,fruit,2', 'kiwi,fruit,2', 'zucchini,veg,3' ])) # For pickling global Player # pylint: disable=global-variable-not-assigned # [START type_hints_deterministic_key] class Player(object): def __init__(self, team, name): self.team = team self.name = name class PlayerCoder(beam.coders.Coder): def encode(self, player): return '%s:%s' % (player.team, player.name) def decode(self, s): return Player(*s.split(':')) def is_deterministic(self): return True beam.coders.registry.register_coder(Player, PlayerCoder) def parse_player_and_score(csv): name, team, score = csv.split(',') return Player(team, name), int(score) totals = (lines | beam.Map(parse_player_and_score) | beam.CombinePerKey(sum).with_input_types( beam.typehints.Tuple[Player, int])) # [END type_hints_deterministic_key] assert_that(totals | beam.Map(lambda (k, v): (k.name, v)), equal_to([('banana', 3), ('kiwi', 4), ('zucchini', 3)])) p.run()
def test_basic_execution_batch_sideinputs_fixed_windows(self): options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) main_stream = ( p | 'main TestStream' >> TestStream().advance_watermark_to(2).add_elements( ['a']).advance_watermark_to(4).add_elements( ['b']).advance_watermark_to_infinity() | 'main window' >> beam.WindowInto(window.FixedWindows(1))) side = ( p | beam.Create([2, 1, 4]) | beam.Map(lambda t: window.TimestampedValue(t, t)) | beam.WindowInto(window.FixedWindows(2))) class RecordFn(beam.DoFn): def process( self, elm=beam.DoFn.ElementParam, ts=beam.DoFn.TimestampParam, side=beam.DoFn.SideInputParam): yield (elm, ts, side) records = ( main_stream # pylint: disable=unused-variable | beam.ParDo(RecordFn(), beam.pvalue.AsList(side))) # assert per window expected_window_to_elements = { window.IntervalWindow(2, 3): [('a', Timestamp(2), [2])], window.IntervalWindow(4, 5): [('b', Timestamp(4), [4])] } assert_that( records, equal_to_per_window(expected_window_to_elements), label='assert per window') p.run()