def test_ignore_pcoll_from_other_pipeline(self): p = beam.Pipeline() p2 = beam.Pipeline() cacheable_from_p2 = p2 | beam.Create([1, 2, 3]) ib.watch(locals()) aug_p = ap.AugmentedPipeline(p) cacheables = aug_p.cacheables() self.assertNotIn(cacheable_from_p2, cacheables)
def test_error_when_pcolls_from_mixed_pipelines(self): p = beam.Pipeline() cacheable_from_p = p | beam.Create([1, 2, 3]) p2 = beam.Pipeline() cacheable_from_p2 = p2 | beam.Create([1, 2, 3]) ib.watch(locals()) self.assertRaises( AssertionError, lambda: ap.AugmentedPipeline( p, (cacheable_from_p, cacheable_from_p2)))
def test_ignore_cacheables(self): p = beam.Pipeline() cacheable_pcoll_1 = p | 'cacheable_pcoll_1' >> beam.Create([1, 2, 3]) cacheable_pcoll_2 = p | 'cacheable_pcoll_2' >> beam.Create([4, 5, 6]) ib.watch(locals()) aug_p = ap.AugmentedPipeline(p, (cacheable_pcoll_1, )) cacheables = aug_p.cacheables() self.assertIn(cacheable_pcoll_1, cacheables) self.assertNotIn(cacheable_pcoll_2, cacheables)
def test_find_all_cacheables(self): p = beam.Pipeline() cacheable_pcoll_1 = p | beam.Create([1, 2, 3]) cacheable_pcoll_2 = cacheable_pcoll_1 | beam.Map(lambda x: x * x) ib.watch(locals()) aug_p = ap.AugmentedPipeline(p) cacheables = aug_p.cacheables() self.assertIn(cacheable_pcoll_1, cacheables) self.assertIn(cacheable_pcoll_2, cacheables)
def test_read_cache(self, mocked_get_cache_manager): p = beam.Pipeline() pcoll = p | beam.Create([1, 2, 3]) consumer_transform = beam.Map(lambda x: x * x) _ = pcoll | consumer_transform ib.watch(locals()) # Create the cache in memory. cache_manager = InMemoryCache() mocked_get_cache_manager.return_value = cache_manager aug_p = ap.AugmentedPipeline(p) key = repr(aug_p._cacheables[pcoll].to_key()) cache_manager.write('test', 'full', key) # Capture the applied transform of the consumer_transform. pcoll_id = aug_p._context.pcollections.get_id(pcoll) consumer_transform_id = None pipeline_proto = p.to_runner_api() for (transform_id, transform) in pipeline_proto.components.transforms.items(): if pcoll_id in transform.inputs.values(): consumer_transform_id = transform_id break self.assertIsNotNone(consumer_transform_id) # Read cache on the pipeline proto. _, cache_id = read_cache.ReadCache( pipeline_proto, aug_p._context, aug_p._cache_manager, aug_p._cacheables[pcoll]).read_cache() actual_pipeline = pipeline_proto # Read cache directly on the pipeline instance. label = '{}{}'.format('_cache_', key) transform = read_cache._ReadCacheTransform(aug_p._cache_manager, key, label) p | 'source' + label >> transform expected_pipeline = p.to_runner_api() # This rougly checks the equivalence between two protos, not detailed # wiring in sub transforms under top level transforms. assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline) # Check if the actual_pipeline uses cache as input of the # consumer_transform instead of the original pcoll from source. inputs = actual_pipeline.components.transforms[ consumer_transform_id].inputs self.assertIn(cache_id, inputs.values()) self.assertNotIn(pcoll_id, inputs.values())
def test_write_cache(self, mocked_get_cache_manager): p = beam.Pipeline() pcoll = p | beam.Create([1, 2, 3]) ib.watch(locals()) cache_manager = InMemoryCache() mocked_get_cache_manager.return_value = cache_manager aug_p = ap.AugmentedPipeline(p) key = repr(aug_p._cacheables[pcoll].to_key()) pipeline_proto = p.to_runner_api() # Write cache on the pipeline proto. write_cache.WriteCache(pipeline_proto, aug_p._context, aug_p._cache_manager, aug_p._cacheables[pcoll]).write_cache() actual_pipeline = pipeline_proto # Write cache directly on the piepline instance. label = '{}{}'.format('_cache_', key) transform = write_cache._WriteCacheTransform(aug_p._cache_manager, key, label) _ = pcoll | 'sink' + label >> transform expected_pipeline = p.to_runner_api() assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline) # Check if the actual_pipeline uses pcoll as an input of a write transform. pcoll_id = aug_p._context.pcollections.get_id(pcoll) write_transform_id = None for transform_id, transform in \ actual_pipeline.components.transforms.items(): if pcoll_id in transform.inputs.values(): write_transform_id = transform_id break self.assertIsNotNone(write_transform_id) self.assertIn( 'sink', actual_pipeline.components.transforms[write_transform_id]. unique_name)