예제 #1
0
    def test_ptransform_override_type_hints(self):
        class NoTypeHintOverride(PTransformOverride):
            def matches(self, applied_ptransform):
                return isinstance(applied_ptransform.transform, DoubleParDo)

            def get_replacement_transform_for_applied_ptransform(
                    self, applied_ptransform):
                return ToStringParDo()

        class WithTypeHintOverride(PTransformOverride):
            def matches(self, applied_ptransform):
                return isinstance(applied_ptransform.transform, DoubleParDo)

            def get_replacement_transform_for_applied_ptransform(
                    self, applied_ptransform):
                return ToStringParDo().with_input_types(int).with_output_types(
                    str)

        for override, expected_type in [(NoTypeHintOverride(), typehints.Any),
                                        (WithTypeHintOverride(), str)]:
            p = TestPipeline()
            pcoll = (p
                     | beam.Create([1, 2, 3])
                     | 'Operate' >> DoubleParDo()
                     | 'NoOp' >> beam.Map(lambda x: x))

            p.replace_all([override])
            self.assertEqual(pcoll.producer.inputs[0].element_type,
                             expected_type)
예제 #2
0
    def test_expand_with_other_options(self):
        options = PipelineOptions([])
        options.view_as(StandardOptions).streaming = True
        p = TestPipeline(options=options)
        pcoll = (p
                 | ReadFromPubSub('projects/fakeprj/topics/a_topic',
                                  None,
                                  'a_label',
                                  with_attributes=True,
                                  timestamp_attribute='time')
                 | beam.Map(lambda x: x))
        self.assertEqual(PubsubMessage, pcoll.element_type)

        # Apply the necessary PTransformOverrides.
        overrides = _get_transform_overrides(options)
        p.replace_all(overrides)

        # Note that the direct output of ReadFromPubSub will be replaced
        # by a PTransformOverride, so we use a no-op Map.
        read_transform = pcoll.producer.inputs[0].producer.transform

        # Ensure that the properties passed through correctly
        source = read_transform._source
        self.assertTrue(source.with_attributes)
        self.assertEqual('time', source.timestamp_attribute)
예제 #3
0
  def test_ptransform_override_type_hints(self):

    class NoTypeHintOverride(PTransformOverride):

      def matches(self, applied_ptransform):
        return isinstance(applied_ptransform.transform, DoubleParDo)

      def get_replacement_transform(self, ptransform):
        return ToStringParDo()

    class WithTypeHintOverride(PTransformOverride):

      def matches(self, applied_ptransform):
        return isinstance(applied_ptransform.transform, DoubleParDo)

      def get_replacement_transform(self, ptransform):
        return (ToStringParDo()
                .with_input_types(int)
                .with_output_types(str))

    for override, expected_type in [(NoTypeHintOverride(), typehints.Any),
                                    (WithTypeHintOverride(), str)]:
      p = TestPipeline()
      pcoll = (p
               | beam.Create([1, 2, 3])
               | 'Operate' >> DoubleParDo()
               | 'NoOp' >> beam.Map(lambda x: x))

      p.replace_all([override])
      self.assertEquals(pcoll.producer.inputs[0].element_type, expected_type)
예제 #4
0
    def test_expand_with_subscription(self):
        p = TestPipeline()
        p.options.view_as(StandardOptions).streaming = True
        pcoll = (p
                 | ReadFromPubSub(
                     None,
                     'projects/fakeprj/subscriptions/a_subscription',
                     'a_label',
                     with_attributes=False,
                     timestamp_attribute=None)
                 | beam.Map(lambda x: x))
        self.assertEqual(bytes, pcoll.element_type)

        # Apply the necessary PTransformOverrides.
        overrides = _get_transform_overrides(p.options)
        p.replace_all(overrides)

        # Note that the direct output of ReadFromPubSub will be replaced
        # by a PTransformOverride, so we use a no-op Map.
        read_transform = pcoll.producer.inputs[0].producer.transform

        # Ensure that the properties passed through correctly
        source = read_transform._source
        self.assertEqual('a_subscription', source.subscription_name)
        self.assertEqual('a_label', source.id_label)
예제 #5
0
  def test_expand(self):
    options = PipelineOptions([])
    options.view_as(StandardOptions).streaming = True
    p = TestPipeline(options=options)
    pcoll = (
        p
        | ReadFromPubSub('projects/fakeprj/topics/baz')
        | WriteToPubSub(
            'projects/fakeprj/topics/a_topic', with_attributes=True)
        | beam.Map(lambda x: x))

    # Apply the necessary PTransformOverrides.
    overrides = _get_transform_overrides(options)
    p.replace_all(overrides)

    # Note that the direct output of ReadFromPubSub will be replaced
    # by a PTransformOverride, so we use a no-op Map.
    write_transform = pcoll.producer.inputs[0].producer.transform

    # Ensure that the properties passed through correctly
    self.assertEqual('a_topic', write_transform.dofn.short_topic_name)
    self.assertEqual(True, write_transform.dofn.with_attributes)
    # TODO(BEAM-4275): These properties aren't supported yet in direct runner.
    self.assertEqual(None, write_transform.dofn.id_label)
    self.assertEqual(None, write_transform.dofn.timestamp_attribute)
예제 #6
0
    def test_expand_with_multiple_sources(self):
        options = PipelineOptions([])
        options.view_as(StandardOptions).streaming = True
        p = TestPipeline(options=options)
        topics = [
            'projects/fakeprj/topics/a_topic',
            'projects/fakeprj2/topics/b_topic'
        ]
        subscriptions = ['projects/fakeprj/subscriptions/a_subscription']

        pubsub_sources = [
            PubSubSourceDescriptor(descriptor)
            for descriptor in topics + subscriptions
        ]
        pcoll = (p | MultipleReadFromPubSub(pubsub_sources)
                 | beam.Map(lambda x: x))

        # Apply the necessary PTransformOverrides.
        overrides = _get_transform_overrides(options)
        p.replace_all(overrides)

        self.assertEqual(bytes, pcoll.element_type)

        # Ensure that the sources are passed through correctly
        read_transforms = pcoll.producer.inputs[0].producer.inputs
        topics_list = []
        subscription_list = []
        for read_transform in read_transforms:
            source = read_transform.producer.transform._source
            if source.full_topic:
                topics_list.append(source.full_topic)
            else:
                subscription_list.append(source.full_subscription)
        self.assertEqual(topics_list, topics)
        self.assertEqual(subscription_list, subscriptions)
예제 #7
0
    def test_expand_deprecated(self):
        p = TestPipeline()
        p.options.view_as(StandardOptions).streaming = True
        pcoll = (p
                 | ReadFromPubSub('projects/fakeprj/topics/baz')
                 | WriteStringsToPubSub('projects/fakeprj/topics/a_topic')
                 | beam.Map(lambda x: x))

        # Apply the necessary PTransformOverrides.
        overrides = _get_transform_overrides(p.options)
        p.replace_all(overrides)

        # Note that the direct output of ReadFromPubSub will be replaced
        # by a PTransformOverride, so we use a no-op Map.
        write_transform = pcoll.producer.inputs[0].producer.transform

        # Ensure that the properties passed through correctly
        self.assertEqual('a_topic', write_transform.dofn.short_topic_name)
예제 #8
0
  def test_expand_deprecated(self):
    p = TestPipeline()
    p.options.view_as(StandardOptions).streaming = True
    pcoll = (p
             | ReadFromPubSub('projects/fakeprj/topics/baz')
             | WriteStringsToPubSub('projects/fakeprj/topics/a_topic')
             | beam.Map(lambda x: x))

    # Apply the necessary PTransformOverrides.
    overrides = _get_transform_overrides(p.options)
    p.replace_all(overrides)

    # Note that the direct output of ReadFromPubSub will be replaced
    # by a PTransformOverride, so we use a no-op Map.
    write_transform = pcoll.producer.inputs[0].producer.transform

    # Ensure that the properties passed through correctly
    self.assertEqual('a_topic', write_transform.dofn.short_topic_name)
예제 #9
0
    def test_expand_with_multiple_sources_and_other_options(self):
        options = PipelineOptions([])
        options.view_as(StandardOptions).streaming = True
        p = TestPipeline(options=options)
        sources = [
            'projects/fakeprj/topics/a_topic',
            'projects/fakeprj2/topics/b_topic',
            'projects/fakeprj/subscriptions/a_subscription'
        ]
        id_labels = ['a_label_topic', 'b_label_topic', 'a_label_subscription']
        timestamp_attributes = [
            'a_ta_topic', 'b_ta_topic', 'a_ta_subscription'
        ]

        pubsub_sources = [
            PubSubSourceDescriptor(source=source,
                                   id_label=id_label,
                                   timestamp_attribute=timestamp_attribute)
            for source, id_label, timestamp_attribute in zip(
                sources, id_labels, timestamp_attributes)
        ]

        pcoll = (p | MultipleReadFromPubSub(pubsub_sources)
                 | beam.Map(lambda x: x))

        # Apply the necessary PTransformOverrides.
        overrides = _get_transform_overrides(options)
        p.replace_all(overrides)

        self.assertEqual(bytes, pcoll.element_type)

        # Ensure that the sources are passed through correctly
        read_transforms = pcoll.producer.inputs[0].producer.inputs
        for i, read_transform in enumerate(read_transforms):
            id_label = id_labels[i]
            timestamp_attribute = timestamp_attributes[i]

            source = read_transform.producer.transform._source
            self.assertEqual(source.id_label, id_label)
            self.assertEqual(source.with_attributes, False)
            self.assertEqual(source.timestamp_attribute, timestamp_attribute)
예제 #10
0
  def test_expand_with_other_options(self):
    p = TestPipeline()
    p.options.view_as(StandardOptions).streaming = True
    pcoll = (p
             | ReadFromPubSub('projects/fakeprj/topics/a_topic',
                              None, 'a_label', with_attributes=True,
                              timestamp_attribute='time')
             | beam.Map(lambda x: x))
    self.assertEqual(PubsubMessage, pcoll.element_type)

    # Apply the necessary PTransformOverrides.
    overrides = _get_transform_overrides(p.options)
    p.replace_all(overrides)

    # Note that the direct output of ReadFromPubSub will be replaced
    # by a PTransformOverride, so we use a no-op Map.
    read_transform = pcoll.producer.inputs[0].producer.transform

    # Ensure that the properties passed through correctly
    source = read_transform._source
    self.assertTrue(source.with_attributes)
    self.assertEqual('time', source.timestamp_attribute)
예제 #11
0
  def test_expand_with_subscription(self):
    p = TestPipeline()
    p.options.view_as(StandardOptions).streaming = True
    pcoll = (p
             | ReadFromPubSub(
                 None, 'projects/fakeprj/subscriptions/a_subscription',
                 'a_label', with_attributes=False, timestamp_attribute=None)
             | beam.Map(lambda x: x))
    self.assertEqual(bytes, pcoll.element_type)

    # Apply the necessary PTransformOverrides.
    overrides = _get_transform_overrides(p.options)
    p.replace_all(overrides)

    # Note that the direct output of ReadFromPubSub will be replaced
    # by a PTransformOverride, so we use a no-op Map.
    read_transform = pcoll.producer.inputs[0].producer.transform

    # Ensure that the properties passed through correctly
    source = read_transform._source
    self.assertEqual('a_subscription', source.subscription_name)
    self.assertEqual('a_label', source.id_label)
예제 #12
0
    def test_expand_with_topic(self):
        p = TestPipeline()
        p.options.view_as(StandardOptions).streaming = True
        pcoll = (p
                 | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic',
                                         None, 'a_label')
                 | beam.Map(lambda x: x))
        # Ensure that the output type is str.
        self.assertEqual(unicode, pcoll.element_type)

        # Apply the necessary PTransformOverrides.
        overrides = _get_transform_overrides(p.options)
        p.replace_all(overrides)

        # Note that the direct output of ReadStringsFromPubSub will be replaced
        # by a PTransformOverride, so we use a no-op Map.
        read_transform = pcoll.producer.inputs[0].producer.transform

        # Ensure that the properties passed through correctly
        source = read_transform._source
        self.assertEqual('a_topic', source.topic_name)
        self.assertEqual('a_label', source.id_label)
예제 #13
0
  def test_expand(self):
    p = TestPipeline()
    p.options.view_as(StandardOptions).streaming = True
    pcoll = (p
             | ReadFromPubSub('projects/fakeprj/topics/baz')
             | WriteToPubSub('projects/fakeprj/topics/a_topic',
                             with_attributes=True)
             | beam.Map(lambda x: x))

    # Apply the necessary PTransformOverrides.
    overrides = _get_transform_overrides(p.options)
    p.replace_all(overrides)

    # Note that the direct output of ReadFromPubSub will be replaced
    # by a PTransformOverride, so we use a no-op Map.
    write_transform = pcoll.producer.inputs[0].producer.transform

    # Ensure that the properties passed through correctly
    self.assertEqual('a_topic', write_transform.dofn.short_topic_name)
    self.assertEqual(True, write_transform.dofn.with_attributes)
    # TODO(BEAM-4275): These properties aren't supported yet in direct runner.
    self.assertEqual(None, write_transform.dofn.id_label)
    self.assertEqual(None, write_transform.dofn.timestamp_attribute)
예제 #14
0
    def test_ptransform_override_multiple_outputs(self):
        class MultiOutputComposite(PTransform):
            def __init__(self):
                self.output_tags = set()

            def expand(self, pcoll):
                def mux_input(x):
                    x = x * 2
                    if isinstance(x, int):
                        yield TaggedOutput('numbers', x)
                    else:
                        yield TaggedOutput('letters', x)

                multi = pcoll | 'MyReplacement' >> beam.ParDo(
                    mux_input).with_outputs()
                letters = multi.letters | 'LettersComposite' >> beam.Map(
                    lambda x: x * 3)
                numbers = multi.numbers | 'NumbersComposite' >> beam.Map(
                    lambda x: x * 5)

                return {
                    'letters': letters,
                    'numbers': numbers,
                }

        class MultiOutputOverride(PTransformOverride):
            def matches(self, applied_ptransform):
                return applied_ptransform.full_label == 'MyMultiOutput'

            def get_replacement_transform(self, ptransform):
                return MultiOutputComposite()

        def mux_input(x):
            if isinstance(x, int):
                yield TaggedOutput('numbers', x)
            else:
                yield TaggedOutput('letters', x)

        p = TestPipeline()
        multi = (p
                 | beam.Create([1, 2, 3, 'a', 'b', 'c'])
                 | 'MyMultiOutput' >> beam.ParDo(mux_input).with_outputs())
        letters = multi.letters | 'MyLetters' >> beam.Map(lambda x: x)
        numbers = multi.numbers | 'MyNumbers' >> beam.Map(lambda x: x)

        # Assert that the PCollection replacement worked correctly and that elements
        # are flowing through. The replacement transform first multiples by 2 then
        # the leaf nodes inside the composite multiply by an additional 3 and 5. Use
        # prime numbers to ensure that each transform is getting executed once.
        assert_that(letters,
                    equal_to(['a' * 2 * 3, 'b' * 2 * 3, 'c' * 2 * 3]),
                    label='assert letters')
        assert_that(numbers,
                    equal_to([1 * 2 * 5, 2 * 2 * 5, 3 * 2 * 5]),
                    label='assert numbers')

        # Do the replacement and run the element assertions.
        p.replace_all([MultiOutputOverride()])
        p.run()

        # The following checks the graph to make sure the replacement occurred.
        visitor = PipelineTest.Visitor(visited=[])
        p.visit(visitor)
        pcollections = visitor.visited
        composites = visitor.enter_composite

        # Assert the replacement is in the composite list and retrieve the
        # AppliedPTransform.
        self.assertIn(MultiOutputComposite,
                      [t.transform.__class__ for t in composites])
        multi_output_composite = list(
            filter(lambda t: t.transform.__class__ == MultiOutputComposite,
                   composites))[0]

        # Assert that all of the replacement PCollections are in the graph.
        for output in multi_output_composite.outputs.values():
            self.assertIn(output, pcollections)

        # Assert that all of the "old"/replaced PCollections are not in the graph.
        self.assertNotIn(multi[None], visitor.visited)
        self.assertNotIn(multi.letters, visitor.visited)
        self.assertNotIn(multi.numbers, visitor.visited)