def test_sharded_key_coder(self): key_and_coders = [(b'', b'\x00', coders.BytesCoder()), (b'key', b'\x03key', coders.BytesCoder()), ('key', b'\03\x6b\x65\x79', coders.StrUtf8Coder()), (('k', 1), b'\x01\x6b\x01', coders.TupleCoder( (coders.StrUtf8Coder(), coders.VarIntCoder())))] for key, bytes_repr, key_coder in key_and_coders: coder = coders.ShardedKeyCoder(key_coder) # Verify cloud object representation self.assertEqual( { '@type': 'kind:sharded_key', 'component_encodings': [key_coder.as_cloud_object()] }, coder.as_cloud_object()) self.assertEqual(b'\x00' + bytes_repr, coder.encode(ShardedKey(key, b''))) self.assertEqual(b'\x03123' + bytes_repr, coder.encode(ShardedKey(key, b'123'))) # Test unnested self.check_coder(coder, ShardedKey(key, b'')) self.check_coder(coder, ShardedKey(key, b'123')) for other_key, _, other_key_coder in key_and_coders: other_coder = coders.ShardedKeyCoder(other_key_coder) # Test nested self.check_coder( coders.TupleCoder((coder, other_coder)), (ShardedKey(key, b''), ShardedKey(other_key, b''))) self.check_coder( coders.TupleCoder((coder, other_coder)), (ShardedKey(key, b'123'), ShardedKey(other_key, b'')))
def test_type_check_invalid_key_type(self): constraint = ShardedKeyType[int] obj = ShardedKey(key='abc', shard_id=b'123') with self.assertRaises((TypeError, TypeError)) as e: constraint.type_check(obj) self.assertEqual( "ShardedKey[int] type-constraint violated. The type of key in " "'ShardedKey' is incorrect. Expected an instance of type 'int', " "instead received an instance of type 'str'.", e.exception.args[0])
def test_sharded_key_coder(self): key_and_coders = [(b'', b'\x00', coders.BytesCoder()), (b'key', b'\x03key', coders.BytesCoder()), ('key', b'\03\x6b\x65\x79', coders.StrUtf8Coder()), (('k', 1), b'\x01\x6b\x01', coders.TupleCoder( (coders.StrUtf8Coder(), coders.VarIntCoder())))] for key, bytes_repr, key_coder in key_and_coders: coder = coders.ShardedKeyCoder(key_coder) # Verify cloud object representation self.assertEqual( { '@type': 'kind:sharded_key', 'component_encodings': [key_coder.as_cloud_object()] }, coder.as_cloud_object()) # Test str repr self.assertEqual('%s' % coder, 'ShardedKeyCoder[%s]' % key_coder) self.assertEqual(b'\x00' + bytes_repr, coder.encode(ShardedKey(key, b''))) self.assertEqual(b'\x03123' + bytes_repr, coder.encode(ShardedKey(key, b'123'))) # Test unnested self.check_coder(coder, ShardedKey(key, b'')) self.check_coder(coder, ShardedKey(key, b'123')) # Test type hints self.assertTrue( isinstance(coder.to_type_hint(), sharded_key_type.ShardedKeyTypeConstraint)) key_type = coder.to_type_hint().key_type if isinstance(key_type, typehints.TupleConstraint): self.assertEqual(key_type.tuple_types, (type(key[0]), type(key[1]))) else: self.assertEqual(key_type, type(key)) self.assertEqual( coders.ShardedKeyCoder.from_type_hint( coder.to_type_hint(), typecoders.CoderRegistry()), coder) for other_key, _, other_key_coder in key_and_coders: other_coder = coders.ShardedKeyCoder(other_key_coder) # Test nested self.check_coder( coders.TupleCoder((coder, other_coder)), (ShardedKey(key, b''), ShardedKey(other_key, b''))) self.check_coder( coders.TupleCoder((coder, other_coder)), (ShardedKey(key, b'123'), ShardedKey(other_key, b'')))
def expand(self, pcoll): sharded_pcoll = pcoll | Map(lambda key_value: ( ShardedKey( key_value[0], # Use [uuid, thread id] as the shard id. GroupIntoBatches.WithShardedKey._shard_id_prefix + bytes( threading.get_ident().to_bytes(8, 'big'))), key_value[1])) return (sharded_pcoll | GroupIntoBatches(self.batch_size, self.max_buffering_duration_secs))
def expand(self, pcoll): key_type, value_type = pcoll.element_type.tuple_types sharded_pcoll = pcoll | Map(lambda key_value: ( ShardedKey( key_value[0], # Use [uuid, thread id] as the shard id. GroupIntoBatches.WithShardedKey._shard_id_prefix + bytes( threading.get_ident().to_bytes(8, 'big'))), key_value[1])).with_output_types(typehints.Tuple[ ShardedKeyType[key_type], # type: ignore[misc] value_type]) return (sharded_pcoll | GroupIntoBatches(self.params.batch_size, self.params.max_buffering_duration_secs))
def test_type_check_valid_composite_type(self): constraint = ShardedKeyType[Tuple[int, str]] obj = ShardedKey(key=(1, 'a'), shard_id=b'123') self.assertIsNone(constraint.type_check(obj))
def test_type_check_valid_simple_type(self): constraint = ShardedKeyType[str] obj = ShardedKey(key='abc', shard_id=b'123') self.assertIsNone(constraint.type_check(obj))
class StandardCodersTest(unittest.TestCase): _urn_to_json_value_parser = { 'beam:coder:bytes:v1': lambda x: x.encode('utf-8'), 'beam:coder:bool:v1': lambda x: x, 'beam:coder:string_utf8:v1': lambda x: x, 'beam:coder:varint:v1': lambda x: x, 'beam:coder:kv:v1': lambda x, key_parser, value_parser: (key_parser(x['key']), value_parser(x['value'])), 'beam:coder:interval_window:v1': lambda x: IntervalWindow(start=Timestamp(micros=(x['end'] - x['span']) * 1000), end=Timestamp(micros=x['end'] * 1000)), 'beam:coder:iterable:v1': lambda x, parser: list(map(parser, x)), 'beam:coder:global_window:v1': lambda x: window.GlobalWindow(), 'beam:coder:windowed_value:v1': lambda x, value_parser, window_parser: windowed_value.create( value_parser(x['value']), x['timestamp'] * 1000, tuple([window_parser(w) for w in x['windows']])), 'beam:coder:param_windowed_value:v1': lambda x, value_parser, window_parser: windowed_value.create( value_parser(x['value']), x['timestamp'] * 1000, tuple([window_parser(w) for w in x['windows']]), PaneInfo(x['pane']['is_first'], x['pane']['is_last'], PaneInfoTiming.from_string(x['pane']['timing']), x['pane'] ['index'], x['pane']['on_time_index'])), 'beam:coder:timer:v1': lambda x, value_parser, window_parser: userstate.Timer( user_key=value_parser(x['userKey']), dynamic_timer_tag=x['dynamicTimerTag'], clear_bit=x['clearBit'], windows=tuple([window_parser(w) for w in x['windows']]), fire_timestamp=None, hold_timestamp=None, paneinfo=None) if x['clearBit'] else userstate.Timer( user_key=value_parser(x['userKey']), dynamic_timer_tag=x['dynamicTimerTag'], clear_bit=x['clearBit'], fire_timestamp=Timestamp(micros=x['fireTimestamp'] * 1000), hold_timestamp=Timestamp(micros=x['holdTimestamp'] * 1000), windows=tuple([window_parser(w) for w in x['windows']]), paneinfo=PaneInfo(x['pane']['is_first'], x['pane']['is_last'], PaneInfoTiming.from_string(x['pane']['timing']), x['pane']['index'], x['pane']['on_time_index'])), 'beam:coder:double:v1': parse_float, 'beam:coder:sharded_key:v1': lambda x, value_parser: ShardedKey( key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')), 'beam:coder:custom_window:v1': lambda x, window_parser: window_parser(x['window']) } def test_standard_coders(self): for name, spec in _load_test_cases(STANDARD_CODERS_YAML): logging.info('Executing %s test.', name) self._run_standard_coder(name, spec) def _run_standard_coder(self, name, spec): def assert_equal(actual, expected): """Handle nan values which self.assertEqual fails on.""" if (isinstance(actual, float) and isinstance(expected, float) and math.isnan(actual) and math.isnan(expected)): return self.assertEqual(actual, expected) coder = self.parse_coder(spec['coder']) parse_value = self.json_value_parser(spec['coder']) nested_list = [spec['nested']] if 'nested' in spec else [True, False] for nested in nested_list: for expected_encoded, json_value in spec['examples'].items(): value = parse_value(json_value) expected_encoded = expected_encoded.encode('latin1') if not spec['coder'].get('non_deterministic', False): actual_encoded = encode_nested(coder, value, nested) if self.fix and actual_encoded != expected_encoded: self.to_fix[spec['index'], expected_encoded] = actual_encoded else: self.assertEqual(expected_encoded, actual_encoded) decoded = decode_nested(coder, expected_encoded, nested) assert_equal(decoded, value) else: # Only verify decoding for a non-deterministic coder self.assertEqual( decode_nested(coder, expected_encoded, nested), value) def parse_coder(self, spec): context = pipeline_context.PipelineContext() coder_id = str(hash(str(spec))) component_ids = [ context.coders.get_id(self.parse_coder(c)) for c in spec.get('components', ()) ] context.coders.put_proto( coder_id, beam_runner_api_pb2.Coder(spec=beam_runner_api_pb2.FunctionSpec( urn=spec['urn'], payload=spec.get('payload', '').encode('latin1')), component_coder_ids=component_ids)) return context.coders.get_by_id(coder_id) def json_value_parser(self, coder_spec): # TODO: integrate this with the logic for the other parsers if coder_spec['urn'] == 'beam:coder:row:v1': schema = schema_pb2.Schema.FromString( coder_spec['payload'].encode('latin1')) return value_parser_from_schema(schema) component_parsers = [ self.json_value_parser(c) for c in coder_spec.get('components', ()) ] return lambda x: self._urn_to_json_value_parser[coder_spec['urn']]( x, *component_parsers) # Used when --fix is passed. fix = False to_fix = {} # type: Dict[Tuple[int, bytes], bytes] @classmethod def tearDownClass(cls): if cls.fix and cls.to_fix: print("FIXING", len(cls.to_fix), "TESTS") doc_sep = '\n---\n' docs = open(STANDARD_CODERS_YAML).read().split(doc_sep) def quote(s): return json.dumps(s.decode('latin1')).replace(r'\u0000', r'\0') for (doc_ix, expected_encoded), actual_encoded in cls.to_fix.items(): print(quote(expected_encoded), "->", quote(actual_encoded)) docs[doc_ix] = docs[doc_ix].replace( quote(expected_encoded) + ':', quote(actual_encoded) + ':') open(STANDARD_CODERS_YAML, 'w').write(doc_sep.join(docs))
def decode_from_stream(self, in_stream, nested): # type: (create_InputStream, bool) -> ShardedKey shard_id = self._shard_id_coder_impl.decode_from_stream(in_stream, True) key = self._key_coder_impl.decode_from_stream(in_stream, True) return ShardedKey(key=key, shard_id=shard_id)