Exemplo n.º 1
0
def cache_output(output_name: str, output: PValue) -> None:
    user_pipeline = ie.current_env().user_pipeline(output.pipeline)
    if user_pipeline:
        cache_manager = ie.current_env().get_cache_manager(
            user_pipeline, create_if_absent=True)
    else:
        _LOGGER.warning(
            'Something is wrong with %s. Cannot introspect its data.', output)
        return
    key = CacheKey.from_pcoll(output_name, output).to_str()
    _ = reify_to_cache(pcoll=output,
                       cache_key=key,
                       cache_manager=cache_manager)
    try:
        output.pipeline.run().wait_until_finish()
    except (KeyboardInterrupt, SystemExit):
        raise
    except Exception as e:
        _LOGGER.warning(_NOT_SUPPORTED_MSG, e, output.pipeline.runner)
        return
    ie.current_env().mark_pcollection_computed([output])
    visualize_computed_pcoll(output_name,
                             output,
                             max_n=float('inf'),
                             max_duration_secs=float('inf'))
Exemplo n.º 2
0
    def __init__(
            self,
            user_pipeline,  # type: beam.Pipeline
            pcolls,  # type: List[beam.pvalue.PCollection]
            result,  # type: beam.runner.PipelineResult
            max_n,  # type: int
            max_duration_secs,  # type: float
    ):
        self._user_pipeline = user_pipeline
        self._result = result
        self._result_lock = threading.Lock()
        self._pcolls = pcolls
        pcoll_var = lambda pcoll: {
            v: k
            for k, v in utils.pcoll_by_name().items()
        }.get(pcoll, None)

        self._streams = {
            pcoll: ElementStream(
                pcoll, pcoll_var(pcoll),
                CacheKey.from_pcoll(pcoll_var(pcoll), pcoll).to_str(), max_n,
                max_duration_secs)
            for pcoll in pcolls
        }

        self._start = time.time()
        self._duration_secs = max_duration_secs
        self._set_computed = bcj.is_cache_complete(str(id(user_pipeline)))

        # Run a separate thread for marking the PCollections done. This is because
        # the pipeline run may be asynchronous.
        self._mark_computed = threading.Thread(target=self._mark_all_computed)
        self._mark_computed.daemon = True
        self._mark_computed.start()
Exemplo n.º 3
0
def _build_query_components(
    query: str,
    found: Dict[str, beam.PCollection],
    output_name: str,
    run: bool = True
) -> Tuple[str,
           Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline],
           SqlChain]:
  """Builds necessary components needed to apply the SqlTransform.

  Args:
    query: The SQL query to be executed by the magic.
    found: The PCollections with variable names found to be used by the query.
    output_name: The output variable name in __main__ module.
    run: Whether to prepare components for a local run or not.

  Returns:
    The processed query to be executed by the magic; a source to apply the
    SqlTransform to: a dictionary of tagged PCollections, or a single
    PCollection, or the pipeline to execute the query; the chain of applied
    beam_sql magics this one belongs to.
  """
  if found:
    user_pipeline = ie.current_env().user_pipeline(
        next(iter(found.values())).pipeline)
    sql_pipeline = beam.Pipeline(options=user_pipeline._options)
    ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline)
    sql_source = {}
    if run:
      if has_source_to_cache(user_pipeline):
        sql_source = pcolls_from_streaming_cache(
            user_pipeline, sql_pipeline, found)
      else:
        cache_manager = ie.current_env().get_cache_manager(
            user_pipeline, create_if_absent=True)
        for pcoll_name, pcoll in found.items():
          cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str()
          sql_source[pcoll_name] = unreify_from_cache(
              pipeline=sql_pipeline,
              cache_key=cache_key,
              cache_manager=cache_manager,
              element_type=pcoll.element_type)
    else:
      sql_source = found
    if len(sql_source) == 1:
      query = replace_single_pcoll_token(query, next(iter(sql_source.keys())))
      sql_source = next(iter(sql_source.values()))

    node = SqlNode(
        output_name=output_name, source=set(found.keys()), query=query)
    chain = ie.current_env().get_sql_chain(
        user_pipeline, set_user_pipeline=True).append(node)
  else:  # does not query any existing PCollection
    sql_source = beam.Pipeline()
    ie.current_env().add_user_pipeline(sql_source)

    # The node should be the root node of the chain created below.
    node = SqlNode(output_name=output_name, source=sql_source, query=query)
    chain = ie.current_env().get_sql_chain(sql_source).append(node)
  return query, sql_source, chain
Exemplo n.º 4
0
    def cache_key(self, pcoll):
        """Gets the identifier of a cacheable PCollection in cache.

    If the pcoll is not a cacheable, return ''.
    The key is what the pcoll would use as identifier if it's materialized in
    cache. It doesn't mean that there would definitely be such cache already.
    Also, the pcoll can come from the original user defined pipeline object or
    an equivalent pcoll from a transformed copy of the original pipeline.

    'pcoll_id' of cacheable is not stable for cache_key, thus not included in
    cache key. A combination of 'var', 'version' and 'producer_version' is
    sufficient to identify a cached PCollection.
    """
        cacheable = self.cacheables.get(self._cacheable_key(pcoll), None)
        if cacheable:
            if cacheable.pcoll in self.runner_pcoll_to_user_pcoll:
                user_pcoll = self.runner_pcoll_to_user_pcoll[cacheable.pcoll]
            else:
                user_pcoll = cacheable.pcoll

            return repr(
                CacheKey(cacheable.var, cacheable.version,
                         cacheable.producer_version,
                         str(id(user_pcoll.pipeline))))
        return ''
    def test_empty(self):
        CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))

        cache = StreamingCache(cache_dir=None)
        self.assertFalse(cache.exists(CACHED_PCOLLECTION_KEY))
        cache.write([], CACHED_PCOLLECTION_KEY)
        reader, _ = cache.read(CACHED_PCOLLECTION_KEY)

        # Assert that an empty reader returns an empty list.
        self.assertFalse([e for e in reader])
Exemplo n.º 6
0
    def setUp(self):
        self.cache = InMemoryCache()
        self.p = beam.Pipeline()
        self.pcoll = self.p | beam.Create([])
        self.cache_key = str(CacheKey('pcoll', '', '', ''))

        # Create a MockPipelineResult to control the state of a fake run of the
        # pipeline.
        self.mock_result = MockPipelineResult()
        ie.current_env().add_user_pipeline(self.p)
        ie.current_env().set_pipeline_result(self.p, self.mock_result)
        ie.current_env().set_cache_manager(self.cache, self.p)
Exemplo n.º 7
0
    def __init__(self, cache_dir, labels, is_cache_complete=None, coder=None):
        if not coder:
            coder = SafeFastPrimitivesCoder()

        if not is_cache_complete:
            is_cache_complete = lambda _: True

        self._cache_dir = cache_dir
        self._coder = coder
        self._labels = labels
        self._path = os.path.join(self._cache_dir, *self._labels)
        self._is_cache_complete = is_cache_complete
        self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
    def test_single_reader(self):
        """Tests that we expect to see all the correctly emitted TestStreamPayloads.
    """
        CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))

        values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
                  .add_element(element=0, event_time_secs=0)
                  .advance_processing_time(1)
                  .add_element(element=1, event_time_secs=1)
                  .advance_processing_time(1)
                  .add_element(element=2, event_time_secs=2)
                  .build()) # yapf: disable

        cache = StreamingCache(cache_dir=None)
        cache.write(values, CACHED_PCOLLECTION_KEY)

        reader, _ = cache.read(CACHED_PCOLLECTION_KEY)
        coder = coders.FastPrimitivesCoder()
        events = list(reader)

        # Units here are in microseconds.
        expected = [
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode(0), timestamp=0)
                    ],
                    tag=CACHED_PCOLLECTION_KEY)),
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(element_event=TestStreamPayload.Event.
                                    AddElements(elements=[
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode(1),
                                            timestamp=1 * 10**6)
                                    ],
                                                tag=CACHED_PCOLLECTION_KEY)),
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(element_event=TestStreamPayload.Event.
                                    AddElements(elements=[
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode(2),
                                            timestamp=2 * 10**6)
                                    ],
                                                tag=CACHED_PCOLLECTION_KEY)),
        ]
        self.assertSequenceEqual(events, expected)
Exemplo n.º 9
0
def pcolls_from_streaming_cache(
    user_pipeline: beam.Pipeline,
    query_pipeline: beam.Pipeline,
    name_to_pcoll: Dict[str, beam.PCollection]) -> Dict[str, beam.PCollection]:
  """Reads PCollection cache through the TestStream.

  Args:
    user_pipeline: The beam.Pipeline object defined by the user in the
        notebook.
    query_pipeline: The beam.Pipeline object built by the magic to execute the
        SQL query.
    name_to_pcoll: PCollections with variable names used in the SQL query.

  Returns:
    A Dict[str, beam.PCollection], where each PCollection is tagged with
    their PCollection variable names, read from the cache.

  When the user_pipeline has unbounded sources, we force all cache reads to go
  through the TestStream even if they are bounded sources.
  """
  def exception_handler(e):
    _LOGGER.error(str(e))
    return True

  cache_manager = ie.current_env().get_cache_manager(
      user_pipeline, create_if_absent=True)
  test_stream_service = ie.current_env().get_test_stream_service_controller(
      user_pipeline)
  if not test_stream_service:
    test_stream_service = TestStreamServiceController(
        cache_manager, exception_handler=exception_handler)
    test_stream_service.start()
    ie.current_env().set_test_stream_service_controller(
        user_pipeline, test_stream_service)

  tag_to_name = {}
  for name, pcoll in name_to_pcoll.items():
    key = CacheKey.from_pcoll(name, pcoll).to_str()
    tag_to_name[key] = name
  output_pcolls = query_pipeline | test_stream.TestStream(
      output_tags=set(tag_to_name.keys()),
      coder=cache_manager._default_pcoder,
      endpoint=test_stream_service.endpoint)
  sql_source = {}
  for tag, output in output_pcolls.items():
    name = tag_to_name[tag]
    # Must mark the element_type to avoid introducing pickled Python coder
    # to the Java expansion service.
    output.element_type = name_to_pcoll[name].element_type
    sql_source[name] = output
  return sql_source
Exemplo n.º 10
0
 def test_always_default_coder_for_test_stream_records(self):
     CACHED_NUMBERS = repr(CacheKey('numbers', '', '', ''))
     numbers = (FileRecordsBuilder(CACHED_NUMBERS)
                .advance_processing_time(2)
                .add_element(element=1, event_time_secs=0)
                .advance_processing_time(1)
                .add_element(element=2, event_time_secs=0)
                .advance_processing_time(1)
                .add_element(element=2, event_time_secs=0)
                .build()) # yapf: disable
     cache = StreamingCache(cache_dir=None)
     cache.write(numbers, CACHED_NUMBERS)
     self.assertIs(type(cache.load_pcoder(CACHED_NUMBERS)),
                   type(cache._default_pcoder))
Exemplo n.º 11
0
 def _wait_until_file_exists(self, timeout_secs=30):
     """Blocks until the file exists for a maximum of timeout_secs.
 """
     # Wait for up to `timeout_secs` for the file to be available.
     start = time.time()
     while not os.path.exists(self._path):
         time.sleep(1)
         if time.time() - start > timeout_secs:
             pcollection_var = CacheKey.from_str(self._labels[-1]).var
             raise RuntimeError(
                 'Timed out waiting for cache file for PCollection `{}` to be '
                 'available with path {}.'.format(pcollection_var,
                                                  self._path))
     return open(self._path, mode='rb')
Exemplo n.º 12
0
 def test_cache_output(self):
     p_cache_output = beam.Pipeline()
     pcoll_co = p_cache_output | 'Create Source' >> beam.Create([1, 2, 3])
     cache_manager = FileBasedCacheManager()
     ie.current_env().set_cache_manager(cache_manager, p_cache_output)
     ib.watch(locals())
     with patch(
             'apache_beam.runners.interactive.display.pcoll_visualization.'
             'visualize_computed_pcoll', lambda a, b: None):
         cache_output('pcoll_co', pcoll_co)
         self.assertIn(pcoll_co, ie.current_env().computed_pcollections)
         self.assertTrue(
             cache_manager.exists(
                 'full',
                 CacheKey.from_pcoll('pcoll_co', pcoll_co).to_str()))
Exemplo n.º 13
0
def _build_query_components(
    query: str, found: Dict[str, beam.PCollection]
) -> Tuple[str, Union[Dict[str, beam.PCollection], beam.PCollection,
                      beam.Pipeline]]:
    """Builds necessary components needed to apply the SqlTransform.

  Args:
    query: The SQL query to be executed by the magic.
    found: The PCollections with variable names found to be used by the query.

  Returns:
    The processed query to be executed by the magic and a source to apply the
    SqlTransform to: a dictionary of tagged PCollections, or a single
    PCollection, or the pipeline to execute the query.
  """
    if found:
        user_pipeline = ie.current_env().user_pipeline(
            next(iter(found.values())).pipeline)
        sql_pipeline = beam.Pipeline(options=user_pipeline._options)
        ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline)
        sql_source = {}
        if has_source_to_cache(user_pipeline):
            sql_source = pcolls_from_streaming_cache(user_pipeline,
                                                     sql_pipeline, found)
        else:
            cache_manager = ie.current_env().get_cache_manager(
                user_pipeline, create_if_absent=True)
            for pcoll_name, pcoll in found.items():
                cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str()
                sql_source[pcoll_name] = unreify_from_cache(
                    pipeline=sql_pipeline,
                    cache_key=cache_key,
                    cache_manager=cache_manager,
                    element_type=pcoll.element_type)
        if len(sql_source) == 1:
            query = replace_single_pcoll_token(query,
                                               next(iter(sql_source.keys())))
            sql_source = next(iter(sql_source.values()))
    else:
        sql_source = beam.Pipeline()
        ie.current_env().add_user_pipeline(sql_source)
    return query, sql_source
Exemplo n.º 14
0
    def read(self, pcoll_name, pcoll, max_n, max_duration_secs):
        # type: (str, beam.pvalue.PValue, int, float) -> Union[None, ElementStream]
        """Reads an ElementStream of a computed PCollection.

    Returns None if an error occurs. The caller is responsible of validating if
    the given pcoll_name and pcoll can identify a watched and computed
    PCollection without ambiguity in the notebook.
    """

        try:
            cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str()
            return ElementStream(pcoll, pcoll_name, cache_key, max_n,
                                 max_duration_secs)
        except (KeyboardInterrupt, SystemExit):
            raise
        except Exception as e:
            # Caller should handle all validations. Here to avoid redundant
            # validations, simply log errors if caller fails to do so.
            _LOGGER.error(str(e))
            return None
Exemplo n.º 15
0
    def cache_key(self, pcoll):
        """Gets the identifier of a cacheable PCollection in cache.

    If the pcoll is not a cacheable, return ''.
    This is only needed in pipeline instrument when the origin of given pcoll
    is unknown (whether it's from the user pipeline or a runner pipeline). If
    a pcoll is from the user pipeline, always use CacheKey.from_pcoll to build
    the key.
    The key is what the pcoll would use as identifier if it's materialized in
    cache. It doesn't mean that there would definitely be such cache already.
    Also, the pcoll can come from the original user defined pipeline object or
    an equivalent pcoll from a transformed copy of the original pipeline.
    """
        cacheable = self._cacheables.get(self.pcoll_id(pcoll), None)
        if cacheable:
            if cacheable.pcoll in self.runner_pcoll_to_user_pcoll:
                user_pcoll = self.runner_pcoll_to_user_pcoll[cacheable.pcoll]
            else:
                user_pcoll = cacheable.pcoll
            return CacheKey.from_pcoll(cacheable.var, user_pcoll).to_str()
        return ''
Exemplo n.º 16
0
    def test_recording_manager_clears_cache(self):
        """Tests that the RecordingManager clears the cache before recording.

    A job may have incomplete PCollections when the job terminates. Clearing the
    cache ensures that correct results are computed every run.
    """
        # Add the TestStream so that it can be cached.
        ib.options.recordable_sources.add(TestStream)
        p = beam.Pipeline(InteractiveRunner(),
                          options=PipelineOptions(streaming=True))
        elems = (p
                 | TestStream().advance_watermark_to(
                     0).advance_processing_time(1).add_elements(list(
                         range(10))).advance_processing_time(1))
        squares = elems | beam.Map(lambda x: x**2)

        # Watch the local scope for Interactive Beam so that referenced PCollections
        # will be cached.
        ib.watch(locals())

        # This is normally done in the interactive_utils when a transform is
        # applied but needs an IPython environment. So we manually run this here.
        ie.current_env().track_user_pipelines()

        # Do the first recording to get the timestamp of the first time the fragment
        # was run.
        rm = RecordingManager(p)

        # Set up a mock for the Cache's clear function which will be used to clear
        # uncomputed PCollections.
        rm._clear_pcolls = MagicMock()
        rm.record([squares], max_n=1, max_duration=500)
        rm.cancel()

        # Assert that the cache cleared the PCollection.
        rm._clear_pcolls.assert_any_call(
            unittest.mock.ANY,
            # elems is unbounded source populated by the background job, thus not
            # cleared.
            {CacheKey.from_pcoll('squares', squares).to_str()})
Exemplo n.º 17
0
    def test_streaming_cache_uses_local_ib_cache_root(self):
        """
    Checks that StreamingCache._cache_dir is set to the
    cache_root set under Interactive Beam for a local directory
    and that the cached values are the same as the values of a
    cache using default settings.
    """
        CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
        values = (FileRecordsBuilder(CACHED_PCOLLECTION_KEY)
                      .advance_processing_time(1)
                      .advance_watermark(watermark_secs=0)
                      .add_element(element=1, event_time_secs=0)
                      .build()) # yapf: disable

        local_cache = StreamingCache(cache_dir=None)
        local_cache.write(values, CACHED_PCOLLECTION_KEY)
        reader_one, _ = local_cache.read(CACHED_PCOLLECTION_KEY)
        pcoll_list_one = list(reader_one)

        # Set Interactive Beam specified cache dir to cloud storage
        ib.options.cache_root = '/tmp/it-test/'
        cache_manager_with_ib_option = StreamingCache(
            cache_dir=ib.options.cache_root)

        self.assertEqual(ib.options.cache_root,
                         cache_manager_with_ib_option._cache_dir)

        cache_manager_with_ib_option.write(values, CACHED_PCOLLECTION_KEY)
        reader_two, _ = cache_manager_with_ib_option.read(
            CACHED_PCOLLECTION_KEY)
        pcoll_list_two = list(reader_two)

        self.assertEqual(pcoll_list_one, pcoll_list_two)

        # Reset Interactive Beam setting
        ib.options.cache_root = None
Exemplo n.º 18
0
    def test_multiple_readers(self):
        """Tests that the service advances the clock with multiple outputs.
    """

        CACHED_LETTERS = repr(CacheKey('letters', '', '', ''))
        CACHED_NUMBERS = repr(CacheKey('numbers', '', '', ''))
        CACHED_LATE = repr(CacheKey('late', '', '', ''))

        letters = (FileRecordsBuilder(CACHED_LETTERS)
                   .advance_processing_time(1)
                   .advance_watermark(watermark_secs=0)
                   .add_element(element='a', event_time_secs=0)
                   .advance_processing_time(10)
                   .advance_watermark(watermark_secs=10)
                   .add_element(element='b', event_time_secs=10)
                   .build()) # yapf: disable

        numbers = (FileRecordsBuilder(CACHED_NUMBERS)
                   .advance_processing_time(2)
                   .add_element(element=1, event_time_secs=0)
                   .advance_processing_time(1)
                   .add_element(element=2, event_time_secs=0)
                   .advance_processing_time(1)
                   .add_element(element=2, event_time_secs=0)
                   .build()) # yapf: disable

        late = (FileRecordsBuilder(CACHED_LATE)
                .advance_processing_time(101)
                .add_element(element='late', event_time_secs=0)
                .build()) # yapf: disable

        cache = StreamingCache(cache_dir=None)
        cache.write(letters, CACHED_LETTERS)
        cache.write(numbers, CACHED_NUMBERS)
        cache.write(late, CACHED_LATE)

        reader = cache.read_multiple([[CACHED_LETTERS], [CACHED_NUMBERS],
                                      [CACHED_LATE]])
        coder = coders.FastPrimitivesCoder()
        events = list(reader)

        # Units here are in microseconds.
        expected = [
            # Advances clock from 0 to 1
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=0, tag=CACHED_LETTERS)),
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('a'), timestamp=0)
                    ],
                    tag=CACHED_LETTERS)),

            # Advances clock from 1 to 2
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode(1), timestamp=0)
                    ],
                    tag=CACHED_NUMBERS)),

            # Advances clock from 2 to 3
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode(2), timestamp=0)
                    ],
                    tag=CACHED_NUMBERS)),

            # Advances clock from 3 to 4
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode(2), timestamp=0)
                    ],
                    tag=CACHED_NUMBERS)),

            # Advances clock from 4 to 11
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=7 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=10 * 10**6, tag=CACHED_LETTERS)),
            TestStreamPayload.Event(element_event=TestStreamPayload.Event.
                                    AddElements(elements=[
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('b'),
                                            timestamp=10 * 10**6)
                                    ],
                                                tag=CACHED_LETTERS)),

            # Advances clock from 11 to 101
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=90 * 10**6)),
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('late'), timestamp=0)
                    ],
                    tag=CACHED_LATE)),
        ]

        self.assertSequenceEqual(events, expected)
Exemplo n.º 19
0
    def test_read_and_write_multiple_outputs(self):
        """An integration test between the Sink and Source with multiple outputs.

    This tests the funcionatlity that the StreamingCache reads from multiple
    files and combines them into a single sorted output.
    """
        LETTERS_TAG = repr(CacheKey('letters', '', '', ''))
        NUMBERS_TAG = repr(CacheKey('numbers', '', '', ''))

        # Units here are in seconds.
        test_stream = (TestStream()
                       .advance_watermark_to(0, tag=LETTERS_TAG)
                       .advance_processing_time(5)
                       .add_elements(['a', 'b', 'c'], tag=LETTERS_TAG)
                       .advance_watermark_to(10, tag=NUMBERS_TAG)
                       .advance_processing_time(1)
                       .add_elements(
                           [
                               TimestampedValue('1', 15),
                               TimestampedValue('2', 15),
                               TimestampedValue('3', 15)
                           ],
                           tag=NUMBERS_TAG)) # yapf: disable

        cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)

        coder = SafeFastPrimitivesCoder()

        options = StandardOptions(streaming=True)
        with TestPipeline(options=options) as p:
            # pylint: disable=expression-not-assigned
            events = p | test_stream
            events[LETTERS_TAG] | 'Letters sink' >> cache.sink([LETTERS_TAG])
            events[NUMBERS_TAG] | 'Numbers sink' >> cache.sink([NUMBERS_TAG])

        reader = cache.read_multiple([[LETTERS_TAG], [NUMBERS_TAG]])
        actual_events = list(reader)

        # Units here are in microseconds.
        expected_events = [
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=5 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=0, tag=LETTERS_TAG)),
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('a'), timestamp=0),
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('b'), timestamp=0),
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('c'), timestamp=0),
                    ],
                    tag=LETTERS_TAG)),
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=10 * 10**6, tag=NUMBERS_TAG)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=0, tag=LETTERS_TAG)),
            TestStreamPayload.Event(element_event=TestStreamPayload.Event.
                                    AddElements(elements=[
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('1'),
                                            timestamp=15 * 10**6),
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('2'),
                                            timestamp=15 * 10**6),
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('3'),
                                            timestamp=15 * 10**6),
                                    ],
                                                tag=NUMBERS_TAG)),
        ]

        self.assertListEqual(actual_events, expected_events)
 def cache_key_of(self, name, pcoll):
     return CacheKey.from_pcoll(name, pcoll).to_str()
Exemplo n.º 21
0
    def test_read_and_write(self):
        """An integration test between the Sink and Source.

    This ensures that the sink and source speak the same language in terms of
    coders, protos, order, and units.
    """
        CACHED_RECORDS = repr(CacheKey('records', '', '', ''))

        # Units here are in seconds.
        test_stream = (
            TestStream(output_tags=(CACHED_RECORDS))
                       .advance_watermark_to(0, tag=CACHED_RECORDS)
                       .advance_processing_time(5)
                       .add_elements(['a', 'b', 'c'], tag=CACHED_RECORDS)
                       .advance_watermark_to(10, tag=CACHED_RECORDS)
                       .advance_processing_time(1)
                       .add_elements(
                           [
                               TimestampedValue('1', 15),
                               TimestampedValue('2', 15),
                               TimestampedValue('3', 15)
                           ],
                           tag=CACHED_RECORDS)) # yapf: disable

        coder = SafeFastPrimitivesCoder()
        cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)

        # Assert that there are no capture keys at first.
        self.assertEqual(cache.capture_keys, set())

        options = StandardOptions(streaming=True)
        with TestPipeline(options=options) as p:
            records = (p | test_stream)[CACHED_RECORDS]

            # pylint: disable=expression-not-assigned
            records | cache.sink([CACHED_RECORDS], is_capture=True)

        reader, _ = cache.read(CACHED_RECORDS)
        actual_events = list(reader)

        # Assert that the capture keys are forwarded correctly.
        self.assertEqual(cache.capture_keys, set([CACHED_RECORDS]))

        # Units here are in microseconds.
        expected_events = [
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=5 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=0, tag=CACHED_RECORDS)),
            TestStreamPayload.Event(
                element_event=TestStreamPayload.Event.AddElements(
                    elements=[
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('a'), timestamp=0),
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('b'), timestamp=0),
                        TestStreamPayload.TimestampedElement(
                            encoded_element=coder.encode('c'), timestamp=0),
                    ],
                    tag=CACHED_RECORDS)),
            TestStreamPayload.Event(processing_time_event=TestStreamPayload.
                                    Event.AdvanceProcessingTime(
                                        advance_duration=1 * 10**6)),
            TestStreamPayload.Event(
                watermark_event=TestStreamPayload.Event.AdvanceWatermark(
                    new_watermark=10 * 10**6, tag=CACHED_RECORDS)),
            TestStreamPayload.Event(element_event=TestStreamPayload.Event.
                                    AddElements(elements=[
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('1'),
                                            timestamp=15 * 10**6),
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('2'),
                                            timestamp=15 * 10**6),
                                        TestStreamPayload.TimestampedElement(
                                            encoded_element=coder.encode('3'),
                                            timestamp=15 * 10**6),
                                    ],
                                                tag=CACHED_RECORDS)),
        ]
        self.assertEqual(actual_events, expected_events)