def test_roundtrip_proto_multi(self):
        test_stream = (TestStream()
                       .advance_processing_time(1)
                       .advance_watermark_to(2, tag='a')
                       .advance_watermark_to(3, tag='b')
                       .add_elements([1, 2, 3], tag='a')
                       .add_elements([4, 5, 6], tag='b')) # yapf: disable

        options = StandardOptions(streaming=True)
        options.view_as(DebugOptions).add_experiment(
            'passthrough_pcollection_output_ids')

        p = TestPipeline(options=options)
        p | test_stream

        pipeline_proto, context = p.to_runner_api(return_context=True)

        for t in pipeline_proto.components.transforms.values():
            if t.spec.urn == common_urns.primitives.TEST_STREAM.urn:
                test_stream_proto = t

        self.assertTrue(test_stream_proto)
        roundtrip_test_stream = TestStream().from_runner_api(
            test_stream_proto, context)

        self.assertListEqual(test_stream._events,
                             roundtrip_test_stream._events)
        self.assertSetEqual(test_stream.output_tags,
                            roundtrip_test_stream.output_tags)
        self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
Example #2
0
def validate_options(args=None, option_classes=None):

    args = args or sys.argv
    option_classes = flatten(option_classes)

    help_flags = ['-h', '--help']
    help = any(flag in help_flags for flag in args)

    # first check to see if we are using the DirectRunner or the DataflowRunner
    # need to strip out any help params so that we don't exit too early
    nohelp_args = [arg for arg in sys.argv if arg not in help_flags]
    # Parse args just for StandardOptions and see which runner we are using
    local = StandardOptions(nohelp_args).runner in (None, 'DirectRunner')

    # make a new parser
    parser = argparse.ArgumentParser()

    # add args for all the options classes that we are using
    for opt in option_classes:
        opt._add_argparse_args(parser)
    StandardOptions._add_argparse_args(
        parser.add_argument_group('Dataflow Runner'))

    if help or not local:
        GoogleCloudOptions._add_argparse_args(
            parser.add_argument_group('Dataflow Runtime'))
        WorkerOptions._add_argparse_args(
            parser.add_argument_group('Dataflow Workers'))
        SetupOptions._add_argparse_args(
            parser.add_argument_group('Dataflow Setup'))

    # parse all args and trigger help if any required args are missing
    parser.parse_known_args(args)

    return PipelineOptions(args)
Example #3
0
    def test_in_streaming_mode(self):
        timestamp_interval = 1
        offset = itertools.count(0)
        start_time = timestamp.Timestamp(0)
        window_duration = 6
        test_stream = (
            TestStream().advance_watermark_to(start_time).add_elements([
                TimestampedValue(x,
                                 next(offset) * timestamp_interval)
                for x in GroupIntoBatchesTest._create_test_data()
            ]).advance_watermark_to(start_time + (window_duration - 1)).
            advance_watermark_to(start_time + (window_duration + 1)).
            advance_watermark_to(start_time + GroupIntoBatchesTest.NUM_ELEMENTS
                                 ).advance_watermark_to_infinity())
        pipeline = TestPipeline(options=StandardOptions(streaming=True))
        #  window duration is 6 and batch size is 5, so output batch size should be
        #  5 (flush because of batchSize reached)
        expected_0 = 5
        # there is only one element left in the window so batch size should be 1
        # (flush because of end of window reached)
        expected_1 = 1
        #  collection is 10 elements, there is only 4 left, so batch size should be
        #  4 (flush because end of collection reached)
        expected_2 = 4

        collection = pipeline | test_stream \
                     | WindowInto(FixedWindows(window_duration)) \
                     | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
        num_elements_in_batches = collection | beam.Map(len)

        result = pipeline.run()
        result.wait_until_finish()
        assert_that(num_elements_in_batches,
                    equal_to([expected_0, expected_1, expected_2]))
  def test_multiple_outputs(self):
    """Tests that the TestStream supports emitting to multiple PCollections."""
    letters_elements = [
        TimestampedValue('a', 6),
        TimestampedValue('b', 7),
        TimestampedValue('c', 8),
    ]
    numbers_elements = [
        TimestampedValue('1', 11),
        TimestampedValue('2', 12),
        TimestampedValue('3', 13),
    ]
    test_stream = (TestStream()
        .advance_watermark_to(5, tag='letters')
        .add_elements(letters_elements, tag='letters')
        .advance_watermark_to(10, tag='numbers')
        .add_elements(numbers_elements, tag='numbers'))  # yapf: disable

    class RecordFn(beam.DoFn):
      def process(
          self,
          element=beam.DoFn.ElementParam,
          timestamp=beam.DoFn.TimestampParam):
        yield (element, timestamp)

    options = StandardOptions(streaming=True)
    options.view_as(DebugOptions).add_experiment(
        'passthrough_pcollection_output_ids')
    p = TestPipeline(options=options)

    main = p | test_stream
    letters = main['letters'] | 'record letters' >> beam.ParDo(RecordFn())
    numbers = main['numbers'] | 'record numbers' >> beam.ParDo(RecordFn())

    assert_that(
        letters,
        equal_to([('a', Timestamp(6)), ('b', Timestamp(7)),
                  ('c', Timestamp(8))]),
        label='assert letters')

    assert_that(
        numbers,
        equal_to([('1', Timestamp(11)), ('2', Timestamp(12)),
                  ('3', Timestamp(13))]),
        label='assert numbers')

    p.run()
Example #5
0
  def test_basic_execution_with_service(self):
    """Tests that the TestStream can correctly read from an RPC service.
    """
    coder = beam.coders.FastPrimitivesCoder()

    test_stream_events = (TestStream(coder=coder)
        .advance_watermark_to(10000)
        .add_elements(['a', 'b', 'c'])
        .advance_watermark_to(20000)
        .add_elements(['d'])
        .add_elements(['e'])
        .advance_processing_time(10)
        .advance_watermark_to(300000)
        .add_elements([TimestampedValue('late', 12000)])
        .add_elements([TimestampedValue('last', 310000)])
        .advance_watermark_to_infinity())._events  # yapf: disable

    test_stream_proto_events = [
        e.to_runner_api(coder) for e in test_stream_events
    ]

    class InMemoryEventReader:
      def read_multiple(self, unused_keys):
        for e in test_stream_proto_events:
          yield e

    service = TestStreamServiceController(reader=InMemoryEventReader())
    service.start()

    test_stream = TestStream(coder=coder, endpoint=service.endpoint)

    class RecordFn(beam.DoFn):
      def process(
          self,
          element=beam.DoFn.ElementParam,
          timestamp=beam.DoFn.TimestampParam):
        yield (element, timestamp)

    options = StandardOptions(streaming=True)

    p = TestPipeline(options=options)
    my_record_fn = RecordFn()
    records = p | test_stream | beam.ParDo(my_record_fn)

    assert_that(
        records,
        equal_to([
            ('a', timestamp.Timestamp(10)),
            ('b', timestamp.Timestamp(10)),
            ('c', timestamp.Timestamp(10)),
            ('d', timestamp.Timestamp(20)),
            ('e', timestamp.Timestamp(20)),
            ('late', timestamp.Timestamp(12)),
            ('last', timestamp.Timestamp(310)),
        ]))

    p.run()
Example #6
0
    def test_multi_triggered_gbk_side_input(self):
        """Test a GBK sideinput, with multiple triggering."""
        options = StandardOptions(streaming=True)
        p = TestPipeline(options=options)

        test_stream = (
            p
            | 'Mixed TestStream' >> TestStream().advance_watermark_to(
                3, tag='main').add_elements(
                    ['a1'], tag='main').advance_watermark_to(
                        8, tag='main').add_elements(['a2'], tag='main').
            add_elements([window.TimestampedValue(
                ('k', 100), 2)], tag='side').add_elements(
                    [window.TimestampedValue(('k', 400), 7)],
                    tag='side').advance_watermark_to_infinity(
                        tag='main').advance_watermark_to_infinity(tag='side'))

        main_data = (
            test_stream['main']
            | 'Main windowInto' >> beam.WindowInto(
                window.FixedWindows(5),
                accumulation_mode=trigger.AccumulationMode.DISCARDING))

        side_data = (
            test_stream['side']
            | 'Side windowInto' >> beam.WindowInto(
                window.FixedWindows(5),
                trigger=trigger.AfterWatermark(early=trigger.AfterCount(1)),
                accumulation_mode=trigger.AccumulationMode.DISCARDING)
            | beam.CombinePerKey(sum)
            | 'Values' >> Map(lambda k_vs: k_vs[1]))

        class RecordFn(beam.DoFn):
            def process(self,
                        elm=beam.DoFn.ElementParam,
                        ts=beam.DoFn.TimestampParam,
                        side=beam.DoFn.SideInputParam):
                yield (elm, ts, side)

        records = (main_data
                   | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_data)))

        expected_window_to_elements = {
            window.IntervalWindow(0, 5): [
                ('a1', Timestamp(3), [100, 0]),
            ],
            window.IntervalWindow(5, 10): [('a2', Timestamp(8), [400, 0])],
        }

        assert_that(records,
                    equal_to_per_window(expected_window_to_elements),
                    use_global_window=False,
                    label='assert per window')

        p.run()
Example #7
0
  def test_basic_execution(self):
    test_stream = (TestStream()
                   .advance_watermark_to(0)
                   .advance_processing_time(5)
                   .add_elements(['a', 'b', 'c'])
                   .advance_watermark_to(2)
                   .advance_processing_time(1)
                   .advance_watermark_to(4)
                   .advance_processing_time(1)
                   .advance_watermark_to(6)
                   .advance_processing_time(1)
                   .advance_watermark_to(8)
                   .advance_processing_time(1)
                   .advance_watermark_to(10)
                   .advance_processing_time(1)
                   .add_elements([TimestampedValue('1', 15),
                                  TimestampedValue('2', 15),
                                  TimestampedValue('3', 15)]))  # yapf: disable

    options = StandardOptions(streaming=True)
    p = TestPipeline(options=options)

    records = (
        p
        | test_stream
        | ReverseTestStream(sample_resolution_sec=1, output_tag=None))

    assert_that(
        records,
        equal_to_per_window({
            beam.window.GlobalWindow(): [
                [ProcessingTimeEvent(5), WatermarkEvent(0)],
                [
                    ElementEvent([
                        TimestampedValue('a', 0),
                        TimestampedValue('b', 0),
                        TimestampedValue('c', 0)
                    ])
                ],
                [ProcessingTimeEvent(1), WatermarkEvent(2000000)],
                [ProcessingTimeEvent(1), WatermarkEvent(4000000)],
                [ProcessingTimeEvent(1), WatermarkEvent(6000000)],
                [ProcessingTimeEvent(1), WatermarkEvent(8000000)],
                [ProcessingTimeEvent(1), WatermarkEvent(10000000)],
                [
                    ElementEvent([
                        TimestampedValue('1', 15),
                        TimestampedValue('2', 15),
                        TimestampedValue('3', 15)
                    ])
                ],
            ],
        }))

    p.run()
Example #8
0
 def test_equal_to_per_window_succeeds_no_reify_windows(self):
   start = int(MIN_TIMESTAMP.micros // 1e6) - 5
   end = start + 20
   expected = {
       window.IntervalWindow(start, end): [('k', [1])],
   }
   with TestPipeline(options=StandardOptions(streaming=True)) as p:
     assert_that((p
                  | Create([1])
                  | beam.WindowInto(
                      FixedWindows(20),
                      trigger=trigger.AfterWatermark(),
                      accumulation_mode=trigger.AccumulationMode.DISCARDING)
                  | beam.Map(lambda x: ('k', x))
                  | beam.GroupByKey()),
                 equal_to_per_window(expected))
Example #9
0
 def test_equal_to_per_window_fail_unmatched_window(self):
   with self.assertRaises(BeamAssertException):
     expected = {
         window.IntervalWindow(50, 100): [('k', [1])],
     }
     with TestPipeline(options=StandardOptions(streaming=True)) as p:
       assert_that((p
                    | Create([1])
                    | beam.WindowInto(
                        FixedWindows(20),
                        trigger=trigger.AfterWatermark(),
                        accumulation_mode=trigger.AccumulationMode.DISCARDING)
                    | beam.Map(lambda x: ('k', x))
                    | beam.GroupByKey()),
                   equal_to_per_window(expected),
                   reify_windows=True)
Example #10
0
  def test_fragment_does_not_prune_teststream(self):
    """Tests that the fragment does not prune the TestStream composite parts.
    """
    options = StandardOptions(streaming=True)
    p = beam.Pipeline(ir.InteractiveRunner(), options)

    test_stream = p | TestStream(output_tags=['a', 'b'])

    # pylint: disable=unused-variable
    a = test_stream['a'] | 'a' >> beam.Map(lambda _: _)
    b = test_stream['b'] | 'b' >> beam.Map(lambda _: _)

    fragment = pf.PipelineFragment([b]).deduce_fragment()

    # If the fragment does prune the TestStreawm composite parts, then the
    # resulting graph is invalid and the following call will raise an exception.
    fragment.to_runner_api()
Example #11
0
  def test_buffering_timer_in_fixed_window_streaming(self):
    window_duration = 6
    max_buffering_duration_secs = 100

    start_time = timestamp.Timestamp(0)
    test_stream = (
        TestStream().add_elements([
            TimestampedValue(value, start_time + i) for i,
            value in enumerate(GroupIntoBatchesTest._create_test_data())
        ]).advance_processing_time(150).advance_watermark_to(
            start_time + window_duration).advance_watermark_to(
                start_time + window_duration +
                1).advance_watermark_to_infinity())

    with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
      # To trigger the processing time timer, use a fake clock with start time
      # being Timestamp(0).
      fake_clock = FakeClock(now=start_time)

      num_elements_per_batch = (
          pipeline | test_stream
          | "fixed window" >> WindowInto(FixedWindows(window_duration))
          | util.GroupIntoBatches(
              GroupIntoBatchesTest.BATCH_SIZE,
              max_buffering_duration_secs,
              fake_clock)
          | "count elements in batch" >> Map(lambda x: (None, len(x[1])))
          | "global window" >> WindowInto(GlobalWindows())
          | GroupByKey()
          | FlatMapTuple(lambda k, vs: vs))

      # Window duration is 6 and batch size is 5, so output batch size
      # should be 5 (flush because of batch size reached).
      expected_0 = 5
      # There is only one element left in the window so batch size
      # should be 1 (flush because of max buffering duration reached).
      expected_1 = 1
      # Collection has 10 elements, there are only 4 left, so batch size should
      # be 4 (flush because of end of window reached).
      expected_2 = 4
      assert_that(
          num_elements_per_batch,
          equal_to([expected_0, expected_1, expected_2]),
          "assert2")
Example #12
0
  def test_buffering_timer_in_global_window_streaming(self):
    max_buffering_duration_secs = 42

    start_time = timestamp.Timestamp(0)
    test_stream = TestStream().advance_watermark_to(start_time)
    for i, value in enumerate(GroupIntoBatchesTest._create_test_data()):
      test_stream.add_elements(
          [TimestampedValue(value, start_time + i)]) \
        .advance_processing_time(5)
    test_stream.advance_watermark_to(
        start_time + GroupIntoBatchesTest.NUM_ELEMENTS + 1) \
      .advance_watermark_to_infinity()

    with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
      # Set a batch size larger than the total number of elements.
      # Since we're in a global window, we would have been waiting
      # for all the elements to arrive without the buffering time limit.
      batch_size = GroupIntoBatchesTest.NUM_ELEMENTS * 2

      # To trigger the processing time timer, use a fake clock with start time
      # being Timestamp(0). Since the fake clock never really advances during
      # the pipeline execution, meaning that the timer is always set to the same
      # value, the timer will be fired on every element after the first firing.
      fake_clock = FakeClock(now=start_time)

      num_elements_per_batch = (
          pipeline | test_stream
          | WindowInto(
              GlobalWindows(),
              trigger=Repeatedly(AfterCount(1)),
              accumulation_mode=trigger.AccumulationMode.DISCARDING)
          | util.GroupIntoBatches(
              batch_size, max_buffering_duration_secs, fake_clock)
          | 'count elements in batch' >> Map(lambda x: (None, len(x[1])))
          | GroupByKey()
          | FlatMapTuple(lambda k, vs: vs))

      # We will flush twice when the max buffering duration is reached and when
      # the global window ends.
      assert_that(num_elements_per_batch, equal_to([9, 1]))
Example #13
0
  def test_roundtrip_proto(self):
    test_stream = (TestStream()
                   .advance_processing_time(1)
                   .advance_watermark_to(2)
                   .add_elements([1, 2, 3])) # yapf: disable

    p = TestPipeline(options=StandardOptions(streaming=True))
    p | test_stream

    pipeline_proto, context = p.to_runner_api(return_context=True)

    for t in pipeline_proto.components.transforms.values():
      if t.spec.urn == common_urns.primitives.TEST_STREAM.urn:
        test_stream_proto = t

    self.assertTrue(test_stream_proto)
    roundtrip_test_stream = TestStream().from_runner_api(
        test_stream_proto, context)

    self.assertListEqual(test_stream._events, roundtrip_test_stream._events)
    self.assertSetEqual(
        test_stream.output_tags, roundtrip_test_stream.output_tags)
    self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
    def test_multiple_outputs_with_watermark_advancement(self):
        """Tests that the TestStream can independently control output watermarks."""

        # Purposely set the watermark of numbers to 20 then letters to 5 to test
        # that the watermark advancement is per PCollection.
        #
        # This creates two PCollections, (a, b, c) and (1, 2, 3). These will be
        # emitted at different times so that they will have different windows. The
        # watermark advancement is checked by checking their windows. If the
        # watermark does not advance, then the windows will be [-inf, -inf). If the
        # windows do not advance separately, then the PCollections will both
        # windowed in [15, 30).
        letters_elements = [
            TimestampedValue('a', 6),
            TimestampedValue('b', 7),
            TimestampedValue('c', 8),
        ]
        numbers_elements = [
            TimestampedValue('1', 21),
            TimestampedValue('2', 22),
            TimestampedValue('3', 23),
        ]
        test_stream = (TestStream().advance_watermark_to(
            0, tag='letters').advance_watermark_to(
                0, tag='numbers').advance_watermark_to(
                    20, tag='numbers').advance_watermark_to(
                        5, tag='letters').add_elements(
                            letters_elements,
                            tag='letters').advance_watermark_to(
                                10, tag='letters').add_elements(
                                    numbers_elements,
                                    tag='numbers').advance_watermark_to(
                                        30, tag='numbers'))

        options = StandardOptions(streaming=True)
        p = TestPipeline(is_integration_test=True, options=options)

        main = p | test_stream

        # Use an AfterWatermark trigger with an early firing to test that the
        # watermark is advancing properly and that the element is being emitted in
        # the correct window.
        letters = (
            main['letters']
            | 'letter windows' >> beam.WindowInto(
                FixedWindows(15),
                trigger=trigger.AfterWatermark(early=trigger.AfterCount(1)),
                accumulation_mode=trigger.AccumulationMode.DISCARDING)
            | 'letter with key' >> beam.Map(lambda x: ('k', x))
            | 'letter gbk' >> beam.GroupByKey())

        numbers = (
            main['numbers']
            | 'number windows' >> beam.WindowInto(
                FixedWindows(15),
                trigger=trigger.AfterWatermark(early=trigger.AfterCount(1)),
                accumulation_mode=trigger.AccumulationMode.DISCARDING)
            | 'number with key' >> beam.Map(lambda x: ('k', x))
            | 'number gbk' >> beam.GroupByKey())

        # The letters were emitted when the watermark was at 5, thus we expect to
        # see the elements in the [0, 15) window. We used an early trigger to make
        # sure that the ON_TIME empty pane was also emitted with a TestStream.
        # This pane has no data because of the early trigger causes the elements to
        # fire before the end of the window and because the accumulation mode
        # discards any data after the trigger fired.
        expected_letters = {
            window.IntervalWindow(0, 15): [
                ('k', ['a', 'b', 'c']),
                ('k', []),
            ],
        }

        # Same here, except the numbers were emitted at watermark = 20, thus they
        # are in the [15, 30) window.
        expected_numbers = {
            window.IntervalWindow(15, 30): [
                ('k', ['1', '2', '3']),
                ('k', []),
            ],
        }
        assert_that(letters,
                    equal_to_per_window(expected_letters),
                    label='letters assert per window')
        assert_that(numbers,
                    equal_to_per_window(expected_numbers),
                    label='numbers assert per window')

        p.run()
  def test_read_and_write_multiple_outputs(self):
    """An integration test between the Sink and Source with multiple outputs.

    This tests the funcionatlity that the StreamingCache reads from multiple
    files and combines them into a single sorted output.
    """
    LETTERS_TAG = repr(CacheKey('letters', '', '', ''))
    NUMBERS_TAG = repr(CacheKey('numbers', '', '', ''))

    # Units here are in seconds.
    test_stream = (TestStream()
                   .advance_watermark_to(0, tag=LETTERS_TAG)
                   .advance_processing_time(5)
                   .add_elements(['a', 'b', 'c'], tag=LETTERS_TAG)
                   .advance_watermark_to(10, tag=NUMBERS_TAG)
                   .advance_processing_time(1)
                   .add_elements(
                       [
                           TimestampedValue('1', 15),
                           TimestampedValue('2', 15),
                           TimestampedValue('3', 15)
                       ],
                       tag=NUMBERS_TAG)) # yapf: disable

    cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)

    coder = SafeFastPrimitivesCoder()

    options = StandardOptions(streaming=True)
    with TestPipeline(options=options) as p:
      # pylint: disable=expression-not-assigned
      events = p | test_stream
      events[LETTERS_TAG] | 'Letters sink' >> cache.sink([LETTERS_TAG])
      events[NUMBERS_TAG] | 'Numbers sink' >> cache.sink([NUMBERS_TAG])

    reader = cache.read_multiple([[LETTERS_TAG], [NUMBERS_TAG]])
    actual_events = list(reader)

    # Units here are in microseconds.
    expected_events = [
        TestStreamPayload.Event(
            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                advance_duration=5 * 10**6)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=0, tag=LETTERS_TAG)),
        TestStreamPayload.Event(
            element_event=TestStreamPayload.Event.AddElements(
                elements=[
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('a'), timestamp=0),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('b'), timestamp=0),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('c'), timestamp=0),
                ],
                tag=LETTERS_TAG)),
        TestStreamPayload.Event(
            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                advance_duration=1 * 10**6)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=10 * 10**6, tag=NUMBERS_TAG)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=0, tag=LETTERS_TAG)),
        TestStreamPayload.Event(
            element_event=TestStreamPayload.Event.AddElements(
                elements=[
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('1'), timestamp=15 *
                        10**6),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('2'), timestamp=15 *
                        10**6),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('3'), timestamp=15 *
                        10**6),
                ],
                tag=NUMBERS_TAG)),
    ]

    self.assertListEqual(actual_events, expected_events)
  def test_read_and_write(self):
    """An integration test between the Sink and Source.

    This ensures that the sink and source speak the same language in terms of
    coders, protos, order, and units.
    """
    CACHED_RECORDS = repr(CacheKey('records', '', '', ''))

    # Units here are in seconds.
    test_stream = (
        TestStream(output_tags=(CACHED_RECORDS))
                   .advance_watermark_to(0, tag=CACHED_RECORDS)
                   .advance_processing_time(5)
                   .add_elements(['a', 'b', 'c'], tag=CACHED_RECORDS)
                   .advance_watermark_to(10, tag=CACHED_RECORDS)
                   .advance_processing_time(1)
                   .add_elements(
                       [
                           TimestampedValue('1', 15),
                           TimestampedValue('2', 15),
                           TimestampedValue('3', 15)
                       ],
                       tag=CACHED_RECORDS)) # yapf: disable

    coder = SafeFastPrimitivesCoder()
    cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)

    # Assert that there are no capture keys at first.
    self.assertEqual(cache.capture_keys, set())

    options = StandardOptions(streaming=True)
    with TestPipeline(options=options) as p:
      records = (p | test_stream)[CACHED_RECORDS]

      # pylint: disable=expression-not-assigned
      records | cache.sink([CACHED_RECORDS], is_capture=True)

    reader, _ = cache.read(CACHED_RECORDS)
    actual_events = list(reader)

    # Assert that the capture keys are forwarded correctly.
    self.assertEqual(cache.capture_keys, set([CACHED_RECORDS]))

    # Units here are in microseconds.
    expected_events = [
        TestStreamPayload.Event(
            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                advance_duration=5 * 10**6)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=0, tag=CACHED_RECORDS)),
        TestStreamPayload.Event(
            element_event=TestStreamPayload.Event.AddElements(
                elements=[
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('a'), timestamp=0),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('b'), timestamp=0),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('c'), timestamp=0),
                ],
                tag=CACHED_RECORDS)),
        TestStreamPayload.Event(
            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                advance_duration=1 * 10**6)),
        TestStreamPayload.Event(
            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                new_watermark=10 * 10**6, tag=CACHED_RECORDS)),
        TestStreamPayload.Event(
            element_event=TestStreamPayload.Event.AddElements(
                elements=[
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('1'), timestamp=15 *
                        10**6),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('2'), timestamp=15 *
                        10**6),
                    TestStreamPayload.TimestampedElement(
                        encoded_element=coder.encode('3'), timestamp=15 *
                        10**6),
                ],
                tag=CACHED_RECORDS)),
    ]
    self.assertEqual(actual_events, expected_events)
Example #17
0
    def test_read_and_write(self):
        """An integration test between the Sink and Source.

    This ensures that the sink and source speak the same language in terms of
    coders, protos, order, and units.
    """

        # Units here are in seconds.
        test_stream = (TestStream()
                       .advance_watermark_to(0, tag='records')
                       .advance_processing_time(5)
                       .add_elements(['a', 'b', 'c'], tag='records')
                       .advance_watermark_to(10, tag='records')
                       .advance_processing_time(1)
                       .add_elements(
                           [
                               TimestampedValue('1', 15),
                               TimestampedValue('2', 15),
                               TimestampedValue('3', 15)
                           ],
                           tag='records')) # yapf: disable

        coder = SafeFastPrimitivesCoder()
        cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)

        options = StandardOptions(streaming=True)
        options.view_as(DebugOptions).add_experiment(
            'passthrough_pcollection_output_ids')
        with TestPipeline(options=options) as p:
            # pylint: disable=expression-not-assigned
            p | test_stream | cache.sink(['records'])

        reader, _ = cache.read('records')
        actual_events = list(reader)

        # Units here are in microseconds.
        expected_events = [
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=5 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=0, tag='records')),
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('a'), timestamp=0),
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('b'), timestamp=0),
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('c'), timestamp=0),
                    ],
                    tag='records')),
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=10 * 10**6, tag='records')),
            TestStreamPayload.Event(element_event=TestStreamPayload.Event.
                                    AddElements(elements=[
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('1'),
                                            timestamp=15 * 10**6),
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('2'),
                                            timestamp=15 * 10**6),
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('3'),
                                            timestamp=15 * 10**6),
                                    ],
                                                tag='records')),
        ]
        self.assertEqual(actual_events, expected_events)
Example #18
0
import apache_beam as beam
from apache_beam.options.pipeline_options import StandardOptions, GoogleCloudOptions, SetupOptions
from apache_beam.io.gcp.pubsub import ReadFromPubSub
from apache_beam.io.gcp.bigquery import WriteToBigQuery

# Settings
PROJECT = 'GCP_PROJECT'
SUBSCRIPTION = 'PUBSUB_SUBSCRIPTION'
REGION = 'GCP_REGION'
BUCKET = 'GCS_BUCKET'
DATABASE = 'BQ_DATABASE'
TABLE = 'BQ_TABLE'

# Define options
opt = StandardOptions()
opt.streaming = True
opt.runner = 'DataflowRunner'

stp = opt.view_as(SetupOptions)
stp.requirements_file = "./requirements.txt"

gcp = opt.view_as(GoogleCloudOptions)
gcp.project = PROJECT
gcp.region = REGION
gcp.staging_location = 'gs://{bucket}/staging'.format(bucket=BUCKET)
gcp.temp_location = 'gs://{bucket}/temp'.format(bucket=BUCKET)


# Enrichment Function
def enrichment(text):
    import json, user_agents
    def test_instrument_example_unbounded_pipeline_to_read_cache_not_cached(
            self):
        """Tests that the instrumenter works when the PCollection is not cached.
    """
        # Create the pipeline that will be instrumented.
        from apache_beam.options.pipeline_options import StandardOptions
        options = StandardOptions(streaming=True)
        p_original_read_cache = beam.Pipeline(
            interactive_runner.InteractiveRunner(), options)
        ie.current_env().set_cache_manager(StreamingCache(cache_dir=None),
                                           p_original_read_cache)
        source_1 = p_original_read_cache | 'source1' >> beam.io.ReadFromPubSub(
            subscription='projects/fake-project/subscriptions/fake_sub')
        # pylint: disable=possibly-unused-variable
        pcoll_1 = source_1 | 'square1' >> beam.Map(lambda x: x * x)

        # Watch but do not cache the PCollections.
        ib.watch(locals())
        # This should be noop.
        utils.watch_sources(p_original_read_cache)
        # Instrument the original pipeline to create the pipeline the user will see.
        p_copy = beam.Pipeline.from_runner_api(
            p_original_read_cache.to_runner_api(),
            runner=interactive_runner.InteractiveRunner(),
            options=options)
        ie.current_env().add_derived_pipeline(p_original_read_cache, p_copy)
        instrumenter = instr.build_pipeline_instrument(p_copy)
        actual_pipeline = beam.Pipeline.from_runner_api(
            proto=instrumenter.instrumented_pipeline_proto(),
            runner=interactive_runner.InteractiveRunner(),
            options=options)

        # Now, build the expected pipeline which replaces the unbounded source with
        # a TestStream.
        source_1_cache_key = self.cache_key_of('source_1', source_1)
        p_expected = beam.Pipeline()
        ie.current_env().set_cache_manager(StreamingCache(cache_dir=None),
                                           p_expected)
        test_stream = (p_expected
                       | TestStream(output_tags=[source_1_cache_key]))
        # pylint: disable=expression-not-assigned
        (test_stream[source_1_cache_key]
         | 'square1' >> beam.Map(lambda x: x * x)
         | 'reify' >> beam.Map(lambda _: _)
         | cache.WriteCache(ie.current_env().get_cache_manager(p_expected),
                            'unused'))

        # Test that the TestStream is outputting to the correct PCollection.
        class TestStreamVisitor(PipelineVisitor):
            def __init__(self):
                self.output_tags = set()

            def enter_composite_transform(self, transform_node):
                self.visit_transform(transform_node)

            def visit_transform(self, transform_node):
                transform = transform_node.transform
                if isinstance(transform, TestStream):
                    self.output_tags = transform.output_tags

        v = TestStreamVisitor()
        actual_pipeline.visit(v)
        expected_output_tags = set([source_1_cache_key])
        actual_output_tags = v.output_tags
        self.assertSetEqual(expected_output_tags, actual_output_tags)

        # Test that the pipeline is as expected.
        assert_pipeline_proto_equal(self, p_expected.to_runner_api(),
                                    instrumenter.instrumented_pipeline_proto())
    def test_able_to_cache_intermediate_unbounded_source_pcollection(self):
        """Tests being able to cache an intermediate source PCollection.

    In the following pipeline, the source doesn't have a reference and so is
    not automatically cached in the watch() command. This tests that this case
    is taken care of.
    """
        # Create the pipeline that will be instrumented.
        from apache_beam.options.pipeline_options import StandardOptions
        options = StandardOptions(streaming=True)
        streaming_cache_manager = StreamingCache(cache_dir=None)
        p_original_cache_source = beam.Pipeline(
            interactive_runner.InteractiveRunner(), options)
        ie.current_env().set_cache_manager(streaming_cache_manager,
                                           p_original_cache_source)

        # pylint: disable=possibly-unused-variable
        source_1 = (
            p_original_cache_source
            | 'source1' >> beam.io.ReadFromPubSub(
                subscription='projects/fake-project/subscriptions/fake_sub')
            | beam.Map(lambda e: e))

        # Watch but do not cache the PCollections.
        ib.watch(locals())
        # Make sure that sources without a user reference are still cached.
        utils.watch_sources(p_original_cache_source)

        intermediate_source_pcoll = None
        for watching in ie.current_env().watching():
            watching = list(watching)
            for var, watchable in watching:
                if 'synthetic' in var:
                    intermediate_source_pcoll = watchable
                    break

        # Instrument the original pipeline to create the pipeline the user will see.
        p_copy = beam.Pipeline.from_runner_api(
            p_original_cache_source.to_runner_api(),
            runner=interactive_runner.InteractiveRunner(),
            options=options)
        ie.current_env().add_derived_pipeline(p_original_cache_source, p_copy)
        instrumenter = instr.build_pipeline_instrument(p_copy)
        actual_pipeline = beam.Pipeline.from_runner_api(
            proto=instrumenter.instrumented_pipeline_proto(),
            runner=interactive_runner.InteractiveRunner(),
            options=options)
        ie.current_env().add_derived_pipeline(p_original_cache_source,
                                              actual_pipeline)

        # Now, build the expected pipeline which replaces the unbounded source with
        # a TestStream.
        intermediate_source_pcoll_cache_key = \
            self.cache_key_of('synthetic_var_' + str(id(intermediate_source_pcoll)),
                         intermediate_source_pcoll)
        p_expected = beam.Pipeline()
        ie.current_env().set_cache_manager(streaming_cache_manager, p_expected)
        test_stream = (
            p_expected
            | TestStream(output_tags=[intermediate_source_pcoll_cache_key]))
        # pylint: disable=expression-not-assigned
        (test_stream[intermediate_source_pcoll_cache_key]
         | 'square1' >> beam.Map(lambda e: e)
         | 'reify' >> beam.Map(lambda _: _)
         | cache.WriteCache(ie.current_env().get_cache_manager(p_expected),
                            'unused'))

        # Test that the TestStream is outputting to the correct PCollection.
        class TestStreamVisitor(PipelineVisitor):
            def __init__(self):
                self.output_tags = set()

            def enter_composite_transform(self, transform_node):
                self.visit_transform(transform_node)

            def visit_transform(self, transform_node):
                transform = transform_node.transform
                if isinstance(transform, TestStream):
                    self.output_tags = transform.output_tags

        v = TestStreamVisitor()
        actual_pipeline.visit(v)
        expected_output_tags = set([intermediate_source_pcoll_cache_key])
        actual_output_tags = v.output_tags
        self.assertSetEqual(expected_output_tags, actual_output_tags)

        # Test that the pipeline is as expected.
        assert_pipeline_proto_equal(self, p_expected.to_runner_api(),
                                    instrumenter.instrumented_pipeline_proto())
Example #21
0
    def test_instrument_example_unbounded_pipeline_direct_from_source(self):
        """Tests that the it caches PCollections from a source.
        """
        # Create a new interactive environment to make the test idempotent.
        ie.new_env(cache_manager=streaming_cache.StreamingCache(
            cache_dir=None))

        # Create the pipeline that will be instrumented.
        from apache_beam.options.pipeline_options import StandardOptions
        options = StandardOptions(streaming=True)
        p_original = beam.Pipeline(interactive_runner.InteractiveRunner(),
                                   options)
        source_1 = p_original | 'source1' >> beam.io.ReadFromPubSub(
            subscription='projects/fake-project/subscriptions/fake_sub')
        # pylint: disable=possibly-unused-variable

        # Watch but do not cache the PCollections.
        ib.watch(locals())

        def cache_key_of(name, pcoll):
            return name + '_' + str(id(pcoll)) + '_' + str(id(pcoll.producer))

        # Instrument the original pipeline to create the pipeline the user will see.
        p_copy = beam.Pipeline.from_runner_api(
            p_original.to_runner_api(),
            runner=interactive_runner.InteractiveRunner(),
            options=options)
        instrumenter = instr.build_pipeline_instrument(p_copy)
        actual_pipeline = beam.Pipeline.from_runner_api(
            proto=instrumenter.instrumented_pipeline_proto(),
            runner=interactive_runner.InteractiveRunner(),
            options=options)

        # Now, build the expected pipeline which replaces the unbounded source with
        # a TestStream.
        source_1_cache_key = cache_key_of('source_1', source_1)
        p_expected = beam.Pipeline()

        # pylint: disable=unused-variable
        test_stream = (
            p_expected
            | TestStream(output_tags=[cache_key_of('source_1', source_1)]))

        # Test that the TestStream is outputting to the correct PCollection.
        class TestStreamVisitor(PipelineVisitor):
            def __init__(self):
                self.output_tags = set()

            def enter_composite_transform(self, transform_node):
                self.visit_transform(transform_node)

            def visit_transform(self, transform_node):
                transform = transform_node.transform
                if isinstance(transform, TestStream):
                    self.output_tags = transform.output_tags

        v = TestStreamVisitor()
        actual_pipeline.visit(v)
        expected_output_tags = set([source_1_cache_key])
        actual_output_tags = v.output_tags
        self.assertSetEqual(expected_output_tags, actual_output_tags)

        # Test that the pipeline is as expected.
        assert_pipeline_proto_equal(
            self, p_expected.to_runner_api(use_fake_coders=True),
            instrumenter.instrumented_pipeline_proto())
Example #22
0
    def test_streaming_wordcount(self):
        self.skipTest('[BEAM-9601] Test is breaking PreCommits')

        class WordExtractingDoFn(beam.DoFn):
            def process(self, element):
                text_line = element.strip()
                words = text_line.split()
                return words

        # Add the TestStream so that it can be cached.
        ib.options.capturable_sources.add(TestStream)
        ib.options.capture_duration = timedelta(seconds=1)

        p = beam.Pipeline(runner=interactive_runner.InteractiveRunner(),
                          options=StandardOptions(streaming=True))

        data = (
            p
            | TestStream()
                .advance_watermark_to(0)
                .advance_processing_time(1)
                .add_elements(['to', 'be', 'or', 'not', 'to', 'be'])
                .advance_watermark_to(20)
                .advance_processing_time(1)
                .add_elements(['that', 'is', 'the', 'question'])
            | beam.WindowInto(beam.window.FixedWindows(10))) # yapf: disable

        counts = (data
                  | 'split' >> beam.ParDo(WordExtractingDoFn())
                  | 'pair_with_one' >> beam.Map(lambda x: (x, 1))
                  | 'group' >> beam.GroupByKey()
                  | 'count' >> beam.Map(lambda wordones:
                                        (wordones[0], sum(wordones[1]))))

        # Watch the local scope for Interactive Beam so that referenced PCollections
        # will be cached.
        ib.watch(locals())

        # This is normally done in the interactive_utils when a transform is
        # applied but needs an IPython environment. So we manually run this here.
        ie.current_env().track_user_pipelines()

        # This tests that the data was correctly cached.
        pane_info = PaneInfo(True, True, PaneInfoTiming.UNKNOWN, 0, 0)
        expected_data_df = pd.DataFrame(
            [('to', 0, [beam.window.IntervalWindow(0, 10)], pane_info),
             ('be', 0, [beam.window.IntervalWindow(0, 10)], pane_info),
             ('or', 0, [beam.window.IntervalWindow(0, 10)], pane_info),
             ('not', 0, [beam.window.IntervalWindow(0, 10)], pane_info),
             ('to', 0, [beam.window.IntervalWindow(0, 10)], pane_info),
             ('be', 0, [beam.window.IntervalWindow(0, 10)], pane_info),
             ('that', 20000000, [beam.window.IntervalWindow(20, 30)
                                 ], pane_info),
             ('is', 20000000, [beam.window.IntervalWindow(20, 30)], pane_info),
             ('the', 20000000, [beam.window.IntervalWindow(20, 30)
                                ], pane_info),
             ('question', 20000000, [beam.window.IntervalWindow(20, 30)
                                     ], pane_info)],
            columns=[0, 'event_time', 'windows', 'pane_info'])

        data_df = ib.collect(data, include_window_info=True)
        pd.testing.assert_frame_equal(expected_data_df, data_df)

        # This tests that the windowing was passed correctly so that all the data
        # is aggregated also correctly.
        pane_info = PaneInfo(True, False, PaneInfoTiming.ON_TIME, 0, 0)
        expected_counts_df = pd.DataFrame([
            ('to', 2, 9999999, [beam.window.IntervalWindow(0, 10)], pane_info),
            ('be', 2, 9999999, [beam.window.IntervalWindow(0, 10)], pane_info),
            ('or', 1, 9999999, [beam.window.IntervalWindow(0, 10)], pane_info),
            ('not', 1, 9999999, [beam.window.IntervalWindow(0, 10)
                                 ], pane_info),
            ('that', 1, 29999999, [beam.window.IntervalWindow(20, 30)
                                   ], pane_info),
            ('is', 1, 29999999, [beam.window.IntervalWindow(20, 30)
                                 ], pane_info),
            ('the', 1, 29999999, [beam.window.IntervalWindow(20, 30)
                                  ], pane_info),
            ('question', 1, 29999999, [beam.window.IntervalWindow(20, 30)
                                       ], pane_info)
        ],
                                          columns=[
                                              0, 1, 'event_time', 'windows',
                                              'pane_info'
                                          ])

        counts_df = ib.collect(counts, include_window_info=True)
        pd.testing.assert_frame_equal(expected_counts_df, counts_df)
Example #23
0
    def test_windowing(self):
        test_stream = (TestStream()
                       .advance_watermark_to(0)
                       .add_elements(['a', 'b', 'c'])
                       .advance_processing_time(1)
                       .advance_processing_time(1)
                       .advance_processing_time(1)
                       .advance_processing_time(1)
                       .advance_processing_time(1)
                       .advance_watermark_to(5)
                       .add_elements(['1', '2', '3'])
                       .advance_processing_time(1)
                       .advance_watermark_to(6)
                       .advance_processing_time(1)
                       .advance_watermark_to(7)
                       .advance_processing_time(1)
                       .advance_watermark_to(8)
                       .advance_processing_time(1)
                       .advance_watermark_to(9)
                       .advance_processing_time(1)
                       .advance_watermark_to(10)
                       .advance_processing_time(1)
                       .advance_watermark_to(11)
                       .advance_processing_time(1)
                       .advance_watermark_to(12)
                       .advance_processing_time(1)
                       .advance_watermark_to(13)
                       .advance_processing_time(1)
                       .advance_watermark_to(14)
                       .advance_processing_time(1)
                       .advance_watermark_to(15)
                       .advance_processing_time(1)
                       )  # yapf: disable

        options = StandardOptions(streaming=True)
        p = TestPipeline(options=options)

        records = (p
                   | test_stream
                   | 'letter windows' >> beam.WindowInto(
                       FixedWindows(5),
                       accumulation_mode=trigger.AccumulationMode.DISCARDING)
                   | 'letter with key' >> beam.Map(lambda x: ('k', x))
                   | 'letter gbk' >> beam.GroupByKey()
                   | ReverseTestStream(sample_resolution_sec=1,
                                       output_tag=None))

        assert_that(
            records,
            equal_to_per_window({
                beam.window.GlobalWindow(): [
                    [ProcessingTimeEvent(5),
                     WatermarkEvent(4999998)],
                    [
                        ElementEvent([
                            TimestampedValue(('k', ['a', 'b', 'c']), 4.999999)
                        ])
                    ],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(5000000)],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(6000000)],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(7000000)],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(8000000)],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(9000000)],
                    [
                        ElementEvent([
                            TimestampedValue(('k', ['1', '2', '3']), 9.999999)
                        ])
                    ],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(10000000)],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(11000000)],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(12000000)],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(13000000)],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(14000000)],
                    [ProcessingTimeEvent(1),
                     WatermarkEvent(15000000)],
                ],
            }))

        p.run()
    def test_instrument_mixed_streaming_batch(self):
        """Tests caching for both batch and streaming sources in the same pipeline.

    This ensures that cached bounded and unbounded sources are read from the
    TestStream.
    """
        # Create the pipeline that will be instrumented.
        from apache_beam.options.pipeline_options import StandardOptions
        options = StandardOptions(streaming=True)
        p_original = beam.Pipeline(interactive_runner.InteractiveRunner(),
                                   options)
        streaming_cache_manager = StreamingCache(cache_dir=None)
        ie.current_env().set_cache_manager(streaming_cache_manager, p_original)
        source_1 = p_original | 'source1' >> beam.io.ReadFromPubSub(
            subscription='projects/fake-project/subscriptions/fake_sub')
        source_2 = p_original | 'source2' >> beam.Create([1, 2, 3, 4, 5])

        # pylint: disable=possibly-unused-variable
        pcoll_1 = ((source_1, source_2)
                   | beam.Flatten()
                   | 'square1' >> beam.Map(lambda x: x * x))

        # Watch but do not cache the PCollections.
        ib.watch(locals())
        # This should be noop.
        utils.watch_sources(p_original)
        self._mock_write_cache(p_original, [],
                               self.cache_key_of('source_2', source_2))
        ie.current_env().mark_pcollection_computed([source_2])

        # Instrument the original pipeline to create the pipeline the user will see.
        p_copy = beam.Pipeline.from_runner_api(
            p_original.to_runner_api(),
            runner=interactive_runner.InteractiveRunner(),
            options=options)
        ie.current_env().add_derived_pipeline(p_original, p_copy)
        instrumenter = instr.build_pipeline_instrument(p_copy)
        actual_pipeline = beam.Pipeline.from_runner_api(
            proto=instrumenter.instrumented_pipeline_proto(),
            runner=interactive_runner.InteractiveRunner(),
            options=options)

        # Now, build the expected pipeline which replaces the unbounded source with
        # a TestStream.
        source_1_cache_key = self.cache_key_of('source_1', source_1)
        source_2_cache_key = self.cache_key_of('source_2', source_2)
        p_expected = beam.Pipeline()
        ie.current_env().set_cache_manager(streaming_cache_manager, p_expected)
        test_stream = (
            p_expected
            | TestStream(output_tags=[source_1_cache_key, source_2_cache_key]))
        # pylint: disable=expression-not-assigned
        ((test_stream[self.cache_key_of('source_1', source_1)],
          test_stream[self.cache_key_of('source_2', source_2)])
         | beam.Flatten()
         | 'square1' >> beam.Map(lambda x: x * x)
         | 'reify' >> beam.Map(lambda _: _)
         | cache.WriteCache(ie.current_env().get_cache_manager(p_expected),
                            'unused'))

        # Test that the TestStream is outputting to the correct PCollection.
        class TestStreamVisitor(PipelineVisitor):
            def __init__(self):
                self.output_tags = set()

            def enter_composite_transform(self, transform_node):
                self.visit_transform(transform_node)

            def visit_transform(self, transform_node):
                transform = transform_node.transform
                if isinstance(transform, TestStream):
                    self.output_tags = transform.output_tags

        v = TestStreamVisitor()
        actual_pipeline.visit(v)
        expected_output_tags = set([source_1_cache_key, source_2_cache_key])
        actual_output_tags = v.output_tags
        self.assertSetEqual(expected_output_tags, actual_output_tags)

        # Test that the pipeline is as expected.
        assert_pipeline_proto_equal(self, p_expected.to_runner_api(),
                                    instrumenter.instrumented_pipeline_proto())
Example #25
0
    def test_basic_execution_in_records_format(self):
        test_stream = (TestStream()
                       .advance_watermark_to(0)
                       .advance_processing_time(5)
                       .add_elements(['a', 'b', 'c'])
                       .advance_watermark_to(2)
                       .advance_processing_time(1)
                       .advance_watermark_to(4)
                       .advance_processing_time(1)
                       .advance_watermark_to(6)
                       .advance_processing_time(1)
                       .advance_watermark_to(8)
                       .advance_processing_time(1)
                       .advance_watermark_to(10)
                       .advance_processing_time(1)
                       .add_elements([TimestampedValue('1', 15),
                                      TimestampedValue('2', 15),
                                      TimestampedValue('3', 15)]))  # yapf: disable

        options = StandardOptions(streaming=True)
        p = TestPipeline(options=options)

        coder = beam.coders.FastPrimitivesCoder()
        records = (p
                   | test_stream
                   | ReverseTestStream(
                       sample_resolution_sec=1,
                       coder=coder,
                       output_format=OutputFormat.TEST_STREAM_FILE_RECORDS,
                       output_tag=None)
                   | 'stringify' >> beam.Map(str))

        assert_that(
            records,
            equal_to_per_window({
                beam.window.GlobalWindow(): [
                    str(TestStreamFileHeader()),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                processing_time_event=TestStreamPayload.Event.
                                AdvanceProcessingTime(
                                    advance_duration=5000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                watermark_event=TestStreamPayload.Event.
                                AdvanceWatermark(new_watermark=0)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                element_event=TestStreamPayload.Event.
                                AddElements(elements=[
                                    TestStreamPayload.TimestampedElement(
                                        encoded_element=coder.encode('a'),
                                        timestamp=0),
                                    TestStreamPayload.TimestampedElement(
                                        encoded_element=coder.encode('b'),
                                        timestamp=0),
                                    TestStreamPayload.TimestampedElement(
                                        encoded_element=coder.encode('c'),
                                        timestamp=0),
                                ])))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                watermark_event=TestStreamPayload.Event.
                                AdvanceWatermark(new_watermark=2000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                processing_time_event=TestStreamPayload.Event.
                                AdvanceProcessingTime(
                                    advance_duration=1000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                watermark_event=TestStreamPayload.Event.
                                AdvanceWatermark(new_watermark=4000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                processing_time_event=TestStreamPayload.Event.
                                AdvanceProcessingTime(
                                    advance_duration=1000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                watermark_event=TestStreamPayload.Event.
                                AdvanceWatermark(new_watermark=6000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                processing_time_event=TestStreamPayload.Event.
                                AdvanceProcessingTime(
                                    advance_duration=1000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                watermark_event=TestStreamPayload.Event.
                                AdvanceWatermark(new_watermark=8000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                processing_time_event=TestStreamPayload.Event.
                                AdvanceProcessingTime(
                                    advance_duration=1000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                watermark_event=TestStreamPayload.Event.
                                AdvanceWatermark(new_watermark=10000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                processing_time_event=TestStreamPayload.Event.
                                AdvanceProcessingTime(
                                    advance_duration=1000000)))),
                    str(
                        TestStreamFileRecord(
                            recorded_event=TestStreamPayload.Event(
                                element_event=TestStreamPayload.Event.
                                AddElements(elements=[
                                    TestStreamPayload.TimestampedElement(
                                        encoded_element=coder.encode('1'),
                                        timestamp=15000000),
                                    TestStreamPayload.TimestampedElement(
                                        encoded_element=coder.encode('2'),
                                        timestamp=15000000),
                                    TestStreamPayload.TimestampedElement(
                                        encoded_element=coder.encode('3'),
                                        timestamp=15000000),
                                ])))),
                ],
            }))

        p.run()
    def test_streaming_wordcount(self):
        class WordExtractingDoFn(beam.DoFn):
            def process(self, element):
                text_line = element.strip()
                words = text_line.split()
                return words

        # Add the TestStream so that it can be cached.
        ib.options.capturable_sources.add(TestStream)

        p = beam.Pipeline(runner=interactive_runner.InteractiveRunner(),
                          options=StandardOptions(streaming=True))

        data = (
            p
            | TestStream()
                .advance_watermark_to(0)
                .advance_processing_time(1)
                .add_elements(['to', 'be', 'or', 'not', 'to', 'be'])
                .advance_watermark_to(20)
                .advance_processing_time(1)
                .add_elements(['that', 'is', 'the', 'question'])
            | beam.WindowInto(beam.window.FixedWindows(10))) # yapf: disable

        counts = (data
                  | 'split' >> beam.ParDo(WordExtractingDoFn())
                  | 'pair_with_one' >> beam.Map(lambda x: (x, 1))
                  | 'group' >> beam.GroupByKey()
                  | 'count' >> beam.Map(lambda wordones:
                                        (wordones[0], sum(wordones[1]))))

        # Watch the local scope for Interactive Beam so that referenced PCollections
        # will be cached.
        ib.watch(locals())

        # This is normally done in the interactive_utils when a transform is
        # applied but needs an IPython environment. So we manually run this here.
        ie.current_env().track_user_pipelines()

        # Create a fake limiter that cancels the BCJ once the main job receives the
        # expected amount of results.
        class FakeLimiter:
            def __init__(self, p, pcoll):
                self.p = p
                self.pcoll = pcoll

            def is_triggered(self):
                result = ie.current_env().pipeline_result(self.p)
                if result:
                    try:
                        results = result.get(self.pcoll)
                    except ValueError:
                        return False
                    return len(results) >= 10
                return False

        # This sets the limiters to stop reading when the test receives 10 elements.
        ie.current_env().options.capture_control.set_limiters_for_test(
            [FakeLimiter(p, data)])

        # This tests that the data was correctly cached.
        pane_info = PaneInfo(True, True, PaneInfoTiming.UNKNOWN, 0, 0)
        expected_data_df = pd.DataFrame([
            ('to', 0, [IntervalWindow(0, 10)], pane_info),
            ('be', 0, [IntervalWindow(0, 10)], pane_info),
            ('or', 0, [IntervalWindow(0, 10)], pane_info),
            ('not', 0, [IntervalWindow(0, 10)], pane_info),
            ('to', 0, [IntervalWindow(0, 10)], pane_info),
            ('be', 0, [IntervalWindow(0, 10)], pane_info),
            ('that', 20000000, [IntervalWindow(20, 30)], pane_info),
            ('is', 20000000, [IntervalWindow(20, 30)], pane_info),
            ('the', 20000000, [IntervalWindow(20, 30)], pane_info),
            ('question', 20000000, [IntervalWindow(20, 30)], pane_info)
        ], columns=[0, 'event_time', 'windows', 'pane_info']) # yapf: disable

        data_df = ib.collect(data, include_window_info=True)
        pd.testing.assert_frame_equal(expected_data_df, data_df)

        # This tests that the windowing was passed correctly so that all the data
        # is aggregated also correctly.
        pane_info = PaneInfo(True, False, PaneInfoTiming.ON_TIME, 0, 0)
        expected_counts_df = pd.DataFrame([
            ('be', 2, 9999999, [IntervalWindow(0, 10)], pane_info),
            ('not', 1, 9999999, [IntervalWindow(0, 10)], pane_info),
            ('or', 1, 9999999, [IntervalWindow(0, 10)], pane_info),
            ('to', 2, 9999999, [IntervalWindow(0, 10)], pane_info),
            ('is', 1, 29999999, [IntervalWindow(20, 30)], pane_info),
            ('question', 1, 29999999, [IntervalWindow(20, 30)], pane_info),
            ('that', 1, 29999999, [IntervalWindow(20, 30)], pane_info),
            ('the', 1, 29999999, [IntervalWindow(20, 30)], pane_info),
        ], columns=[0, 1, 'event_time', 'windows', 'pane_info']) # yapf: disable

        counts_df = ib.collect(counts, include_window_info=True)

        # The group by key has no guarantee of order. So we post-process the DF by
        # sorting so we can test equality.
        sorted_counts_df = (counts_df
                            .sort_values(['event_time', 0], ascending=True)
                            .reset_index(drop=True)) # yapf: disable
        pd.testing.assert_frame_equal(expected_counts_df, sorted_counts_df)
    def test_triggering_frequency(self, is_streaming, with_auto_sharding):
        destination = 'project1:dataset1.table1'

        job_reference = bigquery_api.JobReference()
        job_reference.projectId = 'project1'
        job_reference.jobId = 'job_name1'
        result_job = bigquery_api.Job()
        result_job.jobReference = job_reference

        mock_job = mock.Mock()
        mock_job.status.state = 'DONE'
        mock_job.status.errorResult = None
        mock_job.jobReference = job_reference

        bq_client = mock.Mock()
        bq_client.jobs.Get.return_value = mock_job
        bq_client.jobs.Insert.return_value = result_job

        # Insert a fake clock to work with auto-sharding which needs a processing
        # time timer.
        class _FakeClock(object):
            def __init__(self, now=time.time()):
                self._now = now

            def __call__(self):
                return self._now

        start_time = timestamp.Timestamp(0)
        bq_client.test_clock = _FakeClock(now=start_time)

        triggering_frequency = 20 if is_streaming else None
        transform = bqfl.BigQueryBatchFileLoads(
            destination,
            custom_gcs_temp_location=self._new_tempdir(),
            test_client=bq_client,
            validate=False,
            temp_file_format=bigquery_tools.FileFormat.JSON,
            is_streaming_pipeline=is_streaming,
            triggering_frequency=triggering_frequency,
            with_auto_sharding=with_auto_sharding)

        # Need to test this with the DirectRunner to avoid serializing mocks
        with TestPipeline(
                runner='BundleBasedDirectRunner',
                options=StandardOptions(streaming=is_streaming)) as p:
            if is_streaming:
                _SIZE = len(_ELEMENTS)
                fisrt_batch = [
                    TimestampedValue(value, start_time + i + 1)
                    for i, value in enumerate(_ELEMENTS[:_SIZE // 2])
                ]
                second_batch = [
                    TimestampedValue(value, start_time + _SIZE // 2 + i + 1)
                    for i, value in enumerate(_ELEMENTS[_SIZE // 2:])
                ]
                # Advance processing time between batches of input elements to fire the
                # user triggers. Intentionally advance the processing time twice for the
                # auto-sharding case since we need to first fire the timer and then
                # fire the trigger.
                test_stream = (
                    TestStream().advance_watermark_to(start_time).add_elements(
                        fisrt_batch).advance_processing_time(30).
                    advance_processing_time(30).add_elements(second_batch).
                    advance_processing_time(30).advance_processing_time(
                        30).advance_watermark_to_infinity())
                input = p | test_stream
            else:
                input = p | beam.Create(_ELEMENTS)
            outputs = input | transform

            dest_files = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS]
            dest_job = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS]

            files = dest_files | "GetFiles" >> beam.Map(lambda x: x[1][0])
            destinations = (
                dest_files
                | "GetDests" >>
                beam.Map(lambda x:
                         (bigquery_tools.get_hashable_destination(x[0]), x[1]))
                | "GetUniques" >> combiners.Count.PerKey()
                | "GetFinalDests" >> beam.Keys())
            jobs = dest_job | "GetJobs" >> beam.Map(lambda x: x[1])

            # Check that all files exist.
            _ = (files
                 | beam.Map(
                     lambda x: hamcrest_assert(os.path.exists(x), is_(True))))

            # Expect two load jobs are generated in the streaming case due to the
            # triggering frequency. Grouping is per trigger so we expect two entries
            # in the output as opposed to one.
            file_count = files | combiners.Count.Globally().without_defaults()
            expected_file_count = [1, 1] if is_streaming else [1]
            expected_destinations = [destination, destination
                                     ] if is_streaming else [destination]
            expected_jobs = [job_reference, job_reference
                             ] if is_streaming else [job_reference]
            assert_that(file_count,
                        equal_to(expected_file_count),
                        label='CountFiles')
            assert_that(destinations,
                        equal_to(expected_destinations),
                        label='CheckDestinations')
            assert_that(jobs, equal_to(expected_jobs), label='CheckJobs')