Ejemplo n.º 1
0
async def run_queue(context):
    # Prepare our context and sockets
    frontend = context.socket(zmq.ROUTER)
    backend = context.socket(zmq.ROUTER)
    frontend.bind("tcp://*:5559")
    backend.bind("tcp://*:5560")
    # Initialize poll set
    poller = Poller()
    poller.register(frontend, zmq.POLLIN)
    poller.register(backend, zmq.POLLIN)
    # Switch messages between sockets
    while True:
        socks = await poller.poll()
        socks = dict(socks)

        if socks.get(frontend) == zmq.POLLIN:
            frames = await frontend.recv_multipart()
            print('received from frontend: {}'.format(frames))
            # Add the worker ident to the envelope - simplified for example.
            frames.insert(0, b'Worker1')
            await backend.send_multipart(frames)

        if socks.get(backend) == zmq.POLLIN:
            frames = await backend.recv_multipart()
            print('received from backend: {}'.format(frames))
            msg = frames[1:]  # Slice off worker ident
            await frontend.send_multipart(msg)
Ejemplo n.º 2
0
def run_worker(context):
    # Socket to receive messages on
    receiver = context.socket(zmq.PULL)
    receiver.connect("tcp://localhost:5557")
    # Socket to send messages to
    sender = context.socket(zmq.PUSH)
    sender.connect("tcp://localhost:5558")
    # Socket for control input
    controller = context.socket(zmq.SUB)
    controller.connect("tcp://localhost:5559")
    controller.setsockopt(zmq.SUBSCRIBE, b"")
    # Process messages from receiver and controller
    poller = Poller()
    poller.register(receiver, zmq.POLLIN)
    poller.register(controller, zmq.POLLIN)
    # Process messages from both sockets
    while True:
        socks = yield from poller.poll()
        socks = dict(socks)
        if socks.get(receiver) == zmq.POLLIN:
            message = yield from receiver.recv()
            # Process task
            workload = int(message)  # Workload in msecs
            # Do the work
            yield from asyncio.sleep(workload / 1000.0)
            # Send results to sink
            yield from sender.send(message)
            # Simple progress indicator for the viewer
            sys.stdout.write(".")
            sys.stdout.flush()
        # Any waiting controller command acts as 'KILL'
        if socks.get(controller) == zmq.POLLIN:
            break
Ejemplo n.º 3
0
class AsyncioAuthenticator(Authenticator):
    """ZAP authentication for use in the asyncio IO loop"""

    def __init__(self, context=None, loop=None):
        super().__init__(context)
        self.loop = loop or asyncio.get_event_loop()
        self.__poller = None
        self.__task = None

    @asyncio.coroutine
    def __handle_zap(self):
        while True:
            events = yield from self.__poller.poll()
            if self.zap_socket in dict(events):
                msg = yield from self.zap_socket.recv_multipart()
                self.handle_zap_message(msg)

    def start(self):
        """Start ZAP authentication"""
        super().start()
        self.__poller = Poller()
        self.__poller.register(self.zap_socket, zmq.POLLIN)
        self.__task = asyncio.ensure_future(self.__handle_zap())

    def stop(self):
        """Stop ZAP authentication"""
        if self.__task:
            self.__task.cancel()
        if self.__poller:
            self.__poller.unregister(self.zap_socket)
            self.__poller = None
        super().stop()
Ejemplo n.º 4
0
def run():
    subscriber = Ctx.socket(zmq.SUB)
    subscriber.connect(Url)
    subscription = b"%03d" % 5
    subscriber.setsockopt(zmq.SUBSCRIBE, subscription)
    poller = Poller()
    poller.register(subscriber, zmq.POLLOUT)
    while True:
        topic, data = yield from subscriber.recv_multipart()
        # assert topic == subscription
        print(data)
Ejemplo n.º 5
0
async def receiver():
    """receive messages with polling"""
    pull = ctx.socket(zmq.PULL)
    pull.connect(url)
    poller = Poller()
    poller.register(pull, zmq.POLLIN)
    while True:
        events = await poller.poll()
        if pull in dict(events):
            print("recving", events)
            msg = await pull.recv_multipart()
            print('recvd', msg)
Ejemplo n.º 6
0
def run_client(context):
    print("I: Connecting to server...")
    client = context.socket(zmq.REQ)
    client.connect(SERVER_ENDPOINT)
    poll = Poller()
    poll.register(client, zmq.POLLIN)
    sequence = 0
    retries_left = REQUEST_RETRIES
    while retries_left:
        sequence += 1
        request = str(sequence)
        print("I: Sending (%s)" % request)
        yield from client.send_string(request)
        expect_reply = True
        while expect_reply:
            socks = yield from poll.poll(REQUEST_TIMEOUT)
            socks = dict(socks)
            if socks.get(client) == zmq.POLLIN:
                reply = yield from client.recv()
                if not reply:
                    break
                if int(reply) == sequence:
                    print("I: Server replied OK (%s)" % reply)
                    retries_left = REQUEST_RETRIES
                    expect_reply = False
                else:
                    print("E: Malformed reply from server: %s" % reply)
            else:
                print("W: No response from server, retrying...")
                # Socket is confused. Close and remove it.
                print('W: confused')
                client.setsockopt(zmq.LINGER, 0)
                client.unbind(SERVER_ENDPOINT)
                #client.close()
                poll.unregister(client)
                retries_left -= 1
                if retries_left == 0:
                    print("E: Server seems to be offline, abandoning")
                    return
                print("I: Reconnecting and resending (%s)" % request)
                # Create new connection
                client = context.socket(zmq.REQ)
                client.connect(SERVER_ENDPOINT)
                poll.register(client, zmq.POLLIN)
                yield from client.send_string(request)
Ejemplo n.º 7
0
def run(ident):
    #  Socket to talk to server
    print("Connecting to hello world server.  Ctrl-C to exit early.\n")
    socket = Ctx.socket(zmq.REQ)
    socket.connect(Url)
    poller = Poller()
    poller.register(socket, zmq.POLLOUT)
    #  Do multiple requests, waiting each time for a response
    for request in range(10):
        message = '{} Hello {}'.format(ident, request)
        message = message.encode('utf-8')
        print("Sending message: {}".format(message))
        yield from socket.send(message)
        #  Get the reply.
        message = yield from socket.recv()
        print("Received reply: {}".format(message))
    print('exiting')
    return 'nothing'
Ejemplo n.º 8
0
def run():
    print("Getting ready for hello world client.  Ctrl-C to exit.\n")
    socket = Ctx.socket(zmq.REP)
    socket.bind(Url)
    poller = Poller()
    poller.register(socket, zmq.POLLIN)
    while True:
        #  Wait for next request from client
        message = yield from socket.recv()
        print("Received request: {}".format(message))
        #  Do some 'work'
        yield from asyncio.sleep(1)
        #  Send reply back to client
        message = message.decode('utf-8')
        message = '{}, world'.format(message)
        message = message.encode('utf-8')
        print("Sending reply: {}".format(message))
        yield from socket.send(message)
Ejemplo n.º 9
0
def run_proxy(socket_from, socket_to):
    poller = Poller()
    poller.register(socket_from, zmq.POLLIN)
    poller.register(socket_to, zmq.POLLIN)
    printdbg('(run_proxy) started')
    while True:
        events = yield from poller.poll()
        events = dict(events)
        if socket_from in events:
            msg = yield from socket_from.recv_multipart()
            printdbg('(run_proxy) received from frontend -- msg: {}'.format(
                msg))
            yield from socket_to.send_multipart(msg)
            printdbg('(run_proxy) sent to backend -- msg: {}'.format(msg))
        elif socket_to in events:
            msg = yield from socket_to.recv_multipart()
            printdbg('(run_proxy) received from backend -- msg: {}'.format(
                msg))
            yield from socket_from.send_multipart(msg)
            printdbg('(run_proxy) sent to frontend -- msg: {}'.format(msg))
Ejemplo n.º 10
0
def run_server(loop):
    """Server routine"""
    # Prepare our context and sockets
    # Socket to talk to clients
    clients = Ctx.socket(zmq.ROUTER)
    clients.bind(Url_client)
    workers = Ctx.socket(zmq.DEALER)
    workers.bind(Url_worker)
    # Start the workers
    tasks = []
    for idx in range(5):
        ident = 'worker {}'.format(idx)
        task = asyncio.ensure_future(run_worker(ident))
        tasks.append(task)
    poller = Poller()
    poller.register(clients, zmq.POLLIN)
    poller.register(workers, zmq.POLLIN)
    print('mtserver ready for requests')
    while True:
        events = yield from poller.poll()
        events = dict(events)
        if clients in events:
            message = yield from clients.recv_multipart()
            printdbg('(run) received from client message_parts: {}'.format(
                message))
            client, empty, message = message[:3]
            printdbg('(run) received from client message: {}'.format(
                message))
            printdbg('(run) sending message to workers: {}'.format(message))
            yield from workers.send_multipart([client, b'', message])
        elif workers in events:
            message = yield from workers.recv_multipart()
            printdbg('(run) received from worker message_parts: {}'.format(
                message))
            client, empty, message = message[:3]
            printdbg('(run) received from worker message: {}'.format(
                message))
            yield from clients.send_multipart([client, b'', message])
            printdbg('(run) sent message to clients: {}'.format(message))
Ejemplo n.º 11
0
def run_broker(context):
    # Prepare our context and sockets
    frontend = context.socket(zmq.ROUTER)
    backend = context.socket(zmq.DEALER)
    frontend.bind("tcp://*:5559")
    backend.bind("tcp://*:5560")
    # Initialize poll set
    poller = Poller()
    poller.register(frontend, zmq.POLLIN)
    poller.register(backend, zmq.POLLIN)
    # Switch messages between sockets
    while True:
        socks = yield from poller.poll()
        socks = dict(socks)
        if socks.get(frontend) == zmq.POLLIN:
            message = yield from frontend.recv_multipart()
            print('received from frontend: {}'.format(message))
            yield from backend.send_multipart(message)
        if socks.get(backend) == zmq.POLLIN:
            message = yield from backend.recv_multipart()
            print('received from backend: {}'.format(message))
            yield from frontend.send_multipart(message)
Ejemplo n.º 12
0
def run_queue():
    context = Context(1)

    frontend = context.socket(zmq.ROUTER)    # ROUTER
    backend = context.socket(zmq.ROUTER)     # ROUTER
    frontend.bind("tcp://*:5555")            # For clients
    backend.bind("tcp://*:5556")             # For workers

    poll_workers = Poller()
    poll_workers.register(backend, zmq.POLLIN)

    poll_both = Poller()
    poll_both.register(frontend, zmq.POLLIN)
    poll_both.register(backend, zmq.POLLIN)

    workers = []

    while True:
        if workers:
            socks = yield from poll_both.poll()
        else:
            socks = yield from poll_workers.poll()
        socks = dict(socks)

        # Handle worker activity on backend
        if socks.get(backend) == zmq.POLLIN:
            # Use worker address for LRU routing
            msg = yield from backend.recv_multipart()
            if not msg:
                break
            print('I: received msg: {}'.format(msg))
            address = msg[0]
            workers.append(address)

            # Everything after the second (delimiter) frame is reply
            reply = msg[2:]

            # Forward message to client if it's not a READY
            if reply[0] != LRU_READY:
                print('I: sending -- reply: {}'.format(reply))
                yield from frontend.send_multipart(reply)
            else:
                print('I: received ready -- address: {}'.format(address))

        if socks.get(frontend) == zmq.POLLIN:
            # Get client request, route to first available worker
            msg = yield from frontend.recv_multipart()
            worker = workers.pop(0)
            request = [worker, b''] + msg
            print('I: sending -- worker: {}  msg: {}'.format(worker, msg))
            yield from backend.send_multipart(request)
Ejemplo n.º 13
0
def run_queue(context):
    frontend = context.socket(zmq.ROUTER)    # ROUTER
    backend = context.socket(zmq.ROUTER)     # ROUTER
    frontend.bind(FRONT_END_ADDRESS)    # For clients
    backend.bind(BACK_END_ADDRESS)      # For workers
    poll_workers = Poller()
    poll_workers.register(backend, zmq.POLLIN)
    poll_both = Poller()
    poll_both.register(frontend, zmq.POLLIN)
    poll_both.register(backend, zmq.POLLIN)
    workers = WorkerQueue()
    heartbeat_at = time.time() + HEARTBEAT_INTERVAL
    while True:
        if len(workers.queue) > 0:
            poller = poll_both
        else:
            poller = poll_workers
        socks = yield from poller.poll(HEARTBEAT_INTERVAL * 1000)
        socks = dict(socks)
        # Handle worker activity on backend
        if socks.get(backend) == zmq.POLLIN:
            # Use worker address for LRU routing
            frames = yield from backend.recv_multipart()
            if not frames:
                break
            address = frames[0]
            workers.ready(Worker(address))
            # Validate control message, or return reply to client
            msg = frames[1:]
            if len(msg) == 1:
                if msg[0] not in (PPP_READY, PPP_HEARTBEAT):
                    print("E: Invalid message from worker: %s" % msg)
            else:
                yield from frontend.send_multipart(msg)
            # Send heartbeats to idle workers if it's time
            if time.time() >= heartbeat_at:
                for worker in workers.queue:
                    msg = [worker, PPP_HEARTBEAT]
                    yield from backend.send_multipart(msg)
                heartbeat_at = time.time() + HEARTBEAT_INTERVAL
        if socks.get(frontend) == zmq.POLLIN:
            frames = yield from frontend.recv_multipart()
            if not frames:
                break
            frames.insert(0, next(workers))
            backend.send_multipart(frames)
        workers.purge()
Ejemplo n.º 14
0
class AsyncRdsBusClient(object):
    """
    RDS-BUS 客户端
    """
    ASC = 1
    DESC = -1

    def __init__(self, url, logger, request_timeout=None, database=None):
        self._logger = logger
        self._database = database
        self._context = Context.instance()
        self._poller = Poller()
        self._request = self._context.socket(zmq.DEALER)
        self._request_timeout = request_timeout or 60
        self._rds_bus_url = url
        self._request.connect(self._rds_bus_url)
        self._request_dict = dict()
        self._io_loop = asyncio.get_event_loop()
        self._running = False
        asyncio.ensure_future(self.start())

    @classmethod
    def pack(cls,
             database: str,
             key: str,
             parameter: dict,
             is_query: bool = False,
             order_by: list = None,
             page_no: int = None,
             per_page: int = None,
             found_rows: bool = False):
        """
        打包请求数据
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的实例名
        :param parameter: 参数字典
        :param is_query: 是否为查询操作
        :param order_by: 排序信息 [{"column": "字段名", "order": AsyncRdsBusClient.ASC/AsyncRdsBusClient.DESC}]
        :param page_no: 当前页(范围[0-n) n指第n页)
        :param per_page: 每页记录数
        :param found_rows: 是否统计总数
        :return:
        """
        if is_query:
            amount = int(per_page) if per_page else None
            offset = int(page_no) * amount if page_no else None
            limit = (dict(amount=amount, offset=offset) if offset else dict(
                amount=amount)) if amount else None
            result = dict(command="{}/{}".format(database, key),
                          data=dict(var=parameter,
                                    order_by=order_by,
                                    limit=limit,
                                    found_rows=found_rows))
        else:
            result = dict(command="{}/{}".format(database, key),
                          data=dict(var=parameter))
        return result

    async def query(self,
                    key: str,
                    parameter: dict,
                    order_by: list = None,
                    page_no: int = None,
                    per_page: int = None,
                    found_rows: bool = False,
                    database: str = None,
                    execute: bool = True):
        """
        查询接口
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的语句实例名
        :param parameter: 参数字典
        :param order_by: 排序信息 [{"column": "字段名", "order": AsyncRdsBusClient.ASC/AsyncRdsBusClient.DESC}]
        :param page_no: 当前页(范围[0-n) n指第n页)
        :param per_page: 每页记录数
        :param found_rows: 是否统计总数
        :param execute: 是否执行
        :return:
        """
        _database = database or self._database
        argument = self.pack(database=_database,
                             key=key,
                             parameter=parameter,
                             is_query=True,
                             order_by=order_by,
                             page_no=page_no,
                             per_page=per_page,
                             found_rows=found_rows)
        if execute:
            response = await self._send(operation=OperationType.QUERY,
                                        argument=argument)
            result = RdsData(response)
        else:
            result = argument
        return result

    async def insert(self,
                     key: str,
                     parameter: dict,
                     database: str = None,
                     execute: bool = True):
        """
        新增接口
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的语句实例名
        :param parameter: 参数字典
        :param execute: 是否执行
        :return:
        """
        _database = database or self._database
        argument = self.pack(database=_database, key=key, parameter=parameter)
        if execute:
            response = await self._send(operation=OperationType.INSERT,
                                        argument=argument)
            result = RdsData(response)
        else:
            result = argument
        return result

    async def update(self,
                     key: str,
                     parameter: dict,
                     database: str = None,
                     execute: bool = True):
        """
        更新接口
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的语句实例名
        :param parameter: 参数字典
        :param execute: 是否执行
        :return:
        """
        _database = database or self._database
        argument = self.pack(database=_database, key=key, parameter=parameter)
        if execute:
            response = await self._send(operation=OperationType.UPDATE,
                                        argument=argument)
            result = RdsData(response)
        else:
            result = argument
        return result

    async def delete(self,
                     key: str,
                     parameter: dict,
                     database: str = None,
                     execute: bool = False):
        """
        删除接口
        :param database: RDS-BUS的数据库类名
        :param key: 数据库类所持有的语句实例名
        :param parameter: 参数字典
        :param execute: 是否执行
        :return:
        """
        _database = database or self._database
        argument = self.pack(database=_database, key=key, parameter=parameter)
        if execute:
            response = await self._send(operation=OperationType.DELETE,
                                        argument=argument)
            result = RdsData(response)
        else:
            result = argument
        return result

    async def transaction(self, data: list, database: str = None):
        """
        事务接口
        :param database: RDS-BUS的数据库类名
        :param data: 操作列表
        :return:
        """
        _database = database or self._database
        result = await self._send(
            operation=OperationType.TRANSACTION,
            argument=dict(command="{}/transaction".format(_database),
                          data=data))
        return RdsListData(result)

    async def batch(self, data: list, database: str = None):
        """
        批量接口
        :param database: RDS-BUS的数据库类名
        :param data: 操作列表
        :return:
        """
        _database = database or self._database
        result = await self._send(operation=OperationType.BATCH,
                                  argument=dict(
                                      command="{}/batch".format(_database),
                                      data=data))
        return RdsListData(result)

    async def start(self):
        self._poller.register(self._request, zmq.POLLIN)
        self._running = True
        while True:
            events = await self._poller.poll()
            if self._request in dict(events):
                response = await self._request.recv_json()
                self._logger.debug("received {}".format(response))
                if response["id"] in self._request_dict:
                    future = self._request_dict.pop(response["id"])
                    if HttpResult.is_duplicate_data_failure(response["code"]):
                        future.set_exception(
                            DuplicateDataException.new_exception(
                                response["desc"]))
                    elif HttpResult.is_failure(response["code"]):
                        future.set_exception(
                            CallServiceException(method="ZMQ",
                                                 url=self._rds_bus_url,
                                                 errmsg=response["desc"]))
                    else:
                        future.set_result(response["data"])
                else:
                    self._logger.warning(
                        "unknown response {}".format(response))

    def stop(self):
        if self._running:
            self._poller.unregister(self._request)
            self._running = False

    def shutdown(self):
        self.stop()
        self._request.close()

    def _send(self, operation, argument):
        """

        :param operation:
        :param argument:
        :return:
        """
        request_id = get_unique_id()
        self._request_dict[request_id] = asyncio.Future()
        self._io_loop.call_later(self._request_timeout, self._session_timeout,
                                 request_id)
        self._request.send_multipart([
            json.dumps(
                dict(id=request_id,
                     operation=operation.value,
                     argument=argument)).encode("utf-8")
        ])
        return self._request_dict[request_id]

    def _session_timeout(self, request_id):
        if request_id in self._request_dict:
            future = self._request_dict.pop(request_id)
            future.set_exception(
                ServerTimeoutException(method="ZMQ", url=self._rds_bus_url))
Ejemplo n.º 15
0
class SchedulerConnection(object):
    __slots__ = (
        'address',
        # context object to open socket connections
        'context',
        # pull socket to receive check definitions from scheduler
        'pull',
        # poller object for `pull` socket
        'poller',
        # monitor socket for `pull` socket
        'monitor_socket',
        # poller object for monitor socket
        'monitor_poller',
        'first_missing',
    )

    def __init__(self, address):
        self.pull = self.poller = None
        self.monitor_poller = self.monitor_socket = None
        self.address = address
        self.context = Context.instance()
        self.open()
        self.first_missing = None

    def __str__(self):
        return self.address

    def __repr__(self):
        return 'Scheduler({})'.format(self.address)

    def open(self):
        self.pull = self.context.socket(PULL)
        logger.info('%s - opening pull socket ...', self)
        self.pull.connect(self.address)
        if settings.SCHEDULER_MONITOR:
            logger.info('%s - opening monitor socket ...', self)
            self.monitor_socket = self.pull.get_monitor_socket(
                events=EVENT_DISCONNECTED
            )
        self.register()

    def register(self):
        self.poller = Poller()
        self.poller.register(self.pull, POLLIN)
        if settings.SCHEDULER_MONITOR:
            self.monitor_poller = Poller()
            self.monitor_poller.register(self.monitor_socket, POLLIN)
        logger.info('%s - all sockets are successfully registered '
                    'in poller objects ...', self)

    def close(self):
        """Unregister open sockets from poller objects and close them."""
        self.unregister()
        logger.info('%s - closing open sockets ...', self)
        self.pull.close()
        if settings.SCHEDULER_MONITOR:
            self.monitor_socket.close()
        logger.info('%s - connection closed successfully ...', self)

    def unregister(self):
        """Unregister open sockets from poller object."""
        logger.info('%s - unregistering sockets from poller objects ...', self)
        self.poller.unregister(self.pull)
        if settings.SCHEDULER_MONITOR:
            self.monitor_poller.unregister(self.monitor_socket)

    def reconnect(self):
        self.close()
        self.open()
        self.first_missing = None

    @asyncio.coroutine
    def receive(self):
        check = None
        events = yield from self.poller.poll(timeout=2000)
        if self.pull in dict(events):
            check = yield from self.pull.recv_multipart()
            check = jsonapi.loads(check[0])

        if check:
            self.first_missing = None
        elif self.first_missing is None:
            self.first_missing = datetime.now(tz=pytz.utc)
        if self.first_missing:
            diff = datetime.now(tz=pytz.utc) - self.first_missing
            delta = timedelta(minutes=settings.SCHEDULER_LIVENESS_IN_MINUTES)
            if diff > delta:
                logger.warning(
                    'Alamo worker `%s` pid `%s` try to reconnect to '
                    '`%s` scheduler.',
                    settings.WORKER_FQDN, settings.WORKER_PID, self
                )
                self.reconnect()
        return check

    @asyncio.coroutine
    def receive_event(self):
        event = None
        events = yield from self.monitor_poller.poll(timeout=2000)
        if self.monitor_socket in dict(events):
            msg = yield from self.monitor_socket.recv_multipart(
                flags=NOBLOCK)
            event = parse_monitor_message(msg)
        return event
Ejemplo n.º 16
0
def run_broker(loop):
    """ main broker method """
    print('(run_broker) starting')
    url_worker = "inproc://workers"
    url_client = "inproc://clients"
    client_nbr = NBR_CLIENTS * 3
    # Prepare our context and sockets
    context = Context()
    frontend = context.socket(zmq.ROUTER)
    frontend.bind(url_client)
    backend = context.socket(zmq.ROUTER)
    backend.bind(url_worker)
    print('(run_broker) creating workers and clients')
    # create workers and clients threads
    worker_tasks = []
    for idx in range(NBR_WORKERS):
        task = asyncio.ensure_future(run_worker(url_worker, context, idx))
        worker_tasks.append(task)
    client_tasks = []
    for idx in range(NBR_CLIENTS):
        task = asyncio.ensure_future(run_client(url_client, context, idx))
        client_tasks.append(task)
    print('(run_broker) after creating workers and clients')
    # Logic of LRU loop
    # - Poll backend always, frontend only if 1+ worker ready
    # - If worker replies, queue worker as ready and forward reply
    # to client if necessary
    # - If client requests, pop next worker and send request to it
    # Queue of available workers
    available_workers = 0
    workers_list = []
    # init poller
    poller = Poller()
    # Always poll for worker activity on backend
    poller.register(backend, zmq.POLLIN)
    # Poll front-end only if we have available workers
    poller.register(frontend, zmq.POLLIN)
    while True:
        socks = yield from poller.poll()
        socks = dict(socks)
        # Handle worker activity on backend
        if (backend in socks and socks[backend] == zmq.POLLIN):
            # Queue worker address for LRU routing
            message = yield from backend.recv_multipart()
            assert available_workers < NBR_WORKERS
            worker_addr = message[0]
            # add worker back to the list of workers
            available_workers += 1
            workers_list.append(worker_addr)
            #   Second frame is empty
            empty = message[1]
            assert empty == b""
            # Third frame is READY or else a client reply address
            client_addr = message[2]
            # If client reply, send rest back to frontend
            if client_addr != b'READY':
                # Following frame is empty
                empty = message[3]
                assert empty == b""
                reply = message[4]
                yield from frontend.send_multipart([client_addr, b"", reply])
                printdbg('(run_broker) to frontend -- reply: "{}"'.format(
                    reply))
                client_nbr -= 1
                if client_nbr == 0:
                    printdbg('(run_broker) exiting')
                    break   # Exit after N messages
        # poll on frontend only if workers are available
        if available_workers > 0:
            if (frontend in socks and socks[frontend] == zmq.POLLIN):
                # Now get next client request, route to LRU worker
                # Client request is [address][empty][request]
                response = yield from frontend.recv_multipart()
                [client_addr, empty, request] = response
                assert empty == b""
                #  Dequeue and drop the next worker address
                available_workers += -1
                worker_id = workers_list.pop()
                yield from backend.send_multipart(
                    [worker_id, b"", client_addr, b"", request])
                printdbg('(run_broker) to backend -- request: "{}"'.format(
                    request))
    #out of infinite loop: do some housekeeping
    printdbg('(run_broker) finished')
    for worker_task in worker_tasks:
        worker_task.cancel()
    printdbg('(run_broker) workers cancelled')
    yield from asyncio.sleep(1)
    frontend.close()
    backend.close()
    #context.term()     # Caution: calling term() blocks.
    loop.stop()
    printdbg('(run_broker) returning')
    return 'finished ok'
Ejemplo n.º 17
0
class Backend(object):
    """
        Backend. Central point of the architecture, manages communication between clients (frontends) and agents.
        Schedule jobs on agents.
    """
    def __init__(self, context, agent_addr, client_addr):
        self._content = context
        self._loop = asyncio.get_event_loop()
        self._agent_addr = agent_addr
        self._client_addr = client_addr

        self._agent_socket = context.socket(zmq.ROUTER)
        self._client_socket = context.socket(zmq.ROUTER)
        self._logger = logging.getLogger("inginious.backend")

        # Enable support for ipv6
        self._agent_socket.ipv6 = True
        self._client_socket.ipv6 = True

        self._poller = Poller()
        self._poller.register(self._agent_socket, zmq.POLLIN)
        self._poller.register(self._client_socket, zmq.POLLIN)

        # Dict of available environments
        # {
        #     "name": ("last_id", "created_last", ["agent_addr1", "agent_addr2"], "type")
        # }
        self._environments = {}
        self._registered_clients = set()  # addr of registered clients

        # Dict of registered agents
        # {
        #     agent_address: {"name": "friendly_name", "environments": environment_dict}
        # } environment_dict is a described in AgentHello.
        self._registered_agents = {}

        # addr of available agents. May contain multiple times the same agent, because some agent can
        # manage multiple jobs at once!
        self._available_agents = []

        # ping count per addr of agents
        self._ping_count = {}

        # These two share the same objects! Tuples should never be recreated.
        self._waiting_jobs_pq = TopicPriorityQueue(
        )  # priority queue for waiting jobs
        self._waiting_jobs = {
        }  # mapping job to job message, with key: [(client_addr_as_bytes, ClientNewJob])]

        self._job_running = {
        }  # indicates on which agent which job is running. format: {BackendJobId:(addr_as_bytes,ClientNewJob,start_time)}

    async def handle_agent_message(self, agent_addr, message):
        """Dispatch messages received from agents to the right handlers"""
        message_handlers = {
            AgentHello: self.handle_agent_hello,
            AgentJobStarted: self.handle_agent_job_started,
            AgentJobDone: self.handle_agent_job_done,
            AgentJobSSHDebug: self.handle_agent_job_ssh_debug,
            Pong: self._handle_pong
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._create_safe_task(func(agent_addr, message))

    async def handle_client_message(self, client_addr, message):
        """Dispatch messages received from clients to the right handlers"""

        # Verify that the client is registered
        if message.__class__ != ClientHello and client_addr not in self._registered_clients:
            await ZMQUtils.send_with_addr(self._client_socket, client_addr,
                                          Unknown())
            return

        message_handlers = {
            ClientHello: self.handle_client_hello,
            ClientNewJob: self.handle_client_new_job,
            ClientKillJob: self.handle_client_kill_job,
            ClientGetQueue: self.handle_client_get_queue,
            Ping: self.handle_client_ping
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._create_safe_task(func(client_addr, message))

    async def send_environment_update_to_client(self, client_addrs):
        """ :param client_addrs: list of clients to which we should send the update """
        self._logger.debug("Sending environments updates...")
        available_environments = {
            idx: environment[3]
            for idx, environment in self._environments.items()
        }
        msg = BackendUpdateEnvironments(available_environments)
        for client in client_addrs:
            await ZMQUtils.send_with_addr(self._client_socket, client, msg)

    async def handle_client_hello(self, client_addr, _: ClientHello):
        """ Handle an ClientHello message. Send available environments to the client """
        self._logger.info("New client connected %s", client_addr)
        self._registered_clients.add(client_addr)
        await self.send_environment_update_to_client([client_addr])

    async def handle_client_ping(self, client_addr, _: Ping):
        """ Handle an Ping message. Pong the client """
        await ZMQUtils.send_with_addr(self._client_socket, client_addr, Pong())

    async def handle_client_new_job(self, client_addr, message: ClientNewJob):
        """ Handle an ClientNewJob message. Add a job to the queue and triggers an update """
        self._logger.info("Adding a new job %s %s to the queue", client_addr,
                          message.job_id)

        job = (message.priority, time.time(), client_addr, message.job_id,
               message)
        self._waiting_jobs[(client_addr, message.job_id)] = job
        self._waiting_jobs_pq.put(message.environment, job)

        await self.update_queue()

    async def handle_client_kill_job(self, client_addr,
                                     message: ClientKillJob):
        """ Handle an ClientKillJob message. Remove a job from the waiting list or send the kill message to the right agent. """
        # Check if the job is not in the queue
        if (client_addr, message.job_id) in self._waiting_jobs:

            # Erase the job reference in priority queue
            job = self._waiting_jobs.pop((client_addr, message.job_id))
            job[-1] = None

            # Do not forget to send a JobDone
            await ZMQUtils.send_with_addr(
                self._client_socket, client_addr,
                BackendJobDone(message.job_id,
                               ("killed", "You killed the job"), 0.0, {}, {},
                               {}, "", None, "", ""))
        # If the job is running, transmit the info to the agent
        elif (client_addr, message.job_id) in self._job_running:
            agent_addr = self._job_running[(client_addr, message.job_id)][0]
            await ZMQUtils.send_with_addr(
                self._agent_socket, agent_addr,
                BackendKillJob((client_addr, message.job_id)))
        else:
            self._logger.warning("Client %s attempted to kill unknown job %s",
                                 str(client_addr), str(message.job_id))

    async def handle_client_get_queue(self, client_addr, _: ClientGetQueue):
        """ Handles a ClientGetQueue message. Send back info about the job queue"""
        #jobs_running: a list of tuples in the form
        #(job_id, is_current_client_job, agent_name, info, launcher, started_at, max_time)
        jobs_running = list()

        for backend_job_id, content in self._job_running.items():
            agent_friendly_name = self._registered_agents[content[0]]["name"]
            jobs_running.append(
                (content[1].job_id, backend_job_id[0] == client_addr,
                 agent_friendly_name,
                 content[1].course_id + "/" + content[1].task_id,
                 content[1].launcher, int(content[2]),
                 self._get_time_limit_estimate(content[1])))

        #jobs_waiting: a list of tuples in the form
        #(job_id, is_current_client_job, info, launcher, max_time)
        jobs_waiting = list()

        for job_client_addr, job in self._waiting_jobs.items():
            msg = job[-1]
            if isinstance(msg, ClientNewJob):
                jobs_waiting.append(
                    (msg.job_id, job_client_addr[0] == client_addr,
                     msg.course_id + "/" + msg.task_id, msg.launcher,
                     self._get_time_limit_estimate(msg)))

        await ZMQUtils.send_with_addr(
            self._client_socket, client_addr,
            BackendGetQueue(jobs_running, jobs_waiting))

    async def update_queue(self):
        """
        Send waiting jobs to available agents
        """

        jobs_ignored = []

        available_agents = list(
            self._available_agents)  # do a copy to avoid bad things

        # Loop on available agents to maximize running jobs, and break if priority queue empty
        for agent_addr in available_agents:
            if self._waiting_jobs_pq.empty():
                break  # nothing to do

            try:
                job = None
                while job is None:
                    # keep the object, do not unzip it directly! It's sometimes modified when a job is killed.
                    job = self._waiting_jobs_pq.get(
                        self._registered_agents[agent_addr]
                        ["environments"].keys())
                    priority, insert_time, client_addr, job_id, job_msg = job

                    # Killed job, removing it from the mapping
                    if not job_msg:
                        del self._waiting_jobs[(client_addr, job_id)]
                        job = None  # repeat the while loop. we need a job
            except queue.Empty:
                continue  # skip agent, nothing to do!

            # We have found a job, let's remove the agent from the available list
            self._available_agents.remove(agent_addr)

            # Remove the job from the queue
            del self._waiting_jobs[(client_addr, job_id)]

            # Send the job to agent
            job_id = (client_addr, job_msg.job_id)
            self._job_running[job_id] = (agent_addr, job_msg, time.time())
            self._logger.info("Sending job %s %s to agent %s", client_addr,
                              job_msg.job_id, agent_addr)
            await ZMQUtils.send_with_addr(
                self._agent_socket, agent_addr,
                BackendNewJob(job_id, job_msg.course_id, job_msg.task_id,
                              job_msg.inputdata, job_msg.environment,
                              job_msg.environment_parameters, job_msg.debug))

        # Let's not forget to add again the ignored jobs to the PQ.
        for entry in jobs_ignored:
            self._waiting_jobs_pq.put(entry)

    async def handle_agent_hello(self, agent_addr, message: AgentHello):
        """
        Handle an AgentAvailable message. Add agent_addr to the list of available agents
        """
        self._logger.info("Agent %s (%s) said hello", agent_addr,
                          message.friendly_name)

        if agent_addr in self._registered_agents:
            # Delete previous instance of this agent, if any
            await self._delete_agent(agent_addr)

        self._registered_agents[agent_addr] = {
            "name": message.friendly_name,
            "environments": message.available_environments
        }
        self._available_agents.extend(
            [agent_addr for _ in range(0, message.available_job_slots)])
        self._ping_count[agent_addr] = 0

        # update information about available environments
        for environment_name, environment_info in message.available_environments.items(
        ):
            if environment_name in self._environments:
                # check if the id is the same
                if self._environments[environment_name][0] == environment_info[
                        "id"]:
                    # ok, just add the agent to the list of agents that have the environment
                    self._logger.debug(
                        "Registering environment %s for agent %s",
                        environment_name, str(agent_addr))
                    self._environments[environment_name][2].append(agent_addr)
                elif self._environments[environment_name][
                        1] > environment_info["created"]:
                    # environments stored have been created after the new one
                    # add the agent, but emit a warning
                    self._logger.warning(
                        "Environment %s has multiple version: \n"
                        "\t Currently registered agents have version %s (%i)\n"
                        "\t New agent %s has version %s (%i)",
                        environment_name,
                        self._environments[environment_name][0],
                        self._environments[environment_name][1],
                        str(agent_addr), environment_info["id"],
                        environment_info["created"])
                    self._environments[environment_name][2].append(agent_addr)
                else:  # self._environments[environment_name][1] < environment_info["created"]:
                    # environments stored have been created before the new one
                    # add the agent, update the infos, and emit a warning
                    self._logger.warning(
                        "Environment %s has multiple version: \n"
                        "\t Currently registered agents have version %s (%i)\n"
                        "\t New agent %s has version %s (%i)",
                        environment_name,
                        self._environments[environment_name][0],
                        self._environments[environment_name][1],
                        str(agent_addr), environment_info["id"],
                        environment_info["created"])
                    self._environments[environment_name] = (
                        environment_info["id"], environment_info["created"],
                        self._environments[environment_name][2] + [agent_addr],
                        environment_info["type"])
            else:
                # just add it
                self._logger.debug("Registering environment %s for agent %s",
                                   environment_name, str(agent_addr))
                self._environments[environment_name] = (
                    environment_info["id"], environment_info["created"],
                    [agent_addr], environment_info["type"])

        # update the queue
        await self.update_queue()

        # update clients
        await self.send_environment_update_to_client(self._registered_clients)

    async def handle_agent_job_started(self, agent_addr,
                                       message: AgentJobStarted):
        """Handle an AgentJobStarted message. Send the data back to the client"""
        self._logger.debug("Job %s %s started on agent %s", message.job_id[0],
                           message.job_id[1], agent_addr)
        await ZMQUtils.send_with_addr(self._client_socket, message.job_id[0],
                                      BackendJobStarted(message.job_id[1]))

    async def handle_agent_job_done(self, agent_addr, message: AgentJobDone):
        """Handle an AgentJobDone message. Send the data back to the client, and start new job if needed"""

        if agent_addr in self._registered_agents:

            if message.job_id in self._job_running:
                self._logger.info("Job %s %s finished on agent %s",
                                  message.job_id[0], message.job_id[1],
                                  agent_addr)
                # Remove the job from the list of running jobs
                del self._job_running[message.job_id]
                # The agent is available now
                self._available_agents.append(agent_addr)
            else:
                self._logger.warning(
                    "Job result %s %s from agent %s was not running",
                    message.job_id[0], message.job_id[1], agent_addr)

            # Sent the data back to the client, even if we didn't know the job. This ensure everything can recover
            # in case of problems.
            await ZMQUtils.send_with_addr(
                self._client_socket, message.job_id[0],
                BackendJobDone(message.job_id[1], message.result,
                               message.grade, message.problems, message.tests,
                               message.custom, message.state, message.archive,
                               message.stdout, message.stderr))
        else:
            self._logger.warning(
                "Job result %s %s from non-registered agent %s",
                message.job_id[0], message.job_id[1], agent_addr)

        # update the queue
        await self.update_queue()

    async def handle_agent_job_ssh_debug(self, _, message: AgentJobSSHDebug):
        """Handle an AgentJobSSHDebug message. Send the data back to the client"""
        await ZMQUtils.send_with_addr(
            self._client_socket, message.job_id[0],
            BackendJobSSHDebug(message.job_id[1], message.host, message.port,
                               message.password))

    async def run(self):
        self._logger.info("Backend started")
        self._agent_socket.bind(self._agent_addr)
        self._client_socket.bind(self._client_addr)
        self._loop.call_later(1, self._create_safe_task, self._do_ping())

        try:
            while True:
                socks = await self._poller.poll()
                socks = dict(socks)

                # New message from agent
                if self._agent_socket in socks:
                    agent_addr, message = await ZMQUtils.recv_with_addr(
                        self._agent_socket)
                    await self.handle_agent_message(agent_addr, message)

                # New message from client
                if self._client_socket in socks:
                    client_addr, message = await ZMQUtils.recv_with_addr(
                        self._client_socket)
                    await self.handle_client_message(client_addr, message)

        except asyncio.CancelledError:
            return
        except KeyboardInterrupt:
            return

    async def _handle_pong(self, agent_addr, _: Pong):
        """ Handle a pong """
        self._ping_count[agent_addr] = 0

    async def _do_ping(self):
        """ Ping the agents """

        # the list() call here is needed, as we remove entries from _registered_agents!
        for agent_addr, agent_data in list(self._registered_agents.items()):
            friendly_name = agent_data["name"]

            try:
                ping_count = self._ping_count.get(agent_addr, 0)
                if ping_count > 5:
                    self._logger.warning(
                        "Agent %s (%s) does not respond: removing from list.",
                        agent_addr, friendly_name)
                    delete_agent = True
                else:
                    self._ping_count[agent_addr] = ping_count + 1
                    await ZMQUtils.send_with_addr(self._agent_socket,
                                                  agent_addr, Ping())
                    delete_agent = False
            except:
                # This should not happen, but it's better to check anyway.
                self._logger.exception(
                    "Failed to send ping to agent %s (%s). Removing it from list.",
                    agent_addr, friendly_name)
                delete_agent = True

            if delete_agent:
                try:
                    await self._delete_agent(agent_addr)
                except:
                    self._logger.exception("Failed to delete agent %s (%s)!",
                                           agent_addr, friendly_name)

        self._loop.call_later(1, self._create_safe_task, self._do_ping())

    async def _delete_agent(self, agent_addr):
        """ Deletes an agent """
        self._available_agents = [
            agent for agent in self._available_agents if agent != agent_addr
        ]
        del self._registered_agents[agent_addr]
        await self._recover_jobs()

    async def _recover_jobs(self):
        """ Recover the jobs sent to a crashed agent """
        for (client_addr,
             job_id), (agent_addr, job_msg,
                       _) in reversed(list(self._job_running.items())):
            if agent_addr not in self._registered_agents:
                await ZMQUtils.send_with_addr(
                    self._client_socket, client_addr,
                    BackendJobDone(job_id, ("crash", "Agent restarted"), 0.0,
                                   {}, {}, {}, "", None, None, None))
                del self._job_running[(client_addr, job_id)]

        await self.update_queue()

    def _create_safe_task(self, coroutine):
        """ Calls self._loop.create_task with a safe (== with logged exception) coroutine """
        task = self._loop.create_task(coroutine)
        task.add_done_callback(self.__log_safe_task)
        return task

    def __log_safe_task(self, task):
        exception = task.exception()
        if exception is not None:
            self._logger.exception(
                "An exception occurred while running a Task",
                exc_info=exception)

    def _get_time_limit_estimate(self, job_info: ClientNewJob):
        """
            Returns an estimate of the time taken by a given job, if available in the environment_parameters.
            For this to work, ["limits"]["time"] must be a parameter of the environment.
        """
        try:
            return int(job_info.environment_parameters["limits"]["time"])
        except:
            return -1  # unknown
Ejemplo n.º 18
0
class Backend(object):
    """
        Backend. Central point of the architecture, manages communication between clients (frontends) and agents.
        Schedule jobs on agents.
    """
    def __init__(self, context, agent_addr, client_addr):
        self._content = context
        self._loop = asyncio.get_event_loop()
        self._agent_addr = agent_addr
        self._client_addr = client_addr

        self._agent_socket = context.socket(zmq.ROUTER)
        self._client_socket = context.socket(zmq.ROUTER)
        self._logger = logging.getLogger("inginious.backend")

        # Enable support for ipv6
        self._agent_socket.ipv6 = True
        self._client_socket.ipv6 = True

        self._poller = Poller()
        self._poller.register(self._agent_socket, zmq.POLLIN)
        self._poller.register(self._client_socket, zmq.POLLIN)

        # List of containers available
        # {
        #     "name": ("last_id", "created_last", ["agent_addr1", "agent_addr2"])
        # }
        self._containers = {}

        # List of batch containers available
        # {
        #   "name": {
        #       "description": "a description written in RST",
        #       "id": "container img id",
        #       "created": 123456789,
        #       "agents": ["agent_addr1", "agent_addr2"]
        #       "parameters": {
        #       "key": {
        #           "type:" "file",  # or "text",
        #           "path": "path/to/file/inside/input/dir",  # not mandatory in file, by default "key"
        #           "name": "name of the field",  # not mandatory in file, default "key"
        #           "description": "a short description of what this field is used for"  # not mandatory, default ""
        #       }
        #   }
        # }
        self._batch_containers = {}

        # Batch containers available per agent {"agent_addr": ["batch_id_1", ...]}
        self._batch_containers_on_agent = {}

        # Containers available per agent {"agent_addr": ["container_id_1", ...]}
        self._containers_on_agent = {}

        self._registered_clients = set()  # addr of registered clients
        self._registered_agents = {}  # addr of registered agents
        self._available_agents = []  # addr of available agents
        self._ping_count = {}  # ping count per addr of agents
        self._waiting_jobs = OrderedDict(
        )  # rb queue for waiting jobs format:[(client_addr_as_bytes, Union[ClientNewJob,ClientNewBatchJob])]
        self._job_running = {
        }  # indicates on which agent which job is running. format: {BackendJobId:addr_as_bytes}
        self._batch_job_running = {
        }  # indicates on which agent which job is running. format: {BackendJobId:addr_as_bytes}

    async def handle_agent_message(self, agent_addr, message):
        """Dispatch messages received from agents to the right handlers"""
        message_handlers = {
            AgentHello: self.handle_agent_hello,
            AgentBatchJobStarted: self.handle_agent_batch_job_started,
            AgentBatchJobDone: self.handle_agent_batch_job_done,
            AgentJobStarted: self.handle_agent_job_started,
            AgentJobDone: self.handle_agent_job_done,
            AgentJobSSHDebug: self.handle_agent_job_ssh_debug,
            Pong: self._handle_pong
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._loop.create_task(func(agent_addr, message))

    async def handle_client_message(self, client_addr, message):
        """Dispatch messages received from clients to the right handlers"""

        # Verify that the client is registered
        if message.__class__ != ClientHello and client_addr not in self._registered_clients:
            await ZMQUtils.send_with_addr(self._client_socket, client_addr,
                                          Unknown())
            return

        message_handlers = {
            ClientHello: self.handle_client_hello,
            ClientNewBatchJob: self.handle_client_new_batch_job,
            ClientNewJob: self.handle_client_new_job,
            ClientKillJob: self.handle_client_kill_job,
            Ping: self.handle_client_ping
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._loop.create_task(func(client_addr, message))

    async def send_container_update_to_client(self, client_addrs):
        """ :param client_addrs: list of clients to which we should send the update """
        self._logger.debug("Sending containers updates...")
        available_containers = tuple(self._containers.keys())
        available_batch_containers = {
            x: {
                "description": y["description"],
                "parameters": y["parameters"]
            }
            for x, y in self._batch_containers.items()
        }
        msg = BackendUpdateContainers(available_containers,
                                      available_batch_containers)
        for client in client_addrs:
            await ZMQUtils.send_with_addr(self._client_socket, client, msg)

    async def handle_client_hello(self, client_addr, _: ClientHello):
        """ Handle an ClientHello message. Send available (batch) containers to the client """
        self._logger.info("New client connected %s", client_addr)
        self._registered_clients.add(client_addr)
        await self.send_container_update_to_client([client_addr])

    async def handle_client_ping(self, client_addr, _: Ping):
        """ Handle an Ping message. Pong the client """
        await ZMQUtils.send_with_addr(self._client_socket, client_addr, Pong())

    async def handle_client_new_batch_job(self, client_addr,
                                          message: ClientNewBatchJob):
        """ Handle an ClientNewBatchJob message. Add a job to the queue and triggers an update """
        self._logger.info("Adding a new batch job %s %s to the queue",
                          client_addr, message.job_id)
        self._waiting_jobs[(client_addr, message.job_id, "batch")] = message
        await self.update_queue()

    async def handle_client_new_job(self, client_addr, message: ClientNewJob):
        """ Handle an ClientNewJob message. Add a job to the queue and triggers an update """
        self._logger.info("Adding a new job %s %s to the queue", client_addr,
                          message.job_id)
        self._waiting_jobs[(client_addr, message.job_id, "grade")] = message
        await self.update_queue()

    async def handle_client_kill_job(self, client_addr,
                                     message: ClientKillJob):
        """ Handle an ClientKillJob message. Remove a job from the waiting list or send the kill message to the right agent. """
        # Check if the job is not in the queue
        if (client_addr, message.job_id, "grade") in self._waiting_jobs:
            del self._waiting_jobs[(client_addr, message.job_id, "grade")]
            # Do not forget to send a JobDone
            await ZMQUtils.send_with_addr(
                self._client_socket, client_addr,
                BackendJobDone(message.job_id,
                               ("killed", "You killed the job"), 0.0, {}, {},
                               {}, None, "", ""))
        # If the job is running, transmit the info to the agent
        elif (client_addr, message.job_id) in self._job_running:
            agent_addr, _ = self._job_running[(client_addr, message.job_id)]
            await ZMQUtils.send_with_addr(
                self._agent_socket, agent_addr,
                BackendKillJob((client_addr, message.job_id)))
        else:
            self._logger.warning("Client %s attempted to kill unknown job %s",
                                 str(client_addr), str(message.job_id))

    async def update_queue(self):
        """
        Send waiting jobs to available agents
        """

        # For now, round-robin
        not_found_for_agent = []

        while len(self._available_agents) > 0 and len(self._waiting_jobs) > 0:
            agent_addr = self._available_agents.pop(0)

            # Find first job that can be run on this agent
            found = False
            client_addr, job_id, typestr, job_msg = None, None, None, None
            for (client_addr, job_id,
                 typestr), job_msg in self._waiting_jobs.items():
                if typestr == "batch" and job_msg.container_name in self._batch_containers_on_agent[
                        agent_addr]:
                    found = True
                    break
                elif typestr == "grade" and job_msg.environment in self._containers_on_agent[
                        agent_addr]:
                    found = True
                    break

            if not found:
                self._logger.debug("Nothing to do for agent %s", agent_addr)
                not_found_for_agent.append(agent_addr)
                continue

            # Remove the job from the queue
            del self._waiting_jobs[(client_addr, job_id, typestr)]

            if typestr == "grade" and isinstance(job_msg, ClientNewJob):
                job_id = (client_addr, job_msg.job_id)
                self._job_running[job_id] = agent_addr, job_msg
                self._logger.info("Sending job %s %s to agent %s", client_addr,
                                  job_msg.job_id, agent_addr)
                await ZMQUtils.send_with_addr(
                    self._agent_socket, agent_addr,
                    BackendNewJob(job_id, job_msg.course_id, job_msg.task_id,
                                  job_msg.inputdata, job_msg.environment,
                                  job_msg.enable_network, job_msg.time_limit,
                                  job_msg.hard_time_limit, job_msg.mem_limit,
                                  job_msg.debug))
            elif typestr == "batch":
                job_id = (client_addr, job_msg.job_id)
                self._batch_job_running[job_id] = agent_addr, job_msg
                self._logger.info("Sending batch job %s %s to agent %s",
                                  client_addr, job_msg.job_id, agent_addr)
                await ZMQUtils.send_with_addr(
                    self._agent_socket, agent_addr,
                    BackendNewBatchJob(job_id, job_msg.container_name,
                                       job_msg.input_data))

        # Do not forget to add again for which we did not find jobs to do
        self._available_agents += not_found_for_agent

    async def handle_agent_hello(self, agent_addr, message: AgentHello):
        """
        Handle an AgentAvailable message. Add agent_addr to the list of available agents
        """
        self._logger.info("Agent %s (%s) said hello", agent_addr,
                          message.friendly_name)

        self._registered_agents[agent_addr] = message.friendly_name
        for i in range(0, message.available_job_slots):
            self._available_agents.append(agent_addr)

        self._batch_containers_on_agent[
            agent_addr] = message.available_batch_containers.keys()
        self._containers_on_agent[
            agent_addr] = message.available_containers.keys()

        # update information about available containers
        for container_name, container_info in message.available_containers.items(
        ):
            if container_name in self._containers:
                # check if the id is the same
                if self._containers[container_name][0] == container_info["id"]:
                    # ok, just add the agent to the list of agents that have the container
                    self._logger.debug("Registering container %s for agent %s",
                                       container_name, str(agent_addr))
                    self._containers[container_name][2].append(agent_addr)
                elif self._containers[container_name][1] > container_info[
                        "created"]:
                    # containers stored have been created after the new one
                    # add the agent, but emit a warning
                    self._logger.warning(
                        "Container %s has multiple version: \n"
                        "\t Currently registered agents have version %s (%i)\n"
                        "\t New agent %s has version %s (%i)", container_name,
                        self._containers[container_name][0],
                        self._containers[container_name][1], str(agent_addr),
                        container_info["id"], container_info["created"])
                    self._containers[container_name][2].append(agent_addr)
                else:  # self._containers[container_name][1] < container_info["created"]:
                    # containers stored have been created before the new one
                    # add the agent, update the infos, and emit a warning
                    self._logger.warning(
                        "Container %s has multiple version: \n"
                        "\t Currently registered agents have version %s (%i)\n"
                        "\t New agent %s has version %s (%i)", container_name,
                        self._containers[container_name][0],
                        self._containers[container_name][1], str(agent_addr),
                        container_info["id"], container_info["created"])
                    self._containers[container_name] = (
                        container_info["id"], container_info["created"],
                        self._containers[container_name][2] + [agent_addr])
            else:
                # just add it
                self._logger.debug("Registering container %s for agent %s",
                                   container_name, str(agent_addr))
                self._containers[container_name] = (container_info["id"],
                                                    container_info["created"],
                                                    [agent_addr])

        # update information about available batch containers
        for container_name, container_info in message.available_batch_containers.items(
        ):
            if container_name in self._batch_containers:
                if self._batch_containers[container_name][
                        "id"] == container_info["id"]:
                    # just add it
                    self._logger.debug(
                        "Registering batch container %s for agent %s",
                        container_name, str(agent_addr))
                    self._batch_containers[container_name]["agents"].append(
                        agent_addr)
                elif self._containers[container_name][
                        "created"] > container_info["created"]:
                    # containers stored have been created after the new one
                    # add the agent, but emit a warning
                    self._logger.warning(
                        "Batch container %s has multiple version: \n"
                        "\t Currently registered agents have version %s (%i)\n"
                        "\t New agent %s has version %s (%i)", container_name,
                        self._containers[container_name]["id"],
                        self._containers[container_name]["created"],
                        str(agent_addr), container_info["id"],
                        container_info["created"])
                    self._containers[container_name]["agents"].append(
                        agent_addr)
                else:  # self._containers[container_name]["created"] < container_info["created"]:
                    # containers stored have been created before the new one
                    # add the agent, but emit a warning
                    self._logger.warning(
                        "Batch container %s has multiple version: \n"
                        "\t Currently registered agents have version %s (%i)\n"
                        "\t New agent %s has version %s (%i)", container_name,
                        self._containers[container_name]["id"],
                        self._containers[container_name]["created"],
                        str(agent_addr), container_info["id"],
                        container_info["created"])
                    old_agents = self._containers[container_name]["agents"]
                    self._containers[container_name] = container_info.copy()
                    self._batch_containers[container_name][
                        "agents"] = old_agents + [agent_addr]
            else:
                # just add it
                self._logger.debug(
                    "Registering batch container %s for agent %s",
                    container_name, str(agent_addr))
                self._batch_containers[container_name] = container_info.copy()
                self._batch_containers[container_name]["agents"] = [agent_addr]

        # update the queue
        await self.update_queue()

        # update clients
        await self.send_container_update_to_client(self._registered_clients)

    async def handle_agent_job_started(self, agent_addr,
                                       message: AgentJobStarted):
        """Handle an AgentJobStarted message. Send the data back to the client"""
        self._logger.debug("Job %s %s started on agent %s", message.job_id[0],
                           message.job_id[1], agent_addr)
        await ZMQUtils.send_with_addr(self._client_socket, message.job_id[0],
                                      BackendJobStarted(message.job_id[1]))

    async def handle_agent_job_done(self, agent_addr, message: AgentJobDone):
        """Handle an AgentJobDone message. Send the data back to the client, and start new job if needed"""

        if agent_addr in self._registered_agents:
            self._logger.info("Job %s %s finished on agent %s",
                              message.job_id[0], message.job_id[1], agent_addr)

            # Remove the job from the list of running jobs
            del self._job_running[message.job_id]

            # Sent the data back to the client
            await ZMQUtils.send_with_addr(
                self._client_socket, message.job_id[0],
                BackendJobDone(message.job_id[1], message.result,
                               message.grade, message.problems, message.tests,
                               message.custom, message.archive, message.stdout,
                               message.stderr))

            # The agent is available now
            self._available_agents.append(agent_addr)
        else:
            self._logger.warning(
                "Job result %s %s from non-registered agent %s",
                message.job_id[0], message.job_id[1], agent_addr)

        # update the queue
        await self.update_queue()

    async def handle_agent_job_ssh_debug(self, _, message: AgentJobSSHDebug):
        """Handle an AgentJobSSHDebug message. Send the data back to the client"""
        await ZMQUtils.send_with_addr(
            self._client_socket, message.job_id[0],
            BackendJobSSHDebug(message.job_id[1], message.host, message.port,
                               message.password))

    async def handle_agent_batch_job_started(self, agent_addr,
                                             message: AgentBatchJobStarted):
        """Handle an AgentBatchJobStarted message. Send the data back to the client"""
        self._logger.debug("Batch job %s %s started on agent %s",
                           message.job_id[0], message.job_id[1], agent_addr)
        await ZMQUtils.send_with_addr(
            self._client_socket, message.job_id[0],
            BackendBatchJobStarted(message.job_id[1]))

    async def handle_agent_batch_job_done(self, agent_addr,
                                          message: AgentBatchJobDone):
        """Handle an AgentBatchJobDone message. Send the data back to the client, and start new job if needed"""

        if agent_addr in self._registered_agents:
            self._logger.info("Batch job %s %s finished on agent %s",
                              message.job_id[0], message.job_id[1], agent_addr)

            # Remove the job from the list of running jobs
            del self._batch_job_running[message.job_id]

            # Sent the data back to the client
            await ZMQUtils.send_with_addr(
                self._client_socket, message.job_id[0],
                BackendBatchJobDone(message.job_id[1], message.retval,
                                    message.stdout, message.stderr,
                                    message.file))

            # The agent is available now
            self._available_agents.append(agent_addr)
        else:
            self._logger.warning(
                "Batch job result %s %s from non-registered agent %s",
                message.job_id[0], message.job_id[1], agent_addr)

        # update the queue
        await self.update_queue()

    async def run(self):
        self._logger.info("Backend started")
        self._agent_socket.bind(self._agent_addr)
        self._client_socket.bind(self._client_addr)
        self._loop.call_later(1, asyncio.ensure_future, self._do_ping())

        try:
            while True:
                socks = await self._poller.poll()
                socks = dict(socks)

                # New message from agent
                if self._agent_socket in socks:
                    agent_addr, message = await ZMQUtils.recv_with_addr(
                        self._agent_socket)
                    await self.handle_agent_message(agent_addr, message)

                # New message from client
                if self._client_socket in socks:
                    client_addr, message = await ZMQUtils.recv_with_addr(
                        self._client_socket)
                    await self.handle_client_message(client_addr, message)

        except asyncio.CancelledError:
            return
        except KeyboardInterrupt:
            return

    async def _handle_pong(self, agent_addr, _: Pong):
        """ Handle a pong """
        self._ping_count[agent_addr] = 0

    async def _do_ping(self):
        """ Ping the agents """
        for agent_addr, friendly_name in list(self._registered_agents.items()):
            ping_count = self._ping_count.get(agent_addr, 0)
            if ping_count > 5:
                self._logger.warning(
                    "Agent %s (%s) does not respond: removing from list.",
                    agent_addr, friendly_name)
                self._available_agents = [
                    agent for agent in self._available_agents
                    if agent != agent_addr
                ]
                del self._registered_agents[agent_addr]
                await self._recover_jobs(agent_addr)
            else:
                self._ping_count[agent_addr] = ping_count + 1
                await ZMQUtils.send_with_addr(self._agent_socket, agent_addr,
                                              Ping())
        self._loop.call_later(1, asyncio.ensure_future, self._do_ping())

    async def _recover_jobs(self, agent_addr):
        """ Recover the jobs sent to a crashed agent """
        for (client_addr,
             job_id), (agent,
                       job_msg) in reversed(list(self._job_running.items())):
            if agent == agent_addr:
                self._waiting_jobs[(client_addr, job_id, "grade")] = job_msg
                del self._job_running[(client_addr, job_id)]

        for (client_addr, job_id), (agent, job_msg) in reversed(
                list(self._batch_job_running.items())):
            if agent == agent_addr:
                self._waiting_jobs[(client_addr, job_id, "batch")] = job_msg
                del self._batch_job_running[(client_addr, job_id)]

        await self.update_queue()
Ejemplo n.º 19
0
class Backend(object):
    """
        Backend. Central point of the architecture, manages communication between clients (frontends) and agents.
        Schedule jobs on agents.
    """

    def __init__(self, context, agent_addr, client_addr):
        self._content = context
        self._loop = asyncio.get_event_loop()
        self._agent_addr = agent_addr
        self._client_addr = client_addr

        self._agent_socket = context.socket(zmq.ROUTER)
        self._client_socket = context.socket(zmq.ROUTER)
        self._logger = logging.getLogger("inginious.backend")

        # Enable support for ipv6
        self._agent_socket.ipv6 = True
        self._client_socket.ipv6 = True

        self._poller = Poller()
        self._poller.register(self._agent_socket, zmq.POLLIN)
        self._poller.register(self._client_socket, zmq.POLLIN)

        # List of containers available
        # {
        #     "name": ("last_id", "created_last", ["agent_addr1", "agent_addr2"])
        # }
        self._containers = {}

        # Containers available per agent {"agent_addr": ["container_id_1", ...]}
        self._containers_on_agent = {}

        self._registered_clients = set()  # addr of registered clients
        self._registered_agents = {}  # addr of registered agents
        self._available_agents = []  # addr of available agents
        self._ping_count = {} # ping count per addr of agents
        self._waiting_jobs = OrderedDict()  # rb queue for waiting jobs format:[(client_addr_as_bytes, ClientNewJob])]
        self._job_running = {}  # indicates on which agent which job is running. format: {BackendJobId:(addr_as_bytes,ClientNewJob,start_time)}

    async def handle_agent_message(self, agent_addr, message):
        """Dispatch messages received from agents to the right handlers"""
        message_handlers = {
            AgentHello: self.handle_agent_hello,
            AgentJobStarted: self.handle_agent_job_started,
            AgentJobDone: self.handle_agent_job_done,
            AgentJobSSHDebug: self.handle_agent_job_ssh_debug,
            Pong: self._handle_pong
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._create_safe_task(func(agent_addr, message))

    async def handle_client_message(self, client_addr, message):
        """Dispatch messages received from clients to the right handlers"""

        # Verify that the client is registered
        if message.__class__ != ClientHello and client_addr not in self._registered_clients:
            await ZMQUtils.send_with_addr(self._client_socket, client_addr, Unknown())
            return

        message_handlers = {
            ClientHello: self.handle_client_hello,
            ClientNewJob: self.handle_client_new_job,
            ClientKillJob: self.handle_client_kill_job,
            ClientGetQueue: self.handle_client_get_queue,
            Ping: self.handle_client_ping
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._create_safe_task(func(client_addr, message))

    async def send_container_update_to_client(self, client_addrs):
        """ :param client_addrs: list of clients to which we should send the update """
        self._logger.debug("Sending containers updates...")
        available_containers = tuple(self._containers.keys())
        msg = BackendUpdateContainers(available_containers)
        for client in client_addrs:
            await ZMQUtils.send_with_addr(self._client_socket, client, msg)

    async def handle_client_hello(self, client_addr, _: ClientHello):
        """ Handle an ClientHello message. Send available containers to the client """
        self._logger.info("New client connected %s", client_addr)
        self._registered_clients.add(client_addr)
        await self.send_container_update_to_client([client_addr])

    async def handle_client_ping(self, client_addr, _: Ping):
        """ Handle an Ping message. Pong the client """
        await ZMQUtils.send_with_addr(self._client_socket, client_addr, Pong())

    async def handle_client_new_job(self, client_addr, message: ClientNewJob):
        """ Handle an ClientNewJob message. Add a job to the queue and triggers an update """
        self._logger.info("Adding a new job %s %s to the queue", client_addr, message.job_id)
        self._waiting_jobs[(client_addr, message.job_id)] = message
        await self.update_queue()

    async def handle_client_kill_job(self, client_addr, message: ClientKillJob):
        """ Handle an ClientKillJob message. Remove a job from the waiting list or send the kill message to the right agent. """
        # Check if the job is not in the queue
        if (client_addr, message.job_id) in self._waiting_jobs:
            del self._waiting_jobs[(client_addr, message.job_id)]
            # Do not forget to send a JobDone
            await ZMQUtils.send_with_addr(self._client_socket, client_addr, BackendJobDone(message.job_id, ("killed", "You killed the job"),
                                                                                           0.0, {}, {}, {}, "", None, "", ""))
        # If the job is running, transmit the info to the agent
        elif (client_addr, message.job_id) in self._job_running:
            agent_addr = self._job_running[(client_addr, message.job_id)][0]
            await ZMQUtils.send_with_addr(self._agent_socket, agent_addr, BackendKillJob((client_addr, message.job_id)))
        else:
            self._logger.warning("Client %s attempted to kill unknown job %s", str(client_addr), str(message.job_id))

    async def handle_client_get_queue(self, client_addr, _: ClientGetQueue):
        """ Handles a ClientGetQueue message. Send back info about the job queue"""
        #jobs_running: a list of tuples in the form
        #(job_id, is_current_client_job, agent_name, info, launcher, started_at, max_end)
        jobs_running = list()

        for backend_job_id, content in self._job_running.items():
            jobs_running.append((content[1].job_id, backend_job_id[0] == client_addr, self._registered_agents[content[0]],
                                 content[1].course_id+"/"+content[1].task_id,
                                 content[1].launcher, int(content[2]), int(content[2])+content[1].time_limit))

        #jobs_waiting: a list of tuples in the form
        #(job_id, is_current_client_job, info, launcher, max_time)
        jobs_waiting = list()

        for job_client_addr, msg in self._waiting_jobs.items():
            if isinstance(msg, ClientNewJob):
                jobs_waiting.append((msg.job_id, job_client_addr[0] == client_addr, msg.course_id+"/"+msg.task_id, msg.launcher,
                                     msg.time_limit))

        await ZMQUtils.send_with_addr(self._client_socket, client_addr, BackendGetQueue(jobs_running, jobs_waiting))

    async def update_queue(self):
        """
        Send waiting jobs to available agents
        """

        # For now, round-robin
        not_found_for_agent = []

        while len(self._available_agents) > 0 and len(self._waiting_jobs) > 0:
            agent_addr = self._available_agents.pop(0)

            # Find first job that can be run on this agent
            found = False
            client_addr, job_id, job_msg = None, None, None
            for (client_addr, job_id), job_msg in self._waiting_jobs.items():
                if job_msg.environment in self._containers_on_agent[agent_addr]:
                    found = True
                    break

            if not found:
                self._logger.debug("Nothing to do for agent %s", agent_addr)
                not_found_for_agent.append(agent_addr)
                continue

            # Remove the job from the queue
            del self._waiting_jobs[(client_addr, job_id)]

            job_id = (client_addr, job_msg.job_id)
            self._job_running[job_id] = (agent_addr, job_msg, time.time())
            self._logger.info("Sending job %s %s to agent %s", client_addr, job_msg.job_id, agent_addr)
            await ZMQUtils.send_with_addr(self._agent_socket, agent_addr, BackendNewJob(job_id, job_msg.course_id, job_msg.task_id,
                                                                                        job_msg.inputdata, job_msg.environment,
                                                                                        job_msg.enable_network, job_msg.time_limit,
                                                                                        job_msg.hard_time_limit, job_msg.mem_limit,
                                                                                        job_msg.debug))

        # Do not forget to add again for which we did not find jobs to do
        self._available_agents += not_found_for_agent

    async def handle_agent_hello(self, agent_addr, message: AgentHello):
        """
        Handle an AgentAvailable message. Add agent_addr to the list of available agents
        """
        self._logger.info("Agent %s (%s) said hello", agent_addr, message.friendly_name)

        if agent_addr in self._registered_agents:
            # Delete previous instance of this agent, if any
            await self._delete_agent(agent_addr)

        self._registered_agents[agent_addr] = message.friendly_name
        self._available_agents.extend([agent_addr for _ in range(0, message.available_job_slots)])
        self._containers_on_agent[agent_addr] = message.available_containers.keys()
        self._ping_count[agent_addr] = 0

        # update information about available containers
        for container_name, container_info in message.available_containers.items():
            if container_name in self._containers:
                # check if the id is the same
                if self._containers[container_name][0] == container_info["id"]:
                    # ok, just add the agent to the list of agents that have the container
                    self._logger.debug("Registering container %s for agent %s", container_name, str(agent_addr))
                    self._containers[container_name][2].append(agent_addr)
                elif self._containers[container_name][1] > container_info["created"]:
                    # containers stored have been created after the new one
                    # add the agent, but emit a warning
                    self._logger.warning("Container %s has multiple version: \n"
                                         "\t Currently registered agents have version %s (%i)\n"
                                         "\t New agent %s has version %s (%i)",
                                         container_name,
                                         self._containers[container_name][0], self._containers[container_name][1],
                                         str(agent_addr), container_info["id"], container_info["created"])
                    self._containers[container_name][2].append(agent_addr)
                else:  # self._containers[container_name][1] < container_info["created"]:
                    # containers stored have been created before the new one
                    # add the agent, update the infos, and emit a warning
                    self._logger.warning("Container %s has multiple version: \n"
                                         "\t Currently registered agents have version %s (%i)\n"
                                         "\t New agent %s has version %s (%i)",
                                         container_name,
                                         self._containers[container_name][0], self._containers[container_name][1],
                                         str(agent_addr), container_info["id"], container_info["created"])
                    self._containers[container_name] = (container_info["id"], container_info["created"],
                                                        self._containers[container_name][2] + [agent_addr])
            else:
                # just add it
                self._logger.debug("Registering container %s for agent %s", container_name, str(agent_addr))
                self._containers[container_name] = (container_info["id"], container_info["created"], [agent_addr])

        # update the queue
        await self.update_queue()

        # update clients
        await self.send_container_update_to_client(self._registered_clients)

    async def handle_agent_job_started(self, agent_addr, message: AgentJobStarted):
        """Handle an AgentJobStarted message. Send the data back to the client"""
        self._logger.debug("Job %s %s started on agent %s", message.job_id[0], message.job_id[1], agent_addr)
        await ZMQUtils.send_with_addr(self._client_socket, message.job_id[0], BackendJobStarted(message.job_id[1]))

    async def handle_agent_job_done(self, agent_addr, message: AgentJobDone):
        """Handle an AgentJobDone message. Send the data back to the client, and start new job if needed"""

        if agent_addr in self._registered_agents:
            self._logger.info("Job %s %s finished on agent %s", message.job_id[0], message.job_id[1], agent_addr)

            # Remove the job from the list of running jobs
            del self._job_running[message.job_id]

            # Sent the data back to the client
            await ZMQUtils.send_with_addr(self._client_socket, message.job_id[0], BackendJobDone(message.job_id[1], message.result,
                                                                                                 message.grade, message.problems,
                                                                                                 message.tests, message.custom,
                                                                                                 message.state, message.archive,
                                                                                                 message.stdout, message.stderr))

            # The agent is available now
            self._available_agents.append(agent_addr)
        else:
            self._logger.warning("Job result %s %s from non-registered agent %s", message.job_id[0], message.job_id[1], agent_addr)

        # update the queue
        await self.update_queue()

    async def handle_agent_job_ssh_debug(self, _, message: AgentJobSSHDebug):
        """Handle an AgentJobSSHDebug message. Send the data back to the client"""
        await ZMQUtils.send_with_addr(self._client_socket, message.job_id[0], BackendJobSSHDebug(message.job_id[1], message.host, message.port,
                                                                                                 message.password))

    async def run(self):
        self._logger.info("Backend started")
        self._agent_socket.bind(self._agent_addr)
        self._client_socket.bind(self._client_addr)
        self._loop.call_later(1, self._create_safe_task, self._do_ping())

        try:
            while True:
                socks = await self._poller.poll()
                socks = dict(socks)

                # New message from agent
                if self._agent_socket in socks:
                    agent_addr, message = await ZMQUtils.recv_with_addr(self._agent_socket)
                    await self.handle_agent_message(agent_addr, message)

                # New message from client
                if self._client_socket in socks:
                    client_addr, message = await ZMQUtils.recv_with_addr(self._client_socket)
                    await self.handle_client_message(client_addr, message)

        except asyncio.CancelledError:
            return
        except KeyboardInterrupt:
            return

    async def _handle_pong(self, agent_addr, _ : Pong):
        """ Handle a pong """
        self._ping_count[agent_addr] = 0

    async def _do_ping(self):
        """ Ping the agents """

        # the list() call here is needed, as we remove entries from _registered_agents!
        for agent_addr, friendly_name in list(self._registered_agents.items()):
            try:
                ping_count = self._ping_count.get(agent_addr, 0)
                if ping_count > 5:
                    self._logger.warning("Agent %s (%s) does not respond: removing from list.", agent_addr, friendly_name)
                    delete_agent = True
                else:
                    self._ping_count[agent_addr] = ping_count + 1
                    await ZMQUtils.send_with_addr(self._agent_socket, agent_addr, Ping())
                    delete_agent = False
            except:
                # This should not happen, but it's better to check anyway.
                self._logger.exception("Failed to send ping to agent %s (%s). Removing it from list.", agent_addr, friendly_name)
                delete_agent = True

            if delete_agent:
                try:
                    await self._delete_agent(agent_addr)
                except:
                    self._logger.exception("Failed to delete agent %s (%s)!", agent_addr, friendly_name)

        self._loop.call_later(1, self._create_safe_task, self._do_ping())

    async def _delete_agent(self, agent_addr):
        """ Deletes an agent """
        self._available_agents = [agent for agent in self._available_agents if agent != agent_addr]
        del self._registered_agents[agent_addr]
        await self._recover_jobs(agent_addr)

    async def _recover_jobs(self, agent_addr):
        """ Recover the jobs sent to a crashed agent """
        for (client_addr, job_id), (agent, job_msg, _) in reversed(list(self._job_running.items())):
            if agent == agent_addr:
                await ZMQUtils.send_with_addr(self._client_socket, client_addr,
                                              BackendJobDone(job_id, ("crash", "Agent restarted"),
                                                             0.0, {}, {}, {}, "", None, None, None))
                del self._job_running[(client_addr, job_id)]

        await self.update_queue()

    def _create_safe_task(self, coroutine):
        """ Calls self._loop.create_task with a safe (== with logged exception) coroutine """
        task = self._loop.create_task(coroutine)
        task.add_done_callback(self.__log_safe_task)
        return task

    def __log_safe_task(self, task):
        exception = task.exception()
        if exception is not None:
            self._logger.exception("An exception occurred while running a Task", exc_info=exception)
Ejemplo n.º 20
0
class Backend(object):
    """
        Backend. Central point of the architecture, manages communication between clients (frontends) and agents.
        Schedule jobs on agents.
    """

    def __init__(self, context, agent_addr, client_addr):
        self._content = context
        self._loop = asyncio.get_event_loop()
        self._agent_addr = agent_addr
        self._client_addr = client_addr

        self._agent_socket = context.socket(zmq.ROUTER)
        self._client_socket = context.socket(zmq.ROUTER)
        self._logger = logging.getLogger("inginious.backend")

        # Enable support for ipv6
        self._agent_socket.ipv6 = True
        self._client_socket.ipv6 = True

        self._poller = Poller()
        self._poller.register(self._agent_socket, zmq.POLLIN)
        self._poller.register(self._client_socket, zmq.POLLIN)

        # List of containers available
        # {
        #     "name": ("last_id", "created_last", ["agent_addr1", "agent_addr2"])
        # }
        self._containers = {}

        # Containers available per agent {"agent_addr": ["container_id_1", ...]}
        self._containers_on_agent = {}

        self._registered_clients = set()  # addr of registered clients
        self._registered_agents = {}  # addr of registered agents
        self._available_agents = []  # addr of available agents
        self._ping_count = {} # ping count per addr of agents
        self._waiting_jobs = OrderedDict()  # rb queue for waiting jobs format:[(client_addr_as_bytes, ClientNewJob])]
        self._job_running = {}  # indicates on which agent which job is running. format: {BackendJobId:(addr_as_bytes,ClientNewJob,start_time)}

    async def handle_agent_message(self, agent_addr, message):
        """Dispatch messages received from agents to the right handlers"""
        message_handlers = {
            AgentHello: self.handle_agent_hello,
            AgentJobStarted: self.handle_agent_job_started,
            AgentJobDone: self.handle_agent_job_done,
            AgentJobSSHDebug: self.handle_agent_job_ssh_debug,
            Pong: self._handle_pong
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._create_safe_task(func(agent_addr, message))

    async def handle_client_message(self, client_addr, message):
        """Dispatch messages received from clients to the right handlers"""

        # Verify that the client is registered
        if message.__class__ != ClientHello and client_addr not in self._registered_clients:
            await ZMQUtils.send_with_addr(self._client_socket, client_addr, Unknown())
            return

        message_handlers = {
            ClientHello: self.handle_client_hello,
            ClientNewJob: self.handle_client_new_job,
            ClientKillJob: self.handle_client_kill_job,
            ClientGetQueue: self.handle_client_get_queue,
            Ping: self.handle_client_ping
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._create_safe_task(func(client_addr, message))

    async def send_container_update_to_client(self, client_addrs):
        """ :param client_addrs: list of clients to which we should send the update """
        self._logger.debug("Sending containers updates...")
        available_containers = tuple(self._containers.keys())
        msg = BackendUpdateContainers(available_containers)
        for client in client_addrs:
            await ZMQUtils.send_with_addr(self._client_socket, client, msg)

    async def handle_client_hello(self, client_addr, _: ClientHello):
        """ Handle an ClientHello message. Send available containers to the client """
        self._logger.info("New client connected %s", client_addr)
        self._registered_clients.add(client_addr)
        await self.send_container_update_to_client([client_addr])

    async def handle_client_ping(self, client_addr, _: Ping):
        """ Handle an Ping message. Pong the client """
        await ZMQUtils.send_with_addr(self._client_socket, client_addr, Pong())

    async def handle_client_new_job(self, client_addr, message: ClientNewJob):
        """ Handle an ClientNewJob message. Add a job to the queue and triggers an update """
        self._logger.info("Adding a new job %s %s to the queue", client_addr, message.job_id)
        self._waiting_jobs[(client_addr, message.job_id)] = message
        await self.update_queue()

    async def handle_client_kill_job(self, client_addr, message: ClientKillJob):
        """ Handle an ClientKillJob message. Remove a job from the waiting list or send the kill message to the right agent. """
        # Check if the job is not in the queue
        if (client_addr, message.job_id) in self._waiting_jobs:
            del self._waiting_jobs[(client_addr, message.job_id)]
            # Do not forget to send a JobDone
            await ZMQUtils.send_with_addr(self._client_socket, client_addr, BackendJobDone(message.job_id, ("killed", "You killed the job"),
                                                                                           0.0, {}, {}, {}, None, "", ""))
        # If the job is running, transmit the info to the agent
        elif (client_addr, message.job_id) in self._job_running:
            agent_addr = self._job_running[(client_addr, message.job_id)][0]
            await ZMQUtils.send_with_addr(self._agent_socket, agent_addr, BackendKillJob((client_addr, message.job_id)))
        else:
            self._logger.warning("Client %s attempted to kill unknown job %s", str(client_addr), str(message.job_id))

    async def handle_client_get_queue(self, client_addr, _: ClientGetQueue):
        """ Handles a ClientGetQueue message. Send back info about the job queue"""
        #jobs_running: a list of tuples in the form
        #(job_id, is_current_client_job, agent_name, info, launcher, started_at, max_end)
        jobs_running = list()

        for backend_job_id, content in self._job_running.items():
            jobs_running.append((content[1].job_id, backend_job_id[0] == client_addr, self._registered_agents[content[0]],
                                 content[1].course_id+"/"+content[1].task_id,
                                 content[1].launcher, int(content[2]), int(content[2])+content[1].time_limit))

        #jobs_waiting: a list of tuples in the form
        #(job_id, is_current_client_job, info, launcher, max_time)
        jobs_waiting = list()

        for job_client_addr, msg in self._waiting_jobs.items():
            if isinstance(msg, ClientNewJob):
                jobs_waiting.append((msg.job_id, job_client_addr[0] == client_addr, msg.course_id+"/"+msg.task_id, msg.launcher,
                                     msg.time_limit))

        await ZMQUtils.send_with_addr(self._client_socket, client_addr, BackendGetQueue(jobs_running, jobs_waiting))

    async def update_queue(self):
        """
        Send waiting jobs to available agents
        """

        # For now, round-robin
        not_found_for_agent = []

        while len(self._available_agents) > 0 and len(self._waiting_jobs) > 0:
            agent_addr = self._available_agents.pop(0)

            # Find first job that can be run on this agent
            found = False
            client_addr, job_id, job_msg = None, None, None
            for (client_addr, job_id), job_msg in self._waiting_jobs.items():
                if job_msg.environment in self._containers_on_agent[agent_addr]:
                    found = True
                    break

            if not found:
                self._logger.debug("Nothing to do for agent %s", agent_addr)
                not_found_for_agent.append(agent_addr)
                continue

            # Remove the job from the queue
            del self._waiting_jobs[(client_addr, job_id)]

            job_id = (client_addr, job_msg.job_id)
            self._job_running[job_id] = (agent_addr, job_msg, time.time())
            self._logger.info("Sending job %s %s to agent %s", client_addr, job_msg.job_id, agent_addr)
            await ZMQUtils.send_with_addr(self._agent_socket, agent_addr, BackendNewJob(job_id, job_msg.course_id, job_msg.task_id,
                                                                                        job_msg.inputdata, job_msg.environment,
                                                                                        job_msg.enable_network, job_msg.time_limit,
                                                                                        job_msg.hard_time_limit, job_msg.mem_limit,
                                                                                        job_msg.debug))

        # Do not forget to add again for which we did not find jobs to do
        self._available_agents += not_found_for_agent

    async def handle_agent_hello(self, agent_addr, message: AgentHello):
        """
        Handle an AgentAvailable message. Add agent_addr to the list of available agents
        """
        self._logger.info("Agent %s (%s) said hello", agent_addr, message.friendly_name)

        self._registered_agents[agent_addr] = message.friendly_name
        self._available_agents.extend([agent_addr for _ in range(0, message.available_job_slots)])
        self._containers_on_agent[agent_addr] = message.available_containers.keys()

        # update information about available containers
        for container_name, container_info in message.available_containers.items():
            if container_name in self._containers:
                # check if the id is the same
                if self._containers[container_name][0] == container_info["id"]:
                    # ok, just add the agent to the list of agents that have the container
                    self._logger.debug("Registering container %s for agent %s", container_name, str(agent_addr))
                    self._containers[container_name][2].append(agent_addr)
                elif self._containers[container_name][1] > container_info["created"]:
                    # containers stored have been created after the new one
                    # add the agent, but emit a warning
                    self._logger.warning("Container %s has multiple version: \n"
                                         "\t Currently registered agents have version %s (%i)\n"
                                         "\t New agent %s has version %s (%i)",
                                         container_name,
                                         self._containers[container_name][0], self._containers[container_name][1],
                                         str(agent_addr), container_info["id"], container_info["created"])
                    self._containers[container_name][2].append(agent_addr)
                else:  # self._containers[container_name][1] < container_info["created"]:
                    # containers stored have been created before the new one
                    # add the agent, update the infos, and emit a warning
                    self._logger.warning("Container %s has multiple version: \n"
                                         "\t Currently registered agents have version %s (%i)\n"
                                         "\t New agent %s has version %s (%i)",
                                         container_name,
                                         self._containers[container_name][0], self._containers[container_name][1],
                                         str(agent_addr), container_info["id"], container_info["created"])
                    self._containers[container_name] = (container_info["id"], container_info["created"],
                                                        self._containers[container_name][2] + [agent_addr])
            else:
                # just add it
                self._logger.debug("Registering container %s for agent %s", container_name, str(agent_addr))
                self._containers[container_name] = (container_info["id"], container_info["created"], [agent_addr])

        # update the queue
        await self.update_queue()

        # update clients
        await self.send_container_update_to_client(self._registered_clients)

    async def handle_agent_job_started(self, agent_addr, message: AgentJobStarted):
        """Handle an AgentJobStarted message. Send the data back to the client"""
        self._logger.debug("Job %s %s started on agent %s", message.job_id[0], message.job_id[1], agent_addr)
        await ZMQUtils.send_with_addr(self._client_socket, message.job_id[0], BackendJobStarted(message.job_id[1]))

    async def handle_agent_job_done(self, agent_addr, message: AgentJobDone):
        """Handle an AgentJobDone message. Send the data back to the client, and start new job if needed"""

        if agent_addr in self._registered_agents:
            self._logger.info("Job %s %s finished on agent %s", message.job_id[0], message.job_id[1], agent_addr)

            # Remove the job from the list of running jobs
            del self._job_running[message.job_id]

            # Sent the data back to the client
            await ZMQUtils.send_with_addr(self._client_socket, message.job_id[0], BackendJobDone(message.job_id[1], message.result,
                                                                                                 message.grade, message.problems,
                                                                                                 message.tests, message.custom, message.archive,
                                                                                                 message.stdout, message.stderr))

            # The agent is available now
            self._available_agents.append(agent_addr)
        else:
            self._logger.warning("Job result %s %s from non-registered agent %s", message.job_id[0], message.job_id[1], agent_addr)

        # update the queue
        await self.update_queue()

    async def handle_agent_job_ssh_debug(self, _, message: AgentJobSSHDebug):
        """Handle an AgentJobSSHDebug message. Send the data back to the client"""
        await ZMQUtils.send_with_addr(self._client_socket, message.job_id[0], BackendJobSSHDebug(message.job_id[1], message.host, message.port,
                                                                                                 message.password))

    async def run(self):
        self._logger.info("Backend started")
        self._agent_socket.bind(self._agent_addr)
        self._client_socket.bind(self._client_addr)
        self._loop.call_later(1, self._create_safe_task, self._do_ping())

        try:
            while True:
                socks = await self._poller.poll()
                socks = dict(socks)

                # New message from agent
                if self._agent_socket in socks:
                    agent_addr, message = await ZMQUtils.recv_with_addr(self._agent_socket)
                    await self.handle_agent_message(agent_addr, message)

                # New message from client
                if self._client_socket in socks:
                    client_addr, message = await ZMQUtils.recv_with_addr(self._client_socket)
                    await self.handle_client_message(client_addr, message)

        except asyncio.CancelledError:
            return
        except KeyboardInterrupt:
            return

    async def _handle_pong(self, agent_addr, _ : Pong):
        """ Handle a pong """
        self._ping_count[agent_addr] = 0

    async def _do_ping(self):
        """ Ping the agents """

        # the list() call here is needed, as we remove entries from _registered_agents!
        for agent_addr, friendly_name in list(self._registered_agents.items()):
            try:
                ping_count = self._ping_count.get(agent_addr, 0)
                if ping_count > 5:
                    self._logger.warning("Agent %s (%s) does not respond: removing from list.", agent_addr, friendly_name)
                    delete_agent = True
                else:
                    self._ping_count[agent_addr] = ping_count + 1
                    await ZMQUtils.send_with_addr(self._agent_socket, agent_addr, Ping())
                    delete_agent = False
            except:
                # This should not happen, but it's better to check anyway.
                self._logger.exception("Failed to send ping to agent %s (%s). Removing it from list.", agent_addr, friendly_name)
                delete_agent = True

            if delete_agent:
                try:
                    self._available_agents = [agent for agent in self._available_agents if agent != agent_addr]
                    del self._registered_agents[agent_addr]
                    await self._recover_jobs(agent_addr)
                except:
                    self._logger.exception("Failed to delete agent %s (%s)!", agent_addr, friendly_name)

        self._loop.call_later(1, self._create_safe_task, self._do_ping())

    async def _recover_jobs(self, agent_addr):
        """ Recover the jobs sent to a crashed agent """
        for (client_addr, job_id), (agent, job_msg, _) in reversed(list(self._job_running.items())):
            if agent == agent_addr:
                self._waiting_jobs[(client_addr, job_id)] = job_msg
                del self._job_running[(client_addr, job_id)]

        await self.update_queue()

    def _create_safe_task(self, coroutine):
        """ Calls self._loop.create_task with a safe (== with logged exception) coroutine """
        return self._loop.create_task(self._create_safe_task_coro(coroutine))

    async def _create_safe_task_coro(self, coroutine):
        """ Helper for _create_safe_task """
        try:
            await coroutine
        except:
            self._logger.exception("An exception occurred while running a Task.")
Ejemplo n.º 21
0
class QWeatherStation:
    """Central broker for the communcation done in QWeather"""
    def __init__(self, IP, loop=None, verbose=False, debug=False):
        if loop is None:
            #from zmq import Context,Poller
            #        import asyncio
            #       from zmq.asyncio import Context,Poller
            self.loop = asyncio.get_event_loop()
        else:
            self.loop = loop

        IpAndPort = re.search(IPREPATTERN, IP)
        assert IpAndPort != None, 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://*:XXXX)'
        self.StationIP = IpAndPort.group(1)
        self.StationSocket = IpAndPort.group(2)
        assert self.StationIP[:
                              6] == 'tcp://', 'Ip not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://*:XXXX)'
        assert len(
            self.StationSocket
        ) == 4, 'Port not understood (tcp://xxx.xxx.xxx.xxx:XXXX or txp://*:XXXX)'
        formatting = '{:}: %(levelname)s: %(message)s'.format(
            'QWeatherStation')
        if debug:
            logging.basicConfig(format=formatting, level=logging.DEBUG)
        if verbose:
            logging.basicConfig(format=formatting, level=logging.INFO)
        self.servers = {}  # key:value = clientaddress:value, bytes:string
        self.clients = {}  # key:value = clientaddress:value, bytes:string
        self.servermethods = {}
        self.serverjobs = {}
        self.pinged = []
        self.requesttimeoutdict = {}
        self.cnx = Context()
        self.socket = self.cnx.socket(zmq.ROUTER)
        self.socket.bind(self.StationIP + ':' + self.StationSocket)
        self.proxy = ThreadProxy(zmq.XSUB, zmq.XPUB)
        self.proxy.bind_in(self.StationIP + ':' +
                           str(int(self.StationSocket) + PUBLISHSOCKET))
        self.proxy.bind_out(self.StationIP + ':' +
                            str(int(self.StationSocket) + SUBSOCKET))
        self.proxy.start()
        self.poller = Poller()
        self.poller.register(self.socket, zmq.POLLIN)

        logging.info('Ready to run on IP: {:}'.format(self.get_own_ip()))

    def get_own_ip(self):
        import socket
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            # doesn't even have to be reachable
            s.connect(('10.255.255.255', 1))
            IP = s.getsockname()[0]
        except:
            IP = '127.0.0.1'
        finally:
            s.close()
        return IP

    async def async_run(self):
        """Ansynchronous run the broker by polling the socket repeatedly"""
        while True:
            try:
                items = await self.poller.poll(1000)
            except KeyboardInterrupt:
                self.close()
                break

            if items:
                msg = await self.recieve_message()
                self.handle_message(msg)

    def run(self):
        """Runs the broker, enabling message handling (blocking if called from a scrip)"""
        self.loop.run_until_complete(self.async_run())

    def close(self):
        """Closing function, called at exit"""
        self.poller.unregister(self.socket)
        self.socket.close()

    def handle_message(self, msg):
        """The first step of message handling.\n
        First assert that the second frame is empty\n
        Then process either [S]erver, [C]lient, [P]ing [b]pong, or [#] for executing broker functions
        """
        sender = msg.pop(0)
        if sender in self.clients.keys():
            logging.debug('Recieved message from {:}:\n{:}'.format(
                self.clients[sender], msg, '\n\n'))
        else:
            logging.debug('Recieved message from ID:{:}:\n{:}'.format(
                int.from_bytes(sender, byteorder='big'), msg, '\n\n'))
        empty = msg.pop(0)
        assert empty == b''
        SenderType = msg.pop(0)

        #Server
        if SenderType == b'S':
            command = msg.pop(0)  # 0xF? for server and 0x0? for client
            self.process_server(sender, command, msg)

        #Client
        elif (SenderType == b'C'):
            command = msg.pop(0)  # 0xF? for server and 0x0? for client
            self.process_client(sender, command, msg)

        #Ping
        elif SenderType == b'P':
            if sender in self.clients.keys():
                logging.debug('Recieved Ping from "{:}"'.format(
                    self.clients[sender]))
            else:
                logging.debug('Recieved Ping from ID:{:}'.format(
                    int.from_bytes(sender, byteorder='big')))

            self.socket.send_multipart(
                [sender, b'',
                 b'b'])  #Sending an upside down P (b) to indicate a pong

        #Pong
        elif SenderType == b'b':
            print('got a pong')
            logging.debug('Recieved Pong from ID:{:}'.format(
                int.from_bytes(sender, byteorder='big')))
            print(sender, self.pinged, sender in self.pinged)
            if sender in self.pinged:
                print('before', self.pinged)
                self.pinged.remove(sender)
                print('after', self.pinged)

        #Execute command
        elif SenderType == b'#':
            command = msg.pop(0)
            if command == b'P':  #request broker to ping all servers and remove old ones
                logging.debug('Ping of all servers requested')
                self.loop.create_task(self.ping_connections())
            elif command == b'R':  #requests the broker to "restart" by removing all connections
                for atask in self.requesttimeoutdict.items():
                    atask.cancel()
                self.requesttimeoutdict = {}
                self.servers = {}
                self.clients = {}

            if sender in self.clients.keys():
                logging.debug('Recieved Ping from "{:}"'.format(
                    self.clients[sender]))
            else:
                logging.debug('Recieved Ping from ID:{:}'.format(
                    int.from_bytes(sender, byteorder='big')))

        #SenderType not understood
        else:
            logging.info('Invalid message')

    def process_client(self, sender, command, msg):
        """Second stage of the message handling. Messages go here if they came from a client"""
        if command == CREADY:
            version = msg.pop(0)
            self.handle_client_ready(sender, version, msg)

        elif command == CREQUEST:
            messageid = msg.pop(0)
            servername = msg.pop(0).decode()
            self.handle_client_request(sender, messageid, servername, msg)

        elif command == CDISCONNECT:
            self.handle_client_disconnect()

    def handle_client_ready(self, sender, version, msg):
        """Check the client is using the same version of QWeather, add client to clientlist and send client list of servers and servermethods"""
        if not version == PCLIENT:
            newmsg = [
                sender, b'', CREADY + CFAIL,
                'Mismatch in protocol between client and broker'.encode()
            ]
        else:
            newmsg = [sender, b'', CREADY + CSUCCESS] + [
                pickle.dumps(self.servers)
            ] + [pickle.dumps(self.servermethods)]

            name = msg.pop(0).decode()
            self.clients[sender] = name
            logging.info('Client ready at ID:{:} name:{:}'.format(
                int.from_bytes(sender, byteorder='big'), self.clients[sender]))
        self.send_message(newmsg)

    def handle_client_request(self, sender, messageid, servername, msg):
        """Send a client request to the correct server. Add a timeout callback in case the server response timeouts"""
        try:
            #Find the server address in the server dict based on the name {address:name}
            serveraddr = next(key for key, value in self.servers.items()
                              if value == servername)
            #Create a timeout call which returns an exception if the reply from the server times out.
            timeout = self.loop.call_later(
                B_SERVERRESPONSE_TIMEOUT, self.send_message, [
                    sender, b'', CREQUEST + CFAIL, messageid,
                    servername.encode(),
                    pickle.dumps((Exception('Timeout error')))
                ])
            #Add the timeout to a dictionary so we can find it later (and cancel it before it times out)
            self.requesttimeoutdict[messageid + sender] = timeout

            msg = [serveraddr, b'', CREQUEST, messageid, sender] + msg
            #If the joblist for the requested server is empty, send it to the server, else add it to the serverjoblist for later execution
            if len(self.serverjobs[serveraddr]) == 0:
                self.send_message(msg)
                logging.debug('Client request from "{:}":\n{:}'.format(
                    self.clients[sender], msg))
            else:
                self.serverjobs[serveraddr].append(msg)
        except StopIteration as e:
            logging.debug('Trying to contact a server that does not exist')

    def handle_client_disconnect(self, sender):
        """Remove the client from the client dictionary"""
        logging.debug('Client "{:}" disconnecting'.format(
            self.clients[sender]))
        self.clients.pop(sender)

    def process_server(self, sender, command, msg):
        """Second stage of the message handling. Messages go here if they came from a server"""
        if command == CREADY:
            version = msg.pop(0)
            self.handle_server_ready(sender, version, msg)

        elif command == CREPLY:
            messageid = msg.pop(0)
            servername = self.servers[sender]
            clientaddr = msg.pop(0)
            answ = msg.pop(0)
            self.handle_server_reply(sender, messageid, servername, clientaddr,
                                     answ)

        elif command == SDISCONNECT:
            self.handle_server_disconnect(sender)

    def handle_server_ready(self, sender, version, msg):
        """Check the server is using the same version of QWeather.\n
        Add the server to the serverdict, add the methods to the servermethods dict, add an empty list to the serverjobs dict\n
        Keys for all 3 dicts are the serveraddress/id assigned by ZMQ (the first frame of every message recieved)"""
        if not version == PSERVER:
            newmsg = [
                sender, b'', CREADY + CFAIL,
                'Mismatch in protocol between server and broker'.encode()
            ]
        else:
            servername = msg.pop(0).decode()
            servermethods = pickle.loads(msg.pop(0))
            self.servers[sender] = servername
            self.servermethods[sender] = servermethods
            self.serverjobs[sender] = []
            newmsg = [sender, b'', CREADY + CSUCCESS]
            logging.info('Server "{:}" ready at: {:}'.format(
                servername, int.from_bytes(sender, byteorder='big')))
        self.send_message(newmsg)

    def handle_server_reply(self, sender, messageid, servername, clientaddr,
                            answer):
        """Forward the server reply to the client that requested it.\n
        Also cancel the timeout callback now that the server has replied in time\n
        If there are more jobs in the serverjob list for this server, send the oldest one to the server"""
        msg = [
            clientaddr, b'', CREQUEST + CSUCCESS, messageid,
            servername.encode(), answer
        ]
        try:
            #Cancel the timeout callback created when the request was sent ot the server
            timeouttask = self.requesttimeoutdict.pop(messageid + clientaddr)
            timeouttask.cancel()
            self.send_message(msg)
            logging.debug('Server answer to Client "{:}":\n{:}'.format(
                self.clients[clientaddr], msg))
            #If there are more requests in queue for the server, send the oldest one
            if len(self.serverjobs[sender]) > 0:
                self.send_message(self.serverjobs[sender].pop(0))
        except KeyError:
            print("Trying to send answer to client that does not exist")

    def handle_server_disconnect(self, sender):
        """Remove the server from the server, serverjobs and servermethods dictionaries"""
        logging.debug('Server  "{:}" disconnecting'.format(
            self.servers[sender]))
        self.servers.pop(sender)
        self.serverjobs.pop(sender)
        self.servermethods.pop(sender)

    def send_message(self, msg):
        """Send a multi-frame-message over the zmq socket"""
        self.socket.send_multipart(msg)

    async def recieve_message(self):
        """Recieve a multi-frame-message over the zmq socket (async)"""
        msg = await self.socket.recv_multipart()
        return msg

    async def ping_connections(self):
        """Ping all connections, then await 2 seconds and check if the pings responded"""
        self.__ping()
        await asyncio.sleep(2)
        self.__check_ping()

    def __ping(self):
        self.pinged = []
        for addresse in self.servers.keys():
            self.socket.send_multipart([addresse, b'', CPING, b'P'])
            self.pinged.append(addresse)

    def __check_ping(self):
        for aping in self.pinged:
            for aname, aserver in self.servers.items():
                if aping == aserver[0]:
                    break
            del self.servers[aname]
        print('servers:', self.servers)
        #        print(self.pinged)
        self.pinged = []

    def get_servers(self):
        """Return the server dictionary"""
        return self.servers

    def get_clients(self):
        """Return the client dictionary"""
        return self.clients
Ejemplo n.º 22
0
class Backend(object):
    """
        Backend. Central point of the architecture, manages communication between clients (frontends) and agents.
        Schedule jobs on agents.
    """

    def __init__(self, context, agent_addr, client_addr):
        self._content = context
        self._loop = asyncio.get_event_loop()
        self._agent_addr = agent_addr
        self._client_addr = client_addr

        self._agent_socket = context.socket(zmq.ROUTER)
        self._client_socket = context.socket(zmq.ROUTER)
        self._logger = logging.getLogger("inginious.backend")

        # Enable support for ipv6
        self._agent_socket.ipv6 = True
        self._client_socket.ipv6 = True

        self._poller = Poller()
        self._poller.register(self._agent_socket, zmq.POLLIN)
        self._poller.register(self._client_socket, zmq.POLLIN)

        # dict of available environments. Keys are first the type of environement (docker, mcq, kata...) then the
        # name of the environment.
        self._environments: Dict[str, Dict[str, EnvironmentInfo]] = {}
        self._registered_clients = set()  # addr of registered clients

        self._registered_agents: Dict[bytes, AgentInfo] = {}  # all registered agents
        self._ping_count = {}  # ping count per addr of agents

        # addr of available agents. May contain multiple times the same agent, because some agent can
        # manage multiple jobs at once!
        self._available_agents = []

        # These two share the same objects! Tuples should never be recreated.
        self._waiting_jobs_pq = TopicPriorityQueue()  # priority queue for waiting jobs
        self._waiting_jobs: Dict[str, WaitingJob] = {}  # all jobs waiting in queue

        self._job_running: Dict[str, RunningJob] = {}  # all running jobs

    async def handle_agent_message(self, agent_addr, message):
        """Dispatch messages received from agents to the right handlers"""
        message_handlers = {
            AgentHello: self.handle_agent_hello,
            AgentJobStarted: self.handle_agent_job_started,
            AgentJobDone: self.handle_agent_job_done,
            AgentJobSSHDebug: self.handle_agent_job_ssh_debug,
            Pong: self._handle_pong
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        create_safe_task(self._loop, self._logger, func(agent_addr, message))

    async def handle_client_message(self, client_addr, message):
        """Dispatch messages received from clients to the right handlers"""

        # Verify that the client is registered
        if message.__class__ != ClientHello and client_addr not in self._registered_clients:
            await ZMQUtils.send_with_addr(self._client_socket, client_addr, Unknown())
            return

        message_handlers = {
            ClientHello: self.handle_client_hello,
            ClientNewJob: self.handle_client_new_job,
            ClientKillJob: self.handle_client_kill_job,
            ClientGetQueue: self.handle_client_get_queue,
            Ping: self.handle_client_ping
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        create_safe_task(self._loop, self._logger, func(client_addr, message))

    async def send_environment_update_to_client(self, client_addrs):
        """ :param client_addrs: list of clients to which we should send the update """
        self._logger.debug("Sending environments updates...")
        available_environments = {type: list(environments.keys()) for type, environments in self._environments.items()}
        msg = BackendUpdateEnvironments(available_environments)
        for client in client_addrs:
            await ZMQUtils.send_with_addr(self._client_socket, client, msg)

    async def handle_client_hello(self, client_addr, _: ClientHello):
        """ Handle an ClientHello message. Send available environments to the client """
        self._logger.info("New client connected %s", client_addr)
        self._registered_clients.add(client_addr)
        await self.send_environment_update_to_client([client_addr])

    async def handle_client_ping(self, client_addr, _: Ping):
        """ Handle an Ping message. Pong the client """
        await ZMQUtils.send_with_addr(self._client_socket, client_addr, Pong())

    async def handle_client_new_job(self, client_addr, message: ClientNewJob):
        """ Handle an ClientNewJob message. Add a job to the queue and triggers an update """

        if message.job_id in self._waiting_jobs or message.job_id in self._job_running:
            self._logger.info("Client %s asked to add a job with id %s to the queue, but it's already inside. "
                              "Duplicate random id, message repeat are possible causes, "
                              "and both should be inprobable at best.", client_addr, message.job_id)
            await ZMQUtils.send_with_addr(self._client_socket, client_addr,
                                          BackendJobDone(message.job_id, ("crash", "Duplicate job id"),
                                                         0.0, {}, {}, {}, "", None, "", ""))
            return

        self._logger.info("Adding a new job %s %s to the queue", client_addr, message.job_id)
        job = WaitingJob(message.priority, time.time(), client_addr, message.job_id, message)
        self._waiting_jobs[message.job_id] = job
        self._waiting_jobs_pq.put((message.environment_type, message.environment), job)

        await self.update_queue()

    async def handle_client_kill_job(self, client_addr, message: ClientKillJob):
        """ Handle an ClientKillJob message. Remove a job from the waiting list or send the kill message to the right agent. """
        # Check if the job is not in the queue
        if message.job_id in self._waiting_jobs:
            # Erase the job reference in priority queue
            job = self._waiting_jobs.pop(message.job_id)
            job._replace(msg=None)

            # Do not forget to send a JobDone
            await ZMQUtils.send_with_addr(self._client_socket, client_addr, BackendJobDone(message.job_id, ("killed", "You killed the job"),
                                                                                           0.0, {}, {}, {}, "", None, "", ""))
        # If the job is running, transmit the info to the agent
        elif message.job_id in self._job_running:
            agent_addr = self._job_running[message.job_id].agent_addr
            await ZMQUtils.send_with_addr(self._agent_socket, agent_addr, BackendKillJob(message.job_id))
        else:
            self._logger.warning("Client %s attempted to kill unknown job %s", str(client_addr), str(message.job_id))

    async def handle_client_get_queue(self, client_addr, _: ClientGetQueue):
        """ Handles a ClientGetQueue message. Send back info about the job queue"""
        #jobs_running: a list of tuples in the form
        #(job_id, is_current_client_job, agent_name, info, launcher, started_at, max_time)
        jobs_running = list()

        for job_id, content in self._job_running.items():
            agent_friendly_name = self._registered_agents[content.agent_addr].name
            jobs_running.append((content.msg.job_id, content.client_addr == client_addr, agent_friendly_name,
                                 content.msg.course_id+"/"+content.msg.task_id,
                                 content.msg.launcher, int(content.time_started), self._get_time_limit_estimate(content.msg)))

        #jobs_waiting: a list of tuples in the form
        #(job_id, is_current_client_job, info, launcher, max_time)
        jobs_waiting = list()

        for job in self._waiting_jobs.values():
            if isinstance(job.msg, ClientNewJob):
                jobs_waiting.append((job.job_id, job.client_addr == client_addr, job.msg.course_id+"/"+job.msg.task_id, job.msg.launcher,
                                     self._get_time_limit_estimate(job.msg)))

        await ZMQUtils.send_with_addr(self._client_socket, client_addr, BackendGetQueue(jobs_running, jobs_waiting))

    async def update_queue(self):
        """
        Send waiting jobs to available agents
        """
        available_agents = list(self._available_agents) # do a copy to avoid bad things

        # Loop on available agents to maximize running jobs, and break if priority queue empty
        for agent_addr in available_agents:
            if self._waiting_jobs_pq.empty():
                break  # nothing to do

            try:
                job = None
                while job is None:
                    # keep the object, do not unzip it directly! It's sometimes modified when a job is killed.
                    job = self._waiting_jobs_pq.get(self._registered_agents[agent_addr].environments)
                    priority, insert_time, client_addr, job_id, job_msg = job

                    # Killed job, removing it from the mapping
                    if not job_msg:
                        del self._waiting_jobs[job_id]
                        job = None  # repeat the while loop. we need a job
            except queue.Empty:
                continue  # skip agent, nothing to do!

            # We have found a job, let's remove the agent from the available list
            self._available_agents.remove(agent_addr)

            # Remove the job from the queue
            del self._waiting_jobs[job_id]

            # Send the job to agent
            self._job_running[job_id] = RunningJob(agent_addr, client_addr, job_msg, time.time())
            self._logger.info("Sending job %s %s to agent %s", client_addr, job_id, agent_addr)
            await ZMQUtils.send_with_addr(self._agent_socket, agent_addr, BackendNewJob(job_id, job_msg.course_id, job_msg.task_id,
                                                                                        job_msg.task_problems, job_msg.inputdata,
                                                                                        job_msg.environment_type,
                                                                                        job_msg.environment,
                                                                                        job_msg.environment_parameters,
                                                                                        job_msg.debug))

    async def handle_agent_hello(self, agent_addr, message: AgentHello):
        """
        Handle an AgentAvailable message. Add agent_addr to the list of available agents
        """
        self._logger.info("Agent %s (%s) said hello", agent_addr, message.friendly_name)

        if agent_addr in self._registered_agents:
            # Delete previous instance of this agent, if any
            await self._delete_agent(agent_addr)

        self._registered_agents[agent_addr] = AgentInfo(message.friendly_name,
                                                        [(etype, env) for etype, envs in
                                                         message.available_environments.items() for env in envs])
        self._available_agents.extend([agent_addr for _ in range(0, message.available_job_slots)])
        self._ping_count[agent_addr] = 0

        # update information about available environments
        for environment_type, environments in message.available_environments.items():
            if environment_type not in self._environments:
                self._environments[environment_type] = {}
            env_dict = self._environments[environment_type]
            for name, environment_info in environments.items():
                if name in env_dict:
                    # check if the id is the same
                    if env_dict[name].last_id == environment_info["id"]:
                        # ok, just add the agent to the list of agents that have the environment
                        self._logger.debug("Registering environment %s/%s for agent %s", environment_type, name, str(agent_addr))
                        env_dict[name].agents.append(agent_addr)
                    elif env_dict[name].created_last > environment_info["created"]:
                        # environments stored have been created after the new one
                        # add the agent, but emit a warning
                        self._logger.warning("Environment %s has multiple version: \n"
                                             "\t Currently registered agents have version %s (%i)\n"
                                             "\t New agent %s has version %s (%i)",
                                             name,
                                             env_dict[name].last_id, env_dict[name].created_last,
                                             str(agent_addr), environment_info["id"], environment_info["created"])
                        env_dict[name].agents.append(agent_addr)
                    else:
                        # environments stored have been created before the new one
                        # add the agent, update the infos, and emit a warning
                        self._logger.warning("Environment %s has multiple version: \n"
                                             "\t Currently registered agents have version %s (%i)\n"
                                             "\t New agent %s has version %s (%i)",
                                             name,
                                             env_dict[name].last_id, env_dict[name].created_last,
                                             str(agent_addr), environment_info["id"], environment_info["created"])
                        env_dict[name] = EnvironmentInfo(environment_info["id"], environment_info["created"],
                                                         env_dict[name].agents + [agent_addr], environment_type)
                else:
                    # just add it
                    self._logger.debug("Registering environment %s/%s for agent %s", environment_type, name, str(agent_addr))
                    env_dict[name] = EnvironmentInfo(environment_info["id"], environment_info["created"], [agent_addr], environment_type)

        # update the queue
        await self.update_queue()

        # update clients
        await self.send_environment_update_to_client(self._registered_clients)

    async def handle_agent_job_started(self, agent_addr, message: AgentJobStarted):
        """Handle an AgentJobStarted message. Send the data back to the client"""
        self._logger.debug("Job %s started on agent %s", message.job_id, agent_addr)
        if message.job_id not in self._job_running:
            self._logger.warning("Agent %s said job %s was running, but it is not in the list of running jobs", agent_addr, message.job_id)

        await ZMQUtils.send_with_addr(self._client_socket, self._job_running[message.job_id].client_addr, BackendJobStarted(message.job_id))

    async def handle_agent_job_done(self, agent_addr, message: AgentJobDone):
        """Handle an AgentJobDone message. Send the data back to the client, and start new job if needed"""

        if agent_addr in self._registered_agents:
            if message.job_id not in self._job_running:
                self._logger.warning("Job result %s from agent %s was not running", message.job_id, agent_addr)
            else:
                self._logger.info("Job %s finished on agent %s", message.job_id, agent_addr)
                # Remove the job from the list of running jobs
                running_job = self._job_running.pop(message.job_id)
                # The agent is available now
                self._available_agents.append(agent_addr)

                await ZMQUtils.send_with_addr(self._client_socket, running_job.client_addr,
                                              BackendJobDone(message.job_id, message.result, message.grade,
                                                             message.problems, message.tests, message.custom,
                                                             message.state, message.archive, message.stdout,
                                                             message.stderr))
        else:
            self._logger.warning("Job result %s from non-registered agent %s", message.job_id, agent_addr)

        # update the queue
        await self.update_queue()

    async def handle_agent_job_ssh_debug(self, agent_addr, message: AgentJobSSHDebug):
        """Handle an AgentJobSSHDebug message. Send the data back to the client"""
        if message.job_id not in self._job_running:
            self._logger.warning("Agent %s sent ssh debug info for job %s, but it is not in the list of running jobs", agent_addr, message.job_id)
        await ZMQUtils.send_with_addr(self._client_socket, self._job_running[message.job_id].client_addr,
                                      BackendJobSSHDebug(message.job_id, message.host, message.port, message.user, message.password))

    async def run(self):
        self._logger.info("Backend started")
        self._agent_socket.bind(self._agent_addr)
        self._client_socket.bind(self._client_addr)
        self._loop.call_later(1, create_safe_task, self._loop, self._logger, self._do_ping())

        try:
            while True:
                socks = await self._poller.poll()
                socks = dict(socks)

                # New message from agent
                if self._agent_socket in socks:
                    agent_addr, message = await ZMQUtils.recv_with_addr(self._agent_socket)
                    await self.handle_agent_message(agent_addr, message)

                # New message from client
                if self._client_socket in socks:
                    client_addr, message = await ZMQUtils.recv_with_addr(self._client_socket)
                    await self.handle_client_message(client_addr, message)

        except (asyncio.CancelledError, KeyboardInterrupt):
            return

    async def _handle_pong(self, agent_addr, _ : Pong):
        """ Handle a pong """
        self._ping_count[agent_addr] = 0

    async def _do_ping(self):
        """ Ping the agents """

        # the list() call here is needed, as we remove entries from _registered_agents!
        for agent_addr, agent_data in list(self._registered_agents.items()):
            friendly_name = agent_data.name

            try:
                ping_count = self._ping_count.get(agent_addr, 0)
                if ping_count > 5:
                    self._logger.warning("Agent %s (%s) does not respond: removing from list.", agent_addr, friendly_name)
                    delete_agent = True
                else:
                    self._ping_count[agent_addr] = ping_count + 1
                    await ZMQUtils.send_with_addr(self._agent_socket, agent_addr, Ping())
                    delete_agent = False
            except:
                # This should not happen, but it's better to check anyway.
                self._logger.exception("Failed to send ping to agent %s (%s). Removing it from list.", agent_addr, friendly_name)
                delete_agent = True

            if delete_agent:
                try:
                    await self._delete_agent(agent_addr)
                except:
                    self._logger.exception("Failed to delete agent %s (%s)!", agent_addr, friendly_name)

        self._loop.call_later(1, create_safe_task, self._loop, self._logger, self._do_ping())

    async def _delete_agent(self, agent_addr):
        """ Deletes an agent """
        self._available_agents = [agent for agent in self._available_agents if agent != agent_addr]
        del self._registered_agents[agent_addr]
        await self._recover_jobs()

    async def _recover_jobs(self):
        """ Recover the jobs sent to a crashed agent """
        for job_id, running_job in reversed(list(self._job_running.items())):
            if running_job.agent_addr not in self._registered_agents:
                await ZMQUtils.send_with_addr(self._client_socket, running_job.client_addr,
                                              BackendJobDone(job_id, ("crash", "Agent restarted"),
                                                             0.0, {}, {}, {}, "", None, None, None))
                del self._job_running[job_id]

        await self.update_queue()

    def _get_time_limit_estimate(self, job_info: ClientNewJob):
        """
            Returns an estimate of the time taken by a given job, if available in the environment_parameters.
            For this to work, ["limits"]["time"] must be a parameter of the environment.
        """
        try:
            return int(job_info.environment_parameters["limits"]["time"])
        except:
            return -1 # unknown
Ejemplo n.º 23
0
class DockerAgent(object):
    def __init__(self, context, backend_addr, friendly_name, nb_sub_agents, task_directory, ssh_host=None, ssh_ports=None, tmp_dir="./agent_tmp"):
        """
        :param context: ZeroMQ context for this process
        :param backend_addr: address of the backend (for example, "tcp://127.0.0.1:2222")
        :param friendly_name: a string containing a friendly name to identify agent
        :param nb_sub_agents: nb of slots available for this agent
        :param task_directory: path to the task directory
        :param ssh_host: hostname/ip/... to which external client should connect to access to an ssh remote debug session
        :param ssh_ports: iterable containing ports to which the docker instance can assign ssh servers (for remote debugging)
        :param tmp_dir: temp dir that is used by the agent to start new containers
        """
        self._logger = logging.getLogger("inginious.agent.docker")

        self._logger.info("Starting agent")

        self._backend_addr = backend_addr
        self._context = context
        self._loop = asyncio.get_event_loop()
        self._friendly_name = friendly_name
        self._nb_sub_agents = nb_sub_agents
        self._max_memory_per_slot = int(psutil.virtual_memory().total/nb_sub_agents/1024/1024)

        # data about running containers
        self._containers_running = {}
        self._student_containers_running = {}
        self._containers_ending = {}
        self._student_containers_ending = {}
        self._container_for_job = {}
        self._student_containers_for_job = {}

        self.tmp_dir = tmp_dir
        self.task_directory = task_directory

        # Delete tmp_dir, and recreate-it again
        try:
            rmtree(tmp_dir)
        except:
            pass

        try:
            os.mkdir(tmp_dir)
        except OSError:
            pass

        # Docker
        self._docker = DockerInterface()

        # Auto discover containers
        self._logger.info("Discovering containers")
        self._containers = self._docker.get_containers()

        # SSH remote debug
        self.ssh_host = ssh_host
        if self.ssh_host is None and len(self._containers) != 0:
            self._logger.info("Guessing external host IP")
            self.ssh_host = self._docker.get_host_ip(next(iter(self._containers.values()))["id"])
        if self.ssh_host is None:
            self._logger.warning("Cannot find external host IP. Please indicate it in the configuration. Remote SSH debug has been deactivated.")
            ssh_ports = None
        else:
            self._logger.info("External address for SSH remote debug is %s", self.ssh_host)
        self.ssh_ports = set(ssh_ports) if ssh_ports is not None else set()
        self.running_ssh_debug = {}  # container_id : ssh_port

        # Sockets
        self._backend_socket = self._context.socket(zmq.DEALER)
        self._backend_socket.ipv6 = True
        self._docker_events_publisher = self._context.socket(zmq.PUB)
        self._docker_events_subscriber = self._context.socket(zmq.SUB)

        # Watchers
        self._killer_watcher_push = PipelinePush(context, "agentpush")
        self._killer_watcher_pull = PipelinePull(context, "agentpull")
        self._timeout_watcher = TimeoutWatcher(context, self._docker)

        self._containers_killed = dict()

        # Poller
        self._poller = Poller()
        self._poller.register(self._backend_socket, zmq.POLLIN)
        self._poller.register(self._docker_events_subscriber, zmq.POLLIN)
        self._poller.register(self._killer_watcher_pull.get_pull_socket(), zmq.POLLIN)

    async def init_watch_docker_events(self):
        """ Init everything needed to watch docker events """
        url = "inproc://docker_events"
        self._docker_events_publisher.bind(url)
        self._docker_events_subscriber.connect(url)
        self._docker_events_subscriber.setsockopt(zmq.SUBSCRIBE, b'')
        self._loop.create_task(self._watch_docker_events())

    async def init_watcher_pipe(self):
        """ Init the killer pipeline """
        # Start elements in the pipeline
        self._loop.create_task(self._timeout_watcher.run_pipeline())

        # Link the pipeline
        self._timeout_watcher.link(self._killer_watcher_push)
        # [ if one day we have more watchers, add them here in the pipeline ]
        self._killer_watcher_pull.link(self._timeout_watcher)

    async def _watch_docker_events(self):
        """ Get raw docker events and convert them to more readable objects, and then give them to self._docker_events_subscriber """
        try:
            source = AsyncIteratorWrapper(self._docker.event_stream(filters={"event": ["die", "oom"]}))
            async for i in source:
                if i["Type"] == "container" and i["status"] == "die":
                    container_id = i["id"]
                    try:
                        retval = int(i["Actor"]["Attributes"]["exitCode"])
                    except:
                        self._logger.exception("Cannot parse exitCode for container %s", container_id)
                        retval = -1
                    await ZMQUtils.send(self._docker_events_publisher, EventContainerDied(container_id, retval))
                elif i["Type"] == "container" and i["status"] == "oom":
                    await ZMQUtils.send(self._docker_events_publisher, EventContainerOOM(i["id"]))
                else:
                    raise TypeError(str(i))
        except:
            self._logger.exception("Exception in _watch_docker_events")

    async def handle_backend_message(self, message):
        """Dispatch messages received from clients to the right handlers"""
        message_handlers = {
            BackendNewJob: self.handle_new_job,
            BackendKillJob: self.handle_kill_job,
            Ping: self.handle_ping
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._loop.create_task(func(message))

    async def handle_watcher_pipe_message(self, message):
        """Dispatch messages received from the watcher pipe to the right handlers"""
        message_handlers = {
            KWPKilledStatus: self.handle_kwp_killed_status,
            KWPRegisterContainer: self.handle_kwp_register_container
        }
        try:
            func = message_handlers[message.__class__]
        except:
            raise TypeError("Unknown message type %s" % message.__class__)
        self._loop.create_task(func(message))

    async def handle_kwp_killed_status(self, message: KWPKilledStatus):
        """
        Handles the messages returned by the "killer pipeline", that indicates if a particular container was killed
        by an element of the pipeline. Gives the message to the right handler.
        """
        if message.container_id in self._containers_ending:
            self._loop.create_task(self.handle_job_closing_p2(message))
        elif message.container_id in self._student_containers_ending:
            self._loop.create_task(self.handle_student_job_closing_p2(message))

    async def handle_kwp_register_container(self, message: KWPRegisterContainer):
        # ignore
        pass

    async def handle_ping(self, _: Ping):
        """ Handle an Ping message. Pong the backend """
        await ZMQUtils.send(self._backend_socket, Pong())

    async def handle_new_job(self, message: BackendNewJob):
        """
        Handles a new job: starts the grading container
        """
        try:
            self._logger.info("Received request for jobid %s", message.job_id)

            course_id = message.course_id
            task_id = message.task_id

            debug = message.debug
            environment_name = message.environment
            enable_network = message.enable_network
            time_limit = message.time_limit
            hard_time_limit = message.hard_time_limit or time_limit * 3
            mem_limit = message.mem_limit

            if not os.path.exists(os.path.join(self.task_directory, course_id, task_id)):
                self._logger.warning("Task %s/%s unavailable on this agent", course_id, task_id)
                await self.send_job_result(message.job_id, "crash",
                                           'Task unavailable on agent. Please retry later, the agents should synchronize soon. If the error '
                                           'persists, please contact your course administrator.')
                return

            # Check for realistic memory limit value
            if mem_limit < 20:
                mem_limit = 20
            elif mem_limit > self._max_memory_per_slot:
                self._logger.warning("Task %s/%s ask for too much memory (%dMB)! Available: %dMB", course_id, task_id, mem_limit, self._max_memory_per_slot)
                await self.send_job_result(message.job_id, "crash", 'Not enough memory on agent (available: %dMB). Please contact your course administrator.' % self._max_memory_per_slot)
                return

            if environment_name not in self._containers:
                self._logger.warning("Task %s/%s ask for an unknown environment %s (not in aliases)", course_id, task_id, environment_name)
                await self.send_job_result(message.job_id, "crash", 'Unknown container. Please contact your course administrator.')
                return

            environment = self._containers[environment_name]["id"]

            # Handle ssh debugging
            ssh_port = None
            if debug == "ssh":
                # allow 30 minutes of real time.
                time_limit = 30 * 60
                hard_time_limit = 30 * 60

                # select a port
                if len(self.ssh_ports) == 0:
                    self._logger.warning("User asked for an ssh debug but no ports are available")
                    await self.send_job_result(message.job_id, "crash", 'No ports are available for SSH debug right now. Please retry later.')
                    return
                ssh_port = self.ssh_ports.pop()

            # Create directories for storing all the data for the job
            try:
                container_path = tempfile.mkdtemp(dir=self.tmp_dir)
            except Exception as e:
                self._logger.error("Cannot make container temp directory! %s", str(e), exc_info=True)
                await self.send_job_result(message.job_id, "crash", 'Cannot make container temp directory.')
                if ssh_port is not None:
                    self.ssh_ports.add(ssh_port)
                return

            task_path = os.path.join(container_path, 'task')  # tmp_dir/id/task/
            sockets_path = os.path.join(container_path, 'sockets')  # tmp_dir/id/socket/
            student_path = os.path.join(task_path, 'student')  # tmp_dir/id/task/student/
            systemfiles_path = os.path.join(task_path, 'systemfiles')  # tmp_dir/id/task/systemfiles/

            # Create the needed directories
            os.mkdir(sockets_path)
            os.chmod(container_path, 0o777)
            os.chmod(sockets_path, 0o777)

            # TODO: avoid copy
            await self._loop.run_in_executor(None, lambda: copytree(os.path.join(self.task_directory, course_id, task_id), task_path))
            os.chmod(task_path, 0o777)

            if not os.path.exists(student_path):
                os.mkdir(student_path)
                os.chmod(student_path, 0o777)

            # Run the container
            try:
                container_id = await self._loop.run_in_executor(None, lambda: self._docker.create_container(environment, enable_network, mem_limit,
                                                                                                            task_path, sockets_path, ssh_port))
            except Exception as e:
                self._logger.warning("Cannot create container! %s", str(e), exc_info=True)
                await self.send_job_result(message.job_id, "crash", 'Cannot create container.')
                await self._loop.run_in_executor(None, lambda: rmtree(container_path))
                if ssh_port is not None:
                    self.ssh_ports.add(ssh_port)
                return

            # Store info
            future_results = asyncio.Future()
            self._containers_running[container_id] = message, container_path, future_results
            self._container_for_job[message.job_id] = container_id
            self._student_containers_for_job[message.job_id] = set()
            if ssh_port is not None:
                self.running_ssh_debug[container_id] = ssh_port

            try:
                # Start the container
                await self._loop.run_in_executor(None, lambda: self._docker.start_container(container_id))
            except Exception as e:
                self._logger.warning("Cannot start container! %s", str(e), exc_info=True)
                await self.send_job_result(message.job_id, "crash", 'Cannot start container')
                await self._loop.run_in_executor(None, lambda: rmtree(container_path))
                if ssh_port is not None:
                    self.ssh_ports.add(ssh_port)
                return

            # Talk to the container
            self._loop.create_task(self.handle_running_container(message.job_id, container_id, message.inputdata, debug, ssh_port,
                                                                 environment_name, mem_limit, time_limit, hard_time_limit,
                                                                 sockets_path, student_path, systemfiles_path,
                                                                 future_results))

            # Ask the "cgroup" thread to verify the timeout/memory limit
            await ZMQUtils.send(self._killer_watcher_push.get_push_socket(), KWPRegisterContainer(container_id, mem_limit, time_limit, hard_time_limit))

            # Tell the backend/client the job has started
            await ZMQUtils.send(self._backend_socket, AgentJobStarted(message.job_id))
        except:
            self._logger.exception("Exception in handle_new_job")

    async def create_student_container(self, job_id, parent_container_id, sockets_path, student_path, systemfiles_path, socket_id, environment_name,
                                       memory_limit, time_limit, hard_time_limit, share_network, write_stream):
        """
        Creates a new student container.
        :param write_stream: stream on which to write the return value of the container (with a correctly formatted msgpack message)
        """
        try:
            self._logger.debug("Starting new student container... %s %s %s %s", environment_name, memory_limit, time_limit, hard_time_limit)

            if environment_name not in self._containers:
                self._logger.warning("Student container asked for an unknown environment %s (not in aliases)", environment_name)
                await self._write_to_container_stdin(write_stream, {"type": "run_student_retval", "retval": 254, "socket_id": socket_id})
                return

            environment = self._containers[environment_name]["id"]

            try:
                socket_path = os.path.join(sockets_path, str(socket_id) + ".sock")
                container_id = await self._loop.run_in_executor(None,
                                                                lambda: self._docker.create_container_student(parent_container_id, environment,
                                                                                                              share_network, memory_limit,
                                                                                                              student_path, socket_path,
                                                                                                              systemfiles_path))
            except:
                self._logger.exception("Cannot create student container!")
                await self._write_to_container_stdin(write_stream, {"type": "run_student_retval", "retval": 254, "socket_id": socket_id})
                return

            self._student_containers_for_job[job_id].add(container_id)
            self._student_containers_running[container_id] = job_id, parent_container_id, socket_id, write_stream

            # send to the container that the sibling has started
            await self._write_to_container_stdin(write_stream, {"type": "run_student_started", "socket_id": socket_id})

            try:
                await self._loop.run_in_executor(None, lambda: self._docker.start_container(container_id))
            except:
                self._logger.exception("Cannot start student container!")
                await self._write_to_container_stdin(write_stream, {"type": "run_student_retval", "retval": 254, "socket_id": socket_id})
                return

            # Ask the "cgroup" thread to verify the timeout/memory limit
            await ZMQUtils.send(self._killer_watcher_push.get_push_socket(),
                                KWPRegisterContainer(container_id, memory_limit, time_limit, hard_time_limit))
        except:
            self._logger.exception("Exception in create_student_container")

    async def _write_to_container_stdin(self, write_stream, message):
        """
        Send a message to the stdin of a container, with the right data
        :param write_stream: asyncio write stream to the stdin of the container
        :param message: dict to be msgpacked and sent
        """
        msg = msgpack.dumps(message, encoding="utf8", use_bin_type=True)
        self._logger.debug("Sending %i bytes to container", len(msg))
        write_stream.write(struct.pack('I', len(msg)))
        write_stream.write(msg)
        await write_stream.drain()

    async def handle_running_container(self, job_id, container_id,
                                       inputdata, debug, ssh_port,
                                       orig_env, orig_memory_limit, orig_time_limit, orig_hard_time_limit,
                                       sockets_path, student_path, systemfiles_path,
                                       future_results):
        """ Talk with a container. Sends the initial input. Allows to start student containers """
        sock = await self._loop.run_in_executor(None, lambda: self._docker.attach_to_container(container_id))
        try:
            read_stream, write_stream = await asyncio.open_connection(sock=sock.get_socket())
        except:
            self._logger.exception("Exception occurred while creating read/write stream to container")
            return None

        # Send hello msg
        await self._write_to_container_stdin(write_stream, {"type": "start", "input": inputdata, "debug": debug})

        buffer = bytearray()
        try:
            while not read_stream.at_eof():
                msg_header = await read_stream.readexactly(8)
                outtype, length = struct.unpack_from('>BxxxL', msg_header)  # format imposed by docker in the attach endpoint
                if length != 0:
                    content = await read_stream.readexactly(length)
                    if outtype == 1:  # stdout
                        buffer += content

                    if outtype == 2:  # stderr
                        self._logger.debug("Received stderr from containers:\n%s", content)

                    # 4 first bytes are the lenght of the message. If we have a complete message...
                    while len(buffer) > 4 and len(buffer) >= 4+struct.unpack('I',buffer[0:4])[0]:
                        msg_encoded = buffer[4:4 + struct.unpack('I', buffer[0:4])[0]]  # ... get it
                        buffer = buffer[4 + struct.unpack('I', buffer[0:4])[0]:]  # ... withdraw it from the buffer
                        try:
                            msg = msgpack.unpackb(msg_encoded, encoding="utf8", use_list=False)
                            self._logger.debug("Received msg %s from container %s", msg["type"], container_id)
                            if msg["type"] == "run_student":
                                # start a new student container
                                environment = msg["environment"] or orig_env
                                memory_limit = min(msg["memory_limit"] or orig_memory_limit, orig_memory_limit)
                                time_limit = min(msg["time_limit"] or orig_time_limit, orig_time_limit)
                                hard_time_limit = min(msg["hard_time_limit"] or orig_hard_time_limit, orig_hard_time_limit)
                                share_network = msg["share_network"]
                                socket_id = msg["socket_id"]
                                assert "/" not in socket_id  # ensure task creator do not try to break the agent :-(
                                self._loop.create_task(self.create_student_container(job_id, container_id, sockets_path, student_path,
                                                                                     systemfiles_path, socket_id, environment, memory_limit,
                                                                                     time_limit, hard_time_limit, share_network, write_stream))
                            elif msg["type"] == "ssh_key":
                                # send the data to the backend (and client)
                                self._logger.info("%s %s", self.running_ssh_debug[container_id], str(msg))
                                await ZMQUtils.send(self._backend_socket, AgentJobSSHDebug(job_id, self.ssh_host, ssh_port, msg["ssh_key"]))
                            elif msg["type"] == "result":
                                # last message containing the results of the container
                                future_results.set_result(msg["result"])
                                write_stream.close()
                                sock.close_socket()
                                return  # this is the last message
                        except:
                            self._logger.exception("Received incorrect message from container %s (job id %s)", container_id, job_id)
                            future_results.set_result(None)
                            write_stream.close()
                            sock.close_socket()
                            return
        except asyncio.IncompleteReadError:
            self._logger.debug("Container output ended with an IncompleteReadError; It was probably killed.")
        except:
            self._logger.exception("Exception while reading container %s output", container_id)

        # EOF without result :-(
        self._logger.warning("Container %s has not given any result", container_id)
        write_stream.close()
        sock.close_socket()
        future_results.set_result(None)

    async def handle_student_job_closing_p1(self, container_id, retval):
        """ First part of the student container ending handler. Ask the killer pipeline if they killed the container that recently died. Do some cleaning. """
        try:
            self._logger.debug("Closing student (p1) for %s", container_id)
            try:
                job_id, parent_container_id, socket_id, write_stream = self._student_containers_running[container_id]
                del self._student_containers_running[container_id]
            except:
                self._logger.warning("Student container %s that has finished(p1) was not launched by this agent", str(container_id), exc_info=True)
                return

            # Delete remaining student containers
            if job_id in self._student_containers_for_job:  # if it does not exists, then the parent container has closed
                self._student_containers_for_job[job_id].remove(container_id)
            self._student_containers_ending[container_id] = (job_id, parent_container_id, socket_id, write_stream, retval)

            await ZMQUtils.send(self._killer_watcher_push.get_push_socket(),
                                KWPKilledStatus(container_id, self._containers_killed[container_id] if container_id in self._containers_killed else None))
        except:
            self._logger.exception("Exception in handle_student_job_closing_p1")

    async def handle_student_job_closing_p2(self, killed_msg: KWPKilledStatus):
        """ Second part of the student container ending handler. Gather results and send them to the grading container associated with the job. """
        try:
            container_id = killed_msg.container_id
            self._logger.debug("Closing student (p2) for %s", container_id)
            try:
                _, parent_container_id, socket_id, write_stream, retval = self._student_containers_ending[container_id]
                del self._student_containers_ending[container_id]
            except:
                self._logger.warning("Student container %s that has finished(p2) was not launched by this agent", str(container_id))
                return

            if killed_msg.killed_result == "timeout":
                retval = 253
            elif killed_msg.killed_result == "overflow":
                retval = 252

            try:
                await self._write_to_container_stdin(write_stream, {"type": "run_student_retval", "retval": retval, "socket_id": socket_id})
            except:
                pass  # parent container closed

            # Do not forget to remove the container
            try:
                self._loop.run_in_executor(None, lambda: self._docker.remove_container(container_id))
            except:
                pass  # ignore
        except:
            self._logger.exception("Exception in handle_student_job_closing_p1")

    async def handle_job_closing_p1(self, container_id, retval):
        """ First part of the end job handler. Ask the killer pipeline if they killed the container that recently died. Do some cleaning. """
        try:
            self._logger.debug("Closing (p1) for %s", container_id)
            try:
                message, container_path, future_results = self._containers_running[container_id]
                del self._containers_running[container_id]
            except:
                self._logger.warning("Container %s that has finished(p1) was not launched by this agent", str(container_id), exc_info=True)
                return

            self._containers_ending[container_id] = (message, container_path, retval, future_results)

            # Close sub containers
            for student_container_id_loop in self._student_containers_for_job[message.job_id]:
                # little hack to ensure the value of student_container_id_loop is copied into the closure
                def close_and_delete(student_container_id=student_container_id_loop):
                    try:
                        self._docker.kill_container(student_container_id)
                        self._docker.remove_container(student_container_id)
                    except:
                        pass  # ignore
                asyncio.ensure_future(self._loop.run_in_executor(None, close_and_delete))
            del self._student_containers_for_job[message.job_id]

            # Allow other container to reuse the ssh port this container has finished to use
            if container_id in self.running_ssh_debug:
                self.ssh_ports.add(self.running_ssh_debug[container_id])
                del self.running_ssh_debug[container_id]

            await ZMQUtils.send(self._killer_watcher_push.get_push_socket(),
                                KWPKilledStatus(container_id, self._containers_killed[container_id] if container_id in self._containers_killed else None))
        except:
            self._logger.exception("Exception in handle_job_closing_p1")

    async def handle_job_closing_p2(self, killed_msg: KWPKilledStatus):
        """ Second part of the end job handler. Gather results and send them to the backend. """
        try:
            container_id = killed_msg.container_id
            self._logger.debug("Closing (p2) for %s", container_id)
            try:
                message, container_path, retval, future_results = self._containers_ending[container_id]
                del self._containers_ending[container_id]
            except:
                self._logger.warning("Container %s that has finished(p2) was not launched by this agent", str(container_id))
                return

            stdout = ""
            stderr = ""
            result = "crash" if retval == -1 else None
            error_msg = None
            grade = None
            problems = {}
            custom = {}
            tests = {}
            archive = None

            if killed_msg.killed_result is not None:
                result = killed_msg.killed_result

            # If everything did well, continue to retrieve the status from the container
            if result is None:
                # Get logs back
                try:
                    return_value = await future_results

                    # Accepted types for return dict
                    accepted_types = {"stdout": str, "stderr": str, "result": str, "text": str, "grade": float,
                                      "problems": dict, "custom": dict, "tests": dict, "archive": str}

                    # Check dict content
                    for key, item in return_value.items():
                        if not isinstance(item, accepted_types[key]):
                            raise Exception("Feedback file is badly formatted.")
                        elif accepted_types[key] == dict:
                            for sub_key, sub_item in item.items():
                                if not id_checker(sub_key) or isinstance(sub_item, dict):
                                    raise Exception("Feedback file is badly formatted.")

                    # Set output fields
                    stdout = return_value.get("stdout", "")
                    stderr = return_value.get("stderr", "")
                    result = return_value.get("result", "error")
                    error_msg = return_value.get("text", "")
                    grade = return_value.get("grade", None)
                    problems = return_value.get("problems", {})
                    custom = return_value.get("custom", {})
                    tests = return_value.get("tests", {})
                    archive = return_value.get("archive", None)
                    if archive is not None:
                        archive = base64.b64decode(archive)
                except Exception as e:
                    self._logger.exception("Cannot get back output of container %s! (%s)", container_id, str(e))
                    result = "crash"
                    error_msg = 'The grader did not return a readable output : {}'.format(str(e))

            # Default values
            if error_msg is None:
                error_msg = ""
            if grade is None:
                if result == "success":
                    grade = 100.0
                else:
                    grade = 0.0

            # Remove container
            self._loop.run_in_executor(None, lambda: self._docker.remove_container(container_id))

            # Delete folders
            try:
                await self._loop.run_in_executor(None, lambda: rmtree(container_path))
            except PermissionError:
                self._logger.debug("Cannot remove old container path!")
                # todo: run a docker container to force removal
            
            # Return!
            await self.send_job_result(message.job_id, result, error_msg, grade, problems, tests, custom, archive, stdout, stderr)

            # Do not forget to remove data from internal state
            del self._container_for_job[message.job_id]
            if container_id in self._containers_killed:
                del self._containers_killed[container_id]
        except:
            self._logger.exception("Exception in handle_job_closing_p2")

    async def handle_kill_job(self, message: BackendKillJob):
        """ Handles `kill` messages. Kill things. """
        try:
            if message.job_id in self._container_for_job:
                self._containers_killed[self._container_for_job[message.job_id]] = "killed"
                await self._loop.run_in_executor(None, self._docker.kill_container, self._container_for_job[message.job_id])
            else:
                self._logger.warning("Cannot kill container for job %s because it is not running", str(message.job_id))
        except:
            self._logger.exception("Exception in handle_kill_job")

    async def handle_docker_event(self, message):
        """ Handles events from Docker, notably `die` and `oom` """
        try:
            if type(message) == EventContainerDied:
                if message.container_id in self._containers_running:
                    self._loop.create_task(self.handle_job_closing_p1(message.container_id, message.retval))
                elif message.container_id in self._student_containers_running:
                    self._loop.create_task(self.handle_student_job_closing_p1(message.container_id, message.retval))
            elif type(message) == EventContainerOOM:
                if message.container_id in self._containers_running or message.container_id in self._student_containers_running:
                    self._logger.info("Container %s did OOM, killing it", message.container_id)
                    self._containers_killed[message.container_id] = "overflow"
                    await self._loop.run_in_executor(None, lambda: self._docker.kill_container(message.container_id))
        except:
            self._logger.exception("Exception in handle_docker_event")

    async def send_job_result(self, job_id: BackendJobId, result: str, text: str = "", grade: float = None, problems: Dict[str, SPResult] = None,
                              tests: Dict[str, Any] = None, custom: Dict[str, Any] = None, archive: Optional[bytes] = None,
                              stdout: Optional[str] = None, stderr: Optional[str] = None):
        """ Send the result of a job back to the backend """
        if grade is None:
            if result == "success":
                grade = 100.0
            else:
                grade = 0.0
        if problems is None:
            problems = {}
        if custom is None:
            custom = {}
        if tests is None:
            tests = {}

        await ZMQUtils.send(self._backend_socket, AgentJobDone(job_id, (result, text), round(grade, 2), problems, tests, custom, archive, stdout, stderr))

    async def run_dealer(self):
        """ Run the agent """
        self._logger.info("Agent started")
        self._backend_socket.connect(self._backend_addr)

        # Init Docker events watcher
        await self.init_watch_docker_events()

        # Init watcher pipe
        await self.init_watcher_pipe()

        # Tell the backend we are up and have `nb_sub_agents` threads available
        self._logger.info("Saying hello to the backend")
        await ZMQUtils.send(self._backend_socket, AgentHello(self._friendly_name, self._nb_sub_agents, self._containers))

        # And then run the agent
        try:
            while True:
                socks = await self._poller.poll()
                socks = dict(socks)

                # New message from backend
                if self._backend_socket in socks:
                    message = await ZMQUtils.recv(self._backend_socket)
                    await self.handle_backend_message(message)

                # New docker event
                if self._docker_events_subscriber in socks:
                    message = await ZMQUtils.recv(self._docker_events_subscriber)
                    await self.handle_docker_event(message)

                # End of watcher pipe
                if self._killer_watcher_pull.get_pull_socket() in socks:
                    message = await ZMQUtils.recv(self._killer_watcher_pull.get_pull_socket())
                    await self.handle_watcher_pipe_message(message)

        except asyncio.CancelledError:
            return
        except KeyboardInterrupt:
            return