def test_sharded_key_coder(self):
        key_and_coders = [(b'', b'\x00', coders.BytesCoder()),
                          (b'key', b'\x03key', coders.BytesCoder()),
                          ('key', b'\03\x6b\x65\x79', coders.StrUtf8Coder()),
                          (('k', 1), b'\x01\x6b\x01',
                           coders.TupleCoder(
                               (coders.StrUtf8Coder(), coders.VarIntCoder())))]

        for key, bytes_repr, key_coder in key_and_coders:
            coder = coders.ShardedKeyCoder(key_coder)
            # Verify cloud object representation
            self.assertEqual(
                {
                    '@type': 'kind:sharded_key',
                    'component_encodings': [key_coder.as_cloud_object()]
                }, coder.as_cloud_object())
            self.assertEqual(b'\x00' + bytes_repr,
                             coder.encode(ShardedKey(key, b'')))
            self.assertEqual(b'\x03123' + bytes_repr,
                             coder.encode(ShardedKey(key, b'123')))

            # Test unnested
            self.check_coder(coder, ShardedKey(key, b''))
            self.check_coder(coder, ShardedKey(key, b'123'))

            for other_key, _, other_key_coder in key_and_coders:
                other_coder = coders.ShardedKeyCoder(other_key_coder)
                # Test nested
                self.check_coder(
                    coders.TupleCoder((coder, other_coder)),
                    (ShardedKey(key, b''), ShardedKey(other_key, b'')))
                self.check_coder(
                    coders.TupleCoder((coder, other_coder)),
                    (ShardedKey(key, b'123'), ShardedKey(other_key, b'')))
 def test_type_check_invalid_key_type(self):
     constraint = ShardedKeyType[int]
     obj = ShardedKey(key='abc', shard_id=b'123')
     with self.assertRaises((TypeError, TypeError)) as e:
         constraint.type_check(obj)
     self.assertEqual(
         "ShardedKey[int] type-constraint violated. The type of key in "
         "'ShardedKey' is incorrect. Expected an instance of type 'int', "
         "instead received an instance of type 'str'.", e.exception.args[0])
    def test_sharded_key_coder(self):
        key_and_coders = [(b'', b'\x00', coders.BytesCoder()),
                          (b'key', b'\x03key', coders.BytesCoder()),
                          ('key', b'\03\x6b\x65\x79', coders.StrUtf8Coder()),
                          (('k', 1), b'\x01\x6b\x01',
                           coders.TupleCoder(
                               (coders.StrUtf8Coder(), coders.VarIntCoder())))]

        for key, bytes_repr, key_coder in key_and_coders:
            coder = coders.ShardedKeyCoder(key_coder)
            # Verify cloud object representation
            self.assertEqual(
                {
                    '@type': 'kind:sharded_key',
                    'component_encodings': [key_coder.as_cloud_object()]
                }, coder.as_cloud_object())

            # Test str repr
            self.assertEqual('%s' % coder, 'ShardedKeyCoder[%s]' % key_coder)

            self.assertEqual(b'\x00' + bytes_repr,
                             coder.encode(ShardedKey(key, b'')))
            self.assertEqual(b'\x03123' + bytes_repr,
                             coder.encode(ShardedKey(key, b'123')))

            # Test unnested
            self.check_coder(coder, ShardedKey(key, b''))
            self.check_coder(coder, ShardedKey(key, b'123'))

            # Test type hints
            self.assertTrue(
                isinstance(coder.to_type_hint(),
                           sharded_key_type.ShardedKeyTypeConstraint))
            key_type = coder.to_type_hint().key_type
            if isinstance(key_type, typehints.TupleConstraint):
                self.assertEqual(key_type.tuple_types,
                                 (type(key[0]), type(key[1])))
            else:
                self.assertEqual(key_type, type(key))
            self.assertEqual(
                coders.ShardedKeyCoder.from_type_hint(
                    coder.to_type_hint(), typecoders.CoderRegistry()), coder)

            for other_key, _, other_key_coder in key_and_coders:
                other_coder = coders.ShardedKeyCoder(other_key_coder)
                # Test nested
                self.check_coder(
                    coders.TupleCoder((coder, other_coder)),
                    (ShardedKey(key, b''), ShardedKey(other_key, b'')))
                self.check_coder(
                    coders.TupleCoder((coder, other_coder)),
                    (ShardedKey(key, b'123'), ShardedKey(other_key, b'')))
Beispiel #4
0
 def expand(self, pcoll):
     sharded_pcoll = pcoll | Map(lambda key_value: (
         ShardedKey(
             key_value[0],
             # Use [uuid, thread id] as the shard id.
             GroupIntoBatches.WithShardedKey._shard_id_prefix + bytes(
                 threading.get_ident().to_bytes(8, 'big'))),
         key_value[1]))
     return (sharded_pcoll
             | GroupIntoBatches(self.batch_size,
                                self.max_buffering_duration_secs))
Beispiel #5
0
 def expand(self, pcoll):
     key_type, value_type = pcoll.element_type.tuple_types
     sharded_pcoll = pcoll | Map(lambda key_value: (
         ShardedKey(
             key_value[0],
             # Use [uuid, thread id] as the shard id.
             GroupIntoBatches.WithShardedKey._shard_id_prefix + bytes(
                 threading.get_ident().to_bytes(8, 'big'))),
         key_value[1])).with_output_types(typehints.Tuple[
             ShardedKeyType[key_type],  # type: ignore[misc]
             value_type])
     return (sharded_pcoll
             |
             GroupIntoBatches(self.params.batch_size,
                              self.params.max_buffering_duration_secs))
 def test_type_check_valid_composite_type(self):
     constraint = ShardedKeyType[Tuple[int, str]]
     obj = ShardedKey(key=(1, 'a'), shard_id=b'123')
     self.assertIsNone(constraint.type_check(obj))
 def test_type_check_valid_simple_type(self):
     constraint = ShardedKeyType[str]
     obj = ShardedKey(key='abc', shard_id=b'123')
     self.assertIsNone(constraint.type_check(obj))
class StandardCodersTest(unittest.TestCase):

    _urn_to_json_value_parser = {
        'beam:coder:bytes:v1':
        lambda x: x.encode('utf-8'),
        'beam:coder:bool:v1':
        lambda x: x,
        'beam:coder:string_utf8:v1':
        lambda x: x,
        'beam:coder:varint:v1':
        lambda x: x,
        'beam:coder:kv:v1':
        lambda x, key_parser, value_parser:
        (key_parser(x['key']), value_parser(x['value'])),
        'beam:coder:interval_window:v1':
        lambda x: IntervalWindow(start=Timestamp(micros=(x['end'] - x['span'])
                                                 * 1000),
                                 end=Timestamp(micros=x['end'] * 1000)),
        'beam:coder:iterable:v1':
        lambda x, parser: list(map(parser, x)),
        'beam:coder:global_window:v1':
        lambda x: window.GlobalWindow(),
        'beam:coder:windowed_value:v1':
        lambda x, value_parser, window_parser: windowed_value.create(
            value_parser(x['value']), x['timestamp'] * 1000,
            tuple([window_parser(w) for w in x['windows']])),
        'beam:coder:param_windowed_value:v1':
        lambda x, value_parser, window_parser: windowed_value.create(
            value_parser(x['value']), x['timestamp'] * 1000,
            tuple([window_parser(w) for w in x['windows']]),
            PaneInfo(x['pane']['is_first'], x['pane']['is_last'],
                     PaneInfoTiming.from_string(x['pane']['timing']), x['pane']
                     ['index'], x['pane']['on_time_index'])),
        'beam:coder:timer:v1':
        lambda x, value_parser, window_parser: userstate.Timer(
            user_key=value_parser(x['userKey']),
            dynamic_timer_tag=x['dynamicTimerTag'],
            clear_bit=x['clearBit'],
            windows=tuple([window_parser(w) for w in x['windows']]),
            fire_timestamp=None,
            hold_timestamp=None,
            paneinfo=None)
        if x['clearBit'] else userstate.Timer(
            user_key=value_parser(x['userKey']),
            dynamic_timer_tag=x['dynamicTimerTag'],
            clear_bit=x['clearBit'],
            fire_timestamp=Timestamp(micros=x['fireTimestamp'] * 1000),
            hold_timestamp=Timestamp(micros=x['holdTimestamp'] * 1000),
            windows=tuple([window_parser(w) for w in x['windows']]),
            paneinfo=PaneInfo(x['pane']['is_first'], x['pane']['is_last'],
                              PaneInfoTiming.from_string(x['pane']['timing']),
                              x['pane']['index'], x['pane']['on_time_index'])),
        'beam:coder:double:v1':
        parse_float,
        'beam:coder:sharded_key:v1':
        lambda x, value_parser: ShardedKey(
            key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')),
        'beam:coder:custom_window:v1':
        lambda x, window_parser: window_parser(x['window'])
    }

    def test_standard_coders(self):
        for name, spec in _load_test_cases(STANDARD_CODERS_YAML):
            logging.info('Executing %s test.', name)
            self._run_standard_coder(name, spec)

    def _run_standard_coder(self, name, spec):
        def assert_equal(actual, expected):
            """Handle nan values which self.assertEqual fails on."""
            if (isinstance(actual, float) and isinstance(expected, float)
                    and math.isnan(actual) and math.isnan(expected)):
                return
            self.assertEqual(actual, expected)

        coder = self.parse_coder(spec['coder'])
        parse_value = self.json_value_parser(spec['coder'])
        nested_list = [spec['nested']] if 'nested' in spec else [True, False]
        for nested in nested_list:
            for expected_encoded, json_value in spec['examples'].items():
                value = parse_value(json_value)
                expected_encoded = expected_encoded.encode('latin1')
                if not spec['coder'].get('non_deterministic', False):
                    actual_encoded = encode_nested(coder, value, nested)
                    if self.fix and actual_encoded != expected_encoded:
                        self.to_fix[spec['index'],
                                    expected_encoded] = actual_encoded
                    else:
                        self.assertEqual(expected_encoded, actual_encoded)
                        decoded = decode_nested(coder, expected_encoded,
                                                nested)
                        assert_equal(decoded, value)
                else:
                    # Only verify decoding for a non-deterministic coder
                    self.assertEqual(
                        decode_nested(coder, expected_encoded, nested), value)

    def parse_coder(self, spec):
        context = pipeline_context.PipelineContext()
        coder_id = str(hash(str(spec)))
        component_ids = [
            context.coders.get_id(self.parse_coder(c))
            for c in spec.get('components', ())
        ]
        context.coders.put_proto(
            coder_id,
            beam_runner_api_pb2.Coder(spec=beam_runner_api_pb2.FunctionSpec(
                urn=spec['urn'],
                payload=spec.get('payload', '').encode('latin1')),
                                      component_coder_ids=component_ids))
        return context.coders.get_by_id(coder_id)

    def json_value_parser(self, coder_spec):
        # TODO: integrate this with the logic for the other parsers
        if coder_spec['urn'] == 'beam:coder:row:v1':
            schema = schema_pb2.Schema.FromString(
                coder_spec['payload'].encode('latin1'))
            return value_parser_from_schema(schema)

        component_parsers = [
            self.json_value_parser(c)
            for c in coder_spec.get('components', ())
        ]
        return lambda x: self._urn_to_json_value_parser[coder_spec['urn']](
            x, *component_parsers)

    # Used when --fix is passed.

    fix = False
    to_fix = {}  # type: Dict[Tuple[int, bytes], bytes]

    @classmethod
    def tearDownClass(cls):
        if cls.fix and cls.to_fix:
            print("FIXING", len(cls.to_fix), "TESTS")
            doc_sep = '\n---\n'
            docs = open(STANDARD_CODERS_YAML).read().split(doc_sep)

            def quote(s):
                return json.dumps(s.decode('latin1')).replace(r'\u0000', r'\0')

            for (doc_ix,
                 expected_encoded), actual_encoded in cls.to_fix.items():
                print(quote(expected_encoded), "->", quote(actual_encoded))
                docs[doc_ix] = docs[doc_ix].replace(
                    quote(expected_encoded) + ':',
                    quote(actual_encoded) + ':')
            open(STANDARD_CODERS_YAML, 'w').write(doc_sep.join(docs))
 def decode_from_stream(self, in_stream, nested):
   # type: (create_InputStream, bool) -> ShardedKey
   shard_id = self._shard_id_coder_impl.decode_from_stream(in_stream, True)
   key = self._key_coder_impl.decode_from_stream(in_stream, True)
   return ShardedKey(key=key, shard_id=shard_id)