示例#1
0
class ServiceStateWrapper:
    """
    Class wraps ServiceState interactions with redis
    """

    # Unique typename for Redis key
    REDIS_VALUE_TYPE = "systemd_status"

    def __init__(self):
        serde = RedisSerde(self.REDIS_VALUE_TYPE, get_proto_serializer(),
                           get_proto_deserializer(ServiceExitStatus))
        self._flat_dict = RedisFlatDict(get_default_client(), serde)

    def update_service_status(self, service_name: str,
                              service_status: ServiceExitStatus) -> None:
        """
        Update the service exit status for a given service
        """

        if service_name in self._flat_dict:
            current_service_status = self._flat_dict[service_name]
        else:
            current_service_status = ServiceExitStatus()

        if service_status.latest_service_result == \
                ServiceExitStatus.ServiceResult.Value("SUCCESS"):
            service_status.num_clean_exits = \
                current_service_status.num_clean_exits + 1
            service_status.num_fail_exits = \
                current_service_status.num_fail_exits
        else:
            service_status.num_fail_exits = \
                current_service_status.num_fail_exits + 1
            service_status.num_clean_exits = \
                current_service_status.num_clean_exits
        self._flat_dict[service_name] = service_status

    def get_service_status(self, service_name: str) -> ServiceExitStatus:
        """
        Get the service status protobuf for a given service
        @returns ServiceStatus protobuf object
        """
        return self._flat_dict[service_name]

    def get_all_services_status(self) -> [str, ServiceExitStatus]:
        """
        Get a dict of service name to service status
        @return dict of service_name to service map
        """
        service_status = {}
        for k, v in self._flat_dict.items():
            service_status[k] = v
        return service_status

    def cleanup_service_status(self) -> None:
        """
        Cleanup service status for all services in redis, mostly using for
        testing
        """
        self._flat_dict.clear()
示例#2
0
 def __init__(self):
     serde = RedisSerde(
         self.REDIS_VALUE_TYPE,
         get_proto_serializer(),
         get_proto_deserializer(ServiceExitStatus),
     )
     self._flat_dict = RedisFlatDict(get_default_client(), serde)
示例#3
0
    def __init__(self, print_grpc_payload: bool = False):
        """Initialize Directoryd grpc endpoints."""
        serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(),
                           get_json_deserializer())
        self._redis_dict = RedisFlatDict(get_default_client(), serde)
        self._print_grpc_payload = print_grpc_payload

        if self._print_grpc_payload:
            logging.info("Printing GRPC messages")
示例#4
0
    def setUp(self):
        client = get_default_client()
        # Use arbitrary orc8r proto to test with
        self._hash_dict = RedisHashDict(client, "unittest",
                                        get_proto_serializer(),
                                        get_proto_deserializer(LogVerbosity))

        serde = RedisSerde('log_verbosity', get_proto_serializer(),
                           get_proto_deserializer(LogVerbosity))
        self._flat_dict = RedisFlatDict(client, serde)
示例#5
0
 async def _delete_state_from_redis(self, redis_dict: RedisFlatDict,
                                    key: str) -> None:
     # Ensure that the object isn't updated before deletion
     with redis_dict.lock(key):
         deleted = redis_dict.delete_garbage(key)
         if deleted:
             logging.debug(
                 "Successfully garbage collected "
                 "state for key: %s", key)
         else:
             logging.debug(
                 "Successfully garbage collected "
                 "state in cloud for key %s. "
                 "Didn't delete locally as the "
                 "object is no longer garbage", key)
示例#6
0
    def __init__(self, service: MagmaService,
                 grpc_client_manager: GRPCClientManager):
        super().__init__(DEFAULT_SYNC_INTERVAL, service.loop)
        self._service = service
        # In memory mapping of states to version
        self._state_versions = {}
        # Serdes for each type of state to replicate
        self._serdes = {}
        self._get_proto_redis_serdes()
        self._get_json_redis_serdes()
        self._redis_client = RedisFlatDict(get_default_client(), self._serdes)

        # _grpc_client_manager to manage grpc client recyclings
        self._grpc_client_manager = grpc_client_manager

        # Flag to indicate if resync has completed successfully.
        # Replication cannot proceed until this flag is True
        self._has_resync_completed = False
示例#7
0
    def setUp(self):

        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

        service = MagicMock()
        service.config = {
            # Replicate arbitrary orc8r protos
            'state_protos': [
                {
                    'proto_file': 'orc8r.protos.common_pb2',
                    'proto_msg': 'NetworkID',
                    'redis_key': NID_TYPE,
                    'state_scope': 'network'
                },
                {
                    'proto_file': 'orc8r.protos.service303_pb2',
                    'proto_msg': 'LogVerbosity',
                    'redis_key': LOG_TYPE,
                    'state_scope': 'gateway'
                },
            ],
            'json_state': [{
                'redis_key': FOO_TYPE,
                'state_scope': 'gateway'
            }]
        }
        service.loop = self.loop

        # Bind the rpc server to a free port
        self._rpc_server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10))
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')
        # Add the servicer
        self._servicer = DummyStateServer()
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()
        # Create a rpc stub
        self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))

        serde1 = RedisSerde(NID_TYPE, get_proto_serializer(),
                            get_proto_deserializer(NetworkID))
        serde2 = RedisSerde(FOO_TYPE, get_json_serializer(),
                            get_json_deserializer())
        serde3 = RedisSerde(LOG_TYPE, get_proto_serializer(),
                            get_proto_deserializer(LogVerbosity))

        self.nid_client = RedisFlatDict(get_default_client(), serde1)
        self.foo_client = RedisFlatDict(get_default_client(), serde2)
        self.log_client = RedisFlatDict(get_default_client(), serde3)

        # Set up and start garbage collecting loop
        grpc_client_manager = GRPCClientManager(
            service_name="state",
            service_stub=StateServiceStub,
            max_client_reuse=60,
        )

        # Start state garbage collection loop
        self.garbage_collector = GarbageCollector(service, grpc_client_manager)
示例#8
0
class RedisDictTests(TestCase):
    """
    Tests for the RedisHashDict and RedisFlatDict containers
    """
    @mock.patch("redis.Redis", MockRedis)
    def setUp(self):
        client = get_default_client()
        # Use arbitrary orc8r proto to test with
        self._hash_dict = RedisHashDict(client, "unittest",
                                        get_proto_serializer(),
                                        get_proto_deserializer(LogVerbosity))

        serde = RedisSerde('log_verbosity', get_proto_serializer(),
                           get_proto_deserializer(LogVerbosity))
        self._flat_dict = RedisFlatDict(client, serde)

    @mock.patch("redis.Redis", MockRedis)
    def test_hash_insert(self):
        expected = LogVerbosity(verbosity=0)
        expected2 = LogVerbosity(verbosity=1)

        # insert proto
        self._hash_dict['key1'] = expected
        version = self._hash_dict.get_version("key1")
        actual = self._hash_dict['key1']
        self.assertEqual(1, version)
        self.assertEqual(expected, actual)

        # update proto
        self._hash_dict['key1'] = expected2
        version2 = self._hash_dict.get_version("key1")
        actual2 = self._hash_dict['key1']
        self.assertEqual(2, version2)
        self.assertEqual(expected2, actual2)

    @mock.patch("redis.Redis", MockRedis)
    def test_missing_version(self):
        missing_version = self._hash_dict.get_version("key2")
        self.assertEqual(0, missing_version)

    @mock.patch("redis.Redis", MockRedis)
    def test_hash_delete(self):
        expected = LogVerbosity(verbosity=2)
        self._hash_dict['key3'] = expected

        actual = self._hash_dict['key3']
        self.assertEqual(expected, actual)

        self._hash_dict.pop('key3')
        self.assertRaises(KeyError, self._hash_dict.__getitem__, 'key3')

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_insert(self):
        expected = LogVerbosity(verbosity=5)
        expected2 = LogVerbosity(verbosity=1)

        # insert proto
        self._flat_dict['key1'] = expected
        version = self._flat_dict.get_version("key1")
        actual = self._flat_dict['key1']
        self.assertEqual(1, version)
        self.assertEqual(expected, actual)

        # update proto
        self._flat_dict["key1"] = expected2
        version2 = self._flat_dict.get_version("key1")
        actual2 = self._flat_dict["key1"]
        actual3 = self._flat_dict.get("key1")
        self.assertEqual(2, version2)
        self.assertEqual(expected2, actual2)
        self.assertEqual(expected2, actual3)

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_missing_version(self):
        missing_version = self._flat_dict.get_version("key2")
        self.assertEqual(0, missing_version)

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_bad_key(self):
        expected = LogVerbosity(verbosity=2)
        self.assertRaises(ValueError, self._flat_dict.__setitem__, 'bad:key',
                          expected)
        self.assertRaises(ValueError, self._flat_dict.__getitem__, 'bad:key')
        self.assertRaises(ValueError, self._flat_dict.__delitem__, 'bad:key')

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_delete(self):
        expected = LogVerbosity(verbosity=2)
        self._flat_dict['key3'] = expected

        actual = self._flat_dict['key3']
        self.assertEqual(expected, actual)

        del self._flat_dict['key3']
        self.assertRaises(KeyError, self._flat_dict.__getitem__, 'key3')
        self.assertEqual(None, self._flat_dict.get('key3'))

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_clear(self):
        expected = LogVerbosity(verbosity=2)
        self._flat_dict['key3'] = expected

        actual = self._flat_dict['key3']
        self.assertEqual(expected, actual)

        self._flat_dict.clear()
        self.assertEqual(0, len(self._flat_dict.keys()))

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_garbage_methods(self):
        expected = LogVerbosity(verbosity=2)
        expected2 = LogVerbosity(verbosity=3)

        key = "k1"
        key2 = "k2"
        bad_key = "bad_key"
        self._flat_dict[key] = expected
        self._flat_dict[key2] = expected2

        self._flat_dict.mark_as_garbage(key)
        is_garbage = self._flat_dict.is_garbage(key)
        self.assertTrue(is_garbage)
        is_garbage2 = self._flat_dict.is_garbage(key2)
        self.assertFalse(is_garbage2)

        self.assertEqual([key], self._flat_dict.garbage_keys())
        self.assertEqual([key2], self._flat_dict.keys())

        self.assertIsNone(self._flat_dict.get(key))
        self.assertEqual(expected2, self._flat_dict.get(key2))

        deleted = self._flat_dict.delete_garbage(key)
        not_deleted = self._flat_dict.delete_garbage(key2)
        self.assertTrue(deleted)
        self.assertFalse(not_deleted)

        self.assertIsNone(self._flat_dict.get(key))
        self.assertEqual(expected2, self._flat_dict.get(key2))

        with self.assertRaises(KeyError):
            self._flat_dict.is_garbage(bad_key)
        with self.assertRaises(KeyError):
            self._flat_dict.mark_as_garbage(bad_key)
示例#9
0
class GarbageCollectorTests(TestCase):
    def setUp(self):
        self.mock_redis = fakeredis.FakeStrictRedis()

        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

        service = MagicMock()
        service.config = {
            # Replicate arbitrary orc8r protos
            'state_protos': [{'proto_file': 'orc8r.protos.common_pb2',
                              'proto_msg': 'NetworkID',
                              'redis_key': NID_TYPE,
                              'state_scope': 'network'},
                             {'proto_file': 'orc8r.protos.service303_pb2',
                              'proto_msg': 'LogVerbosity',
                              'redis_key': LOG_TYPE,
                              'state_scope': 'gateway'},
                             ],
            'json_state': [{'redis_key': FOO_TYPE, 'state_scope': 'gateway'}]
        }
        service.loop = self.loop

        # Bind the rpc server to a free port
        self._rpc_server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10)
        )
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')
        # Add the servicer
        self._servicer = DummyStateServer()
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()
        # Create a rpc stub
        self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))

        serde1 = RedisSerde(NID_TYPE,
                            get_proto_serializer(),
                            get_proto_deserializer(NetworkID))
        serde2 = RedisSerde(FOO_TYPE,
                            get_json_serializer(),
                            get_json_deserializer())
        serde3 = RedisSerde(LOG_TYPE,
                            get_proto_serializer(),
                            get_proto_deserializer(LogVerbosity))

        self.nid_client = RedisFlatDict(self.mock_redis, serde1)
        self.foo_client = RedisFlatDict(self.mock_redis, serde2)
        self.log_client = RedisFlatDict(self.mock_redis, serde3)

        # Set up and start garbage collecting loop
        grpc_client_manager = GRPCClientManager(
            service_name="state",
            service_stub=StateServiceStub,
            max_client_reuse=60,
        )

        # mock the get_default_client function used to return the same
        # fakeredis object
        func_mock = mock.MagicMock(return_value=self.mock_redis)
        with patch('magma.state.redis_dicts.get_default_client', func_mock):
            # Start state garbage collection loop
            self.garbage_collector = GarbageCollector(service,
                                                      grpc_client_manager)

    def tearDown(self):
        self._rpc_server.stop(None)
        self.loop.close()

    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    def test_collect_states_to_delete(self):
        async def test():
            # Ensure setup is initialized properly
            self.nid_client.clear()
            self.foo_client.clear()
            self.log_client.clear()

            key = 'id1'
            self.nid_client[key] = NetworkID(id='foo')
            self.foo_client[key] = Foo("boo", 3)
            req = await self.garbage_collector._collect_states_to_delete()
            self.assertIsNone(req)

            self.nid_client.mark_as_garbage(key)
            self.foo_client.mark_as_garbage(key)
            req = await self.garbage_collector._collect_states_to_delete()
            self.assertEqual(2, len(req.ids))
            for state_id in req.ids:
                if state_id.type == NID_TYPE:
                    self.assertEqual('id1', state_id.deviceID)
                elif state_id.type == FOO_TYPE:
                    self.assertEqual('aaa-bbb:id1', state_id.deviceID)
                else:
                    self.fail("Unknown state type %s" % state_id.type)

            # Cleanup
            del self.foo_client[key]
            del self.nid_client[key]

        self.loop.run_until_complete(test())

    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_garbage_collect_success(self, get_rpc_mock):
        async def test():
            get_rpc_mock.return_value = self.channel
            self.nid_client.clear()
            self.foo_client.clear()
            self.log_client.clear()

            key = 'id1'
            foo = Foo("boo", 4)
            self.nid_client[key] = NetworkID(id='foo')
            self.foo_client[key] = foo
            self.nid_client.mark_as_garbage(key)
            self.foo_client.mark_as_garbage(key)
            req = await self.garbage_collector._collect_states_to_delete()
            self.assertEqual(2, len(req.ids))

            # Ensure all garbage collected objects get deleted from Redis
            await self.garbage_collector._send_to_state_service(req)
            self.assertEqual(0, len(self.nid_client.keys()))
            self.assertEqual(0, len(self.foo_client.keys()))
            self.assertEqual(0, len(self.nid_client.garbage_keys()))
            self.assertEqual(0, len(self.foo_client.garbage_keys()))

        self.loop.run_until_complete(test())

    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_garbage_collect_rpc_failure(self, get_rpc_mock):
        async def test():
            get_rpc_mock.return_value = self.channel
            self.nid_client.clear()
            self.foo_client.clear()
            self.log_client.clear()

            key = 'id1'
            self.nid_client[key] = NetworkID(id='foo')
            self.log_client[key] = LogVerbosity(verbosity=3)
            self.nid_client.mark_as_garbage(key)
            self.log_client.mark_as_garbage(key)

            req = await self.garbage_collector._collect_states_to_delete()
            self.assertEqual(2, len(req.ids))

            # Ensure objects on deleted from Redis on RPC failure
            await self.garbage_collector._send_to_state_service(req)
            self.assertEqual(0, len(self.nid_client.keys()))
            self.assertEqual(0, len(self.log_client.keys()))
            self.assertEqual(1, len(self.nid_client.garbage_keys()))
            self.assertEqual(1, len(self.log_client.garbage_keys()))

            # Cleanup
            del self.log_client[key]
            del self.nid_client[key]

        self.loop.run_until_complete(test())

    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_garbage_collect_with_state_update(self, get_rpc_mock):
        async def test():
            get_rpc_mock.return_value = self.channel
            self.nid_client.clear()
            self.foo_client.clear()
            self.log_client.clear()

            key = 'id1'
            foo = Foo("boo", 4)
            self.nid_client[key] = NetworkID(id='foo')
            self.foo_client[key] = foo
            self.nid_client.mark_as_garbage(key)
            self.foo_client.mark_as_garbage(key)
            req = await self.garbage_collector._collect_states_to_delete()
            self.assertEqual(2, len(req.ids))

            # Update one of the states, to ensure we don't delete valid state
            # from Redis
            expected = NetworkID(id='bar')
            self.nid_client[key] = expected

            # Ensure all garbage collected objects get deleted from Redis
            await self.garbage_collector._send_to_state_service(req)
            self.assertEqual(1, len(self.nid_client.keys()))
            self.assertEqual(0, len(self.foo_client.keys()))
            self.assertEqual(0, len(self.nid_client.garbage_keys()))
            self.assertEqual(0, len(self.foo_client.garbage_keys()))
            self.assertEqual(expected, self.nid_client[key])

        self.loop.run_until_complete(test())
示例#10
0
 def __init__(self):
     serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(),
                        get_json_deserializer())
     self._redis_dict = RedisFlatDict(get_default_client(), serde)
示例#11
0
class GatewayDirectoryServiceRpcServicer(GatewayDirectoryServiceServicer):
    """ gRPC based server for the Directoryd Gateway service. """
    def __init__(self):
        serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(),
                           get_json_deserializer())
        self._redis_dict = RedisFlatDict(get_default_client(), serde)

    def add_to_server(self, server):
        """ Add the servicer to a gRPC server """
        add_GatewayDirectoryServiceServicer_to_server(self, server)

    @return_void
    def UpdateRecord(self, request, context):
        """ Update the directory record of an object

        Args:
            request (UpdateRecordRequest): update record request
        """
        if len(request.id) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("ID argument cannot be empty in "
                                "UpdateRecordRequest")
            return

        # Lock Redis for requested key until update is complete
        with self._redis_dict.lock(request.id):
            hwid = get_gateway_hwid()
            record = self._redis_dict.get(request.id) or \
                     DirectoryRecord(location_history=[hwid], identifiers={})

            if record.location_history[0] != hwid:
                record.location_history = [hwid] + record.location_history

            for field_key in request.fields:
                record.identifiers[field_key] = request.fields[field_key]

            # Truncate location history to the five most recent hwid's
            record.location_history = \
                record.location_history[:LOCATION_MAX_LEN]
            self._redis_dict[request.id] = record

    @return_void
    def DeleteRecord(self, request, context):
        """ Delete the directory record for an ID

        Args:
             request (DeleteRecordRequest): delete record request
         """
        if len(request.id) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("ID argument cannot be empty in "
                                "DeleteRecordRequest")
            return

        # Lock Redis for requested key until delete is complete
        with self._redis_dict.lock(request.id):
            if request.id not in self._redis_dict:
                context.set_code(grpc.StatusCode.NOT_FOUND)
                context.set_details("Record for ID %s was not found." %
                                    request.id)
                return
            self._redis_dict.mark_as_garbage(request.id)

    def GetDirectoryField(self, request, context):
        """ Get the directory record field for an ID and key

        Args:
             request (GetDirectoryFieldRequest): get directory field request
         """
        if len(request.id) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("ID argument cannot be empty in "
                                "GetDirectoryFieldRequest")
            return
        if len(request.field_key) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("Field key argument cannot be empty in "
                                "GetDirectoryFieldRequest")
            return

        # Lock Redis for requested key until get is complete
        with self._redis_dict.lock(request.id):
            if request.id not in self._redis_dict:
                context.set_code(grpc.StatusCode.NOT_FOUND)
                context.set_details("Record for ID %s was not found." %
                                    request.id)
                return
            record = self._redis_dict[request.id]
            if request.field_key not in record.identifiers:
                context.set_code(grpc.StatusCode.NOT_FOUND)
                context.set_details("Field %s was not found in record for "
                                    "ID %s" % (request.field_key, request.id))
                return
            return DirectoryField(key=request.field_key,
                                  value=record.identifiers[request.field_key])

    def GetAllDirectoryRecords(self, request, context):
        """ Get all directory records

        Args:
             request (Void): void
        """
        response = AllDirectoryRecords()
        for key in self._redis_dict.keys():
            with self._redis_dict.lock(key):
                # Lookup may produce an exception if the key has been deleted
                # between the call to __iter__ and lock
                try:
                    stored_record = self._redis_dict[key]
                except KeyError:
                    continue
                directory_record = response.records.add()
                directory_record.id = key
                directory_record.location_history[:] = \
                    stored_record.location_history
                for identifier_key in stored_record.identifiers:
                    directory_record.fields[identifier_key] = \
                        stored_record.identifiers[identifier_key]

        return response
示例#12
0
class GatewayDirectoryServiceRpcServicer(GatewayDirectoryServiceServicer):
    """ gRPC based server for the Directoryd Gateway service. """
    def __init__(self):
        serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(),
                           get_json_deserializer())
        self._redis_dict = RedisFlatDict(get_default_client(), serde)

    def add_to_server(self, server):
        """ Add the servicer to a gRPC server """
        add_GatewayDirectoryServiceServicer_to_server(self, server)

    @return_void
    def UpdateRecord(self, request, context):
        """ Update the directory record of an object

        Args:
            request (UpdateRecordRequest): update record request
        """
        if len(request.id) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("ID argument cannot be empty in "
                                "UpdateRecordRequest")
            return

        # Lock Redis for requested key until update is complete
        with self._redis_dict.lock(request.id):
            hwid = get_gateway_hwid()
            record = self._redis_dict.get(request.id) or \
                     DirectoryRecord(location_history=[hwid], identifiers={})

            if record.location_history[0] != hwid:
                record.location_history = [hwid] + record.location_history

            for field in request.fields:
                record.identifiers[field.key] = field.value

            # Truncate location history to the five most recent hwid's
            record.location_history = \
                record.location_history[:LOCATION_MAX_LEN]
            self._redis_dict[request.id] = record

    @return_void
    def DeleteRecord(self, request, context):
        """ Delete the directory record for an ID

        Args:
             request (DeleteRecordRequest): delete record request
         """
        if len(request.id) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("ID argument cannot be empty in "
                                "DeleteRecordRequest")
            return

        # Lock Redis for requested key until delete is complete
        with self._redis_dict.lock(request.id):
            if request.id not in self._redis_dict:
                context.set_code(grpc.StatusCode.NOT_FOUND)
                context.set_details("Record for ID %s was not found." %
                                    request.id)
                return
            # TODO: Set record to be garbage collected rather than deleting
            # directly
            del self._redis_dict[request.id]

    def GetDirectoryField(self, request, context):
        """ Get the directory record field for an ID and key

        Args:
             request (GetDirectoryFieldRequest): get directory field request
         """
        if len(request.id) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("ID argument cannot be empty in "
                                "GetDirectoryFieldRequest")
            return
        if len(request.field_key) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("Field key argument cannot be empty in "
                                "GetDirectoryFieldRequest")
            return

        # Lock Redis for requested key until get is complete
        with self._redis_dict.lock(request.id):
            if request.id not in self._redis_dict:
                context.set_code(grpc.StatusCode.NOT_FOUND)
                context.set_details("Record for ID %s was not found." %
                                    request.id)
                return
            record = self._redis_dict[request.id]
            if request.field_key not in record.identifiers:
                context.set_code(grpc.StatusCode.NOT_FOUND)
                context.set_details("Field %s was not found in record for "
                                    "ID %s" % (request.field_key, request.id))
                return
            return DirectoryField(key=request.field_key,
                                  value=record.identifiers[request.field_key])
示例#13
0
class StateReplicator(SDWatchdogTask):
    """
    StateReplicator periodically fetches all configured state from Redis,
    reporting any updates to the Orchestrator State service.
    """
    def __init__(self, service: MagmaService,
                 grpc_client_manager: GRPCClientManager):
        super().__init__(DEFAULT_SYNC_INTERVAL, service.loop)
        self._service = service
        # In memory mapping of states to version
        self._state_versions = {}
        # Serdes for each type of state to replicate
        self._serdes = {}
        self._get_proto_redis_serdes()
        self._get_json_redis_serdes()
        self._redis_client = RedisFlatDict(get_default_client(), self._serdes)

        # _grpc_client_manager to manage grpc client recyclings
        self._grpc_client_manager = grpc_client_manager

        # Flag to indicate if resync has completed successfully.
        # Replication cannot proceed until this flag is True
        self._has_resync_completed = False

    def _get_proto_redis_serdes(self):
        state_protos = self._service.config.get('state_protos', []) or []
        for proto_cfg in state_protos:
            is_invalid_cfg = 'proto_msg' not in proto_cfg or \
                             'proto_file' not in proto_cfg or \
                             'redis_key' not in proto_cfg or \
                             'state_scope' not in proto_cfg
            if is_invalid_cfg:
                logging.warning(
                    "Invalid proto config found in state_protos "
                    "configuration: %s", proto_cfg)
                continue
            try:
                proto_module = importlib.import_module(proto_cfg['proto_file'])
                msg = getattr(proto_module, proto_cfg['proto_msg'])
                redis_key = proto_cfg['redis_key']
                logging.info('Initializing RedisSerde for proto state %s',
                             proto_cfg['redis_key'])
                serde = StateSerde(redis_key, get_proto_serializer(),
                                   get_proto_deserializer(msg),
                                   proto_cfg['state_scope'], PROTO_FORMAT)
                self._serdes[redis_key] = serde

            except (ImportError, AttributeError) as err:
                logging.error(err)

    def _get_json_redis_serdes(self):
        json_state = self._service.config.get('json_state', []) or []
        for json_cfg in json_state:
            is_invalid_cfg = 'redis_key' not in json_cfg or \
                             'state_scope' not in json_cfg
            if is_invalid_cfg:
                logging.warning(
                    "Invalid json state config found in json_state"
                    "configuration: %s", json_cfg)
                continue
            logging.info('Initializing RedisSerde for json state %s',
                         json_cfg['redis_key'])
            redis_key = json_cfg['redis_key']
            serde = StateSerde(redis_key, get_json_serializer(),
                               get_json_deserializer(),
                               json_cfg['state_scope'], JSON_FORMAT)
            self._serdes[redis_key] = serde

    async def _run(self):
        if not self._has_resync_completed:
            try:
                await self._resync()
            except grpc.RpcError as err:
                logging.error("GRPC call failed for initial state re-sync: %s",
                              err)
                return
        request = await self._collect_states_to_replicate()
        if request is not None:
            await self._send_to_state_service(request)

    async def _resync(self):
        states_to_sync = []
        for key in self._redis_client:
            try:
                idval, state_type = self._parse_key(key)
            except ValueError as err:
                logging.debug(err)
                continue

            state_scope = self._serdes[state_type].state_scope
            version = self._redis_client.get_version(idval, state_type)
            device_id = self.make_scoped_device_id(idval, state_scope)
            state_id = StateID(type=state_type, deviceID=device_id)
            id_and_version = IDAndVersion(id=state_id, version=version)
            states_to_sync.append(id_and_version)

        if len(states_to_sync) == 0:
            logging.debug("Not re-syncing state. No local state found.")
            return
        state_client = self._grpc_client_manager.get_client()
        request = SyncStatesRequest(states=states_to_sync)
        response = await grpc_async_wrapper(
            state_client.SyncStates.future(
                request,
                DEFAULT_GRPC_TIMEOUT,
            ), self._loop)
        unsynced_states = set()
        for id_and_version in response.unsyncedStates:
            unsynced_states.add(
                (id_and_version.id.type, id_and_version.id.deviceID))
        # Update in-memory map to add already synced states
        for state in request.states:
            in_mem_key = self.make_mem_key(state.id.deviceID, state.id.type)
            if (state.id.type, state.id.deviceID) not in unsynced_states:
                self._state_versions[in_mem_key] = state.version

        self._has_resync_completed = True
        logging.info("Successfully resynced state with Orchestrator!")

    async def _collect_states_to_replicate(self):
        states_to_report = []
        for key in self._redis_client:
            try:
                idval, state_type = self._parse_key(key)
            except ValueError as err:
                logging.debug(err)
                continue

            state_scope = self._serdes[state_type].state_scope
            device_id = self.make_scoped_device_id(idval, state_scope)
            in_mem_key = self.make_mem_key(device_id, state_type)
            redis_version = self._redis_client.get_version(idval, state_type)

            if in_mem_key in self._state_versions and \
                    self._state_versions[in_mem_key] == redis_version:
                continue

            redis_state = self._redis_client.get(key)
            if self._serdes[state_type].state_format == PROTO_FORMAT:
                state_to_serialize = MessageToDict(redis_state)
            else:
                state_to_serialize = redis_state
            serialized_json_state = json.dumps(state_to_serialize)
            state_proto = State(type=state_type,
                                deviceID=device_id,
                                value=serialized_json_state.encode("utf-8"),
                                version=redis_version)

            states_to_report.append(state_proto)

        if len(states_to_report) == 0:
            logging.debug("Not replicating state. No state has changed!")
            return None
        return ReportStatesRequest(states=states_to_report)

    async def _send_to_state_service(self, request: ReportStatesRequest):
        state_client = self._grpc_client_manager.get_client()
        try:
            response = await grpc_async_wrapper(
                state_client.ReportStates.future(
                    request,
                    DEFAULT_GRPC_TIMEOUT,
                ), self._loop)

        except grpc.RpcError as err:
            logging.error("GRPC call failed for state replication: %s", err)
        else:
            unreplicated_states = set()
            for idAndError in response.unreportedStates:
                logging.warning("Failed to replicate state for (%s,%s): %s",
                                idAndError.type, idAndError.deviceID,
                                idAndError.error)
                unreplicated_states.add((idAndError.type, idAndError.deviceID))
            # Update in-memory map for successfully reported states
            for state in request.states:
                if (state.type, state.deviceID) in unreplicated_states:
                    continue
                in_mem_key = self.make_mem_key(state.deviceID, state.type)
                self._state_versions[in_mem_key] = state.version

                logging.debug(
                    "Successfully replicated state for: "
                    "deviceID: %s,"
                    "type: %s, "
                    "version: %d", state.deviceID, state.type, state.version)
        finally:
            # reset timeout to config-specified + some buffer
            self.set_timeout(self._interval * 2)

    def _parse_key(self, key):
        split_key = key.split(REDIS_KEY_DELIMITER, 1)
        if len(split_key) != 2:
            raise ValueError("Redis key: %s is not of format <id>:<<type>. "
                             "Not replicating." % key)
        idval = split_key[0]
        state_type = split_key[1]
        if state_type not in self._serdes:
            raise ValueError("No serde found for state type: %s. "
                             "Not replicating key: %s" % (state_type, idval))

        return idval, state_type

    @staticmethod
    def make_mem_key(device_id, state_type):
        """
        Create a key of the format <id>:<type>
        """
        return device_id + ":" + state_type

    @staticmethod
    def make_scoped_device_id(id, scope):
        """
        Create a deviceID of the format <id> for scope 'network'
        Otherwise create a key of the format <hwid>:<id> for 'gateway' or
        unrecognized scope.
        """
        if scope == "network":
            return id
        else:
            return snowflake.snowflake() + ":" + id
示例#14
0
    def setUp(self):
        self.mock_redis = fakeredis.FakeStrictRedis()

        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

        service = MagicMock()
        service.config = {
            # Replicate arbitrary orc8r protos
            'state_protos': [
                {
                    'proto_file': 'orc8r.protos.common_pb2',
                    'proto_msg': 'NetworkID',
                    'redis_key': NID_TYPE,
                    'state_scope': 'network',
                },
                {
                    'proto_file': 'orc8r.protos.common_pb2',
                    'proto_msg': 'IDList',
                    'redis_key': IDList_TYPE,
                    'state_scope': 'gateway',
                },
                {
                    'proto_file': 'orc8r.protos.service303_pb2',
                    'proto_msg': 'LogVerbosity',
                    'redis_key': LOG_TYPE,
                    'state_scope': 'gateway',
                },
            ],
            'json_state': [{
                'redis_key': FOO_TYPE,
                'state_scope': 'network'
            }],
        }
        service.loop = self.loop

        # Bind the rpc server to a free port
        self._rpc_server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10), )
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')
        # Add the servicer
        self._servicer = DummyStateServer()
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()
        # Create a rpc stub
        self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))

        serde1 = RedisSerde(
            NID_TYPE,
            get_proto_serializer(),
            get_proto_deserializer(NetworkID),
        )
        serde2 = RedisSerde(
            IDList_TYPE,
            get_proto_serializer(),
            get_proto_deserializer(IDList),
        )
        serde3 = RedisSerde(
            LOG_TYPE,
            get_proto_serializer(),
            get_proto_deserializer(LogVerbosity),
        )
        serde4 = RedisSerde(
            FOO_TYPE,
            get_json_serializer(),
            get_json_deserializer(),
        )

        self.nid_client = RedisFlatDict(self.mock_redis, serde1)
        self.idlist_client = RedisFlatDict(self.mock_redis, serde2)
        self.log_client = RedisFlatDict(self.mock_redis, serde3)
        self.foo_client = RedisFlatDict(self.mock_redis, serde4)

        # Set up and start state replicating loop
        grpc_client_manager = GRPCClientManager(
            service_name="state",
            service_stub=StateServiceStub,
            max_client_reuse=60,
        )

        # mock the get_default_client function used to return the same
        # fakeredis object
        func_mock = mock.MagicMock(return_value=self.mock_redis)
        with mock.patch(
                'magma.state.redis_dicts.get_default_client',
                func_mock,
        ):
            garbage_collector = GarbageCollector(service, grpc_client_manager)

            self.state_replicator = StateReplicator(
                service=service,
                garbage_collector=garbage_collector,
                grpc_client_manager=grpc_client_manager,
            )
        self.state_replicator.start()
示例#15
0
class GatewayDirectoryServiceRpcServicer(GatewayDirectoryServiceServicer):
    """gRPC based server for the Directoryd Gateway service"""
    def __init__(self, print_grpc_payload: bool = False):
        """Initialize Directoryd grpc endpoints."""
        serde = RedisSerde(DIRECTORYD_REDIS_TYPE, get_json_serializer(),
                           get_json_deserializer())
        self._redis_dict = RedisFlatDict(get_default_client(), serde)
        self._print_grpc_payload = print_grpc_payload

        if self._print_grpc_payload:
            logging.info("Printing GRPC messages")

    def add_to_server(self, server):
        """ Add the servicer to a gRPC server """
        add_GatewayDirectoryServiceServicer_to_server(self, server)

    @return_void
    def UpdateRecord(self, request, context):
        """ Update the directory record of an object

        Args:
            request (UpdateRecordRequest): update record request
        """
        logging.debug("UpdateRecord request received")
        self._print_grpc(request)
        if len(request.id) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("ID argument cannot be empty in "
                                "UpdateRecordRequest")
            return

        try:
            # Lock Redis for requested key until update is complete
            with self._redis_dict.lock(request.id):
                hwid = get_gateway_hwid()
                record = self._redis_dict.get(request.id) or \
                         DirectoryRecord(location_history=[hwid],
                                         identifiers={})

                if record.location_history[0] != hwid:
                    record.location_history = [hwid] + record.location_history

                for field_key in request.fields:
                    record.identifiers[field_key] = request.fields[field_key]

                # Truncate location history to the five most recent hwid's
                record.location_history = \
                    record.location_history[:LOCATION_MAX_LEN]
                self._redis_dict[request.id] = record
        except (RedisError, LockError) as e:
            logging.error(e)
            context.set_code(grpc.StatusCode.UNAVAILABLE)
            context.set_details("Could not connect to redis: %s" % e)

    @return_void
    def DeleteRecord(self, request, context):
        """ Delete the directory record for an ID

        Args:
             request (DeleteRecordRequest): delete record request
        """
        logging.debug("DeleteRecord request received")
        self._print_grpc(request)
        if len(request.id) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("ID argument cannot be empty in "
                                "DeleteRecordRequest")
            return

        # Lock Redis for requested key until delete is complete
        try:
            with self._redis_dict.lock(request.id):
                if request.id not in self._redis_dict:
                    context.set_code(grpc.StatusCode.NOT_FOUND)
                    context.set_details("Record for ID %s was not found." %
                                        request.id)
                    return
                self._redis_dict.mark_as_garbage(request.id)
        except (RedisError, LockError) as e:
            logging.error(e)
            context.set_code(grpc.StatusCode.UNAVAILABLE)
            context.set_details("Could not connect to redis: %s" % e)

    def GetDirectoryField(self, request, context):
        """ Get the directory record field for an ID and key

        Args:
             request (GetDirectoryFieldRequest): get directory field request
        """
        logging.debug("GetDirectoryField request received")
        self._print_grpc(request)
        if len(request.id) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("ID argument cannot be empty in "
                                "GetDirectoryFieldRequest")
            return
        if len(request.field_key) == 0:
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details("Field key argument cannot be empty in "
                                "GetDirectoryFieldRequest")
            response = DirectoryField()
            self._print_grpc(response)
            return response

        # Lock Redis for requested key until get is complete
        try:
            with self._redis_dict.lock(request.id):
                if request.id not in self._redis_dict:
                    context.set_code(grpc.StatusCode.NOT_FOUND)
                    context.set_details("Record for ID %s was not found." %
                                        request.id)
                    return DirectoryField()
                record = self._redis_dict[request.id]
        except (RedisError, LockError) as e:
            logging.error(e)
            context.set_code(grpc.StatusCode.UNAVAILABLE)
            context.set_details("Could not connect to redis: %s" % e)
            response = DirectoryField()
            self._print_grpc(response)
            return response

        if request.field_key not in record.identifiers:
            context.set_code(grpc.StatusCode.NOT_FOUND)
            context.set_details("Field %s was not found in record for "
                                "ID %s" % (request.field_key, request.id))
            return DirectoryField()

        response = DirectoryField(key=request.field_key,
                                  value=record.identifiers[request.field_key])
        self._print_grpc(response)
        return response

    def GetAllDirectoryRecords(self, request, context):
        """ Get all directory records

        Args:
             request (Void): void
        """
        logging.debug("GetAllDirectoryRecords request received")
        self._print_grpc(request)
        response = AllDirectoryRecords()
        try:
            redis_keys = self._redis_dict.keys()
        except RedisError as e:
            logging.error(e)
            context.set_code(grpc.StatusCode.UNAVAILABLE)
            context.set_details("Could not connect to redis: %s" % e)
            self._print_grpc(request)
            return response

        for key in redis_keys:
            try:
                with self._redis_dict.lock(key):
                    # Lookup may produce an exception if the key has been
                    # deleted between the call to __iter__ and lock
                    stored_record = self._redis_dict[key]
            except (RedisError, LockError) as e:
                logging.error(e)
                context.set_code(grpc.StatusCode.UNAVAILABLE)
                context.set_details("Could not connect to redis: %s" % e)
                self._print_grpc(response)
                return response
            except KeyError:
                continue

            directory_record = response.records.add()
            directory_record.id = key
            directory_record.location_history[:] = \
                stored_record.location_history
            for identifier_key in stored_record.identifiers:
                directory_record.fields[identifier_key] = \
                    stored_record.identifiers[identifier_key]

        self._print_grpc(response)
        return response

    def _print_grpc(self, message):
        if self._print_grpc_payload:
            log_msg = "{} {}".format(message.DESCRIPTOR.full_name,
                                     MessageToJson(message))
            # add indentation
            padding = 2 * ' '
            log_msg = ''.join("{}{}".format(padding, line)
                              for line in log_msg.splitlines(True))

            log_msg = "GRPC message:\n{}".format(log_msg)
            logging.info(log_msg)
示例#16
0
class StateReplicatorTests(TestCase):
    @mock.patch("redis.Redis", MockRedis)
    def setUp(self):

        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

        service = MagicMock()
        service.config = {
            # Replicate arbitrary orc8r protos
            'state_protos': [{
                'proto_file': 'orc8r.protos.common_pb2',
                'proto_msg': 'NetworkID',
                'redis_key': NID_TYPE,
                'state_scope': 'network'
            }, {
                'proto_file': 'orc8r.protos.common_pb2',
                'proto_msg': 'IDList',
                'redis_key': IDList_TYPE,
                'state_scope': 'gateway'
            }, {
                'proto_file': 'orc8r.protos.service303_pb2',
                'proto_msg': 'LogVerbosity',
                'redis_key': LOG_TYPE,
                'state_scope': 'gateway'
            }],
            'json_state': [{
                'redis_key': FOO_TYPE,
                'state_scope': 'network'
            }]
        }
        service.loop = self.loop

        # Bind the rpc server to a free port
        self._rpc_server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10))
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')
        # Add the servicer
        self._servicer = DummyStateServer()
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()
        # Create a rpc stub
        self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))

        serde1 = RedisSerde(NID_TYPE, get_proto_serializer(),
                            get_proto_deserializer(NetworkID))
        serde2 = RedisSerde(IDList_TYPE, get_proto_serializer(),
                            get_proto_deserializer(IDList))
        serde3 = RedisSerde(LOG_TYPE, get_proto_serializer(),
                            get_proto_deserializer(LogVerbosity))
        serde4 = RedisSerde(FOO_TYPE, get_json_serializer(),
                            get_json_deserializer())

        self.nid_client = RedisFlatDict(get_default_client(), serde1)
        self.idlist_client = RedisFlatDict(get_default_client(), serde2)
        self.log_client = RedisFlatDict(get_default_client(), serde3)
        self.foo_client = RedisFlatDict(get_default_client(), serde4)

        # Set up and start state replicating loop
        grpc_client_manager = GRPCClientManager(
            service_name="state",
            service_stub=StateServiceStub,
            max_client_reuse=60,
        )

        garbage_collector = GarbageCollector(service, grpc_client_manager)

        self.state_replicator = StateReplicator(
            service=service,
            garbage_collector=garbage_collector,
            grpc_client_manager=grpc_client_manager,
        )
        self.state_replicator.start()

    @mock.patch("redis.Redis", MockRedis)
    def tearDown(self):
        self._rpc_server.stop(None)
        self.state_replicator.stop()
        self.loop.close()

    def convert_msg_to_state(self, redis_state, is_proto=True):
        if is_proto:
            json_converted_state = MessageToDict(redis_state)
            serialized_json_state = json.dumps(json_converted_state)
        else:
            serialized_json_state = jsonpickle.encode(redis_state)
        return serialized_json_state.encode("utf-8")

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    def test_collect_states_to_replicate(self):
        async def test():
            # Ensure setup is initialized properly
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'

            self.nid_client[key] = NetworkID(id='foo')
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            self.foo_client[key] = Foo("boo", 3)

            exp1 = self.convert_msg_to_state(self.nid_client[key])
            exp2 = self.convert_msg_to_state(self.idlist_client[key])
            exp3 = self.convert_msg_to_state(self.foo_client[key], False)

            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(3, len(req.states))
            for state in req.states:
                if state.type == NID_TYPE:
                    self.assertEqual('id1', state.deviceID)
                    self.assertEqual(1, state.version)
                    self.assertEqual(exp1, state.value)
                elif state.type == IDList_TYPE:
                    self.assertEqual('aaa-bbb:id1', state.deviceID)
                    self.assertEqual(1, state.version)
                    self.assertEqual(exp2, state.value)
                elif state.type == FOO_TYPE:
                    self.assertEqual('id1', state.deviceID)
                    self.assertEqual(1, state.version)
                    self.assertEqual(exp3, state.value)
                else:
                    self.fail("Unknown state type %s" % state.type)

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_replicate_states_success(self, get_rpc_mock):
        async def test():
            get_rpc_mock.return_value = self.channel

            # Ensure setup is initialized properly
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'
            foo = Foo("boo", 4)
            self.nid_client[key] = NetworkID(id='foo')
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            self.foo_client[key] = foo
            # Increment version
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])

            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(3, len(req.states))

            # Ensure in-memory map updates properly
            await self.state_replicator._send_to_state_service(req)
            self.assertEqual(3, len(self.state_replicator._state_versions))
            mem_key1 = make_mem_key('id1', NID_TYPE)
            mem_key2 = make_mem_key('aaa-bbb:id1', IDList_TYPE)
            mem_key3 = make_mem_key('id1', FOO_TYPE)
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key1])
            self.assertEqual(2,
                             self.state_replicator._state_versions[mem_key2])
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key3])

            # Now add new state and update some existing state
            key2 = 'id2'
            self.nid_client[key2] = NetworkID(id='bar')
            self.idlist_client[key] = IDList(ids=['bar', 'foo'])
            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(2, len(req.states))

            # Ensure in-memory map updates properly
            await self.state_replicator._send_to_state_service(req)
            self.assertEqual(4, len(self.state_replicator._state_versions))
            mem_key4 = make_mem_key('id2', NID_TYPE)
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key1])
            self.assertEqual(3,
                             self.state_replicator._state_versions[mem_key2])
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key3])
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key4])

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_unreplicated_states(self, get_grpc_mock):
        async def test():
            get_grpc_mock.return_value = self.channel

            # Add initial state to be replicated
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'
            key2 = 'id2'
            self.nid_client[key] = NetworkID(id='foo')
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            # Increment version
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            # Set state that will be 'unreplicated'
            self.log_client[key2] = LogVerbosity(verbosity=5)

            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(3, len(req.states))

            # Ensure in-memory map updates properly for successful replications
            await self.state_replicator._send_to_state_service(req)
            self.assertEqual(2, len(self.state_replicator._state_versions))
            mem_key1 = make_mem_key('id1', NID_TYPE)
            mem_key2 = make_mem_key('aaa-bbb:id1', IDList_TYPE)
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key1])
            self.assertEqual(2,
                             self.state_replicator._state_versions[mem_key2])

            # Now run again, ensuring only the state the wasn't replicated
            # will be sent again
            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(1, len(req.states))
            self.assertEqual('aaa-bbb:id2', req.states[0].deviceID)
            self.assertEqual(LOG_TYPE, req.states[0].type)

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_resync_success(self, get_grpc_mock):
        async def test():
            get_grpc_mock.return_value = self.channel
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'
            # Set state that will be 'unsynced'
            self.nid_client[key] = NetworkID(id='foo')
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            # Increment state's version
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])

            await self.state_replicator._resync()
            self.assertEqual(True, self.state_replicator._has_resync_completed)
            self.assertEqual(1, len(self.state_replicator._state_versions))
            mem_key = make_mem_key('aaa-bbb:id1', IDList_TYPE)
            self.assertEqual(2, self.state_replicator._state_versions[mem_key])

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_resync_failure(self, get_grpc_mock):
        async def test():
            get_grpc_mock.return_value = self.channel
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            # Set state that will trigger the RpcError
            log_key = 'id1'
            self.log_client[log_key] = LogVerbosity(verbosity=5)

            try:
                await self.state_replicator._resync()
            except grpc.RpcError:
                pass

            self.assertEqual(False,
                             self.state_replicator._has_resync_completed)
            self.assertEqual(0, len(self.state_replicator._state_versions))

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_deleted_replicated_state(self, get_grpc_mock):
        async def test():
            get_grpc_mock.return_value = self.channel
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'
            self.nid_client[key] = NetworkID(id='foo')
            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(1, len(req.states))

            # Ensure in-memory map updates properly
            await self.state_replicator._send_to_state_service(req)
            self.assertEqual(1, len(self.state_replicator._state_versions))
            mem_key1 = make_mem_key('id1', NID_TYPE)
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key1])

            # Now delete state and ensure in-memory map gets updated properly
            del self.nid_client[key]
            req = await self.state_replicator._collect_states_to_replicate()
            self.assertIsNone(req)

            await self.state_replicator._cleanup_deleted_keys()
            self.assertFalse(key in self.state_replicator._state_versions)

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())
示例#17
0
class RedisDictTests(TestCase):
    """
    Tests for the RedisHashDict and RedisFlatDict containers
    """
    @mock.patch("redis.Redis", MockRedis)
    def setUp(self):
        client = get_default_client()
        # Use arbitrary orc8r proto to test with
        self._hash_dict = RedisHashDict(
            client,
            "unittest",
            get_proto_serializer(),
            get_proto_deserializer(LogVerbosity))

        serdes = {}
        serdes['log_verbosity'] = RedisSerde('log_verbosity',
                                get_proto_serializer(),
                                get_proto_deserializer(LogVerbosity))
        self._flat_dict = RedisFlatDict(client, serdes)

    @mock.patch("redis.Redis", MockRedis)
    def test_hash_insert(self):
        expected = LogVerbosity(verbosity=0)
        expected2 = LogVerbosity(verbosity=1)

        # insert proto
        self._hash_dict['key1'] = expected
        version = self._hash_dict.get_version("key1")
        actual = self._hash_dict['key1']
        self.assertEqual(1, version)
        self.assertEqual(expected, actual)

        # update proto
        self._hash_dict['key1'] = expected2
        version2 = self._hash_dict.get_version("key1")
        actual2 = self._hash_dict['key1']
        self.assertEqual(2, version2)
        self.assertEqual(expected2, actual2)

    @mock.patch("redis.Redis", MockRedis)
    def test_missing_version(self):
        missing_version = self._hash_dict.get_version("key2")
        self.assertEqual(0, missing_version)

    @mock.patch("redis.Redis", MockRedis)
    def test_hash_delete(self):
        expected = LogVerbosity(verbosity=2)
        self._hash_dict['key3'] = expected

        actual = self._hash_dict['key3']
        self.assertEqual(expected, actual)

        self._hash_dict.pop('key3')
        self.assertRaises(KeyError, self._hash_dict.__getitem__, 'key3')

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_insert(self):
        expected = LogVerbosity(verbosity=5)
        expected2 = LogVerbosity(verbosity=1)

        # insert proto
        self._flat_dict['key1:log_verbosity'] = expected
        version = self._flat_dict.get_version("key1", "log_verbosity")
        actual = self._flat_dict['key1:log_verbosity']
        self.assertEqual(1, version)
        self.assertEqual(expected, actual)

        # update proto
        self._flat_dict["key1:log_verbosity"] = expected2
        version2 = self._flat_dict.get_version("key1", "log_verbosity")
        actual2 = self._flat_dict["key1:log_verbosity"]
        self.assertEqual(2, version2)
        self.assertEqual(expected2, actual2)

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_missing_version(self):
        missing_version = self._flat_dict.get_version("key2", "log_verbosity")
        self.assertEqual(0, missing_version)

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_invalid_key(self):
        expected = LogVerbosity(verbosity=5)
        self.assertRaises(ValueError, self._flat_dict.__setitem__, 'key3',
                          expected)

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_invalid_serde(self):
        expected = LogVerbosity(verbosity=5)
        self.assertRaises(ValueError, self._flat_dict.__setitem__,
                          'key3:missing_serde', expected)

    @mock.patch("redis.Redis", MockRedis)
    def test_flat_delete(self):
        expected = LogVerbosity(verbosity=2)
        self._flat_dict['key3:log_verbosity'] = expected

        actual = self._flat_dict['key3:log_verbosity']
        self.assertEqual(expected, actual)

        self._flat_dict.pop('key3:log_verbosity')
        self.assertRaises(KeyError, self._flat_dict.__getitem__,
                          'key3:log_verbosity')