예제 #1
    def test_process_status(self):
        This test ensures that when a new status update is received from an
        executor, the local server metadata is correctly updated in the normal
        # Construct a new thread status to pass into the policy engine.
        function_name = 'square'
        status = ThreadStatus()
        status.running = True
        status.ip = self.ip
        status.tid = 1
        status.utilization = 0.10

        # Process the newly created status.

        status.tid = 2
        status.utilization = 0.90

        key = (status.ip, status.tid)

        self.assertTrue(key not in self.policy.unpinned_cpu_executors)
        self.assertTrue(key in self.policy.function_locations[function_name])
        self.assertTrue(key in self.policy.backoff)
예제 #2
    def setUp(self):
        self.pusher_cache = zmq_utils.MockPusherCache()
        self.socket = zmq_utils.MockZmqSocket()
        self.pin_socket = zmq_utils.MockZmqSocket()

        self.kvs_client = kvs_client.MockAnnaClient()
        self.ip = ''

        self.policy = DefaultCloudburstSchedulerPolicy(self.pin_socket,

        # Add an executor to the policy engine by default.
        status = ThreadStatus()
        status.ip = self.ip
        status.tid = 0
        self.executor_key = (status.ip, status.tid)
예제 #3
    def setUp(self):
        self.kvs_client = kvs_client.MockAnnaClient()
        self.socket = zmq_utils.MockZmqSocket()
        self.pusher_cache = zmq_utils.MockPusherCache()

        self.ip = ''
        self.status = ThreadStatus()
        self.status.ip = self.ip
        self.status.tid = 0
        self.status.running = True

        self.pinned_functions = {}
        self.runtimes = {}
        self.exec_counts = {}

        self.user_library = CloudburstUserLibrary(zmq_utils.MockZmqContext(),
                                                  self.pusher_cache, self.ip,
                                                  0, self.kvs_client)
예제 #4
    def test_metadata_update(self):
        This test calls the periodic metadata update protocol and ensures that
        the correct metadata is removed from the system and that the correct
        metadata is retrieved/updated from the KVS.
        # Create two executor threads on separate machines.
        old_ip = ''
        new_ip = ''
        old_executor = (old_ip, 1)
        new_executor = (new_ip, 2)

        old_status = ThreadStatus()
        old_status.ip = old_ip
        old_status.tid = 1
        old_status.running = True

        new_status = ThreadStatus()
        new_status.ip = new_ip
        new_status.tid = 2
        new_status.running = True

        self.policy.thread_statuses[old_executor] = old_status
        self.policy.thread_statuses[new_executor] = new_status

        # Add two executors, one with old an old backoff and one with a new
        # time.
        self.policy.backoff[old_executor] = time.time() - 10
        self.policy.backoff[new_executor] = time.time()

        # For the new executor, add 10 old running times and 10 new ones.
        self.policy.running_counts[new_executor] = set()
        for _ in range(10):
            self.policy.running_counts[new_executor].add(time.time() - 10)

        for _ in range(10):

        # Publish some caching metadata into the KVS for each executor.
        old_set = StringSet()
        old_set.keys.extend(['key1', 'key2', 'key3'])
        new_set = StringSet()
        new_set.keys.extend(['key3', 'key4', 'key5'])
                            LWWPairLattice(0, old_set.SerializeToString()))
                            LWWPairLattice(0, new_set.SerializeToString()))


        # Check that the metadata has been correctly pruned.
        self.assertEqual(len(self.policy.backoff), 1)
        self.assertTrue(new_executor in self.policy.backoff)
        self.assertEqual(len(self.policy.running_counts[new_executor]), 10)

        # Check that the caching information is correct.
        self.assertTrue(len(self.policy.key_locations['key1']), 1)
        self.assertTrue(len(self.policy.key_locations['key2']), 1)
        self.assertTrue(len(self.policy.key_locations['key3']), 2)
        self.assertTrue(len(self.policy.key_locations['key4']), 1)
        self.assertTrue(len(self.policy.key_locations['key5']), 1)

        self.assertTrue(old_ip in self.policy.key_locations['key1'])
        self.assertTrue(old_ip in self.policy.key_locations['key2'])
        self.assertTrue(old_ip in self.policy.key_locations['key3'])
        self.assertTrue(new_ip in self.policy.key_locations['key3'])
        self.assertTrue(new_ip in self.policy.key_locations['key4'])
        self.assertTrue(new_ip in self.policy.key_locations['key5'])
예제 #5
    def test_process_status_not_running(self):
        This test passes in a status for a server that is leaving the system
        and ensures that all the metadata reflects this fact after the
        # Construct a new thread status to prime the policy engine with.
        function_name = 'square'
        status = ThreadStatus()
        status.running = True
        status.ip = self.ip
        status.tid = 1
        status.utilization = 0

        key = (status.ip, status.tid)

        # Add metadata to the policy engine to make it think that this node
        # used to have a pinned function.
        self.policy.function_locations[function_name] = {key}
        self.policy.thread_statuses[key] = status

        # Clear the status' fields to report that it is turning off.
        status = ThreadStatus()
        status.running = False
        status.ip = self.ip
        status.tid = 1

        # Process the status and check the metadata.

        self.assertTrue(key not in self.policy.thread_statuses)
        self.assertTrue(key not in self.policy.unpinned_executors)
        self.assertEqual(len(self.policy.function_locations[function_name]), 0)
예제 #6
    def test_process_status_restart(self):
        This tests checks that when we receive a status update from a restarted
        executor, we correctly update its metadata.
        # Construct a new thread status to pass into the policy engine.
        function_name = 'square'
        status = ThreadStatus()
        status.running = True
        status.ip = self.ip
        status.tid = 1
        status.utilization = 0

        key = (status.ip, status.tid)

        # Add metadata to the policy engine to make it think that this node
        # used to have a pinned function.
        self.policy.function_locations[function_name] = {key}
        self.policy.thread_statuses[key] = status

        # Clear the status' pinned functions (i.e., restart).
        status = ThreadStatus()
        status.ip = self.ip
        status.tid = 1
        status.running = True
        status.utilization = 0

        # Process the status and check the metadata.

        self.assertEqual(len(self.policy.function_locations[function_name]), 0)
        self.assertTrue(key in self.policy.unpinned_executors)
예제 #7
def executor(ip, mgmt_ip, schedulers, thread_id):
    # logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s %(message)s')
                        format='%(asctime)s %(message)s')

    # Check what resources we have access to, set as an environment variable.
    if os.getenv('EXECUTOR_TYPE', 'CPU') == 'GPU':
        exec_type = GPU
        exec_type = CPU

    context = zmq.Context(1)
    poller = zmq.Poller()

    pin_socket = context.socket(zmq.PULL)
    pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id))

    unpin_socket = context.socket(zmq.PULL)
    unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                      (sutils.UNPIN_PORT + thread_id))

    exec_socket = context.socket(zmq.PULL)
    exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                     (sutils.FUNC_EXEC_PORT + thread_id))

    dag_queue_socket = context.socket(zmq.PULL)
    dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                          (sutils.DAG_QUEUE_PORT + thread_id))

    dag_exec_socket = context.socket(zmq.PULL)
    dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                         (sutils.DAG_EXEC_PORT + thread_id))

    self_depart_socket = context.socket(zmq.PULL)
    self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                            (sutils.SELF_DEPART_PORT + thread_id))

    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(pin_socket, zmq.POLLIN)
    poller.register(unpin_socket, zmq.POLLIN)
    poller.register(exec_socket, zmq.POLLIN)
    poller.register(dag_queue_socket, zmq.POLLIN)
    poller.register(dag_exec_socket, zmq.POLLIN)
    poller.register(self_depart_socket, zmq.POLLIN)

    # If the management IP is set to None, that means that we are running in
    # local mode, so we use a regular AnnaTcpClient rather than an IPC client.
    has_ephe = False
    if mgmt_ip:
        if 'STORAGE_OR_DEFAULT' in os.environ and os.environ[
                'STORAGE_OR_DEFAULT'] == '0':
            client = AnnaTcpClient(os.environ['ROUTE_ADDR'],
            has_ephe = True
            client = AnnaIpcClient(thread_id, context)
        # force_remote_anna = 1
        # if 'FORCE_REMOTE' in os.environ:
        #     force_remote_anna = int(os.environ['FORCE_REMOTE'])

        # if force_remote_anna == 0: # remote anna only
        #     client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id)
        # elif force_remote_anna == 1: # anna cache
        #     client = AnnaIpcClient(thread_id, context)
        # elif force_remote_anna == 2: # control both cache and remote anna
        #     remote_client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id)
        #     cache_client = AnnaIpcClient(thread_id, context)
        #     client = cache_client
        #     user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id, (cache_client, remote_client))

        local = False
        client = AnnaTcpClient('', '', local=True, offset=1)
        local = True

    user_library = CloudburstUserLibrary(context,

    status = ThreadStatus()
    status.ip = ip
    status.tid = thread_id
    status.running = True
    status.type = exec_type
    utils.push_status(schedulers, pusher_cache, status)

    departing = False

    # Maintains a request queue for each function pinned on this executor. Each
    # function will have a set of request IDs mapped to it, and this map stores
    # a schedule for each request ID.
    queue = {}

    # Tracks the actual function objects that are pinned to this executor.
    function_cache = {}

    # Tracks runtime cost of excuting a DAG function.
    runtimes = {}

    # If multiple triggers are necessary for a function, track the triggers as
    # we receive them. This is also used if a trigger arrives before its
    # corresponding schedule.
    received_triggers = {}

    # Tracks when we received a function request, so we can report end-to-end
    # latency for the whole executio.
    receive_times = {}

    # Tracks the number of requests we are finishing for each function pinned
    # here.
    exec_counts = {}

    # Tracks the end-to-end runtime of each DAG request for which we are the
    # sink function.
    dag_runtimes = {}

    # A map with KVS keys and their corresponding deserialized payloads.
    cache = {}

    # A map which tracks the most recent DAGs for which we have finished our
    # work.
    finished_executions = {}

    # The set of pinned functions and whether they support batching. NOTE: This
    # is only a set for local mode -- in cluster mode, there will only be one
    # pinned function per executor.
    batching = False

    # Internal metadata to track thread utilization.
    report_start = time.time()
    event_occupancy = {
        'pin': 0.0,
        'unpin': 0.0,
        'func_exec': 0.0,
        'dag_queue': 0.0,
        'dag_exec': 0.0
    total_occupancy = 0.0

    while True:
        socks = dict(poller.poll(timeout=1000))

        if pin_socket in socks and socks[pin_socket] == zmq.POLLIN:
            work_start = time.time()
            batching = pin(pin_socket, pusher_cache, client, status,
                           function_cache, runtimes, exec_counts, user_library,
                           local, batching)
            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['pin'] += elapsed
            total_occupancy += elapsed

        if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN:
            work_start = time.time()
            unpin(unpin_socket, status, function_cache, runtimes, exec_counts)
            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['unpin'] += elapsed
            total_occupancy += elapsed

        if exec_socket in socks and socks[exec_socket] == zmq.POLLIN:
            work_start = time.time()
            # logging.info(f'Executor timer. exec_socket recv: {work_start}')

            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['func_exec'] += elapsed
            total_occupancy += elapsed

        if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN:
            work_start = time.time()
                f'Executor timer. dag_queue_socket recv: {work_start}')
            # In order to effectively support batching, we have to make sure we
            # dequeue lots of schedules in addition to lots of triggers. Right
            # now, we're not going to worry about supporting batching here,
            # just on the trigger dequeue side, but we still have to dequeue
            # all schedules we've received. We just process them one at a time.
            while True:
                schedule = DagSchedule()
                    msg = dag_queue_socket.recv(zmq.DONTWAIT)
                except zmq.ZMQError as e:
                    if e.errno == zmq.EAGAIN:
                        break  # There are no more messages.
                        raise e  # Unexpected error.

                fname = schedule.target_function

                    'Received a schedule for DAG %s (%s), function %s.' %
                    (schedule.dag.name, schedule.id, fname))

                if fname not in queue:
                    queue[fname] = {}

                queue[fname][schedule.id] = schedule

                if (schedule.id, fname) not in receive_times:
                    receive_times[(schedule.id, fname)] = time.time()

                # In case we receive the trigger before we receive the schedule, we
                # can trigger from this operation as well.
                trkey = (schedule.id, fname)
                fref = None

                # Check to see what type of execution this function is.
                for ref in schedule.dag.functions:
                    if ref.name == fname:
                        fref = ref

                if (trkey in received_triggers and
                    ((len(received_triggers[trkey]) == len(schedule.triggers))
                     or (fref.type == MULTIEXEC))):

                    triggers = list(received_triggers[trkey].values())

                    if fname not in function_cache:
                        logging.error('%s not in function cache', fname)
                        utils.generate_error_response(schedule, client, fname)
                    exec_start = time.time()
                    # logging.info(f'Executor timer. dag_queue_socket exec_dag: {exec_start}')
                    # We don't support actual batching for when we receive a
                    # schedule before a trigger, so everything is just a batch of
                    # size 1 if anything.
                    success = exec_dag_function(pusher_cache, client,
                                                [schedule], user_library,
                                                dag_runtimes, cache,
                                                schedulers, batching)[0]

                    del received_triggers[trkey]
                    if success:
                        del queue[fname][schedule.id]

                        fend = time.time()
                        fstart = receive_times[(schedule.id, fname)]
                        runtimes[fname].append(fend - work_start)
                        exec_counts[fname] += 1

                        finished_executions[(schedule.id, fname)] = time.time()

            elapsed = time.time() - work_start
            event_occupancy['dag_queue'] += elapsed
            total_occupancy += elapsed

        if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN:
            work_start = time.time()
            # logging.info(f'Executor timer. dag_exec_socket recv: {work_start}')

            # How many messages to dequeue -- BATCH_SIZE_MAX or 1 depending on
            # the function configuration.
            if batching:
                count = BATCH_SIZE_MAX
                count = 1

            trigger_keys = set()

            for _ in range(count):  # Dequeue count number of messages.
                trigger = DagTrigger()

                    msg = dag_exec_socket.recv(zmq.DONTWAIT)
                except zmq.ZMQError as e:
                    if e.errno == zmq.EAGAIN:  # There are no more messages.
                        raise e  # Unexpected error.


                # We have received a repeated trigger for a function that has
                # already finished executing.
                if trigger.id in finished_executions:

                fname = trigger.target_function
                    'Received a trigger for schedule %s, function %s.' %
                    (trigger.id, fname))

                key = (trigger.id, fname)
                if key not in received_triggers:
                    received_triggers[key] = {}

                if (trigger.id, fname) not in receive_times:
                    receive_times[(trigger.id, fname)] = time.time()

                received_triggers[key][trigger.source] = trigger

            # Only execute the functions for which we have received a schedule.
            # Everything else will wait.
            for tid, fname in list(trigger_keys):
                if fname not in queue or tid not in queue[fname]:
                    trigger_keys.remove((tid, fname))

            if len(trigger_keys) == 0:

            fref = None
            schedule = queue[fname][list(trigger_keys)[0]
                                    [0]]  # Pick a random schedule to check.
            # Check to see what type of execution this function is.
            for ref in schedule.dag.functions:
                if ref.name == fname:
                    fref = ref

            # Compile a list of all the trigger sets for which we have
            # enough triggers.
            trigger_sets = []
            schedules = []
            for key in trigger_keys:
                if (len(received_triggers[key]) == len(schedule.triggers)) or \
                        fref.type == MULTIEXEC:

                    if fref.type == MULTIEXEC:
                        triggers = [trigger]
                        triggers = list(received_triggers[key].values())

                    if fname not in function_cache:
                        logging.error('%s not in function cache', fname)
                        utils.generate_error_response(schedule, client, fname)

                    schedule = queue[fname][key[0]]

            exec_start = time.time()
            # logging.info(f'Executor timer. dag_exec_socket exec_dag: {exec_start}')
            # Pass all of the trigger_sets into exec_dag_function at once.
            # We also include the batching variaible to make sure we know
            # whether to pass lists into the fn or not.
            if len(trigger_sets) > 0:
                successes = exec_dag_function(pusher_cache, client,
                                              function_cache[fname], schedules,
                                              user_library, dag_runtimes,
                                              cache, schedulers, batching)
                del received_triggers[key]

                for key, success in zip(trigger_keys, successes):
                    if success:
                        del queue[fname][key[0]]  # key[0] is trigger.id.

                        fend = time.time()
                        fstart = receive_times[key]

                        average_time = (fend - work_start) / len(trigger_keys)

                        exec_counts[fname] += 1

                        finished_executions[(schedule.id, fname)] = time.time()

            elapsed = time.time() - work_start
            event_occupancy['dag_exec'] += elapsed
            total_occupancy += elapsed

        if self_depart_socket in socks and socks[self_depart_socket] == \
            # This message does not matter.

            logging.info('Preparing to depart. No longer accepting requests ' +
                         'and clearing all queues.')

            status.running = False
            utils.push_status(schedulers, pusher_cache, status)

            departing = True

        # periodically report function occupancy
        report_end = time.time()
        if report_end - report_start > REPORT_THRESH:
            if len(cache) > 100:
                extra_keys = list(cache.keys())[:len(cache) - 100]
                for key in extra_keys:
                    del cache[key]

            utilization = total_occupancy / (report_end - report_start)
            status.utilization = utilization

            # Periodically report my status to schedulers with the utilization
            # set.
            utils.push_status(schedulers, pusher_cache, status)

            logging.debug('Total thread occupancy: %.6f' % (utilization))

            for event in event_occupancy:
                occ = event_occupancy[event] / (report_end - report_start)
                logging.debug('\tEvent %s occupancy: %.6f' % (event, occ))
                event_occupancy[event] = 0.0

            stats = ExecutorStatistics()
            for fname in runtimes:
                if exec_counts[fname] > 0:
                    fstats = stats.functions.add()
                    fstats.name = fname
                    fstats.call_count = exec_counts[fname]

                exec_counts[fname] = 0

            for dname in dag_runtimes:
                dstats = stats.dags.add()
                dstats.name = dname



            # If we are running in cluster mode, mgmt_ip will be set, and we
            # will report our status and statistics to it. Otherwise, we will
            # write to the local conf file
            if mgmt_ip:
                sckt = pusher_cache.get(

                sckt = pusher_cache.get(utils.get_util_report_address(mgmt_ip))

            report_start = time.time()
            total_occupancy = 0.0

            # Periodically clear any old functions we have cached that we are
            # no longer accepting requests for.
            del_list = []
            for fname in queue:
                if len(queue[fname]) == 0 and fname not in status.functions:
                    del function_cache[fname]
                    del runtimes[fname]
                    del exec_counts[fname]

            for fname in del_list:
                del queue[fname]

            del_list = []
            for tid in finished_executions:
                if (time.time() - finished_executions[tid]) > 10:

            for tid in del_list:
                del finished_executions[tid]

            # If we are departing and have cleared our queues, let the
            # management server know, and exit the process.
            if departing and len(queue) == 0:
                sckt = pusher_cache.get(utils.get_depart_done_addr(mgmt_ip))

                # We specifically pass 1 as the exit code when ending our
                # process so that the wrapper script does not restart us.
예제 #8
def scheduler(ip, mgmt_ip, route_addr):

    # If the management IP is not set, we are running in local mode.
    local = (mgmt_ip is None)
    kvs = AnnaTcpClient(route_addr, ip, local=local)

    scheduler_id = str(uuid.uuid4())

    context = zmq.Context(1)

    # A mapping from a DAG's name to its protobuf representation.
    dags = {}

    # Tracks how often a request for each function is received.
    call_frequency = {}

    # Tracks the time interval between successive requests for a particular
    # DAG.
    interarrivals = {}

    # Tracks the most recent arrival for each DAG -- used to calculate
    # interarrival times.
    last_arrivals = {}

    # Maintains a list of all other schedulers in the system, so we can
    # propagate metadata to them.
    schedulers = []

    connect_socket = context.socket(zmq.REP)
    connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT))

    func_create_socket = context.socket(zmq.REP)
    func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT))

    func_call_socket = context.socket(zmq.REP)
    func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT))

    dag_create_socket = context.socket(zmq.REP)
    dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT))

    dag_call_socket = context.socket(zmq.REP)
    dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT))

    dag_delete_socket = context.socket(zmq.REP)
    dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT))

    list_socket = context.socket(zmq.REP)
    list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT))

    exec_status_socket = context.socket(zmq.PULL)
    exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT))

    sched_update_socket = context.socket(zmq.PULL)
    sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE %

    pin_accept_socket = context.socket(zmq.PULL)
    pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500)
    pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE %

    requestor_cache = SocketCache(context, zmq.REQ)
    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(connect_socket, zmq.POLLIN)
    poller.register(func_create_socket, zmq.POLLIN)
    poller.register(func_call_socket, zmq.POLLIN)
    poller.register(dag_create_socket, zmq.POLLIN)
    poller.register(dag_call_socket, zmq.POLLIN)
    poller.register(dag_delete_socket, zmq.POLLIN)
    poller.register(list_socket, zmq.POLLIN)
    poller.register(exec_status_socket, zmq.POLLIN)
    poller.register(sched_update_socket, zmq.POLLIN)

    # Start the policy engine.
    policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket,

    start = time.time()

    while True:
        socks = dict(poller.poll(timeout=1000))

        if connect_socket in socks and socks[connect_socket] == zmq.POLLIN:
            msg = connect_socket.recv_string()

        if (func_create_socket in socks
                and socks[func_create_socket] == zmq.POLLIN):
            create_function(func_create_socket, kvs)

        if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN:
            call_function(func_call_socket, pusher_cache, policy)

        if (dag_create_socket in socks
                and socks[dag_create_socket] == zmq.POLLIN):
            create_dag(dag_create_socket, pusher_cache, kvs, dags, policy,

        if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN:
            call = DagCall()

            name = call.name

            t = time.time()
            if name in last_arrivals:
                if name not in interarrivals:
                    interarrivals[name] = []

                interarrivals[name].append(t - last_arrivals[name])

            last_arrivals[name] = t

            if name not in dags:
                resp = GenericResponse()
                resp.success = False
                resp.error = NO_SUCH_DAG


            dag = dags[name]
            for fname in dag[0].functions:
                call_frequency[fname.name] += 1

            response = call_dag(call, pusher_cache, dags, policy)

        if (dag_delete_socket in socks
                and socks[dag_delete_socket] == zmq.POLLIN):
            delete_dag(dag_delete_socket, dags, policy, call_frequency)

        if list_socket in socks and socks[list_socket] == zmq.POLLIN:
            msg = list_socket.recv_string()
            prefix = msg if msg else ''

            resp = StringSet()
            resp.keys.extend(sched_utils.get_func_list(kvs, prefix))


        if exec_status_socket in socks and socks[exec_status_socket] == \
            status = ThreadStatus()


        if sched_update_socket in socks and socks[sched_update_socket] == \
            status = SchedulerStatus()

            # Retrieve any DAGs that some other scheduler knows about that we
            # do not yet know about.
            for dname in status.dags:
                if dname not in dags:
                    payload = kvs.get(dname)
                    while None in payload:
                        payload = kvs.get(dname)

                    dag = Dag()
                    dags[dag.name] = (dag, sched_utils.find_dag_source(dag))

                    for fname in dag.functions:
                        if fname not in call_frequency:
                            call_frequency[fname] = 0


        end = time.time()

        if end - start > METADATA_THRESHOLD:
            # Update the scheduler policy-related metadata.

            # If the management IP is None, that means we arre running in
            # local mode, so there is no need to deal with caches and other
            # schedulers.
            if mgmt_ip:
                schedulers = sched_utils.get_ip_set(
                    requestor_cache, False)

        if end - start > REPORT_THRESHOLD:
            num_unique_executors = policy.get_unique_executors()
            key = scheduler_id + ':' + str(time.time())
            data = {'key': key, 'count': num_unique_executors}

            status = SchedulerStatus()
            for name in dags.keys():

            for fname in policy.function_locations:
                for loc in policy.function_locations[fname]:
                    floc = status.function_locations.add()
                    floc.name = fname
                    floc.ip = loc[0]
                    floc.tid = loc[1]

            msg = status.SerializeToString()

            for sched_ip in schedulers:
                if sched_ip != ip:
                    sckt = pusher_cache.get(

            stats = ExecutorStatistics()
            for fname in call_frequency:
                fstats = stats.functions.add()
                fstats.name = fname
                fstats.call_count = call_frequency[fname]
                logging.info('Reporting %d calls for function %s.' %
                             (call_frequency[fname], fname))

                call_frequency[fname] = 0

            for dname in interarrivals:
                dstats = stats.dags.add()
                dstats.name = dname
                dstats.call_count = len(interarrivals[dname]) + 1


            # We only attempt to send the statistics if we are running in
            # cluster mode. If we are running in local mode, we write them to
            # the local log file.
            if mgmt_ip:
                sckt = pusher_cache.get(

            start = time.time()
예제 #9
def executor(ip, mgmt_ip, schedulers, thread_id):
                        format='%(asctime)s %(message)s')

    context = zmq.Context(1)
    poller = zmq.Poller()

    pin_socket = context.socket(zmq.PULL)
    pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id))

    unpin_socket = context.socket(zmq.PULL)
    unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                      (sutils.UNPIN_PORT + thread_id))

    exec_socket = context.socket(zmq.PULL)
    exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                     (sutils.FUNC_EXEC_PORT + thread_id))

    dag_queue_socket = context.socket(zmq.PULL)
    dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                          (sutils.DAG_QUEUE_PORT + thread_id))

    dag_exec_socket = context.socket(zmq.PULL)
    dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                         (sutils.DAG_EXEC_PORT + thread_id))

    self_depart_socket = context.socket(zmq.PULL)
    self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                            (sutils.SELF_DEPART_PORT + thread_id))

    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(pin_socket, zmq.POLLIN)
    poller.register(unpin_socket, zmq.POLLIN)
    poller.register(exec_socket, zmq.POLLIN)
    poller.register(dag_queue_socket, zmq.POLLIN)
    poller.register(dag_exec_socket, zmq.POLLIN)
    poller.register(self_depart_socket, zmq.POLLIN)

    # If the management IP is set to None, that means that we are running in
    # local mode, so we use a regular AnnaTcpClient rather than an IPC client.
    if mgmt_ip:
        client = AnnaIpcClient(thread_id, context)
        client = AnnaTcpClient('', '', local=True, offset=1)

    user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id,

    status = ThreadStatus()
    status.ip = ip
    status.tid = thread_id
    status.running = True
    utils.push_status(schedulers, pusher_cache, status)

    departing = False

    # Maintains a request queue for each function pinned on this executor. Each
    # function will have a set of request IDs mapped to it, and this map stores
    # a schedule for each request ID.
    queue = {}

    # Tracks the actual function objects that are pinned to this executor.
    function_cache = {}

    # Tracks runtime cost of excuting a DAG function.
    runtimes = {}

    # If multiple triggers are necessary for a function, track the triggers as
    # we receive them. This is also used if a trigger arrives before its
    # corresponding schedule.
    received_triggers = {}

    # Tracks when we received a function request, so we can report end-to-end
    # latency for the whole executio.
    receive_times = {}

    # Tracks the number of requests we are finishing for each function pinned
    # here.
    exec_counts = {}

    # Tracks the end-to-end runtime of each DAG request for which we are the
    # sink function.
    dag_runtimes = {}

    # A map with KVS keys and their corresponding deserialized payloads.
    cache = {}

    # Internal metadata to track thread utilization.
    report_start = time.time()
    event_occupancy = {
        'pin': 0.0,
        'unpin': 0.0,
        'func_exec': 0.0,
        'dag_queue': 0.0,
        'dag_exec': 0.0
    total_occupancy = 0.0

    while True:
        socks = dict(poller.poll(timeout=1000))

        if pin_socket in socks and socks[pin_socket] == zmq.POLLIN:
            work_start = time.time()
            pin(pin_socket, pusher_cache, client, status, function_cache,
                runtimes, exec_counts, user_library)
            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['pin'] += elapsed
            total_occupancy += elapsed

        if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN:
            work_start = time.time()
            unpin(unpin_socket, status, function_cache, runtimes, exec_counts)
            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['unpin'] += elapsed
            total_occupancy += elapsed

        if exec_socket in socks and socks[exec_socket] == zmq.POLLIN:
            work_start = time.time()
            exec_function(exec_socket, client, user_library, cache,

            utils.push_status(schedulers, pusher_cache, status)

            elapsed = time.time() - work_start
            event_occupancy['func_exec'] += elapsed
            total_occupancy += elapsed

        if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN:
            work_start = time.time()

            schedule = DagSchedule()
            fname = schedule.target_function

            logging.info('Received a schedule for DAG %s (%s), function %s.' %
                         (schedule.dag.name, schedule.id, fname))

            if fname not in queue:
                queue[fname] = {}

            queue[fname][schedule.id] = schedule

            if (schedule.id, fname) not in receive_times:
                receive_times[(schedule.id, fname)] = time.time()

            # In case we receive the trigger before we receive the schedule, we
            # can trigger from this operation as well.
            trkey = (schedule.id, fname)
            if (trkey in received_triggers and
                (len(received_triggers[trkey]) == len(schedule.triggers))):

                exec_dag_function(pusher_cache, client,
                                  function_cache[fname], schedule,
                                  user_library, dag_runtimes, cache)

                del received_triggers[trkey]
                del queue[fname][schedule.id]

                fend = time.time()
                fstart = receive_times[(schedule.id, fname)]
                runtimes[fname].append(fend - fstart)
                exec_counts[fname] += 1

            elapsed = time.time() - work_start
            event_occupancy['dag_queue'] += elapsed
            total_occupancy += elapsed

        if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN:
            work_start = time.time()
            trigger = DagTrigger()

            fname = trigger.target_function
            logging.info('Received a trigger for schedule %s, function %s.' %
                         (trigger.id, fname))

            key = (trigger.id, fname)
            if key not in received_triggers:
                received_triggers[key] = {}

            if (trigger.id, fname) not in receive_times:
                receive_times[(trigger.id, fname)] = time.time()

            received_triggers[key][trigger.source] = trigger
            if fname in queue and trigger.id in queue[fname]:
                schedule = queue[fname][trigger.id]
                if len(received_triggers[key]) == len(schedule.triggers):
                    exec_dag_function(pusher_cache, client,
                                      function_cache[fname], schedule,
                                      user_library, dag_runtimes, cache)

                    del received_triggers[key]
                    del queue[fname][trigger.id]

                    fend = time.time()
                    fstart = receive_times[(trigger.id, fname)]
                    runtimes[fname].append(fend - fstart)
                    exec_counts[fname] += 1

            elapsed = time.time() - work_start
            event_occupancy['dag_exec'] += elapsed
            total_occupancy += elapsed

        if self_depart_socket in socks and socks[self_depart_socket] == \
            # This message does not matter.

            logging.info('Preparing to depart. No longer accepting requests ' +
                         'and clearing all queues.')

            status.running = False
            utils.push_status(schedulers, pusher_cache, status)

            departing = True

        # periodically report function occupancy
        report_end = time.time()
        if report_end - report_start > REPORT_THRESH:

            utilization = total_occupancy / (report_end - report_start)
            status.utilization = utilization

            # Periodically report my status to schedulers with the utilization
            # set.
            utils.push_status(schedulers, pusher_cache, status)

            logging.info('Total thread occupancy: %.6f' % (utilization))

            for event in event_occupancy:
                occ = event_occupancy[event] / (report_end - report_start)
                logging.info('\tEvent %s occupancy: %.6f' % (event, occ))
                event_occupancy[event] = 0.0

            stats = ExecutorStatistics()
            for fname in runtimes:
                if exec_counts[fname] > 0:
                    fstats = stats.functions.add()
                    fstats.name = fname
                    fstats.call_count = exec_counts[fname]

                exec_counts[fname] = 0

            for dname in dag_runtimes:
                dstats = stats.dags.add()
                dstats.name = dname



            # If we are running in cluster mode, mgmt_ip will be set, and we
            # will report our status and statistics to it. Otherwise, we will
            # write to the local conf file
            if mgmt_ip:
                sckt = pusher_cache.get(

                sckt = pusher_cache.get(utils.get_util_report_address(mgmt_ip))

            report_start = time.time()
            total_occupancy = 0.0

            # Periodically clear any old functions we have cached that we are
            # no longer accepting requests for.
            del_list = []
            for fname in queue:
                if len(queue[fname]) == 0 and fname not in status.functions:
                    del function_cache[fname]
                    del runtimes[fname]
                    del exec_counts[fname]

            for fname in del_list:
                del queue[fname]

            # If we are departing and have cleared our queues, let the
            # management server know, and exit the process.
            if departing and len(queue) == 0:
                sckt = pusher_cache.get(utils.get_depart_done_addr(mgmt_ip))

                # We specifically pass 1 as the exit code when ending our
                # process so that the wrapper script does not restart us.
예제 #10
파일: server.py 프로젝트: MincYu/cloudburst
def scheduler(ip, mgmt_ip, route_addr, policy_type):

    # If the management IP is not set, we are running in local mode.
    local = (mgmt_ip is None)
    kvs = AnnaTcpClient(route_addr, ip, local=local)

    scheduler_id = str(uuid.uuid4())

    context = zmq.Context(1)
    context.set(zmq.MAX_SOCKETS, 10000)

    # A mapping from a DAG's name to its protobuf representation.
    dags = {}

    # Tracks how often a request for each function is received.
    call_frequency = {}

    # Tracks the time interval between successive requests for a particular
    # DAG.
    interarrivals = {}

    # Tracks the most recent arrival for each DAG -- used to calculate
    # interarrival times.
    last_arrivals = {}

    # Maintains a list of all other schedulers in the system, so we can
    # propagate metadata to them.
    schedulers = set()

    connect_socket = context.socket(zmq.REP)
    connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT))

    func_create_socket = context.socket(zmq.REP)
    func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT))

    func_call_socket = context.socket(zmq.REP)
    func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT))

    # This is for handle the invocation from queue
    # Mainly for storage event
    func_call_queue_socket = context.socket(zmq.PULL)
    func_call_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE %

    dag_create_socket = context.socket(zmq.REP)
    dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT))

    dag_call_socket = context.socket(zmq.REP)
    dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT))

    dag_delete_socket = context.socket(zmq.REP)
    dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT))

    list_socket = context.socket(zmq.REP)
    list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT))

    exec_status_socket = context.socket(zmq.PULL)
    exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT))

    sched_update_socket = context.socket(zmq.PULL)
    sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE %

    pin_accept_socket = context.socket(zmq.PULL)
    pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000)  # 10 seconds.
    pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE %

    continuation_socket = context.socket(zmq.PULL)
    continuation_socket.bind(sutils.BIND_ADDR_TEMPLATE %

    if not local:
        management_request_socket = context.socket(zmq.REQ)
        management_request_socket.setsockopt(zmq.RCVTIMEO, 500)
        # By setting this flag, zmq matches replies with requests.
        management_request_socket.setsockopt(zmq.REQ_CORRELATE, 1)
        # Relax strict alternation between request and reply.
        # For detailed explanation, see here: http://api.zeromq.org/4-1:zmq-setsockopt
        management_request_socket.setsockopt(zmq.REQ_RELAXED, 1)

    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(connect_socket, zmq.POLLIN)
    poller.register(func_create_socket, zmq.POLLIN)
    poller.register(func_call_socket, zmq.POLLIN)
    poller.register(func_call_queue_socket, zmq.POLLIN)
    poller.register(dag_create_socket, zmq.POLLIN)
    poller.register(dag_call_socket, zmq.POLLIN)
    poller.register(dag_delete_socket, zmq.POLLIN)
    poller.register(list_socket, zmq.POLLIN)
    poller.register(exec_status_socket, zmq.POLLIN)
    poller.register(sched_update_socket, zmq.POLLIN)
    poller.register(continuation_socket, zmq.POLLIN)

    # Start the policy engine.
    policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket,

    start = time.time()

    while True:
        socks = dict(poller.poll(timeout=1000))

        if connect_socket in socks and socks[connect_socket] == zmq.POLLIN:
            msg = connect_socket.recv_string()

        if (func_create_socket in socks
                and socks[func_create_socket] == zmq.POLLIN):
            create_function(func_create_socket, kvs)

        if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN:
            call_function(func_call_socket, pusher_cache, policy)

        if func_call_queue_socket in socks and socks[
                func_call_queue_socket] == zmq.POLLIN:
            call_function_from_queue(func_call_queue_socket, pusher_cache,

        if (dag_create_socket in socks
                and socks[dag_create_socket] == zmq.POLLIN):
            create_dag(dag_create_socket, pusher_cache, kvs, dags, policy,

        if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN:
            start_t = int(time.time() * 1000000)
            call = DagCall()

            name = call.name

            t = time.time()
            if name in last_arrivals:
                if name not in interarrivals:
                    interarrivals[name] = []

                interarrivals[name].append(t - last_arrivals[name])

            last_arrivals[name] = t

            if name not in dags:
                resp = GenericResponse()
                resp.success = False
                resp.error = NO_SUCH_DAG


            dag = dags[name]
            for fname in dag[0].functions:
                call_frequency[fname.name] += 1

            response = call_dag(call, pusher_cache, dags, policy)
            sched_t = int(time.time() * 1000000)
                f'App function {name} recv: {start_t}, scheduled: {sched_t}')

        if (dag_delete_socket in socks
                and socks[dag_delete_socket] == zmq.POLLIN):
            delete_dag(dag_delete_socket, dags, policy, call_frequency)

        if list_socket in socks and socks[list_socket] == zmq.POLLIN:
            msg = list_socket.recv_string()
            prefix = msg if msg else ''

            resp = StringSet()
            resp.keys.extend(sched_utils.get_func_list(kvs, prefix))


        if exec_status_socket in socks and socks[exec_status_socket] == \
            status = ThreadStatus()


        if sched_update_socket in socks and socks[sched_update_socket] == \
            status = SchedulerStatus()

            # Retrieve any DAGs that some other scheduler knows about that we
            # do not yet know about.
            for dname in status.dags:
                if dname not in dags:
                    payload = kvs.get(dname)
                    while None in payload:
                        payload = kvs.get(dname)

                    dag = Dag()
                    dags[dag.name] = (dag, sched_utils.find_dag_source(dag))

                    for fname in dag.functions:
                        if fname.name not in call_frequency:
                            call_frequency[fname.name] = 0


        if continuation_socket in socks and socks[continuation_socket] == \
            start_t = int(time.time() * 1000000)

            continuation = Continuation()

            call = continuation.call
            call.name = continuation.name

            result = Value()

            dag, sources = dags[call.name]
            for source in sources:

            call_dag(call, pusher_cache, dags, policy, continuation.id)
            sched_t = int(time.time() * 1000000)
                f'App function {call.name} recv: {start_t}, scheduled: {sched_t}'
            for fname in dag.functions:
                call_frequency[fname.name] += 1

        end = time.time()

        if end - start > METADATA_THRESHOLD:
            # Update the scheduler policy-related metadata.

            # If the management IP is None, that means we arre running in
            # local mode, so there is no need to deal with caches and other
            # schedulers.
            if not local:
                latest_schedulers = sched_utils.get_ip_set(
                    management_request_socket, False)
                if latest_schedulers:
                    schedulers = latest_schedulers

        if end - start > REPORT_THRESHOLD:
            status = SchedulerStatus()
            for name in dags.keys():

            for fname in policy.function_locations:
                for loc in policy.function_locations[fname]:
                    floc = status.function_locations.add()
                    floc.name = fname
                    floc.ip = loc[0]
                    floc.tid = loc[1]

            msg = status.SerializeToString()

            for sched_ip in schedulers:
                if sched_ip != ip:
                    sckt = pusher_cache.get(

            stats = ExecutorStatistics()
            for fname in call_frequency:
                fstats = stats.functions.add()
                fstats.name = fname
                fstats.call_count = call_frequency[fname]
                logging.debug('Reporting %d calls for function %s.' %
                              (call_frequency[fname], fname))

                call_frequency[fname] = 0

            for dname in interarrivals:
                dstats = stats.dags.add()
                dstats.name = dname
                dstats.call_count = len(interarrivals[dname]) + 1


            # We only attempt to send the statistics if we are running in
            # cluster mode. If we are running in local mode, we write them to
            # the local log file.
            if mgmt_ip:
                sckt = pusher_cache.get(

            start = time.time()