def test_gbk_execution(self): test_stream = (TestStream() .advance_watermark_to(10) .add_elements(['a', 'b', 'c']) .advance_watermark_to(20) .add_elements(['d']) .add_elements(['e']) .advance_processing_time(10) .advance_watermark_to(300) .add_elements([TimestampedValue('late', 12)]) .add_elements([TimestampedValue('last', 310)])) options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) records = (p | test_stream | beam.WindowInto(FixedWindows(15)) | beam.Map(lambda x: ('k', x)) | beam.GroupByKey()) # TODO(BEAM-2519): timestamp assignment for elements from a GBK should # respect the TimestampCombiner. The test below should also verify the # timestamps of the outputted elements once this is implemented. assert_that(records, equal_to([ ('k', ['a', 'b', 'c']), ('k', ['d', 'e']), ('k', ['late']), ('k', ['last'])])) p.run()
def test_read_from_text_file_pattern(self): pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4]) assert len(expected_data) == 40 pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText(pattern) assert_that(pcoll, equal_to(expected_data)) pipeline.run()
def test_multi(self): @ptransform.PTransform.register_urn('multi', None) class MutltiTransform(ptransform.PTransform): def expand(self, pcolls): return { 'main': (pcolls['main1'], pcolls['main2']) | beam.Flatten() | beam.Map(lambda x, s: x + s, beam.pvalue.AsSingleton(pcolls['side'])), 'side': pcolls['side'] | beam.Map(lambda x: x + x), } def to_runner_api_parameter(self, unused_context): return 'multi', None @staticmethod def from_runner_api_parameter(unused_parameter, unused_context): return MutltiTransform() with beam.Pipeline() as p: main1 = p | 'Main1' >> beam.Create(['a', 'bb'], reshuffle=False) main2 = p | 'Main2' >> beam.Create(['x', 'yy', 'zzz'], reshuffle=False) side = p | 'Side' >> beam.Create(['s']) res = dict(main1=main1, main2=main2, side=side) | beam.ExternalTransform( 'multi', None, expansion_service.ExpansionServiceServicer()) assert_that(res['main'], equal_to(['as', 'bbs', 'xs', 'yys', 'zzzs'])) assert_that(res['side'], equal_to(['ss']), label='CheckSide')
def test_payload(self): @ptransform.PTransform.register_urn('payload', bytes) class PayloadTransform(ptransform.PTransform): def __init__(self, payload): self._payload = payload def expand(self, pcoll): return pcoll | beam.Map(lambda x, s: x + s, self._payload) def to_runner_api_parameter(self, unused_context): return b'payload', self._payload.encode('ascii') @staticmethod def from_runner_api_parameter(payload, unused_context): return PayloadTransform(payload.decode('ascii')) with beam.Pipeline() as p: res = ( p | beam.Create(['a', 'bb'], reshuffle=False) | beam.ExternalTransform( 'payload', b's', expansion_service.ExpansionServiceServicer())) assert_that(res, equal_to(['as', 'bbs']))
def test_setting_timestamp(self): with TestPipeline() as p: unkeyed_items = p | beam.Create([12, 30, 60, 61, 66]) items = (unkeyed_items | 'key' >> beam.Map(lambda x: ('k', x))) def extract_timestamp_from_log_entry(entry): return entry[1] # [START setting_timestamp] class AddTimestampDoFn(beam.DoFn): def process(self, element): # Extract the numeric Unix seconds-since-epoch timestamp to be # associated with the current log entry. unix_timestamp = extract_timestamp_from_log_entry(element) # Wrap and emit the current entry and new timestamp in a # TimestampedValue. yield beam.window.TimestampedValue(element, unix_timestamp) timestamped_items = items | 'timestamp' >> beam.ParDo(AddTimestampDoFn()) # [END setting_timestamp] fixed_windowed_items = ( timestamped_items | 'window' >> beam.WindowInto( beam.window.FixedWindows(60))) summed = (fixed_windowed_items | 'group' >> beam.GroupByKey() | 'combine' >> beam.CombineValues(sum)) unkeyed = summed | 'unkey' >> beam.Map(lambda x: x[1]) assert_that(unkeyed, equal_to([42, 187]))
def test_read_from_text_single_file(self): file_name, expected_data = write_data(5) assert len(expected_data) == 5 pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText(file_name) assert_that(pcoll, equal_to(expected_data)) pipeline.run()
def test_group_by_key(self): with self.create_pipeline() as p: res = (p | beam.Create([('a', 1), ('a', 2), ('b', 3)]) | beam.GroupByKey() | beam.Map(lambda k_vs: (k_vs[0], sorted(k_vs[1])))) assert_that(res, equal_to([('a', [1, 2]), ('b', [3])]))
def test_pardo(self): with self.create_pipeline() as p: res = (p | beam.Create(['a', 'bc']) | beam.Map(lambda e: e * 2) | beam.Map(lambda e: e + 'x')) assert_that(res, equal_to(['aax', 'bcbcx']))
def test_pardo_windowed_side_inputs(self): with self.create_pipeline() as p: # Now with some windowing. pcoll = p | beam.Create(list(range(10))) | beam.Map( lambda t: window.TimestampedValue(t, t)) # Intentionally choosing non-aligned windows to highlight the transition. main = pcoll | 'WindowMain' >> beam.WindowInto(window.FixedWindows(5)) side = pcoll | 'WindowSide' >> beam.WindowInto(window.FixedWindows(7)) res = main | beam.Map(lambda x, s: (x, sorted(s)), beam.pvalue.AsList(side)) assert_that( res, equal_to([ # The window [0, 5) maps to the window [0, 7). (0, list(range(7))), (1, list(range(7))), (2, list(range(7))), (3, list(range(7))), (4, list(range(7))), # The window [5, 10) maps to the window [7, 14). (5, list(range(7, 10))), (6, list(range(7, 10))), (7, list(range(7, 10))), (8, list(range(7, 10))), (9, list(range(7, 10)))]), label='windowed')
def test_gbk_side_input(self): with self.create_pipeline() as p: main = p | 'main' >> beam.Create([None]) side = p | 'side' >> beam.Create([('a', 1)]) | beam.GroupByKey() assert_that( main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)), equal_to([(None, {'a': [1]})]))
def test_read_all_from_avro_many_single_files(self): path1 = self._write_data() path2 = self._write_data() path3 = self._write_data() with TestPipeline() as p: assert_that(p | Create([path1, path2, path3]) | avroio.ReadAllFromAvro(), equal_to(self.RECORDS * 3))
def test_to_list_and_to_dict(self): pipeline = TestPipeline() the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6] pcoll = pipeline | 'start' >> Create(the_list) result = pcoll | 'to list' >> combine.ToList() def matcher(expected): def match(actual): equal_to(expected[0])(actual[0]) return match assert_that(result, matcher([the_list])) pipeline.run() pipeline = TestPipeline() pairs = [(1, 2), (3, 4), (5, 6)] pcoll = pipeline | 'start-pairs' >> Create(pairs) result = pcoll | 'to dict' >> combine.ToDict() def matcher(): def match(actual): equal_to([1])([len(actual)]) equal_to(pairs)(actual[0].iteritems()) return match assert_that(result, matcher()) pipeline.run()
def test_window_preserved(self): expected_timestamp = timestamp.Timestamp(5) expected_window = window.IntervalWindow(1.0, 2.0) class AddWindowDoFn(beam.DoFn): def process(self, element): yield WindowedValue( element, expected_timestamp, [expected_window]) pipeline = TestPipeline() data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] expected_windows = [ TestWindowedValue(kv, expected_timestamp, [expected_window]) for kv in data] before_identity = (pipeline | 'start' >> beam.Create(data) | 'add_windows' >> beam.ParDo(AddWindowDoFn())) assert_that(before_identity, equal_to(expected_windows), label='before_identity', reify_windows=True) after_identity = (before_identity | 'window' >> beam.WindowInto( beam.transforms.util._IdentityWindowFn( coders.IntervalWindowCoder()))) assert_that(after_identity, equal_to(expected_windows), label='after_identity', reify_windows=True) pipeline.run()
def test_no_window_context_fails(self): expected_timestamp = timestamp.Timestamp(5) # Assuming the default window function is window.GlobalWindows. expected_window = window.GlobalWindow() class AddTimestampDoFn(beam.DoFn): def process(self, element): yield window.TimestampedValue(element, expected_timestamp) pipeline = TestPipeline() data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] expected_windows = [ TestWindowedValue(kv, expected_timestamp, [expected_window]) for kv in data] before_identity = (pipeline | 'start' >> beam.Create(data) | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn())) assert_that(before_identity, equal_to(expected_windows), label='before_identity', reify_windows=True) after_identity = (before_identity | 'window' >> beam.WindowInto( beam.transforms.util._IdentityWindowFn( coders.GlobalWindowCoder())) # This DoFn will return TimestampedValues, making # WindowFn.AssignContext passed to IdentityWindowFn # contain a window of None. IdentityWindowFn should # raise an exception. | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn())) assert_that(after_identity, equal_to(expected_windows), label='after_identity', reify_windows=True) with self.assertRaisesRegexp(ValueError, r'window.*None.*add_timestamps2'): pipeline.run()
def test_timestamped_with_combiners(self): with TestPipeline() as p: result = (p # Create some initial test values. | 'start' >> Create([(k, k) for k in range(10)]) # The purpose of the WindowInto transform is to establish a # FixedWindows windowing function for the PCollection. # It does not bucket elements into windows since the timestamps # from Create are not spaced 5 ms apart and very likely they all # fall into the same window. | 'w' >> WindowInto(FixedWindows(5)) # Generate timestamped values using the values as timestamps. # Now there are values 5 ms apart and since Map propagates the # windowing function from input to output the output PCollection # will have elements falling into different 5ms windows. | Map(lambda x_t2: TimestampedValue(x_t2[0], x_t2[1])) # We add a 'key' to each value representing the index of the # window. This is important since there is no guarantee of # order for the elements of a PCollection. | Map(lambda v: (v / 5, v))) # Sum all elements associated with a key and window. Although it # is called CombinePerKey it is really CombinePerKeyAndWindow the # same way GroupByKey is really GroupByKeyAndWindow. sum_per_window = result | CombinePerKey(sum) # Compute mean per key and window. mean_per_window = result | combiners.Mean.PerKey() assert_that(sum_per_window, equal_to([(0, 10), (1, 35)]), label='assert:sum') assert_that(mean_per_window, equal_to([(0, 2.0), (1, 7.0)]), label='assert:mean')
def test_after_count(self): with TestPipeline() as p: def construct_timestamped(k_t): return TimestampedValue((k_t[0], k_t[1]), k_t[1]) def format_result(k_v): return ('%s-%s' % (k_v[0], len(k_v[1])), set(k_v[1])) result = (p | beam.Create([1, 2, 3, 4, 5, 10, 11]) | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)]) | beam.Map(construct_timestamped) | beam.WindowInto(FixedWindows(10), trigger=AfterCount(3), accumulation_mode=AccumulationMode.DISCARDING) | beam.GroupByKey() | beam.Map(format_result)) assert_that(result, equal_to( list( { 'A-5': {1, 2, 3, 4, 5}, # A-10, A-11 never emitted due to AfterCount(3) never firing. 'B-4': {6, 7, 8, 9}, 'B-3': {10, 15, 16}, }.items() )))
def test_read_messages_timestamp_attribute_rfc3339_success(self, mock_pubsub): data = 'data' message_id = 'message_id' attributes = {'time': '2018-03-12T13:37:01.234567Z'} publish_time = '2018-03-12T13:37:01.234567Z' payloads = [ create_client_message(data, message_id, attributes, publish_time)] expected_elements = [ TestWindowedValue( PubsubMessage(data, attributes), timestamp.Timestamp.from_rfc3339(attributes['time']), [window.GlobalWindow()]), ] mock_pubsub.Client = functools.partial(FakePubsubClient, payloads) mock_pubsub.subscription.AutoAck = FakeAutoAck p = TestPipeline() p.options.view_as(StandardOptions).streaming = True pcoll = (p | ReadFromPubSub( 'projects/fakeprj/topics/a_topic', None, 'a_label', with_attributes=True, timestamp_attribute='time')) assert_that(pcoll, equal_to(expected_elements), reify_windows=True) p.run()
def test_read_messages_timestamp_attribute_missing(self, mock_pubsub): data = 'data' attributes = {} publish_time_secs = 1520861821 publish_time_nanos = 234567000 publish_time = '2018-03-12T13:37:01.234567Z' ack_id = 'ack_id' pull_response = test_utils.create_pull_response([ test_utils.PullResponseMessage( data, attributes, publish_time_secs, publish_time_nanos, ack_id) ]) expected_elements = [ TestWindowedValue( PubsubMessage(data, attributes), timestamp.Timestamp.from_rfc3339(publish_time), [window.GlobalWindow()]), ] mock_pubsub.return_value.pull.return_value = pull_response p = TestPipeline() p.options.view_as(StandardOptions).streaming = True pcoll = (p | ReadFromPubSub( 'projects/fakeprj/topics/a_topic', None, None, with_attributes=True, timestamp_attribute='nonexistent')) assert_that(pcoll, equal_to(expected_elements), reify_windows=True) p.run() mock_pubsub.return_value.acknowledge.assert_has_calls([ mock.call(mock.ANY, [ack_id])])
def test_pardo_side_inputs(self): def cross_product(elem, sides): for side in sides: yield elem, side with self.create_pipeline() as p: main = p | 'main' >> beam.Create(['a', 'b', 'c']) side = p | 'side' >> beam.Create(['x', 'y']) assert_that(main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)), equal_to([('a', 'x'), ('b', 'x'), ('c', 'x'), ('a', 'y'), ('b', 'y'), ('c', 'y')])) # Now with some windowing. pcoll = p | beam.Create(range(10)) | beam.Map( lambda t: window.TimestampedValue(t, t)) # Intentionally choosing non-aligned windows to highlight the transition. main = pcoll | 'WindowMain' >> beam.WindowInto(window.FixedWindows(5)) side = pcoll | 'WindowSide' >> beam.WindowInto(window.FixedWindows(7)) res = main | beam.Map(lambda x, s: (x, sorted(s)), beam.pvalue.AsList(side)) assert_that( res, equal_to([ # The window [0, 5) maps to the window [0, 7). (0, range(7)), (1, range(7)), (2, range(7)), (3, range(7)), (4, range(7)), # The window [5, 10) maps to the window [7, 14). (5, range(7, 10)), (6, range(7, 10)), (7, range(7, 10)), (8, range(7, 10)), (9, range(7, 10))]), label='windowed')
def test_read_messages_timestamp_attribute_milli_success(self, mock_pubsub): data = b'data' attributes = {'time': '1337'} publish_time_secs = 1520861821 publish_time_nanos = 234567000 ack_id = 'ack_id' pull_response = test_utils.create_pull_response([ test_utils.PullResponseMessage( data, attributes, publish_time_secs, publish_time_nanos, ack_id) ]) expected_elements = [ TestWindowedValue( PubsubMessage(data, attributes), timestamp.Timestamp(micros=int(attributes['time']) * 1000), [window.GlobalWindow()]), ] mock_pubsub.return_value.pull.return_value = pull_response options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) pcoll = (p | ReadFromPubSub( 'projects/fakeprj/topics/a_topic', None, None, with_attributes=True, timestamp_attribute='time')) assert_that(pcoll, equal_to(expected_elements), reify_windows=True) p.run() mock_pubsub.return_value.acknowledge.assert_has_calls([ mock.call(mock.ANY, [ack_id])])
def run_pipeline(self, count_implementation, factor=1): p = TestPipeline() words = p | beam.Create(['CAT', 'DOG', 'CAT', 'CAT', 'DOG']) result = words | count_implementation assert_that( result, equal_to([('CAT', (3 * factor)), ('DOG', (2 * factor))])) p.run()
def test_run_direct(self): file_name = self._create_temp_file('aaaa\nbbbb\ncccc\ndddd') pipeline = TestPipeline() pcoll = pipeline | beam.io.Read(LineSource(file_name)) assert_that(pcoll, equal_to(['aaaa', 'bbbb', 'cccc', 'dddd'])) pipeline.run()
def test_compute_top_sessions(self): p = TestPipeline() edits = p | beam.Create(self.EDITS) result = edits | top_wikipedia_sessions.ComputeTopSessions(1.0) assert_that(result, equal_to(self.EXPECTED)) p.run()
def test_global_fanout(self): with TestPipeline() as p: result = ( p | beam.Create(range(100)) | beam.CombineGlobally(combine.MeanCombineFn()).with_fanout(11)) assert_that(result, equal_to([49.5]))
def test_basic_execution_sideinputs(self): options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) main_stream = (p | 'main TestStream' >> TestStream() .advance_watermark_to(10) .add_elements(['e'])) side_stream = (p | 'side TestStream' >> TestStream() .add_elements([window.TimestampedValue(2, 2)]) .add_elements([window.TimestampedValue(1, 1)]) .add_elements([window.TimestampedValue(7, 7)]) .add_elements([window.TimestampedValue(4, 4)]) ) class RecordFn(beam.DoFn): def process(self, elm=beam.DoFn.ElementParam, ts=beam.DoFn.TimestampParam, side=beam.DoFn.SideInputParam): yield (elm, ts, side) records = (main_stream # pylint: disable=unused-variable | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_stream))) assert_that(records, equal_to([('e', Timestamp(10), [2, 1, 7, 4])])) p.run()
def test_reified_value_assert_fail_unmatched_window(self): expected = [TestWindowedValue(v, MIN_TIMESTAMP, [IntervalWindow(0, 1)]) for v in [1, 2, 3]] with self.assertRaises(Exception): with TestPipeline() as p: assert_that(p | Create([2, 3, 1]), equal_to(expected), reify_windows=True)
def check_many_files(output_pcs): dest_file_pc = output_pcs[bqfl.WriteRecordsToFile.WRITTEN_FILE_TAG] spilled_records_pc = output_pcs[ bqfl.WriteRecordsToFile.UNWRITTEN_RECORD_TAG] spilled_records_count = (spilled_records_pc | beam.combiners.Count.Globally()) assert_that(spilled_records_count, equal_to([3]), label='spilled count') files_per_dest = (dest_file_pc | beam.Map(lambda x: x).with_output_types( beam.typehints.KV[str, str]) | beam.combiners.Count.PerKey()) files_per_dest = ( files_per_dest | "GetDests" >> beam.Map( lambda x: (bigquery_tools.get_hashable_destination(x[0]), x[1]))) # Only table1 and table3 get files. table2 records get spilled. assert_that(files_per_dest, equal_to([('project1:dataset1.table1', 1), ('project1:dataset1.table3', 1)]), label='file count') # Check that the files exist _ = dest_file_pc | beam.Map(lambda x: x[1]) | beam.Map( lambda x: hamcrest_assert(os.path.exists(x), is_(True)))
def test_model_composite_triggers(self): pipeline_options = PipelineOptions() pipeline_options.view_as(StandardOptions).streaming = True with TestPipeline(options=pipeline_options) as p: test_stream = (TestStream() .advance_watermark_to(10) .add_elements(['a', 'a', 'a', 'b', 'b']) .advance_watermark_to(70) .add_elements([TimestampedValue('a', 10), TimestampedValue('a', 10), TimestampedValue('c', 10), TimestampedValue('c', 10)]) .advance_processing_time(600)) pcollection = (p | test_stream | 'pair_with_one' >> beam.Map(lambda x: (x, 1))) counts = ( # [START model_composite_triggers] pcollection | WindowInto( FixedWindows(1 * 60), trigger=AfterWatermark( late=AfterProcessingTime(10 * 60)), accumulation_mode=AccumulationMode.DISCARDING) # [END model_composite_triggers] | 'group' >> beam.GroupByKey() | 'count' >> beam.Map( lambda word_ones: (word_ones[0], sum(word_ones[1])))) assert_that(counts, equal_to([('a', 3), ('b', 2), ('a', 2), ('c', 2)]))
def test_reified_value_assert_fail_unmatched_timestamp(self): expected = [TestWindowedValue(v, 1, [GlobalWindow()]) for v in [1, 2, 3]] with self.assertRaises(Exception): with TestPipeline() as p: assert_that(p | Create([2, 3, 1]), equal_to(expected), reify_windows=True)
def test_basic_execution(self): test_stream = (TestStream() .advance_watermark_to(10) .add_elements(['a', 'b', 'c']) .advance_watermark_to(20) .add_elements(['d']) .add_elements(['e']) .advance_processing_time(10) .advance_watermark_to(300) .add_elements([TimestampedValue('late', 12)]) .add_elements([TimestampedValue('last', 310)])) class RecordFn(beam.DoFn): def process(self, element=beam.DoFn.ElementParam, timestamp=beam.DoFn.TimestampParam): yield (element, timestamp) options = PipelineOptions() options.view_as(StandardOptions).streaming = True p = TestPipeline(options=options) my_record_fn = RecordFn() records = p | test_stream | beam.ParDo(my_record_fn) assert_that(records, equal_to([ ('a', timestamp.Timestamp(10)), ('b', timestamp.Timestamp(10)), ('c', timestamp.Timestamp(10)), ('d', timestamp.Timestamp(20)), ('e', timestamp.Timestamp(20)), ('late', timestamp.Timestamp(12)), ('last', timestamp.Timestamp(310)),])) p.run()
def test_std_dev_usage(self): with TestPipeline() as p: output = filter_by_std_dev(p, self.test_path) assert_that(output, equal_to(['1.1.1.2', '1.1.1.3', '1.1.1.4']))
def test_stats_pipeline(self): # input with three examples. examples = [{'a': np.array([1.0, 2.0]), 'b': np.array(['a', 'b', 'c', 'e']), 'c': np.linspace(1, 500, 500, dtype=np.int32)}, {'a': np.array([3.0, 4.0, np.NaN, 5.0]), 'b': np.array(['a', 'c', 'd', 'a']), 'c': np.linspace(501, 1250, 750, dtype=np.int32)}, {'a': np.array([1.0]), 'b': np.array(['a', 'b', 'c', 'd']), 'c': np.linspace(1251, 3000, 1750, dtype=np.int32)}] expected_result = text_format.Parse(""" datasets { num_examples: 3 features { name: 'a' type: FLOAT num_stats { common_stats { num_non_missing: 3 num_missing: 0 min_num_values: 1 max_num_values: 4 avg_num_values: 2.33333333 tot_num_values: 7 num_values_histogram { buckets { low_value: 1.0 high_value: 1.0 sample_count: 1.0 } buckets { low_value: 1.0 high_value: 4.0 sample_count: 1.0 } buckets { low_value: 4.0 high_value: 4.0 sample_count: 1.0 } type: QUANTILES } } mean: 2.66666666 std_dev: 1.49071198 num_zeros: 0 min: 1.0 max: 5.0 median: 3.0 histograms { num_nan: 1 buckets { low_value: 1.0 high_value: 2.3333333 sample_count: 2.9866667 } buckets { low_value: 2.3333333 high_value: 3.6666667 sample_count: 1.0066667 } buckets { low_value: 3.6666667 high_value: 5.0 sample_count: 2.0066667 } type: STANDARD } histograms { num_nan: 1 buckets { low_value: 1.0 high_value: 1.0 sample_count: 1.5 } buckets { low_value: 1.0 high_value: 3.0 sample_count: 1.5 } buckets { low_value: 3.0 high_value: 4.0 sample_count: 1.5 } buckets { low_value: 4.0 high_value: 5.0 sample_count: 1.5 } type: QUANTILES } } } features { name: 'c' type: INT num_stats { common_stats { num_non_missing: 3 num_missing: 0 min_num_values: 500 max_num_values: 1750 avg_num_values: 1000.0 tot_num_values: 3000 num_values_histogram { buckets { low_value: 500.0 high_value: 500.0 sample_count: 1.0 } buckets { low_value: 500.0 high_value: 1750.0 sample_count: 1.0 } buckets { low_value: 1750.0 high_value: 1750.0 sample_count: 1.0 } type: QUANTILES } } mean: 1500.5 std_dev: 866.025355672 min: 1.0 max: 3000.0 median: 1501.0 histograms { buckets { low_value: 1.0 high_value: 1000.66666667 sample_count: 999.666666667 } buckets { low_value: 1000.66666667 high_value: 2000.33333333 sample_count: 999.666666667 } buckets { low_value: 2000.33333333 high_value: 3000.0 sample_count: 1000.66666667 } type: STANDARD } histograms { buckets { low_value: 1.0 high_value: 751.0 sample_count: 750.0 } buckets { low_value: 751.0 high_value: 1501.0 sample_count: 750.0 } buckets { low_value: 1501.0 high_value: 2250.0 sample_count: 750.0 } buckets { low_value: 2250.0 high_value: 3000.0 sample_count: 750.0 } type: QUANTILES } } } features { name: "b" type: STRING string_stats { common_stats { num_non_missing: 3 min_num_values: 4 max_num_values: 4 avg_num_values: 4.0 tot_num_values: 12 num_values_histogram { buckets { low_value: 4.0 high_value: 4.0 sample_count: 1.0 } buckets { low_value: 4.0 high_value: 4.0 sample_count: 1.0 } buckets { low_value: 4.0 high_value: 4.0 sample_count: 1.0 } type: QUANTILES } } unique: 5 top_values { value: "a" frequency: 4.0 } top_values { value: "c" frequency: 3.0 } avg_length: 1.0 rank_histogram { buckets { low_rank: 0 high_rank: 0 label: "a" sample_count: 4.0 } buckets { low_rank: 1 high_rank: 1 label: "c" sample_count: 3.0 } buckets { low_rank: 2 high_rank: 2 label: "d" sample_count: 2.0 } } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) with beam.Pipeline() as p: options = stats_options.StatsOptions( num_top_values=2, num_rank_histogram_buckets=3, num_values_histogram_buckets=3, num_histogram_buckets=3, num_quantiles_histogram_buckets=4, epsilon=0.001) result = ( p | beam.Create(examples) | stats_api.GenerateStatistics(options)) util.assert_that( result, test_util.make_dataset_feature_stats_list_proto_equal_fn( self, expected_result))
def test_assert_that(self): # TODO: figure out a way for fn_api_runner to parse and raise the # underlying exception. with self.assertRaisesRegexp(Exception, 'Failed assert'): with self.create_pipeline() as p: assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
def test_combine_per_key(self): with self.create_pipeline() as p: res = (p | beam.Create([('a', 1), ('a', 2), ('b', 3)]) | beam.CombinePerKey(beam.combiners.MeanCombineFn())) assert_that(res, equal_to([('a', 1.5), ('b', 3.0)]))
def test_reshuffle(self): with self.create_pipeline() as p: assert_that(p | beam.Create([1, 2, 3]) | beam.Reshuffle(), equal_to([1, 2, 3]))
def test_pardo_side_and_main_outputs(self): def even_odd(elem): yield elem yield beam.pvalue.TaggedOutput('odd' if elem % 2 else 'even', elem) with self.create_pipeline() as p: ints = p | beam.Create([1, 2, 3]) named = ints | 'named' >> beam.FlatMap( even_odd).with_outputs('even', 'odd', main='all') assert_that(named.all, equal_to([1, 2, 3]), label='named.all') assert_that(named.even, equal_to([2]), label='named.even') assert_that(named.odd, equal_to([1, 3]), label='named.odd') unnamed = ints | 'unnamed' >> beam.FlatMap(even_odd).with_outputs() unnamed[None] | beam.Map(id) # pylint: disable=expression-not-assigned assert_that(unnamed[None], equal_to([1, 2, 3]), label='unnamed.all') assert_that(unnamed.even, equal_to([2]), label='unnamed.even') assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd')
def test_stats_pipeline_with_examples_with_no_values(self): examples = [{'a': np.array([], dtype=np.floating), 'b': np.array([], dtype=np.object), 'c': np.array([], dtype=np.int32), 'w': np.array([2])}, {'a': np.array([], dtype=np.floating), 'b': np.array([], dtype=np.object), 'c': np.array([], dtype=np.int32), 'w': np.array([2])}, {'a': np.array([], dtype=np.floating), 'b': np.array([], dtype=np.object), 'c': np.array([], dtype=np.int32), 'w': np.array([2])}] expected_result = text_format.Parse( """ datasets{ num_examples: 3 features { name: 'a' type: FLOAT num_stats { common_stats { num_non_missing: 3 num_values_histogram { buckets { sample_count: 1.5 } buckets { sample_count: 1.5 } type: QUANTILES } weighted_common_stats { num_non_missing: 6 } } } } features { name: 'b' type: STRING string_stats { common_stats { num_non_missing: 3 num_values_histogram { buckets { sample_count: 1.5 } buckets { sample_count: 1.5 } type: QUANTILES } weighted_common_stats { num_non_missing: 6 } } } } features { name: 'c' type: INT num_stats { common_stats { num_non_missing: 3 num_values_histogram { buckets { sample_count: 1.5 } buckets { sample_count: 1.5 } type: QUANTILES } weighted_common_stats { num_non_missing: 6 } } } } } """, statistics_pb2.DatasetFeatureStatisticsList()) with beam.Pipeline() as p: options = stats_options.StatsOptions( weight_feature='w', num_top_values=1, num_rank_histogram_buckets=1, num_values_histogram_buckets=2, num_histogram_buckets=1, num_quantiles_histogram_buckets=1, epsilon=0.001) result = ( p | beam.Create(examples) | stats_api.GenerateStatistics(options)) util.assert_that( result, test_util.make_dataset_feature_stats_list_proto_equal_fn( self, expected_result))
def testCalibrationPlot(self): computations = calibration_plot.CalibrationPlot( num_buckets=10).computations() histogram = computations[0] plot = computations[1] example1 = { 'labels': np.array([0.0]), 'predictions': np.array([0.2]), 'example_weights': np.array([1.0]) } example2 = { 'labels': np.array([1.0]), 'predictions': np.array([0.8]), 'example_weights': np.array([2.0]) } example3 = { 'labels': np.array([0.0]), 'predictions': np.array([0.5]), 'example_weights': np.array([3.0]) } example4 = { 'labels': np.array([1.0]), 'predictions': np.array([-0.1]), 'example_weights': np.array([4.0]) } example5 = { 'labels': np.array([1.0]), 'predictions': np.array([0.5]), 'example_weights': np.array([5.0]) } example6 = { 'labels': np.array([1.0]), 'predictions': np.array([0.8]), 'example_weights': np.array([6.0]) } example7 = { 'labels': np.array([0.0]), 'predictions': np.array([0.2]), 'example_weights': np.array([7.0]) } example8 = { 'labels': np.array([1.0]), 'predictions': np.array([1.1]), 'example_weights': np.array([8.0]) } with beam.Pipeline() as pipeline: # pylint: disable=no-value-for-parameter result = ( pipeline | 'Create' >> beam.Create([ example1, example2, example3, example4, example5, example6, example7, example8 ]) | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) | 'AddSlice' >> beam.Map(lambda x: ((), x)) | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) | 'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1])))) # pylint: enable=no-value-for-parameter def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey(name='calibration_plot') self.assertIn(key, got_plots) got_plot = got_plots[key] self.assertProtoEquals( """ buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 total_weighted_label { value: 4.0 } total_weighted_refined_prediction { value: -0.4 } num_weighted_examples { value: 4.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.1 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.1 upper_threshold_exclusive: 0.2 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.2 upper_threshold_exclusive: 0.3 total_weighted_label { } total_weighted_refined_prediction { value: 1.6 } num_weighted_examples { value: 8.0 } } buckets { lower_threshold_inclusive: 0.3 upper_threshold_exclusive: 0.4 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.4 upper_threshold_exclusive: 0.5 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 0.6 total_weighted_label { value: 5.0 } total_weighted_refined_prediction { value: 4.0 } num_weighted_examples { value: 8.0 } } buckets { lower_threshold_inclusive: 0.6 upper_threshold_exclusive: 0.7 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.7 upper_threshold_exclusive: 0.8 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.8 upper_threshold_exclusive: 0.9 total_weighted_label { value: 8.0 } total_weighted_refined_prediction { value: 6.4 } num_weighted_examples { value: 8.0 } } buckets { lower_threshold_inclusive: 0.9 upper_threshold_exclusive: 1.0 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf total_weighted_label { value: 8.0 } total_weighted_refined_prediction { value: 8.8 } num_weighted_examples { value: 8.0 } } """, got_plot) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(result, check_result, label='result')
def test_read(self): print('name:', __name__) with self.create_pipeline() as p: lines = p | beam.io.ReadFromText('/etc/profile') assert_that(lines, lambda lines: len(lines) > 0)
def test_synthetic_sdf_step_multiplies_output_elements_count(self): with beam.Pipeline() as p: pcoll = p | beam.Create(list(range(10))) | beam.ParDo( synthetic_pipeline.get_synthetic_sdf_step(0, 0, 10)) assert_that(pcoll | beam.combiners.Count.Globally(), equal_to([100]))
def test_basic_empty_missing(self): """Test that the correct empty result is returned for a missing month.""" with TestPipeline() as p: results = self._get_result_for_month(p, 4) assert_that(results, equal_to([]))
def run(argv=None, assert_results=None): parser = argparse.ArgumentParser() parser.add_argument( '--input_email', required=True, help='Email database, with each line formatted as "name<TAB>email".') parser.add_argument( '--input_phone', required=True, help='Phonebook, with each line formatted as "name<TAB>phone number".') parser.add_argument( '--input_snailmail', required=True, help='Address database, with each line formatted as "name<TAB>address".' ) parser.add_argument('--output_tsv', required=True, help='Tab-delimited output file.') parser.add_argument('--output_stats', required=True, help='Output file for statistics about the input.') known_args, pipeline_args = parser.parse_known_args(argv) # We use the save_main_session option because one or more DoFn's in this # workflow rely on global context (e.g., a module imported at module level). pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = True with beam.Pipeline(options=pipeline_options) as p: # Helper: read a tab-separated key-value mapping from a text file, # escape all quotes/backslashes, and convert it a PCollection of # (key, value) pairs. def read_kv_textfile(label, textfile): return (p | 'Read: %s' % label >> ReadFromText(textfile) | 'Backslash: %s' % label >> beam.Map(lambda x: re.sub(r'\\', r'\\\\', x)) | 'EscapeQuotes: %s' % label >> beam.Map(lambda x: re.sub(r'"', r'\"', x)) | 'Split: %s' % label >> beam.Map(lambda x: re.split(r'\t+', x, 1))) # Read input databases. email = read_kv_textfile('email', known_args.input_email) phone = read_kv_textfile('phone', known_args.input_phone) snailmail = read_kv_textfile('snailmail', known_args.input_snailmail) # Group together all entries under the same name. grouped = (email, phone, snailmail) | 'group_by_name' >> beam.CoGroupByKey() # Prepare tab-delimited output; something like this: # "name"<TAB>"email_1,email_2"<TAB>"phone"<TAB>"first_snailmail_only" tsv_lines = grouped | beam.Map( lambda (name, (email, phone, snailmail)): '\t'.join([ '"%s"' % name, '"%s"' % ','.join(email), '"%s"' % ','.join(phone), '"%s"' % next(iter(snailmail), '') ])) # Compute some stats about our database of people. luddites = grouped | beam.Filter( # People without email. lambda (name, (email, phone, snailmail)): not next(iter(email), None)) writers = grouped | beam.Filter( # People without phones. lambda (name, (email, phone, snailmail)): not next(iter(phone), None)) nomads = grouped | beam.Filter( # People without addresses. lambda (name, (e, p, snailmail)): not next(iter(snailmail), None)) num_luddites = luddites | 'Luddites' >> beam.combiners.Count.Globally() num_writers = writers | 'Writers' >> beam.combiners.Count.Globally() num_nomads = nomads | 'Nomads' >> beam.combiners.Count.Globally() # Write tab-delimited output. # pylint: disable=expression-not-assigned tsv_lines | 'WriteTsv' >> WriteToText(known_args.output_tsv) # TODO(silviuc): Move the assert_results logic to the unit test. if assert_results is not None: expected_luddites, expected_writers, expected_nomads = assert_results assert_that(num_luddites, equal_to([expected_luddites]), label='assert:luddites') assert_that(num_writers, equal_to([expected_writers]), label='assert:writers') assert_that(num_nomads, equal_to([expected_nomads]), label='assert:nomads')
def testEqual(self): with TestPipeline() as p: tokens = p | beam.Create(self.sample_input) result = tokens | beam.CombineGlobally( utils.CalculateCoefficients(0.5)) assert_that(result, equal_to([{'en': 1.0, 'fr': 1.0}]))
def test_basic_empty(self): """Test that the correct empty result is returned for a simple dataset.""" with TestPipeline() as p: results = self._get_result_for_month(p, 3) assert_that(results, equal_to([]))
def testUnsorted(self): with TestPipeline() as p: tokens = p | 'CreateInput' >> beam.Create(self.sample_input) result = tokens | beam.CombineGlobally(utils.SortByCount()) assert_that(result, equal_to([[('c', 9), ('a', 5), ('d', 4), ('b', 2)]]))
def testEvaluateWithSlicingAndDifferentBatchSizes(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = linear_classifier.simple_linear_classifier( None, temp_eval_export_dir) for batch_size in [1, 2, 4, 8]: with beam.Pipeline() as pipeline: example1 = self._makeExample( age=3.0, language='english', label=1.0, slice_key='first_slice') example2 = self._makeExample( age=3.0, language='chinese', label=0.0, slice_key='first_slice') example3 = self._makeExample( age=4.0, language='english', label=0.0, slice_key='second_slice') example4 = self._makeExample( age=5.0, language='chinese', label=1.0, slice_key='second_slice') example5 = self._makeExample( age=5.0, language='chinese', label=1.0, slice_key='second_slice') metrics, plots = ( pipeline | beam.Create([ example1.SerializeToString(), example2.SerializeToString(), example3.SerializeToString(), example4.SerializeToString(), example5.SerializeToString(), ]) | evaluate.Evaluate( eval_saved_model_path=eval_export_dir, add_metrics_callbacks=[_addExampleCountMetricCallback], slice_spec=[ slicer.SingleSliceSpec(), slicer.SingleSliceSpec(columns=['slice_key']) ], desired_batch_size=batch_size)) def check_result(got): try: self.assertEqual(3, len(got), 'got: %s' % got) slices = {} for slice_key, value in got: slices[slice_key] = value overall_slice = () first_slice = (('slice_key', 'first_slice'),) second_slice = (('slice_key', 'second_slice'),) self.assertItemsEqual(slices.keys(), [overall_slice, first_slice, second_slice]) self.assertDictElementsAlmostEqual( slices[overall_slice], { 'accuracy': 0.4, 'label/mean': 0.6, 'my_mean_age': 4.0, 'my_mean_age_times_label': 2.6, 'added_example_count': 5.0 }) self.assertDictElementsAlmostEqual( slices[first_slice], { 'accuracy': 1.0, 'label/mean': 0.5, 'my_mean_age': 3.0, 'my_mean_age_times_label': 1.5, 'added_example_count': 2.0 }) self.assertDictElementsAlmostEqual( slices[second_slice], { 'accuracy': 0.0, 'label/mean': 2.0 / 3.0, 'my_mean_age': 14.0 / 3.0, 'my_mean_age_times_label': 10.0 / 3.0, 'added_example_count': 3.0 }) except AssertionError as err: # This function is redefined every iteration, so it will have the # right value of batch_size. raise util.BeamAssertException('batch_size = %d, error: %s' % (batch_size, err)) # pylint: disable=cell-var-from-loop util.assert_that(metrics, check_result, label='metrics') util.assert_that(plots, util.is_empty(), label='plots')
def test_run_direct(self): file_name = self._create_temp_file(b'aaaa\nbbbb\ncccc\ndddd') with TestPipeline() as pipeline: pcoll = pipeline | beam.io.Read(LineSource(file_name)) assert_that(pcoll, equal_to([b'aaaa', b'bbbb', b'cccc', b'dddd']))
def testLangNotInLangSet(self): with TestPipeline() as p: tokens = p | beam.Create(self.sample_input) result = tokens | beam.ParDo(utils.FilterTokensByLang({'fr'})) assert_that(result, equal_to([]))
def test_multiple_destinations_transform(self): streaming = self.test_pipeline.options.view_as(StandardOptions).streaming if streaming and isinstance(self.test_pipeline.runner, TestDataflowRunner): self.skipTest("TestStream is not supported on TestDataflowRunner") output_table_1 = '%s%s' % (self.output_table, 1) output_table_2 = '%s%s' % (self.output_table, 2) full_output_table_1 = '%s:%s' % (self.project, output_table_1) full_output_table_2 = '%s:%s' % (self.project, output_table_2) schema1 = { 'fields': [{ 'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE' }, { 'name': 'language', 'type': 'STRING', 'mode': 'NULLABLE' }] } schema2 = { 'fields': [{ 'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE' }, { 'name': 'foundation', 'type': 'STRING', 'mode': 'NULLABLE' }] } bad_record = {'language': 1, 'manguage': 2} if streaming: pipeline_verifiers = [ PipelineStateMatcher(PipelineState.RUNNING), BigqueryFullResultStreamingMatcher( project=self.project, query="SELECT name, language FROM %s" % output_table_1, data=[(d['name'], d['language']) for d in _ELEMENTS if 'language' in d]), BigqueryFullResultStreamingMatcher( project=self.project, query="SELECT name, foundation FROM %s" % output_table_2, data=[(d['name'], d['foundation']) for d in _ELEMENTS if 'foundation' in d]) ] else: pipeline_verifiers = [ BigqueryFullResultMatcher( project=self.project, query="SELECT name, language FROM %s" % output_table_1, data=[(d['name'], d['language']) for d in _ELEMENTS if 'language' in d]), BigqueryFullResultMatcher( project=self.project, query="SELECT name, foundation FROM %s" % output_table_2, data=[(d['name'], d['foundation']) for d in _ELEMENTS if 'foundation' in d]) ] args = self.test_pipeline.get_full_options_as_args( on_success_matcher=hc.all_of(*pipeline_verifiers)) with beam.Pipeline(argv=args) as p: if streaming: _SIZE = len(_ELEMENTS) test_stream = ( TestStream().advance_watermark_to(0).add_elements( _ELEMENTS[:_SIZE // 2]).advance_watermark_to(100).add_elements( _ELEMENTS[_SIZE // 2:]).advance_watermark_to_infinity()) input = p | test_stream else: input = p | beam.Create(_ELEMENTS) schema_table_pcv = beam.pvalue.AsDict( p | "MakeSchemas" >> beam.Create([(full_output_table_1, schema1), (full_output_table_2, schema2)])) table_record_pcv = beam.pvalue.AsDict( p | "MakeTables" >> beam.Create([('table1', full_output_table_1), ('table2', full_output_table_2)])) input2 = p | "Broken record" >> beam.Create([bad_record]) input = (input, input2) | beam.Flatten() r = ( input | "WriteWithMultipleDests" >> beam.io.gcp.bigquery.WriteToBigQuery( table=lambda x, tables: (tables['table1'] if 'language' in x else tables['table2']), table_side_inputs=(table_record_pcv, ), schema=lambda dest, table_map: table_map.get(dest, None), schema_side_inputs=(schema_table_pcv, ), insert_retry_strategy=RetryStrategy.RETRY_ON_TRANSIENT_ERROR, method='STREAMING_INSERTS')) assert_that( r[beam.io.gcp.bigquery.BigQueryWriteFn.FAILED_ROWS], equal_to([(full_output_table_1, bad_record)]))
def test_native_source(self): with beam.Pipeline(argv=self.args) as p: result = (p | 'read' >> beam.io.Read( beam.io.BigQuerySource(query=self.query, use_standard_sql=True))) assert_that(result, equal_to(self.get_expected_data()))
def test_timestamp_param_map(self): with TestPipeline() as p: assert_that( p | Create([1, 2]) | beam.Map(lambda _, t=DoFn.TimestampParam: t), equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
def expand(self, pcoll): assert_that(pcoll, self.matcher)
def test_fake_read(self): pipeline = TestPipeline() pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3])) assert_that(pcoll, equal_to([1, 2, 3])) pipeline.run()
def test_read_from_avro(self): path = self._write_data() with TestPipeline() as p: assert_that( p | avroio.ReadFromAvro(path, use_fastavro=self.use_fastavro), equal_to(self.RECORDS))
def testEvaluateWithPlots(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = ( fixed_prediction_estimator.simple_fixed_prediction_estimator( None, temp_eval_export_dir)) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=eval_export_dir, add_metrics_callbacks=[ post_export_metrics.example_count(), post_export_metrics.auc_plots() ]) extractors = [ predict_extractor.PredictExtractor(eval_shared_model), slice_key_extractor.SliceKeyExtractor() ] with beam.Pipeline() as pipeline: example1 = self._makeExample(prediction=0.0, label=1.0) example2 = self._makeExample(prediction=0.7, label=0.0) example3 = self._makeExample(prediction=0.8, label=1.0) example4 = self._makeExample(prediction=1.0, label=1.0) (metrics, plots), _ = ( pipeline | 'Create' >> beam.Create([ example1.SerializeToString(), example2.SerializeToString(), example3.SerializeToString(), example4.SerializeToString() ]) | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts() | 'Extract' >> tfma_unit.Extract(extractors=extractors) # pylint: disable=no-value-for-parameter | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator .ComputeMetricsAndPlots(eval_shared_model=eval_shared_model)) def check_metrics(got): try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertDictElementsAlmostEqual( got_values_dict=value, expected_values_dict={ metric_keys.EXAMPLE_COUNT: 4.0, }) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(metrics, check_metrics, label='metrics') def check_plots(got): try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertDictMatrixRowsAlmostEqual( got_values_dict=value, expected_values_dict={ metric_keys.AUC_PLOTS_MATRICES: [ (8001, [2, 1, 0, 1, 1.0 / 1.0, 1.0 / 3.0]) ], }) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(plots, check_plots, label='plots')
def test_apply_custom_transform(self): pipeline = TestPipeline() pcoll = pipeline | 'pcoll' >> Create([1, 2, 3]) result = pcoll | PipelineTest.CustomTransform() assert_that(result, equal_to([2, 3, 4])) pipeline.run()
def test_input_output_polymorphism(self): one_series = pd.Series([1]) two_series = pd.Series([2]) three_series = pd.Series([3]) proxy = one_series[:0] def equal_to_series(expected): def check(actual): actual = pd.concat(actual) if not expected.equals(actual): raise AssertionError('Series not equal: \n%s\n%s\n' % (expected, actual)) return check with beam.Pipeline() as p: one = p | 'One' >> beam.Create([one_series]) two = p | 'Two' >> beam.Create([two_series]) assert_that( one | 'PcollInPcollOut' >> transforms.DataframeTransform( lambda x: 3 * x, proxy=proxy, yield_elements='pandas'), equal_to_series(three_series), label='CheckPcollInPcollOut') assert_that((one, two) | 'TupleIn' >> transforms.DataframeTransform( lambda x, y: (x + y), (proxy, proxy), yield_elements='pandas'), equal_to_series(three_series), label='CheckTupleIn') assert_that(dict(x=one, y=two) | 'DictIn' >> transforms.DataframeTransform( lambda x, y: (x + y), proxy=dict(x=proxy, y=proxy), yield_elements='pandas'), equal_to_series(three_series), label='CheckDictIn') double, triple = one | 'TupleOut' >> transforms.DataframeTransform( lambda x: (2 * x, 3 * x), proxy, yield_elements='pandas') assert_that(double, equal_to_series(two_series), 'CheckTupleOut0') assert_that(triple, equal_to_series(three_series), 'CheckTupleOut1') res = one | 'DictOut' >> transforms.DataframeTransform( lambda x: {'res': 3 * x}, proxy, yield_elements='pandas') assert_that(res['res'], equal_to_series(three_series), 'CheckDictOut')
def test_create_singleton_pcollection(self): pipeline = TestPipeline() pcoll = pipeline | 'label' >> Create([[1, 2, 3]]) assert_that(pcoll, equal_to([[1, 2, 3]])) pipeline.run()
def test_create(self): with self.create_pipeline() as p: assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
def testEvaluateNoSlicingAddPostExportAndCustomMetrics(self): temp_eval_export_dir = self._getEvalExportDir() _, eval_export_dir = linear_classifier.simple_linear_classifier( None, temp_eval_export_dir) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=eval_export_dir, add_metrics_callbacks=[ _addExampleCountMetricCallback, # Note that since everything runs in-process this doesn't # actually test that the py_func can be correctly recreated # on workers in a distributed context. _addPyFuncMetricCallback, post_export_metrics.example_count(), post_export_metrics.example_weight(example_weight_key='age') ]) extractors = [ predict_extractor.PredictExtractor(eval_shared_model), slice_key_extractor.SliceKeyExtractor() ] with beam.Pipeline() as pipeline: example1 = self._makeExample(age=3.0, language='english', label=1.0) example2 = self._makeExample(age=3.0, language='chinese', label=0.0) example3 = self._makeExample(age=4.0, language='english', label=1.0) example4 = self._makeExample(age=5.0, language='chinese', label=0.0) (metrics, plots), _ = ( pipeline | 'Create' >> beam.Create([ example1.SerializeToString(), example2.SerializeToString(), example3.SerializeToString(), example4.SerializeToString() ]) | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts() | 'Extract' >> tfma_unit.Extract(extractors=extractors) # pylint: disable=no-value-for-parameter | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator .ComputeMetricsAndPlots(eval_shared_model=eval_shared_model)) def check_result(got): try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertDictElementsAlmostEqual( got_values_dict=value, expected_values_dict={ 'accuracy': 1.0, 'label/mean': 0.5, 'my_mean_age': 3.75, 'my_mean_age_times_label': 1.75, 'added_example_count': 4.0, 'py_func_label_sum': 2.0, metric_keys.EXAMPLE_COUNT: 4.0, metric_keys.EXAMPLE_WEIGHT: 15.0 }) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(metrics, check_result, label='metrics') util.assert_that(plots, util.is_empty(), label='plots')