예제 #1
0
  def test_gbk_execution(self):
    test_stream = (TestStream()
                   .advance_watermark_to(10)
                   .add_elements(['a', 'b', 'c'])
                   .advance_watermark_to(20)
                   .add_elements(['d'])
                   .add_elements(['e'])
                   .advance_processing_time(10)
                   .advance_watermark_to(300)
                   .add_elements([TimestampedValue('late', 12)])
                   .add_elements([TimestampedValue('last', 310)]))

    options = PipelineOptions()
    options.view_as(StandardOptions).streaming = True
    p = TestPipeline(options=options)
    records = (p
               | test_stream
               | beam.WindowInto(FixedWindows(15))
               | beam.Map(lambda x: ('k', x))
               | beam.GroupByKey())
    # TODO(BEAM-2519): timestamp assignment for elements from a GBK should
    # respect the TimestampCombiner.  The test below should also verify the
    # timestamps of the outputted elements once this is implemented.
    assert_that(records, equal_to([
        ('k', ['a', 'b', 'c']),
        ('k', ['d', 'e']),
        ('k', ['late']),
        ('k', ['last'])]))
    p.run()
예제 #2
0
 def test_read_from_text_file_pattern(self):
   pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
   assert len(expected_data) == 40
   pipeline = TestPipeline()
   pcoll = pipeline | 'Read' >> ReadFromText(pattern)
   assert_that(pcoll, equal_to(expected_data))
   pipeline.run()
예제 #3
0
  def test_multi(self):

    @ptransform.PTransform.register_urn('multi', None)
    class MutltiTransform(ptransform.PTransform):
      def expand(self, pcolls):
        return {
            'main':
                (pcolls['main1'], pcolls['main2'])
                | beam.Flatten()
                | beam.Map(lambda x, s: x + s,
                           beam.pvalue.AsSingleton(pcolls['side'])),
            'side': pcolls['side'] | beam.Map(lambda x: x + x),
        }

      def to_runner_api_parameter(self, unused_context):
        return 'multi', None

      @staticmethod
      def from_runner_api_parameter(unused_parameter, unused_context):
        return MutltiTransform()

    with beam.Pipeline() as p:
      main1 = p | 'Main1' >> beam.Create(['a', 'bb'], reshuffle=False)
      main2 = p | 'Main2' >> beam.Create(['x', 'yy', 'zzz'], reshuffle=False)
      side = p | 'Side' >> beam.Create(['s'])
      res = dict(main1=main1, main2=main2, side=side) | beam.ExternalTransform(
          'multi', None, expansion_service.ExpansionServiceServicer())
      assert_that(res['main'], equal_to(['as', 'bbs', 'xs', 'yys', 'zzzs']))
      assert_that(res['side'], equal_to(['ss']), label='CheckSide')
예제 #4
0
  def test_payload(self):

    @ptransform.PTransform.register_urn('payload', bytes)
    class PayloadTransform(ptransform.PTransform):
      def __init__(self, payload):
        self._payload = payload

      def expand(self, pcoll):
        return pcoll | beam.Map(lambda x, s: x + s, self._payload)

      def to_runner_api_parameter(self, unused_context):
        return b'payload', self._payload.encode('ascii')

      @staticmethod
      def from_runner_api_parameter(payload, unused_context):
        return PayloadTransform(payload.decode('ascii'))

    with beam.Pipeline() as p:
      res = (
          p
          | beam.Create(['a', 'bb'], reshuffle=False)
          | beam.ExternalTransform(
              'payload', b's',
              expansion_service.ExpansionServiceServicer()))
      assert_that(res, equal_to(['as', 'bbs']))
예제 #5
0
  def test_setting_timestamp(self):
    with TestPipeline() as p:
      unkeyed_items = p | beam.Create([12, 30, 60, 61, 66])
      items = (unkeyed_items | 'key' >> beam.Map(lambda x: ('k', x)))

      def extract_timestamp_from_log_entry(entry):
        return entry[1]

      # [START setting_timestamp]
      class AddTimestampDoFn(beam.DoFn):

        def process(self, element):
          # Extract the numeric Unix seconds-since-epoch timestamp to be
          # associated with the current log entry.
          unix_timestamp = extract_timestamp_from_log_entry(element)
          # Wrap and emit the current entry and new timestamp in a
          # TimestampedValue.
          yield beam.window.TimestampedValue(element, unix_timestamp)

      timestamped_items = items | 'timestamp' >> beam.ParDo(AddTimestampDoFn())
      # [END setting_timestamp]
      fixed_windowed_items = (
          timestamped_items | 'window' >> beam.WindowInto(
              beam.window.FixedWindows(60)))
      summed = (fixed_windowed_items
                | 'group' >> beam.GroupByKey()
                | 'combine' >> beam.CombineValues(sum))
      unkeyed = summed | 'unkey' >> beam.Map(lambda x: x[1])
      assert_that(unkeyed, equal_to([42, 187]))
예제 #6
0
 def test_read_from_text_single_file(self):
   file_name, expected_data = write_data(5)
   assert len(expected_data) == 5
   pipeline = TestPipeline()
   pcoll = pipeline | 'Read' >> ReadFromText(file_name)
   assert_that(pcoll, equal_to(expected_data))
   pipeline.run()
 def test_group_by_key(self):
   with self.create_pipeline() as p:
     res = (p
            | beam.Create([('a', 1), ('a', 2), ('b', 3)])
            | beam.GroupByKey()
            | beam.Map(lambda k_vs: (k_vs[0], sorted(k_vs[1]))))
     assert_that(res, equal_to([('a', [1, 2]), ('b', [3])]))
 def test_pardo(self):
   with self.create_pipeline() as p:
     res = (p
            | beam.Create(['a', 'bc'])
            | beam.Map(lambda e: e * 2)
            | beam.Map(lambda e: e + 'x'))
     assert_that(res, equal_to(['aax', 'bcbcx']))
 def test_pardo_windowed_side_inputs(self):
   with self.create_pipeline() as p:
     # Now with some windowing.
     pcoll = p | beam.Create(list(range(10))) | beam.Map(
         lambda t: window.TimestampedValue(t, t))
     # Intentionally choosing non-aligned windows to highlight the transition.
     main = pcoll | 'WindowMain' >> beam.WindowInto(window.FixedWindows(5))
     side = pcoll | 'WindowSide' >> beam.WindowInto(window.FixedWindows(7))
     res = main | beam.Map(lambda x, s: (x, sorted(s)),
                           beam.pvalue.AsList(side))
     assert_that(
         res,
         equal_to([
             # The window [0, 5) maps to the window [0, 7).
             (0, list(range(7))),
             (1, list(range(7))),
             (2, list(range(7))),
             (3, list(range(7))),
             (4, list(range(7))),
             # The window [5, 10) maps to the window [7, 14).
             (5, list(range(7, 10))),
             (6, list(range(7, 10))),
             (7, list(range(7, 10))),
             (8, list(range(7, 10))),
             (9, list(range(7, 10)))]),
         label='windowed')
 def test_gbk_side_input(self):
   with self.create_pipeline() as p:
     main = p | 'main' >> beam.Create([None])
     side = p | 'side' >> beam.Create([('a', 1)]) | beam.GroupByKey()
     assert_that(
         main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)),
         equal_to([(None, {'a': [1]})]))
예제 #11
0
 def test_read_all_from_avro_many_single_files(self):
   path1 = self._write_data()
   path2 = self._write_data()
   path3 = self._write_data()
   with TestPipeline() as p:
     assert_that(p | Create([path1, path2, path3]) | avroio.ReadAllFromAvro(),
                 equal_to(self.RECORDS * 3))
예제 #12
0
  def test_to_list_and_to_dict(self):
    pipeline = TestPipeline()
    the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
    pcoll = pipeline | 'start' >> Create(the_list)
    result = pcoll | 'to list' >> combine.ToList()

    def matcher(expected):
      def match(actual):
        equal_to(expected[0])(actual[0])
      return match
    assert_that(result, matcher([the_list]))
    pipeline.run()

    pipeline = TestPipeline()
    pairs = [(1, 2), (3, 4), (5, 6)]
    pcoll = pipeline | 'start-pairs' >> Create(pairs)
    result = pcoll | 'to dict' >> combine.ToDict()

    def matcher():
      def match(actual):
        equal_to([1])([len(actual)])
        equal_to(pairs)(actual[0].iteritems())
      return match
    assert_that(result, matcher())
    pipeline.run()
예제 #13
0
  def test_window_preserved(self):
    expected_timestamp = timestamp.Timestamp(5)
    expected_window = window.IntervalWindow(1.0, 2.0)

    class AddWindowDoFn(beam.DoFn):
      def process(self, element):
        yield WindowedValue(
            element, expected_timestamp, [expected_window])

    pipeline = TestPipeline()
    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
    expected_windows = [
        TestWindowedValue(kv, expected_timestamp, [expected_window])
        for kv in data]
    before_identity = (pipeline
                       | 'start' >> beam.Create(data)
                       | 'add_windows' >> beam.ParDo(AddWindowDoFn()))
    assert_that(before_identity, equal_to(expected_windows),
                label='before_identity', reify_windows=True)
    after_identity = (before_identity
                      | 'window' >> beam.WindowInto(
                          beam.transforms.util._IdentityWindowFn(
                              coders.IntervalWindowCoder())))
    assert_that(after_identity, equal_to(expected_windows),
                label='after_identity', reify_windows=True)
    pipeline.run()
예제 #14
0
  def test_no_window_context_fails(self):
    expected_timestamp = timestamp.Timestamp(5)
    # Assuming the default window function is window.GlobalWindows.
    expected_window = window.GlobalWindow()

    class AddTimestampDoFn(beam.DoFn):
      def process(self, element):
        yield window.TimestampedValue(element, expected_timestamp)

    pipeline = TestPipeline()
    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
    expected_windows = [
        TestWindowedValue(kv, expected_timestamp, [expected_window])
        for kv in data]
    before_identity = (pipeline
                       | 'start' >> beam.Create(data)
                       | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn()))
    assert_that(before_identity, equal_to(expected_windows),
                label='before_identity', reify_windows=True)
    after_identity = (before_identity
                      | 'window' >> beam.WindowInto(
                          beam.transforms.util._IdentityWindowFn(
                              coders.GlobalWindowCoder()))
                      # This DoFn will return TimestampedValues, making
                      # WindowFn.AssignContext passed to IdentityWindowFn
                      # contain a window of None. IdentityWindowFn should
                      # raise an exception.
                      | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn()))
    assert_that(after_identity, equal_to(expected_windows),
                label='after_identity', reify_windows=True)
    with self.assertRaisesRegexp(ValueError, r'window.*None.*add_timestamps2'):
      pipeline.run()
예제 #15
0
 def test_timestamped_with_combiners(self):
   with TestPipeline() as p:
     result = (p
               # Create some initial test values.
               | 'start' >> Create([(k, k) for k in range(10)])
               # The purpose of the WindowInto transform is to establish a
               # FixedWindows windowing function for the PCollection.
               # It does not bucket elements into windows since the timestamps
               # from Create are not spaced 5 ms apart and very likely they all
               # fall into the same window.
               | 'w' >> WindowInto(FixedWindows(5))
               # Generate timestamped values using the values as timestamps.
               # Now there are values 5 ms apart and since Map propagates the
               # windowing function from input to output the output PCollection
               # will have elements falling into different 5ms windows.
               | Map(lambda x_t2: TimestampedValue(x_t2[0], x_t2[1]))
               # We add a 'key' to each value representing the index of the
               # window. This is important since there is no guarantee of
               # order for the elements of a PCollection.
               | Map(lambda v: (v / 5, v)))
     # Sum all elements associated with a key and window. Although it
     # is called CombinePerKey it is really CombinePerKeyAndWindow the
     # same way GroupByKey is really GroupByKeyAndWindow.
     sum_per_window = result | CombinePerKey(sum)
     # Compute mean per key and window.
     mean_per_window = result | combiners.Mean.PerKey()
     assert_that(sum_per_window, equal_to([(0, 10), (1, 35)]),
                 label='assert:sum')
     assert_that(mean_per_window, equal_to([(0, 2.0), (1, 7.0)]),
                 label='assert:mean')
예제 #16
0
  def test_after_count(self):
    with TestPipeline() as p:
      def construct_timestamped(k_t):
        return TimestampedValue((k_t[0], k_t[1]), k_t[1])

      def format_result(k_v):
        return ('%s-%s' % (k_v[0], len(k_v[1])), set(k_v[1]))

      result = (p
                | beam.Create([1, 2, 3, 4, 5, 10, 11])
                | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)])
                | beam.Map(construct_timestamped)
                | beam.WindowInto(FixedWindows(10), trigger=AfterCount(3),
                                  accumulation_mode=AccumulationMode.DISCARDING)
                | beam.GroupByKey()
                | beam.Map(format_result))
      assert_that(result, equal_to(
          list(
              {
                  'A-5': {1, 2, 3, 4, 5},
                  # A-10, A-11 never emitted due to AfterCount(3) never firing.
                  'B-4': {6, 7, 8, 9},
                  'B-3': {10, 15, 16},
              }.items()
          )))
예제 #17
0
  def test_read_messages_timestamp_attribute_rfc3339_success(self, mock_pubsub):
    data = 'data'
    message_id = 'message_id'
    attributes = {'time': '2018-03-12T13:37:01.234567Z'}
    publish_time = '2018-03-12T13:37:01.234567Z'
    payloads = [
        create_client_message(data, message_id, attributes, publish_time)]
    expected_elements = [
        TestWindowedValue(
            PubsubMessage(data, attributes),
            timestamp.Timestamp.from_rfc3339(attributes['time']),
            [window.GlobalWindow()]),
    ]

    mock_pubsub.Client = functools.partial(FakePubsubClient, payloads)
    mock_pubsub.subscription.AutoAck = FakeAutoAck

    p = TestPipeline()
    p.options.view_as(StandardOptions).streaming = True
    pcoll = (p
             | ReadFromPubSub(
                 'projects/fakeprj/topics/a_topic', None, 'a_label',
                 with_attributes=True, timestamp_attribute='time'))
    assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
    p.run()
예제 #18
0
  def test_read_messages_timestamp_attribute_missing(self, mock_pubsub):
    data = 'data'
    attributes = {}
    publish_time_secs = 1520861821
    publish_time_nanos = 234567000
    publish_time = '2018-03-12T13:37:01.234567Z'
    ack_id = 'ack_id'
    pull_response = test_utils.create_pull_response([
        test_utils.PullResponseMessage(
            data, attributes, publish_time_secs, publish_time_nanos, ack_id)
    ])
    expected_elements = [
        TestWindowedValue(
            PubsubMessage(data, attributes),
            timestamp.Timestamp.from_rfc3339(publish_time),
            [window.GlobalWindow()]),
    ]
    mock_pubsub.return_value.pull.return_value = pull_response

    p = TestPipeline()
    p.options.view_as(StandardOptions).streaming = True
    pcoll = (p
             | ReadFromPubSub(
                 'projects/fakeprj/topics/a_topic', None, None,
                 with_attributes=True, timestamp_attribute='nonexistent'))
    assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
    p.run()
    mock_pubsub.return_value.acknowledge.assert_has_calls([
        mock.call(mock.ANY, [ack_id])])
예제 #19
0
  def test_pardo_side_inputs(self):
    def cross_product(elem, sides):
      for side in sides:
        yield elem, side
    with self.create_pipeline() as p:
      main = p | 'main' >> beam.Create(['a', 'b', 'c'])
      side = p | 'side' >> beam.Create(['x', 'y'])
      assert_that(main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)),
                  equal_to([('a', 'x'), ('b', 'x'), ('c', 'x'),
                            ('a', 'y'), ('b', 'y'), ('c', 'y')]))

      # Now with some windowing.
      pcoll = p | beam.Create(range(10)) | beam.Map(
          lambda t: window.TimestampedValue(t, t))
      # Intentionally choosing non-aligned windows to highlight the transition.
      main = pcoll | 'WindowMain' >> beam.WindowInto(window.FixedWindows(5))
      side = pcoll | 'WindowSide' >> beam.WindowInto(window.FixedWindows(7))
      res = main | beam.Map(lambda x, s: (x, sorted(s)),
                            beam.pvalue.AsList(side))
      assert_that(
          res,
          equal_to([
              # The window [0, 5) maps to the window [0, 7).
              (0, range(7)),
              (1, range(7)),
              (2, range(7)),
              (3, range(7)),
              (4, range(7)),
              # The window [5, 10) maps to the window [7, 14).
              (5, range(7, 10)),
              (6, range(7, 10)),
              (7, range(7, 10)),
              (8, range(7, 10)),
              (9, range(7, 10))]),
          label='windowed')
예제 #20
0
  def test_read_messages_timestamp_attribute_milli_success(self, mock_pubsub):
    data = b'data'
    attributes = {'time': '1337'}
    publish_time_secs = 1520861821
    publish_time_nanos = 234567000
    ack_id = 'ack_id'
    pull_response = test_utils.create_pull_response([
        test_utils.PullResponseMessage(
            data, attributes, publish_time_secs, publish_time_nanos, ack_id)
    ])
    expected_elements = [
        TestWindowedValue(
            PubsubMessage(data, attributes),
            timestamp.Timestamp(micros=int(attributes['time']) * 1000),
            [window.GlobalWindow()]),
    ]
    mock_pubsub.return_value.pull.return_value = pull_response

    options = PipelineOptions([])
    options.view_as(StandardOptions).streaming = True
    p = TestPipeline(options=options)
    pcoll = (p
             | ReadFromPubSub(
                 'projects/fakeprj/topics/a_topic', None, None,
                 with_attributes=True, timestamp_attribute='time'))
    assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
    p.run()
    mock_pubsub.return_value.acknowledge.assert_has_calls([
        mock.call(mock.ANY, [ack_id])])
 def run_pipeline(self, count_implementation, factor=1):
   p = TestPipeline()
   words = p | beam.Create(['CAT', 'DOG', 'CAT', 'CAT', 'DOG'])
   result = words | count_implementation
   assert_that(
       result, equal_to([('CAT', (3 * factor)), ('DOG', (2 * factor))]))
   p.run()
예제 #22
0
  def test_run_direct(self):
    file_name = self._create_temp_file('aaaa\nbbbb\ncccc\ndddd')
    pipeline = TestPipeline()
    pcoll = pipeline | beam.io.Read(LineSource(file_name))
    assert_that(pcoll, equal_to(['aaaa', 'bbbb', 'cccc', 'dddd']))

    pipeline.run()
  def test_compute_top_sessions(self):
    p = TestPipeline()
    edits = p | beam.Create(self.EDITS)
    result = edits | top_wikipedia_sessions.ComputeTopSessions(1.0)

    assert_that(result, equal_to(self.EXPECTED))
    p.run()
예제 #24
0
파일: combiners_test.py 프로젝트: lyft/beam
 def test_global_fanout(self):
   with TestPipeline() as p:
     result = (
         p
         | beam.Create(range(100))
         | beam.CombineGlobally(combine.MeanCombineFn()).with_fanout(11))
     assert_that(result, equal_to([49.5]))
예제 #25
0
  def test_basic_execution_sideinputs(self):
    options = PipelineOptions()
    options.view_as(StandardOptions).streaming = True
    p = TestPipeline(options=options)

    main_stream = (p
                   | 'main TestStream' >> TestStream()
                   .advance_watermark_to(10)
                   .add_elements(['e']))
    side_stream = (p
                   | 'side TestStream' >> TestStream()
                   .add_elements([window.TimestampedValue(2, 2)])
                   .add_elements([window.TimestampedValue(1, 1)])
                   .add_elements([window.TimestampedValue(7, 7)])
                   .add_elements([window.TimestampedValue(4, 4)])
                  )

    class RecordFn(beam.DoFn):
      def process(self,
                  elm=beam.DoFn.ElementParam,
                  ts=beam.DoFn.TimestampParam,
                  side=beam.DoFn.SideInputParam):
        yield (elm, ts, side)

    records = (main_stream        # pylint: disable=unused-variable
               | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_stream)))

    assert_that(records, equal_to([('e', Timestamp(10), [2, 1, 7, 4])]))

    p.run()
예제 #26
0
 def test_reified_value_assert_fail_unmatched_window(self):
   expected = [TestWindowedValue(v, MIN_TIMESTAMP, [IntervalWindow(0, 1)])
               for v in [1, 2, 3]]
   with self.assertRaises(Exception):
     with TestPipeline() as p:
       assert_that(p | Create([2, 3, 1]), equal_to(expected),
                   reify_windows=True)
    def check_many_files(output_pcs):
      dest_file_pc = output_pcs[bqfl.WriteRecordsToFile.WRITTEN_FILE_TAG]
      spilled_records_pc = output_pcs[
          bqfl.WriteRecordsToFile.UNWRITTEN_RECORD_TAG]

      spilled_records_count = (spilled_records_pc |
                               beam.combiners.Count.Globally())
      assert_that(spilled_records_count, equal_to([3]), label='spilled count')

      files_per_dest = (dest_file_pc
                        | beam.Map(lambda x: x).with_output_types(
                            beam.typehints.KV[str, str])
                        | beam.combiners.Count.PerKey())
      files_per_dest = (
          files_per_dest
          | "GetDests" >> beam.Map(
              lambda x: (bigquery_tools.get_hashable_destination(x[0]),
                         x[1])))

      # Only table1 and table3 get files. table2 records get spilled.
      assert_that(files_per_dest,
                  equal_to([('project1:dataset1.table1', 1),
                            ('project1:dataset1.table3', 1)]),
                  label='file count')

      # Check that the files exist
      _ = dest_file_pc | beam.Map(lambda x: x[1]) | beam.Map(
          lambda x: hamcrest_assert(os.path.exists(x), is_(True)))
예제 #28
0
  def test_model_composite_triggers(self):
    pipeline_options = PipelineOptions()
    pipeline_options.view_as(StandardOptions).streaming = True

    with TestPipeline(options=pipeline_options) as p:
      test_stream = (TestStream()
                     .advance_watermark_to(10)
                     .add_elements(['a', 'a', 'a', 'b', 'b'])
                     .advance_watermark_to(70)
                     .add_elements([TimestampedValue('a', 10),
                                    TimestampedValue('a', 10),
                                    TimestampedValue('c', 10),
                                    TimestampedValue('c', 10)])
                     .advance_processing_time(600))
      pcollection = (p
                     | test_stream
                     | 'pair_with_one' >> beam.Map(lambda x: (x, 1)))

      counts = (
          # [START model_composite_triggers]
          pcollection | WindowInto(
              FixedWindows(1 * 60),
              trigger=AfterWatermark(
                  late=AfterProcessingTime(10 * 60)),
              accumulation_mode=AccumulationMode.DISCARDING)
          # [END model_composite_triggers]
          | 'group' >> beam.GroupByKey()
          | 'count' >> beam.Map(
              lambda word_ones: (word_ones[0], sum(word_ones[1]))))
      assert_that(counts, equal_to([('a', 3), ('b', 2), ('a', 2), ('c', 2)]))
예제 #29
0
 def test_reified_value_assert_fail_unmatched_timestamp(self):
   expected = [TestWindowedValue(v, 1, [GlobalWindow()])
               for v in [1, 2, 3]]
   with self.assertRaises(Exception):
     with TestPipeline() as p:
       assert_that(p | Create([2, 3, 1]), equal_to(expected),
                   reify_windows=True)
예제 #30
0
  def test_basic_execution(self):
    test_stream = (TestStream()
                   .advance_watermark_to(10)
                   .add_elements(['a', 'b', 'c'])
                   .advance_watermark_to(20)
                   .add_elements(['d'])
                   .add_elements(['e'])
                   .advance_processing_time(10)
                   .advance_watermark_to(300)
                   .add_elements([TimestampedValue('late', 12)])
                   .add_elements([TimestampedValue('last', 310)]))

    class RecordFn(beam.DoFn):
      def process(self, element=beam.DoFn.ElementParam,
                  timestamp=beam.DoFn.TimestampParam):
        yield (element, timestamp)

    options = PipelineOptions()
    options.view_as(StandardOptions).streaming = True
    p = TestPipeline(options=options)
    my_record_fn = RecordFn()
    records = p | test_stream | beam.ParDo(my_record_fn)
    assert_that(records, equal_to([
        ('a', timestamp.Timestamp(10)),
        ('b', timestamp.Timestamp(10)),
        ('c', timestamp.Timestamp(10)),
        ('d', timestamp.Timestamp(20)),
        ('e', timestamp.Timestamp(20)),
        ('late', timestamp.Timestamp(12)),
        ('last', timestamp.Timestamp(310)),]))
    p.run()
예제 #31
0
 def test_std_dev_usage(self):
     with TestPipeline() as p:
         output = filter_by_std_dev(p, self.test_path)
         assert_that(output, equal_to(['1.1.1.2', '1.1.1.3', '1.1.1.4']))
예제 #32
0
  def test_stats_pipeline(self):
    # input with three examples.
    examples = [{'a': np.array([1.0, 2.0]),
                 'b': np.array(['a', 'b', 'c', 'e']),
                 'c': np.linspace(1, 500, 500, dtype=np.int32)},
                {'a': np.array([3.0, 4.0, np.NaN, 5.0]),
                 'b': np.array(['a', 'c', 'd', 'a']),
                 'c': np.linspace(501, 1250, 750, dtype=np.int32)},
                {'a': np.array([1.0]),
                 'b': np.array(['a', 'b', 'c', 'd']),
                 'c': np.linspace(1251, 3000, 1750, dtype=np.int32)}]

    expected_result = text_format.Parse("""
    datasets {
      num_examples: 3
      features {
        name: 'a'
        type: FLOAT
        num_stats {
          common_stats {
            num_non_missing: 3
            num_missing: 0
            min_num_values: 1
            max_num_values: 4
            avg_num_values: 2.33333333
            tot_num_values: 7
            num_values_histogram {
              buckets {
                low_value: 1.0
                high_value: 1.0
                sample_count: 1.0
              }
              buckets {
                low_value: 1.0
                high_value: 4.0
                sample_count: 1.0
              }
              buckets {
                low_value: 4.0
                high_value: 4.0
                sample_count: 1.0
              }
              type: QUANTILES
            }
          }
          mean: 2.66666666
          std_dev: 1.49071198
          num_zeros: 0
          min: 1.0
          max: 5.0
          median: 3.0
          histograms {
            num_nan: 1
            buckets {
              low_value: 1.0
              high_value: 2.3333333
              sample_count: 2.9866667
            }
            buckets {
              low_value: 2.3333333
              high_value: 3.6666667
              sample_count: 1.0066667
            }
            buckets {
              low_value: 3.6666667
              high_value: 5.0
              sample_count: 2.0066667
            }
            type: STANDARD
          }
          histograms {
            num_nan: 1
            buckets {
              low_value: 1.0
              high_value: 1.0
              sample_count: 1.5
            }
            buckets {
              low_value: 1.0
              high_value: 3.0
              sample_count: 1.5
            }
            buckets {
              low_value: 3.0
              high_value: 4.0
              sample_count: 1.5
            }
            buckets {
              low_value: 4.0
              high_value: 5.0
              sample_count: 1.5
            }
            type: QUANTILES
          }
        }
      }
      features {
        name: 'c'
        type: INT
        num_stats {
          common_stats {
            num_non_missing: 3
            num_missing: 0
            min_num_values: 500
            max_num_values: 1750
            avg_num_values: 1000.0
            tot_num_values: 3000
            num_values_histogram {
              buckets {
                low_value: 500.0
                high_value: 500.0
                sample_count: 1.0
              }
              buckets {
                low_value: 500.0
                high_value: 1750.0
                sample_count: 1.0
              }
              buckets {
                low_value: 1750.0
                high_value: 1750.0
                sample_count: 1.0
              }
              type: QUANTILES
            }
          }
          mean: 1500.5
          std_dev: 866.025355672
          min: 1.0
          max: 3000.0
          median: 1501.0
          histograms {
            buckets {
              low_value: 1.0
              high_value: 1000.66666667
              sample_count: 999.666666667
            }
            buckets {
              low_value: 1000.66666667
              high_value: 2000.33333333
              sample_count: 999.666666667
            }
            buckets {
              low_value: 2000.33333333
              high_value: 3000.0
              sample_count: 1000.66666667
            }
            type: STANDARD
          }
          histograms {
            buckets {
              low_value: 1.0
              high_value: 751.0
              sample_count: 750.0
            }
            buckets {
              low_value: 751.0
              high_value: 1501.0
              sample_count: 750.0
            }
            buckets {
              low_value: 1501.0
              high_value: 2250.0
              sample_count: 750.0
            }
            buckets {
              low_value: 2250.0
              high_value: 3000.0
              sample_count: 750.0
            }
            type: QUANTILES
          }
        }
      }
      features {
        name: "b"
        type: STRING
        string_stats {
          common_stats {
            num_non_missing: 3
            min_num_values: 4
            max_num_values: 4
            avg_num_values: 4.0
            tot_num_values: 12
            num_values_histogram {
              buckets {
                low_value: 4.0
                high_value: 4.0
                sample_count: 1.0
              }
              buckets {
                low_value: 4.0
                high_value: 4.0
                sample_count: 1.0
              }
              buckets {
                low_value: 4.0
                high_value: 4.0
                sample_count: 1.0
              }
              type: QUANTILES
            }
          }
          unique: 5
          top_values {
            value: "a"
            frequency: 4.0
          }
          top_values {
            value: "c"
            frequency: 3.0
          }
          avg_length: 1.0
          rank_histogram {
            buckets {
              low_rank: 0
              high_rank: 0
              label: "a"
              sample_count: 4.0
            }
            buckets {
              low_rank: 1
              high_rank: 1
              label: "c"
              sample_count: 3.0
            }
            buckets {
              low_rank: 2
              high_rank: 2
              label: "d"
              sample_count: 2.0
            }
          }
        }
      }
    }
    """, statistics_pb2.DatasetFeatureStatisticsList())

    with beam.Pipeline() as p:
      options = stats_options.StatsOptions(
          num_top_values=2,
          num_rank_histogram_buckets=3,
          num_values_histogram_buckets=3,
          num_histogram_buckets=3,
          num_quantiles_histogram_buckets=4,
          epsilon=0.001)
      result = (
          p | beam.Create(examples) | stats_api.GenerateStatistics(options))
      util.assert_that(
          result,
          test_util.make_dataset_feature_stats_list_proto_equal_fn(
              self, expected_result))
예제 #33
0
 def test_assert_that(self):
   # TODO: figure out a way for fn_api_runner to parse and raise the
   # underlying exception.
   with self.assertRaisesRegexp(Exception, 'Failed assert'):
     with self.create_pipeline() as p:
       assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
예제 #34
0
 def test_combine_per_key(self):
   with self.create_pipeline() as p:
     res = (p
            | beam.Create([('a', 1), ('a', 2), ('b', 3)])
            | beam.CombinePerKey(beam.combiners.MeanCombineFn()))
     assert_that(res, equal_to([('a', 1.5), ('b', 3.0)]))
예제 #35
0
 def test_reshuffle(self):
   with self.create_pipeline() as p:
     assert_that(p | beam.Create([1, 2, 3]) | beam.Reshuffle(),
                 equal_to([1, 2, 3]))
예제 #36
0
  def test_pardo_side_and_main_outputs(self):
    def even_odd(elem):
      yield elem
      yield beam.pvalue.TaggedOutput('odd' if elem % 2 else 'even', elem)
    with self.create_pipeline() as p:
      ints = p | beam.Create([1, 2, 3])
      named = ints | 'named' >> beam.FlatMap(
          even_odd).with_outputs('even', 'odd', main='all')
      assert_that(named.all, equal_to([1, 2, 3]), label='named.all')
      assert_that(named.even, equal_to([2]), label='named.even')
      assert_that(named.odd, equal_to([1, 3]), label='named.odd')

      unnamed = ints | 'unnamed' >> beam.FlatMap(even_odd).with_outputs()
      unnamed[None] | beam.Map(id)  # pylint: disable=expression-not-assigned
      assert_that(unnamed[None], equal_to([1, 2, 3]), label='unnamed.all')
      assert_that(unnamed.even, equal_to([2]), label='unnamed.even')
      assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd')
예제 #37
0
 def test_stats_pipeline_with_examples_with_no_values(self):
   examples = [{'a': np.array([], dtype=np.floating),
                'b': np.array([], dtype=np.object),
                'c': np.array([], dtype=np.int32),
                'w': np.array([2])},
               {'a': np.array([], dtype=np.floating),
                'b': np.array([], dtype=np.object),
                'c': np.array([], dtype=np.int32),
                'w': np.array([2])},
               {'a': np.array([], dtype=np.floating),
                'b': np.array([], dtype=np.object),
                'c': np.array([], dtype=np.int32),
                'w': np.array([2])}]
   expected_result = text_format.Parse(
       """
     datasets{
       num_examples: 3
       features {
         name: 'a'
         type: FLOAT
         num_stats {
           common_stats {
             num_non_missing: 3
             num_values_histogram {
               buckets {
                 sample_count: 1.5
               }
               buckets {
                 sample_count: 1.5
               }
               type: QUANTILES
             }
             weighted_common_stats {
               num_non_missing: 6
             }
           }
         }
       }
       features {
         name: 'b'
         type: STRING
         string_stats {
           common_stats {
             num_non_missing: 3
             num_values_histogram {
               buckets {
                 sample_count: 1.5
               }
               buckets {
                 sample_count: 1.5
               }
               type: QUANTILES
             }
             weighted_common_stats {
               num_non_missing: 6
             }
           }
         }
       }
       features {
         name: 'c'
         type: INT
         num_stats {
           common_stats {
             num_non_missing: 3
             num_values_histogram {
               buckets {
                 sample_count: 1.5
               }
               buckets {
                 sample_count: 1.5
               }
               type: QUANTILES
             }
             weighted_common_stats {
               num_non_missing: 6
             }
           }
         }
       }
     }
   """, statistics_pb2.DatasetFeatureStatisticsList())
   with beam.Pipeline() as p:
     options = stats_options.StatsOptions(
         weight_feature='w',
         num_top_values=1,
         num_rank_histogram_buckets=1,
         num_values_histogram_buckets=2,
         num_histogram_buckets=1,
         num_quantiles_histogram_buckets=1,
         epsilon=0.001)
     result = (
         p | beam.Create(examples) | stats_api.GenerateStatistics(options))
     util.assert_that(
         result,
         test_util.make_dataset_feature_stats_list_proto_equal_fn(
             self, expected_result))
예제 #38
0
    def testCalibrationPlot(self):
        computations = calibration_plot.CalibrationPlot(
            num_buckets=10).computations()
        histogram = computations[0]
        plot = computations[1]

        example1 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.2]),
            'example_weights': np.array([1.0])
        }
        example2 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.8]),
            'example_weights': np.array([2.0])
        }
        example3 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.5]),
            'example_weights': np.array([3.0])
        }
        example4 = {
            'labels': np.array([1.0]),
            'predictions': np.array([-0.1]),
            'example_weights': np.array([4.0])
        }
        example5 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.5]),
            'example_weights': np.array([5.0])
        }
        example6 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.8]),
            'example_weights': np.array([6.0])
        }
        example7 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.2]),
            'example_weights': np.array([7.0])
        }
        example8 = {
            'labels': np.array([1.0]),
            'predictions': np.array([1.1]),
            'example_weights': np.array([8.0])
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([
                    example1, example2, example3, example4, example5, example6,
                    example7, example8
                ])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)
                |
                'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1]))))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey(name='calibration_plot')
                    self.assertIn(key, got_plots)
                    got_plot = got_plots[key]
                    self.assertProtoEquals(
                        """
              buckets {
                lower_threshold_inclusive: -inf
                upper_threshold_exclusive: 0.0
                total_weighted_label {
                  value: 4.0
                }
                total_weighted_refined_prediction {
                  value: -0.4
                }
                num_weighted_examples {
                  value: 4.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.0
                upper_threshold_exclusive: 0.1
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.1
                upper_threshold_exclusive: 0.2
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.2
                upper_threshold_exclusive: 0.3
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                  value: 1.6
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.3
                upper_threshold_exclusive: 0.4
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.4
                upper_threshold_exclusive: 0.5
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.5
                upper_threshold_exclusive: 0.6
                total_weighted_label {
                  value: 5.0
                }
                total_weighted_refined_prediction {
                  value: 4.0
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.6
                upper_threshold_exclusive: 0.7
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.7
                upper_threshold_exclusive: 0.8
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.8
                upper_threshold_exclusive: 0.9
                total_weighted_label {
                  value: 8.0
                }
                total_weighted_refined_prediction {
                  value: 6.4
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.9
                upper_threshold_exclusive: 1.0
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 1.0
                upper_threshold_exclusive: inf
                total_weighted_label {
                  value: 8.0
                }
                total_weighted_refined_prediction {
                  value: 8.8
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
          """, got_plot)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
예제 #39
0
 def test_read(self):
     print('name:', __name__)
     with self.create_pipeline() as p:
         lines = p | beam.io.ReadFromText('/etc/profile')
         assert_that(lines, lambda lines: len(lines) > 0)
예제 #40
0
 def test_synthetic_sdf_step_multiplies_output_elements_count(self):
     with beam.Pipeline() as p:
         pcoll = p | beam.Create(list(range(10))) | beam.ParDo(
             synthetic_pipeline.get_synthetic_sdf_step(0, 0, 10))
         assert_that(pcoll | beam.combiners.Count.Globally(),
                     equal_to([100]))
예제 #41
0
 def test_basic_empty_missing(self):
     """Test that the correct empty result is returned for a missing month."""
     with TestPipeline() as p:
         results = self._get_result_for_month(p, 4)
         assert_that(results, equal_to([]))
예제 #42
0
def run(argv=None, assert_results=None):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--input_email',
        required=True,
        help='Email database, with each line formatted as "name<TAB>email".')
    parser.add_argument(
        '--input_phone',
        required=True,
        help='Phonebook, with each line formatted as "name<TAB>phone number".')
    parser.add_argument(
        '--input_snailmail',
        required=True,
        help='Address database, with each line formatted as "name<TAB>address".'
    )
    parser.add_argument('--output_tsv',
                        required=True,
                        help='Tab-delimited output file.')
    parser.add_argument('--output_stats',
                        required=True,
                        help='Output file for statistics about the input.')
    known_args, pipeline_args = parser.parse_known_args(argv)
    # We use the save_main_session option because one or more DoFn's in this
    # workflow rely on global context (e.g., a module imported at module level).
    pipeline_options = PipelineOptions(pipeline_args)
    pipeline_options.view_as(SetupOptions).save_main_session = True
    with beam.Pipeline(options=pipeline_options) as p:

        # Helper: read a tab-separated key-value mapping from a text file,
        # escape all quotes/backslashes, and convert it a PCollection of
        # (key, value) pairs.
        def read_kv_textfile(label, textfile):
            return (p
                    | 'Read: %s' % label >> ReadFromText(textfile)
                    | 'Backslash: %s' % label >>
                    beam.Map(lambda x: re.sub(r'\\', r'\\\\', x))
                    | 'EscapeQuotes: %s' % label >>
                    beam.Map(lambda x: re.sub(r'"', r'\"', x))
                    | 'Split: %s' % label >>
                    beam.Map(lambda x: re.split(r'\t+', x, 1)))

        # Read input databases.
        email = read_kv_textfile('email', known_args.input_email)
        phone = read_kv_textfile('phone', known_args.input_phone)
        snailmail = read_kv_textfile('snailmail', known_args.input_snailmail)

        # Group together all entries under the same name.
        grouped = (email, phone,
                   snailmail) | 'group_by_name' >> beam.CoGroupByKey()

        # Prepare tab-delimited output; something like this:
        # "name"<TAB>"email_1,email_2"<TAB>"phone"<TAB>"first_snailmail_only"
        tsv_lines = grouped | beam.Map(
            lambda (name, (email, phone, snailmail)): '\t'.join([
                '"%s"' % name,
                '"%s"' % ','.join(email),
                '"%s"' % ','.join(phone),
                '"%s"' % next(iter(snailmail), '')
            ]))

        # Compute some stats about our database of people.
        luddites = grouped | beam.Filter(  # People without email.
            lambda (name,
                    (email, phone, snailmail)): not next(iter(email), None))
        writers = grouped | beam.Filter(  # People without phones.
            lambda (name,
                    (email, phone, snailmail)): not next(iter(phone), None))
        nomads = grouped | beam.Filter(  # People without addresses.
            lambda (name, (e, p, snailmail)): not next(iter(snailmail), None))

        num_luddites = luddites | 'Luddites' >> beam.combiners.Count.Globally()
        num_writers = writers | 'Writers' >> beam.combiners.Count.Globally()
        num_nomads = nomads | 'Nomads' >> beam.combiners.Count.Globally()

        # Write tab-delimited output.
        # pylint: disable=expression-not-assigned
        tsv_lines | 'WriteTsv' >> WriteToText(known_args.output_tsv)

        # TODO(silviuc): Move the assert_results logic to the unit test.
        if assert_results is not None:
            expected_luddites, expected_writers, expected_nomads = assert_results
            assert_that(num_luddites,
                        equal_to([expected_luddites]),
                        label='assert:luddites')
            assert_that(num_writers,
                        equal_to([expected_writers]),
                        label='assert:writers')
            assert_that(num_nomads,
                        equal_to([expected_nomads]),
                        label='assert:nomads')
예제 #43
0
 def testEqual(self):
     with TestPipeline() as p:
         tokens = p | beam.Create(self.sample_input)
         result = tokens | beam.CombineGlobally(
             utils.CalculateCoefficients(0.5))
         assert_that(result, equal_to([{'en': 1.0, 'fr': 1.0}]))
예제 #44
0
 def test_basic_empty(self):
     """Test that the correct empty result is returned for a simple dataset."""
     with TestPipeline() as p:
         results = self._get_result_for_month(p, 3)
         assert_that(results, equal_to([]))
예제 #45
0
 def testUnsorted(self):
     with TestPipeline() as p:
         tokens = p | 'CreateInput' >> beam.Create(self.sample_input)
         result = tokens | beam.CombineGlobally(utils.SortByCount())
         assert_that(result,
                     equal_to([[('c', 9), ('a', 5), ('d', 4), ('b', 2)]]))
예제 #46
0
  def testEvaluateWithSlicingAndDifferentBatchSizes(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = linear_classifier.simple_linear_classifier(
        None, temp_eval_export_dir)

    for batch_size in [1, 2, 4, 8]:

      with beam.Pipeline() as pipeline:
        example1 = self._makeExample(
            age=3.0, language='english', label=1.0, slice_key='first_slice')
        example2 = self._makeExample(
            age=3.0, language='chinese', label=0.0, slice_key='first_slice')
        example3 = self._makeExample(
            age=4.0, language='english', label=0.0, slice_key='second_slice')
        example4 = self._makeExample(
            age=5.0, language='chinese', label=1.0, slice_key='second_slice')
        example5 = self._makeExample(
            age=5.0, language='chinese', label=1.0, slice_key='second_slice')

        metrics, plots = (
            pipeline
            | beam.Create([
                example1.SerializeToString(),
                example2.SerializeToString(),
                example3.SerializeToString(),
                example4.SerializeToString(),
                example5.SerializeToString(),
            ])
            | evaluate.Evaluate(
                eval_saved_model_path=eval_export_dir,
                add_metrics_callbacks=[_addExampleCountMetricCallback],
                slice_spec=[
                    slicer.SingleSliceSpec(),
                    slicer.SingleSliceSpec(columns=['slice_key'])
                ],
                desired_batch_size=batch_size))

        def check_result(got):
          try:
            self.assertEqual(3, len(got), 'got: %s' % got)
            slices = {}
            for slice_key, value in got:
              slices[slice_key] = value
            overall_slice = ()
            first_slice = (('slice_key', 'first_slice'),)
            second_slice = (('slice_key', 'second_slice'),)
            self.assertItemsEqual(slices.keys(),
                                  [overall_slice, first_slice, second_slice])
            self.assertDictElementsAlmostEqual(
                slices[overall_slice], {
                    'accuracy': 0.4,
                    'label/mean': 0.6,
                    'my_mean_age': 4.0,
                    'my_mean_age_times_label': 2.6,
                    'added_example_count': 5.0
                })
            self.assertDictElementsAlmostEqual(
                slices[first_slice], {
                    'accuracy': 1.0,
                    'label/mean': 0.5,
                    'my_mean_age': 3.0,
                    'my_mean_age_times_label': 1.5,
                    'added_example_count': 2.0
                })
            self.assertDictElementsAlmostEqual(
                slices[second_slice], {
                    'accuracy': 0.0,
                    'label/mean': 2.0 / 3.0,
                    'my_mean_age': 14.0 / 3.0,
                    'my_mean_age_times_label': 10.0 / 3.0,
                    'added_example_count': 3.0
                })

          except AssertionError as err:
            # This function is redefined every iteration, so it will have the
            # right value of batch_size.
            raise util.BeamAssertException('batch_size = %d, error: %s' %
                                           (batch_size, err))  # pylint: disable=cell-var-from-loop

        util.assert_that(metrics, check_result, label='metrics')
        util.assert_that(plots, util.is_empty(), label='plots')
예제 #47
0
 def test_run_direct(self):
     file_name = self._create_temp_file(b'aaaa\nbbbb\ncccc\ndddd')
     with TestPipeline() as pipeline:
         pcoll = pipeline | beam.io.Read(LineSource(file_name))
         assert_that(pcoll, equal_to([b'aaaa', b'bbbb', b'cccc', b'dddd']))
예제 #48
0
 def testLangNotInLangSet(self):
     with TestPipeline() as p:
         tokens = p | beam.Create(self.sample_input)
         result = tokens | beam.ParDo(utils.FilterTokensByLang({'fr'}))
         assert_that(result, equal_to([]))
예제 #49
0
  def test_multiple_destinations_transform(self):
    streaming = self.test_pipeline.options.view_as(StandardOptions).streaming
    if streaming and isinstance(self.test_pipeline.runner, TestDataflowRunner):
      self.skipTest("TestStream is not supported on TestDataflowRunner")

    output_table_1 = '%s%s' % (self.output_table, 1)
    output_table_2 = '%s%s' % (self.output_table, 2)

    full_output_table_1 = '%s:%s' % (self.project, output_table_1)
    full_output_table_2 = '%s:%s' % (self.project, output_table_2)

    schema1 = {
        'fields': [{
            'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'
        }, {
            'name': 'language', 'type': 'STRING', 'mode': 'NULLABLE'
        }]
    }
    schema2 = {
        'fields': [{
            'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'
        }, {
            'name': 'foundation', 'type': 'STRING', 'mode': 'NULLABLE'
        }]
    }

    bad_record = {'language': 1, 'manguage': 2}

    if streaming:
      pipeline_verifiers = [
          PipelineStateMatcher(PipelineState.RUNNING),
          BigqueryFullResultStreamingMatcher(
              project=self.project,
              query="SELECT name, language FROM %s" % output_table_1,
              data=[(d['name'], d['language']) for d in _ELEMENTS
                    if 'language' in d]),
          BigqueryFullResultStreamingMatcher(
              project=self.project,
              query="SELECT name, foundation FROM %s" % output_table_2,
              data=[(d['name'], d['foundation']) for d in _ELEMENTS
                    if 'foundation' in d])
      ]
    else:
      pipeline_verifiers = [
          BigqueryFullResultMatcher(
              project=self.project,
              query="SELECT name, language FROM %s" % output_table_1,
              data=[(d['name'], d['language']) for d in _ELEMENTS
                    if 'language' in d]),
          BigqueryFullResultMatcher(
              project=self.project,
              query="SELECT name, foundation FROM %s" % output_table_2,
              data=[(d['name'], d['foundation']) for d in _ELEMENTS
                    if 'foundation' in d])
      ]

    args = self.test_pipeline.get_full_options_as_args(
        on_success_matcher=hc.all_of(*pipeline_verifiers))

    with beam.Pipeline(argv=args) as p:
      if streaming:
        _SIZE = len(_ELEMENTS)
        test_stream = (
            TestStream().advance_watermark_to(0).add_elements(
                _ELEMENTS[:_SIZE // 2]).advance_watermark_to(100).add_elements(
                    _ELEMENTS[_SIZE // 2:]).advance_watermark_to_infinity())
        input = p | test_stream
      else:
        input = p | beam.Create(_ELEMENTS)

      schema_table_pcv = beam.pvalue.AsDict(
          p | "MakeSchemas" >> beam.Create([(full_output_table_1, schema1),
                                            (full_output_table_2, schema2)]))

      table_record_pcv = beam.pvalue.AsDict(
          p | "MakeTables" >> beam.Create([('table1', full_output_table_1),
                                           ('table2', full_output_table_2)]))

      input2 = p | "Broken record" >> beam.Create([bad_record])

      input = (input, input2) | beam.Flatten()

      r = (
          input
          | "WriteWithMultipleDests" >> beam.io.gcp.bigquery.WriteToBigQuery(
              table=lambda x,
              tables:
              (tables['table1'] if 'language' in x else tables['table2']),
              table_side_inputs=(table_record_pcv, ),
              schema=lambda dest,
              table_map: table_map.get(dest, None),
              schema_side_inputs=(schema_table_pcv, ),
              insert_retry_strategy=RetryStrategy.RETRY_ON_TRANSIENT_ERROR,
              method='STREAMING_INSERTS'))

      assert_that(
          r[beam.io.gcp.bigquery.BigQueryWriteFn.FAILED_ROWS],
          equal_to([(full_output_table_1, bad_record)]))
예제 #50
0
 def test_native_source(self):
     with beam.Pipeline(argv=self.args) as p:
         result = (p | 'read' >> beam.io.Read(
             beam.io.BigQuerySource(query=self.query,
                                    use_standard_sql=True)))
         assert_that(result, equal_to(self.get_expected_data()))
예제 #51
0
 def test_timestamp_param_map(self):
     with TestPipeline() as p:
         assert_that(
             p | Create([1, 2])
             | beam.Map(lambda _, t=DoFn.TimestampParam: t),
             equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
예제 #52
0
 def expand(self, pcoll):
   assert_that(pcoll, self.matcher)
예제 #53
0
 def test_fake_read(self):
     pipeline = TestPipeline()
     pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3]))
     assert_that(pcoll, equal_to([1, 2, 3]))
     pipeline.run()
예제 #54
0
 def test_read_from_avro(self):
     path = self._write_data()
     with TestPipeline() as p:
         assert_that(
             p | avroio.ReadFromAvro(path, use_fastavro=self.use_fastavro),
             equal_to(self.RECORDS))
  def testEvaluateWithPlots(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        fixed_prediction_estimator.simple_fixed_prediction_estimator(
            None, temp_eval_export_dir))
    eval_shared_model = self.createTestEvalSharedModel(
        eval_saved_model_path=eval_export_dir,
        add_metrics_callbacks=[
            post_export_metrics.example_count(),
            post_export_metrics.auc_plots()
        ])
    extractors = [
        predict_extractor.PredictExtractor(eval_shared_model),
        slice_key_extractor.SliceKeyExtractor()
    ]

    with beam.Pipeline() as pipeline:
      example1 = self._makeExample(prediction=0.0, label=1.0)
      example2 = self._makeExample(prediction=0.7, label=0.0)
      example3 = self._makeExample(prediction=0.8, label=1.0)
      example4 = self._makeExample(prediction=1.0, label=1.0)

      (metrics, plots), _ = (
          pipeline
          | 'Create' >> beam.Create([
              example1.SerializeToString(),
              example2.SerializeToString(),
              example3.SerializeToString(),
              example4.SerializeToString()
          ])
          | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
          | 'Extract' >> tfma_unit.Extract(extractors=extractors)  # pylint: disable=no-value-for-parameter
          | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator
          .ComputeMetricsAndPlots(eval_shared_model=eval_shared_model))

      def check_metrics(got):
        try:
          self.assertEqual(1, len(got), 'got: %s' % got)
          (slice_key, value) = got[0]
          self.assertEqual((), slice_key)
          self.assertDictElementsAlmostEqual(
              got_values_dict=value,
              expected_values_dict={
                  metric_keys.EXAMPLE_COUNT: 4.0,
              })
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(metrics, check_metrics, label='metrics')

      def check_plots(got):
        try:
          self.assertEqual(1, len(got), 'got: %s' % got)
          (slice_key, value) = got[0]
          self.assertEqual((), slice_key)
          self.assertDictMatrixRowsAlmostEqual(
              got_values_dict=value,
              expected_values_dict={
                  metric_keys.AUC_PLOTS_MATRICES: [
                      (8001, [2, 1, 0, 1, 1.0 / 1.0, 1.0 / 3.0])
                  ],
              })
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(plots, check_plots, label='plots')
예제 #56
0
 def test_apply_custom_transform(self):
     pipeline = TestPipeline()
     pcoll = pipeline | 'pcoll' >> Create([1, 2, 3])
     result = pcoll | PipelineTest.CustomTransform()
     assert_that(result, equal_to([2, 3, 4]))
     pipeline.run()
예제 #57
0
    def test_input_output_polymorphism(self):
        one_series = pd.Series([1])
        two_series = pd.Series([2])
        three_series = pd.Series([3])
        proxy = one_series[:0]

        def equal_to_series(expected):
            def check(actual):
                actual = pd.concat(actual)
                if not expected.equals(actual):
                    raise AssertionError('Series not equal: \n%s\n%s\n' %
                                         (expected, actual))

            return check

        with beam.Pipeline() as p:
            one = p | 'One' >> beam.Create([one_series])
            two = p | 'Two' >> beam.Create([two_series])

            assert_that(
                one | 'PcollInPcollOut' >> transforms.DataframeTransform(
                    lambda x: 3 * x, proxy=proxy, yield_elements='pandas'),
                equal_to_series(three_series),
                label='CheckPcollInPcollOut')

            assert_that((one, two)
                        | 'TupleIn' >> transforms.DataframeTransform(
                            lambda x, y: (x + y), (proxy, proxy),
                            yield_elements='pandas'),
                        equal_to_series(three_series),
                        label='CheckTupleIn')

            assert_that(dict(x=one, y=two)
                        | 'DictIn' >> transforms.DataframeTransform(
                            lambda x, y: (x + y),
                            proxy=dict(x=proxy, y=proxy),
                            yield_elements='pandas'),
                        equal_to_series(three_series),
                        label='CheckDictIn')

            double, triple = one | 'TupleOut' >> transforms.DataframeTransform(
                lambda x: (2 * x, 3 * x), proxy, yield_elements='pandas')
            assert_that(double, equal_to_series(two_series), 'CheckTupleOut0')
            assert_that(triple, equal_to_series(three_series),
                        'CheckTupleOut1')

            res = one | 'DictOut' >> transforms.DataframeTransform(
                lambda x: {'res': 3 * x}, proxy, yield_elements='pandas')
            assert_that(res['res'], equal_to_series(three_series),
                        'CheckDictOut')
예제 #58
0
 def test_create_singleton_pcollection(self):
     pipeline = TestPipeline()
     pcoll = pipeline | 'label' >> Create([[1, 2, 3]])
     assert_that(pcoll, equal_to([[1, 2, 3]]))
     pipeline.run()
예제 #59
0
 def test_create(self):
   with self.create_pipeline() as p:
     assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
  def testEvaluateNoSlicingAddPostExportAndCustomMetrics(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = linear_classifier.simple_linear_classifier(
        None, temp_eval_export_dir)
    eval_shared_model = self.createTestEvalSharedModel(
        eval_saved_model_path=eval_export_dir,
        add_metrics_callbacks=[
            _addExampleCountMetricCallback,
            # Note that since everything runs in-process this doesn't
            # actually test that the py_func can be correctly recreated
            # on workers in a distributed context.
            _addPyFuncMetricCallback,
            post_export_metrics.example_count(),
            post_export_metrics.example_weight(example_weight_key='age')
        ])
    extractors = [
        predict_extractor.PredictExtractor(eval_shared_model),
        slice_key_extractor.SliceKeyExtractor()
    ]

    with beam.Pipeline() as pipeline:
      example1 = self._makeExample(age=3.0, language='english', label=1.0)
      example2 = self._makeExample(age=3.0, language='chinese', label=0.0)
      example3 = self._makeExample(age=4.0, language='english', label=1.0)
      example4 = self._makeExample(age=5.0, language='chinese', label=0.0)

      (metrics, plots), _ = (
          pipeline
          | 'Create' >> beam.Create([
              example1.SerializeToString(),
              example2.SerializeToString(),
              example3.SerializeToString(),
              example4.SerializeToString()
          ])
          | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
          | 'Extract' >> tfma_unit.Extract(extractors=extractors)  # pylint: disable=no-value-for-parameter
          | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator
          .ComputeMetricsAndPlots(eval_shared_model=eval_shared_model))

      def check_result(got):
        try:
          self.assertEqual(1, len(got), 'got: %s' % got)
          (slice_key, value) = got[0]
          self.assertEqual((), slice_key)
          self.assertDictElementsAlmostEqual(
              got_values_dict=value,
              expected_values_dict={
                  'accuracy': 1.0,
                  'label/mean': 0.5,
                  'my_mean_age': 3.75,
                  'my_mean_age_times_label': 1.75,
                  'added_example_count': 4.0,
                  'py_func_label_sum': 2.0,
                  metric_keys.EXAMPLE_COUNT: 4.0,
                  metric_keys.EXAMPLE_WEIGHT: 15.0
              })
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(metrics, check_result, label='metrics')
      util.assert_that(plots, util.is_empty(), label='plots')