Exemplo n.º 1
0
async def test_func_registry():
    p = Peer(connect=ZeroMQAddress('tcp://127.0.0.1:5000'),
             serializer=json.dumps,
             deserializer=json.loads,
             transport=ZeroMQRPCTransport)

    def dummy():
        pass

    assert 'dummy' not in p._func_registry
    p.handle_function('dummy', dummy)
    assert 'dummy' in p._func_registry
    assert p._func_registry['dummy'] is dummy
Exemplo n.º 2
0
async def serve() -> None:
    peer = Peer(bind=ZeroMQAddress('tcp://127.0.0.1:5030'),
                serializer=noop_serializer,
                deserializer=noop_deserializer,
                transport=ZeroMQRPCTransport)
    adaptor = ThriftServerAdaptor(peer, simple_thrift.SimpleService,
                                  SimpleDispatcher())
    peer.handle_function('simple', adaptor.handle_function)

    loop = asyncio.get_running_loop()
    forever = loop.create_future()
    loop.add_signal_handler(signal.SIGINT, forever.cancel)
    loop.add_signal_handler(signal.SIGTERM, forever.cancel)
    async with peer:
        try:
            print('server started')
            await forever
        except asyncio.CancelledError:
            pass
    print('server terminated')
Exemplo n.º 3
0
async def serve() -> None:
    peer = Peer(bind=RedisStreamAddress('redis://localhost:6379', 'myservice',
                                        'server-group', 'client1'),
                transport=RPCRedisTransport,
                serializer=lambda o: json.dumps(o).encode('utf8'),
                deserializer=lambda b: json.loads(b))
    peer.handle_function('echo', handle_echo)
    peer.handle_function('add', handle_add)

    loop = asyncio.get_running_loop()
    forever = loop.create_future()
    loop.add_signal_handler(signal.SIGINT, forever.cancel)
    loop.add_signal_handler(signal.SIGTERM, forever.cancel)
    async with peer:
        try:
            print('server started')
            await forever
        except asyncio.CancelledError:
            pass
    print('server terminated')
Exemplo n.º 4
0
async def serve():
    peer = Peer(bind=ZeroMQAddress('tcp://127.0.0.1:5010'),
                transport=ZeroMQRPCTransport,
                scheduler=ExitOrderedAsyncScheduler(),
                serializer=json.dumps,
                deserializer=json.loads)
    peer.handle_function('echo', handle_echo)
    peer.handle_function('add', handle_add)
    peer.handle_function('print_delim', handle_delimeter)

    print('echo() will take 1 second and add() will take 0.5 second.')
    print('You can confirm the effect of scheduler '
          'and the ordering key by the console logs.\n')

    loop = asyncio.get_running_loop()
    forever = loop.create_future()
    loop.add_signal_handler(signal.SIGINT, forever.cancel)
    loop.add_signal_handler(signal.SIGTERM, forever.cancel)
    async with peer:
        try:
            print('server started')
            await forever
        except asyncio.CancelledError:
            pass
    print('server terminated')
Exemplo n.º 5
0
class AgentRPCServer(aobject):
    rpc_function: ClassVar[RPCFunctionRegistry] = RPCFunctionRegistry()

    rpc_server: Peer
    rpc_addr: str
    agent: AbstractVolumeAgent

    def __init__(self, etcd, config):
        self.config = config
        self.etcd = etcd

        self.agent: AbstractVolumeAgent = None

    async def init(self):
        await self.update_status('starting')

        if self.config['storage']['mode'] == 'xfs':
            from .xfs.agent import VolumeAgent
            self.agent = VolumeAgent(self.config['storage']['path'],
                                     self.config['agent']['user-uid'],
                                     self.config['agent']['user-gid'])
        elif self.config['storage']['mode'] == 'btrfs':
            # TODO: Implement Btrfs Agent
            pass
        await self.agent.init()

        rpc_addr = self.config['agent']['rpc-listen-addr']
        self.rpc_server = Peer(
            bind=ZeroMQAddress(f"tcp://{rpc_addr}"),
            transport=ZeroMQRPCTransport,
            scheduler=KeySerializedAsyncScheduler(),
            serializer=msgpack.packb,
            deserializer=msgpack.unpackb,
            debug_rpc=self.config['debug']['enabled'],
        )
        for func_name in self.rpc_function.functions:
            self.rpc_server.handle_function(func_name,
                                            getattr(self, func_name))
        log.info('started handling RPC requests at {}', rpc_addr)

        await self.etcd.put('ip', rpc_addr.host, scope=ConfigScopes.NODE)
        await self.update_status('running')

    async def shutdown(self):
        if self.rpc_server is not None:
            self.rpc_server.close()
            await self.rpc_server.wait_closed()

    async def update_status(self, status):
        await self.etcd.put('', status, scope=ConfigScopes.NODE)

    @aiotools.actxmgr
    async def handle_rpc_exception(self):
        try:
            yield
        except AssertionError:
            log.exception('assertion failure')
            raise
        except Exception:
            log.exception('unexpected error')
            raise

    @rpc_function
    async def hello(self, agent_id: str) -> str:
        log.debug('rpc::hello({0})', agent_id)
        return 'OLLEH'

    @rpc_function
    async def create(self, kernel_id: str, size: str) -> str:
        log.debug('rpc::create({0}, {1})', kernel_id, size)
        async with self.handle_rpc_exception():
            return await self.agent.create(kernel_id, size)

    @rpc_function
    async def remove(self, kernel_id: str):
        log.debug('rpc::remove({0})', kernel_id)
        async with self.handle_rpc_exception():
            return await self.agent.remove(kernel_id)

    @rpc_function
    async def get(self, kernel_id: str) -> str:
        log.debug('rpc::get({0})', kernel_id)
        async with self.handle_rpc_exception():
            return await self.agent.get(kernel_id)
Exemplo n.º 6
0
class AgentRPCServer(aobject):
    rpc_function: ClassVar[RPCFunctionRegistry] = RPCFunctionRegistry()

    loop: asyncio.AbstractEventLoop
    agent: AbstractAgent
    rpc_server: Peer
    rpc_addr: str
    agent_addr: str

    _stop_signal: signal.Signals

    def __init__(
        self,
        etcd: AsyncEtcd,
        local_config: Mapping[str, Any],
        *,
        skip_detect_manager: bool = False,
    ) -> None:
        self.loop = current_loop()
        self.etcd = etcd
        self.local_config = local_config
        self.skip_detect_manager = skip_detect_manager
        self._stop_signal = signal.SIGTERM

    async def __ainit__(self) -> None:
        # Start serving requests.
        await self.update_status('starting')

        if not self.skip_detect_manager:
            await self.detect_manager()

        await self.read_agent_config()
        await self.read_agent_config_container()

        self.stats_monitor = StatsPluginContext(self.etcd, self.local_config)
        self.error_monitor = ErrorPluginContext(self.etcd, self.local_config)
        await self.stats_monitor.init()
        await self.error_monitor.init()

        backend = self.local_config['agent']['backend']
        agent_mod = importlib.import_module(f"ai.backend.agent.{backend.value}")
        self.agent = await agent_mod.get_agent_cls().new(  # type: ignore
            self.etcd,
            self.local_config,
            stats_monitor=self.stats_monitor,
            error_monitor=self.error_monitor,
        )

        rpc_addr = self.local_config['agent']['rpc-listen-addr']
        self.rpc_server = Peer(
            bind=ZeroMQAddress(f"tcp://{rpc_addr}"),
            transport=ZeroMQRPCTransport,
            scheduler=ExitOrderedAsyncScheduler(),
            serializer=msgpack.packb,
            deserializer=msgpack.unpackb,
            debug_rpc=self.local_config['debug']['enabled'],
        )
        for func_name in self.rpc_function.functions:
            self.rpc_server.handle_function(func_name, getattr(self, func_name))
        log.info('started handling RPC requests at {}', rpc_addr)

        await self.etcd.put('ip', rpc_addr.host, scope=ConfigScopes.NODE)
        watcher_port = utils.nmget(self.local_config, 'watcher.service-addr.port', None)
        if watcher_port is not None:
            await self.etcd.put('watcher_port', watcher_port, scope=ConfigScopes.NODE)

        await self.update_status('running')

    async def detect_manager(self):
        log.info('detecting the manager...')
        manager_instances = await self.etcd.get_prefix('nodes/manager')
        if not manager_instances:
            log.warning('watching etcd to wait for the manager being available')
            async with aclosing(self.etcd.watch_prefix('nodes/manager')) as agen:
                async for ev in agen:
                    if ev.event == 'put' and ev.value == 'up':
                        break
        log.info('detected at least one manager running')

    async def read_agent_config(self):
        # Fill up Redis configs from etcd.
        self.local_config['redis'] = config.redis_config_iv.check(
            await self.etcd.get_prefix('config/redis')
        )
        log.info('configured redis_addr: {0}', self.local_config['redis']['addr'])

        # Fill up vfolder configs from etcd.
        self.local_config['vfolder'] = config.vfolder_config_iv.check(
            await self.etcd.get_prefix('volumes')
        )
        if self.local_config['vfolder']['mount'] is None:
            log.info('assuming use of storage-proxy since vfolder mount path is not configured in etcd')
        else:
            log.info('configured vfolder mount base: {0}', self.local_config['vfolder']['mount'])
            log.info('configured vfolder fs prefix: {0}', self.local_config['vfolder']['fsprefix'])

        # Fill up shared agent configurations from etcd.
        agent_etcd_config = agent_etcd_config_iv.check(
            await self.etcd.get_prefix('config/agent')
        )
        for k, v in agent_etcd_config.items():
            self.local_config['agent'][k] = v

    async def read_agent_config_container(self):
        # Fill up global container configurations from etcd.
        try:
            container_etcd_config = container_etcd_config_iv.check(
                await self.etcd.get_prefix('config/container')
            )
        except TrafaretDataError as etrafa:
            log.warning("etcd: container-config error: {}".format(etrafa))
            container_etcd_config = {}
        for k, v in container_etcd_config.items():
            self.local_config['container'][k] = v
            log.info("etcd: container-config: {}={}".format(k, v))

    async def __aenter__(self) -> None:
        await self.rpc_server.__aenter__()

    def mark_stop_signal(self, stop_signal: signal.Signals) -> None:
        self._stop_signal = stop_signal

    async def __aexit__(self, *exc_info) -> None:
        # Stop receiving further requests.
        await self.rpc_server.__aexit__(*exc_info)
        await self.agent.shutdown(self._stop_signal)
        await self.stats_monitor.cleanup()
        await self.error_monitor.cleanup()

    @collect_error
    async def update_status(self, status):
        await self.etcd.put('', status, scope=ConfigScopes.NODE)

    @rpc_function
    @collect_error
    async def ping(self, msg: str) -> str:
        log.debug('rpc::ping()')
        return msg

    @rpc_function
    @collect_error
    async def gather_hwinfo(self) -> Mapping[str, HardwareMetadata]:
        log.debug('rpc::gather_hwinfo()')
        return await self.agent.gather_hwinfo()

    @rpc_function
    @collect_error
    async def ping_kernel(self, kernel_id: str):
        log.debug('rpc::ping_kernel({0})', kernel_id)

    @rpc_function
    @collect_error
    async def create_kernels(
        self,
        creation_id: str,
        raw_session_id: str,
        raw_kernel_ids: Sequence[str],
        raw_configs: Sequence[dict],
        raw_cluster_info: dict,
    ):
        cluster_info = cast(ClusterInfo, raw_cluster_info)
        session_id = SessionId(UUID(raw_session_id))
        raw_results = []
        coros = []
        for raw_kernel_id, raw_config in zip(raw_kernel_ids, raw_configs):
            log.info('rpc::create_kernel(k:{0}, img:{1})',
                     raw_kernel_id, raw_config['image']['canonical'])
            kernel_id = KernelId(UUID(raw_kernel_id))
            kernel_config = cast(KernelCreationConfig, raw_config)
            coros.append(self.agent.create_kernel(
                creation_id,
                session_id,
                kernel_id,
                kernel_config,
                cluster_info,
            ))
        results = await asyncio.gather(*coros, return_exceptions=True)
        errors = [*filter(lambda item: isinstance(item, Exception), results)]
        if errors:
            # Raise up the first error.
            raise errors[0]
        raw_results = [
            {
                'id': str(result['id']),
                'kernel_host': result['kernel_host'],
                'repl_in_port': result['repl_in_port'],
                'repl_out_port': result['repl_out_port'],
                'stdin_port': result['stdin_port'],    # legacy
                'stdout_port': result['stdout_port'],  # legacy
                'service_ports': result['service_ports'],
                'container_id': result['container_id'],
                'resource_spec': result['resource_spec'],
                'attached_devices': result['attached_devices'],
            }
            for result in results
        ]
        return raw_results

    @rpc_function
    @collect_error
    async def destroy_kernel(
        self,
        kernel_id: str,
        reason: str = None,
        suppress_events: bool = False,
    ):
        log.info('rpc::destroy_kernel(k:{0})', kernel_id)
        done = asyncio.Event()
        await self.agent.inject_container_lifecycle_event(
            KernelId(UUID(kernel_id)),
            LifecycleEvent.DESTROY,
            reason or 'user-requested',
            done_event=done,
            suppress_events=suppress_events,
        )
        await done.wait()
        return getattr(done, '_result', None)

    @rpc_function
    @collect_error
    async def interrupt_kernel(self, kernel_id: str):
        log.info('rpc::interrupt_kernel(k:{0})', kernel_id)
        await self.agent.interrupt_kernel(KernelId(UUID(kernel_id)))

    @rpc_function
    @collect_error
    async def get_completions(self, kernel_id: str,
                              text: str, opts: dict):
        log.debug('rpc::get_completions(k:{0}, ...)', kernel_id)
        await self.agent.get_completions(KernelId(UUID(kernel_id)), text, opts)

    @rpc_function
    @collect_error
    async def get_logs(self, kernel_id: str):
        log.info('rpc::get_logs(k:{0})', kernel_id)
        return await self.agent.get_logs(KernelId(UUID(kernel_id)))

    @rpc_function
    @collect_error
    async def restart_kernel(
        self,
        creation_id: str,
        session_id: str,
        kernel_id: str,
        updated_config: dict,
    ):
        log.info('rpc::restart_kernel(s:{0}, k:{1})', session_id, kernel_id)
        return await self.agent.restart_kernel(
            creation_id,
            SessionId(UUID(session_id)),
            KernelId(UUID(kernel_id)),
            cast(KernelCreationConfig, updated_config),
        )

    @rpc_function
    @collect_error
    async def execute(
        self,
        kernel_id,          # type: str
        api_version,        # type: int
        run_id,             # type: str
        mode,               # type: Literal['query', 'batch', 'continue', 'input']
        code,               # type: str
        opts,               # type: Dict[str, Any]
        flush_timeout,      # type: float
    ):
        # type: (...) -> Dict[str, Any]
        if mode != 'continue':
            log.info('rpc::execute(k:{0}, run-id:{1}, mode:{2}, code:{3!r})',
                     kernel_id, run_id, mode,
                     code[:20] + '...' if len(code) > 20 else code)
        result = await self.agent.execute(
            KernelId(UUID(kernel_id)),
            run_id,
            mode,
            code,
            opts=opts,
            api_version=api_version,
            flush_timeout=flush_timeout
        )
        return result

    @rpc_function
    @collect_error
    async def execute_batch(
        self,
        kernel_id,          # type: str
        startup_command,    # type: str
    ) -> None:
        # TODO: use task-group to keep track of completion/cancellation
        asyncio.create_task(self.agent.execute_batch(
            KernelId(UUID(kernel_id)),
            startup_command,
        ))
        await asyncio.sleep(0)

    @rpc_function
    @collect_error
    async def start_service(
        self,
        kernel_id,   # type: str
        service,     # type: str
        opts         # type: Dict[str, Any]
    ):
        # type: (...) -> Dict[str, Any]
        log.info('rpc::start_service(k:{0}, app:{1})', kernel_id, service)
        return await self.agent.start_service(KernelId(UUID(kernel_id)), service, opts)

    @rpc_function
    @collect_error
    async def shutdown_service(
        self,
        kernel_id,  # type: str
        service,    # type: str
    ):
        log.info('rpc::shutdown_service(k:{0}, app:{1})', kernel_id, service)
        return await self.agent.shutdown_service(KernelId(UUID(kernel_id)), service)

    @rpc_function
    @collect_error
    async def upload_file(self, kernel_id: str, filename: str, filedata: bytes):
        log.info('rpc::upload_file(k:{0}, fn:{1})', kernel_id, filename)
        await self.agent.accept_file(KernelId(UUID(kernel_id)), filename, filedata)

    @rpc_function
    @collect_error
    async def download_file(self, kernel_id: str, filepath: str):
        log.info('rpc::download_file(k:{0}, fn:{1})', kernel_id, filepath)
        return await self.agent.download_file(KernelId(UUID(kernel_id)), filepath)

    @rpc_function
    @collect_error
    async def list_files(self, kernel_id: str, path: str):
        log.info('rpc::list_files(k:{0}, fn:{1})', kernel_id, path)
        return await self.agent.list_files(KernelId(UUID(kernel_id)), path)

    @rpc_function
    @collect_error
    async def shutdown_agent(self, terminate_kernels: bool):
        # TODO: implement
        log.info('rpc::shutdown_agent()')
        pass

    @rpc_function
    @collect_error
    async def create_overlay_network(self, network_name: str) -> None:
        log.debug('rpc::create_overlay_network(name:{})', network_name)
        return await self.agent.create_overlay_network(network_name)

    @rpc_function
    @collect_error
    async def destroy_overlay_network(self, network_name: str) -> None:
        log.debug('rpc::destroy_overlay_network(name:{})', network_name)
        return await self.agent.destroy_overlay_network(network_name)

    @rpc_function
    @collect_error
    async def create_local_network(self, network_name: str) -> None:
        log.debug('rpc::create_local_network(name:{})', network_name)
        return await self.agent.create_local_network(network_name)

    @rpc_function
    @collect_error
    async def destroy_local_network(self, network_name: str) -> None:
        log.debug('rpc::destroy_local_network(name:{})', network_name)
        return await self.agent.destroy_local_network(network_name)

    @rpc_function
    @collect_error
    async def reset_agent(self):
        log.debug('rpc::reset()')
        kernel_ids = tuple(self.agent.kernel_registry.keys())
        tasks = []
        for kernel_id in kernel_ids:
            try:
                task = asyncio.ensure_future(
                    self.agent.destroy_kernel(kernel_id, 'agent-reset'))
                tasks.append(task)
            except Exception:
                await self.error_monitor.capture_exception()
                log.exception('reset: destroying {0}', kernel_id)
        await asyncio.gather(*tasks)
Exemplo n.º 7
0
async def serve(scheduler_type: str) -> None:
    global last_snapshot, scheduler

    sched_cls = scheduler_types[scheduler_type]
    scheduler = sched_cls()
    peer = Peer(bind=ZeroMQAddress('tcp://127.0.0.1:5020'),
                transport=ZeroMQRPCTransport,
                scheduler=scheduler,
                serializer=lambda o: json.dumps(o).encode('utf8'),
                deserializer=lambda b: json.loads(b))
    peer.handle_function('echo', handle_echo)
    peer.handle_function('add', handle_add)
    peer.handle_function('long_delay', handle_long_delay)
    peer.handle_function('error', handle_error)
    peer.handle_function('set_output', handle_output)
    peer.handle_function('memstat', handle_show_memory_stat)

    loop = asyncio.get_running_loop()
    forever = loop.create_future()
    loop.add_signal_handler(signal.SIGINT, forever.cancel)
    loop.add_signal_handler(signal.SIGTERM, forever.cancel)

    try:
        tracemalloc.start(10)
        async with peer:
            last_snapshot = tracemalloc.take_snapshot()
            try:
                print('server started')
                await forever
            except asyncio.CancelledError:
                pass
        print('server terminated')
    finally:
        tracemalloc.stop()