Example #1
0
    def test_read_messages_timestamp_attribute_missing(self, mock_pubsub):
        data = b'data'
        attributes = {}
        publish_time_secs = 1520861821
        publish_time_nanos = 234567000
        publish_time = '2018-03-12T13:37:01.234567Z'
        ack_id = 'ack_id'
        pull_response = test_utils.create_pull_response([
            test_utils.PullResponseMessage(data, attributes, publish_time_secs,
                                           publish_time_nanos, ack_id)
        ])
        expected_elements = [
            TestWindowedValue(PubsubMessage(data, attributes),
                              timestamp.Timestamp.from_rfc3339(publish_time),
                              [window.GlobalWindow()]),
        ]
        mock_pubsub.return_value.pull.return_value = pull_response

        options = PipelineOptions([])
        options.view_as(StandardOptions).streaming = True
        with TestPipeline(options=options) as p:
            pcoll = (p
                     | ReadFromPubSub('projects/fakeprj/topics/a_topic',
                                      None,
                                      None,
                                      with_attributes=True,
                                      timestamp_attribute='nonexistent'))
            assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
        mock_pubsub.return_value.acknowledge.assert_has_calls(
            [mock.call(mock.ANY, [ack_id])])

        mock_pubsub.return_value.api.transport.channel.close.assert_has_calls(
            [mock.call()])
Example #2
0
            def finish_bundle(self):
                from apache_beam.transforms import window

                assert self.file_to_read
                for file_name in glob.glob(self.file_to_read):
                    if self.compression_type is None:
                        with open(file_name) as file:
                            for record in file:
                                value = self.coder.decode(record.rstrip('\n'))
                                yield WindowedValue(value, -1,
                                                    [window.GlobalWindow()])
                    else:
                        with gzip.open(file_name, 'r') as file:
                            for record in file:
                                value = self.coder.decode(record.rstrip('\n'))
                                yield WindowedValue(value, -1,
                                                    [window.GlobalWindow()])
 def finish_bundle(self):
     """Runs predictions on remaining elements at end of bundle of elements."""
     logging.info("Run predictions on all intermediate elements.")
     for elements in self.batches.values():
         outputs = self.make_predictions(elements)
         for output in outputs:
             yield WindowedValue(
                 value=output,
                 timestamp=int(time.time()),
                 windows=(window.GlobalWindow(),))
     self.batches = {}
Example #4
0
 def test_global_window_coder(self):
   coder = coders.GlobalWindowCoder()
   value = window.GlobalWindow()
   # Verify cloud object representation
   self.assertEqual({'@type': 'kind:global_window'}, coder.as_cloud_object())
   # Test binary representation
   self.assertEqual(b'', coder.encode(value))
   self.assertEqual(value, coder.decode(b''))
   # Test unnested
   self.check_coder(coder, value)
   # Test nested
   self.check_coder(coders.TupleCoder((coder, coder)), (value, value))
Example #5
0
  def _window(self, output, add_window=False):
    """Forces an output into the global window.

    While 'process' will output to the same window as its incomming element,
    'finish_bundle' has to specify BatchedInferencea window to output into.
    Since we are dealing with a bounded input, we can use 'GlobalWindow'.

    Args:
      output: The function output that may need to be added to a window.
      add_window: Adds output to the GlobalWindow.

    Returns:
      output or output encapsulated in 'WindowedValue'.
    """
    if add_window:
      return windowed_value.WindowedValue(output, -1, [window.GlobalWindow()])
    return output
 def process(self, element):
     """Aggregate elements to batches and yield predictions."""
     shape = element["image"].shape
     dataset = element["dataset"]
     if (dataset, shape) in self.batches:
         self.batches[(dataset, shape)].append(element)
     else:
         self.batches[(dataset, shape)] = [element]
     if len(self.batches[(dataset, shape)]) >= self.batch_size:
         outputs = self.make_predictions(self.batches[(dataset, shape)])
         del self.batches[(dataset, shape)]
         for output in outputs:
             yield WindowedValue(
                 value=output,
                 timestamp=int(time.time()),
                 windows=(window.GlobalWindow(),))
     elif len(self.batches.values()) >= self.batch_size:
         # TODO: fix so that it's counting total number of elements, not shapes
         logging.info("Intermediate storage too large. Flushing.")
         self.finish_bundle()
Example #7
0
  def _test_read_messages_success(self, mock_pubsub):
    payload = 'payload'
    message_id = 'message_id'
    publish_time = '2018-03-12T13:37:01.234567Z'
    attributes = {'key': 'value'}
    data = [create_client_message(
        payload, message_id, attributes, publish_time)]
    expected_data = [TestWindowedValue(PubsubMessage(payload, attributes),
                                       timestamp.Timestamp(1520861821.234567),
                                       [window.GlobalWindow()])]

    mock_pubsub.Client = functools.partial(FakePubsubClient, data)
    mock_pubsub.subscription.AutoAck = FakeAutoAck

    p = TestPipeline()
    p.options.view_as(StandardOptions).streaming = True
    pcoll = (p
             | ReadFromPubSub('projects/fakeprj/topics/a_topic',
                              None, 'a_label', with_attributes=True))
    assert_that(pcoll, equal_to(expected_data), reify_windows=True)
    p.run()
Example #8
0
    def test_no_window_context_fails(self):
        expected_timestamp = timestamp.Timestamp(5)
        # Assuming the default window function is window.GlobalWindows.
        expected_window = window.GlobalWindow()

        class AddTimestampDoFn(beam.DoFn):
            def process(self, element):
                yield window.TimestampedValue(element, expected_timestamp)

        pipeline = TestPipeline()
        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
        expected_windows = [
            TestWindowedValue(kv, expected_timestamp, [expected_window])
            for kv in data
        ]
        before_identity = (pipeline
                           | 'start' >> beam.Create(data)
                           |
                           'add_timestamps' >> beam.ParDo(AddTimestampDoFn()))
        assert_that(before_identity,
                    equal_to(expected_windows),
                    label='before_identity',
                    reify_windows=True)
        after_identity = (
            before_identity
            | 'window' >> beam.WindowInto(
                beam.transforms.util._IdentityWindowFn(
                    coders.GlobalWindowCoder()))
            # This DoFn will return TimestampedValues, making
            # WindowFn.AssignContext passed to IdentityWindowFn
            # contain a window of None. IdentityWindowFn should
            # raise an exception.
            | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn()))
        assert_that(after_identity,
                    equal_to(expected_windows),
                    label='after_identity',
                    reify_windows=True)
        with self.assertRaisesRegex(ValueError,
                                    r'window.*None.*add_timestamps2'):
            pipeline.run()
def globally_windowed_value():
  return windowed_value.WindowedValue(
      value=small_int(), timestamp=12345678, windows=(window.GlobalWindow(), ))
Example #10
0
 def finish_bundle(self):
   if self._batch:
     yield WindowedValue(self._flush_batch(), -1, [window.GlobalWindow()])
class StandardCodersTest(unittest.TestCase):

  _urn_to_coder_class = {
      'urn:beam:coders:bytes:0.1': coders.BytesCoder,
      'urn:beam:coders:varint:0.1': coders.VarIntCoder,
      'urn:beam:coders:kv:0.1': lambda k, v: coders.TupleCoder((k, v)),
      'urn:beam:coders:interval_window:0.1': coders.IntervalWindowCoder,
      'urn:beam:coders:stream:0.1': lambda t: coders.IterableCoder(t),
      'urn:beam:coders:global_window:0.1': coders.GlobalWindowCoder,
      'urn:beam:coders:windowed_value:0.1':
          lambda v, w: coders.WindowedValueCoder(v, w)
  }

  _urn_to_json_value_parser = {
      'urn:beam:coders:bytes:0.1': lambda x: x,
      'urn:beam:coders:varint:0.1': lambda x: x,
      'urn:beam:coders:kv:0.1':
          lambda x, key_parser, value_parser: (key_parser(x['key']),
                                               value_parser(x['value'])),
      'urn:beam:coders:interval_window:0.1':
          lambda x: IntervalWindow(
              start=Timestamp(micros=(x['end'] - x['span']) * 1000),
              end=Timestamp(micros=x['end'] * 1000)),
      'urn:beam:coders:stream:0.1': lambda x, parser: map(parser, x),
      'urn:beam:coders:global_window:0.1': lambda x: window.GlobalWindow(),
      'urn:beam:coders:windowed_value:0.1':
          lambda x, value_parser, window_parser: windowed_value.create(
              value_parser(x['value']), x['timestamp'] * 1000,
              tuple([window_parser(w) for w in x['windows']]))
  }

  def test_standard_coders(self):
    for name, spec in _load_test_cases(STANDARD_CODERS_YAML):
      logging.info('Executing %s test.', name)
      self._run_standard_coder(name, spec)

  def _run_standard_coder(self, name, spec):
    coder = self.parse_coder(spec['coder'])
    parse_value = self.json_value_parser(spec['coder'])
    nested_list = [spec['nested']] if 'nested' in spec else [True, False]
    for nested in nested_list:
      for expected_encoded, json_value in spec['examples'].items():
        value = parse_value(json_value)
        expected_encoded = expected_encoded.encode('latin1')
        if not spec['coder'].get('non_deterministic', False):
          actual_encoded = encode_nested(coder, value, nested)
          if self.fix and actual_encoded != expected_encoded:
            self.to_fix[spec['index'], expected_encoded] = actual_encoded
          else:
            self.assertEqual(expected_encoded, actual_encoded)
            self.assertEqual(decode_nested(coder, expected_encoded, nested),
                             value)
        else:
          # Only verify decoding for a non-deterministic coder
          self.assertEqual(decode_nested(coder, expected_encoded, nested),
                           value)

  def parse_coder(self, spec):
    return self._urn_to_coder_class[spec['urn']](
        *[self.parse_coder(c) for c in spec.get('components', ())])

  def json_value_parser(self, coder_spec):
    component_parsers = [
        self.json_value_parser(c) for c in coder_spec.get('components', ())]
    return lambda x: self._urn_to_json_value_parser[coder_spec['urn']](
        x, *component_parsers)

  # Used when --fix is passed.

  fix = False
  to_fix = {}

  @classmethod
  def tearDownClass(cls):
    if cls.fix and cls.to_fix:
      print "FIXING", len(cls.to_fix), "TESTS"
      doc_sep = '\n---\n'
      docs = open(STANDARD_CODERS_YAML).read().split(doc_sep)

      def quote(s):
        return json.dumps(s.decode('latin1')).replace(r'\u0000', r'\0')
      for (doc_ix, expected_encoded), actual_encoded in cls.to_fix.items():
        print quote(expected_encoded), "->", quote(actual_encoded)
        docs[doc_ix] = docs[doc_ix].replace(
            quote(expected_encoded) + ':', quote(actual_encoded) + ':')
      open(STANDARD_CODERS_YAML, 'w').write(doc_sep.join(docs))
Example #12
0
    def __init__(
            self,
            fn,
            args,
            kwargs,
            side_inputs,
            windowing,
            context=None,
            tagged_receivers=None,
            logger=None,
            step_name=None,
            # Preferred alternative to logger
            # TODO(robertwb): Remove once all runners are updated.
            logging_context=None,
            # Preferred alternative to context
            # TODO(robertwb): Remove once all runners are updated.
            state=None,
            scoped_metrics_container=None):
        """Initializes a DoFnRunner.

    Args:
      fn: user DoFn to invoke
      args: positional side input arguments (static and placeholder), if any
      kwargs: keyword side input arguments (static and placeholder), if any
      side_inputs: list of sideinput.SideInputMaps for deferred side inputs
      windowing: windowing properties of the output PCollection(s)
      context: a DoFnContext to use (deprecated)
      tagged_receivers: a dict of tag name to Receiver objects
      logger: a logging module (deprecated)
      step_name: the name of this step
      logging_context: a LoggingContext object
      state: handle for accessing DoFn state
      scoped_metrics_container: Context switcher for metrics container
    """
        self.step_name = step_name
        self.window_fn = windowing.windowfn
        self.tagged_receivers = tagged_receivers
        self.scoped_metrics_container = (scoped_metrics_container
                                         or ScopedMetricsContainer())

        global_window = window.GlobalWindow()

        # Need to support multiple iterations.
        side_inputs = list(side_inputs)

        if logging_context:
            self.logging_context = logging_context
        else:
            self.logging_context = get_logging_context(logger,
                                                       step_name=step_name)

        # Optimize for the common case.
        self.main_receivers = as_receiver(tagged_receivers[None])

        # TODO(sourabh): Deprecate the use of context
        if state:
            assert context is None
            self.context = DoFnContext(self.step_name, state=state)
        else:
            assert context is not None
            self.context = context

        # TODO(Sourabhbajaj): Remove the usage of OldDoFn
        if isinstance(fn, core.NewDoFn):
            self.is_new_dofn = True

            # Stash values for use in new_dofn_process.
            self.side_inputs = side_inputs
            self.has_windowed_side_inputs = not all(si.is_globally_windowed()
                                                    for si in self.side_inputs)

            self.args = args if args else []
            self.kwargs = kwargs if kwargs else {}
            self.dofn = fn

        else:
            self.is_new_dofn = False
            self.has_windowed_side_inputs = False  # Set to True in one case below.
            if not args and not kwargs:
                self.dofn = fn
                self.dofn_process = fn.process
            else:
                if side_inputs and all(side_input.is_globally_windowed()
                                       for side_input in side_inputs):
                    args, kwargs = util.insert_values_in_args(
                        args, kwargs, [
                            side_input[global_window]
                            for side_input in side_inputs
                        ])
                    side_inputs = []
                if side_inputs:
                    self.has_windowed_side_inputs = True

                    def process(context):
                        w = context.windows[0]
                        cur_args, cur_kwargs = util.insert_values_in_args(
                            args, kwargs,
                            [side_input[w] for side_input in side_inputs])
                        return fn.process(context, *cur_args, **cur_kwargs)

                    self.dofn_process = process
                elif kwargs:
                    self.dofn_process = lambda context: fn.process(
                        context, *args, **kwargs)
                else:
                    self.dofn_process = lambda context: fn.process(
                        context, *args)

                class CurriedFn(core.DoFn):

                    start_bundle = staticmethod(fn.start_bundle)
                    process = staticmethod(self.dofn_process)
                    finish_bundle = staticmethod(fn.finish_bundle)

                self.dofn = CurriedFn()
Example #13
0
    def new_dofn_process(self, element):
        self.context.set_element(element)
        arguments, _, _, defaults = self.dofn.get_function_arguments('process')
        defaults = defaults if defaults else []

        self_in_args = int(self.dofn.is_process_bounded())

        # Call for the process function for each window if has windowed side inputs
        # or if the process accesses the window parameter. We can just call it once
        # otherwise as none of the arguments are changing
        if self.has_windowed_side_inputs or core.NewDoFn.WindowParam in defaults:
            windows = element.windows
        else:
            windows = [window.GlobalWindow()]

        for w in windows:
            args, kwargs = util.insert_values_in_args(
                self.args, self.kwargs, [s[w] for s in self.side_inputs])

            # If there are more arguments than the default then the first argument
            # should be the element and the rest should be picked from the side
            # inputs as window and timestamp should always be tagged
            if len(arguments) > len(defaults) + self_in_args:
                if core.NewDoFn.ElementParam not in defaults:
                    args_to_pick = len(arguments) - len(
                        defaults) - 1 - self_in_args
                    final_args = [element.value] + args[:args_to_pick]
                else:
                    args_to_pick = len(arguments) - len(
                        defaults) - self_in_args
                    final_args = args[:args_to_pick]
            else:
                args_to_pick = 0
                final_args = []
            args = iter(args[args_to_pick:])

            for a, d in zip(arguments[-len(defaults):], defaults):
                if d == core.NewDoFn.ElementParam:
                    final_args.append(element.value)
                elif d == core.NewDoFn.ContextParam:
                    final_args.append(self.context)
                elif d == core.NewDoFn.WindowParam:
                    final_args.append(w)
                elif d == core.NewDoFn.TimestampParam:
                    final_args.append(element.timestamp)
                elif d == core.NewDoFn.SideInputParam:
                    # If no more args are present then the value must be passed via kwarg
                    try:
                        final_args.append(args.next())
                    except StopIteration:
                        if a not in kwargs:
                            raise
                else:
                    # If no more args are present then the value must be passed via kwarg
                    try:
                        final_args.append(args.next())
                    except StopIteration:
                        if a not in kwargs:
                            kwargs[a] = d
            final_args.extend(list(args))
            self._process_outputs(element,
                                  self.dofn.process(*final_args, **kwargs))
Example #14
0
def _global_window_mapping_fn(w, global_window=window.GlobalWindow()):
    return global_window
Example #15
0
 def finish_bundle(self, context=None):
   from apache_beam.transforms import window
   from apache_beam.utils.windowed_value import WindowedValue
   if len(self._cached) > 0:
     yield WindowedValue(self._cached, -1, [window.GlobalWindow()])
Example #16
0
 def finish_bundle(self, *args, **kwargs):
     if self._batch:
         yield WindowedValue(self._batch, -1, [window.GlobalWindow()])
class StandardCodersTest(unittest.TestCase):

    _urn_to_json_value_parser = {
        'beam:coder:bytes:v1':
        lambda x: x.encode('utf-8'),
        'beam:coder:bool:v1':
        lambda x: x,
        'beam:coder:string_utf8:v1':
        lambda x: x,
        'beam:coder:varint:v1':
        lambda x: x,
        'beam:coder:kv:v1':
        lambda x, key_parser, value_parser:
        (key_parser(x['key']), value_parser(x['value'])),
        'beam:coder:interval_window:v1':
        lambda x: IntervalWindow(start=Timestamp(micros=(x['end'] - x['span'])
                                                 * 1000),
                                 end=Timestamp(micros=x['end'] * 1000)),
        'beam:coder:iterable:v1':
        lambda x, parser: list(map(parser, x)),
        'beam:coder:global_window:v1':
        lambda x: window.GlobalWindow(),
        'beam:coder:windowed_value:v1':
        lambda x, value_parser, window_parser: windowed_value.create(
            value_parser(x['value']), x['timestamp'] * 1000,
            tuple([window_parser(w) for w in x['windows']])),
        'beam:coder:param_windowed_value:v1':
        lambda x, value_parser, window_parser: windowed_value.create(
            value_parser(x['value']), x['timestamp'] * 1000,
            tuple([window_parser(w) for w in x['windows']]),
            PaneInfo(x['pane']['is_first'], x['pane']['is_last'],
                     PaneInfoTiming.from_string(x['pane']['timing']), x['pane']
                     ['index'], x['pane']['on_time_index'])),
        'beam:coder:timer:v1':
        lambda x, value_parser, window_parser: userstate.Timer(
            user_key=value_parser(x['userKey']),
            dynamic_timer_tag=x['dynamicTimerTag'],
            clear_bit=x['clearBit'],
            windows=tuple([window_parser(w) for w in x['windows']]),
            fire_timestamp=None,
            hold_timestamp=None,
            paneinfo=None)
        if x['clearBit'] else userstate.Timer(
            user_key=value_parser(x['userKey']),
            dynamic_timer_tag=x['dynamicTimerTag'],
            clear_bit=x['clearBit'],
            fire_timestamp=Timestamp(micros=x['fireTimestamp'] * 1000),
            hold_timestamp=Timestamp(micros=x['holdTimestamp'] * 1000),
            windows=tuple([window_parser(w) for w in x['windows']]),
            paneinfo=PaneInfo(x['pane']['is_first'], x['pane']['is_last'],
                              PaneInfoTiming.from_string(x['pane']['timing']),
                              x['pane']['index'], x['pane']['on_time_index'])),
        'beam:coder:double:v1':
        parse_float,
        'beam:coder:sharded_key:v1':
        lambda x, value_parser: ShardedKey(
            key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')),
        'beam:coder:custom_window:v1':
        lambda x, window_parser: window_parser(x['window'])
    }

    def test_standard_coders(self):
        for name, spec in _load_test_cases(STANDARD_CODERS_YAML):
            logging.info('Executing %s test.', name)
            self._run_standard_coder(name, spec)

    def _run_standard_coder(self, name, spec):
        def assert_equal(actual, expected):
            """Handle nan values which self.assertEqual fails on."""
            if (isinstance(actual, float) and isinstance(expected, float)
                    and math.isnan(actual) and math.isnan(expected)):
                return
            self.assertEqual(actual, expected)

        coder = self.parse_coder(spec['coder'])
        parse_value = self.json_value_parser(spec['coder'])
        nested_list = [spec['nested']] if 'nested' in spec else [True, False]
        for nested in nested_list:
            for expected_encoded, json_value in spec['examples'].items():
                value = parse_value(json_value)
                expected_encoded = expected_encoded.encode('latin1')
                if not spec['coder'].get('non_deterministic', False):
                    actual_encoded = encode_nested(coder, value, nested)
                    if self.fix and actual_encoded != expected_encoded:
                        self.to_fix[spec['index'],
                                    expected_encoded] = actual_encoded
                    else:
                        self.assertEqual(expected_encoded, actual_encoded)
                        decoded = decode_nested(coder, expected_encoded,
                                                nested)
                        assert_equal(decoded, value)
                else:
                    # Only verify decoding for a non-deterministic coder
                    self.assertEqual(
                        decode_nested(coder, expected_encoded, nested), value)

    def parse_coder(self, spec):
        context = pipeline_context.PipelineContext()
        coder_id = str(hash(str(spec)))
        component_ids = [
            context.coders.get_id(self.parse_coder(c))
            for c in spec.get('components', ())
        ]
        context.coders.put_proto(
            coder_id,
            beam_runner_api_pb2.Coder(spec=beam_runner_api_pb2.FunctionSpec(
                urn=spec['urn'],
                payload=spec.get('payload', '').encode('latin1')),
                                      component_coder_ids=component_ids))
        return context.coders.get_by_id(coder_id)

    def json_value_parser(self, coder_spec):
        # TODO: integrate this with the logic for the other parsers
        if coder_spec['urn'] == 'beam:coder:row:v1':
            schema = schema_pb2.Schema.FromString(
                coder_spec['payload'].encode('latin1'))
            return value_parser_from_schema(schema)

        component_parsers = [
            self.json_value_parser(c)
            for c in coder_spec.get('components', ())
        ]
        return lambda x: self._urn_to_json_value_parser[coder_spec['urn']](
            x, *component_parsers)

    # Used when --fix is passed.

    fix = False
    to_fix = {}  # type: Dict[Tuple[int, bytes], bytes]

    @classmethod
    def tearDownClass(cls):
        if cls.fix and cls.to_fix:
            print("FIXING", len(cls.to_fix), "TESTS")
            doc_sep = '\n---\n'
            docs = open(STANDARD_CODERS_YAML).read().split(doc_sep)

            def quote(s):
                return json.dumps(s.decode('latin1')).replace(r'\u0000', r'\0')

            for (doc_ix,
                 expected_encoded), actual_encoded in cls.to_fix.items():
                print(quote(expected_encoded), "->", quote(actual_encoded))
                docs[doc_ix] = docs[doc_ix].replace(
                    quote(expected_encoded) + ':',
                    quote(actual_encoded) + ':')
            open(STANDARD_CODERS_YAML, 'w').write(doc_sep.join(docs))
Example #18
0
def _global_window_mapping_fn(w, global_window=window.GlobalWindow()):
  # type: (...) -> window.GlobalWindow
  return global_window
Example #19
0
 def finish_bundle(self, element=None):
     if len(self._cached) > 0:  # pylint: disable=g-explicit-length-test
         yield WindowedValue(self._cached, -1, [window.GlobalWindow()])
Example #20
0
 def finish_bundle(self):
   if self.writer is not None:
     yield WindowedValue(self.writer.close(),
                         window.GlobalWindow().max_timestamp(),
                         [window.GlobalWindow()])
Example #21
0
 def __init__(self):
     from apache_beam.transforms import window
     super(GlobalWindowCoder, self).__init__(window.GlobalWindow())
Example #22
0
  def finish_bundle(self, element=None):
    from apache_beam.transforms import window
    from apache_beam.utils.windowed_value import WindowedValue

    if len(self._cached) > 0:  # pylint: disable=g-explicit-length-test
      yield WindowedValue(self._cached, -1, [window.GlobalWindow()])
Example #23
0
class StandardCodersTest(unittest.TestCase):

    _urn_to_json_value_parser = {
        'beam:coder:bytes:v1':
        lambda x: x.encode('utf-8'),
        'beam:coder:string_utf8:v1':
        lambda x: x,
        'beam:coder:varint:v1':
        lambda x: x,
        'beam:coder:kv:v1':
        lambda x, key_parser, value_parser:
        (key_parser(x['key']), value_parser(x['value'])),
        'beam:coder:interval_window:v1':
        lambda x: IntervalWindow(start=Timestamp(micros=(x['end'] - x['span'])
                                                 * 1000),
                                 end=Timestamp(micros=x['end'] * 1000)),
        'beam:coder:iterable:v1':
        lambda x, parser: list(map(parser, x)),
        'beam:coder:global_window:v1':
        lambda x: window.GlobalWindow(),
        'beam:coder:windowed_value:v1':
        lambda x, value_parser, window_parser: windowed_value.create(
            value_parser(x['value']), x['timestamp'] * 1000,
            tuple([window_parser(w) for w in x['windows']])),
        'beam:coder:timer:v1':
        lambda x, payload_parser: dict(payload=payload_parser(x['payload']),
                                       timestamp=Timestamp(micros=x['timestamp'
                                                                    ] * 1000)),
        'beam:coder:double:v1':
        parse_float,
    }

    def test_standard_coders(self):
        for name, spec in _load_test_cases(STANDARD_CODERS_YAML):
            logging.info('Executing %s test.', name)
            self._run_standard_coder(name, spec)

    def _run_standard_coder(self, name, spec):
        def assert_equal(actual, expected):
            """Handle nan values which self.assertEqual fails on."""
            if (isinstance(actual, float) and isinstance(expected, float)
                    and math.isnan(actual) and math.isnan(expected)):
                return
            self.assertEqual(actual, expected)

        coder = self.parse_coder(spec['coder'])
        parse_value = self.json_value_parser(spec['coder'])
        nested_list = [spec['nested']] if 'nested' in spec else [True, False]
        for nested in nested_list:
            for expected_encoded, json_value in spec['examples'].items():
                value = parse_value(json_value)
                expected_encoded = expected_encoded.encode('latin1')
                if not spec['coder'].get('non_deterministic', False):
                    actual_encoded = encode_nested(coder, value, nested)
                    if self.fix and actual_encoded != expected_encoded:
                        self.to_fix[spec['index'],
                                    expected_encoded] = actual_encoded
                    else:
                        self.assertEqual(expected_encoded, actual_encoded)
                        decoded = decode_nested(coder, expected_encoded,
                                                nested)
                        assert_equal(decoded, value)
                else:
                    # Only verify decoding for a non-deterministic coder
                    self.assertEqual(
                        decode_nested(coder, expected_encoded, nested), value)

    def parse_coder(self, spec):
        context = pipeline_context.PipelineContext()
        coder_id = str(hash(str(spec)))
        component_ids = [
            context.coders.get_id(self.parse_coder(c))
            for c in spec.get('components', ())
        ]
        context.coders.put_proto(
            coder_id,
            beam_runner_api_pb2.Coder(spec=beam_runner_api_pb2.FunctionSpec(
                urn=spec['urn'], payload=spec.get('payload')),
                                      component_coder_ids=component_ids))
        return context.coders.get_by_id(coder_id)

    def json_value_parser(self, coder_spec):
        component_parsers = [
            self.json_value_parser(c)
            for c in coder_spec.get('components', ())
        ]
        return lambda x: self._urn_to_json_value_parser[coder_spec['urn']](
            x, *component_parsers)

    # Used when --fix is passed.

    fix = False
    to_fix = {}

    @classmethod
    def tearDownClass(cls):
        if cls.fix and cls.to_fix:
            print("FIXING", len(cls.to_fix), "TESTS")
            doc_sep = '\n---\n'
            docs = open(STANDARD_CODERS_YAML).read().split(doc_sep)

            def quote(s):
                return json.dumps(s.decode('latin1')).replace(r'\u0000', r'\0')

            for (doc_ix,
                 expected_encoded), actual_encoded in cls.to_fix.items():
                print(quote(expected_encoded), "->", quote(actual_encoded))
                docs[doc_ix] = docs[doc_ix].replace(
                    quote(expected_encoded) + ':',
                    quote(actual_encoded) + ':')
            open(STANDARD_CODERS_YAML, 'w').write(doc_sep.join(docs))
Example #24
0
 def finish_bundle(self):
     if self.writer is not None:
         yield WindowedValue(self.writer.close(), window.MAX_TIMESTAMP,
                             [window.GlobalWindow()])