Beispiel #1
0
def main():
    """
    main() for gateway state replication service
    """
    service = MagmaService('state', mconfigs_pb2.State())

    # Optionally pipe errors to Sentry
    sentry_init()

    # _grpc_client_manager to manage grpc client recycling
    grpc_client_manager = GRPCClientManager(
        service_name="state",
        service_stub=StateServiceStub,
        max_client_reuse=60,
    )

    # Garbage collector propagates state deletions back to Orchestrator
    garbage_collector = GarbageCollector(service, grpc_client_manager)

    # Start state replication loop
    state_manager = StateReplicator(service, garbage_collector,
                                    grpc_client_manager)
    state_manager.start()

    # Run the service loop
    service.run()

    # Cleanup the service
    service.close()
    def setUp(self):

        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

        service = MagicMock()
        service.config = {
            # Replicate arbitrary orc8r protos
            'state_protos': [{
                'proto_file': 'orc8r.protos.common_pb2',
                'proto_msg': 'NetworkID',
                'redis_key': NID_TYPE,
                'state_scope': 'network'
            }, {
                'proto_file': 'orc8r.protos.common_pb2',
                'proto_msg': 'IDList',
                'redis_key': IDList_TYPE,
                'state_scope': 'gateway'
            }, {
                'proto_file': 'orc8r.protos.service303_pb2',
                'proto_msg': 'LogVerbosity',
                'redis_key': LOG_TYPE,
                'state_scope': 'gateway'
            }],
            'json_state': [{
                'redis_key': FOO_TYPE,
                'state_scope': 'network'
            }]
        }
        service.loop = self.loop

        # Bind the rpc server to a free port
        self._rpc_server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10))
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')
        # Add the servicer
        self._servicer = DummyStateServer()
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()
        # Create a rpc stub
        self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))

        serde1 = RedisSerde(NID_TYPE, get_proto_serializer(),
                            get_proto_deserializer(NetworkID))
        serde2 = RedisSerde(IDList_TYPE, get_proto_serializer(),
                            get_proto_deserializer(IDList))
        serde3 = RedisSerde(LOG_TYPE, get_proto_serializer(),
                            get_proto_deserializer(LogVerbosity))
        serde4 = RedisSerde(FOO_TYPE, get_json_serializer(),
                            get_json_deserializer())

        self.nid_client = RedisFlatDict(get_default_client(), serde1)
        self.idlist_client = RedisFlatDict(get_default_client(), serde2)
        self.log_client = RedisFlatDict(get_default_client(), serde3)
        self.foo_client = RedisFlatDict(get_default_client(), serde4)

        # Set up and start state replicating loop
        grpc_client_manager = GRPCClientManager(
            service_name="state",
            service_stub=StateServiceStub,
            max_client_reuse=60,
        )

        garbage_collector = GarbageCollector(service, grpc_client_manager)

        self.state_replicator = StateReplicator(
            service=service,
            garbage_collector=garbage_collector,
            grpc_client_manager=grpc_client_manager,
        )
        self.state_replicator.start()
class StateReplicatorTests(TestCase):
    @mock.patch("redis.Redis", MockRedis)
    def setUp(self):

        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

        service = MagicMock()
        service.config = {
            # Replicate arbitrary orc8r protos
            'state_protos': [{
                'proto_file': 'orc8r.protos.common_pb2',
                'proto_msg': 'NetworkID',
                'redis_key': NID_TYPE,
                'state_scope': 'network'
            }, {
                'proto_file': 'orc8r.protos.common_pb2',
                'proto_msg': 'IDList',
                'redis_key': IDList_TYPE,
                'state_scope': 'gateway'
            }, {
                'proto_file': 'orc8r.protos.service303_pb2',
                'proto_msg': 'LogVerbosity',
                'redis_key': LOG_TYPE,
                'state_scope': 'gateway'
            }],
            'json_state': [{
                'redis_key': FOO_TYPE,
                'state_scope': 'network'
            }]
        }
        service.loop = self.loop

        # Bind the rpc server to a free port
        self._rpc_server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10))
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')
        # Add the servicer
        self._servicer = DummyStateServer()
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()
        # Create a rpc stub
        self.channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))

        serde1 = RedisSerde(NID_TYPE, get_proto_serializer(),
                            get_proto_deserializer(NetworkID))
        serde2 = RedisSerde(IDList_TYPE, get_proto_serializer(),
                            get_proto_deserializer(IDList))
        serde3 = RedisSerde(LOG_TYPE, get_proto_serializer(),
                            get_proto_deserializer(LogVerbosity))
        serde4 = RedisSerde(FOO_TYPE, get_json_serializer(),
                            get_json_deserializer())

        self.nid_client = RedisFlatDict(get_default_client(), serde1)
        self.idlist_client = RedisFlatDict(get_default_client(), serde2)
        self.log_client = RedisFlatDict(get_default_client(), serde3)
        self.foo_client = RedisFlatDict(get_default_client(), serde4)

        # Set up and start state replicating loop
        grpc_client_manager = GRPCClientManager(
            service_name="state",
            service_stub=StateServiceStub,
            max_client_reuse=60,
        )

        garbage_collector = GarbageCollector(service, grpc_client_manager)

        self.state_replicator = StateReplicator(
            service=service,
            garbage_collector=garbage_collector,
            grpc_client_manager=grpc_client_manager,
        )
        self.state_replicator.start()

    @mock.patch("redis.Redis", MockRedis)
    def tearDown(self):
        self._rpc_server.stop(None)
        self.state_replicator.stop()
        self.loop.close()

    def convert_msg_to_state(self, redis_state, is_proto=True):
        if is_proto:
            json_converted_state = MessageToDict(redis_state)
            serialized_json_state = json.dumps(json_converted_state)
        else:
            serialized_json_state = jsonpickle.encode(redis_state)
        return serialized_json_state.encode("utf-8")

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    def test_collect_states_to_replicate(self):
        async def test():
            # Ensure setup is initialized properly
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'

            self.nid_client[key] = NetworkID(id='foo')
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            self.foo_client[key] = Foo("boo", 3)

            exp1 = self.convert_msg_to_state(self.nid_client[key])
            exp2 = self.convert_msg_to_state(self.idlist_client[key])
            exp3 = self.convert_msg_to_state(self.foo_client[key], False)

            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(3, len(req.states))
            for state in req.states:
                if state.type == NID_TYPE:
                    self.assertEqual('id1', state.deviceID)
                    self.assertEqual(1, state.version)
                    self.assertEqual(exp1, state.value)
                elif state.type == IDList_TYPE:
                    self.assertEqual('aaa-bbb:id1', state.deviceID)
                    self.assertEqual(1, state.version)
                    self.assertEqual(exp2, state.value)
                elif state.type == FOO_TYPE:
                    self.assertEqual('id1', state.deviceID)
                    self.assertEqual(1, state.version)
                    self.assertEqual(exp3, state.value)
                else:
                    self.fail("Unknown state type %s" % state.type)

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_replicate_states_success(self, get_rpc_mock):
        async def test():
            get_rpc_mock.return_value = self.channel

            # Ensure setup is initialized properly
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'
            foo = Foo("boo", 4)
            self.nid_client[key] = NetworkID(id='foo')
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            self.foo_client[key] = foo
            # Increment version
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])

            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(3, len(req.states))

            # Ensure in-memory map updates properly
            await self.state_replicator._send_to_state_service(req)
            self.assertEqual(3, len(self.state_replicator._state_versions))
            mem_key1 = make_mem_key('id1', NID_TYPE)
            mem_key2 = make_mem_key('aaa-bbb:id1', IDList_TYPE)
            mem_key3 = make_mem_key('id1', FOO_TYPE)
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key1])
            self.assertEqual(2,
                             self.state_replicator._state_versions[mem_key2])
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key3])

            # Now add new state and update some existing state
            key2 = 'id2'
            self.nid_client[key2] = NetworkID(id='bar')
            self.idlist_client[key] = IDList(ids=['bar', 'foo'])
            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(2, len(req.states))

            # Ensure in-memory map updates properly
            await self.state_replicator._send_to_state_service(req)
            self.assertEqual(4, len(self.state_replicator._state_versions))
            mem_key4 = make_mem_key('id2', NID_TYPE)
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key1])
            self.assertEqual(3,
                             self.state_replicator._state_versions[mem_key2])
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key3])
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key4])

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_unreplicated_states(self, get_grpc_mock):
        async def test():
            get_grpc_mock.return_value = self.channel

            # Add initial state to be replicated
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'
            key2 = 'id2'
            self.nid_client[key] = NetworkID(id='foo')
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            # Increment version
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            # Set state that will be 'unreplicated'
            self.log_client[key2] = LogVerbosity(verbosity=5)

            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(3, len(req.states))

            # Ensure in-memory map updates properly for successful replications
            await self.state_replicator._send_to_state_service(req)
            self.assertEqual(2, len(self.state_replicator._state_versions))
            mem_key1 = make_mem_key('id1', NID_TYPE)
            mem_key2 = make_mem_key('aaa-bbb:id1', IDList_TYPE)
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key1])
            self.assertEqual(2,
                             self.state_replicator._state_versions[mem_key2])

            # Now run again, ensuring only the state the wasn't replicated
            # will be sent again
            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(1, len(req.states))
            self.assertEqual('aaa-bbb:id2', req.states[0].deviceID)
            self.assertEqual(LOG_TYPE, req.states[0].type)

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_resync_success(self, get_grpc_mock):
        async def test():
            get_grpc_mock.return_value = self.channel
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'
            # Set state that will be 'unsynced'
            self.nid_client[key] = NetworkID(id='foo')
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])
            # Increment state's version
            self.idlist_client[key] = IDList(ids=['bar', 'blah'])

            await self.state_replicator._resync()
            self.assertEqual(True, self.state_replicator._has_resync_completed)
            self.assertEqual(1, len(self.state_replicator._state_versions))
            mem_key = make_mem_key('aaa-bbb:id1', IDList_TYPE)
            self.assertEqual(2, self.state_replicator._state_versions[mem_key])

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_resync_failure(self, get_grpc_mock):
        async def test():
            get_grpc_mock.return_value = self.channel
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            # Set state that will trigger the RpcError
            log_key = 'id1'
            self.log_client[log_key] = LogVerbosity(verbosity=5)

            try:
                await self.state_replicator._resync()
            except grpc.RpcError:
                pass

            self.assertEqual(False,
                             self.state_replicator._has_resync_completed)
            self.assertEqual(0, len(self.state_replicator._state_versions))

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    @mock.patch('magma.magmad.state_reporter.ServiceRegistry.get_rpc_channel')
    def test_deleted_replicated_state(self, get_grpc_mock):
        async def test():
            get_grpc_mock.return_value = self.channel
            self.nid_client.clear()
            self.idlist_client.clear()
            self.log_client.clear()
            self.foo_client.clear()

            key = 'id1'
            self.nid_client[key] = NetworkID(id='foo')
            req = await self.state_replicator._collect_states_to_replicate()
            self.assertEqual(1, len(req.states))

            # Ensure in-memory map updates properly
            await self.state_replicator._send_to_state_service(req)
            self.assertEqual(1, len(self.state_replicator._state_versions))
            mem_key1 = make_mem_key('id1', NID_TYPE)
            self.assertEqual(1,
                             self.state_replicator._state_versions[mem_key1])

            # Now delete state and ensure in-memory map gets updated properly
            del self.nid_client[key]
            req = await self.state_replicator._collect_states_to_replicate()
            self.assertIsNone(req)

            await self.state_replicator._cleanup_deleted_keys()
            self.assertFalse(key in self.state_replicator._state_versions)

        # Cancel the replicator's loop so there are no other activities
        self.state_replicator._periodic_task.cancel()
        self.loop.run_until_complete(test())