def test_cacheable_key_with_version_map(self):
        p = beam.Pipeline(interactive_runner.InteractiveRunner())
        ie.current_env().set_cache_manager(InMemoryCache(), p)
        # pylint: disable=range-builtin-not-iterating
        init_pcoll = p | 'Init Create' >> beam.Create(range(10))

        # It's normal that when executing, the pipeline object is a different
        # but equivalent instance from what user has built. The pipeline instrument
        # should be able to identify if the original instance has changed in an
        # interactive env while mutating the other instance for execution. The
        # version map can be used to figure out what the PCollection instances are
        # in the original instance and if the evaluation has changed since last
        # execution.
        p2 = beam.Pipeline(interactive_runner.InteractiveRunner())
        ie.current_env().set_cache_manager(InMemoryCache(), p2)
        # pylint: disable=range-builtin-not-iterating
        init_pcoll_2 = p2 | 'Init Create' >> beam.Create(range(10))
        _, ctx = p2.to_runner_api(return_context=True)

        # The cacheable_key should use id(init_pcoll) as prefix even when
        # init_pcoll_2 is supplied as long as the version map is given.
        self.assertEqual(
            instr.cacheable_key(
                init_pcoll_2, instr.pcolls_to_pcoll_id(p2, ctx),
                {'ref_PCollection_PCollection_8': str(id(init_pcoll))}),
            str(id(init_pcoll)) + '_ref_PCollection_PCollection_8')
Exemple #2
0
    def setUp(self):
        self.cache = InMemoryCache()
        self.p = beam.Pipeline()
        self.pcoll = self.p | beam.Create([])
        self.cache_key = str(CacheKey('pcoll', '', '', ''))

        # Create a MockPipelineResult to control the state of a fake run of the
        # pipeline.
        self.mock_result = MockPipelineResult()
        ie.current_env().add_user_pipeline(self.p)
        ie.current_env().set_pipeline_result(self.p, self.mock_result)
        ie.current_env().set_cache_manager(self.cache, self.p)
Exemple #3
0
    def test_read_cache(self, mocked_get_cache_manager):
        p = beam.Pipeline()
        pcoll = p | beam.Create([1, 2, 3])
        consumer_transform = beam.Map(lambda x: x * x)
        _ = pcoll | consumer_transform
        ib.watch(locals())

        # Create the cache in memory.
        cache_manager = InMemoryCache()
        mocked_get_cache_manager.return_value = cache_manager
        aug_p = ap.AugmentedPipeline(p)
        key = repr(aug_p._cacheables[pcoll].to_key())
        cache_manager.write('test', 'full', key)

        # Capture the applied transform of the consumer_transform.
        pcoll_id = aug_p._context.pcollections.get_id(pcoll)
        consumer_transform_id = None
        pipeline_proto = p.to_runner_api()
        for (transform_id,
             transform) in pipeline_proto.components.transforms.items():
            if pcoll_id in transform.inputs.values():
                consumer_transform_id = transform_id
                break
        self.assertIsNotNone(consumer_transform_id)

        # Read cache on the pipeline proto.
        _, cache_id = read_cache.ReadCache(
            pipeline_proto, aug_p._context, aug_p._cache_manager,
            aug_p._cacheables[pcoll]).read_cache()
        actual_pipeline = pipeline_proto

        # Read cache directly on the pipeline instance.
        label = '{}{}'.format('_cache_', key)
        transform = read_cache._ReadCacheTransform(aug_p._cache_manager, key,
                                                   label)
        p | 'source' + label >> transform
        expected_pipeline = p.to_runner_api()

        # This rougly checks the equivalence between two protos, not detailed
        # wiring in sub transforms under top level transforms.
        assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline)

        # Check if the actual_pipeline uses cache as input of the
        # consumer_transform instead of the original pcoll from source.
        inputs = actual_pipeline.components.transforms[
            consumer_transform_id].inputs
        self.assertIn(cache_id, inputs.values())
        self.assertNotIn(pcoll_id, inputs.values())
Exemple #4
0
 def test_not_has_unbounded_source(self):
     p = beam.Pipeline()
     ie.current_env().set_cache_manager(InMemoryCache(), p)
     with tempfile.NamedTemporaryFile(delete=False) as f:
         f.write(b'test')
     _ = p | 'ReadBoundedSource' >> beam.io.ReadFromText(f.name)
     self.assertFalse(utils.has_unbounded_sources(p))
Exemple #5
0
  def test_cacheables(self):
    p = beam.Pipeline(interactive_runner.InteractiveRunner())
    ie.current_env().set_cache_manager(InMemoryCache(), p)
    # pylint: disable=range-builtin-not-iterating
    init_pcoll = p | 'Init Create' >> beam.Create(range(10))
    squares = init_pcoll | 'Square' >> beam.Map(lambda x: x * x)
    cubes = init_pcoll | 'Cube' >> beam.Map(lambda x: x**3)
    ib.watch(locals())

    pipeline_instrument = instr.build_pipeline_instrument(p)
    self.assertEqual(
        pipeline_instrument.cacheables,
        {
            pipeline_instrument._cacheable_key(init_pcoll): instr.Cacheable(
                var='init_pcoll',
                version=str(id(init_pcoll)),
                pcoll_id='ref_PCollection_PCollection_8',
                producer_version=str(id(init_pcoll.producer)),
                pcoll=init_pcoll),
            pipeline_instrument._cacheable_key(squares): instr.Cacheable(
                var='squares',
                version=str(id(squares)),
                pcoll_id='ref_PCollection_PCollection_9',
                producer_version=str(id(squares.producer)),
                pcoll=squares),
            pipeline_instrument._cacheable_key(cubes): instr.Cacheable(
                var='cubes',
                version=str(id(cubes)),
                pcoll_id='ref_PCollection_PCollection_10',
                producer_version=str(id(cubes.producer)),
                pcoll=cubes)
        })
    def test_cacheables(self):
        p_cacheables = beam.Pipeline(interactive_runner.InteractiveRunner())
        ie.current_env().set_cache_manager(InMemoryCache(), p_cacheables)
        # pylint: disable=bad-option-value
        init_pcoll = p_cacheables | 'Init Create' >> beam.Create(range(10))
        squares = init_pcoll | 'Square' >> beam.Map(lambda x: x * x)
        cubes = init_pcoll | 'Cube' >> beam.Map(lambda x: x**3)
        ib.watch(locals())

        pipeline_instrument = instr.build_pipeline_instrument(p_cacheables)

        self.assertEqual(
            pipeline_instrument._cacheables, {
                pipeline_instrument.pcoll_id(init_pcoll):
                Cacheable(var='init_pcoll',
                          version=str(id(init_pcoll)),
                          producer_version=str(id(init_pcoll.producer)),
                          pcoll=init_pcoll),
                pipeline_instrument.pcoll_id(squares):
                Cacheable(var='squares',
                          version=str(id(squares)),
                          producer_version=str(id(squares.producer)),
                          pcoll=squares),
                pipeline_instrument.pcoll_id(cubes):
                Cacheable(var='cubes',
                          version=str(id(cubes)),
                          producer_version=str(id(cubes.producer)),
                          pcoll=cubes)
            })
 def test_pcoll_id_with_user_pipeline(self):
     p_id_user = beam.Pipeline(interactive_runner.InteractiveRunner())
     ie.current_env().set_cache_manager(InMemoryCache(), p_id_user)
     init_pcoll = p_id_user | 'Init Create' >> beam.Create([1, 2, 3])
     instrumentation = instr.build_pipeline_instrument(p_id_user)
     self.assertEqual(instrumentation.pcoll_id(init_pcoll),
                      'ref_PCollection_PCollection_8')
 def test_pcoll_to_pcoll_id(self):
     p = beam.Pipeline(interactive_runner.InteractiveRunner())
     ie.current_env().set_cache_manager(InMemoryCache(), p)
     # pylint: disable=bad-option-value
     init_pcoll = p | 'Init Create' >> beam.Impulse()
     _, ctx = p.to_runner_api(return_context=True)
     self.assertEqual(instr.pcoll_to_pcoll_id(p, ctx),
                      {str(init_pcoll): 'ref_PCollection_PCollection_1'})
Exemple #9
0
 def test_pcolls_to_pcoll_id(self):
     p = beam.Pipeline(interactive_runner.InteractiveRunner())
     ie.current_env().set_cache_manager(InMemoryCache(), p)
     # pylint: disable=range-builtin-not-iterating
     init_pcoll = p | 'Init Create' >> beam.Impulse()
     _, ctx = p.to_runner_api(use_fake_coders=True, return_context=True)
     self.assertEqual(instr.pcolls_to_pcoll_id(p, ctx),
                      {str(init_pcoll): 'ref_PCollection_PCollection_1'})
 def test_cacheable_key_without_version_map(self):
     p = beam.Pipeline(interactive_runner.InteractiveRunner())
     ie.current_env().set_cache_manager(InMemoryCache(), p)
     # pylint: disable=range-builtin-not-iterating
     init_pcoll = p | 'Init Create' >> beam.Create(range(10))
     _, ctx = p.to_runner_api(return_context=True)
     self.assertEqual(
         instr.cacheable_key(init_pcoll, instr.pcolls_to_pcoll_id(p, ctx)),
         str(id(init_pcoll)) + '_ref_PCollection_PCollection_8')
 def test_side_effect_pcoll_is_included(self):
     pipeline_with_side_effect = beam.Pipeline(
         interactive_runner.InteractiveRunner())
     ie.current_env().set_cache_manager(InMemoryCache(),
                                        pipeline_with_side_effect)
     # Deliberately not assign the result to a variable to make it a
     # "side effect" transform. Note we never watch anything from
     # the pipeline defined locally either.
     # pylint: disable=bad-option-value,expression-not-assigned
     pipeline_with_side_effect | 'Init Create' >> beam.Create(range(10))
     pipeline_instrument = instr.build_pipeline_instrument(
         pipeline_with_side_effect)
     self.assertTrue(pipeline_instrument._extended_targets)
    def _example_pipeline(self, watch=True, bounded=True):
        p_example = beam.Pipeline(interactive_runner.InteractiveRunner())
        ie.current_env().set_cache_manager(InMemoryCache(), p_example)
        # pylint: disable=bad-option-value
        if bounded:
            source = beam.Create(range(10))
        else:
            source = beam.io.ReadFromPubSub(
                subscription='projects/fake-project/subscriptions/fake_sub')

        init_pcoll = p_example | 'Init Source' >> source
        second_pcoll = init_pcoll | 'Second' >> beam.Map(lambda x: x * x)
        if watch:
            ib.watch(locals())
        return (p_example, init_pcoll, second_pcoll)
    def test_cache_key(self):
        p = beam.Pipeline(interactive_runner.InteractiveRunner())
        ie.current_env().set_cache_manager(InMemoryCache(), p)
        # pylint: disable=bad-option-value
        init_pcoll = p | 'Init Create' >> beam.Create(range(10))
        squares = init_pcoll | 'Square' >> beam.Map(lambda x: x * x)
        cubes = init_pcoll | 'Cube' >> beam.Map(lambda x: x**3)
        # Watch the local variables, i.e., the Beam pipeline defined.
        ib.watch(locals())

        pipeline_instrument = instr.build_pipeline_instrument(p)
        self.assertEqual(pipeline_instrument.cache_key(init_pcoll),
                         self.cache_key_of('init_pcoll', init_pcoll))
        self.assertEqual(pipeline_instrument.cache_key(squares),
                         self.cache_key_of('squares', squares))
        self.assertEqual(pipeline_instrument.cache_key(cubes),
                         self.cache_key_of('cubes', cubes))
Exemple #14
0
  def test_cacheables(self):
    p = beam.Pipeline(interactive_runner.InteractiveRunner())
    ie.current_env().set_cache_manager(InMemoryCache(), p)
    # pylint: disable=range-builtin-not-iterating
    init_pcoll = p | 'Init Create' >> beam.Create(range(10))
    squares = init_pcoll | 'Square' >> beam.Map(lambda x: x * x)
    cubes = init_pcoll | 'Cube' >> beam.Map(lambda x: x**3)
    ib.watch(locals())

    pipeline_instrument = instr.build_pipeline_instrument(p)

    # TODO(BEAM-7760): The PipelineInstrument cacheables maintains a global list
    # of cacheable PCollections across all pipelines. Here we take the subset of
    # cacheables that only pertain to this test's pipeline.
    cacheables = {
        k: c
        for k,
        c in pipeline_instrument.cacheables.items() if c.pcoll.pipeline is p
    }

    self.assertEqual(
        cacheables,
        {
            pipeline_instrument._cacheable_key(init_pcoll): instr.Cacheable(
                var='init_pcoll',
                version=str(id(init_pcoll)),
                pcoll_id='ref_PCollection_PCollection_8',
                producer_version=str(id(init_pcoll.producer)),
                pcoll=init_pcoll),
            pipeline_instrument._cacheable_key(squares): instr.Cacheable(
                var='squares',
                version=str(id(squares)),
                pcoll_id='ref_PCollection_PCollection_9',
                producer_version=str(id(squares.producer)),
                pcoll=squares),
            pipeline_instrument._cacheable_key(cubes): instr.Cacheable(
                var='cubes',
                version=str(id(cubes)),
                pcoll_id='ref_PCollection_PCollection_10',
                producer_version=str(id(cubes.producer)),
                pcoll=cubes)
        })
    def test_write_cache(self, mocked_get_cache_manager):
        p = beam.Pipeline()
        pcoll = p | beam.Create([1, 2, 3])
        ib.watch(locals())

        cache_manager = InMemoryCache()
        mocked_get_cache_manager.return_value = cache_manager
        aug_p = ap.AugmentedPipeline(p)
        key = repr(aug_p._cacheables[pcoll].to_key())
        pipeline_proto = p.to_runner_api()

        # Write cache on the pipeline proto.
        write_cache.WriteCache(pipeline_proto, aug_p._context,
                               aug_p._cache_manager,
                               aug_p._cacheables[pcoll]).write_cache()
        actual_pipeline = pipeline_proto

        # Write cache directly on the piepline instance.
        label = '{}{}'.format('_cache_', key)
        transform = write_cache._WriteCacheTransform(aug_p._cache_manager, key,
                                                     label)
        _ = pcoll | 'sink' + label >> transform
        expected_pipeline = p.to_runner_api()

        assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline)

        # Check if the actual_pipeline uses pcoll as an input of a write transform.
        pcoll_id = aug_p._context.pcollections.get_id(pcoll)
        write_transform_id = None
        for transform_id, transform in \
             actual_pipeline.components.transforms.items():
            if pcoll_id in transform.inputs.values():
                write_transform_id = transform_id
                break
        self.assertIsNotNone(write_transform_id)
        self.assertIn(
            'sink', actual_pipeline.components.transforms[write_transform_id].
            unique_name)
    def test_pcoll_id_with_runner_pipeline(self):
        p_id_runner = beam.Pipeline(interactive_runner.InteractiveRunner())
        ie.current_env().set_cache_manager(InMemoryCache(), p_id_runner)
        # pylint: disable=possibly-unused-variable
        init_pcoll = p_id_runner | 'Init Create' >> beam.Create([1, 2, 3])
        ib.watch(locals())

        # It's normal that when executing, the pipeline object is a different
        # but equivalent instance from what user has built. The pipeline instrument
        # should be able to identify if the original instance has changed in an
        # interactive env while mutating the other instance for execution. The
        # version map can be used to figure out what the PCollection instances are
        # in the original instance and if the evaluation has changed since last
        # execution.
        p2_id_runner = beam.Pipeline(interactive_runner.InteractiveRunner())
        # pylint: disable=bad-option-value
        init_pcoll_2 = p2_id_runner | 'Init Create' >> beam.Create(range(10))
        ie.current_env().add_derived_pipeline(p_id_runner, p2_id_runner)

        instrumentation = instr.build_pipeline_instrument(p2_id_runner)
        # The cache_key should use id(init_pcoll) as prefix even when
        # init_pcoll_2 is supplied as long as the version map is given.
        self.assertEqual(instrumentation.pcoll_id(init_pcoll_2),
                         'ref_PCollection_PCollection_8')
Exemple #17
0
 def setUp(self):
     ie.new_env(cache_manager=InMemoryCache())
Exemple #18
0
 def test_has_unbounded_source(self):
     p = beam.Pipeline()
     ie.current_env().set_cache_manager(InMemoryCache(), p)
     _ = p | 'ReadUnboundedSource' >> beam.io.ReadFromPubSub(
         subscription='projects/fake-project/subscriptions/fake_sub')
     self.assertTrue(utils.has_unbounded_sources(p))
Exemple #19
0
class ElementStreamTest(unittest.TestCase):
    def setUp(self):
        self.cache = InMemoryCache()
        self.p = beam.Pipeline()
        self.pcoll = self.p | beam.Create([])
        self.cache_key = str(CacheKey('pcoll', '', '', ''))

        # Create a MockPipelineResult to control the state of a fake run of the
        # pipeline.
        self.mock_result = MockPipelineResult()
        ie.current_env().add_user_pipeline(self.p)
        ie.current_env().set_pipeline_result(self.p, self.mock_result)
        ie.current_env().set_cache_manager(self.cache, self.p)

    def test_read(self):
        """Test reading and if a stream is done no more elements are returned."""

        self.mock_result.set_state(PipelineState.DONE)
        self.cache.write(['expected'], 'full', self.cache_key)
        self.cache.save_pcoder(None, 'full', self.cache_key)

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=1,
                               max_duration_secs=1)

        self.assertFalse(stream.is_done())
        self.assertEqual(list(stream.read())[0], 'expected')
        self.assertTrue(stream.is_done())

    def test_done_if_terminated(self):
        """Test that terminating the job sets the stream as done."""

        self.cache.write(['expected'], 'full', self.cache_key)
        self.cache.save_pcoder(None, 'full', self.cache_key)

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=100,
                               max_duration_secs=10)

        self.assertFalse(stream.is_done())
        self.assertEqual(list(stream.read(tail=False))[0], 'expected')

        # The limiters were not reached, so the stream is not done yet.
        self.assertFalse(stream.is_done())

        self.mock_result.set_state(PipelineState.DONE)
        self.assertEqual(list(stream.read(tail=False))[0], 'expected')

        # The underlying pipeline is terminated, so the stream won't yield new
        # elements.
        self.assertTrue(stream.is_done())

    def test_read_n(self):
        """Test that the stream only reads 'n' elements."""

        self.mock_result.set_state(PipelineState.DONE)
        self.cache.write(list(range(5)), 'full', self.cache_key)
        self.cache.save_pcoder(None, 'full', self.cache_key)

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=1,
                               max_duration_secs=1)
        self.assertEqual(list(stream.read()), [0])
        self.assertTrue(stream.is_done())

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=2,
                               max_duration_secs=1)
        self.assertEqual(list(stream.read()), [0, 1])
        self.assertTrue(stream.is_done())

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=5,
                               max_duration_secs=1)
        self.assertEqual(list(stream.read()), list(range(5)))
        self.assertTrue(stream.is_done())

        # Test that if the user asks for more than in the cache it still returns.
        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=10,
                               max_duration_secs=1)
        self.assertEqual(list(stream.read()), list(range(5)))
        self.assertTrue(stream.is_done())

    def test_read_duration(self):
        """Test that the stream only reads a 'duration' of elements."""
        def as_windowed_value(element):
            return WindowedValueHolder(WindowedValue(element, 0, []))

        values = (FileRecordsBuilder(tag=self.cache_key)
                  .advance_processing_time(1)
                  .add_element(element=as_windowed_value(0), event_time_secs=0)
                  .advance_processing_time(1)
                  .add_element(element=as_windowed_value(1), event_time_secs=1)
                  .advance_processing_time(1)
                  .add_element(element=as_windowed_value(2), event_time_secs=3)
                  .advance_processing_time(1)
                  .add_element(element=as_windowed_value(3), event_time_secs=4)
                  .advance_processing_time(1)
                  .add_element(element=as_windowed_value(4), event_time_secs=5)
                  .build()) # yapf: disable

        values = [
            v.recorded_event for v in values
            if isinstance(v, beam_interactive_api_pb2.TestStreamFileRecord)
        ]

        self.mock_result.set_state(PipelineState.DONE)
        self.cache.write(values, 'full', self.cache_key)
        self.cache.save_pcoder(coders.FastPrimitivesCoder(), 'full',
                               self.cache_key)

        # The following tests a progression of reading different durations from the
        # cache.

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=100,
                               max_duration_secs=1)
        self.assertSequenceEqual([e.value for e in stream.read()], [0])

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=100,
                               max_duration_secs=2)
        self.assertSequenceEqual([e.value for e in stream.read()], [0, 1])

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=100,
                               max_duration_secs=10)
        self.assertSequenceEqual([e.value for e in stream.read()],
                                 [0, 1, 2, 3, 4])
Exemple #20
0
    def test_describe(self):
        p = beam.Pipeline(InteractiveRunner())
        numbers = p | 'numbers' >> beam.Create([0, 1, 2])
        letters = p | 'letters' >> beam.Create(['a', 'b', 'c'])

        ib.watch(locals())

        # Create a MockPipelineResult to control the state of a fake run of the
        # pipeline.
        mock_result = MockPipelineResult()
        ie.current_env().track_user_pipelines()
        ie.current_env().set_pipeline_result(p, mock_result)

        cache_manager = InMemoryCache()
        ie.current_env().set_cache_manager(cache_manager, p)

        # Create a recording with an arbitrary start time.
        recording = Recording(p, [numbers, letters],
                              mock_result,
                              max_n=10,
                              max_duration_secs=60)

        # Get the cache key of the stream and write something to cache. This is
        # so that a pipeline doesn't have to run in the test.
        numbers_stream = recording.stream(numbers)
        cache_manager.write([0, 1, 2], 'full', numbers_stream.cache_key)
        cache_manager.save_pcoder(None, 'full', numbers_stream.cache_key)

        letters_stream = recording.stream(letters)
        cache_manager.write(['a', 'b', 'c'], 'full', letters_stream.cache_key)
        cache_manager.save_pcoder(None, 'full', letters_stream.cache_key)

        # Get the description.
        description = recording.describe()
        size = description['size']

        self.assertEqual(
            size,
            cache_manager.size('full', numbers_stream.cache_key) +
            cache_manager.size('full', letters_stream.cache_key))
class ElementStreamTest(unittest.TestCase):
    def setUp(self):
        ie.new_env()

        self.cache = InMemoryCache()
        self.p = beam.Pipeline()
        self.pcoll = self.p | beam.Create([])
        self.cache_key = str(pi.CacheKey('pcoll', '', '', ''))

        # Create a MockPipelineResult to control the state of a fake run of the
        # pipeline.
        self.mock_result = MockPipelineResult()
        ie.current_env().track_user_pipelines()
        ie.current_env().set_pipeline_result(self.p, self.mock_result)
        ie.current_env().set_cache_manager(self.cache, self.p)

    def test_read(self):
        """Test reading and if a stream is done no more elements are returned."""

        self.mock_result.set_state(PipelineState.DONE)
        self.cache.write(['expected'], 'full', self.cache_key)
        self.cache.save_pcoder(None, 'full', self.cache_key)

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=1,
                               max_duration_secs=1)

        self.assertFalse(stream.is_done())
        self.assertEqual(list(stream.read())[0], 'expected')
        self.assertTrue(stream.is_done())

    def test_done_if_terminated(self):
        """Test that terminating the job sets the stream as done."""

        self.cache.write(['expected'], 'full', self.cache_key)
        self.cache.save_pcoder(None, 'full', self.cache_key)

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=100,
                               max_duration_secs=10)

        self.assertFalse(stream.is_done())
        self.assertEqual(list(stream.read(tail=False))[0], 'expected')

        # The limiters were not reached, so the stream is not done yet.
        self.assertFalse(stream.is_done())

        self.mock_result.set_state(PipelineState.DONE)
        self.assertEqual(list(stream.read(tail=False))[0], 'expected')

        # The underlying pipeline is terminated, so the stream won't yield new
        # elements.
        self.assertTrue(stream.is_done())

    def test_read_n(self):
        """Test that the stream only reads 'n' elements."""

        self.mock_result.set_state(PipelineState.DONE)
        self.cache.write(list(range(5)), 'full', self.cache_key)
        self.cache.save_pcoder(None, 'full', self.cache_key)

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=1,
                               max_duration_secs=1)
        self.assertEqual(list(stream.read()), [0])
        self.assertTrue(stream.is_done())

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=2,
                               max_duration_secs=1)
        self.assertEqual(list(stream.read()), [0, 1])
        self.assertTrue(stream.is_done())

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=5,
                               max_duration_secs=1)
        self.assertEqual(list(stream.read()), list(range(5)))
        self.assertTrue(stream.is_done())

        # Test that if the user asks for more than in the cache it still returns.
        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=10,
                               max_duration_secs=1)
        self.assertEqual(list(stream.read()), list(range(5)))
        self.assertTrue(stream.is_done())

    def test_read_duration(self):
        """Test that the stream only reads a 'duration' of elements."""

        values = (FileRecordsBuilder(tag=self.cache_key)
                  .advance_processing_time(1)
                  .add_element(element=0, event_time_secs=0)
                  .advance_processing_time(1)
                  .add_element(element=1, event_time_secs=1)
                  .advance_processing_time(1)
                  .add_element(element=2, event_time_secs=3)
                  .advance_processing_time(1)
                  .add_element(element=3, event_time_secs=4)
                  .advance_processing_time(1)
                  .add_element(element=4, event_time_secs=5)
                  .build()) # yapf: disable

        self.mock_result.set_state(PipelineState.DONE)
        self.cache.write(values, 'full', self.cache_key)
        self.cache.save_pcoder(None, 'full', self.cache_key)

        # The elements read from the cache are TestStreamFileRecord instances and
        # have the underlying elements encoded. This method decodes the elements
        # from the TestStreamFileRecord.
        def get_elements(events):
            coder = coders.FastPrimitivesCoder()
            elements = []
            for e in events:
                if not isinstance(e, TestStreamFileRecord):
                    continue

                if e.recorded_event.element_event:
                    elements += ([
                        coder.decode(el.encoded_element)
                        for el in e.recorded_event.element_event.elements
                    ])
            return elements

        # The following tests a progression of reading different durations from the
        # cache.
        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=100,
                               max_duration_secs=1)
        self.assertSequenceEqual(get_elements(stream.read()), [0])

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=100,
                               max_duration_secs=2)
        self.assertSequenceEqual(get_elements(stream.read()), [0, 1])

        stream = ElementStream(self.pcoll,
                               '',
                               self.cache_key,
                               max_n=100,
                               max_duration_secs=10)
        self.assertSequenceEqual(get_elements(stream.read()), [0, 1, 2, 3, 4])