class ServiceStateWrapper: """ Class wraps ServiceState interactions with redis """ # Unique typename for Redis key REDIS_VALUE_TYPE = "systemd_status" def __init__(self): serde = RedisSerde(self.REDIS_VALUE_TYPE, get_proto_serializer(), get_proto_deserializer(ServiceExitStatus)) self._flat_dict = RedisFlatDict(get_default_client(), serde) def update_service_status(self, service_name: str, service_status: ServiceExitStatus) -> None: """ Update the service exit status for a given service """ if service_name in self._flat_dict: current_service_status = self._flat_dict[service_name] else: current_service_status = ServiceExitStatus() if service_status.latest_service_result == \ ServiceExitStatus.ServiceResult.Value("SUCCESS"): service_status.num_clean_exits = \ current_service_status.num_clean_exits + 1 service_status.num_fail_exits = \ current_service_status.num_fail_exits else: service_status.num_fail_exits = \ current_service_status.num_fail_exits + 1 service_status.num_clean_exits = \ current_service_status.num_clean_exits self._flat_dict[service_name] = service_status def get_service_status(self, service_name: str) -> ServiceExitStatus: """ Get the service status protobuf for a given service @returns ServiceStatus protobuf object """ return self._flat_dict[service_name] def get_all_services_status(self) -> [str, ServiceExitStatus]: """ Get a dict of service name to service status @return dict of service_name to service map """ service_status = {} for k, v in self._flat_dict.items(): service_status[k] = v return service_status def cleanup_service_status(self) -> None: """ Cleanup service status for all services in redis, mostly using for testing """ self._flat_dict.clear()
def __init__(self): serde = RedisSerde( self.REDIS_VALUE_TYPE, get_proto_serializer(), get_proto_deserializer(ServiceExitStatus), ) self._flat_dict = RedisFlatDict(get_default_client(), serde)
def __init__(self, print_grpc_payload: bool = False): """Initialize Directoryd grpc endpoints.""" serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(), get_json_deserializer()) self._redis_dict = RedisFlatDict(get_default_client(), serde) self._print_grpc_payload = print_grpc_payload if self._print_grpc_payload: logging.info("Printing GRPC messages")
def setUp(self): client = get_default_client() # Use arbitrary orc8r proto to test with self._hash_dict = RedisHashDict(client, "unittest", get_proto_serializer(), get_proto_deserializer(LogVerbosity)) serde = RedisSerde('log_verbosity', get_proto_serializer(), get_proto_deserializer(LogVerbosity)) self._flat_dict = RedisFlatDict(client, serde)
async def _delete_state_from_redis(self, redis_dict: RedisFlatDict, key: str) -> None: # Ensure that the object isn't updated before deletion with redis_dict.lock(key): deleted = redis_dict.delete_garbage(key) if deleted: logging.debug( "Successfully garbage collected " "state for key: %s", key) else: logging.debug( "Successfully garbage collected " "state in cloud for key %s. " "Didn't delete locally as the " "object is no longer garbage", key)
def __init__(self, service: MagmaService, grpc_client_manager: GRPCClientManager): super().__init__(DEFAULT_SYNC_INTERVAL, service.loop) self._service = service # In memory mapping of states to version self._state_versions = {} # Serdes for each type of state to replicate self._serdes = {} self._get_proto_redis_serdes() self._get_json_redis_serdes() self._redis_client = RedisFlatDict(get_default_client(), self._serdes) # _grpc_client_manager to manage grpc client recyclings self._grpc_client_manager = grpc_client_manager # Flag to indicate if resync has completed successfully. # Replication cannot proceed until this flag is True self._has_resync_completed = False
def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) service = MagicMock() service.config = { # Replicate arbitrary orc8r protos 'state_protos': [ { 'proto_file': 'orc8r.protos.common_pb2', 'proto_msg': 'NetworkID', 'redis_key': NID_TYPE, 'state_scope': 'network' }, { 'proto_file': 'orc8r.protos.service303_pb2', 'proto_msg': 'LogVerbosity', 'redis_key': LOG_TYPE, 'state_scope': 'gateway' }, ], 'json_state': [{ 'redis_key': FOO_TYPE, 'state_scope': 'gateway' }] } service.loop = self.loop # Bind the rpc server to a free port self._rpc_server = grpc.server( futures.ThreadPoolExecutor(max_workers=10)) port = self._rpc_server.add_insecure_port('0.0.0.0:0') # Add the servicer self._servicer = DummyStateServer() self._servicer.add_to_server(self._rpc_server) self._rpc_server.start() # Create a rpc stub self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port)) serde1 = RedisSerde(NID_TYPE, get_proto_serializer(), get_proto_deserializer(NetworkID)) serde2 = RedisSerde(FOO_TYPE, get_json_serializer(), get_json_deserializer()) serde3 = RedisSerde(LOG_TYPE, get_proto_serializer(), get_proto_deserializer(LogVerbosity)) self.nid_client = RedisFlatDict(get_default_client(), serde1) self.foo_client = RedisFlatDict(get_default_client(), serde2) self.log_client = RedisFlatDict(get_default_client(), serde3) # Set up and start garbage collecting loop grpc_client_manager = GRPCClientManager( service_name="state", service_stub=StateServiceStub, max_client_reuse=60, ) # Start state garbage collection loop self.garbage_collector = GarbageCollector(service, grpc_client_manager)
class RedisDictTests(TestCase): """ Tests for the RedisHashDict and RedisFlatDict containers """ @mock.patch("redis.Redis", MockRedis) def setUp(self): client = get_default_client() # Use arbitrary orc8r proto to test with self._hash_dict = RedisHashDict(client, "unittest", get_proto_serializer(), get_proto_deserializer(LogVerbosity)) serde = RedisSerde('log_verbosity', get_proto_serializer(), get_proto_deserializer(LogVerbosity)) self._flat_dict = RedisFlatDict(client, serde) @mock.patch("redis.Redis", MockRedis) def test_hash_insert(self): expected = LogVerbosity(verbosity=0) expected2 = LogVerbosity(verbosity=1) # insert proto self._hash_dict['key1'] = expected version = self._hash_dict.get_version("key1") actual = self._hash_dict['key1'] self.assertEqual(1, version) self.assertEqual(expected, actual) # update proto self._hash_dict['key1'] = expected2 version2 = self._hash_dict.get_version("key1") actual2 = self._hash_dict['key1'] self.assertEqual(2, version2) self.assertEqual(expected2, actual2) @mock.patch("redis.Redis", MockRedis) def test_missing_version(self): missing_version = self._hash_dict.get_version("key2") self.assertEqual(0, missing_version) @mock.patch("redis.Redis", MockRedis) def test_hash_delete(self): expected = LogVerbosity(verbosity=2) self._hash_dict['key3'] = expected actual = self._hash_dict['key3'] self.assertEqual(expected, actual) self._hash_dict.pop('key3') self.assertRaises(KeyError, self._hash_dict.__getitem__, 'key3') @mock.patch("redis.Redis", MockRedis) def test_flat_insert(self): expected = LogVerbosity(verbosity=5) expected2 = LogVerbosity(verbosity=1) # insert proto self._flat_dict['key1'] = expected version = self._flat_dict.get_version("key1") actual = self._flat_dict['key1'] self.assertEqual(1, version) self.assertEqual(expected, actual) # update proto self._flat_dict["key1"] = expected2 version2 = self._flat_dict.get_version("key1") actual2 = self._flat_dict["key1"] actual3 = self._flat_dict.get("key1") self.assertEqual(2, version2) self.assertEqual(expected2, actual2) self.assertEqual(expected2, actual3) @mock.patch("redis.Redis", MockRedis) def test_flat_missing_version(self): missing_version = self._flat_dict.get_version("key2") self.assertEqual(0, missing_version) @mock.patch("redis.Redis", MockRedis) def test_flat_bad_key(self): expected = LogVerbosity(verbosity=2) self.assertRaises(ValueError, self._flat_dict.__setitem__, 'bad:key', expected) self.assertRaises(ValueError, self._flat_dict.__getitem__, 'bad:key') self.assertRaises(ValueError, self._flat_dict.__delitem__, 'bad:key') @mock.patch("redis.Redis", MockRedis) def test_flat_delete(self): expected = LogVerbosity(verbosity=2) self._flat_dict['key3'] = expected actual = self._flat_dict['key3'] self.assertEqual(expected, actual) del self._flat_dict['key3'] self.assertRaises(KeyError, self._flat_dict.__getitem__, 'key3') self.assertEqual(None, self._flat_dict.get('key3')) @mock.patch("redis.Redis", MockRedis) def test_flat_clear(self): expected = LogVerbosity(verbosity=2) self._flat_dict['key3'] = expected actual = self._flat_dict['key3'] self.assertEqual(expected, actual) self._flat_dict.clear() self.assertEqual(0, len(self._flat_dict.keys())) @mock.patch("redis.Redis", MockRedis) def test_flat_garbage_methods(self): expected = LogVerbosity(verbosity=2) expected2 = LogVerbosity(verbosity=3) key = "k1" key2 = "k2" bad_key = "bad_key" self._flat_dict[key] = expected self._flat_dict[key2] = expected2 self._flat_dict.mark_as_garbage(key) is_garbage = self._flat_dict.is_garbage(key) self.assertTrue(is_garbage) is_garbage2 = self._flat_dict.is_garbage(key2) self.assertFalse(is_garbage2) self.assertEqual([key], self._flat_dict.garbage_keys()) self.assertEqual([key2], self._flat_dict.keys()) self.assertIsNone(self._flat_dict.get(key)) self.assertEqual(expected2, self._flat_dict.get(key2)) deleted = self._flat_dict.delete_garbage(key) not_deleted = self._flat_dict.delete_garbage(key2) self.assertTrue(deleted) self.assertFalse(not_deleted) self.assertIsNone(self._flat_dict.get(key)) self.assertEqual(expected2, self._flat_dict.get(key2)) with self.assertRaises(KeyError): self._flat_dict.is_garbage(bad_key) with self.assertRaises(KeyError): self._flat_dict.mark_as_garbage(bad_key)
class GarbageCollectorTests(TestCase): def setUp(self): self.mock_redis = fakeredis.FakeStrictRedis() self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) service = MagicMock() service.config = { # Replicate arbitrary orc8r protos 'state_protos': [{'proto_file': 'orc8r.protos.common_pb2', 'proto_msg': 'NetworkID', 'redis_key': NID_TYPE, 'state_scope': 'network'}, {'proto_file': 'orc8r.protos.service303_pb2', 'proto_msg': 'LogVerbosity', 'redis_key': LOG_TYPE, 'state_scope': 'gateway'}, ], 'json_state': [{'redis_key': FOO_TYPE, 'state_scope': 'gateway'}] } service.loop = self.loop # Bind the rpc server to a free port self._rpc_server = grpc.server( futures.ThreadPoolExecutor(max_workers=10) ) port = self._rpc_server.add_insecure_port('0.0.0.0:0') # Add the servicer self._servicer = DummyStateServer() self._servicer.add_to_server(self._rpc_server) self._rpc_server.start() # Create a rpc stub self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port)) serde1 = RedisSerde(NID_TYPE, get_proto_serializer(), get_proto_deserializer(NetworkID)) serde2 = RedisSerde(FOO_TYPE, get_json_serializer(), get_json_deserializer()) serde3 = RedisSerde(LOG_TYPE, get_proto_serializer(), get_proto_deserializer(LogVerbosity)) self.nid_client = RedisFlatDict(self.mock_redis, serde1) self.foo_client = RedisFlatDict(self.mock_redis, serde2) self.log_client = RedisFlatDict(self.mock_redis, serde3) # Set up and start garbage collecting loop grpc_client_manager = GRPCClientManager( service_name="state", service_stub=StateServiceStub, max_client_reuse=60, ) # mock the get_default_client function used to return the same # fakeredis object func_mock = mock.MagicMock(return_value=self.mock_redis) with patch('magma.state.redis_dicts.get_default_client', func_mock): # Start state garbage collection loop self.garbage_collector = GarbageCollector(service, grpc_client_manager) def tearDown(self): self._rpc_server.stop(None) self.loop.close() @mock.patch('snowflake.snowflake', get_mock_snowflake) def test_collect_states_to_delete(self): async def test(): # Ensure setup is initialized properly self.nid_client.clear() self.foo_client.clear() self.log_client.clear() key = 'id1' self.nid_client[key] = NetworkID(id='foo') self.foo_client[key] = Foo("boo", 3) req = await self.garbage_collector._collect_states_to_delete() self.assertIsNone(req) self.nid_client.mark_as_garbage(key) self.foo_client.mark_as_garbage(key) req = await self.garbage_collector._collect_states_to_delete() self.assertEqual(2, len(req.ids)) for state_id in req.ids: if state_id.type == NID_TYPE: self.assertEqual('id1', state_id.deviceID) elif state_id.type == FOO_TYPE: self.assertEqual('aaa-bbb:id1', state_id.deviceID) else: self.fail("Unknown state type %s" % state_id.type) # Cleanup del self.foo_client[key] del self.nid_client[key] self.loop.run_until_complete(test()) @mock.patch('snowflake.snowflake', get_mock_snowflake) @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel') def test_garbage_collect_success(self, get_rpc_mock): async def test(): get_rpc_mock.return_value = self.channel self.nid_client.clear() self.foo_client.clear() self.log_client.clear() key = 'id1' foo = Foo("boo", 4) self.nid_client[key] = NetworkID(id='foo') self.foo_client[key] = foo self.nid_client.mark_as_garbage(key) self.foo_client.mark_as_garbage(key) req = await self.garbage_collector._collect_states_to_delete() self.assertEqual(2, len(req.ids)) # Ensure all garbage collected objects get deleted from Redis await self.garbage_collector._send_to_state_service(req) self.assertEqual(0, len(self.nid_client.keys())) self.assertEqual(0, len(self.foo_client.keys())) self.assertEqual(0, len(self.nid_client.garbage_keys())) self.assertEqual(0, len(self.foo_client.garbage_keys())) self.loop.run_until_complete(test()) @mock.patch('snowflake.snowflake', get_mock_snowflake) @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel') def test_garbage_collect_rpc_failure(self, get_rpc_mock): async def test(): get_rpc_mock.return_value = self.channel self.nid_client.clear() self.foo_client.clear() self.log_client.clear() key = 'id1' self.nid_client[key] = NetworkID(id='foo') self.log_client[key] = LogVerbosity(verbosity=3) self.nid_client.mark_as_garbage(key) self.log_client.mark_as_garbage(key) req = await self.garbage_collector._collect_states_to_delete() self.assertEqual(2, len(req.ids)) # Ensure objects on deleted from Redis on RPC failure await self.garbage_collector._send_to_state_service(req) self.assertEqual(0, len(self.nid_client.keys())) self.assertEqual(0, len(self.log_client.keys())) self.assertEqual(1, len(self.nid_client.garbage_keys())) self.assertEqual(1, len(self.log_client.garbage_keys())) # Cleanup del self.log_client[key] del self.nid_client[key] self.loop.run_until_complete(test()) @mock.patch('snowflake.snowflake', get_mock_snowflake) @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel') def test_garbage_collect_with_state_update(self, get_rpc_mock): async def test(): get_rpc_mock.return_value = self.channel self.nid_client.clear() self.foo_client.clear() self.log_client.clear() key = 'id1' foo = Foo("boo", 4) self.nid_client[key] = NetworkID(id='foo') self.foo_client[key] = foo self.nid_client.mark_as_garbage(key) self.foo_client.mark_as_garbage(key) req = await self.garbage_collector._collect_states_to_delete() self.assertEqual(2, len(req.ids)) # Update one of the states, to ensure we don't delete valid state # from Redis expected = NetworkID(id='bar') self.nid_client[key] = expected # Ensure all garbage collected objects get deleted from Redis await self.garbage_collector._send_to_state_service(req) self.assertEqual(1, len(self.nid_client.keys())) self.assertEqual(0, len(self.foo_client.keys())) self.assertEqual(0, len(self.nid_client.garbage_keys())) self.assertEqual(0, len(self.foo_client.garbage_keys())) self.assertEqual(expected, self.nid_client[key]) self.loop.run_until_complete(test())
def __init__(self): serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(), get_json_deserializer()) self._redis_dict = RedisFlatDict(get_default_client(), serde)
class GatewayDirectoryServiceRpcServicer(GatewayDirectoryServiceServicer): """ gRPC based server for the Directoryd Gateway service. """ def __init__(self): serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(), get_json_deserializer()) self._redis_dict = RedisFlatDict(get_default_client(), serde) def add_to_server(self, server): """ Add the servicer to a gRPC server """ add_GatewayDirectoryServiceServicer_to_server(self, server) @return_void def UpdateRecord(self, request, context): """ Update the directory record of an object Args: request (UpdateRecordRequest): update record request """ if len(request.id) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("ID argument cannot be empty in " "UpdateRecordRequest") return # Lock Redis for requested key until update is complete with self._redis_dict.lock(request.id): hwid = get_gateway_hwid() record = self._redis_dict.get(request.id) or \ DirectoryRecord(location_history=[hwid], identifiers={}) if record.location_history[0] != hwid: record.location_history = [hwid] + record.location_history for field_key in request.fields: record.identifiers[field_key] = request.fields[field_key] # Truncate location history to the five most recent hwid's record.location_history = \ record.location_history[:LOCATION_MAX_LEN] self._redis_dict[request.id] = record @return_void def DeleteRecord(self, request, context): """ Delete the directory record for an ID Args: request (DeleteRecordRequest): delete record request """ if len(request.id) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("ID argument cannot be empty in " "DeleteRecordRequest") return # Lock Redis for requested key until delete is complete with self._redis_dict.lock(request.id): if request.id not in self._redis_dict: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Record for ID %s was not found." % request.id) return self._redis_dict.mark_as_garbage(request.id) def GetDirectoryField(self, request, context): """ Get the directory record field for an ID and key Args: request (GetDirectoryFieldRequest): get directory field request """ if len(request.id) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("ID argument cannot be empty in " "GetDirectoryFieldRequest") return if len(request.field_key) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("Field key argument cannot be empty in " "GetDirectoryFieldRequest") return # Lock Redis for requested key until get is complete with self._redis_dict.lock(request.id): if request.id not in self._redis_dict: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Record for ID %s was not found." % request.id) return record = self._redis_dict[request.id] if request.field_key not in record.identifiers: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Field %s was not found in record for " "ID %s" % (request.field_key, request.id)) return return DirectoryField(key=request.field_key, value=record.identifiers[request.field_key]) def GetAllDirectoryRecords(self, request, context): """ Get all directory records Args: request (Void): void """ response = AllDirectoryRecords() for key in self._redis_dict.keys(): with self._redis_dict.lock(key): # Lookup may produce an exception if the key has been deleted # between the call to __iter__ and lock try: stored_record = self._redis_dict[key] except KeyError: continue directory_record = response.records.add() directory_record.id = key directory_record.location_history[:] = \ stored_record.location_history for identifier_key in stored_record.identifiers: directory_record.fields[identifier_key] = \ stored_record.identifiers[identifier_key] return response
class GatewayDirectoryServiceRpcServicer(GatewayDirectoryServiceServicer): """ gRPC based server for the Directoryd Gateway service. """ def __init__(self): serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(), get_json_deserializer()) self._redis_dict = RedisFlatDict(get_default_client(), serde) def add_to_server(self, server): """ Add the servicer to a gRPC server """ add_GatewayDirectoryServiceServicer_to_server(self, server) @return_void def UpdateRecord(self, request, context): """ Update the directory record of an object Args: request (UpdateRecordRequest): update record request """ if len(request.id) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("ID argument cannot be empty in " "UpdateRecordRequest") return # Lock Redis for requested key until update is complete with self._redis_dict.lock(request.id): hwid = get_gateway_hwid() record = self._redis_dict.get(request.id) or \ DirectoryRecord(location_history=[hwid], identifiers={}) if record.location_history[0] != hwid: record.location_history = [hwid] + record.location_history for field in request.fields: record.identifiers[field.key] = field.value # Truncate location history to the five most recent hwid's record.location_history = \ record.location_history[:LOCATION_MAX_LEN] self._redis_dict[request.id] = record @return_void def DeleteRecord(self, request, context): """ Delete the directory record for an ID Args: request (DeleteRecordRequest): delete record request """ if len(request.id) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("ID argument cannot be empty in " "DeleteRecordRequest") return # Lock Redis for requested key until delete is complete with self._redis_dict.lock(request.id): if request.id not in self._redis_dict: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Record for ID %s was not found." % request.id) return # TODO: Set record to be garbage collected rather than deleting # directly del self._redis_dict[request.id] def GetDirectoryField(self, request, context): """ Get the directory record field for an ID and key Args: request (GetDirectoryFieldRequest): get directory field request """ if len(request.id) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("ID argument cannot be empty in " "GetDirectoryFieldRequest") return if len(request.field_key) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("Field key argument cannot be empty in " "GetDirectoryFieldRequest") return # Lock Redis for requested key until get is complete with self._redis_dict.lock(request.id): if request.id not in self._redis_dict: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Record for ID %s was not found." % request.id) return record = self._redis_dict[request.id] if request.field_key not in record.identifiers: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Field %s was not found in record for " "ID %s" % (request.field_key, request.id)) return return DirectoryField(key=request.field_key, value=record.identifiers[request.field_key])
class StateReplicator(SDWatchdogTask): """ StateReplicator periodically fetches all configured state from Redis, reporting any updates to the Orchestrator State service. """ def __init__(self, service: MagmaService, grpc_client_manager: GRPCClientManager): super().__init__(DEFAULT_SYNC_INTERVAL, service.loop) self._service = service # In memory mapping of states to version self._state_versions = {} # Serdes for each type of state to replicate self._serdes = {} self._get_proto_redis_serdes() self._get_json_redis_serdes() self._redis_client = RedisFlatDict(get_default_client(), self._serdes) # _grpc_client_manager to manage grpc client recyclings self._grpc_client_manager = grpc_client_manager # Flag to indicate if resync has completed successfully. # Replication cannot proceed until this flag is True self._has_resync_completed = False def _get_proto_redis_serdes(self): state_protos = self._service.config.get('state_protos', []) or [] for proto_cfg in state_protos: is_invalid_cfg = 'proto_msg' not in proto_cfg or \ 'proto_file' not in proto_cfg or \ 'redis_key' not in proto_cfg or \ 'state_scope' not in proto_cfg if is_invalid_cfg: logging.warning( "Invalid proto config found in state_protos " "configuration: %s", proto_cfg) continue try: proto_module = importlib.import_module(proto_cfg['proto_file']) msg = getattr(proto_module, proto_cfg['proto_msg']) redis_key = proto_cfg['redis_key'] logging.info('Initializing RedisSerde for proto state %s', proto_cfg['redis_key']) serde = StateSerde(redis_key, get_proto_serializer(), get_proto_deserializer(msg), proto_cfg['state_scope'], PROTO_FORMAT) self._serdes[redis_key] = serde except (ImportError, AttributeError) as err: logging.error(err) def _get_json_redis_serdes(self): json_state = self._service.config.get('json_state', []) or [] for json_cfg in json_state: is_invalid_cfg = 'redis_key' not in json_cfg or \ 'state_scope' not in json_cfg if is_invalid_cfg: logging.warning( "Invalid json state config found in json_state" "configuration: %s", json_cfg) continue logging.info('Initializing RedisSerde for json state %s', json_cfg['redis_key']) redis_key = json_cfg['redis_key'] serde = StateSerde(redis_key, get_json_serializer(), get_json_deserializer(), json_cfg['state_scope'], JSON_FORMAT) self._serdes[redis_key] = serde async def _run(self): if not self._has_resync_completed: try: await self._resync() except grpc.RpcError as err: logging.error("GRPC call failed for initial state re-sync: %s", err) return request = await self._collect_states_to_replicate() if request is not None: await self._send_to_state_service(request) async def _resync(self): states_to_sync = [] for key in self._redis_client: try: idval, state_type = self._parse_key(key) except ValueError as err: logging.debug(err) continue state_scope = self._serdes[state_type].state_scope version = self._redis_client.get_version(idval, state_type) device_id = self.make_scoped_device_id(idval, state_scope) state_id = StateID(type=state_type, deviceID=device_id) id_and_version = IDAndVersion(id=state_id, version=version) states_to_sync.append(id_and_version) if len(states_to_sync) == 0: logging.debug("Not re-syncing state. No local state found.") return state_client = self._grpc_client_manager.get_client() request = SyncStatesRequest(states=states_to_sync) response = await grpc_async_wrapper( state_client.SyncStates.future( request, DEFAULT_GRPC_TIMEOUT, ), self._loop) unsynced_states = set() for id_and_version in response.unsyncedStates: unsynced_states.add( (id_and_version.id.type, id_and_version.id.deviceID)) # Update in-memory map to add already synced states for state in request.states: in_mem_key = self.make_mem_key(state.id.deviceID, state.id.type) if (state.id.type, state.id.deviceID) not in unsynced_states: self._state_versions[in_mem_key] = state.version self._has_resync_completed = True logging.info("Successfully resynced state with Orchestrator!") async def _collect_states_to_replicate(self): states_to_report = [] for key in self._redis_client: try: idval, state_type = self._parse_key(key) except ValueError as err: logging.debug(err) continue state_scope = self._serdes[state_type].state_scope device_id = self.make_scoped_device_id(idval, state_scope) in_mem_key = self.make_mem_key(device_id, state_type) redis_version = self._redis_client.get_version(idval, state_type) if in_mem_key in self._state_versions and \ self._state_versions[in_mem_key] == redis_version: continue redis_state = self._redis_client.get(key) if self._serdes[state_type].state_format == PROTO_FORMAT: state_to_serialize = MessageToDict(redis_state) else: state_to_serialize = redis_state serialized_json_state = json.dumps(state_to_serialize) state_proto = State(type=state_type, deviceID=device_id, value=serialized_json_state.encode("utf-8"), version=redis_version) states_to_report.append(state_proto) if len(states_to_report) == 0: logging.debug("Not replicating state. No state has changed!") return None return ReportStatesRequest(states=states_to_report) async def _send_to_state_service(self, request: ReportStatesRequest): state_client = self._grpc_client_manager.get_client() try: response = await grpc_async_wrapper( state_client.ReportStates.future( request, DEFAULT_GRPC_TIMEOUT, ), self._loop) except grpc.RpcError as err: logging.error("GRPC call failed for state replication: %s", err) else: unreplicated_states = set() for idAndError in response.unreportedStates: logging.warning("Failed to replicate state for (%s,%s): %s", idAndError.type, idAndError.deviceID, idAndError.error) unreplicated_states.add((idAndError.type, idAndError.deviceID)) # Update in-memory map for successfully reported states for state in request.states: if (state.type, state.deviceID) in unreplicated_states: continue in_mem_key = self.make_mem_key(state.deviceID, state.type) self._state_versions[in_mem_key] = state.version logging.debug( "Successfully replicated state for: " "deviceID: %s," "type: %s, " "version: %d", state.deviceID, state.type, state.version) finally: # reset timeout to config-specified + some buffer self.set_timeout(self._interval * 2) def _parse_key(self, key): split_key = key.split(REDIS_KEY_DELIMITER, 1) if len(split_key) != 2: raise ValueError("Redis key: %s is not of format <id>:<<type>. " "Not replicating." % key) idval = split_key[0] state_type = split_key[1] if state_type not in self._serdes: raise ValueError("No serde found for state type: %s. " "Not replicating key: %s" % (state_type, idval)) return idval, state_type @staticmethod def make_mem_key(device_id, state_type): """ Create a key of the format <id>:<type> """ return device_id + ":" + state_type @staticmethod def make_scoped_device_id(id, scope): """ Create a deviceID of the format <id> for scope 'network' Otherwise create a key of the format <hwid>:<id> for 'gateway' or unrecognized scope. """ if scope == "network": return id else: return snowflake.snowflake() + ":" + id
def setUp(self): self.mock_redis = fakeredis.FakeStrictRedis() self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) service = MagicMock() service.config = { # Replicate arbitrary orc8r protos 'state_protos': [ { 'proto_file': 'orc8r.protos.common_pb2', 'proto_msg': 'NetworkID', 'redis_key': NID_TYPE, 'state_scope': 'network', }, { 'proto_file': 'orc8r.protos.common_pb2', 'proto_msg': 'IDList', 'redis_key': IDList_TYPE, 'state_scope': 'gateway', }, { 'proto_file': 'orc8r.protos.service303_pb2', 'proto_msg': 'LogVerbosity', 'redis_key': LOG_TYPE, 'state_scope': 'gateway', }, ], 'json_state': [{ 'redis_key': FOO_TYPE, 'state_scope': 'network' }], } service.loop = self.loop # Bind the rpc server to a free port self._rpc_server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), ) port = self._rpc_server.add_insecure_port('0.0.0.0:0') # Add the servicer self._servicer = DummyStateServer() self._servicer.add_to_server(self._rpc_server) self._rpc_server.start() # Create a rpc stub self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port)) serde1 = RedisSerde( NID_TYPE, get_proto_serializer(), get_proto_deserializer(NetworkID), ) serde2 = RedisSerde( IDList_TYPE, get_proto_serializer(), get_proto_deserializer(IDList), ) serde3 = RedisSerde( LOG_TYPE, get_proto_serializer(), get_proto_deserializer(LogVerbosity), ) serde4 = RedisSerde( FOO_TYPE, get_json_serializer(), get_json_deserializer(), ) self.nid_client = RedisFlatDict(self.mock_redis, serde1) self.idlist_client = RedisFlatDict(self.mock_redis, serde2) self.log_client = RedisFlatDict(self.mock_redis, serde3) self.foo_client = RedisFlatDict(self.mock_redis, serde4) # Set up and start state replicating loop grpc_client_manager = GRPCClientManager( service_name="state", service_stub=StateServiceStub, max_client_reuse=60, ) # mock the get_default_client function used to return the same # fakeredis object func_mock = mock.MagicMock(return_value=self.mock_redis) with mock.patch( 'magma.state.redis_dicts.get_default_client', func_mock, ): garbage_collector = GarbageCollector(service, grpc_client_manager) self.state_replicator = StateReplicator( service=service, garbage_collector=garbage_collector, grpc_client_manager=grpc_client_manager, ) self.state_replicator.start()
class GatewayDirectoryServiceRpcServicer(GatewayDirectoryServiceServicer): """gRPC based server for the Directoryd Gateway service""" def __init__(self, print_grpc_payload: bool = False): """Initialize Directoryd grpc endpoints.""" serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(), get_json_deserializer()) self._redis_dict = RedisFlatDict(get_default_client(), serde) self._print_grpc_payload = print_grpc_payload if self._print_grpc_payload: logging.info("Printing GRPC messages") def add_to_server(self, server): """ Add the servicer to a gRPC server """ add_GatewayDirectoryServiceServicer_to_server(self, server) @return_void def UpdateRecord(self, request, context): """ Update the directory record of an object Args: request (UpdateRecordRequest): update record request """ logging.debug("UpdateRecord request received") self._print_grpc(request) if len(request.id) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("ID argument cannot be empty in " "UpdateRecordRequest") return try: # Lock Redis for requested key until update is complete with self._redis_dict.lock(request.id): hwid = get_gateway_hwid() record = self._redis_dict.get(request.id) or \ DirectoryRecord(location_history=[hwid], identifiers={}) if record.location_history[0] != hwid: record.location_history = [hwid] + record.location_history for field_key in request.fields: record.identifiers[field_key] = request.fields[field_key] # Truncate location history to the five most recent hwid's record.location_history = \ record.location_history[:LOCATION_MAX_LEN] self._redis_dict[request.id] = record except (RedisError, LockError) as e: logging.error(e) context.set_code(grpc.StatusCode.UNAVAILABLE) context.set_details("Could not connect to redis: %s" % e) @return_void def DeleteRecord(self, request, context): """ Delete the directory record for an ID Args: request (DeleteRecordRequest): delete record request """ logging.debug("DeleteRecord request received") self._print_grpc(request) if len(request.id) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("ID argument cannot be empty in " "DeleteRecordRequest") return # Lock Redis for requested key until delete is complete try: with self._redis_dict.lock(request.id): if request.id not in self._redis_dict: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Record for ID %s was not found." % request.id) return self._redis_dict.mark_as_garbage(request.id) except (RedisError, LockError) as e: logging.error(e) context.set_code(grpc.StatusCode.UNAVAILABLE) context.set_details("Could not connect to redis: %s" % e) def GetDirectoryField(self, request, context): """ Get the directory record field for an ID and key Args: request (GetDirectoryFieldRequest): get directory field request """ logging.debug("GetDirectoryField request received") self._print_grpc(request) if len(request.id) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("ID argument cannot be empty in " "GetDirectoryFieldRequest") return if len(request.field_key) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("Field key argument cannot be empty in " "GetDirectoryFieldRequest") response = DirectoryField() self._print_grpc(response) return response # Lock Redis for requested key until get is complete try: with self._redis_dict.lock(request.id): if request.id not in self._redis_dict: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Record for ID %s was not found." % request.id) return DirectoryField() record = self._redis_dict[request.id] except (RedisError, LockError) as e: logging.error(e) context.set_code(grpc.StatusCode.UNAVAILABLE) context.set_details("Could not connect to redis: %s" % e) response = DirectoryField() self._print_grpc(response) return response if request.field_key not in record.identifiers: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Field %s was not found in record for " "ID %s" % (request.field_key, request.id)) return DirectoryField() response = DirectoryField(key=request.field_key, value=record.identifiers[request.field_key]) self._print_grpc(response) return response def GetAllDirectoryRecords(self, request, context): """ Get all directory records Args: request (Void): void """ logging.debug("GetAllDirectoryRecords request received") self._print_grpc(request) response = AllDirectoryRecords() try: redis_keys = self._redis_dict.keys() except RedisError as e: logging.error(e) context.set_code(grpc.StatusCode.UNAVAILABLE) context.set_details("Could not connect to redis: %s" % e) self._print_grpc(request) return response for key in redis_keys: try: with self._redis_dict.lock(key): # Lookup may produce an exception if the key has been # deleted between the call to __iter__ and lock stored_record = self._redis_dict[key] except (RedisError, LockError) as e: logging.error(e) context.set_code(grpc.StatusCode.UNAVAILABLE) context.set_details("Could not connect to redis: %s" % e) self._print_grpc(response) return response except KeyError: continue directory_record = response.records.add() directory_record.id = key directory_record.location_history[:] = \ stored_record.location_history for identifier_key in stored_record.identifiers: directory_record.fields[identifier_key] = \ stored_record.identifiers[identifier_key] self._print_grpc(response) return response def _print_grpc(self, message): if self._print_grpc_payload: log_msg = "{} {}".format(message.DESCRIPTOR.full_name, MessageToJson(message)) # add indentation padding = 2 * ' ' log_msg = ''.join("{}{}".format(padding, line) for line in log_msg.splitlines(True)) log_msg = "GRPC message:\n{}".format(log_msg) logging.info(log_msg)
class StateReplicatorTests(TestCase): @mock.patch("redis.Redis", MockRedis) def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) service = MagicMock() service.config = { # Replicate arbitrary orc8r protos 'state_protos': [{ 'proto_file': 'orc8r.protos.common_pb2', 'proto_msg': 'NetworkID', 'redis_key': NID_TYPE, 'state_scope': 'network' }, { 'proto_file': 'orc8r.protos.common_pb2', 'proto_msg': 'IDList', 'redis_key': IDList_TYPE, 'state_scope': 'gateway' }, { 'proto_file': 'orc8r.protos.service303_pb2', 'proto_msg': 'LogVerbosity', 'redis_key': LOG_TYPE, 'state_scope': 'gateway' }], 'json_state': [{ 'redis_key': FOO_TYPE, 'state_scope': 'network' }] } service.loop = self.loop # Bind the rpc server to a free port self._rpc_server = grpc.server( futures.ThreadPoolExecutor(max_workers=10)) port = self._rpc_server.add_insecure_port('0.0.0.0:0') # Add the servicer self._servicer = DummyStateServer() self._servicer.add_to_server(self._rpc_server) self._rpc_server.start() # Create a rpc stub self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port)) serde1 = RedisSerde(NID_TYPE, get_proto_serializer(), get_proto_deserializer(NetworkID)) serde2 = RedisSerde(IDList_TYPE, get_proto_serializer(), get_proto_deserializer(IDList)) serde3 = RedisSerde(LOG_TYPE, get_proto_serializer(), get_proto_deserializer(LogVerbosity)) serde4 = RedisSerde(FOO_TYPE, get_json_serializer(), get_json_deserializer()) self.nid_client = RedisFlatDict(get_default_client(), serde1) self.idlist_client = RedisFlatDict(get_default_client(), serde2) self.log_client = RedisFlatDict(get_default_client(), serde3) self.foo_client = RedisFlatDict(get_default_client(), serde4) # Set up and start state replicating loop grpc_client_manager = GRPCClientManager( service_name="state", service_stub=StateServiceStub, max_client_reuse=60, ) garbage_collector = GarbageCollector(service, grpc_client_manager) self.state_replicator = StateReplicator( service=service, garbage_collector=garbage_collector, grpc_client_manager=grpc_client_manager, ) self.state_replicator.start() @mock.patch("redis.Redis", MockRedis) def tearDown(self): self._rpc_server.stop(None) self.state_replicator.stop() self.loop.close() def convert_msg_to_state(self, redis_state, is_proto=True): if is_proto: json_converted_state = MessageToDict(redis_state) serialized_json_state = json.dumps(json_converted_state) else: serialized_json_state = jsonpickle.encode(redis_state) return serialized_json_state.encode("utf-8") @mock.patch("redis.Redis", MockRedis) @mock.patch('snowflake.snowflake', get_mock_snowflake) def test_collect_states_to_replicate(self): async def test(): # Ensure setup is initialized properly self.nid_client.clear() self.idlist_client.clear() self.log_client.clear() self.foo_client.clear() key = 'id1' self.nid_client[key] = NetworkID(id='foo') self.idlist_client[key] = IDList(ids=['bar', 'blah']) self.foo_client[key] = Foo("boo", 3) exp1 = self.convert_msg_to_state(self.nid_client[key]) exp2 = self.convert_msg_to_state(self.idlist_client[key]) exp3 = self.convert_msg_to_state(self.foo_client[key], False) req = await self.state_replicator._collect_states_to_replicate() self.assertEqual(3, len(req.states)) for state in req.states: if state.type == NID_TYPE: self.assertEqual('id1', state.deviceID) self.assertEqual(1, state.version) self.assertEqual(exp1, state.value) elif state.type == IDList_TYPE: self.assertEqual('aaa-bbb:id1', state.deviceID) self.assertEqual(1, state.version) self.assertEqual(exp2, state.value) elif state.type == FOO_TYPE: self.assertEqual('id1', state.deviceID) self.assertEqual(1, state.version) self.assertEqual(exp3, state.value) else: self.fail("Unknown state type %s" % state.type) # Cancel the replicator's loop so there are no other activities self.state_replicator._periodic_task.cancel() self.loop.run_until_complete(test()) @mock.patch("redis.Redis", MockRedis) @mock.patch('snowflake.snowflake', get_mock_snowflake) @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel') def test_replicate_states_success(self, get_rpc_mock): async def test(): get_rpc_mock.return_value = self.channel # Ensure setup is initialized properly self.nid_client.clear() self.idlist_client.clear() self.log_client.clear() self.foo_client.clear() key = 'id1' foo = Foo("boo", 4) self.nid_client[key] = NetworkID(id='foo') self.idlist_client[key] = IDList(ids=['bar', 'blah']) self.foo_client[key] = foo # Increment version self.idlist_client[key] = IDList(ids=['bar', 'blah']) req = await self.state_replicator._collect_states_to_replicate() self.assertEqual(3, len(req.states)) # Ensure in-memory map updates properly await self.state_replicator._send_to_state_service(req) self.assertEqual(3, len(self.state_replicator._state_versions)) mem_key1 = make_mem_key('id1', NID_TYPE) mem_key2 = make_mem_key('aaa-bbb:id1', IDList_TYPE) mem_key3 = make_mem_key('id1', FOO_TYPE) self.assertEqual(1, self.state_replicator._state_versions[mem_key1]) self.assertEqual(2, self.state_replicator._state_versions[mem_key2]) self.assertEqual(1, self.state_replicator._state_versions[mem_key3]) # Now add new state and update some existing state key2 = 'id2' self.nid_client[key2] = NetworkID(id='bar') self.idlist_client[key] = IDList(ids=['bar', 'foo']) req = await self.state_replicator._collect_states_to_replicate() self.assertEqual(2, len(req.states)) # Ensure in-memory map updates properly await self.state_replicator._send_to_state_service(req) self.assertEqual(4, len(self.state_replicator._state_versions)) mem_key4 = make_mem_key('id2', NID_TYPE) self.assertEqual(1, self.state_replicator._state_versions[mem_key1]) self.assertEqual(3, self.state_replicator._state_versions[mem_key2]) self.assertEqual(1, self.state_replicator._state_versions[mem_key3]) self.assertEqual(1, self.state_replicator._state_versions[mem_key4]) # Cancel the replicator's loop so there are no other activities self.state_replicator._periodic_task.cancel() self.loop.run_until_complete(test()) @mock.patch("redis.Redis", MockRedis) @mock.patch('snowflake.snowflake', get_mock_snowflake) @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel') def test_unreplicated_states(self, get_grpc_mock): async def test(): get_grpc_mock.return_value = self.channel # Add initial state to be replicated self.nid_client.clear() self.idlist_client.clear() self.log_client.clear() self.foo_client.clear() key = 'id1' key2 = 'id2' self.nid_client[key] = NetworkID(id='foo') self.idlist_client[key] = IDList(ids=['bar', 'blah']) # Increment version self.idlist_client[key] = IDList(ids=['bar', 'blah']) # Set state that will be 'unreplicated' self.log_client[key2] = LogVerbosity(verbosity=5) req = await self.state_replicator._collect_states_to_replicate() self.assertEqual(3, len(req.states)) # Ensure in-memory map updates properly for successful replications await self.state_replicator._send_to_state_service(req) self.assertEqual(2, len(self.state_replicator._state_versions)) mem_key1 = make_mem_key('id1', NID_TYPE) mem_key2 = make_mem_key('aaa-bbb:id1', IDList_TYPE) self.assertEqual(1, self.state_replicator._state_versions[mem_key1]) self.assertEqual(2, self.state_replicator._state_versions[mem_key2]) # Now run again, ensuring only the state the wasn't replicated # will be sent again req = await self.state_replicator._collect_states_to_replicate() self.assertEqual(1, len(req.states)) self.assertEqual('aaa-bbb:id2', req.states[0].deviceID) self.assertEqual(LOG_TYPE, req.states[0].type) # Cancel the replicator's loop so there are no other activities self.state_replicator._periodic_task.cancel() self.loop.run_until_complete(test()) @mock.patch("redis.Redis", MockRedis) @mock.patch('snowflake.snowflake', get_mock_snowflake) @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel') def test_resync_success(self, get_grpc_mock): async def test(): get_grpc_mock.return_value = self.channel self.nid_client.clear() self.idlist_client.clear() self.log_client.clear() self.foo_client.clear() key = 'id1' # Set state that will be 'unsynced' self.nid_client[key] = NetworkID(id='foo') self.idlist_client[key] = IDList(ids=['bar', 'blah']) # Increment state's version self.idlist_client[key] = IDList(ids=['bar', 'blah']) await self.state_replicator._resync() self.assertEqual(True, self.state_replicator._has_resync_completed) self.assertEqual(1, len(self.state_replicator._state_versions)) mem_key = make_mem_key('aaa-bbb:id1', IDList_TYPE) self.assertEqual(2, self.state_replicator._state_versions[mem_key]) # Cancel the replicator's loop so there are no other activities self.state_replicator._periodic_task.cancel() self.loop.run_until_complete(test()) @mock.patch("redis.Redis", MockRedis) @mock.patch('snowflake.snowflake', get_mock_snowflake) @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel') def test_resync_failure(self, get_grpc_mock): async def test(): get_grpc_mock.return_value = self.channel self.nid_client.clear() self.idlist_client.clear() self.log_client.clear() self.foo_client.clear() # Set state that will trigger the RpcError log_key = 'id1' self.log_client[log_key] = LogVerbosity(verbosity=5) try: await self.state_replicator._resync() except grpc.RpcError: pass self.assertEqual(False, self.state_replicator._has_resync_completed) self.assertEqual(0, len(self.state_replicator._state_versions)) # Cancel the replicator's loop so there are no other activities self.state_replicator._periodic_task.cancel() self.loop.run_until_complete(test()) @mock.patch("redis.Redis", MockRedis) @mock.patch('snowflake.snowflake', get_mock_snowflake) @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel') def test_deleted_replicated_state(self, get_grpc_mock): async def test(): get_grpc_mock.return_value = self.channel self.nid_client.clear() self.idlist_client.clear() self.log_client.clear() self.foo_client.clear() key = 'id1' self.nid_client[key] = NetworkID(id='foo') req = await self.state_replicator._collect_states_to_replicate() self.assertEqual(1, len(req.states)) # Ensure in-memory map updates properly await self.state_replicator._send_to_state_service(req) self.assertEqual(1, len(self.state_replicator._state_versions)) mem_key1 = make_mem_key('id1', NID_TYPE) self.assertEqual(1, self.state_replicator._state_versions[mem_key1]) # Now delete state and ensure in-memory map gets updated properly del self.nid_client[key] req = await self.state_replicator._collect_states_to_replicate() self.assertIsNone(req) await self.state_replicator._cleanup_deleted_keys() self.assertFalse(key in self.state_replicator._state_versions) # Cancel the replicator's loop so there are no other activities self.state_replicator._periodic_task.cancel() self.loop.run_until_complete(test())
class RedisDictTests(TestCase): """ Tests for the RedisHashDict and RedisFlatDict containers """ @mock.patch("redis.Redis", MockRedis) def setUp(self): client = get_default_client() # Use arbitrary orc8r proto to test with self._hash_dict = RedisHashDict( client, "unittest", get_proto_serializer(), get_proto_deserializer(LogVerbosity)) serdes = {} serdes['log_verbosity'] = RedisSerde('log_verbosity', get_proto_serializer(), get_proto_deserializer(LogVerbosity)) self._flat_dict = RedisFlatDict(client, serdes) @mock.patch("redis.Redis", MockRedis) def test_hash_insert(self): expected = LogVerbosity(verbosity=0) expected2 = LogVerbosity(verbosity=1) # insert proto self._hash_dict['key1'] = expected version = self._hash_dict.get_version("key1") actual = self._hash_dict['key1'] self.assertEqual(1, version) self.assertEqual(expected, actual) # update proto self._hash_dict['key1'] = expected2 version2 = self._hash_dict.get_version("key1") actual2 = self._hash_dict['key1'] self.assertEqual(2, version2) self.assertEqual(expected2, actual2) @mock.patch("redis.Redis", MockRedis) def test_missing_version(self): missing_version = self._hash_dict.get_version("key2") self.assertEqual(0, missing_version) @mock.patch("redis.Redis", MockRedis) def test_hash_delete(self): expected = LogVerbosity(verbosity=2) self._hash_dict['key3'] = expected actual = self._hash_dict['key3'] self.assertEqual(expected, actual) self._hash_dict.pop('key3') self.assertRaises(KeyError, self._hash_dict.__getitem__, 'key3') @mock.patch("redis.Redis", MockRedis) def test_flat_insert(self): expected = LogVerbosity(verbosity=5) expected2 = LogVerbosity(verbosity=1) # insert proto self._flat_dict['key1:log_verbosity'] = expected version = self._flat_dict.get_version("key1", "log_verbosity") actual = self._flat_dict['key1:log_verbosity'] self.assertEqual(1, version) self.assertEqual(expected, actual) # update proto self._flat_dict["key1:log_verbosity"] = expected2 version2 = self._flat_dict.get_version("key1", "log_verbosity") actual2 = self._flat_dict["key1:log_verbosity"] self.assertEqual(2, version2) self.assertEqual(expected2, actual2) @mock.patch("redis.Redis", MockRedis) def test_flat_missing_version(self): missing_version = self._flat_dict.get_version("key2", "log_verbosity") self.assertEqual(0, missing_version) @mock.patch("redis.Redis", MockRedis) def test_flat_invalid_key(self): expected = LogVerbosity(verbosity=5) self.assertRaises(ValueError, self._flat_dict.__setitem__, 'key3', expected) @mock.patch("redis.Redis", MockRedis) def test_flat_invalid_serde(self): expected = LogVerbosity(verbosity=5) self.assertRaises(ValueError, self._flat_dict.__setitem__, 'key3:missing_serde', expected) @mock.patch("redis.Redis", MockRedis) def test_flat_delete(self): expected = LogVerbosity(verbosity=2) self._flat_dict['key3:log_verbosity'] = expected actual = self._flat_dict['key3:log_verbosity'] self.assertEqual(expected, actual) self._flat_dict.pop('key3:log_verbosity') self.assertRaises(KeyError, self._flat_dict.__getitem__, 'key3:log_verbosity')