Пример #1
0
 def test_reshuffle_window_fn_preserved(self):
   pipeline = TestPipeline()
   data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   expected_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [
       ((1, 1), 1.0, IntervalWindow(1.0, 3.0)),
       ((2, 1), 1.0, IntervalWindow(1.0, 3.0)),
       ((3, 1), 1.0, IntervalWindow(1.0, 3.0)),
       ((1, 2), 2.0, IntervalWindow(2.0, 4.0)),
       ((2, 2), 2.0, IntervalWindow(2.0, 4.0)),
       ((1, 4), 4.0, IntervalWindow(4.0, 6.0))]]
   expected_merged_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [
       ((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
       ((2, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
       ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)),
       ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]]
   before_reshuffle = (pipeline
                       | 'start' >> beam.Create(data)
                       | 'add_timestamp' >> beam.Map(
                           lambda v: TimestampedValue(v, v[1]))
                       | 'window' >> beam.WindowInto(Sessions(gap_size=2)))
   assert_that(before_reshuffle, equal_to(expected_windows),
               label='before_reshuffle', reify_windows=True)
   after_reshuffle = (before_reshuffle
                      | 'reshuffle' >> beam.Reshuffle())
   assert_that(after_reshuffle, equal_to(expected_windows),
               label='after_reshuffle', reify_windows=True)
   after_group = (after_reshuffle
                  | 'group_by_key' >> beam.GroupByKey())
   assert_that(after_group, equal_to(expected_merged_windows),
               label='after_group', reify_windows=True)
   pipeline.run()
Пример #2
0
  def test_no_window_context_fails(self):
    expected_timestamp = timestamp.Timestamp(5)
    # Assuming the default window function is window.GlobalWindows.
    expected_window = window.GlobalWindow()

    class AddTimestampDoFn(beam.DoFn):
      def process(self, element):
        yield window.TimestampedValue(element, expected_timestamp)

    pipeline = TestPipeline()
    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
    expected_windows = [
        TestWindowedValue(kv, expected_timestamp, [expected_window])
        for kv in data]
    before_identity = (pipeline
                       | 'start' >> beam.Create(data)
                       | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn()))
    assert_that(before_identity, equal_to(expected_windows),
                label='before_identity', reify_windows=True)
    after_identity = (before_identity
                      | 'window' >> beam.WindowInto(
                          beam.transforms.util._IdentityWindowFn(
                              coders.GlobalWindowCoder()))
                      # This DoFn will return TimestampedValues, making
                      # WindowFn.AssignContext passed to IdentityWindowFn
                      # contain a window of None. IdentityWindowFn should
                      # raise an exception.
                      | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn()))
    assert_that(after_identity, equal_to(expected_windows),
                label='after_identity', reify_windows=True)
    with self.assertRaisesRegexp(ValueError, r'window.*None.*add_timestamps2'):
      pipeline.run()
Пример #3
0
  def test_multi(self):

    @ptransform.PTransform.register_urn('multi', None)
    class MutltiTransform(ptransform.PTransform):
      def expand(self, pcolls):
        return {
            'main':
                (pcolls['main1'], pcolls['main2'])
                | beam.Flatten()
                | beam.Map(lambda x, s: x + s,
                           beam.pvalue.AsSingleton(pcolls['side'])),
            'side': pcolls['side'] | beam.Map(lambda x: x + x),
        }

      def to_runner_api_parameter(self, unused_context):
        return 'multi', None

      @staticmethod
      def from_runner_api_parameter(unused_parameter, unused_context):
        return MutltiTransform()

    with beam.Pipeline() as p:
      main1 = p | 'Main1' >> beam.Create(['a', 'bb'], reshuffle=False)
      main2 = p | 'Main2' >> beam.Create(['x', 'yy', 'zzz'], reshuffle=False)
      side = p | 'Side' >> beam.Create(['s'])
      res = dict(main1=main1, main2=main2, side=side) | beam.ExternalTransform(
          'multi', None, expansion_service.ExpansionServiceServicer())
      assert_that(res['main'], equal_to(['as', 'bbs', 'xs', 'yys', 'zzzs']))
      assert_that(res['side'], equal_to(['ss']), label='CheckSide')
Пример #4
0
 def test_timestamped_with_combiners(self):
   with TestPipeline() as p:
     result = (p
               # Create some initial test values.
               | 'start' >> Create([(k, k) for k in range(10)])
               # The purpose of the WindowInto transform is to establish a
               # FixedWindows windowing function for the PCollection.
               # It does not bucket elements into windows since the timestamps
               # from Create are not spaced 5 ms apart and very likely they all
               # fall into the same window.
               | 'w' >> WindowInto(FixedWindows(5))
               # Generate timestamped values using the values as timestamps.
               # Now there are values 5 ms apart and since Map propagates the
               # windowing function from input to output the output PCollection
               # will have elements falling into different 5ms windows.
               | Map(lambda x_t2: TimestampedValue(x_t2[0], x_t2[1]))
               # We add a 'key' to each value representing the index of the
               # window. This is important since there is no guarantee of
               # order for the elements of a PCollection.
               | Map(lambda v: (v / 5, v)))
     # Sum all elements associated with a key and window. Although it
     # is called CombinePerKey it is really CombinePerKeyAndWindow the
     # same way GroupByKey is really GroupByKeyAndWindow.
     sum_per_window = result | CombinePerKey(sum)
     # Compute mean per key and window.
     mean_per_window = result | combiners.Mean.PerKey()
     assert_that(sum_per_window, equal_to([(0, 10), (1, 35)]),
                 label='assert:sum')
     assert_that(mean_per_window, equal_to([(0, 2.0), (1, 7.0)]),
                 label='assert:mean')
    def check_many_files(output_pcs):
      dest_file_pc = output_pcs[bqfl.WriteRecordsToFile.WRITTEN_FILE_TAG]
      spilled_records_pc = output_pcs[
          bqfl.WriteRecordsToFile.UNWRITTEN_RECORD_TAG]

      spilled_records_count = (spilled_records_pc |
                               beam.combiners.Count.Globally())
      assert_that(spilled_records_count, equal_to([3]), label='spilled count')

      files_per_dest = (dest_file_pc
                        | beam.Map(lambda x: x).with_output_types(
                            beam.typehints.KV[str, str])
                        | beam.combiners.Count.PerKey())
      files_per_dest = (
          files_per_dest
          | "GetDests" >> beam.Map(
              lambda x: (bigquery_tools.get_hashable_destination(x[0]),
                         x[1])))

      # Only table1 and table3 get files. table2 records get spilled.
      assert_that(files_per_dest,
                  equal_to([('project1:dataset1.table1', 1),
                            ('project1:dataset1.table3', 1)]),
                  label='file count')

      # Check that the files exist
      _ = dest_file_pc | beam.Map(lambda x: x[1]) | beam.Map(
          lambda x: hamcrest_assert(os.path.exists(x), is_(True)))
Пример #6
0
  def test_window_preserved(self):
    expected_timestamp = timestamp.Timestamp(5)
    expected_window = window.IntervalWindow(1.0, 2.0)

    class AddWindowDoFn(beam.DoFn):
      def process(self, element):
        yield WindowedValue(
            element, expected_timestamp, [expected_window])

    pipeline = TestPipeline()
    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
    expected_windows = [
        TestWindowedValue(kv, expected_timestamp, [expected_window])
        for kv in data]
    before_identity = (pipeline
                       | 'start' >> beam.Create(data)
                       | 'add_windows' >> beam.ParDo(AddWindowDoFn()))
    assert_that(before_identity, equal_to(expected_windows),
                label='before_identity', reify_windows=True)
    after_identity = (before_identity
                      | 'window' >> beam.WindowInto(
                          beam.transforms.util._IdentityWindowFn(
                              coders.IntervalWindowCoder())))
    assert_that(after_identity, equal_to(expected_windows),
                label='after_identity', reify_windows=True)
    pipeline.run()
Пример #7
0
  def test_pardo_side_inputs(self):
    def cross_product(elem, sides):
      for side in sides:
        yield elem, side
    with self.create_pipeline() as p:
      main = p | 'main' >> beam.Create(['a', 'b', 'c'])
      side = p | 'side' >> beam.Create(['x', 'y'])
      assert_that(main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)),
                  equal_to([('a', 'x'), ('b', 'x'), ('c', 'x'),
                            ('a', 'y'), ('b', 'y'), ('c', 'y')]))

      # Now with some windowing.
      pcoll = p | beam.Create(range(10)) | beam.Map(
          lambda t: window.TimestampedValue(t, t))
      # Intentionally choosing non-aligned windows to highlight the transition.
      main = pcoll | 'WindowMain' >> beam.WindowInto(window.FixedWindows(5))
      side = pcoll | 'WindowSide' >> beam.WindowInto(window.FixedWindows(7))
      res = main | beam.Map(lambda x, s: (x, sorted(s)),
                            beam.pvalue.AsList(side))
      assert_that(
          res,
          equal_to([
              # The window [0, 5) maps to the window [0, 7).
              (0, range(7)),
              (1, range(7)),
              (2, range(7)),
              (3, range(7)),
              (4, range(7)),
              # The window [5, 10) maps to the window [7, 14).
              (5, range(7, 10)),
              (6, range(7, 10)),
              (7, range(7, 10)),
              (8, range(7, 10)),
              (9, range(7, 10))]),
          label='windowed')
  def test_records_traverse_transform_with_mocks(self):
    destination = 'project1:dataset1.table1'

    job_reference = bigquery_api.JobReference()
    job_reference.projectId = 'project1'
    job_reference.jobId = 'job_name1'
    result_job = bigquery_api.Job()
    result_job.jobReference = job_reference

    mock_job = mock.Mock()
    mock_job.status.state = 'DONE'
    mock_job.status.errorResult = None
    mock_job.jobReference = job_reference

    bq_client = mock.Mock()
    bq_client.jobs.Get.return_value = mock_job

    bq_client.jobs.Insert.return_value = result_job

    transform = bqfl.BigQueryBatchFileLoads(
        destination,
        custom_gcs_temp_location=self._new_tempdir(),
        test_client=bq_client,
        validate=False)

    # Need to test this with the DirectRunner to avoid serializing mocks
    with TestPipeline('DirectRunner') as p:
      outputs = p | beam.Create(_ELEMENTS) | transform

      dest_files = outputs[bqfl.BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS]
      dest_job = outputs[bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS]

      jobs = dest_job | "GetJobs" >> beam.Map(lambda x: x[1])

      files = dest_files | "GetFiles" >> beam.Map(lambda x: x[1])
      destinations = (
          dest_files
          | "GetDests" >> beam.Map(
              lambda x: (
                  bigquery_tools.get_hashable_destination(x[0]), x[1]))
          | "GetUniques" >> beam.combiners.Count.PerKey()
          | "GetFinalDests" >>beam.Keys())

      # All files exist
      _ = (files | beam.Map(
          lambda x: hamcrest_assert(os.path.exists(x), is_(True))))

      # One file per destination
      assert_that(files | beam.combiners.Count.Globally(),
                  equal_to([1]),
                  label='CountFiles')

      assert_that(destinations,
                  equal_to([destination]),
                  label='CheckDestinations')

      assert_that(jobs,
                  equal_to([job_reference]), label='CheckJobs')
Пример #9
0
 def test_multi(self):
   with beam.Pipeline() as p:
     main1 = p | 'Main1' >> beam.Create(['a', 'bb'], reshuffle=False)
     main2 = p | 'Main2' >> beam.Create(['x', 'yy', 'zzz'], reshuffle=False)
     side = p | 'Side' >> beam.Create(['s'])
     res = dict(main1=main1, main2=main2, side=side) | beam.ExternalTransform(
         'multi', None, expansion_service.ExpansionServiceServicer())
     assert_that(res['main'], equal_to(['as', 'bbs', 'xs', 'yys', 'zzzs']))
     assert_that(res['side'], equal_to(['ss']), label='CheckSide')
Пример #10
0
  def test_create(self):
    pipeline = TestPipeline()
    pcoll = pipeline | 'label1' >> Create([1, 2, 3])
    assert_that(pcoll, equal_to([1, 2, 3]))

    # Test if initial value is an iterator object.
    pcoll2 = pipeline | 'label2' >> Create(iter((4, 5, 6)))
    pcoll3 = pcoll2 | 'do' >> FlatMap(lambda x: [x + 10])
    assert_that(pcoll3, equal_to([14, 15, 16]), label='pcoll3')
    pipeline.run()
Пример #11
0
 def test_reuse_cloned_custom_transform_instance(self):
   pipeline = TestPipeline()
   pcoll1 = pipeline | 'pc1' >> Create([1, 2, 3])
   pcoll2 = pipeline | 'pc2' >> Create([4, 5, 6])
   transform = PipelineTest.CustomTransform()
   result1 = pcoll1 | transform
   result2 = pcoll2 | 'new_label' >> transform
   assert_that(result1, equal_to([2, 3, 4]), label='r1')
   assert_that(result2, equal_to([5, 6, 7]), label='r2')
   pipeline.run()
Пример #12
0
def assert_reentrant_reads_succeed(source_info):
  """Tests if a given source can be read in a reentrant manner.

  Assume that given source produces the set of values ``{v1, v2, v3, ... vn}``.
  For ``i`` in range ``[1, n-1]`` this method performs a reentrant read after
  reading ``i`` elements and verifies that both the original and reentrant read
  produce the expected set of values.

  Args:
    source_info (Tuple[~apache_beam.io.iobase.BoundedSource, int, int]):
      a three-tuple that gives the reference
      :class:`~apache_beam.io.iobase.BoundedSource`, position to start reading
      at, and a position to stop reading at.

  Raises:
    ~exceptions.ValueError: if source is too trivial or reentrant read result
      in an incorrect read.
  """

  source, start_position, stop_position = source_info
  assert isinstance(source, iobase.BoundedSource)

  expected_values = [val for val in source.read(source.get_range_tracker(
      start_position, stop_position))]
  if len(expected_values) < 2:
    raise ValueError('Source is too trivial since it produces only %d '
                     'values. Please give a source that reads at least 2 '
                     'values.' % len(expected_values))

  for i in range(1, len(expected_values) - 1):
    read_iter = source.read(source.get_range_tracker(
        start_position, stop_position))
    original_read = []
    for _ in range(i):
      original_read.append(next(read_iter))

    # Reentrant read
    reentrant_read = [val for val in source.read(
        source.get_range_tracker(start_position, stop_position))]

    # Continuing original read.
    for val in read_iter:
      original_read.append(val)

    if equal_to(original_read)(expected_values):
      raise ValueError('Source did not produce expected values when '
                       'performing a reentrant read after reading %d values. '
                       'Expected %r received %r.'
                       % (i, expected_values, original_read))

    if equal_to(reentrant_read)(expected_values):
      raise ValueError('A reentrant read of source after reading %d values '
                       'did not produce expected values. Expected %r '
                       'received %r.'
                       % (i, expected_values, reentrant_read))
Пример #13
0
  def test_java_expansion(self):
    if not self.expansion_service_jar:
      raise unittest.SkipTest('No expansion service jar provided.')

    # The actual definitions of these transforms is in
    # org.apache.beam.runners.core.construction.TestExpansionService.
    TEST_COUNT_URN = "pytest:beam:transforms:count"
    TEST_FILTER_URN = "pytest:beam:transforms:filter_less_than"

    # Run as cheaply as possible on the portable runner.
    # TODO(robertwb): Support this directly in the direct runner.
    options = beam.options.pipeline_options.PipelineOptions(
        runner='PortableRunner',
        experiments=['beam_fn_api'],
        environment_type=python_urns.EMBEDDED_PYTHON,
        job_endpoint='embed')

    try:
      # Start the java server and wait for it to be ready.
      port = '8091'
      address = 'localhost:%s' % port
      server = subprocess.Popen(
          ['java', '-jar', self.expansion_service_jar, port])
      with grpc.insecure_channel(address) as channel:
        grpc.channel_ready_future(channel).result()

      # Run a simple count-filtered-letters pipeline.
      with beam.Pipeline(options=options) as p:
        res = (
            p
            | beam.Create(list('aaabccxyyzzz'))
            | beam.Map(unicode)
            # TODO(BEAM-6587): Use strings directly rather than ints.
            | beam.Map(lambda x: int(ord(x)))
            | beam.ExternalTransform(TEST_FILTER_URN, b'middle', address)
            | beam.ExternalTransform(TEST_COUNT_URN, None, address)
            # TODO(BEAM-6587): Remove when above is removed.
            | beam.Map(lambda kv: (chr(kv[0]), kv[1]))
            | beam.Map(lambda kv: '%s: %s' % kv))

        assert_that(res, equal_to(['a: 3', 'b: 1', 'c: 2']))

      # Test GenerateSequence Java transform
      with beam.Pipeline(options=options) as p:
        res = (
            p
            | GenerateSequence(start=1, stop=10,
                               expansion_service=address)
        )

        assert_that(res, equal_to([i for i in range(1, 10)]))

    finally:
      server.kill()
 def test_pardo_side_outputs(self):
   def tee(elem, *tags):
     for tag in tags:
       if tag in elem:
         yield beam.pvalue.TaggedOutput(tag, elem)
   with self.create_pipeline() as p:
     xy = (p
           | 'Create' >> beam.Create(['x', 'y', 'xy'])
           | beam.FlatMap(tee, 'x', 'y').with_outputs())
     assert_that(xy.x, equal_to(['x', 'xy']), label='x')
     assert_that(xy.y, equal_to(['y', 'xy']), label='y')
Пример #15
0
    def check_files_created(output_pc):
      files = output_pc | "GetFiles" >> beam.Map(lambda x: x[1])
      file_count = files | "CountFiles" >> beam.combiners.Count.Globally()

      _ = files | "FilesExist" >> beam.Map(
          lambda x: hamcrest_assert(os.path.exists(x), is_(True)))
      assert_that(file_count, equal_to([3]), label='check file count')

      destinations = output_pc | "GetDests" >> beam.Map(lambda x: x[0])
      assert_that(destinations, equal_to(list(_DISTINCT_DESTINATIONS)),
                  label='check destinations ')
Пример #16
0
  def test_combine_globally_with_default_side_input(self):
    class SideInputCombine(PTransform):
      def expand(self, pcoll):
        side = pcoll | CombineGlobally(sum).as_singleton_view()
        main = pcoll.pipeline | Create([None])
        return main | Map(lambda _, s: s, side)

    with TestPipeline() as p:
      result1 = p | 'i1' >> Create([]) | 'c1' >> SideInputCombine()
      result2 = p | 'i2' >> Create([1, 2, 3, 4]) | 'c2' >> SideInputCombine()
      assert_that(result1, equal_to([0]), label='r1')
      assert_that(result2, equal_to([10]), label='r2')
Пример #17
0
  def test_reshuffle_after_gbk_contents_unchanged(self):
    pipeline = TestPipeline()
    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
    expected_result = [(1, [1, 2, 3]), (2, [1, 2]), (3, [1])]

    after_gbk = (pipeline
                 | beam.Create(data)
                 | beam.GroupByKey())
    assert_that(after_gbk, equal_to(expected_result), label='after_gbk')
    after_reshuffle = after_gbk | beam.Reshuffle()
    assert_that(after_reshuffle, equal_to(expected_result),
                label='after_reshuffle')
    pipeline.run()
Пример #18
0
 def test_reshuffle_global_window(self):
   pipeline = TestPipeline()
   data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
   expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
   before_reshuffle = (pipeline
                       | beam.Create(data)
                       | beam.WindowInto(GlobalWindows())
                       | beam.GroupByKey())
   assert_that(before_reshuffle, equal_to(expected_data),
               label='before_reshuffle')
   after_reshuffle = before_reshuffle | beam.Reshuffle()
   assert_that(after_reshuffle, equal_to(expected_data),
               label='after reshuffle')
   pipeline.run()
Пример #19
0
  def test_flatmap_builtin(self):
    pipeline = TestPipeline()
    pcoll = pipeline | 'label1' >> Create([1, 2, 3])
    assert_that(pcoll, equal_to([1, 2, 3]))

    pcoll2 = pcoll | 'do' >> FlatMap(lambda x: [x + 10])
    assert_that(pcoll2, equal_to([11, 12, 13]), label='pcoll2')

    pcoll3 = pcoll2 | 'm1' >> Map(lambda x: [x, 12])
    assert_that(pcoll3,
                equal_to([[11, 12], [12, 12], [13, 12]]), label='pcoll3')

    pcoll4 = pcoll3 | 'do2' >> FlatMap(set)
    assert_that(pcoll4, equal_to([11, 12, 12, 12, 13]), label='pcoll4')
    pipeline.run()
Пример #20
0
 def test_read_from_text_file_pattern(self):
   pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
   assert len(expected_data) == 40
   pipeline = TestPipeline()
   pcoll = pipeline | 'Read' >> ReadFromText(pattern)
   assert_that(pcoll, equal_to(expected_data))
   pipeline.run()
 def test_pardo(self):
   with self.create_pipeline() as p:
     res = (p
            | beam.Create(['a', 'bc'])
            | beam.Map(lambda e: e * 2)
            | beam.Map(lambda e: e + 'x'))
     assert_that(res, equal_to(['aax', 'bcbcx']))
Пример #22
0
  def test_read_messages_timestamp_attribute_rfc3339_success(self, mock_pubsub):
    data = 'data'
    message_id = 'message_id'
    attributes = {'time': '2018-03-12T13:37:01.234567Z'}
    publish_time = '2018-03-12T13:37:01.234567Z'
    payloads = [
        create_client_message(data, message_id, attributes, publish_time)]
    expected_elements = [
        TestWindowedValue(
            PubsubMessage(data, attributes),
            timestamp.Timestamp.from_rfc3339(attributes['time']),
            [window.GlobalWindow()]),
    ]

    mock_pubsub.Client = functools.partial(FakePubsubClient, payloads)
    mock_pubsub.subscription.AutoAck = FakeAutoAck

    p = TestPipeline()
    p.options.view_as(StandardOptions).streaming = True
    pcoll = (p
             | ReadFromPubSub(
                 'projects/fakeprj/topics/a_topic', None, 'a_label',
                 with_attributes=True, timestamp_attribute='time'))
    assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
    p.run()
Пример #23
0
  def test_setting_timestamp(self):
    with TestPipeline() as p:
      unkeyed_items = p | beam.Create([12, 30, 60, 61, 66])
      items = (unkeyed_items | 'key' >> beam.Map(lambda x: ('k', x)))

      def extract_timestamp_from_log_entry(entry):
        return entry[1]

      # [START setting_timestamp]
      class AddTimestampDoFn(beam.DoFn):

        def process(self, element):
          # Extract the numeric Unix seconds-since-epoch timestamp to be
          # associated with the current log entry.
          unix_timestamp = extract_timestamp_from_log_entry(element)
          # Wrap and emit the current entry and new timestamp in a
          # TimestampedValue.
          yield beam.window.TimestampedValue(element, unix_timestamp)

      timestamped_items = items | 'timestamp' >> beam.ParDo(AddTimestampDoFn())
      # [END setting_timestamp]
      fixed_windowed_items = (
          timestamped_items | 'window' >> beam.WindowInto(
              beam.window.FixedWindows(60)))
      summed = (fixed_windowed_items
                | 'group' >> beam.GroupByKey()
                | 'combine' >> beam.CombineValues(sum))
      unkeyed = summed | 'unkey' >> beam.Map(lambda x: x[1])
      assert_that(unkeyed, equal_to([42, 187]))
 def test_gbk_side_input(self):
   with self.create_pipeline() as p:
     main = p | 'main' >> beam.Create([None])
     side = p | 'side' >> beam.Create([('a', 1)]) | beam.GroupByKey()
     assert_that(
         main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)),
         equal_to([(None, {'a': [1]})]))
Пример #25
0
 def test_read_all_from_avro_many_single_files(self):
   path1 = self._write_data()
   path2 = self._write_data()
   path3 = self._write_data()
   with TestPipeline() as p:
     assert_that(p | Create([path1, path2, path3]) | avroio.ReadAllFromAvro(),
                 equal_to(self.RECORDS * 3))
 def test_group_by_key(self):
   with self.create_pipeline() as p:
     res = (p
            | beam.Create([('a', 1), ('a', 2), ('b', 3)])
            | beam.GroupByKey()
            | beam.Map(lambda k_vs: (k_vs[0], sorted(k_vs[1]))))
     assert_that(res, equal_to([('a', [1, 2]), ('b', [3])]))
Пример #27
0
  def test_after_count(self):
    with TestPipeline() as p:
      def construct_timestamped(k_t):
        return TimestampedValue((k_t[0], k_t[1]), k_t[1])

      def format_result(k_v):
        return ('%s-%s' % (k_v[0], len(k_v[1])), set(k_v[1]))

      result = (p
                | beam.Create([1, 2, 3, 4, 5, 10, 11])
                | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)])
                | beam.Map(construct_timestamped)
                | beam.WindowInto(FixedWindows(10), trigger=AfterCount(3),
                                  accumulation_mode=AccumulationMode.DISCARDING)
                | beam.GroupByKey()
                | beam.Map(format_result))
      assert_that(result, equal_to(
          list(
              {
                  'A-5': {1, 2, 3, 4, 5},
                  # A-10, A-11 never emitted due to AfterCount(3) never firing.
                  'B-4': {6, 7, 8, 9},
                  'B-3': {10, 15, 16},
              }.items()
          )))
 def test_pardo_windowed_side_inputs(self):
   with self.create_pipeline() as p:
     # Now with some windowing.
     pcoll = p | beam.Create(list(range(10))) | beam.Map(
         lambda t: window.TimestampedValue(t, t))
     # Intentionally choosing non-aligned windows to highlight the transition.
     main = pcoll | 'WindowMain' >> beam.WindowInto(window.FixedWindows(5))
     side = pcoll | 'WindowSide' >> beam.WindowInto(window.FixedWindows(7))
     res = main | beam.Map(lambda x, s: (x, sorted(s)),
                           beam.pvalue.AsList(side))
     assert_that(
         res,
         equal_to([
             # The window [0, 5) maps to the window [0, 7).
             (0, list(range(7))),
             (1, list(range(7))),
             (2, list(range(7))),
             (3, list(range(7))),
             (4, list(range(7))),
             # The window [5, 10) maps to the window [7, 14).
             (5, list(range(7, 10))),
             (6, list(range(7, 10))),
             (7, list(range(7, 10))),
             (8, list(range(7, 10))),
             (9, list(range(7, 10)))]),
         label='windowed')
Пример #29
0
  def test_read_messages_timestamp_attribute_missing(self, mock_pubsub):
    data = 'data'
    attributes = {}
    publish_time_secs = 1520861821
    publish_time_nanos = 234567000
    publish_time = '2018-03-12T13:37:01.234567Z'
    ack_id = 'ack_id'
    pull_response = test_utils.create_pull_response([
        test_utils.PullResponseMessage(
            data, attributes, publish_time_secs, publish_time_nanos, ack_id)
    ])
    expected_elements = [
        TestWindowedValue(
            PubsubMessage(data, attributes),
            timestamp.Timestamp.from_rfc3339(publish_time),
            [window.GlobalWindow()]),
    ]
    mock_pubsub.return_value.pull.return_value = pull_response

    p = TestPipeline()
    p.options.view_as(StandardOptions).streaming = True
    pcoll = (p
             | ReadFromPubSub(
                 'projects/fakeprj/topics/a_topic', None, None,
                 with_attributes=True, timestamp_attribute='nonexistent'))
    assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
    p.run()
    mock_pubsub.return_value.acknowledge.assert_has_calls([
        mock.call(mock.ANY, [ack_id])])
Пример #30
0
 def test_read_from_text_single_file(self):
   file_name, expected_data = write_data(5)
   assert len(expected_data) == 5
   pipeline = TestPipeline()
   pcoll = pipeline | 'Read' >> ReadFromText(file_name)
   assert_that(pcoll, equal_to(expected_data))
   pipeline.run()
Пример #31
0
 def test_combine_per_key(self):
   with self.create_pipeline() as p:
     res = (p
            | beam.Create([('a', 1), ('a', 2), ('b', 3)])
            | beam.CombinePerKey(beam.combiners.MeanCombineFn()))
     assert_that(res, equal_to([('a', 1.5), ('b', 3.0)]))
Пример #32
0
    def test_single_phase_mixed_analyzer_run_once(self):
        span_0_key = 'span-0'
        span_1_key = 'span-1'

        def preprocessing_fn(inputs):

            integerized_s = tft.compute_and_apply_vocabulary(inputs['s'])

            _ = tft.bucketize(inputs['x'], 2, name='bucketize')

            return {
                'integerized_s':
                integerized_s,
                'x_min':
                tft.min(inputs['x'], name='x') + tf.zeros_like(inputs['x']),
                'x_mean':
                tft.mean(inputs['x'], name='x') + tf.zeros_like(inputs['x']),
                'y_min':
                tft.min(inputs['y'], name='y') + tf.zeros_like(inputs['y']),
                'y_mean':
                tft.mean(inputs['y'], name='y') + tf.zeros_like(inputs['y']),
            }

        # Run AnalyzeAndTransform on some input data and compare with expected
        # output.
        input_data = [{'x': 12, 'y': 1, 's': 'd'}, {'x': 10, 'y': 1, 's': 'c'}]
        input_metadata = dataset_metadata.DatasetMetadata(
            schema_utils.schema_from_feature_spec({
                'x':
                tf.io.FixedLenFeature([], tf.float32),
                'y':
                tf.io.FixedLenFeature([], tf.float32),
                's':
                tf.io.FixedLenFeature([], tf.string),
            }))
        input_data_dict = {
            span_0_key: [{
                'x': -2,
                'y': 1,
                's': 'b',
            }, {
                'x': 4,
                'y': -4,
                's': 'b',
            }],
            span_1_key:
            input_data,
        }

        with _TestPipeline() as p:
            flat_data = p | 'CreateInputData' >> beam.Create(
                list(itertools.chain(*input_data_dict.values())))
            cache_dict = {
                span_0_key: {
                    b'__v0__CacheableCombineAccumulate[x_1/mean_and_var]-.\xc4t>ZBv\xea\xa5SU\xf4\x065\xc6\x1c\x81W\xf9\x1b':
                    p | 'CreateA' >> beam.Create([b'[2.0, 1.0, 9.0, 0.0]']),
                    b'__v0__CacheableCombineAccumulate[x/x]-\x95\xc5w\x88\x85\x8b5V\xc9\x00\xe0\x0f\x03\x1a\xdaL\x9d\xd5\xb3\xe3':
                    p | 'CreateB' >> beam.Create([b'[2.0, 4.0]']),
                    b'__v0__CacheableCombineAccumulate[y_1/mean_and_var]-E^\xb7VZ\xeew4rm\xab\xa3\xa4k|J\x80ck\x16':
                    p | 'CreateC' >> beam.Create([b'[2.0, -1.5, 6.25, 0.0]']),
                    b'__v0__CacheableCombineAccumulate[y/y]-\xdf\x1ey\x03\x1c\x96\xd5'
                    b' e\x9bJ\xa1\xd2\xfc\x9c\x03\x0fM \xdb':
                    p | 'CreateD' >> beam.Create([b'[4.0, 1.0]']),
                },
                span_1_key: {},
            }

            transform_fn, cache_output = (
                (flat_data, input_data_dict, cache_dict, input_metadata)
                | 'Analyze' >>
                tft_beam.AnalyzeDatasetWithCache(preprocessing_fn))
            _ = (cache_output | 'WriteCache' >>
                 analyzer_cache.WriteAnalysisCacheToFS(p, self._cache_dir))

            transformed_dataset = (
                ((input_data_dict[span_1_key], input_metadata), transform_fn)
                | 'Transform' >> tft_beam.TransformDataset())

            dot_string = nodes.get_dot_graph(
                [analysis_graph_builder._ANALYSIS_GRAPH]).to_string()
            self.WriteRenderedDotFile(dot_string)

            # The output cache should not have entries for the cache that is present
            # in the input cache.
            self.assertEqual(len(cache_output[span_0_key]),
                             len(cache_output[span_1_key]) - 4)

            transformed_data, unused_transformed_metadata = transformed_dataset

            expected_transformed = [
                {
                    'x_mean': 6.0,
                    'x_min': -2.0,
                    'y_mean': -0.25,
                    'y_min': -4.0,
                    'integerized_s': 1,
                },
                {
                    'x_mean': 6.0,
                    'x_min': -2.0,
                    'y_mean': -0.25,
                    'y_min': -4.0,
                    'integerized_s': 2,
                },
            ]
            beam_test_util.assert_that(
                transformed_data,
                beam_test_util.equal_to(expected_transformed))

            transform_fn_dir = os.path.join(self.base_test_dir, 'transform_fn')
            _ = transform_fn | tft_beam.WriteTransformFn(transform_fn_dir)

        # 4 from analyzing 2 spans, and 2 from transform.
        self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 6)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_decoded'), 4)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_encoded'), 8)
        self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'),
                         2)
Пример #33
0
    def testClassifyIncarcerationEvents(self):
        """Tests the ClassifyIncarcerationEvents DoFn."""
        fake_person_id = 12345

        fake_person = StatePerson.new_with_defaults(
            person_id=fake_person_id,
            gender=Gender.MALE,
            birthdate=date(1970, 1, 1),
            residency_status=ResidencyStatus.PERMANENT)

        incarceration_period = StateIncarcerationPeriod.new_with_defaults(
            incarceration_period_id=1111,
            incarceration_type=StateIncarcerationType.STATE_PRISON,
            status=StateIncarcerationPeriodStatus.NOT_IN_CUSTODY,
            state_code='TX',
            facility='PRISON XX',
            admission_date=date(2010, 11, 20),
            admission_reason=StateIncarcerationPeriodAdmissionReason.
            PROBATION_REVOCATION,
            release_date=date(2010, 11, 21),
            release_reason=StateIncarcerationPeriodReleaseReason.
            SENTENCE_SERVED)

        incarceration_sentence = StateIncarcerationSentence.new_with_defaults(
            incarceration_sentence_id=123,
            incarceration_periods=[incarceration_period],
            start_date=date(2009, 2, 9),
            charges=[
                StateCharge.new_with_defaults(ncic_code='5699',
                                              statute='30A123',
                                              offense_date=date(2009, 1, 9))
            ])

        sentence_group = StateSentenceGroup.new_with_defaults(
            sentence_group_id=123,
            incarceration_sentences=[incarceration_sentence])

        incarceration_sentence.sentence_group = sentence_group

        incarceration_period.incarceration_sentences = [incarceration_sentence]

        person_entities = {
            'person': [fake_person],
            'sentence_groups': [sentence_group]
        }

        fake_person_id_to_county_query_result = [{
            'person_id':
            fake_person_id,
            'county_of_residence':
            _COUNTY_OF_RESIDENCE
        }]

        incarceration_events = [
            IncarcerationStayEvent(
                admission_reason=incarceration_period.admission_reason,
                admission_reason_raw_text=incarceration_period.
                admission_reason_raw_text,
                supervision_type_at_admission=
                StateSupervisionPeriodSupervisionType.PROBATION,
                state_code=incarceration_period.state_code,
                event_date=incarceration_period.admission_date,
                facility=incarceration_period.facility,
                county_of_residence=_COUNTY_OF_RESIDENCE,
                most_serious_offense_ncic_code='5699',
                most_serious_offense_statute='30A123'),
            IncarcerationAdmissionEvent(
                state_code=incarceration_period.state_code,
                event_date=incarceration_period.admission_date,
                facility=incarceration_period.facility,
                county_of_residence=_COUNTY_OF_RESIDENCE,
                admission_reason=incarceration_period.admission_reason,
                admission_reason_raw_text=incarceration_period.
                admission_reason_raw_text,
                supervision_type_at_admission=
                StateSupervisionPeriodSupervisionType.PROBATION,
            ),
            IncarcerationReleaseEvent(
                state_code=incarceration_period.state_code,
                event_date=incarceration_period.release_date,
                facility=incarceration_period.facility,
                county_of_residence=_COUNTY_OF_RESIDENCE,
                release_reason=incarceration_period.release_reason)
        ]

        correct_output = [(fake_person, incarceration_events)]

        test_pipeline = TestPipeline()

        person_id_to_county_kv = (
            test_pipeline
            | "Read person id to county associations from BigQuery" >>
            beam.Create(fake_person_id_to_county_query_result)
            |
            "Convert to KV" >> beam.ParDo(ConvertDictToKVTuple(), 'person_id'))

        output = (test_pipeline
                  | beam.Create([(fake_person_id, person_entities)])
                  | 'Identify Incarceration Events' >> beam.ParDo(
                      pipeline.ClassifyIncarcerationEvents(),
                      AsDict(person_id_to_county_kv)))

        assert_that(output, equal_to(correct_output))

        test_pipeline.run()
Пример #34
0
 def test_project(self):
     with TestPipeline() as p:
         out = (p | beam.Create([SimpleRow(1, "foo", 3.14)])
                | SqlTransform("SELECT `id`, `flt` FROM PCOLLECTION"))
         assert_that(out, equal_to([(1, 3.14)]))
Пример #35
0
  def test_compute_top_sessions(self):
    with TestPipeline() as p:
      edits = p | beam.Create(self.EDITS)
      result = edits | top_wikipedia_sessions.ComputeTopSessions(1.0)

      assert_that(result, equal_to(self.EXPECTED))
Пример #36
0
 def test_log_distribution(self):
   with TestPipeline() as p:
     data = [int(math.log(x)) for x in range(1, 1000)]
     pc = p | Create(data)
     quantiles = pc | beam.ApproximateQuantiles.Globally(5)
     assert_that(quantiles, equal_to([[0, 5, 6, 6, 6]]))
Пример #37
0
    def test_multiple_destinations_transform(self):
        output_table_1 = '%s%s' % (self.output_table, 1)
        output_table_2 = '%s%s' % (self.output_table, 2)
        schema1 = {
            'fields': [{
                'name': 'name',
                'type': 'STRING',
                'mode': 'NULLABLE'
            }, {
                'name': 'language',
                'type': 'STRING',
                'mode': 'NULLABLE'
            }]
        }
        schema2 = {
            'fields': [{
                'name': 'name',
                'type': 'STRING',
                'mode': 'NULLABLE'
            }, {
                'name': 'foundation',
                'type': 'STRING',
                'mode': 'NULLABLE'
            }]
        }

        bad_record = {'language': 1, 'manguage': 2}

        pipeline_verifiers = [
            BigqueryFullResultMatcher(
                project=self.project,
                query="SELECT * FROM %s" % output_table_1,
                data=[(d['name'], d['language']) for d in _ELEMENTS
                      if 'language' in d]),
            BigqueryFullResultMatcher(
                project=self.project,
                query="SELECT * FROM %s" % output_table_2,
                data=[(d['name'], d['foundation']) for d in _ELEMENTS
                      if 'foundation' in d])
        ]

        args = self.test_pipeline.get_full_options_as_args(
            on_success_matcher=hc.all_of(*pipeline_verifiers))

        with beam.Pipeline(argv=args) as p:
            input = p | beam.Create(_ELEMENTS)

            input2 = p | "Broken record" >> beam.Create([bad_record])

            input = (input, input2) | beam.Flatten()

            r = (input
                 | "WriteWithMultipleDests" >>
                 beam.io.gcp.bigquery.WriteToBigQuery(
                     table=lambda x:
                     ((output_table_1, schema1) if 'language' in x else
                      (output_table_2, schema2)),
                     method='STREAMING_INSERTS'))

            assert_that(r[beam.io.gcp.bigquery.BigQueryWriteFn.FAILED_ROWS],
                        equal_to([(output_table_1, bad_record)]))
Пример #38
0
 def test_read_from_avro(self):
     path = self._write_data()
     with TestPipeline() as p:
         assert_that(
             p | avroio.ReadFromAvro(path, use_fastavro=self.use_fastavro),
             equal_to(self.RECORDS))
Пример #39
0
 def test_reshuffle(self):
   with self.create_pipeline() as p:
     assert_that(p | beam.Create([1, 2, 3]) | beam.Reshuffle(),
                 equal_to([1, 2, 3]))
Пример #40
0
 def test_create(self):
   with self.create_pipeline() as p:
     assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
Пример #41
0
 def test_assert_that(self):
   # TODO: figure out a way for fn_api_runner to parse and raise the
   # underlying exception.
   with self.assertRaisesRegexp(Exception, 'Failed assert'):
     with self.create_pipeline() as p:
       assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
Пример #42
0
 def test_nested(self):
     with beam.Pipeline() as p:
         assert_that(p | FibTransform(6), equal_to([8]))
Пример #43
0
    def test_multiple_partition_files(self):
        destination = 'project1:dataset1.table1'

        job_reference = bigquery_api.JobReference()
        job_reference.projectId = 'project1'
        job_reference.jobId = 'job_name1'
        result_job = mock.Mock()
        result_job.jobReference = job_reference

        mock_job = mock.Mock()
        mock_job.status.state = 'DONE'
        mock_job.status.errorResult = None
        mock_job.jobReference = job_reference

        bq_client = mock.Mock()
        bq_client.jobs.Get.return_value = mock_job

        bq_client.jobs.Insert.return_value = result_job
        bq_client.tables.Delete.return_value = None

        with TestPipeline('DirectRunner') as p:
            outputs = (p
                       | beam.Create(_ELEMENTS, reshuffle=False)
                       | bqfl.BigQueryBatchFileLoads(
                           destination,
                           custom_gcs_temp_location=self._new_tempdir(),
                           test_client=bq_client,
                           validate=False,
                           temp_file_format=bigquery_tools.FileFormat.JSON,
                           max_file_size=45,
                           max_partition_size=80,
                           max_files_per_partition=2))

            dest_files = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS]
            dest_load_jobs = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS]
            dest_copy_jobs = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_COPY_JOBID_PAIRS]

            load_jobs = dest_load_jobs | "GetLoadJobs" >> beam.Map(
                lambda x: x[1])
            copy_jobs = dest_copy_jobs | "GetCopyJobs" >> beam.Map(
                lambda x: x[1])

            files = dest_files | "GetFiles" >> beam.Map(lambda x: x[1][0])
            destinations = (
                dest_files
                | "GetDests" >>
                beam.Map(lambda x:
                         (bigquery_tools.get_hashable_destination(x[0]), x[1]))
                | "GetUniques" >> combiners.Count.PerKey()
                | "GetFinalDests" >> beam.Keys())

            # All files exist
            _ = (files
                 | beam.Map(
                     lambda x: hamcrest_assert(os.path.exists(x), is_(True))))

            # One file per destination
            assert_that(files | "CountFiles" >> combiners.Count.Globally(),
                        equal_to([6]),
                        label='CheckFileCount')

            assert_that(destinations,
                        equal_to([destination]),
                        label='CheckDestinations')

            assert_that(load_jobs
                        | "CountLoadJobs" >> combiners.Count.Globally(),
                        equal_to([6]),
                        label='CheckLoadJobCount')
            assert_that(copy_jobs
                        | "CountCopyJobs" >> combiners.Count.Globally(),
                        equal_to([6]),
                        label='CheckCopyJobCount')
Пример #44
0
    def test_records_traverse_transform_with_mocks(self):
        destination = 'project1:dataset1.table1'

        job_reference = bigquery_api.JobReference()
        job_reference.projectId = 'project1'
        job_reference.jobId = 'job_name1'
        result_job = bigquery_api.Job()
        result_job.jobReference = job_reference

        mock_job = mock.Mock()
        mock_job.status.state = 'DONE'
        mock_job.status.errorResult = None
        mock_job.jobReference = job_reference

        bq_client = mock.Mock()
        bq_client.jobs.Get.return_value = mock_job

        bq_client.jobs.Insert.return_value = result_job

        transform = bqfl.BigQueryBatchFileLoads(
            destination,
            custom_gcs_temp_location=self._new_tempdir(),
            test_client=bq_client,
            validate=False,
            temp_file_format=bigquery_tools.FileFormat.JSON)

        # Need to test this with the DirectRunner to avoid serializing mocks
        with TestPipeline('DirectRunner') as p:
            outputs = p | beam.Create(_ELEMENTS) | transform

            dest_files = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS]
            dest_job = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS]

            jobs = dest_job | "GetJobs" >> beam.Map(lambda x: x[1])

            files = dest_files | "GetFiles" >> beam.Map(lambda x: x[1][0])
            destinations = (
                dest_files
                | "GetDests" >>
                beam.Map(lambda x:
                         (bigquery_tools.get_hashable_destination(x[0]), x[1]))
                | "GetUniques" >> combiners.Count.PerKey()
                | "GetFinalDests" >> beam.Keys())

            # All files exist
            _ = (files
                 | beam.Map(
                     lambda x: hamcrest_assert(os.path.exists(x), is_(True))))

            # One file per destination
            assert_that(files | combiners.Count.Globally(),
                        equal_to([1]),
                        label='CountFiles')

            assert_that(destinations,
                        equal_to([destination]),
                        label='CheckDestinations')

            assert_that(jobs, equal_to([job_reference]), label='CheckJobs')
Пример #45
0
 def test_singleton(self):
   with TestPipeline() as p:
     data = [389]
     pc = p | Create(data)
     quantiles = pc | beam.ApproximateQuantiles.Globally(5)
     assert_that(quantiles, equal_to([[389, 389, 389, 389, 389]]))
Пример #46
0
 def test_tostring_kvs_delimeter(self):
     with TestPipeline() as p:
         result = (p | beam.Create([("one", 1),
                                    ("two", 2)]) | util.ToString.Kvs("\t"))
         assert_that(result, equal_to(["one\t1", "two\t2"]))
Пример #47
0
def run(argv=None, assert_results=None, save_main_session=True):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--input_email',
        required=True,
        help='Email database, with each line formatted as "name<TAB>email".')
    parser.add_argument(
        '--input_phone',
        required=True,
        help='Phonebook, with each line formatted as "name<TAB>phone number".')
    parser.add_argument(
        '--input_snailmail',
        required=True,
        help='Address database, with each line formatted as "name<TAB>address".'
    )
    parser.add_argument('--output_tsv',
                        required=True,
                        help='Tab-delimited output file.')
    parser.add_argument('--output_stats',
                        required=True,
                        help='Output file for statistics about the input.')
    known_args, pipeline_args = parser.parse_known_args(argv)
    # We use the save_main_session option because one or more DoFn's in this
    # workflow rely on global context (e.g., a module imported at module level).
    pipeline_options = PipelineOptions(pipeline_args)
    pipeline_options.view_as(
        SetupOptions).save_main_session = save_main_session
    with beam.Pipeline(options=pipeline_options) as p:

        # Helper: read a tab-separated key-value mapping from a text file,
        # escape all quotes/backslashes, and convert it a PCollection of
        # (key, value) pairs.
        def read_kv_textfile(label, textfile):
            return (p
                    | 'Read: %s' % label >> ReadFromText(textfile)
                    | 'Backslash: %s' % label >>
                    beam.Map(lambda x: re.sub(r'\\', r'\\\\', x))
                    | 'EscapeQuotes: %s' % label >>
                    beam.Map(lambda x: re.sub(r'"', r'\"', x))
                    | 'Split: %s' % label >>
                    beam.Map(lambda x: re.split(r'\t+', x, 1)))

        # Read input databases.
        email = read_kv_textfile('email', known_args.input_email)
        phone = read_kv_textfile('phone', known_args.input_phone)
        snailmail = read_kv_textfile('snailmail', known_args.input_snailmail)

        # Group together all entries under the same name.
        grouped = (email, phone,
                   snailmail) | 'group_by_name' >> beam.CoGroupByKey()

        # Prepare tab-delimited output; something like this:
        # "name"<TAB>"email_1,email_2"<TAB>"phone"<TAB>"first_snailmail_only"
        def format_as_tsv(name_email_phone_snailmail):
            (name, (email, phone, snailmail)) = name_email_phone_snailmail
            return '\t'.join([
                '"%s"' % name,
                '"%s"' % ','.join(email),
                '"%s"' % ','.join(phone),
                '"%s"' % next(iter(snailmail), '')
            ])

        tsv_lines = grouped | beam.Map(format_as_tsv)

        # Compute some stats about our database of people.
        def without_email(name_email_phone_snailmail):
            (_, (email, _, _)) = name_email_phone_snailmail
            return not next(iter(email), None)

        def without_phones(name_email_phone_snailmail):
            (_, (_, phone, _)) = name_email_phone_snailmail
            return not next(iter(phone), None)

        def without_address(name_email_phone_snailmail):
            (_, (_, _, snailmail)) = name_email_phone_snailmail
            return not next(iter(snailmail), None)

        luddites = grouped | beam.Filter(
            without_email)  # People without email.
        writers = grouped | beam.Filter(
            without_phones)  # People without phones.
        nomads = grouped | beam.Filter(
            without_address)  # People without addresses.

        num_luddites = luddites | 'Luddites' >> beam.combiners.Count.Globally()
        num_writers = writers | 'Writers' >> beam.combiners.Count.Globally()
        num_nomads = nomads | 'Nomads' >> beam.combiners.Count.Globally()

        # Write tab-delimited output.
        # pylint: disable=expression-not-assigned
        tsv_lines | 'WriteTsv' >> WriteToText(known_args.output_tsv)

        # TODO(silviuc): Move the assert_results logic to the unit test.
        if assert_results is not None:
            expected_luddites, expected_writers, expected_nomads = assert_results
            assert_that(num_luddites,
                        equal_to([expected_luddites]),
                        label='assert:luddites')
            assert_that(num_writers,
                        equal_to([expected_writers]),
                        label='assert:writers')
            assert_that(num_nomads,
                        equal_to([expected_nomads]),
                        label='assert:nomads')
Пример #48
0
  def test_multiple_destinations_transform(self):
    streaming = self.test_pipeline.options.view_as(StandardOptions).streaming
    if streaming and isinstance(self.test_pipeline.runner, TestDataflowRunner):
      self.skipTest("TestStream is not supported on TestDataflowRunner")

    output_table_1 = '%s%s' % (self.output_table, 1)
    output_table_2 = '%s%s' % (self.output_table, 2)

    full_output_table_1 = '%s:%s' % (self.project, output_table_1)
    full_output_table_2 = '%s:%s' % (self.project, output_table_2)

    schema1 = {'fields': [
        {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'},
        {'name': 'language', 'type': 'STRING', 'mode': 'NULLABLE'}]}
    schema2 = {'fields': [
        {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'},
        {'name': 'foundation', 'type': 'STRING', 'mode': 'NULLABLE'}]}

    bad_record = {'language': 1, 'manguage': 2}

    if streaming:
      pipeline_verifiers = [
          PipelineStateMatcher(PipelineState.RUNNING),
          BigqueryFullResultStreamingMatcher(
              project=self.project,
              query="SELECT name, language FROM %s" % output_table_1,
              data=[(d['name'], d['language'])
                    for d in _ELEMENTS
                    if 'language' in d]),
          BigqueryFullResultStreamingMatcher(
              project=self.project,
              query="SELECT name, foundation FROM %s" % output_table_2,
              data=[(d['name'], d['foundation'])
                    for d in _ELEMENTS
                    if 'foundation' in d])]
    else:
      pipeline_verifiers = [
          BigqueryFullResultMatcher(
              project=self.project,
              query="SELECT name, language FROM %s" % output_table_1,
              data=[(d['name'], d['language'])
                    for d in _ELEMENTS
                    if 'language' in d]),
          BigqueryFullResultMatcher(
              project=self.project,
              query="SELECT name, foundation FROM %s" % output_table_2,
              data=[(d['name'], d['foundation'])
                    for d in _ELEMENTS
                    if 'foundation' in d])]

    args = self.test_pipeline.get_full_options_as_args(
        on_success_matcher=hc.all_of(*pipeline_verifiers),
        experiments='use_beam_bq_sink')

    with beam.Pipeline(argv=args) as p:
      if streaming:
        _SIZE = len(_ELEMENTS)
        test_stream = (TestStream()
                       .advance_watermark_to(0)
                       .add_elements(_ELEMENTS[:_SIZE//2])
                       .advance_watermark_to(100)
                       .add_elements(_ELEMENTS[_SIZE//2:])
                       .advance_watermark_to_infinity())
        input = p | test_stream
      else:
        input = p | beam.Create(_ELEMENTS)

      schema_table_pcv = beam.pvalue.AsDict(
          p | "MakeSchemas" >> beam.Create([(full_output_table_1, schema1),
                                            (full_output_table_2, schema2)]))

      table_record_pcv = beam.pvalue.AsDict(
          p | "MakeTables" >> beam.Create([('table1', full_output_table_1),
                                           ('table2', full_output_table_2)]))

      input2 = p | "Broken record" >> beam.Create([bad_record])

      input = (input, input2) | beam.Flatten()

      r = (input
           | "WriteWithMultipleDests" >> beam.io.gcp.bigquery.WriteToBigQuery(
               table=lambda x, tables: (tables['table1']
                                        if 'language' in x
                                        else tables['table2']),
               table_side_inputs=(table_record_pcv,),
               schema=lambda dest, table_map: table_map.get(dest, None),
               schema_side_inputs=(schema_table_pcv,),
               method='STREAMING_INSERTS'))

      assert_that(r[beam.io.gcp.bigquery.BigQueryWriteFn.FAILED_ROWS],
                  equal_to([(full_output_table_1, bad_record)]))
Пример #49
0
 def test_tostring_kvs(self):
     with TestPipeline() as p:
         result = (p | beam.Create([("one", 1),
                                    ("two", 2)]) | util.ToString.Kvs())
         assert_that(result, equal_to(["one,1", "two,2"]))
Пример #50
0
    def test_streaming_complex_timing(self):
        # Use state on the TestCase class, since other references would be pickled
        # into a closure and not have the desired side effects.
        #
        # TODO(BEAM-5295): Use assert_that after it works for the cases here in
        # streaming mode.
        WriteFilesTest.all_records = []

        dir = '%s%s' % (self._new_tempdir(), os.sep)

        # Setting up the input (TestStream)
        ts = TestStream().advance_watermark_to(0)
        for elm in WriteFilesTest.LARGER_COLLECTION:
            timestamp = int(elm)

            ts.add_elements([('key', '%s' % elm)])
            if timestamp % 5 == 0 and timestamp != 0:
                # TODO(BEAM-3759): Add many firings per window after getting PaneInfo.
                ts.advance_processing_time(5)
                ts.advance_watermark_to(timestamp)
        ts.advance_watermark_to_infinity()

        def no_colon_file_naming(*args):
            file_name = fileio.destination_prefix_naming()(*args)
            return file_name.replace(':', '_')

        # The pipeline that we are testing
        options = PipelineOptions()
        options.view_as(StandardOptions).streaming = True
        with TestPipeline(options=options) as p:
            res = (p
                   | ts
                   | beam.WindowInto(
                       FixedWindows(10),
                       trigger=trigger.AfterWatermark(),
                       accumulation_mode=trigger.AccumulationMode.DISCARDING)
                   | beam.GroupByKey()
                   | beam.FlatMap(lambda x: x[1]))
            # Triggering after 5 processing-time seconds, and on the watermark. Also
            # discarding old elements.

            _ = (res
                 | beam.io.fileio.WriteToFiles(
                     path=dir,
                     file_naming=no_colon_file_naming,
                     max_writers_per_bundle=0)
                 | beam.Map(lambda fr: FileSystems.join(dir, fr.file_name))
                 | beam.ParDo(self.record_dofn()))

        # Verification pipeline
        with TestPipeline() as p:
            files = (p | beam.io.fileio.MatchFiles(FileSystems.join(dir, '*')))

            file_names = (files | beam.Map(lambda fm: fm.path))

            file_contents = (
                files
                | beam.io.fileio.ReadMatches()
                | beam.Map(lambda rf: (rf.metadata.path, rf.read_utf8().strip(
                ).split('\n'))))

            content = (file_contents
                       | beam.FlatMap(lambda fc: [ln.strip() for ln in fc[1]]))

            assert_that(file_names,
                        equal_to(WriteFilesTest.all_records),
                        label='AssertFilesMatch')
            assert_that(content,
                        matches_all(WriteFilesTest.LARGER_COLLECTION),
                        label='AssertContentsMatch')
Пример #51
0
 def test_tostring_iterables(self):
     with TestPipeline() as p:
         result = (p | beam.Create([("one", "two", "three"),
                                    ("four", "five", "six")])
                   | util.ToString.Iterables())
         assert_that(result, equal_to(["one,two,three", "four,five,six"]))
Пример #52
0
 def test_assert_that(self):
     # We still want to make sure asserts fail, even if the message
     # isn't right (BEAM-6019).
     with self.assertRaises(Exception):
         with self.create_pipeline() as p:
             assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
Пример #53
0
    def test_non_frequency_vocabulary_merge(self):
        """This test compares vocabularies produced with and without cache."""

        mi_vocab_name = 'mutual_information_vocab'
        adjusted_mi_vocab_name = 'adjusted_mutual_information_vocab'
        weighted_frequency_vocab_name = 'weighted_frequency_vocab'

        def preprocessing_fn(inputs):
            _ = tft.vocabulary(inputs['s'],
                               labels=inputs['label'],
                               store_frequency=True,
                               vocab_filename=mi_vocab_name,
                               min_diff_from_avg=0.1,
                               use_adjusted_mutual_info=False)

            _ = tft.vocabulary(inputs['s'],
                               labels=inputs['label'],
                               store_frequency=True,
                               vocab_filename=adjusted_mi_vocab_name,
                               min_diff_from_avg=1.0,
                               use_adjusted_mutual_info=True)

            _ = tft.vocabulary(inputs['s'],
                               weights=inputs['weight'],
                               store_frequency=True,
                               vocab_filename=weighted_frequency_vocab_name,
                               use_adjusted_mutual_info=False)
            return inputs

        span_0_key = 'span-0'
        span_1_key = 'span-1'

        input_data = [
            dict(s='a', weight=1, label=1),
            dict(s='a', weight=0.5, label=1),
            dict(s='b', weight=0.75, label=1),
            dict(s='b', weight=1, label=0),
        ]
        input_metadata = dataset_metadata.DatasetMetadata(
            schema_utils.schema_from_feature_spec({
                's':
                tf.io.FixedLenFeature([], tf.string),
                'label':
                tf.io.FixedLenFeature([], tf.int64),
                'weight':
                tf.io.FixedLenFeature([], tf.float32),
            }))
        input_data_dict = {
            span_0_key: input_data,
            span_1_key: input_data,
        }

        with _TestPipeline() as p:
            flat_data = p | 'CreateInputData' >> beam.Create(
                list(itertools.chain(*input_data_dict.values())))

            # wrap each value in input_data_dict as a pcoll.
            input_data_pcoll_dict = {}
            for a, b in six.iteritems(input_data_dict):
                input_data_pcoll_dict[a] = p | a >> beam.Create(b)

            transform_fn_with_cache, output_cache = (
                (flat_data, input_data_pcoll_dict, {}, input_metadata)
                | tft_beam.AnalyzeDatasetWithCache(preprocessing_fn))
            transform_fn_with_cache_dir = os.path.join(
                self.base_test_dir, 'transform_fn_with_cache')
            _ = transform_fn_with_cache | tft_beam.WriteTransformFn(
                transform_fn_with_cache_dir)

            expected_accumulators = {
                b'__v0__VocabularyAccumulate[vocabulary]-<GhZ\xac\xb8\xa9\x8c\xce\x1c\xb2-ck\xca\xe8\xec\t%\x8f':
                [
                    b'["a", [2, [0.0, 1.0], [0.0, 0.0], 1.0]]',
                    b'["b", [2, [0.5, 0.5], [0.0, 0.0], 1.0]]',
                    b'["global_y_count_sentinel", [4, [0.25, 0.75], [0.0, 0.0], '
                    b'1.0]]'
                ],
                b'__v0__VocabularyAccumulate[vocabulary_1]-\xa6\xae\nd\xe3\xd1\x9f\xa0\xe2\xb4\x05j\xa5\xfd\x8c\xfaeN\xd1\x1f':
                [
                    b'["a", [2, [0.0, 1.0], [0.0, 0.0], 1.0]]',
                    b'["b", [2, [0.5, 0.5], [0.0, 0.0], 1.0]]',
                    b'["global_y_count_sentinel", [4, [0.25, 0.75], [0.0, 0.0], '
                    b'1.0]]'
                ],
                b"__v0__VocabularyAccumulate[vocabulary_2]-\x97\x1c>\x851\x94'\xdc\xdf\xfd\xcc\x86\xb7\xb8\xe1\xe8*\x89B\t":
                [b'["a", 1.5]', b'["b", 1.75]'],
            }
            spans = [span_0_key, span_1_key]
            self.assertCountEqual(output_cache.keys(), spans)
            for span in spans:
                self.assertCountEqual(output_cache[span].keys(),
                                      expected_accumulators.keys())
                for idx, (key, value) in enumerate(
                        six.iteritems(expected_accumulators)):
                    beam_test_util.assert_that(
                        output_cache[span][key],
                        beam_test_util.equal_to(value),
                        label='AssertCache[{}][{}]'.format(span, idx))

        # 4 from analysis on each of the input spans.
        self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_decoded'), 0)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_encoded'), 6)
        self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'),
                         2)

        with _TestPipeline() as p:
            flat_data = p | 'CreateInputData' >> beam.Create(input_data * 2)

            transform_fn_no_cache = ((flat_data, input_metadata)
                                     |
                                     tft_beam.AnalyzeDataset(preprocessing_fn))

            transform_fn_no_cache_dir = os.path.join(self.base_test_dir,
                                                     'transform_fn_no_cache')
            _ = transform_fn_no_cache | tft_beam.WriteTransformFn(
                transform_fn_no_cache_dir)

        # 4 from analysis on each of the input spans.
        self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_decoded'), 0)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_encoded'), 0)
        self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'),
                         2)

        tft_output_cache = tft.TFTransformOutput(transform_fn_with_cache_dir)
        tft_output_no_cache = tft.TFTransformOutput(transform_fn_no_cache_dir)

        for vocab_filename in (mi_vocab_name, adjusted_mi_vocab_name,
                               weighted_frequency_vocab_name):
            cache_path = tft_output_cache.vocabulary_file_by_name(
                vocab_filename)
            no_cache_path = tft_output_no_cache.vocabulary_file_by_name(
                vocab_filename)
            with tf.io.gfile.GFile(cache_path, 'rb') as f1, tf.io.gfile.GFile(
                    no_cache_path, 'rb') as f2:
                self.assertEqual(
                    f1.readlines(), f2.readlines(),
                    'vocab with cache != vocab without cache for: {}'.format(
                        vocab_filename))
 def test_assert_that(self):
     with self.assertRaisesRegexp(BeamAssertException, 'bad_assert'):
         with self.create_pipeline() as p:
             assert_that(p | beam.Create(['a', 'b']), equal_to(['a']),
                         'bad_assert')
Пример #55
0
 def test_tostring_iterables_with_delimeter(self):
     with TestPipeline() as p:
         data = [("one", "two", "three"), ("four", "five", "six")]
         result = (p | beam.Create(data) | util.ToString.Iterables("\t"))
         assert_that(result,
                     equal_to(["one\ttwo\tthree", "four\tfive\tsix"]))
Пример #56
0
    def test_caching_vocab_for_integer_categorical(self):

        span_0_key = 'span-0'
        span_1_key = 'span-1'

        def preprocessing_fn(inputs):
            return {
                'x_vocab':
                tft.compute_and_apply_vocabulary(inputs['x'],
                                                 frequency_threshold=2)
            }

        input_metadata = dataset_metadata.DatasetMetadata(
            schema_utils.schema_from_feature_spec({
                'x':
                tf.FixedLenFeature([], tf.int64),
            }))
        input_data_dict = {
            span_0_key: [{
                'x': -2,
            }, {
                'x': -4,
            }, {
                'x': -1,
            }, {
                'x': 4,
            }],
            span_1_key: [{
                'x': -2,
            }, {
                'x': -1,
            }, {
                'x': 6,
            }, {
                'x': 7,
            }],
        }
        expected_transformed_data = [{
            'x_vocab': 0,
        }, {
            'x_vocab': 1,
        }, {
            'x_vocab': -1,
        }, {
            'x_vocab': -1,
        }]
        with _TestPipeline() as p:
            flat_data = p | 'CreateInputData' >> beam.Create(
                list(itertools.chain(*input_data_dict.values())))

            cache_dict = {
                span_0_key: {
                    b'__v0__VocabularyAccumulate[compute_and_apply_vocabulary/vocabulary]-\x05e\xfe4\x03H.P\xb5\xcb\xd22\xe3\x16\x15\xf8\xf5\xe38\xd9':
                    p | 'CreateB' >> beam.Create(
                        [b'[-2, 2]', b'[-4, 1]', b'[-1, 1]', b'[4, 1]']),
                },
                span_1_key: {},
            }

            transform_fn, cache_output = (
                (flat_data, input_data_dict, cache_dict, input_metadata)
                | 'Analyze' >>
                tft_beam.AnalyzeDatasetWithCache(preprocessing_fn))

            dot_string = nodes.get_dot_graph(
                [analysis_graph_builder._ANALYSIS_GRAPH]).to_string()
            self.WriteRenderedDotFile(dot_string)

            self.assertNotIn(span_0_key, cache_output)

            _ = cache_output | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS(
                p, self._cache_dir)

            transformed_dataset = (
                ((input_data_dict[span_1_key], input_metadata), transform_fn)
                | 'Transform' >> tft_beam.TransformDataset())

            transformed_data, _ = transformed_dataset

            beam_test_util.assert_that(
                transformed_data,
                beam_test_util.equal_to(expected_transformed_data),
                label='first')

        # 4 from analysis since 1 span was completely cached, and 4 from transform.
        self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_decoded'), 1)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_encoded'), 1)
        self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'),
                         2)
Пример #57
0
    def test_tostring_elements(self):

        with TestPipeline() as p:
            result = (p | beam.Create([1, 1, 2, 3]) | util.ToString.Element())
            assert_that(result, equal_to(["1", "1", "2", "3"]))
Пример #58
0
    def test_single_phase_run_twice(self):

        span_0_key = 'span-0'
        span_1_key = 'span-1'

        def preprocessing_fn(inputs):

            _ = tft.vocabulary(inputs['s'], vocab_filename='vocab1')

            _ = tft.bucketize(inputs['x'], 2, name='bucketize')

            return {
                'x_min':
                tft.min(inputs['x'], name='x') + tf.zeros_like(inputs['x']),
                'x_mean':
                tft.mean(inputs['x'], name='x') + tf.zeros_like(inputs['x']),
                'y_min':
                tft.min(inputs['y'], name='y') + tf.zeros_like(inputs['y']),
                'y_mean':
                tft.mean(inputs['y'], name='y') + tf.zeros_like(inputs['y']),
                's_integerized':
                tft.compute_and_apply_vocabulary(
                    inputs['s'],
                    labels=inputs['label'],
                    use_adjusted_mutual_info=True),
            }

        input_metadata = dataset_metadata.DatasetMetadata(
            schema_utils.schema_from_feature_spec({
                'x':
                tf.io.FixedLenFeature([], tf.float32),
                'y':
                tf.io.FixedLenFeature([], tf.float32),
                's':
                tf.io.FixedLenFeature([], tf.string),
                'label':
                tf.io.FixedLenFeature([], tf.int64),
            }))
        input_data_dict = {
            span_0_key: [{
                'x': -2,
                'y': 1,
                's': 'a',
                'label': 0,
            }, {
                'x': 4,
                'y': -4,
                's': 'a',
                'label': 1,
            }, {
                'x': 5,
                'y': 11,
                's': 'a',
                'label': 1,
            }, {
                'x': 1,
                'y': -4,
                's': u'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'),
                'label': 1,
            }],
            span_1_key: [{
                'x': 12,
                'y': 1,
                's': u'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'),
                'label': 0
            }, {
                'x': 10,
                'y': 1,
                's': 'c',
                'label': 1
            }],
        }
        expected_vocabulary_contents = np.array(
            [b'a', u'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'), b'c'], dtype=object)
        with _TestPipeline() as p:
            flat_data = p | 'CreateInputData' >> beam.Create(
                list(itertools.chain(*input_data_dict.values())))

            # wrap each value in input_data_dict as a pcoll.
            input_data_pcoll_dict = {}
            for a, b in six.iteritems(input_data_dict):
                input_data_pcoll_dict[a] = p | a >> beam.Create(b)

            transform_fn_1, cache_output = (
                (flat_data, input_data_pcoll_dict, {}, input_metadata)
                | 'Analyze' >>
                tft_beam.AnalyzeDatasetWithCache(preprocessing_fn))
            _ = (cache_output
                 | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS(
                     p, self._cache_dir))

            transformed_dataset = (((input_data_pcoll_dict[span_1_key],
                                     input_metadata), transform_fn_1)
                                   |
                                   'Transform' >> tft_beam.TransformDataset())

            del input_data_pcoll_dict
            transformed_data, unused_transformed_metadata = transformed_dataset

            expected_transformed_data = [
                {
                    'x_mean': 5.0,
                    'x_min': -2.0,
                    'y_mean': 1.0,
                    'y_min': -4.0,
                    's_integerized': 0,
                },
                {
                    'x_mean': 5.0,
                    'x_min': -2.0,
                    'y_mean': 1.0,
                    'y_min': -4.0,
                    's_integerized': 2,
                },
            ]
            beam_test_util.assert_that(
                transformed_data,
                beam_test_util.equal_to(expected_transformed_data),
                label='first')

            transform_fn_dir = os.path.join(self.base_test_dir,
                                            'transform_fn_1')
            _ = transform_fn_1 | tft_beam.WriteTransformFn(transform_fn_dir)

            for key in input_data_dict:
                self.assertIn(key, cache_output)
                self.assertEqual(7, len(cache_output[key]))

        tf_transform_output = tft.TFTransformOutput(transform_fn_dir)
        vocab1_path = tf_transform_output.vocabulary_file_by_name('vocab1')
        self.AssertVocabularyContents(vocab1_path,
                                      expected_vocabulary_contents)

        # 4 from analyzing 2 spans, and 2 from transform.
        self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_decoded'), 0)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_encoded'), 14)
        self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'),
                         2)

        with _TestPipeline() as p:
            flat_data = p | 'CreateInputData' >> beam.Create(
                list(itertools.chain(*input_data_dict.values())))

            # wrap each value in input_data_dict as a pcoll.
            input_data_pcoll_dict = {}
            for a, b in six.iteritems(input_data_dict):
                input_data_pcoll_dict[a] = p | a >> beam.Create(b)

            input_cache = p | analyzer_cache.ReadAnalysisCacheFromFS(
                self._cache_dir, list(input_data_dict.keys()))

            transform_fn_2, second_output_cache = (
                (flat_data, input_data_pcoll_dict, input_cache, input_metadata)
                | 'AnalyzeAgain' >>
                (tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)))
            _ = (second_output_cache
                 | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS(
                     p, self._cache_dir))

            dot_string = nodes.get_dot_graph(
                [analysis_graph_builder._ANALYSIS_GRAPH]).to_string()
            self.WriteRenderedDotFile(dot_string)

            transformed_dataset = (
                ((input_data_dict[span_1_key], input_metadata), transform_fn_2)
                | 'TransformAgain' >> tft_beam.TransformDataset())
            transformed_data, unused_transformed_metadata = transformed_dataset
            beam_test_util.assert_that(
                transformed_data,
                beam_test_util.equal_to(expected_transformed_data),
                label='second')

            transform_fn_dir = os.path.join(self.base_test_dir,
                                            'transform_fn_2')
            _ = transform_fn_2 | tft_beam.WriteTransformFn(transform_fn_dir)

        tf_transform_output = tft.TFTransformOutput(transform_fn_dir)
        vocab1_path = tf_transform_output.vocabulary_file_by_name('vocab1')
        self.AssertVocabularyContents(vocab1_path,
                                      expected_vocabulary_contents)

        self.assertFalse(second_output_cache)

        # Only 2 from transform.
        self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 2)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_decoded'), 14)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_encoded'), 0)

        # The root CreateSavedModel is optimized away because the data doesn't get
        # processed at all (only cache).
        self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'),
                         1)
Пример #59
0
 def test_flatten(self):
   with self.create_pipeline() as p:
     res = (p | 'a' >> beam.Create(['a']),
            p | 'bc' >> beam.Create(['b', 'c']),
            p | 'd' >> beam.Create(['d'])) | beam.Flatten()
     assert_that(res, equal_to(['a', 'b', 'c', 'd']))
 def _run_source_test(self, pattern, expected_data, splittable=True):
     pipeline = TestPipeline()
     pcoll = pipeline | 'Read' >> beam.io.Read(
         LineSource(pattern, splittable=splittable))
     assert_that(pcoll, equal_to(expected_data))
     pipeline.run()