def test_process_status(self): ''' This test ensures that when a new status update is received from an executor, the local server metadata is correctly updated in the normal case. ''' # Construct a new thread status to pass into the policy engine. function_name = 'square' status = ThreadStatus() status.running = True status.ip = self.ip status.tid = 1 status.functions.append(function_name) status.utilization = 0.10 # Process the newly created status. self.policy.process_status(status) status.tid = 2 status.utilization = 0.90 self.policy.process_status(status) key = (status.ip, status.tid) self.assertTrue(key not in self.policy.unpinned_cpu_executors) self.assertTrue(key in self.policy.function_locations[function_name]) self.assertTrue(key in self.policy.backoff)
def setUp(self): self.pusher_cache = zmq_utils.MockPusherCache() self.socket = zmq_utils.MockZmqSocket() self.pin_socket = zmq_utils.MockZmqSocket() self.kvs_client = kvs_client.MockAnnaClient() self.ip = '127.0.0.1' self.policy = DefaultCloudburstSchedulerPolicy(self.pin_socket, self.pusher_cache, self.kvs_client, self.ip, random_threshold=0) # Add an executor to the policy engine by default. status = ThreadStatus() status.ip = self.ip status.tid = 0 self.executor_key = (status.ip, status.tid) self.policy.unpinned_executors.add(self.executor_key)
def setUp(self): self.kvs_client = kvs_client.MockAnnaClient() self.socket = zmq_utils.MockZmqSocket() self.pusher_cache = zmq_utils.MockPusherCache() self.ip = '127.0.0.1' self.status = ThreadStatus() self.status.ip = self.ip self.status.tid = 0 self.status.running = True self.pinned_functions = {} self.runtimes = {} self.exec_counts = {} self.user_library = CloudburstUserLibrary(zmq_utils.MockZmqContext(), self.pusher_cache, self.ip, 0, self.kvs_client)
def test_metadata_update(self): ''' This test calls the periodic metadata update protocol and ensures that the correct metadata is removed from the system and that the correct metadata is retrieved/updated from the KVS. ''' # Create two executor threads on separate machines. old_ip = '127.0.0.1' new_ip = '192.168.0.1' old_executor = (old_ip, 1) new_executor = (new_ip, 2) old_status = ThreadStatus() old_status.ip = old_ip old_status.tid = 1 old_status.running = True new_status = ThreadStatus() new_status.ip = new_ip new_status.tid = 2 new_status.running = True self.policy.thread_statuses[old_executor] = old_status self.policy.thread_statuses[new_executor] = new_status # Add two executors, one with old an old backoff and one with a new # time. self.policy.backoff[old_executor] = time.time() - 10 self.policy.backoff[new_executor] = time.time() # For the new executor, add 10 old running times and 10 new ones. self.policy.running_counts[new_executor] = set() for _ in range(10): time.sleep(.0001) self.policy.running_counts[new_executor].add(time.time() - 10) for _ in range(10): time.sleep(.0001) self.policy.running_counts[new_executor].add(time.time()) # Publish some caching metadata into the KVS for each executor. old_set = StringSet() old_set.keys.extend(['key1', 'key2', 'key3']) new_set = StringSet() new_set.keys.extend(['key3', 'key4', 'key5']) self.kvs_client.put(get_cache_ip_key(old_ip), LWWPairLattice(0, old_set.SerializeToString())) self.kvs_client.put(get_cache_ip_key(new_ip), LWWPairLattice(0, new_set.SerializeToString())) self.policy.update() # Check that the metadata has been correctly pruned. self.assertEqual(len(self.policy.backoff), 1) self.assertTrue(new_executor in self.policy.backoff) self.assertEqual(len(self.policy.running_counts[new_executor]), 10) # Check that the caching information is correct. self.assertTrue(len(self.policy.key_locations['key1']), 1) self.assertTrue(len(self.policy.key_locations['key2']), 1) self.assertTrue(len(self.policy.key_locations['key3']), 2) self.assertTrue(len(self.policy.key_locations['key4']), 1) self.assertTrue(len(self.policy.key_locations['key5']), 1) self.assertTrue(old_ip in self.policy.key_locations['key1']) self.assertTrue(old_ip in self.policy.key_locations['key2']) self.assertTrue(old_ip in self.policy.key_locations['key3']) self.assertTrue(new_ip in self.policy.key_locations['key3']) self.assertTrue(new_ip in self.policy.key_locations['key4']) self.assertTrue(new_ip in self.policy.key_locations['key5'])
def test_process_status_not_running(self): ''' This test passes in a status for a server that is leaving the system and ensures that all the metadata reflects this fact after the processing. ''' # Construct a new thread status to prime the policy engine with. function_name = 'square' status = ThreadStatus() status.running = True status.ip = self.ip status.tid = 1 status.functions.append(function_name) status.utilization = 0 key = (status.ip, status.tid) # Add metadata to the policy engine to make it think that this node # used to have a pinned function. self.policy.function_locations[function_name] = {key} self.policy.thread_statuses[key] = status # Clear the status' fields to report that it is turning off. status = ThreadStatus() status.running = False status.ip = self.ip status.tid = 1 # Process the status and check the metadata. self.policy.process_status(status) self.assertTrue(key not in self.policy.thread_statuses) self.assertTrue(key not in self.policy.unpinned_executors) self.assertEqual(len(self.policy.function_locations[function_name]), 0)
def test_process_status_restart(self): ''' This tests checks that when we receive a status update from a restarted executor, we correctly update its metadata. ''' # Construct a new thread status to pass into the policy engine. function_name = 'square' status = ThreadStatus() status.running = True status.ip = self.ip status.tid = 1 status.functions.append(function_name) status.utilization = 0 key = (status.ip, status.tid) # Add metadata to the policy engine to make it think that this node # used to have a pinned function. self.policy.function_locations[function_name] = {key} self.policy.thread_statuses[key] = status # Clear the status' pinned functions (i.e., restart). status = ThreadStatus() status.ip = self.ip status.tid = 1 status.running = True status.utilization = 0 # Process the status and check the metadata. self.policy.process_status(status) self.assertEqual(len(self.policy.function_locations[function_name]), 0) self.assertTrue(key in self.policy.unpinned_executors)
def executor(ip, mgmt_ip, schedulers, thread_id): # logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s %(message)s') logging.basicConfig(filename='log_executor.txt', level=logging.INFO, filemode="w", format='%(asctime)s %(message)s') # Check what resources we have access to, set as an environment variable. if os.getenv('EXECUTOR_TYPE', 'CPU') == 'GPU': exec_type = GPU else: exec_type = CPU context = zmq.Context(1) poller = zmq.Poller() pin_socket = context.socket(zmq.PULL) pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id)) unpin_socket = context.socket(zmq.PULL) unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.UNPIN_PORT + thread_id)) exec_socket = context.socket(zmq.PULL) exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.FUNC_EXEC_PORT + thread_id)) dag_queue_socket = context.socket(zmq.PULL) dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_QUEUE_PORT + thread_id)) dag_exec_socket = context.socket(zmq.PULL) dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_EXEC_PORT + thread_id)) self_depart_socket = context.socket(zmq.PULL) self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SELF_DEPART_PORT + thread_id)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(pin_socket, zmq.POLLIN) poller.register(unpin_socket, zmq.POLLIN) poller.register(exec_socket, zmq.POLLIN) poller.register(dag_queue_socket, zmq.POLLIN) poller.register(dag_exec_socket, zmq.POLLIN) poller.register(self_depart_socket, zmq.POLLIN) # If the management IP is set to None, that means that we are running in # local mode, so we use a regular AnnaTcpClient rather than an IPC client. has_ephe = False if mgmt_ip: if 'STORAGE_OR_DEFAULT' in os.environ and os.environ[ 'STORAGE_OR_DEFAULT'] == '0': client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id) has_ephe = True else: client = AnnaIpcClient(thread_id, context) # force_remote_anna = 1 # if 'FORCE_REMOTE' in os.environ: # force_remote_anna = int(os.environ['FORCE_REMOTE']) # if force_remote_anna == 0: # remote anna only # client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id) # elif force_remote_anna == 1: # anna cache # client = AnnaIpcClient(thread_id, context) # elif force_remote_anna == 2: # control both cache and remote anna # remote_client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id) # cache_client = AnnaIpcClient(thread_id, context) # client = cache_client # user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id, (cache_client, remote_client)) local = False else: client = AnnaTcpClient('127.0.0.1', '127.0.0.1', local=True, offset=1) local = True user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id, client, has_ephe=has_ephe) status = ThreadStatus() status.ip = ip status.tid = thread_id status.running = True status.type = exec_type utils.push_status(schedulers, pusher_cache, status) departing = False # Maintains a request queue for each function pinned on this executor. Each # function will have a set of request IDs mapped to it, and this map stores # a schedule for each request ID. queue = {} # Tracks the actual function objects that are pinned to this executor. function_cache = {} # Tracks runtime cost of excuting a DAG function. runtimes = {} # If multiple triggers are necessary for a function, track the triggers as # we receive them. This is also used if a trigger arrives before its # corresponding schedule. received_triggers = {} # Tracks when we received a function request, so we can report end-to-end # latency for the whole executio. receive_times = {} # Tracks the number of requests we are finishing for each function pinned # here. exec_counts = {} # Tracks the end-to-end runtime of each DAG request for which we are the # sink function. dag_runtimes = {} # A map with KVS keys and their corresponding deserialized payloads. cache = {} # A map which tracks the most recent DAGs for which we have finished our # work. finished_executions = {} # The set of pinned functions and whether they support batching. NOTE: This # is only a set for local mode -- in cluster mode, there will only be one # pinned function per executor. batching = False # Internal metadata to track thread utilization. report_start = time.time() event_occupancy = { 'pin': 0.0, 'unpin': 0.0, 'func_exec': 0.0, 'dag_queue': 0.0, 'dag_exec': 0.0 } total_occupancy = 0.0 while True: socks = dict(poller.poll(timeout=1000)) if pin_socket in socks and socks[pin_socket] == zmq.POLLIN: work_start = time.time() batching = pin(pin_socket, pusher_cache, client, status, function_cache, runtimes, exec_counts, user_library, local, batching) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['pin'] += elapsed total_occupancy += elapsed if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN: work_start = time.time() unpin(unpin_socket, status, function_cache, runtimes, exec_counts) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['unpin'] += elapsed total_occupancy += elapsed if exec_socket in socks and socks[exec_socket] == zmq.POLLIN: work_start = time.time() # logging.info(f'Executor timer. exec_socket recv: {work_start}') exec_function(exec_socket, client, user_library, cache, function_cache, has_ephe=has_ephe) user_library.close() utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['func_exec'] += elapsed total_occupancy += elapsed if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN: work_start = time.time() logging.info( f'Executor timer. dag_queue_socket recv: {work_start}') # In order to effectively support batching, we have to make sure we # dequeue lots of schedules in addition to lots of triggers. Right # now, we're not going to worry about supporting batching here, # just on the trigger dequeue side, but we still have to dequeue # all schedules we've received. We just process them one at a time. while True: schedule = DagSchedule() try: msg = dag_queue_socket.recv(zmq.DONTWAIT) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: break # There are no more messages. else: raise e # Unexpected error. schedule.ParseFromString(msg) fname = schedule.target_function logging.info( 'Received a schedule for DAG %s (%s), function %s.' % (schedule.dag.name, schedule.id, fname)) if fname not in queue: queue[fname] = {} queue[fname][schedule.id] = schedule if (schedule.id, fname) not in receive_times: receive_times[(schedule.id, fname)] = time.time() # In case we receive the trigger before we receive the schedule, we # can trigger from this operation as well. trkey = (schedule.id, fname) fref = None # Check to see what type of execution this function is. for ref in schedule.dag.functions: if ref.name == fname: fref = ref if (trkey in received_triggers and ((len(received_triggers[trkey]) == len(schedule.triggers)) or (fref.type == MULTIEXEC))): triggers = list(received_triggers[trkey].values()) if fname not in function_cache: logging.error('%s not in function cache', fname) utils.generate_error_response(schedule, client, fname) continue exec_start = time.time() # logging.info(f'Executor timer. dag_queue_socket exec_dag: {exec_start}') # We don't support actual batching for when we receive a # schedule before a trigger, so everything is just a batch of # size 1 if anything. success = exec_dag_function(pusher_cache, client, [triggers], function_cache[fname], [schedule], user_library, dag_runtimes, cache, schedulers, batching)[0] user_library.close() del received_triggers[trkey] if success: del queue[fname][schedule.id] fend = time.time() fstart = receive_times[(schedule.id, fname)] runtimes[fname].append(fend - work_start) exec_counts[fname] += 1 finished_executions[(schedule.id, fname)] = time.time() elapsed = time.time() - work_start event_occupancy['dag_queue'] += elapsed total_occupancy += elapsed if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN: work_start = time.time() # logging.info(f'Executor timer. dag_exec_socket recv: {work_start}') # How many messages to dequeue -- BATCH_SIZE_MAX or 1 depending on # the function configuration. if batching: count = BATCH_SIZE_MAX else: count = 1 trigger_keys = set() for _ in range(count): # Dequeue count number of messages. trigger = DagTrigger() try: msg = dag_exec_socket.recv(zmq.DONTWAIT) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: # There are no more messages. break else: raise e # Unexpected error. trigger.ParseFromString(msg) # We have received a repeated trigger for a function that has # already finished executing. if trigger.id in finished_executions: continue fname = trigger.target_function logging.info( 'Received a trigger for schedule %s, function %s.' % (trigger.id, fname)) key = (trigger.id, fname) trigger_keys.add(key) if key not in received_triggers: received_triggers[key] = {} if (trigger.id, fname) not in receive_times: receive_times[(trigger.id, fname)] = time.time() received_triggers[key][trigger.source] = trigger # Only execute the functions for which we have received a schedule. # Everything else will wait. for tid, fname in list(trigger_keys): if fname not in queue or tid not in queue[fname]: trigger_keys.remove((tid, fname)) if len(trigger_keys) == 0: continue fref = None schedule = queue[fname][list(trigger_keys)[0] [0]] # Pick a random schedule to check. # Check to see what type of execution this function is. for ref in schedule.dag.functions: if ref.name == fname: fref = ref break # Compile a list of all the trigger sets for which we have # enough triggers. trigger_sets = [] schedules = [] for key in trigger_keys: if (len(received_triggers[key]) == len(schedule.triggers)) or \ fref.type == MULTIEXEC: if fref.type == MULTIEXEC: triggers = [trigger] else: triggers = list(received_triggers[key].values()) if fname not in function_cache: logging.error('%s not in function cache', fname) utils.generate_error_response(schedule, client, fname) continue trigger_sets.append(triggers) schedule = queue[fname][key[0]] schedules.append(schedule) exec_start = time.time() # logging.info(f'Executor timer. dag_exec_socket exec_dag: {exec_start}') # Pass all of the trigger_sets into exec_dag_function at once. # We also include the batching variaible to make sure we know # whether to pass lists into the fn or not. if len(trigger_sets) > 0: successes = exec_dag_function(pusher_cache, client, trigger_sets, function_cache[fname], schedules, user_library, dag_runtimes, cache, schedulers, batching) user_library.close() del received_triggers[key] for key, success in zip(trigger_keys, successes): if success: del queue[fname][key[0]] # key[0] is trigger.id. fend = time.time() fstart = receive_times[key] average_time = (fend - work_start) / len(trigger_keys) runtimes[fname].append(average_time) exec_counts[fname] += 1 finished_executions[(schedule.id, fname)] = time.time() elapsed = time.time() - work_start event_occupancy['dag_exec'] += elapsed total_occupancy += elapsed if self_depart_socket in socks and socks[self_depart_socket] == \ zmq.POLLIN: # This message does not matter. self_depart_socket.recv() logging.info('Preparing to depart. No longer accepting requests ' + 'and clearing all queues.') status.ClearField('functions') status.running = False utils.push_status(schedulers, pusher_cache, status) departing = True # periodically report function occupancy report_end = time.time() if report_end - report_start > REPORT_THRESH: if len(cache) > 100: extra_keys = list(cache.keys())[:len(cache) - 100] for key in extra_keys: del cache[key] utilization = total_occupancy / (report_end - report_start) status.utilization = utilization # Periodically report my status to schedulers with the utilization # set. utils.push_status(schedulers, pusher_cache, status) logging.debug('Total thread occupancy: %.6f' % (utilization)) for event in event_occupancy: occ = event_occupancy[event] / (report_end - report_start) logging.debug('\tEvent %s occupancy: %.6f' % (event, occ)) event_occupancy[event] = 0.0 stats = ExecutorStatistics() for fname in runtimes: if exec_counts[fname] > 0: fstats = stats.functions.add() fstats.name = fname fstats.call_count = exec_counts[fname] fstats.runtime.extend(runtimes[fname]) runtimes[fname].clear() exec_counts[fname] = 0 for dname in dag_runtimes: dstats = stats.dags.add() dstats.name = dname dstats.runtimes.extend(dag_runtimes[dname]) dag_runtimes[dname].clear() # If we are running in cluster mode, mgmt_ip will be set, and we # will report our status and statistics to it. Otherwise, we will # write to the local conf file if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) sckt = pusher_cache.get(utils.get_util_report_address(mgmt_ip)) sckt.send(status.SerializeToString()) else: logging.info(stats) status.ClearField('utilization') report_start = time.time() total_occupancy = 0.0 # Periodically clear any old functions we have cached that we are # no longer accepting requests for. del_list = [] for fname in queue: if len(queue[fname]) == 0 and fname not in status.functions: del_list.append(fname) del function_cache[fname] del runtimes[fname] del exec_counts[fname] for fname in del_list: del queue[fname] del_list = [] for tid in finished_executions: if (time.time() - finished_executions[tid]) > 10: del_list.append(tid) for tid in del_list: del finished_executions[tid] # If we are departing and have cleared our queues, let the # management server know, and exit the process. if departing and len(queue) == 0: sckt = pusher_cache.get(utils.get_depart_done_addr(mgmt_ip)) sckt.send_string(ip) # We specifically pass 1 as the exit code when ending our # process so that the wrapper script does not restart us. sys.exit(1)
def scheduler(ip, mgmt_ip, route_addr): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = [] connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500) pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) requestor_cache = SocketCache(context, zmq.REQ) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname not in call_frequency: call_frequency[fname] = 0 policy.update_function_locations(status.function_locations) end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if mgmt_ip: schedulers = sched_utils.get_ip_set( sched_utils.get_scheduler_list_address(mgmt_ip), requestor_cache, False) if end - start > REPORT_THRESHOLD: num_unique_executors = policy.get_unique_executors() key = scheduler_id + ':' + str(time.time()) data = {'key': key, 'count': num_unique_executors} status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.info('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()
def executor(ip, mgmt_ip, schedulers, thread_id): logging.basicConfig(filename='log_executor.txt', level=logging.INFO, format='%(asctime)s %(message)s') context = zmq.Context(1) poller = zmq.Poller() pin_socket = context.socket(zmq.PULL) pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id)) unpin_socket = context.socket(zmq.PULL) unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.UNPIN_PORT + thread_id)) exec_socket = context.socket(zmq.PULL) exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.FUNC_EXEC_PORT + thread_id)) dag_queue_socket = context.socket(zmq.PULL) dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_QUEUE_PORT + thread_id)) dag_exec_socket = context.socket(zmq.PULL) dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_EXEC_PORT + thread_id)) self_depart_socket = context.socket(zmq.PULL) self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SELF_DEPART_PORT + thread_id)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(pin_socket, zmq.POLLIN) poller.register(unpin_socket, zmq.POLLIN) poller.register(exec_socket, zmq.POLLIN) poller.register(dag_queue_socket, zmq.POLLIN) poller.register(dag_exec_socket, zmq.POLLIN) poller.register(self_depart_socket, zmq.POLLIN) # If the management IP is set to None, that means that we are running in # local mode, so we use a regular AnnaTcpClient rather than an IPC client. if mgmt_ip: client = AnnaIpcClient(thread_id, context) else: client = AnnaTcpClient('127.0.0.1', '127.0.0.1', local=True, offset=1) user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id, client) status = ThreadStatus() status.ip = ip status.tid = thread_id status.running = True utils.push_status(schedulers, pusher_cache, status) departing = False # Maintains a request queue for each function pinned on this executor. Each # function will have a set of request IDs mapped to it, and this map stores # a schedule for each request ID. queue = {} # Tracks the actual function objects that are pinned to this executor. function_cache = {} # Tracks runtime cost of excuting a DAG function. runtimes = {} # If multiple triggers are necessary for a function, track the triggers as # we receive them. This is also used if a trigger arrives before its # corresponding schedule. received_triggers = {} # Tracks when we received a function request, so we can report end-to-end # latency for the whole executio. receive_times = {} # Tracks the number of requests we are finishing for each function pinned # here. exec_counts = {} # Tracks the end-to-end runtime of each DAG request for which we are the # sink function. dag_runtimes = {} # A map with KVS keys and their corresponding deserialized payloads. cache = {} # Internal metadata to track thread utilization. report_start = time.time() event_occupancy = { 'pin': 0.0, 'unpin': 0.0, 'func_exec': 0.0, 'dag_queue': 0.0, 'dag_exec': 0.0 } total_occupancy = 0.0 while True: socks = dict(poller.poll(timeout=1000)) if pin_socket in socks and socks[pin_socket] == zmq.POLLIN: work_start = time.time() pin(pin_socket, pusher_cache, client, status, function_cache, runtimes, exec_counts, user_library) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['pin'] += elapsed total_occupancy += elapsed if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN: work_start = time.time() unpin(unpin_socket, status, function_cache, runtimes, exec_counts) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['unpin'] += elapsed total_occupancy += elapsed if exec_socket in socks and socks[exec_socket] == zmq.POLLIN: work_start = time.time() exec_function(exec_socket, client, user_library, cache, function_cache) user_library.close() utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['func_exec'] += elapsed total_occupancy += elapsed if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN: work_start = time.time() schedule = DagSchedule() schedule.ParseFromString(dag_queue_socket.recv()) fname = schedule.target_function logging.info('Received a schedule for DAG %s (%s), function %s.' % (schedule.dag.name, schedule.id, fname)) if fname not in queue: queue[fname] = {} queue[fname][schedule.id] = schedule if (schedule.id, fname) not in receive_times: receive_times[(schedule.id, fname)] = time.time() # In case we receive the trigger before we receive the schedule, we # can trigger from this operation as well. trkey = (schedule.id, fname) if (trkey in received_triggers and (len(received_triggers[trkey]) == len(schedule.triggers))): exec_dag_function(pusher_cache, client, received_triggers[trkey], function_cache[fname], schedule, user_library, dag_runtimes, cache) user_library.close() del received_triggers[trkey] del queue[fname][schedule.id] fend = time.time() fstart = receive_times[(schedule.id, fname)] runtimes[fname].append(fend - fstart) exec_counts[fname] += 1 elapsed = time.time() - work_start event_occupancy['dag_queue'] += elapsed total_occupancy += elapsed if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN: work_start = time.time() trigger = DagTrigger() trigger.ParseFromString(dag_exec_socket.recv()) fname = trigger.target_function logging.info('Received a trigger for schedule %s, function %s.' % (trigger.id, fname)) key = (trigger.id, fname) if key not in received_triggers: received_triggers[key] = {} if (trigger.id, fname) not in receive_times: receive_times[(trigger.id, fname)] = time.time() received_triggers[key][trigger.source] = trigger if fname in queue and trigger.id in queue[fname]: schedule = queue[fname][trigger.id] if len(received_triggers[key]) == len(schedule.triggers): exec_dag_function(pusher_cache, client, received_triggers[key], function_cache[fname], schedule, user_library, dag_runtimes, cache) user_library.close() del received_triggers[key] del queue[fname][trigger.id] fend = time.time() fstart = receive_times[(trigger.id, fname)] runtimes[fname].append(fend - fstart) exec_counts[fname] += 1 elapsed = time.time() - work_start event_occupancy['dag_exec'] += elapsed total_occupancy += elapsed if self_depart_socket in socks and socks[self_depart_socket] == \ zmq.POLLIN: # This message does not matter. self_depart_socket.recv() logging.info('Preparing to depart. No longer accepting requests ' + 'and clearing all queues.') status.ClearField('functions') status.running = False utils.push_status(schedulers, pusher_cache, status) departing = True # periodically report function occupancy report_end = time.time() if report_end - report_start > REPORT_THRESH: cache.clear() utilization = total_occupancy / (report_end - report_start) status.utilization = utilization # Periodically report my status to schedulers with the utilization # set. utils.push_status(schedulers, pusher_cache, status) logging.info('Total thread occupancy: %.6f' % (utilization)) for event in event_occupancy: occ = event_occupancy[event] / (report_end - report_start) logging.info('\tEvent %s occupancy: %.6f' % (event, occ)) event_occupancy[event] = 0.0 stats = ExecutorStatistics() for fname in runtimes: if exec_counts[fname] > 0: fstats = stats.functions.add() fstats.name = fname fstats.call_count = exec_counts[fname] fstats.runtime.extend(runtimes[fname]) runtimes[fname].clear() exec_counts[fname] = 0 for dname in dag_runtimes: dstats = stats.dags.add() dstats.name = dname dstats.runtimes.extend(dag_runtimes[dname]) dag_runtimes[dname].clear() # If we are running in cluster mode, mgmt_ip will be set, and we # will report our status and statistics to it. Otherwise, we will # write to the local conf file if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) sckt = pusher_cache.get(utils.get_util_report_address(mgmt_ip)) sckt.send(status.SerializeToString()) else: logging.info(stats) status.ClearField('utilization') report_start = time.time() total_occupancy = 0.0 # Periodically clear any old functions we have cached that we are # no longer accepting requests for. del_list = [] for fname in queue: if len(queue[fname]) == 0 and fname not in status.functions: del_list.append(fname) del function_cache[fname] del runtimes[fname] del exec_counts[fname] for fname in del_list: del queue[fname] # If we are departing and have cleared our queues, let the # management server know, and exit the process. if departing and len(queue) == 0: sckt = pusher_cache.get(utils.get_depart_done_addr(mgmt_ip)) sckt.send_string(ip) # We specifically pass 1 as the exit code when ending our # process so that the wrapper script does not restart us. os._exit(1)
def scheduler(ip, mgmt_ip, route_addr, policy_type): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) context.set(zmq.MAX_SOCKETS, 10000) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = set() connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) # This is for handle the invocation from queue # Mainly for storage event func_call_queue_socket = context.socket(zmq.PULL) func_call_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_QUEUE_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 seconds. pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) continuation_socket = context.socket(zmq.PULL) continuation_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.CONTINUATION_PORT)) if not local: management_request_socket = context.socket(zmq.REQ) management_request_socket.setsockopt(zmq.RCVTIMEO, 500) # By setting this flag, zmq matches replies with requests. management_request_socket.setsockopt(zmq.REQ_CORRELATE, 1) # Relax strict alternation between request and reply. # For detailed explanation, see here: http://api.zeromq.org/4-1:zmq-setsockopt management_request_socket.setsockopt(zmq.REQ_RELAXED, 1) management_request_socket.connect( sched_utils.get_scheduler_list_address(mgmt_ip)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(func_call_queue_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) poller.register(continuation_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, policy_type, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if func_call_queue_socket in socks and socks[ func_call_queue_socket] == zmq.POLLIN: call_function_from_queue(func_call_queue_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: start_t = int(time.time() * 1000000) call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) sched_t = int(time.time() * 1000000) logging.info( f'App function {name} recv: {start_t}, scheduled: {sched_t}') dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname.name not in call_frequency: call_frequency[fname.name] = 0 policy.update_function_locations(status.function_locations) if continuation_socket in socks and socks[continuation_socket] == \ zmq.POLLIN: start_t = int(time.time() * 1000000) continuation = Continuation() continuation.ParseFromString(continuation_socket.recv()) call = continuation.call call.name = continuation.name result = Value() result.ParseFromString(continuation.result) dag, sources = dags[call.name] for source in sources: call.function_args[source].values.extend([result]) call_dag(call, pusher_cache, dags, policy, continuation.id) sched_t = int(time.time() * 1000000) print( f'App function {call.name} recv: {start_t}, scheduled: {sched_t}' ) for fname in dag.functions: call_frequency[fname.name] += 1 end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if not local: latest_schedulers = sched_utils.get_ip_set( management_request_socket, False) if latest_schedulers: schedulers = latest_schedulers if end - start > REPORT_THRESHOLD: status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.debug('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()