Example #1
0
    def test_ignore_pcoll_from_other_pipeline(self):
        p = beam.Pipeline()
        p2 = beam.Pipeline()
        cacheable_from_p2 = p2 | beam.Create([1, 2, 3])
        ib.watch(locals())

        aug_p = ap.AugmentedPipeline(p)
        cacheables = aug_p.cacheables()
        self.assertNotIn(cacheable_from_p2, cacheables)
Example #2
0
    def test_error_when_pcolls_from_mixed_pipelines(self):
        p = beam.Pipeline()
        cacheable_from_p = p | beam.Create([1, 2, 3])
        p2 = beam.Pipeline()
        cacheable_from_p2 = p2 | beam.Create([1, 2, 3])
        ib.watch(locals())

        self.assertRaises(
            AssertionError, lambda: ap.AugmentedPipeline(
                p, (cacheable_from_p, cacheable_from_p2)))
Example #3
0
    def test_ignore_cacheables(self):
        p = beam.Pipeline()
        cacheable_pcoll_1 = p | 'cacheable_pcoll_1' >> beam.Create([1, 2, 3])
        cacheable_pcoll_2 = p | 'cacheable_pcoll_2' >> beam.Create([4, 5, 6])
        ib.watch(locals())

        aug_p = ap.AugmentedPipeline(p, (cacheable_pcoll_1, ))
        cacheables = aug_p.cacheables()
        self.assertIn(cacheable_pcoll_1, cacheables)
        self.assertNotIn(cacheable_pcoll_2, cacheables)
Example #4
0
    def test_find_all_cacheables(self):
        p = beam.Pipeline()
        cacheable_pcoll_1 = p | beam.Create([1, 2, 3])
        cacheable_pcoll_2 = cacheable_pcoll_1 | beam.Map(lambda x: x * x)
        ib.watch(locals())

        aug_p = ap.AugmentedPipeline(p)
        cacheables = aug_p.cacheables()
        self.assertIn(cacheable_pcoll_1, cacheables)
        self.assertIn(cacheable_pcoll_2, cacheables)
Example #5
0
    def test_read_cache(self, mocked_get_cache_manager):
        p = beam.Pipeline()
        pcoll = p | beam.Create([1, 2, 3])
        consumer_transform = beam.Map(lambda x: x * x)
        _ = pcoll | consumer_transform
        ib.watch(locals())

        # Create the cache in memory.
        cache_manager = InMemoryCache()
        mocked_get_cache_manager.return_value = cache_manager
        aug_p = ap.AugmentedPipeline(p)
        key = repr(aug_p._cacheables[pcoll].to_key())
        cache_manager.write('test', 'full', key)

        # Capture the applied transform of the consumer_transform.
        pcoll_id = aug_p._context.pcollections.get_id(pcoll)
        consumer_transform_id = None
        pipeline_proto = p.to_runner_api()
        for (transform_id,
             transform) in pipeline_proto.components.transforms.items():
            if pcoll_id in transform.inputs.values():
                consumer_transform_id = transform_id
                break
        self.assertIsNotNone(consumer_transform_id)

        # Read cache on the pipeline proto.
        _, cache_id = read_cache.ReadCache(
            pipeline_proto, aug_p._context, aug_p._cache_manager,
            aug_p._cacheables[pcoll]).read_cache()
        actual_pipeline = pipeline_proto

        # Read cache directly on the pipeline instance.
        label = '{}{}'.format('_cache_', key)
        transform = read_cache._ReadCacheTransform(aug_p._cache_manager, key,
                                                   label)
        p | 'source' + label >> transform
        expected_pipeline = p.to_runner_api()

        # This rougly checks the equivalence between two protos, not detailed
        # wiring in sub transforms under top level transforms.
        assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline)

        # Check if the actual_pipeline uses cache as input of the
        # consumer_transform instead of the original pcoll from source.
        inputs = actual_pipeline.components.transforms[
            consumer_transform_id].inputs
        self.assertIn(cache_id, inputs.values())
        self.assertNotIn(pcoll_id, inputs.values())
    def test_write_cache(self, mocked_get_cache_manager):
        p = beam.Pipeline()
        pcoll = p | beam.Create([1, 2, 3])
        ib.watch(locals())

        cache_manager = InMemoryCache()
        mocked_get_cache_manager.return_value = cache_manager
        aug_p = ap.AugmentedPipeline(p)
        key = repr(aug_p._cacheables[pcoll].to_key())
        pipeline_proto = p.to_runner_api()

        # Write cache on the pipeline proto.
        write_cache.WriteCache(pipeline_proto, aug_p._context,
                               aug_p._cache_manager,
                               aug_p._cacheables[pcoll]).write_cache()
        actual_pipeline = pipeline_proto

        # Write cache directly on the piepline instance.
        label = '{}{}'.format('_cache_', key)
        transform = write_cache._WriteCacheTransform(aug_p._cache_manager, key,
                                                     label)
        _ = pcoll | 'sink' + label >> transform
        expected_pipeline = p.to_runner_api()

        assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline)

        # Check if the actual_pipeline uses pcoll as an input of a write transform.
        pcoll_id = aug_p._context.pcollections.get_id(pcoll)
        write_transform_id = None
        for transform_id, transform in \
             actual_pipeline.components.transforms.items():
            if pcoll_id in transform.inputs.values():
                write_transform_id = transform_id
                break
        self.assertIsNotNone(write_transform_id)
        self.assertIn(
            'sink', actual_pipeline.components.transforms[write_transform_id].
            unique_name)