Beispiel #1
0
    def __init__(self, ip, tid, anna_client):
        self.ctx = zmq.Context()
        self.send_socket_cache = SocketCache(self.ctx, zmq.PUSH)

        self.executor_ip = ip
        self.executor_tid = tid
        self.client = anna_client

        # Threadsafe queue to serve as this node's inbox.
        # Items are (sender string, message bytestring).
        # NB: currently unbounded in size.
        self.recv_inbox = queue.Queue()

        # Thread for receiving messages into our inbox.
        self.recv_inbox_thread = threading.Thread(
              target=self._recv_inbox_listener)
        self.recv_inbox_thread.do_run = True
        self.recv_inbox_thread.start()
Beispiel #2
0
    def __init__(self, elb_addr, ip, local=False, offset=0):
        '''
        The AnnaTcpClientTcpAnnaClient allows you to interact with a local
        copy of Anna or with a remote cluster running on AWS.

        elb_addr: Either 127.0.0.1 (local mode) or the address of an AWS ELB
        for the routing tier
        ip: The IP address of the machine being used -- if None is provided,
        one is inferred by using socket.gethostbyname(); WARNING: this does not
        always work
        elb_ports: The ports on which the routing tier will listen; use 6450 if
        running in local mode, otherwise do not change
        offset: A port numbering offset, which is only needed if multiple
        clients are running on the same machine
        '''

        self.elb_addr = elb_addr

        if local:
            self.elb_ports = [6450]
        else:
            self.elb_ports = list(range(6450, 6454))

        if ip:
            self.ut = UserThread(ip, offset)
        else:  # If the IP is not provided, we attempt to infer it.
            self.ut = UserThread(socket.gethostbyname(socket.gethostname()),
                                 offset)

        self.context = zmq.Context(1)

        self.address_cache = {}
        self.pusher_cache = SocketCache(self.context, zmq.PUSH)

        self.response_puller = self.context.socket(zmq.PULL)
        self.response_puller.bind(self.ut.get_request_pull_bind_addr())

        self.key_address_puller = self.context.socket(zmq.PULL)
        self.key_address_puller.bind(self.ut.get_key_address_bind_addr())

        self.rid = 0

        self.cache = {}
Beispiel #3
0
class AnnaTcpClient(BaseAnnaClient):
    def __init__(self, elb_addr, ip, local=False, offset=0):
        '''
        The AnnaTcpClientTcpAnnaClient allows you to interact with a local
        copy of Anna or with a remote cluster running on AWS.

        elb_addr: Either 127.0.0.1 (local mode) or the address of an AWS ELB
        for the routing tier
        ip: The IP address of the machine being used -- if None is provided,
        one is inferred by using socket.gethostbyname(); WARNING: this does not
        always work
        elb_ports: The ports on which the routing tier will listen; use 6450 if
        running in local mode, otherwise do not change
        offset: A port numbering offset, which is only needed if multiple
        clients are running on the same machine
        '''

        self.elb_addr = elb_addr

        if local:
            self.elb_ports = [6450]
        else:
            self.elb_ports = list(range(6450, 6454))

        if ip:
            self.ut = UserThread(ip, offset)
        else:  # If the IP is not provided, we attempt to infer it.
            self.ut = UserThread(socket.gethostbyname(socket.gethostname()),
                                 offset)

        self.context = zmq.Context(1)

        self.get_address_cache = {}
        self.put_address_cache = {}
        self.pusher_cache = SocketCache(self.context, zmq.PUSH)

        self.response_puller = self.context.socket(zmq.PULL)
        self.response_puller.bind(self.ut.get_request_pull_bind_addr())

        self.key_address_puller = self.context.socket(zmq.PULL)
        self.key_address_puller.bind(self.ut.get_key_address_bind_addr())

        self.rid = 0

    def get(self, keys):
        if type(keys) != list:
            keys = [keys]

        worker_addresses = {}
        for key in keys:
            worker_addresses[key] = (self._get_worker_address(key, 1))
        #print("Worker Address: {}".format(worker_addresses[key]))

        if type(worker_addresses[key]) == list:
            worker_addresses[key] = worker_addresses[key][0]

        # Initialize all KV pairs to 0. Only change a value if we get a valid
        # response from the server.
        kv_pairs = {}
        for key in keys:
            kv_pairs[key] = None

        request_ids = []
        for key in worker_addresses:
            if worker_addresses[key]:
                send_sock = self.pusher_cache.get(worker_addresses[key])

                req, _ = self._prepare_data_request([key])
                req.type = GET

                send_request(req, send_sock)
                request_ids.append(req.request_id)

        # Wait for all responses to return.
        responses = recv_response(request_ids, self.response_puller,
                                  KeyResponse)

        for response in responses:
            for tup in response.tuples:
                if tup.invalidate:
                    self._invalidate_cache(tup.key, 'get')

                if tup.error == NO_ERROR and not tup.invalidate:
                    kv_pairs[tup.key] = self._deserialize(tup)

        return kv_pairs

    def get_all(self, keys):
        if type(keys) != list:
            keys = [keys]
            raise ValueError('`get_all` currently only supports single key' +
                             ' GETs.')
        worker_addresses = {}
        for key in keys:
            worker_addresses[key] = self._get_worker_address(key, False)

        # Initialize all KV pairs to 0. Only change a value if we get a valid
        # response from the server.
        kv_pairs = {}
        for key in keys:
            kv_pairs[key] = None

        for key in keys:
            if worker_addresses[key]:
                req, _ = self._prepare_data_request(key)
                req.type = GET

                req_ids = []
                for address in worker_addresses[key]:
                    req.request_id = self._get_request_id()

                    send_sock = self.pusher_cache.get(address)
                    send_request(req, send_sock)

                req_ids.append(req.request_id)

        responses = recv_response(req_ids, self.response_puller, KeyResponse)

        for resp in responses:
            for tup in resp.tuples:
                if tup.invalidate:
                    self._invalidate_cache(tup.key)

                if tup.error == NO_ERROR:
                    val = self._deserialize(tup)

                    if kv_pairs[tup.key]:
                        kv_pairs[tup.key].merge(val)
                    else:
                        kv_pairs[tup.key] = val

        return kv_pairs

    def put(self, key, value):
        port = random.choice(self.elb_ports)
        worker_address = self._query_routing(key, port)
        if type(worker_address) == list:
            worker_address = worker_address[0]

        if not worker_address:
            return False

        send_sock = self.pusher_cache.get(worker_address)

        # We pass in a list because the data request preparation can prepare
        # multiple tuples
        req, tup = self._prepare_data_request([key])
        req.type = PUT

        # PUT only supports one key operations, we only ever have to look at
        # the first KeyTuple returned.
        tup = tup[0]
        tup.payload, tup.lattice_type = self._serialize(value)

        send_request(req, send_sock)
        response = recv_response([req.request_id], self.response_puller,
                                 KeyResponse)[0]

        tup = response.tuples[0]

        if tup.invalidate:
            self._invalidate_cache(tup.key)

        return tup.error == NO_ERROR

    def put_all(self, key, value):
        worker_addresses = self._get_worker_address(key, False)

        if not worker_addresses:
            return False

        req, tup = self._prepare_data_request(key)
        req.type = PUT
        tup.payload, tup.lattice_type = self._serialize(value)
        tup.timestamp = 0

        req_ids = []
        for address in worker_addresses:
            req.request_id = self._get_request_id()

            send_sock = self.pusher_cache.get(address)
            send_request(req, send_sock)

            req_ids.append(req.request_id)

        responses = recv_response(req_ids, self.response_puller, KeyResponse)

        for resp in responses:
            tup = resp.tuples[0]
            if tup.invalidate:
                # reissue the request
                self._invalidate_cache(tup.key)
                return self.durable_put(key, value)

            if tup.error != NO_ERROR:
                return False

        return True

    # Returns the worker address for a particular key. If worker addresses for
    # that key are not cached locally, a query is synchronously issued to the
    # routing tier, and the address cache is updated.
    def _get_worker_address(self, key, access_type, pick=True):
        insert_to_cache = False
        monitor_address = None

        #if it's a GET
        if access_type == 1:
            if key not in self.get_address_cache:
                #Key is not in cache
                port = random.choice(self.elb_ports)
                addresses = self._query_routing(key, port)
                addresses = list(set(addresses))
                for address in addresses:
                    #TODO: Change here the IP to the address of the Master node!
                    if address != "tcp://192.168.0.31:5900":
                        self.get_address_cache[key] = []
                        self.get_address_cache[key].append(address)
                    else:
                        monitor_address = address
                #print(self.get_address_cache)

            if len(self.get_address_cache[key]) == 0:
                return None

            if pick:
                return random.choice(self.get_address_cache[key])
            else:
                return self.get_address_cache[key]

        # if it's a PUT
        if access_type == 2:
            if key not in self.put_address_cache:
                #print("Key is not in cache")
                port = random.choice(self.elb_ports)
                #print("Port: {}".format(port))
                addresses = self._query_routing(key, port)
                addresses = list(set(addresses))
                #print(addresses)
                #print(self.put_address_cache)
                for address in addresses:
                    #TODO: Change here the IP to the address of the Master node!
                    if address != "tcp://192.168.0.31:5900":
                        self.put_address_cache[key] = []
                        self.put_address_cache[key].append(address)
                        insert_to_cache = True
                    else:
                        monitor_address = address
                #print(self.put_address_cache)

            if len(self.put_address_cache[key]) == 0:
                #print('1')
                return None

            if pick:
                #print('2')
                return random.choice(self.put_address_cache[key])
            else:
                #print('3')
                return self.put_address_cache[key]

    # Invalidates the address cache for a particular key when the server tells
    # the client that its cache is out of date.
    def _invalidate_cache(self, key, type='both'):
        if type == 'get':
            del self.get_address_cache[key]
        elif type == 'put':
            del self.put_address_cache[key]
        else:
            try:
                del self.get_address_cache[key]
            except KeyError:
                pass
            try:
                del self.put_address_cache[key]
            except KeyError:
                pass

    # Issues a synchronous query to the routing tier. Takes in a key and a
    # (randomly chosen) routing port to issue the request to. Returns a list of
    # addresses that the routing tier returned that correspond to the input
    # key.
    def _query_routing(self, key, port):
        key_request = KeyAddressRequest()

        key_request.query_type = u'GET'

        key_request.response_address = self.ut.get_key_address_connect_addr()
        key_request.keys.append(key)
        key_request.request_id = self._get_request_id()

        dst_addr = 'tcp://' + self.elb_addr + ':' + str(port)
        send_sock = self.pusher_cache.get(dst_addr)

        send_request(key_request, send_sock)
        response = recv_response([key_request.request_id],
                                 self.key_address_puller,
                                 KeyAddressResponse)[0]

        if response.error != 0:
            return []

        result = []
        for t in response.addresses:
            if t.key == key:
                for a in t.ips:
                    result.append(a)

        return result

    @property
    def response_address(self):
        return self.ut.get_request_pull_connect_addr()
Beispiel #4
0
def scheduler(ip, mgmt_ip, route_addr):
    logging.basicConfig(filename='log_scheduler.txt', level=logging.INFO,
                        format='%(asctime)s %(message)s')

    kvs = AnnaClient(route_addr, ip)

    key_ip_map = {}
    ctx = zmq.Context(1)

    # Each dag consists of a set of functions and connections. Each one of
    # the functions is pinned to one or more nodes, which is tracked here.
    dags = {}
    thread_statuses = {}
    func_locations = {}
    running_counts = {}
    backoff = {}

    connect_socket = ctx.socket(zmq.REP)
    connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT))

    func_create_socket = ctx.socket(zmq.REP)
    func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT))

    func_call_socket = ctx.socket(zmq.REP)
    func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT))

    dag_create_socket = ctx.socket(zmq.REP)
    dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT))

    dag_call_socket = ctx.socket(zmq.REP)
    dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT))

    list_socket = ctx.socket(zmq.REP)
    list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT))

    exec_status_socket = ctx.socket(zmq.PULL)
    exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT))

    sched_update_socket = ctx.socket(zmq.PULL)
    sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.SCHED_UPDATE_PORT))

    backoff_socket = ctx.socket(zmq.PULL)
    backoff_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.BACKOFF_PORT))

    pin_accept_socket = ctx.socket(zmq.PULL)
    pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500)
    pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                           (sutils.PIN_ACCEPT_PORT))

    requestor_cache = SocketCache(ctx, zmq.REQ)
    pusher_cache = SocketCache(ctx, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(connect_socket, zmq.POLLIN)
    poller.register(func_create_socket, zmq.POLLIN)
    poller.register(func_call_socket, zmq.POLLIN)
    poller.register(dag_create_socket, zmq.POLLIN)
    poller.register(dag_call_socket, zmq.POLLIN)
    poller.register(list_socket, zmq.POLLIN)
    poller.register(exec_status_socket, zmq.POLLIN)
    poller.register(sched_update_socket, zmq.POLLIN)
    poller.register(backoff_socket, zmq.POLLIN)

    executors = set()
    executor_status_map = {}
    schedulers = _update_cluster_state(requestor_cache, mgmt_ip, executors,
                                       key_ip_map, kvs)

    # track how often each DAG function is called
    call_frequency = {}

    start = time.time()

    while True:
        socks = dict(poller.poll(timeout=1000))

        if connect_socket in socks and socks[connect_socket] == zmq.POLLIN:
            msg = connect_socket.recv_string()
            connect_socket.send_string(route_addr)

        if (func_create_socket in socks and
                socks[func_create_socket] == zmq.POLLIN):
            create_func(func_create_socket, kvs)

        if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN:
            call_function(func_call_socket, pusher_cache, executors,
                          key_ip_map, executor_status_map, running_counts,
                          backoff)

        if (dag_create_socket in socks and socks[dag_create_socket]
                == zmq.POLLIN):
            create_dag(dag_create_socket, pusher_cache, kvs, executors, dags,
                       ip, pin_accept_socket, func_locations, call_frequency)

        if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN:
            call = DagCall()
            call.ParseFromString(dag_call_socket.recv())

            if call.name not in dags:
                resp = GenericResponse()
                resp.success = False
                resp.error = NO_SUCH_DAG

                dag_call_socket.send(resp.SerializeToString())
                continue

            exec_id = generate_timestamp(0)

            dag = dags[call.name]
            for fname in dag[0].functions:
                call_frequency[fname] += 1

            rid = call_dag(call, pusher_cache, dags, func_locations,
                           key_ip_map, running_counts, backoff)

            resp = GenericResponse()
            resp.success = True
            resp.response_id = rid
            dag_call_socket.send(resp.SerializeToString())

        if list_socket in socks and socks[list_socket] == zmq.POLLIN:
            logging.info('Received query for function list.')
            msg = list_socket.recv_string()
            prefix = msg if msg else ''

            resp = FunctionList()
            resp.names.extend(utils._get_func_list(kvs, prefix))

            list_socket.send(resp.SerializeToString())

        if exec_status_socket in socks and socks[exec_status_socket] == \
                zmq.POLLIN:
            status = ThreadStatus()
            status.ParseFromString(exec_status_socket.recv())

            key = (status.ip, status.tid)
            logging.info('Received status update from executor %s:%d.' %
                         (key[0], int(key[1])))

            if key in executor_status_map:
                if status.type == PERIODIC:
                    if executor_status_map[key] - time.time() > 5:
                        del executor_status_map[key]
                    else:
                        continue
                elif status.type == POST_REQUEST:
                    del executor_status_map[key]

            # this means that this node is currently departing, so we remove it
            # from all of our metadata tracking
            if not status.running:
                if key in thread_statuses:
                    old_status = thread_statuses[key]
                    del thread_statuses[key]

                    for fname in old_status.functions:
                        func_locations[fname].discard((old_status.ip,
                                                       old_status.tid))

                executors.discard(key)
                continue

            if key not in executors:
                executors.add(key)

            if key in thread_statuses and thread_statuses[key] != status:
                # remove all the old function locations, and all the new ones
                # -- there will probably be a large overlap, but this shouldn't
                # be much different than calculating two different set
                # differences anyway
                for func in thread_statuses[key].functions:
                    if func in func_locations:
                        func_locations[func].discard(key)

            thread_statuses[key] = status
            for func in status.functions:
                if func not in func_locations:
                    func_locations[func] = set()

                func_locations[func].add(key)

        if sched_update_socket in socks and socks[sched_update_socket] == \
                zmq.POLLIN:
            status = SchedulerStatus()
            status.ParseFromString(sched_update_socket.recv())

            # retrieve any DAG that some other scheduler knows about that we do
            # not yet know about
            for dname in status.dags:
                if dname not in dags:
                    payload = kvs.get(dname)
                    while not payload:
                        payload = kvs.get(dname)
                    dag = Dag()
                    dag.ParseFromString(payload.reveal()[1])

                    dags[dag.name] = (dag, utils._find_dag_source(dag))

                    for fname in dag.functions:
                        if fname not in call_frequency:
                            call_frequency[fname] = 0

                        if fname not in func_locations:
                            func_locations[fname] = set()

            for floc in status.func_locations:
                key = (floc.ip, floc.tid)
                fname = floc.name

                if fname not in func_locations:
                    func_locations[fname] = set()

                func_locations[fname].add(key)

        if backoff_socket in socks and socks[backoff_socket] == zmq.POLLIN:
            msg = backoff_socket.recv_string()
            splits = msg.split(':')
            node, tid = splits[0], int(splits[1])

            backoff[(node, tid)] = time.time()

        # periodically clean up the running counts map
        for executor in running_counts:
            call_times = running_counts[executor]
            new_set = set()
            for ts in call_times:
                if time.time() - ts < 2.5:
                    new_set.add(ts)

            running_counts[executor] = new_set

        remove_set = set()
        for executor in backoff:
            if time.time() - backoff[executor] > 5:
                remove_set.add(executor)

        for executor in remove_set:
            del backoff[executor]

        end = time.time()
        if end - start > THRESHOLD:
            schedulers = _update_cluster_state(requestor_cache, mgmt_ip,
                                               executors, key_ip_map, kvs)

            status = SchedulerStatus()
            for name in dags.keys():
                status.dags.append(name)

            for fname in func_locations:
                for loc in func_locations[fname]:
                    floc = status.func_locations.add()
                    floc.name = fname
                    floc.ip = loc[0]
                    floc.tid = loc[1]

            msg = status.SerializeToString()

            for sched_ip in schedulers:
                if sched_ip != ip:
                    sckt = pusher_cache.get(utils._get_scheduler_update_address
                                            (sched_ip))
                    sckt.send(msg)

            stats = ExecutorStatistics()
            for fname in call_frequency:
                fstats = stats.statistics.add()
                fstats.fname = fname
                fstats.call_count = call_frequency[fname]
                logging.info('Reporting %d calls for function %s.' %
                             (call_frequency[fname], fname))

                call_frequency[fname] = 0

            sckt = pusher_cache.get(sutils._get_statistics_report_address
                                    (mgmt_ip))
            sckt.send(stats.SerializeToString())

            start = time.time()
Beispiel #5
0
def executor(ip, mgmt_ip, schedulers, thread_id):
    # logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s %(message)s')
    logging.basicConfig(filename='log_executor.txt',
                        level=logging.INFO,
                        filemode="w",
                        format='%(asctime)s %(message)s')

    # Check what resources we have access to, set as an environment variable.
    if os.getenv('EXECUTOR_TYPE', 'CPU') == 'GPU':
        exec_type = GPU
    else:
        exec_type = CPU

    context = zmq.Context(1)
    poller = zmq.Poller()

    pin_socket = context.socket(zmq.PULL)
    pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id))

    unpin_socket = context.socket(zmq.PULL)
    unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                      (sutils.UNPIN_PORT + thread_id))

    exec_socket = context.socket(zmq.PULL)
    exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                     (sutils.FUNC_EXEC_PORT + thread_id))

    dag_queue_socket = context.socket(zmq.PULL)
    dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                          (sutils.DAG_QUEUE_PORT + thread_id))

    dag_exec_socket = context.socket(zmq.PULL)
    dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                         (sutils.DAG_EXEC_PORT + thread_id))

    self_depart_socket = context.socket(zmq.PULL)
    self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                            (sutils.SELF_DEPART_PORT + thread_id))

    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(pin_socket, zmq.POLLIN)
    poller.register(unpin_socket, zmq.POLLIN)
    poller.register(exec_socket, zmq.POLLIN)
    poller.register(dag_queue_socket, zmq.POLLIN)
    poller.register(dag_exec_socket, zmq.POLLIN)
    poller.register(self_depart_socket, zmq.POLLIN)

    # If the management IP is set to None, that means that we are running in
    # local mode, so we use a regular AnnaTcpClient rather than an IPC client.
    has_ephe = False
    if mgmt_ip:
        if 'STORAGE_OR_DEFAULT' in os.environ and os.environ[
                'STORAGE_OR_DEFAULT'] == '0':
            client = AnnaTcpClient(os.environ['ROUTE_ADDR'],
                                   ip,
                                   local=False,
                                   offset=thread_id)
            has_ephe = True
        else:
            client = AnnaIpcClient(thread_id, context)
        # force_remote_anna = 1
        # if 'FORCE_REMOTE' in os.environ:
        #     force_remote_anna = int(os.environ['FORCE_REMOTE'])

        # if force_remote_anna == 0: # remote anna only
        #     client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id)
        # elif force_remote_anna == 1: # anna cache
        #     client = AnnaIpcClient(thread_id, context)
        # elif force_remote_anna == 2: # control both cache and remote anna
        #     remote_client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id)
        #     cache_client = AnnaIpcClient(thread_id, context)
        #     client = cache_client
        #     user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id, (cache_client, remote_client))

        local = False
    else:
        client = AnnaTcpClient('127.0.0.1', '127.0.0.1', local=True, offset=1)
        local = True

    user_library = CloudburstUserLibrary(context,
                                         pusher_cache,
                                         ip,
                                         thread_id,
                                         client,
                                         has_ephe=has_ephe)

    status = ThreadStatus()
    status.ip = ip
    status.tid = thread_id
    status.running = True
    status.type = exec_type
    utils.push_status(schedulers, pusher_cache, status)

    departing = False

    # Maintains a request queue for each function pinned on this executor. Each
    # function will have a set of request IDs mapped to it, and this map stores
    # a schedule for each request ID.
    queue = {}

    # Tracks the actual function objects that are pinned to this executor.
    function_cache = {}

    # Tracks runtime cost of excuting a DAG function.
    runtimes = {}

    # If multiple triggers are necessary for a function, track the triggers as
    # we receive them. This is also used if a trigger arrives before its
    # corresponding schedule.
    received_triggers = {}

    # Tracks when we received a function request, so we can report end-to-end
    # latency for the whole executio.
    receive_times = {}

    # Tracks the number of requests we are finishing for each function pinned
    # here.
    exec_counts = {}

    # Tracks the end-to-end runtime of each DAG request for which we are the
    # sink function.
    dag_runtimes = {}

    # A map with KVS keys and their corresponding deserialized payloads.
    cache = {}

    # A map which tracks the most recent DAGs for which we have finished our
    # work.
    finished_executions = {}

    # The set of pinned functions and whether they support batching. NOTE: This
    # is only a set for local mode -- in cluster mode, there will only be one
    # pinned function per executor.
    batching = False

    # Internal metadata to track thread utilization.
    report_start = time.time()
    event_occupancy = {
        'pin': 0.0,
        'unpin': 0.0,
        'func_exec': 0.0,
        'dag_queue': 0.0,
        'dag_exec': 0.0
    }
    total_occupancy = 0.0

    while True:
        socks = dict(poller.poll(timeout=1000))

        if pin_socket in socks and socks[pin_socket] == zmq.POLLIN:
            work_start = time.time()
            batching = pin(pin_socket, pusher_cache, client, status,
                           function_cache, runtimes, exec_counts, user_library,
                           local, batching)
            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['pin'] += elapsed
            total_occupancy += elapsed

        if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN:
            work_start = time.time()
            unpin(unpin_socket, status, function_cache, runtimes, exec_counts)
            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['unpin'] += elapsed
            total_occupancy += elapsed

        if exec_socket in socks and socks[exec_socket] == zmq.POLLIN:
            work_start = time.time()
            # logging.info(f'Executor timer. exec_socket recv: {work_start}')
            exec_function(exec_socket,
                          client,
                          user_library,
                          cache,
                          function_cache,
                          has_ephe=has_ephe)
            user_library.close()

            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['func_exec'] += elapsed
            total_occupancy += elapsed

        if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN:
            work_start = time.time()
            logging.info(
                f'Executor timer. dag_queue_socket recv: {work_start}')
            # In order to effectively support batching, we have to make sure we
            # dequeue lots of schedules in addition to lots of triggers. Right
            # now, we're not going to worry about supporting batching here,
            # just on the trigger dequeue side, but we still have to dequeue
            # all schedules we've received. We just process them one at a time.
            while True:
                schedule = DagSchedule()
                try:
                    msg = dag_queue_socket.recv(zmq.DONTWAIT)
                except zmq.ZMQError as e:
                    if e.errno == zmq.EAGAIN:
                        break  # There are no more messages.
                    else:
                        raise e  # Unexpected error.

                schedule.ParseFromString(msg)
                fname = schedule.target_function

                logging.info(
                    'Received a schedule for DAG %s (%s), function %s.' %
                    (schedule.dag.name, schedule.id, fname))

                if fname not in queue:
                    queue[fname] = {}

                queue[fname][schedule.id] = schedule

                if (schedule.id, fname) not in receive_times:
                    receive_times[(schedule.id, fname)] = time.time()

                # In case we receive the trigger before we receive the schedule, we
                # can trigger from this operation as well.
                trkey = (schedule.id, fname)
                fref = None

                # Check to see what type of execution this function is.
                for ref in schedule.dag.functions:
                    if ref.name == fname:
                        fref = ref

                if (trkey in received_triggers and
                    ((len(received_triggers[trkey]) == len(schedule.triggers))
                     or (fref.type == MULTIEXEC))):

                    triggers = list(received_triggers[trkey].values())

                    if fname not in function_cache:
                        logging.error('%s not in function cache', fname)
                        utils.generate_error_response(schedule, client, fname)
                        continue
                    exec_start = time.time()
                    # logging.info(f'Executor timer. dag_queue_socket exec_dag: {exec_start}')
                    # We don't support actual batching for when we receive a
                    # schedule before a trigger, so everything is just a batch of
                    # size 1 if anything.
                    success = exec_dag_function(pusher_cache, client,
                                                [triggers],
                                                function_cache[fname],
                                                [schedule], user_library,
                                                dag_runtimes, cache,
                                                schedulers, batching)[0]
                    user_library.close()

                    del received_triggers[trkey]
                    if success:
                        del queue[fname][schedule.id]

                        fend = time.time()
                        fstart = receive_times[(schedule.id, fname)]
                        runtimes[fname].append(fend - work_start)
                        exec_counts[fname] += 1

                        finished_executions[(schedule.id, fname)] = time.time()

            elapsed = time.time() - work_start
            event_occupancy['dag_queue'] += elapsed
            total_occupancy += elapsed

        if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN:
            work_start = time.time()
            # logging.info(f'Executor timer. dag_exec_socket recv: {work_start}')

            # How many messages to dequeue -- BATCH_SIZE_MAX or 1 depending on
            # the function configuration.
            if batching:
                count = BATCH_SIZE_MAX
            else:
                count = 1

            trigger_keys = set()

            for _ in range(count):  # Dequeue count number of messages.
                trigger = DagTrigger()

                try:
                    msg = dag_exec_socket.recv(zmq.DONTWAIT)
                except zmq.ZMQError as e:
                    if e.errno == zmq.EAGAIN:  # There are no more messages.
                        break
                    else:
                        raise e  # Unexpected error.

                trigger.ParseFromString(msg)

                # We have received a repeated trigger for a function that has
                # already finished executing.
                if trigger.id in finished_executions:
                    continue

                fname = trigger.target_function
                logging.info(
                    'Received a trigger for schedule %s, function %s.' %
                    (trigger.id, fname))

                key = (trigger.id, fname)
                trigger_keys.add(key)
                if key not in received_triggers:
                    received_triggers[key] = {}

                if (trigger.id, fname) not in receive_times:
                    receive_times[(trigger.id, fname)] = time.time()

                received_triggers[key][trigger.source] = trigger

            # Only execute the functions for which we have received a schedule.
            # Everything else will wait.
            for tid, fname in list(trigger_keys):
                if fname not in queue or tid not in queue[fname]:
                    trigger_keys.remove((tid, fname))

            if len(trigger_keys) == 0:
                continue

            fref = None
            schedule = queue[fname][list(trigger_keys)[0]
                                    [0]]  # Pick a random schedule to check.
            # Check to see what type of execution this function is.
            for ref in schedule.dag.functions:
                if ref.name == fname:
                    fref = ref
                    break

            # Compile a list of all the trigger sets for which we have
            # enough triggers.
            trigger_sets = []
            schedules = []
            for key in trigger_keys:
                if (len(received_triggers[key]) == len(schedule.triggers)) or \
                        fref.type == MULTIEXEC:

                    if fref.type == MULTIEXEC:
                        triggers = [trigger]
                    else:
                        triggers = list(received_triggers[key].values())

                    if fname not in function_cache:
                        logging.error('%s not in function cache', fname)
                        utils.generate_error_response(schedule, client, fname)
                        continue

                    trigger_sets.append(triggers)
                    schedule = queue[fname][key[0]]
                    schedules.append(schedule)

            exec_start = time.time()
            # logging.info(f'Executor timer. dag_exec_socket exec_dag: {exec_start}')
            # Pass all of the trigger_sets into exec_dag_function at once.
            # We also include the batching variaible to make sure we know
            # whether to pass lists into the fn or not.
            if len(trigger_sets) > 0:
                successes = exec_dag_function(pusher_cache, client,
                                              trigger_sets,
                                              function_cache[fname], schedules,
                                              user_library, dag_runtimes,
                                              cache, schedulers, batching)
                user_library.close()
                del received_triggers[key]

                for key, success in zip(trigger_keys, successes):
                    if success:
                        del queue[fname][key[0]]  # key[0] is trigger.id.

                        fend = time.time()
                        fstart = receive_times[key]

                        average_time = (fend - work_start) / len(trigger_keys)

                        runtimes[fname].append(average_time)
                        exec_counts[fname] += 1

                        finished_executions[(schedule.id, fname)] = time.time()

            elapsed = time.time() - work_start
            event_occupancy['dag_exec'] += elapsed
            total_occupancy += elapsed

        if self_depart_socket in socks and socks[self_depart_socket] == \
                zmq.POLLIN:
            # This message does not matter.
            self_depart_socket.recv()

            logging.info('Preparing to depart. No longer accepting requests ' +
                         'and clearing all queues.')

            status.ClearField('functions')
            status.running = False
            utils.push_status(schedulers, pusher_cache, status)

            departing = True

        # periodically report function occupancy
        report_end = time.time()
        if report_end - report_start > REPORT_THRESH:
            if len(cache) > 100:
                extra_keys = list(cache.keys())[:len(cache) - 100]
                for key in extra_keys:
                    del cache[key]

            utilization = total_occupancy / (report_end - report_start)
            status.utilization = utilization

            # Periodically report my status to schedulers with the utilization
            # set.
            utils.push_status(schedulers, pusher_cache, status)

            logging.debug('Total thread occupancy: %.6f' % (utilization))

            for event in event_occupancy:
                occ = event_occupancy[event] / (report_end - report_start)
                logging.debug('\tEvent %s occupancy: %.6f' % (event, occ))
                event_occupancy[event] = 0.0

            stats = ExecutorStatistics()
            for fname in runtimes:
                if exec_counts[fname] > 0:
                    fstats = stats.functions.add()
                    fstats.name = fname
                    fstats.call_count = exec_counts[fname]
                    fstats.runtime.extend(runtimes[fname])

                runtimes[fname].clear()
                exec_counts[fname] = 0

            for dname in dag_runtimes:
                dstats = stats.dags.add()
                dstats.name = dname

                dstats.runtimes.extend(dag_runtimes[dname])

                dag_runtimes[dname].clear()

            # If we are running in cluster mode, mgmt_ip will be set, and we
            # will report our status and statistics to it. Otherwise, we will
            # write to the local conf file
            if mgmt_ip:
                sckt = pusher_cache.get(
                    sutils.get_statistics_report_address(mgmt_ip))
                sckt.send(stats.SerializeToString())

                sckt = pusher_cache.get(utils.get_util_report_address(mgmt_ip))
                sckt.send(status.SerializeToString())
            else:
                logging.info(stats)

            status.ClearField('utilization')
            report_start = time.time()
            total_occupancy = 0.0

            # Periodically clear any old functions we have cached that we are
            # no longer accepting requests for.
            del_list = []
            for fname in queue:
                if len(queue[fname]) == 0 and fname not in status.functions:
                    del_list.append(fname)
                    del function_cache[fname]
                    del runtimes[fname]
                    del exec_counts[fname]

            for fname in del_list:
                del queue[fname]

            del_list = []
            for tid in finished_executions:
                if (time.time() - finished_executions[tid]) > 10:
                    del_list.append(tid)

            for tid in del_list:
                del finished_executions[tid]

            # If we are departing and have cleared our queues, let the
            # management server know, and exit the process.
            if departing and len(queue) == 0:
                sckt = pusher_cache.get(utils.get_depart_done_addr(mgmt_ip))
                sckt.send_string(ip)

                # We specifically pass 1 as the exit code when ending our
                # process so that the wrapper script does not restart us.
                sys.exit(1)
Beispiel #6
0
def scheduler(ip, mgmt_ip, route_addr):

    # If the management IP is not set, we are running in local mode.
    local = (mgmt_ip is None)
    kvs = AnnaTcpClient(route_addr, ip, local=local)

    scheduler_id = str(uuid.uuid4())

    context = zmq.Context(1)

    # A mapping from a DAG's name to its protobuf representation.
    dags = {}

    # Tracks how often a request for each function is received.
    call_frequency = {}

    # Tracks the time interval between successive requests for a particular
    # DAG.
    interarrivals = {}

    # Tracks the most recent arrival for each DAG -- used to calculate
    # interarrival times.
    last_arrivals = {}

    # Maintains a list of all other schedulers in the system, so we can
    # propagate metadata to them.
    schedulers = []

    connect_socket = context.socket(zmq.REP)
    connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT))

    func_create_socket = context.socket(zmq.REP)
    func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT))

    func_call_socket = context.socket(zmq.REP)
    func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT))

    dag_create_socket = context.socket(zmq.REP)
    dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT))

    dag_call_socket = context.socket(zmq.REP)
    dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT))

    dag_delete_socket = context.socket(zmq.REP)
    dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT))

    list_socket = context.socket(zmq.REP)
    list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT))

    exec_status_socket = context.socket(zmq.PULL)
    exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT))

    sched_update_socket = context.socket(zmq.PULL)
    sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.SCHED_UPDATE_PORT))

    pin_accept_socket = context.socket(zmq.PULL)
    pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500)
    pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                           (sutils.PIN_ACCEPT_PORT))

    requestor_cache = SocketCache(context, zmq.REQ)
    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(connect_socket, zmq.POLLIN)
    poller.register(func_create_socket, zmq.POLLIN)
    poller.register(func_call_socket, zmq.POLLIN)
    poller.register(dag_create_socket, zmq.POLLIN)
    poller.register(dag_call_socket, zmq.POLLIN)
    poller.register(dag_delete_socket, zmq.POLLIN)
    poller.register(list_socket, zmq.POLLIN)
    poller.register(exec_status_socket, zmq.POLLIN)
    poller.register(sched_update_socket, zmq.POLLIN)

    # Start the policy engine.
    policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket,
                                              pusher_cache,
                                              kvs,
                                              ip,
                                              local=local)
    policy.update()

    start = time.time()

    while True:
        socks = dict(poller.poll(timeout=1000))

        if connect_socket in socks and socks[connect_socket] == zmq.POLLIN:
            msg = connect_socket.recv_string()
            connect_socket.send_string(route_addr)

        if (func_create_socket in socks
                and socks[func_create_socket] == zmq.POLLIN):
            create_function(func_create_socket, kvs)

        if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN:
            call_function(func_call_socket, pusher_cache, policy)

        if (dag_create_socket in socks
                and socks[dag_create_socket] == zmq.POLLIN):
            create_dag(dag_create_socket, pusher_cache, kvs, dags, policy,
                       call_frequency)

        if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN:
            call = DagCall()
            call.ParseFromString(dag_call_socket.recv())

            name = call.name

            t = time.time()
            if name in last_arrivals:
                if name not in interarrivals:
                    interarrivals[name] = []

                interarrivals[name].append(t - last_arrivals[name])

            last_arrivals[name] = t

            if name not in dags:
                resp = GenericResponse()
                resp.success = False
                resp.error = NO_SUCH_DAG

                dag_call_socket.send(resp.SerializeToString())
                continue

            dag = dags[name]
            for fname in dag[0].functions:
                call_frequency[fname.name] += 1

            response = call_dag(call, pusher_cache, dags, policy)
            dag_call_socket.send(response.SerializeToString())

        if (dag_delete_socket in socks
                and socks[dag_delete_socket] == zmq.POLLIN):
            delete_dag(dag_delete_socket, dags, policy, call_frequency)

        if list_socket in socks and socks[list_socket] == zmq.POLLIN:
            msg = list_socket.recv_string()
            prefix = msg if msg else ''

            resp = StringSet()
            resp.keys.extend(sched_utils.get_func_list(kvs, prefix))

            list_socket.send(resp.SerializeToString())

        if exec_status_socket in socks and socks[exec_status_socket] == \
                zmq.POLLIN:
            status = ThreadStatus()
            status.ParseFromString(exec_status_socket.recv())

            policy.process_status(status)

        if sched_update_socket in socks and socks[sched_update_socket] == \
                zmq.POLLIN:
            status = SchedulerStatus()
            status.ParseFromString(sched_update_socket.recv())

            # Retrieve any DAGs that some other scheduler knows about that we
            # do not yet know about.
            for dname in status.dags:
                if dname not in dags:
                    payload = kvs.get(dname)
                    while None in payload:
                        payload = kvs.get(dname)

                    dag = Dag()
                    dag.ParseFromString(payload[dname].reveal())
                    dags[dag.name] = (dag, sched_utils.find_dag_source(dag))

                    for fname in dag.functions:
                        if fname not in call_frequency:
                            call_frequency[fname] = 0

            policy.update_function_locations(status.function_locations)

        end = time.time()

        if end - start > METADATA_THRESHOLD:
            # Update the scheduler policy-related metadata.
            policy.update()

            # If the management IP is None, that means we arre running in
            # local mode, so there is no need to deal with caches and other
            # schedulers.
            if mgmt_ip:
                schedulers = sched_utils.get_ip_set(
                    sched_utils.get_scheduler_list_address(mgmt_ip),
                    requestor_cache, False)

        if end - start > REPORT_THRESHOLD:
            num_unique_executors = policy.get_unique_executors()
            key = scheduler_id + ':' + str(time.time())
            data = {'key': key, 'count': num_unique_executors}

            status = SchedulerStatus()
            for name in dags.keys():
                status.dags.append(name)

            for fname in policy.function_locations:
                for loc in policy.function_locations[fname]:
                    floc = status.function_locations.add()
                    floc.name = fname
                    floc.ip = loc[0]
                    floc.tid = loc[1]

            msg = status.SerializeToString()

            for sched_ip in schedulers:
                if sched_ip != ip:
                    sckt = pusher_cache.get(
                        sched_utils.get_scheduler_update_address(sched_ip))
                    sckt.send(msg)

            stats = ExecutorStatistics()
            for fname in call_frequency:
                fstats = stats.functions.add()
                fstats.name = fname
                fstats.call_count = call_frequency[fname]
                logging.info('Reporting %d calls for function %s.' %
                             (call_frequency[fname], fname))

                call_frequency[fname] = 0

            for dname in interarrivals:
                dstats = stats.dags.add()
                dstats.name = dname
                dstats.call_count = len(interarrivals[dname]) + 1
                dstats.interarrival.extend(interarrivals[dname])

                interarrivals[dname].clear()

            # We only attempt to send the statistics if we are running in
            # cluster mode. If we are running in local mode, we write them to
            # the local log file.
            if mgmt_ip:
                sckt = pusher_cache.get(
                    sutils.get_statistics_report_address(mgmt_ip))
                sckt.send(stats.SerializeToString())

            start = time.time()
Beispiel #7
0
class FluentUserLibrary(AbstractFluentUserLibrary):

    # ip: Executor IP.
    # tid: Executor thread ID.
    # anna_client: The Anna client, used for interfacing with the kvs.
    def __init__(self, ip, tid, anna_client):
        self.ctx = zmq.Context()
        self.send_socket_cache = SocketCache(self.ctx, zmq.PUSH)

        self.executor_ip = ip
        self.executor_tid = tid
        self.client = anna_client

        # Threadsafe queue to serve as this node's inbox.
        # Items are (sender string, message bytestring).
        # NB: currently unbounded in size.
        self.recv_inbox = queue.Queue()

        # Thread for receiving messages into our inbox.
        self.recv_inbox_thread = threading.Thread(
              target=self._recv_inbox_listener)
        self.recv_inbox_thread.do_run = True
        self.recv_inbox_thread.start()

    def put(self, ref, ltc):
        return self.client.put(ref, ltc)

    def get(self, ref):
        if type(ref) == list:
            return self.client.get(ref)

        return self.client.get(ref)[ref]

    def getid(self):
        return (self.executor_ip, self.executor_tid)

    # dest is currently (IP string, thread id int) of destination executor.
    def send(self, dest, bytestr):
        ip, tid = dest
        dest_addr = server_utils._get_user_msg_inbox_addr(ip, tid)
        sender = (self.executor_ip, self.executor_tid)

        socket = self.send_socket_cache.get(dest_addr)
        socket.send_pyobj((sender, bytestr))

    def close(self):
        self.recv_inbox_thread.do_run = False
        self.recv_inbox_thread.join()

    def recv(self):
        res = []
        while True:
            try:
                (sender, msg) = self.recv_inbox.get(block=False)
                res.append((sender, msg))
            except queue.Empty:
                break
        return res

    # Function that continuously listens for send()s sent by other nodes,
    # and stores the messages in an inbox.
    def _recv_inbox_listener(self):
        # Socket for receiving send() messages from other nodes.
        recv_inbox_socket = self.ctx.socket(zmq.PULL)
        recv_inbox_socket.bind(server_utils.BIND_ADDR_TEMPLATE %
                               (server_utils.RECV_INBOX_PORT +
                                self.executor_tid))
        t = threading.currentThread()

        while t.do_run:
            try:
                (sender, msg) = recv_inbox_socket.recv_pyobj(zmq.NOBLOCK)
                self.recv_inbox.put((sender, msg))
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    continue
                else:
                    raise e
            time.sleep(.010)

        recv_inbox_socket.close()
Beispiel #8
0
def run(self_ip):
    context = zmq.Context(1)

    pusher_cache = SocketCache(context, zmq.PUSH)

    restart_pull_socket = context.socket(zmq.REP)
    restart_pull_socket.bind('tcp://*:7000')

    churn_pull_socket = context.socket(zmq.PULL)
    churn_pull_socket.bind('tcp://*:7001')

    list_executors_socket = context.socket(zmq.PULL)
    list_executors_socket.bind('tcp://*:7002')

    function_status_socket = context.socket(zmq.PULL)
    function_status_socket.bind('tcp://*:7003')

    list_schedulers_socket = context.socket(zmq.REP)
    list_schedulers_socket.bind('tcp://*:7004')

    executor_depart_socket = context.socket(zmq.PULL)
    executor_depart_socket.bind('tcp://*:7005')

    statistics_socket = context.socket(zmq.PULL)
    statistics_socket.bind('tcp://*:7006')

    pin_accept_socket = context.socket(zmq.PULL)
    pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 seconds.
    pin_accept_socket.bind('tcp://*:' + PIN_ACCEPT_PORT)

    poller = zmq.Poller()
    poller.register(restart_pull_socket, zmq.POLLIN)
    poller.register(churn_pull_socket, zmq.POLLIN)
    poller.register(function_status_socket, zmq.POLLIN)
    poller.register(list_executors_socket, zmq.POLLIN)
    poller.register(list_schedulers_socket, zmq.POLLIN)
    poller.register(executor_depart_socket, zmq.POLLIN)
    poller.register(statistics_socket, zmq.POLLIN)

    add_push_socket = context.socket(zmq.PUSH)
    add_push_socket.connect('ipc:///tmp/node_add')

    remove_push_socket = context.socket(zmq.PUSH)
    remove_push_socket.connect('ipc:///tmp/node_remove')

    client, _ = util.init_k8s()

    scaler = DefaultScaler(self_ip, context, add_push_socket, remove_push_socket, pin_accept_socket)
    policy = DefaultHydroPolicy(scaler)

    # Tracks the self-reported statuses of each executor thread in the system.
    executor_statuses = {}

    # Tracks of which executors are departing. This is used to ensure all
    # threads acknowledge that they are finished before we remove a thread from
    # the system.
    departing_executors = {}

    # Tracks how often each function is called.
    function_frequencies = {}

    # Tracks the aggregated runtime for each function.
    function_runtimes = {}

    # Tracks the arrival times of DAG requests.
    arrival_times = {}

    # Tracks how often each DAG is called.
    dag_frequencies = {}

    # Tracks how long each DAG request spends in the system, end to end.
    dag_runtimes = {}

    start = time.time()
    while True:
        socks = dict(poller.poll(timeout=1000))

        if (churn_pull_socket in socks and socks[churn_pull_socket] ==
                zmq.POLLIN):
            msg = churn_pull_socket.recv_string()
            args = msg.split(':')

            if args[0] == 'add':
                scaler.add_vms(args[2], args[1])
            elif args[0] == 'remove':
                scaler.remove_vms(args[2], args[1])

        if (restart_pull_socket in socks and socks[restart_pull_socket] ==
                zmq.POLLIN):
            msg = restart_pull_socket.recv_string()
            args = msg.split(':')

            pod = util.get_pod_from_ip(client, args[1])
            count = str(pod.status.container_statuses[0].restart_count)

            restart_pull_socket.send_string(count)

        if (list_executors_socket in socks and socks[list_executors_socket] ==
                zmq.POLLIN):
            # We can safely ignore this message's contents, and the response
            # does not depend on it.
            response_ip = list_executors_socket.recv_string()

            ips = StringSet()
            for ip in util.get_pod_ips(client, 'role=function'):
                ips.keys.append(ip)
            for ip in util.get_pod_ips(client, 'role=gpu'):
                ips.keys.append(ip)

            sckt = pusher_cache.get(response_ip)
            sckt.send(ips.SerializeToString())

        if (function_status_socket in socks and
                socks[function_status_socket] == zmq.POLLIN):
            # Dequeue all available ThreadStatus messages rather than doing
            # them one at a time---this prevents starvation if other operations
            # (e.g., pin) take a long time.
            while True:
                status = ThreadStatus()
                try:
                    status.ParseFromString(function_status_socket.recv(zmq.DONTWAIT))
                except:
                    break # We've run out of messages.

                key = (status.ip, status.tid)

                # If this executor is one of the ones that's currently departing,
                # we can just ignore its status updates since we don't want
                # utilization to be skewed downwards. The reason we might still
                # receive this message is because the depart message may not have
                # arrived when this was sent.
                if key[0] in departing_executors:
                    continue

                executor_statuses[key] = status
                # logging.info(('Received thread status update from %s:%d: %.4f ' +
                #               'occupancy, %d functions pinned') %
                #              (status.ip, status.tid, status.utilization,
                #               len(status.functions)))
                logging.info(f"Functions {status.functions} is placed on node "
                             f"{status.ip}:{status.tid}")

        if (list_schedulers_socket in socks and
                socks[list_schedulers_socket] == zmq.POLLIN):
            # We can safely ignore this message's contents, and the response
            # does not depend on it.
            list_schedulers_socket.recv_string()

            ips = StringSet()
            for ip in util.get_pod_ips(client, 'role=scheduler'):
                ips.keys.append(ip)

            list_schedulers_socket.send(ips.SerializeToString())

        if (executor_depart_socket in socks and
                socks[executor_depart_socket] == zmq.POLLIN):
            ip = executor_depart_socket.recv_string()
            departing_executors[ip] -= 1

            # We wait until all the threads at this executor have acknowledged
            # that they are ready to leave, and we then remove the VM from the
            # system.
            if departing_executors[ip] == 0:
                logging.info('Removing node with ip %s' % ip)
                scaler.remove_vms('function', ip)
                del departing_executors[ip]

        if (statistics_socket in socks and
                socks[statistics_socket] == zmq.POLLIN):
            stats = ExecutorStatistics()
            stats.ParseFromString(statistics_socket.recv())

            # Aggregates statistics reported for individual functions including
            # call frequencies, processed requests, and total runtimes.
            for fstats in stats.functions:
                fname = fstats.name

                if fname not in function_frequencies:
                    function_frequencies[fname] = 0

                if fname not in function_runtimes:
                    function_runtimes[fname] = (0.0, 0)

                if fstats.runtime:
                    old_latency = function_runtimes[fname]

                    # This tracks how many calls were processed for the
                    # function and the length of the total runtime of all
                    # calls.
                    function_runtimes[fname] = (
                          old_latency[0] + sum(fstats.runtime),
                          old_latency[1] + fstats.call_count)
                else:
                    # This tracks how many calls are made to the function.
                    function_frequencies[fname] += fstats.call_count

            # Aggregates statistics for DAG requests, including call
            # frequencies, arrival rates, and end-to-end runtimes.
            for dstats in stats.dags:
                dname = dstats.name

                # Tracks the interarrival rates of requests to this function as
                # perceived by the scheduler.
                if dname not in arrival_times:
                    arrival_times[dname] = []

                arrival_times[dname] += list(dstats.interarrival)

                # Tracks how many calls to this DAG were received.
                if dname not in dag_frequencies:
                    dag_frequencies[dname] = 0

                dag_frequencies[dname] += dstats.call_count

                # Tracks the end-to-end runtime of individual requests
                # completed in the last epoch.
                if dname not in dag_runtimes:
                    dag_runtimes[dname] = []

                for rt in dstats.runtimes:
                    dag_runtimes[dname].append(rt)

        end = time.time()
        if end - start > REPORT_PERIOD:
            logging.info('Checking hash ring...')
            check_hash_ring(client, context)

            # Invoke the configured policy to check system load and respond
            # appropriately.
            policy.replica_policy(function_frequencies, function_runtimes,
                                  dag_runtimes, executor_statuses,
                                  arrival_times)
            # TODO(simon): this turn off node scaling policy, which is what we want for static exp env
            # policy.executor_policy(executor_statuses, departing_executors)

            # Clears all metadata that was passed in for this epoch.
            function_runtimes.clear()
            function_frequencies.clear()
            dag_runtimes.clear()
            arrival_times.clear()

            # Restart the timer for the next reporting epoch.
            start = time.time()
Beispiel #9
0
def executor(ip, mgmt_ip, schedulers, thread_id):
    logging.basicConfig(filename='log_executor.txt',
                        level=logging.INFO,
                        format='%(asctime)s %(message)s')

    ctx = zmq.Context(1)
    poller = zmq.Poller()

    pin_socket = ctx.socket(zmq.PULL)
    pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id))

    unpin_socket = ctx.socket(zmq.PULL)
    unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                      (sutils.UNPIN_PORT + thread_id))

    exec_socket = ctx.socket(zmq.PULL)
    exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                     (sutils.FUNC_EXEC_PORT + thread_id))

    dag_queue_socket = ctx.socket(zmq.PULL)
    dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                          (sutils.DAG_QUEUE_PORT + thread_id))

    dag_exec_socket = ctx.socket(zmq.PULL)
    dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                         (sutils.DAG_EXEC_PORT + thread_id))

    self_depart_socket = ctx.socket(zmq.PULL)
    self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                            (sutils.SELF_DEPART_PORT + thread_id))

    pusher_cache = SocketCache(ctx, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(pin_socket, zmq.POLLIN)
    poller.register(unpin_socket, zmq.POLLIN)
    poller.register(exec_socket, zmq.POLLIN)
    poller.register(dag_queue_socket, zmq.POLLIN)
    poller.register(dag_exec_socket, zmq.POLLIN)
    poller.register(self_depart_socket, zmq.POLLIN)

    client = IpcAnnaClient(thread_id)

    status = ThreadStatus()
    status.ip = ip
    status.tid = thread_id
    status.running = True
    utils._push_status(schedulers, pusher_cache, status)

    departing = False

    # this is going to be a map of map of maps for every function that we have
    # pinnned, we will track a map of execution ids to DAG schedules
    queue = {}

    # track the actual function objects that we are storing here
    pinned_functions = {}

    # tracks runtime cost of excuting a DAG function
    runtimes = {}

    # if multiple triggers are necessary for a function, track the triggers as
    # we receive them
    received_triggers = {}

    # track when we received a function request, so we can report e2e latency
    receive_times = {}

    # track how many functions we're executing
    exec_counts = {}

    # metadata to track thread utilization
    report_start = time.time()
    event_occupancy = {
        'pin': 0.0,
        'unpin': 0.0,
        'func_exec': 0.0,
        'dag_queue': 0.0,
        'dag_exec': 0.0
    }
    total_occupancy = 0.0

    while True:
        socks = dict(poller.poll(timeout=1000))

        if pin_socket in socks and socks[pin_socket] == zmq.POLLIN:
            work_start = time.time()
            pin(pin_socket, client, status, pinned_functions, runtimes,
                exec_counts)
            utils._push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['pin'] += elapsed
            total_occupancy += elapsed

        if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN:
            work_start = time.time()
            unpin(unpin_socket, status, pinned_functions, runtimes,
                  exec_counts)
            utils._push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['unpin'] += elapsed
            total_occupancy += elapsed

        if exec_socket in socks and socks[exec_socket] == zmq.POLLIN:
            work_start = time.time()
            exec_function(exec_socket, client, status)

            elapsed = time.time() - work_start
            event_occupancy['func_exec'] += elapsed
            total_occupancy += elapsed

        if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN:
            work_start = time.time()

            schedule = DagSchedule()
            schedule.ParseFromString(dag_queue_socket.recv())
            fname = schedule.target_function

            logging.info('Received a schedule for DAG %s (%s), function %s.' %
                         (schedule.dag.name, schedule.id, fname))

            if fname not in queue:
                queue[fname] = {}

            queue[fname][schedule.id] = schedule

            if (schedule.id, fname) not in receive_times:
                receive_times[(schedule.id, fname)] = time.time()

            # in case we receive the trigger before we receive the schedule, we
            # can trigger from this operation as well
            trkey = (schedule.id, fname)
            if trkey in received_triggers and \
                    len(received_triggers[trkey]) == \
                            len(schedule.triggers):
                exec_dag_function(pusher_cache, client,
                                  received_triggers[trkey],
                                  pinned_functions[fname], schedule)
                del received_triggers[trkey]
                del queue[fname][schedule.id]

                fend = time.time()
                fstart = receive_times[(schedule.id, fname)]
                runtimes[fname] += fend - fstart
                exec_counts[fname] += 1

            elapsed = time.time() - work_start
            event_occupancy['dag_queue'] += elapsed
            total_occupancy += elapsed

        if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN:
            work_start = time.time()
            trigger = DagTrigger()
            trigger.ParseFromString(dag_exec_socket.recv())

            fname = trigger.target_function
            logging.info('Received a trigger for schedule %s, function %s.' %
                         (trigger.id, fname))

            key = (trigger.id, fname)
            if trigger.id not in received_triggers:
                received_triggers[key] = {}

            if (trigger.id, fname) not in receive_times:
                receive_times[(trigger.id, fname)] = time.time()

            received_triggers[key][trigger.source] = trigger
            if fname in queue and trigger.id in queue[fname]:
                schedule = queue[fname][trigger.id]
                if len(received_triggers[key]) == len(schedule.triggers):
                    exec_dag_function(pusher_cache, client,
                                      received_triggers[key],
                                      pinned_functions[fname], schedule)
                    del received_triggers[key]
                    del queue[fname][trigger.id]

                    fend = time.time()
                    fstart = receive_times[(trigger.id, fname)]
                    runtimes[fname] += fend - fstart
                    exec_counts[fname] += 1

            elapsed = time.time() - work_start
            event_occupancy['dag_exec'] += elapsed
            total_occupancy += elapsed

        if self_depart_socket in socks and socks[self_depart_socket] == \
                zmq.POLLIN:
            # This message should not matter
            msg = self_depart_socket.recv()

            logging.info('Preparing to depart. No longer accepting requests ' +
                         'and clearing all queues.')

            status.ClearField('functions')
            status.running = False
            utils._push_status(schedulers, pusher_cache, status)

            departing = True

        # periodically report function occupancy
        report_end = time.time()
        if report_end - report_start > REPORT_THRESH:
            utilization = total_occupancy / (report_end - report_start)
            status.utilization = utilization

            sckt = pusher_cache.get(utils._get_util_report_address(mgmt_ip))
            sckt.send(status.SerializeToString())

            logging.info('Total thread occupancy: %.6f%%' % (utilization))

            for event in event_occupancy:
                occ = event_occupancy[event]
                logging.info('Event %s occupancy: %.6f%%' % (event, occ))
                event_occupancy[event] = 0.0

            stats = ExecutorStatistics()
            for fname in runtimes:
                if exec_counts[fname] > 0:
                    fstats = stats.statistics.add()
                    fstats.fname = fname
                    fstats.runtime = runtimes[fname]
                    fstats.call_count = exec_counts[fname]

                runtimes[fname] = 0.0
                exec_counts[fname] = 0

            sckt = pusher_cache.get(sutils._get_statistics_report_address \
                    (mgmt_ip))
            sckt.send(stats.SerializeToString())

            report_start = time.time()
            total_occupancy = 0.0

            # periodically clear any old functions we have cached that we are
            # no longer accepting requests for
            for fname in queue:
                if len(queue[fname]) == 0 and fname not in status.functions:
                    del queue[fname]
                    del pinned_functions[fname]
                    del runtimes[fname]
                    del exec_counts[fname]

            # if we are departing and have cleared our queues, let the
            # management server know, and exit the process
            if departing and len(queue) == 0:
                sckt = pusher_cache.get(utils._get_depart_done_addr(mgmt_ip))
                sckt.send_string(ip)

                return 0
Beispiel #10
0
def executor(ip, mgmt_ip, schedulers, thread_id):
    global_util = 0
    logging.basicConfig(filename='log_executor.txt', level=logging.INFO)

    ctx = zmq.Context(1)
    poller = zmq.Poller()

    pin_socket = ctx.socket(zmq.REP)
    pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id))

    unpin_socket = ctx.socket(zmq.REP)
    unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.UNPIN_PORT + thread_id))

    exec_socket = ctx.socket(zmq.REP)
    exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.FUNC_EXEC_PORT + thread_id))

    dag_queue_socket = ctx.socket(zmq.REP)
    dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_QUEUE_PORT
        + thread_id))

    dag_exec_socket = ctx.socket(zmq.PULL)
    dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_EXEC_PORT
        + thread_id))

    self_depart_socket = ctx.socket(zmq.PULL)
    self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SELF_DEPART_PORT +
        thread_id))

    pusher_cache = SocketCache(ctx, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(pin_socket, zmq.POLLIN)
    poller.register(unpin_socket, zmq.POLLIN)
    poller.register(exec_socket, zmq.POLLIN)
    poller.register(dag_queue_socket, zmq.POLLIN)
    poller.register(dag_exec_socket, zmq.POLLIN)
    poller.register(self_depart_socket, zmq.POLLIN)

    client = IpcAnnaClient()

    status = ThreadStatus()
    status.ip = ip
    status.tid = thread_id
    status.running = True
    utils._push_status(schedulers, pusher_cache, status)

    departing = False

    # this is going to be a map of map of maps for every function that we have
    # pinnned, we will track a map of execution ids to DAG schedules
    queue = {}

    # track the actual function objects that we are storing here
    pinned_functions = {}

    # tracks runtime cost of excuting a DAG function
    runtimes = {}

    # metadata to track thread utilization
    report_start = time.time()
    event_occupancy = { 'pin': 0.0, 'unpin': 0.0, 'func_exec': 0.0,
            'dag_queue': 0.0, 'dag_exec': 0.0 }
    total_occupancy = 0.0

    while True:
        socks = dict(poller.poll(timeout=1000))

        if pin_socket in socks and socks[pin_socket] == zmq.POLLIN:
            work_start = time.time()
            pin(pin_socket, client, status, pinned_functions, runtimes)
            utils._push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['pin'] += elapsed
            total_occupancy += elapsed

        if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN:
            work_start = time.time()
            unpin(unpin_socket, status, pinned_functions, runtimes)
            utils._push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['unpin'] += elapsed
            total_occupancy += elapsed

        if exec_socket in socks and socks[exec_socket] == zmq.POLLIN:
            work_start = time.time()
            exec_function(exec_socket, client, status)

            elapsed = time.time() - work_start
            event_occupancy['func_exec'] += elapsed
            total_occupancy += elapsed

        if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN:
            work_start = time.time()

            schedule = DagSchedule()
            schedule.ParseFromString(dag_queue_socket.recv())
            fname = schedule.target_function

            logging.info('Received a schedule for DAG %s, function %s.' %
                    (schedule.dag.name, fname))

            # if we are trying to kill this node or unpin this function, we
            # don't accept requests anymore for DAG schedules; this also checks
            # to make sure it's the right IP for the target
            if not status.running or (fname not in status.functions and \
                    fname in queue and \
                    schedule.id not in queue[fname].keys()) or \
                    schedule.locations[fname].split(':')[0] != ip:
                sutils.error.error = INVALID_TARGET
                dag_queue_socket.send(sutils.error.SerializeToString())
                continue

            if fname not in queue:
                queue[fname] = {}

            queue[fname][schedule.id] = schedule
            dag_queue_socket.send(sutils.ok_resp)

            elapsed = time.time() - work_start
            event_occupancy['dag_queue'] += elapsed
            total_occupancy += elapsed

        if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN:
            work_start = time.time()
            trigger = DagTrigger()
            trigger.ParseFromString(dag_exec_socket.recv())

            fname = trigger.target_function

            exec_dag_function(pusher_cache, client, trigger,
                    pinned_functions[fname], queue[fname][trigger.id])

            elapsed = time.time() - work_start
            event_occupancy['dag_exec'] += elapsed
            total_occupancy += elapsed
            runtimes[fname] += elapsed

        if self_depart_socket in socks and socks[self_depart_socket] == \
                zmq.POLLIN:
            # This message should not matter
            msg = self_depart_socket.recv()

            logging.info('Preparing to depart. No longer accepting requests ' +
                    'and clearing all queues.')

            status.ClearField('functions')
            status.running = False
            utils._push_status(schedulers, pusher_cache, status)

            departing = True

        # periodically report function occupancy
        report_end = time.time()
        if report_end - report_start > REPORT_THRESH:
            utilization = total_occupancy / (report_end - report_start)
            status.utilization = utilization

            sckt = pusher_cache.get(utils._get_util_report_address(mgmt_ip))
            sckt.send(status.SerializeToString())

            logging.info('Total thread occupancy: %.4f%%' % (utilization))

            for event in event_occupancy:
                occ = event_occupancy[event]
                logging.info('Event %s occupancy: %.4f%%' % (event, occ))
                event_occupancy[event] = 0.0

            stats = ExecutorStatistics()
            for fname in runtimes:
                fstats = stats.statistics.add()
                fstats.fname = fname
                fstats.runtime = runtimes[fname]

                runtimes[fname] = 0.0

            sckt = pusher_cache.get(sutils._get_statistics_report_address \
                    (mgmt_ip))
            sckt.send(stats.SerializeToString())


            report_start = time.time()
            total_occupancy = 0.0

            # periodically clear any old functions we have cached that we are
            # no longer accepting requests for
            for fname in queue:
                if len(queue[fname]) == 0 and fname not in status.functions:
                    del queue[fname]
                    del pinned_functions[fname]

            # if we are departing and have cleared our queues, let the
            # management server know, and exit the process
            if departing and len(queue) == 0:
                sckt = pusher_cache.get(utils._get_depart_done_addr(mgmt_ip))
                sckt.send_string(ip)

                return 0
Beispiel #11
0
def executor(ip, mgmt_ip, schedulers, thread_id):
    logging.basicConfig(filename='log_executor.txt',
                        level=logging.INFO,
                        format='%(asctime)s %(message)s')

    context = zmq.Context(1)
    poller = zmq.Poller()

    pin_socket = context.socket(zmq.PULL)
    pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id))

    unpin_socket = context.socket(zmq.PULL)
    unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                      (sutils.UNPIN_PORT + thread_id))

    exec_socket = context.socket(zmq.PULL)
    exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                     (sutils.FUNC_EXEC_PORT + thread_id))

    dag_queue_socket = context.socket(zmq.PULL)
    dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                          (sutils.DAG_QUEUE_PORT + thread_id))

    dag_exec_socket = context.socket(zmq.PULL)
    dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                         (sutils.DAG_EXEC_PORT + thread_id))

    self_depart_socket = context.socket(zmq.PULL)
    self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                            (sutils.SELF_DEPART_PORT + thread_id))

    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(pin_socket, zmq.POLLIN)
    poller.register(unpin_socket, zmq.POLLIN)
    poller.register(exec_socket, zmq.POLLIN)
    poller.register(dag_queue_socket, zmq.POLLIN)
    poller.register(dag_exec_socket, zmq.POLLIN)
    poller.register(self_depart_socket, zmq.POLLIN)

    # If the management IP is set to None, that means that we are running in
    # local mode, so we use a regular AnnaTcpClient rather than an IPC client.
    if mgmt_ip:
        client = AnnaIpcClient(thread_id, context)
    else:
        client = AnnaTcpClient('127.0.0.1', '127.0.0.1', local=True, offset=1)

    user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id,
                                         client)

    status = ThreadStatus()
    status.ip = ip
    status.tid = thread_id
    status.running = True
    utils.push_status(schedulers, pusher_cache, status)

    departing = False

    # Maintains a request queue for each function pinned on this executor. Each
    # function will have a set of request IDs mapped to it, and this map stores
    # a schedule for each request ID.
    queue = {}

    # Tracks the actual function objects that are pinned to this executor.
    function_cache = {}

    # Tracks runtime cost of excuting a DAG function.
    runtimes = {}

    # If multiple triggers are necessary for a function, track the triggers as
    # we receive them. This is also used if a trigger arrives before its
    # corresponding schedule.
    received_triggers = {}

    # Tracks when we received a function request, so we can report end-to-end
    # latency for the whole executio.
    receive_times = {}

    # Tracks the number of requests we are finishing for each function pinned
    # here.
    exec_counts = {}

    # Tracks the end-to-end runtime of each DAG request for which we are the
    # sink function.
    dag_runtimes = {}

    # A map with KVS keys and their corresponding deserialized payloads.
    cache = {}

    # Internal metadata to track thread utilization.
    report_start = time.time()
    event_occupancy = {
        'pin': 0.0,
        'unpin': 0.0,
        'func_exec': 0.0,
        'dag_queue': 0.0,
        'dag_exec': 0.0
    }
    total_occupancy = 0.0

    while True:
        socks = dict(poller.poll(timeout=1000))

        if pin_socket in socks and socks[pin_socket] == zmq.POLLIN:
            work_start = time.time()
            pin(pin_socket, pusher_cache, client, status, function_cache,
                runtimes, exec_counts, user_library)
            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['pin'] += elapsed
            total_occupancy += elapsed

        if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN:
            work_start = time.time()
            unpin(unpin_socket, status, function_cache, runtimes, exec_counts)
            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['unpin'] += elapsed
            total_occupancy += elapsed

        if exec_socket in socks and socks[exec_socket] == zmq.POLLIN:
            work_start = time.time()
            exec_function(exec_socket, client, user_library, cache,
                          function_cache)
            user_library.close()

            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['func_exec'] += elapsed
            total_occupancy += elapsed

        if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN:
            work_start = time.time()

            schedule = DagSchedule()
            schedule.ParseFromString(dag_queue_socket.recv())
            fname = schedule.target_function

            logging.info('Received a schedule for DAG %s (%s), function %s.' %
                         (schedule.dag.name, schedule.id, fname))

            if fname not in queue:
                queue[fname] = {}

            queue[fname][schedule.id] = schedule

            if (schedule.id, fname) not in receive_times:
                receive_times[(schedule.id, fname)] = time.time()

            # In case we receive the trigger before we receive the schedule, we
            # can trigger from this operation as well.
            trkey = (schedule.id, fname)
            if (trkey in received_triggers and
                (len(received_triggers[trkey]) == len(schedule.triggers))):

                exec_dag_function(pusher_cache, client,
                                  received_triggers[trkey],
                                  function_cache[fname], schedule,
                                  user_library, dag_runtimes, cache)
                user_library.close()

                del received_triggers[trkey]
                del queue[fname][schedule.id]

                fend = time.time()
                fstart = receive_times[(schedule.id, fname)]
                runtimes[fname].append(fend - fstart)
                exec_counts[fname] += 1

            elapsed = time.time() - work_start
            event_occupancy['dag_queue'] += elapsed
            total_occupancy += elapsed

        if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN:
            work_start = time.time()
            trigger = DagTrigger()
            trigger.ParseFromString(dag_exec_socket.recv())

            fname = trigger.target_function
            logging.info('Received a trigger for schedule %s, function %s.' %
                         (trigger.id, fname))

            key = (trigger.id, fname)
            if key not in received_triggers:
                received_triggers[key] = {}

            if (trigger.id, fname) not in receive_times:
                receive_times[(trigger.id, fname)] = time.time()

            received_triggers[key][trigger.source] = trigger
            if fname in queue and trigger.id in queue[fname]:
                schedule = queue[fname][trigger.id]
                if len(received_triggers[key]) == len(schedule.triggers):
                    exec_dag_function(pusher_cache, client,
                                      received_triggers[key],
                                      function_cache[fname], schedule,
                                      user_library, dag_runtimes, cache)
                    user_library.close()

                    del received_triggers[key]
                    del queue[fname][trigger.id]

                    fend = time.time()
                    fstart = receive_times[(trigger.id, fname)]
                    runtimes[fname].append(fend - fstart)
                    exec_counts[fname] += 1

            elapsed = time.time() - work_start
            event_occupancy['dag_exec'] += elapsed
            total_occupancy += elapsed

        if self_depart_socket in socks and socks[self_depart_socket] == \
                zmq.POLLIN:
            # This message does not matter.
            self_depart_socket.recv()

            logging.info('Preparing to depart. No longer accepting requests ' +
                         'and clearing all queues.')

            status.ClearField('functions')
            status.running = False
            utils.push_status(schedulers, pusher_cache, status)

            departing = True

        # periodically report function occupancy
        report_end = time.time()
        if report_end - report_start > REPORT_THRESH:
            cache.clear()

            utilization = total_occupancy / (report_end - report_start)
            status.utilization = utilization

            # Periodically report my status to schedulers with the utilization
            # set.
            utils.push_status(schedulers, pusher_cache, status)

            logging.info('Total thread occupancy: %.6f' % (utilization))

            for event in event_occupancy:
                occ = event_occupancy[event] / (report_end - report_start)
                logging.info('\tEvent %s occupancy: %.6f' % (event, occ))
                event_occupancy[event] = 0.0

            stats = ExecutorStatistics()
            for fname in runtimes:
                if exec_counts[fname] > 0:
                    fstats = stats.functions.add()
                    fstats.name = fname
                    fstats.call_count = exec_counts[fname]
                    fstats.runtime.extend(runtimes[fname])

                runtimes[fname].clear()
                exec_counts[fname] = 0

            for dname in dag_runtimes:
                dstats = stats.dags.add()
                dstats.name = dname

                dstats.runtimes.extend(dag_runtimes[dname])

                dag_runtimes[dname].clear()

            # If we are running in cluster mode, mgmt_ip will be set, and we
            # will report our status and statistics to it. Otherwise, we will
            # write to the local conf file
            if mgmt_ip:
                sckt = pusher_cache.get(
                    sutils.get_statistics_report_address(mgmt_ip))
                sckt.send(stats.SerializeToString())

                sckt = pusher_cache.get(utils.get_util_report_address(mgmt_ip))
                sckt.send(status.SerializeToString())
            else:
                logging.info(stats)

            status.ClearField('utilization')
            report_start = time.time()
            total_occupancy = 0.0

            # Periodically clear any old functions we have cached that we are
            # no longer accepting requests for.
            del_list = []
            for fname in queue:
                if len(queue[fname]) == 0 and fname not in status.functions:
                    del_list.append(fname)
                    del function_cache[fname]
                    del runtimes[fname]
                    del exec_counts[fname]

            for fname in del_list:
                del queue[fname]

            # If we are departing and have cleared our queues, let the
            # management server know, and exit the process.
            if departing and len(queue) == 0:
                sckt = pusher_cache.get(utils.get_depart_done_addr(mgmt_ip))
                sckt.send_string(ip)

                # We specifically pass 1 as the exit code when ending our
                # process so that the wrapper script does not restart us.
                os._exit(1)
Beispiel #12
0
def scheduler(ip, mgmt_ip, route_addr, policy_type):

    # If the management IP is not set, we are running in local mode.
    local = (mgmt_ip is None)
    kvs = AnnaTcpClient(route_addr, ip, local=local)

    scheduler_id = str(uuid.uuid4())

    context = zmq.Context(1)
    context.set(zmq.MAX_SOCKETS, 10000)

    # A mapping from a DAG's name to its protobuf representation.
    dags = {}

    # Tracks how often a request for each function is received.
    call_frequency = {}

    # Tracks the time interval between successive requests for a particular
    # DAG.
    interarrivals = {}

    # Tracks the most recent arrival for each DAG -- used to calculate
    # interarrival times.
    last_arrivals = {}

    # Maintains a list of all other schedulers in the system, so we can
    # propagate metadata to them.
    schedulers = set()

    connect_socket = context.socket(zmq.REP)
    connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT))

    func_create_socket = context.socket(zmq.REP)
    func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT))

    func_call_socket = context.socket(zmq.REP)
    func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT))

    # This is for handle the invocation from queue
    # Mainly for storage event
    func_call_queue_socket = context.socket(zmq.PULL)
    func_call_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                                (FUNC_CALL_QUEUE_PORT))

    dag_create_socket = context.socket(zmq.REP)
    dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT))

    dag_call_socket = context.socket(zmq.REP)
    dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT))

    dag_delete_socket = context.socket(zmq.REP)
    dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT))

    list_socket = context.socket(zmq.REP)
    list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT))

    exec_status_socket = context.socket(zmq.PULL)
    exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT))

    sched_update_socket = context.socket(zmq.PULL)
    sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.SCHED_UPDATE_PORT))

    pin_accept_socket = context.socket(zmq.PULL)
    pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000)  # 10 seconds.
    pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                           (sutils.PIN_ACCEPT_PORT))

    continuation_socket = context.socket(zmq.PULL)
    continuation_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.CONTINUATION_PORT))

    if not local:
        management_request_socket = context.socket(zmq.REQ)
        management_request_socket.setsockopt(zmq.RCVTIMEO, 500)
        # By setting this flag, zmq matches replies with requests.
        management_request_socket.setsockopt(zmq.REQ_CORRELATE, 1)
        # Relax strict alternation between request and reply.
        # For detailed explanation, see here: http://api.zeromq.org/4-1:zmq-setsockopt
        management_request_socket.setsockopt(zmq.REQ_RELAXED, 1)
        management_request_socket.connect(
            sched_utils.get_scheduler_list_address(mgmt_ip))

    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(connect_socket, zmq.POLLIN)
    poller.register(func_create_socket, zmq.POLLIN)
    poller.register(func_call_socket, zmq.POLLIN)
    poller.register(func_call_queue_socket, zmq.POLLIN)
    poller.register(dag_create_socket, zmq.POLLIN)
    poller.register(dag_call_socket, zmq.POLLIN)
    poller.register(dag_delete_socket, zmq.POLLIN)
    poller.register(list_socket, zmq.POLLIN)
    poller.register(exec_status_socket, zmq.POLLIN)
    poller.register(sched_update_socket, zmq.POLLIN)
    poller.register(continuation_socket, zmq.POLLIN)

    # Start the policy engine.
    policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket,
                                              pusher_cache,
                                              kvs,
                                              ip,
                                              policy_type,
                                              local=local)
    policy.update()

    start = time.time()

    while True:
        socks = dict(poller.poll(timeout=1000))

        if connect_socket in socks and socks[connect_socket] == zmq.POLLIN:
            msg = connect_socket.recv_string()
            connect_socket.send_string(route_addr)

        if (func_create_socket in socks
                and socks[func_create_socket] == zmq.POLLIN):
            create_function(func_create_socket, kvs)

        if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN:
            call_function(func_call_socket, pusher_cache, policy)

        if func_call_queue_socket in socks and socks[
                func_call_queue_socket] == zmq.POLLIN:
            call_function_from_queue(func_call_queue_socket, pusher_cache,
                                     policy)

        if (dag_create_socket in socks
                and socks[dag_create_socket] == zmq.POLLIN):
            create_dag(dag_create_socket, pusher_cache, kvs, dags, policy,
                       call_frequency)

        if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN:
            start_t = int(time.time() * 1000000)
            call = DagCall()
            call.ParseFromString(dag_call_socket.recv())

            name = call.name

            t = time.time()
            if name in last_arrivals:
                if name not in interarrivals:
                    interarrivals[name] = []

                interarrivals[name].append(t - last_arrivals[name])

            last_arrivals[name] = t

            if name not in dags:
                resp = GenericResponse()
                resp.success = False
                resp.error = NO_SUCH_DAG

                dag_call_socket.send(resp.SerializeToString())
                continue

            dag = dags[name]
            for fname in dag[0].functions:
                call_frequency[fname.name] += 1

            response = call_dag(call, pusher_cache, dags, policy)
            sched_t = int(time.time() * 1000000)
            logging.info(
                f'App function {name} recv: {start_t}, scheduled: {sched_t}')
            dag_call_socket.send(response.SerializeToString())

        if (dag_delete_socket in socks
                and socks[dag_delete_socket] == zmq.POLLIN):
            delete_dag(dag_delete_socket, dags, policy, call_frequency)

        if list_socket in socks and socks[list_socket] == zmq.POLLIN:
            msg = list_socket.recv_string()
            prefix = msg if msg else ''

            resp = StringSet()
            resp.keys.extend(sched_utils.get_func_list(kvs, prefix))

            list_socket.send(resp.SerializeToString())

        if exec_status_socket in socks and socks[exec_status_socket] == \
                zmq.POLLIN:
            status = ThreadStatus()
            status.ParseFromString(exec_status_socket.recv())

            policy.process_status(status)

        if sched_update_socket in socks and socks[sched_update_socket] == \
                zmq.POLLIN:
            status = SchedulerStatus()
            status.ParseFromString(sched_update_socket.recv())

            # Retrieve any DAGs that some other scheduler knows about that we
            # do not yet know about.
            for dname in status.dags:
                if dname not in dags:
                    payload = kvs.get(dname)
                    while None in payload:
                        payload = kvs.get(dname)

                    dag = Dag()
                    dag.ParseFromString(payload[dname].reveal())
                    dags[dag.name] = (dag, sched_utils.find_dag_source(dag))

                    for fname in dag.functions:
                        if fname.name not in call_frequency:
                            call_frequency[fname.name] = 0

            policy.update_function_locations(status.function_locations)

        if continuation_socket in socks and socks[continuation_socket] == \
                zmq.POLLIN:
            start_t = int(time.time() * 1000000)

            continuation = Continuation()
            continuation.ParseFromString(continuation_socket.recv())

            call = continuation.call
            call.name = continuation.name

            result = Value()
            result.ParseFromString(continuation.result)

            dag, sources = dags[call.name]
            for source in sources:
                call.function_args[source].values.extend([result])

            call_dag(call, pusher_cache, dags, policy, continuation.id)
            sched_t = int(time.time() * 1000000)
            print(
                f'App function {call.name} recv: {start_t}, scheduled: {sched_t}'
            )
            for fname in dag.functions:
                call_frequency[fname.name] += 1

        end = time.time()

        if end - start > METADATA_THRESHOLD:
            # Update the scheduler policy-related metadata.
            policy.update()

            # If the management IP is None, that means we arre running in
            # local mode, so there is no need to deal with caches and other
            # schedulers.
            if not local:
                latest_schedulers = sched_utils.get_ip_set(
                    management_request_socket, False)
                if latest_schedulers:
                    schedulers = latest_schedulers

        if end - start > REPORT_THRESHOLD:
            status = SchedulerStatus()
            for name in dags.keys():
                status.dags.append(name)

            for fname in policy.function_locations:
                for loc in policy.function_locations[fname]:
                    floc = status.function_locations.add()
                    floc.name = fname
                    floc.ip = loc[0]
                    floc.tid = loc[1]

            msg = status.SerializeToString()

            for sched_ip in schedulers:
                if sched_ip != ip:
                    sckt = pusher_cache.get(
                        sched_utils.get_scheduler_update_address(sched_ip))
                    sckt.send(msg)

            stats = ExecutorStatistics()
            for fname in call_frequency:
                fstats = stats.functions.add()
                fstats.name = fname
                fstats.call_count = call_frequency[fname]
                logging.debug('Reporting %d calls for function %s.' %
                              (call_frequency[fname], fname))

                call_frequency[fname] = 0

            for dname in interarrivals:
                dstats = stats.dags.add()
                dstats.name = dname
                dstats.call_count = len(interarrivals[dname]) + 1
                dstats.interarrival.extend(interarrivals[dname])

                interarrivals[dname].clear()

            # We only attempt to send the statistics if we are running in
            # cluster mode. If we are running in local mode, we write them to
            # the local log file.
            if mgmt_ip:
                sckt = pusher_cache.get(
                    sutils.get_statistics_report_address(mgmt_ip))
                sckt.send(stats.SerializeToString())

            start = time.time()
Beispiel #13
0
def scheduler(ip, mgmt_ip, route_addr):
    logging.basicConfig(filename='log_scheduler.txt', level=logging.INFO)

    kvs = AnnaClient(route_addr, ip)

    key_cache_map = {}
    key_ip_map = {}
    ctx = zmq.Context(1)

    # Each dag consists of a set of functions and connections. Each one of
    # the functions is pinned to one or more nodes, which is tracked here.
    dags = {}
    thread_statuses = {}
    func_locations = {}

    connect_socket = ctx.socket(zmq.REP)
    connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT))

    func_create_socket = ctx.socket(zmq.REP)
    func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT))

    func_call_socket = ctx.socket(zmq.REP)
    func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT))

    dag_create_socket = ctx.socket(zmq.REP)
    dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT))

    dag_call_socket = ctx.socket(zmq.REP)
    dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT))

    list_socket = ctx.socket(zmq.REP)
    list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT))

    exec_status_socket = ctx.socket(zmq.PULL)
    exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT))

    sched_update_socket = ctx.socket(zmq.PULL)
    sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.SCHED_UPDATE_PORT))

    requestor_cache = SocketCache(ctx, zmq.REQ)
    pusher_cache = SocketCache(ctx, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(connect_socket, zmq.POLLIN)
    poller.register(func_create_socket, zmq.POLLIN)
    poller.register(func_call_socket, zmq.POLLIN)
    poller.register(dag_create_socket, zmq.POLLIN)
    poller.register(dag_call_socket, zmq.POLLIN)
    poller.register(list_socket, zmq.POLLIN)
    poller.register(exec_status_socket, zmq.POLLIN)
    poller.register(sched_update_socket, zmq.POLLIN)

    departed_executors = set()
    executors, schedulers = _update_cluster_state(requestor_cache, mgmt_ip,
                                                  departed_executors,
                                                  key_cache_map, key_ip_map,
                                                  kvs)

    # track how often each DAG function is called
    call_frequency = {}

    start = time.time()

    while True:
        socks = dict(poller.poll(timeout=1000))

        if connect_socket in socks and socks[connect_socket] == zmq.POLLIN:
            msg = connect_socket.recv_string()
            connect_socket.send_string(routing_addr)

        if func_create_socket in socks and socks[
                func_create_socket] == zmq.POLLIN:
            create_func(func_create_socket, kvs)

        if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN:
            call_function(func_call_socket, requestor_cache, executors,
                          key_ip_map)

        if dag_create_socket in socks and socks[
                dag_create_socket] == zmq.POLLIN:
            create_dag(dag_create_socket, requestor_cache, kvs, executors,
                       dags, func_locations, call_frequency)

        if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN:
            call = DagCall()
            call.ParseFromString(dag_call_socket.recv())
            exec_id = generate_timestamp(0)

            dag = dags[call.name]
            for fname in dag[0].functions:
                call_frequency[fname] += 1

            accepted, error, rid = call_dag(call, requestor_cache,
                                            pusher_cache, dags, func_locations,
                                            key_ip_map)

            while not accepted:
                # the assumption here is that the request was not accepted
                # because the cluster was out of date -- so we update cluster
                # state before proceeding
                executors, schedulers = _update_cluster_state(
                    requestor_cache, mgmt_ip, departed_executors,
                    key_cache_map, key_ip_map, kvs)

                accepted, error, rid = call_dag(call, requestor_cache,
                                                pusher_cache, dags,
                                                func_locations, key_ip_map)

            resp = GenericResponse()
            resp.success = True
            resp.response_id = rid
            dag_call_socket.send(resp.SerializeToString())

        if list_socket in socks and socks[list_socket] == zmq.POLLIN:
            logging.info('Received query for function list.')
            msg = list_socket.recv_string()
            prefix = msg if msg else ''

            resp = FunctionList()
            resp.names.extend(utils._get_func_list(kvs, prefix))

            list_socket.send(resp.SerializeToString())

        if exec_status_socket in socks and socks[exec_status_socket] == \
                zmq.POLLIN:
            status = ThreadStatus()
            status.ParseFromString(exec_status_socket.recv())

            key = (status.ip, status.tid)
            logging.info('Received status update from executor %s:%d.' %
                         (key[0], int(key[1])))

            # this means that this node is currently departing, so we remove it
            # from all of our metadata tracking
            if not status.running:
                old_status = thread_statuses[key]

                executors.remove(key)
                departed_executors.add(status.ip)
                del thread_statuses[key]

                for fname in old_status.functions:
                    func_locations[fname].remove(old_status.ip, old_status.tid)

                continue

            if key not in thread_statuses:
                thread_statuses[key] = status

                if key not in executors:
                    executors.add(key)
            elif thread_statuses[key] != status:
                # remove all the old function locations, and all the new ones
                # -- there will probably be a large overlap, but this shouldn't
                # be much different than calculating two different set
                # differences anyway
                for func in thread_statuses[key].functions:
                    if func in func_locations:
                        func_locations[func].discard(key)

                for func in status.functions:
                    if func not in func_locations:
                        func_locations[func] = set()

                    func_locations[func].add(key)

                thread_statuses[key] = status

        if sched_update_socket in socks and socks[sched_update_socket] == \
                zmq.POLLIN:
            logging.info('Received update from another scheduler.')
            ks = KeySet()
            ks.ParseFromString(sched_update_socket.recv())

            # retrieve any DAG that some other scheduler knows about that we do
            # not yet know about
            for dname in ks.keys:
                if dname not in dags:
                    dag = Dag()
                    dag.ParseFromString(kvs.get(dname).value)

                    dags[dname] = dag

        end = time.time()
        if end - start > THRESHOLD:
            executors, schedulers = _update_cluster_state(
                requestor_cache, mgmt_ip, departed_executors, key_cache_map,
                key_ip_map, kvs)

            dag_names = KeySet()
            for name in dags.keys():
                dag_names.keys.append(name)
            msg = dag_names.SerializeToString()

            for sched_ip in schedulers:
                if sched_ip != ip:
                    pusher_cache.get(
                        utils._get_scheduler_update_address(sched_ip))
                    sckt.send(msg)

            stats = ExecutorStatistics()
            for fname in call_frequency:
                fstats = stats.statistics.add()
                fstats.fname = fname
                fstats.call_count = call_frequency[fname]

                call_frequency[fname] = 0

            sckt = pusher_cache.get(sutils._get_statistics_report_address \
                    (mgmt_ip))
            sckt.send(stats.SerializeToString())

            start = time.time()