def scheduler(ip, mgmt_ip, route_addr): logging.basicConfig(filename='log_scheduler.txt', level=logging.INFO, format='%(asctime)s %(message)s') kvs = AnnaClient(route_addr, ip) key_ip_map = {} ctx = zmq.Context(1) # Each dag consists of a set of functions and connections. Each one of # the functions is pinned to one or more nodes, which is tracked here. dags = {} thread_statuses = {} func_locations = {} running_counts = {} backoff = {} connect_socket = ctx.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = ctx.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = ctx.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) dag_create_socket = ctx.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = ctx.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) list_socket = ctx.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = ctx.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = ctx.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) backoff_socket = ctx.socket(zmq.PULL) backoff_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.BACKOFF_PORT)) pin_accept_socket = ctx.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500) pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) requestor_cache = SocketCache(ctx, zmq.REQ) pusher_cache = SocketCache(ctx, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) poller.register(backoff_socket, zmq.POLLIN) executors = set() executor_status_map = {} schedulers = _update_cluster_state(requestor_cache, mgmt_ip, executors, key_ip_map, kvs) # track how often each DAG function is called call_frequency = {} start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_func(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, executors, key_ip_map, executor_status_map, running_counts, backoff) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, executors, dags, ip, pin_accept_socket, func_locations, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: call = DagCall() call.ParseFromString(dag_call_socket.recv()) if call.name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue exec_id = generate_timestamp(0) dag = dags[call.name] for fname in dag[0].functions: call_frequency[fname] += 1 rid = call_dag(call, pusher_cache, dags, func_locations, key_ip_map, running_counts, backoff) resp = GenericResponse() resp.success = True resp.response_id = rid dag_call_socket.send(resp.SerializeToString()) if list_socket in socks and socks[list_socket] == zmq.POLLIN: logging.info('Received query for function list.') msg = list_socket.recv_string() prefix = msg if msg else '' resp = FunctionList() resp.names.extend(utils._get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) key = (status.ip, status.tid) logging.info('Received status update from executor %s:%d.' % (key[0], int(key[1]))) if key in executor_status_map: if status.type == PERIODIC: if executor_status_map[key] - time.time() > 5: del executor_status_map[key] else: continue elif status.type == POST_REQUEST: del executor_status_map[key] # this means that this node is currently departing, so we remove it # from all of our metadata tracking if not status.running: if key in thread_statuses: old_status = thread_statuses[key] del thread_statuses[key] for fname in old_status.functions: func_locations[fname].discard((old_status.ip, old_status.tid)) executors.discard(key) continue if key not in executors: executors.add(key) if key in thread_statuses and thread_statuses[key] != status: # remove all the old function locations, and all the new ones # -- there will probably be a large overlap, but this shouldn't # be much different than calculating two different set # differences anyway for func in thread_statuses[key].functions: if func in func_locations: func_locations[func].discard(key) thread_statuses[key] = status for func in status.functions: if func not in func_locations: func_locations[func] = set() func_locations[func].add(key) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # retrieve any DAG that some other scheduler knows about that we do # not yet know about for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while not payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload.reveal()[1]) dags[dag.name] = (dag, utils._find_dag_source(dag)) for fname in dag.functions: if fname not in call_frequency: call_frequency[fname] = 0 if fname not in func_locations: func_locations[fname] = set() for floc in status.func_locations: key = (floc.ip, floc.tid) fname = floc.name if fname not in func_locations: func_locations[fname] = set() func_locations[fname].add(key) if backoff_socket in socks and socks[backoff_socket] == zmq.POLLIN: msg = backoff_socket.recv_string() splits = msg.split(':') node, tid = splits[0], int(splits[1]) backoff[(node, tid)] = time.time() # periodically clean up the running counts map for executor in running_counts: call_times = running_counts[executor] new_set = set() for ts in call_times: if time.time() - ts < 2.5: new_set.add(ts) running_counts[executor] = new_set remove_set = set() for executor in backoff: if time.time() - backoff[executor] > 5: remove_set.add(executor) for executor in remove_set: del backoff[executor] end = time.time() if end - start > THRESHOLD: schedulers = _update_cluster_state(requestor_cache, mgmt_ip, executors, key_ip_map, kvs) status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in func_locations: for loc in func_locations[fname]: floc = status.func_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get(utils._get_scheduler_update_address (sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.statistics.add() fstats.fname = fname fstats.call_count = call_frequency[fname] logging.info('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 sckt = pusher_cache.get(sutils._get_statistics_report_address (mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()
class AnnaTcpClient(BaseAnnaClient): def __init__(self, elb_addr, ip, local=False, offset=0): ''' The AnnaTcpClientTcpAnnaClient allows you to interact with a local copy of Anna or with a remote cluster running on AWS. elb_addr: Either 127.0.0.1 (local mode) or the address of an AWS ELB for the routing tier ip: The IP address of the machine being used -- if None is provided, one is inferred by using socket.gethostbyname(); WARNING: this does not always work elb_ports: The ports on which the routing tier will listen; use 6450 if running in local mode, otherwise do not change offset: A port numbering offset, which is only needed if multiple clients are running on the same machine ''' self.elb_addr = elb_addr if local: self.elb_ports = [6450] else: self.elb_ports = list(range(6450, 6454)) if ip: self.ut = UserThread(ip, offset) else: # If the IP is not provided, we attempt to infer it. self.ut = UserThread(socket.gethostbyname(socket.gethostname()), offset) self.context = zmq.Context(1) self.get_address_cache = {} self.put_address_cache = {} self.pusher_cache = SocketCache(self.context, zmq.PUSH) self.response_puller = self.context.socket(zmq.PULL) self.response_puller.bind(self.ut.get_request_pull_bind_addr()) self.key_address_puller = self.context.socket(zmq.PULL) self.key_address_puller.bind(self.ut.get_key_address_bind_addr()) self.rid = 0 def get(self, keys): if type(keys) != list: keys = [keys] worker_addresses = {} for key in keys: worker_addresses[key] = (self._get_worker_address(key, 1)) #print("Worker Address: {}".format(worker_addresses[key])) if type(worker_addresses[key]) == list: worker_addresses[key] = worker_addresses[key][0] # Initialize all KV pairs to 0. Only change a value if we get a valid # response from the server. kv_pairs = {} for key in keys: kv_pairs[key] = None request_ids = [] for key in worker_addresses: if worker_addresses[key]: send_sock = self.pusher_cache.get(worker_addresses[key]) req, _ = self._prepare_data_request([key]) req.type = GET send_request(req, send_sock) request_ids.append(req.request_id) # Wait for all responses to return. responses = recv_response(request_ids, self.response_puller, KeyResponse) for response in responses: for tup in response.tuples: if tup.invalidate: self._invalidate_cache(tup.key, 'get') if tup.error == NO_ERROR and not tup.invalidate: kv_pairs[tup.key] = self._deserialize(tup) return kv_pairs def get_all(self, keys): if type(keys) != list: keys = [keys] raise ValueError('`get_all` currently only supports single key' + ' GETs.') worker_addresses = {} for key in keys: worker_addresses[key] = self._get_worker_address(key, False) # Initialize all KV pairs to 0. Only change a value if we get a valid # response from the server. kv_pairs = {} for key in keys: kv_pairs[key] = None for key in keys: if worker_addresses[key]: req, _ = self._prepare_data_request(key) req.type = GET req_ids = [] for address in worker_addresses[key]: req.request_id = self._get_request_id() send_sock = self.pusher_cache.get(address) send_request(req, send_sock) req_ids.append(req.request_id) responses = recv_response(req_ids, self.response_puller, KeyResponse) for resp in responses: for tup in resp.tuples: if tup.invalidate: self._invalidate_cache(tup.key) if tup.error == NO_ERROR: val = self._deserialize(tup) if kv_pairs[tup.key]: kv_pairs[tup.key].merge(val) else: kv_pairs[tup.key] = val return kv_pairs def put(self, key, value): port = random.choice(self.elb_ports) worker_address = self._query_routing(key, port) if type(worker_address) == list: worker_address = worker_address[0] if not worker_address: return False send_sock = self.pusher_cache.get(worker_address) # We pass in a list because the data request preparation can prepare # multiple tuples req, tup = self._prepare_data_request([key]) req.type = PUT # PUT only supports one key operations, we only ever have to look at # the first KeyTuple returned. tup = tup[0] tup.payload, tup.lattice_type = self._serialize(value) send_request(req, send_sock) response = recv_response([req.request_id], self.response_puller, KeyResponse)[0] tup = response.tuples[0] if tup.invalidate: self._invalidate_cache(tup.key) return tup.error == NO_ERROR def put_all(self, key, value): worker_addresses = self._get_worker_address(key, False) if not worker_addresses: return False req, tup = self._prepare_data_request(key) req.type = PUT tup.payload, tup.lattice_type = self._serialize(value) tup.timestamp = 0 req_ids = [] for address in worker_addresses: req.request_id = self._get_request_id() send_sock = self.pusher_cache.get(address) send_request(req, send_sock) req_ids.append(req.request_id) responses = recv_response(req_ids, self.response_puller, KeyResponse) for resp in responses: tup = resp.tuples[0] if tup.invalidate: # reissue the request self._invalidate_cache(tup.key) return self.durable_put(key, value) if tup.error != NO_ERROR: return False return True # Returns the worker address for a particular key. If worker addresses for # that key are not cached locally, a query is synchronously issued to the # routing tier, and the address cache is updated. def _get_worker_address(self, key, access_type, pick=True): insert_to_cache = False monitor_address = None #if it's a GET if access_type == 1: if key not in self.get_address_cache: #Key is not in cache port = random.choice(self.elb_ports) addresses = self._query_routing(key, port) addresses = list(set(addresses)) for address in addresses: #TODO: Change here the IP to the address of the Master node! if address != "tcp://192.168.0.31:5900": self.get_address_cache[key] = [] self.get_address_cache[key].append(address) else: monitor_address = address #print(self.get_address_cache) if len(self.get_address_cache[key]) == 0: return None if pick: return random.choice(self.get_address_cache[key]) else: return self.get_address_cache[key] # if it's a PUT if access_type == 2: if key not in self.put_address_cache: #print("Key is not in cache") port = random.choice(self.elb_ports) #print("Port: {}".format(port)) addresses = self._query_routing(key, port) addresses = list(set(addresses)) #print(addresses) #print(self.put_address_cache) for address in addresses: #TODO: Change here the IP to the address of the Master node! if address != "tcp://192.168.0.31:5900": self.put_address_cache[key] = [] self.put_address_cache[key].append(address) insert_to_cache = True else: monitor_address = address #print(self.put_address_cache) if len(self.put_address_cache[key]) == 0: #print('1') return None if pick: #print('2') return random.choice(self.put_address_cache[key]) else: #print('3') return self.put_address_cache[key] # Invalidates the address cache for a particular key when the server tells # the client that its cache is out of date. def _invalidate_cache(self, key, type='both'): if type == 'get': del self.get_address_cache[key] elif type == 'put': del self.put_address_cache[key] else: try: del self.get_address_cache[key] except KeyError: pass try: del self.put_address_cache[key] except KeyError: pass # Issues a synchronous query to the routing tier. Takes in a key and a # (randomly chosen) routing port to issue the request to. Returns a list of # addresses that the routing tier returned that correspond to the input # key. def _query_routing(self, key, port): key_request = KeyAddressRequest() key_request.query_type = u'GET' key_request.response_address = self.ut.get_key_address_connect_addr() key_request.keys.append(key) key_request.request_id = self._get_request_id() dst_addr = 'tcp://' + self.elb_addr + ':' + str(port) send_sock = self.pusher_cache.get(dst_addr) send_request(key_request, send_sock) response = recv_response([key_request.request_id], self.key_address_puller, KeyAddressResponse)[0] if response.error != 0: return [] result = [] for t in response.addresses: if t.key == key: for a in t.ips: result.append(a) return result @property def response_address(self): return self.ut.get_request_pull_connect_addr()
def scheduler(ip, mgmt_ip, route_addr): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = [] connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500) pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) requestor_cache = SocketCache(context, zmq.REQ) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname not in call_frequency: call_frequency[fname] = 0 policy.update_function_locations(status.function_locations) end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if mgmt_ip: schedulers = sched_utils.get_ip_set( sched_utils.get_scheduler_list_address(mgmt_ip), requestor_cache, False) if end - start > REPORT_THRESHOLD: num_unique_executors = policy.get_unique_executors() key = scheduler_id + ':' + str(time.time()) data = {'key': key, 'count': num_unique_executors} status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.info('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()
def executor(ip, mgmt_ip, schedulers, thread_id): # logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s %(message)s') logging.basicConfig(filename='log_executor.txt', level=logging.INFO, filemode="w", format='%(asctime)s %(message)s') # Check what resources we have access to, set as an environment variable. if os.getenv('EXECUTOR_TYPE', 'CPU') == 'GPU': exec_type = GPU else: exec_type = CPU context = zmq.Context(1) poller = zmq.Poller() pin_socket = context.socket(zmq.PULL) pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id)) unpin_socket = context.socket(zmq.PULL) unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.UNPIN_PORT + thread_id)) exec_socket = context.socket(zmq.PULL) exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.FUNC_EXEC_PORT + thread_id)) dag_queue_socket = context.socket(zmq.PULL) dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_QUEUE_PORT + thread_id)) dag_exec_socket = context.socket(zmq.PULL) dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_EXEC_PORT + thread_id)) self_depart_socket = context.socket(zmq.PULL) self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SELF_DEPART_PORT + thread_id)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(pin_socket, zmq.POLLIN) poller.register(unpin_socket, zmq.POLLIN) poller.register(exec_socket, zmq.POLLIN) poller.register(dag_queue_socket, zmq.POLLIN) poller.register(dag_exec_socket, zmq.POLLIN) poller.register(self_depart_socket, zmq.POLLIN) # If the management IP is set to None, that means that we are running in # local mode, so we use a regular AnnaTcpClient rather than an IPC client. has_ephe = False if mgmt_ip: if 'STORAGE_OR_DEFAULT' in os.environ and os.environ[ 'STORAGE_OR_DEFAULT'] == '0': client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id) has_ephe = True else: client = AnnaIpcClient(thread_id, context) # force_remote_anna = 1 # if 'FORCE_REMOTE' in os.environ: # force_remote_anna = int(os.environ['FORCE_REMOTE']) # if force_remote_anna == 0: # remote anna only # client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id) # elif force_remote_anna == 1: # anna cache # client = AnnaIpcClient(thread_id, context) # elif force_remote_anna == 2: # control both cache and remote anna # remote_client = AnnaTcpClient(os.environ['ROUTE_ADDR'], ip, local=False, offset=thread_id) # cache_client = AnnaIpcClient(thread_id, context) # client = cache_client # user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id, (cache_client, remote_client)) local = False else: client = AnnaTcpClient('127.0.0.1', '127.0.0.1', local=True, offset=1) local = True user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id, client, has_ephe=has_ephe) status = ThreadStatus() status.ip = ip status.tid = thread_id status.running = True status.type = exec_type utils.push_status(schedulers, pusher_cache, status) departing = False # Maintains a request queue for each function pinned on this executor. Each # function will have a set of request IDs mapped to it, and this map stores # a schedule for each request ID. queue = {} # Tracks the actual function objects that are pinned to this executor. function_cache = {} # Tracks runtime cost of excuting a DAG function. runtimes = {} # If multiple triggers are necessary for a function, track the triggers as # we receive them. This is also used if a trigger arrives before its # corresponding schedule. received_triggers = {} # Tracks when we received a function request, so we can report end-to-end # latency for the whole executio. receive_times = {} # Tracks the number of requests we are finishing for each function pinned # here. exec_counts = {} # Tracks the end-to-end runtime of each DAG request for which we are the # sink function. dag_runtimes = {} # A map with KVS keys and their corresponding deserialized payloads. cache = {} # A map which tracks the most recent DAGs for which we have finished our # work. finished_executions = {} # The set of pinned functions and whether they support batching. NOTE: This # is only a set for local mode -- in cluster mode, there will only be one # pinned function per executor. batching = False # Internal metadata to track thread utilization. report_start = time.time() event_occupancy = { 'pin': 0.0, 'unpin': 0.0, 'func_exec': 0.0, 'dag_queue': 0.0, 'dag_exec': 0.0 } total_occupancy = 0.0 while True: socks = dict(poller.poll(timeout=1000)) if pin_socket in socks and socks[pin_socket] == zmq.POLLIN: work_start = time.time() batching = pin(pin_socket, pusher_cache, client, status, function_cache, runtimes, exec_counts, user_library, local, batching) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['pin'] += elapsed total_occupancy += elapsed if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN: work_start = time.time() unpin(unpin_socket, status, function_cache, runtimes, exec_counts) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['unpin'] += elapsed total_occupancy += elapsed if exec_socket in socks and socks[exec_socket] == zmq.POLLIN: work_start = time.time() # logging.info(f'Executor timer. exec_socket recv: {work_start}') exec_function(exec_socket, client, user_library, cache, function_cache, has_ephe=has_ephe) user_library.close() utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['func_exec'] += elapsed total_occupancy += elapsed if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN: work_start = time.time() logging.info( f'Executor timer. dag_queue_socket recv: {work_start}') # In order to effectively support batching, we have to make sure we # dequeue lots of schedules in addition to lots of triggers. Right # now, we're not going to worry about supporting batching here, # just on the trigger dequeue side, but we still have to dequeue # all schedules we've received. We just process them one at a time. while True: schedule = DagSchedule() try: msg = dag_queue_socket.recv(zmq.DONTWAIT) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: break # There are no more messages. else: raise e # Unexpected error. schedule.ParseFromString(msg) fname = schedule.target_function logging.info( 'Received a schedule for DAG %s (%s), function %s.' % (schedule.dag.name, schedule.id, fname)) if fname not in queue: queue[fname] = {} queue[fname][schedule.id] = schedule if (schedule.id, fname) not in receive_times: receive_times[(schedule.id, fname)] = time.time() # In case we receive the trigger before we receive the schedule, we # can trigger from this operation as well. trkey = (schedule.id, fname) fref = None # Check to see what type of execution this function is. for ref in schedule.dag.functions: if ref.name == fname: fref = ref if (trkey in received_triggers and ((len(received_triggers[trkey]) == len(schedule.triggers)) or (fref.type == MULTIEXEC))): triggers = list(received_triggers[trkey].values()) if fname not in function_cache: logging.error('%s not in function cache', fname) utils.generate_error_response(schedule, client, fname) continue exec_start = time.time() # logging.info(f'Executor timer. dag_queue_socket exec_dag: {exec_start}') # We don't support actual batching for when we receive a # schedule before a trigger, so everything is just a batch of # size 1 if anything. success = exec_dag_function(pusher_cache, client, [triggers], function_cache[fname], [schedule], user_library, dag_runtimes, cache, schedulers, batching)[0] user_library.close() del received_triggers[trkey] if success: del queue[fname][schedule.id] fend = time.time() fstart = receive_times[(schedule.id, fname)] runtimes[fname].append(fend - work_start) exec_counts[fname] += 1 finished_executions[(schedule.id, fname)] = time.time() elapsed = time.time() - work_start event_occupancy['dag_queue'] += elapsed total_occupancy += elapsed if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN: work_start = time.time() # logging.info(f'Executor timer. dag_exec_socket recv: {work_start}') # How many messages to dequeue -- BATCH_SIZE_MAX or 1 depending on # the function configuration. if batching: count = BATCH_SIZE_MAX else: count = 1 trigger_keys = set() for _ in range(count): # Dequeue count number of messages. trigger = DagTrigger() try: msg = dag_exec_socket.recv(zmq.DONTWAIT) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: # There are no more messages. break else: raise e # Unexpected error. trigger.ParseFromString(msg) # We have received a repeated trigger for a function that has # already finished executing. if trigger.id in finished_executions: continue fname = trigger.target_function logging.info( 'Received a trigger for schedule %s, function %s.' % (trigger.id, fname)) key = (trigger.id, fname) trigger_keys.add(key) if key not in received_triggers: received_triggers[key] = {} if (trigger.id, fname) not in receive_times: receive_times[(trigger.id, fname)] = time.time() received_triggers[key][trigger.source] = trigger # Only execute the functions for which we have received a schedule. # Everything else will wait. for tid, fname in list(trigger_keys): if fname not in queue or tid not in queue[fname]: trigger_keys.remove((tid, fname)) if len(trigger_keys) == 0: continue fref = None schedule = queue[fname][list(trigger_keys)[0] [0]] # Pick a random schedule to check. # Check to see what type of execution this function is. for ref in schedule.dag.functions: if ref.name == fname: fref = ref break # Compile a list of all the trigger sets for which we have # enough triggers. trigger_sets = [] schedules = [] for key in trigger_keys: if (len(received_triggers[key]) == len(schedule.triggers)) or \ fref.type == MULTIEXEC: if fref.type == MULTIEXEC: triggers = [trigger] else: triggers = list(received_triggers[key].values()) if fname not in function_cache: logging.error('%s not in function cache', fname) utils.generate_error_response(schedule, client, fname) continue trigger_sets.append(triggers) schedule = queue[fname][key[0]] schedules.append(schedule) exec_start = time.time() # logging.info(f'Executor timer. dag_exec_socket exec_dag: {exec_start}') # Pass all of the trigger_sets into exec_dag_function at once. # We also include the batching variaible to make sure we know # whether to pass lists into the fn or not. if len(trigger_sets) > 0: successes = exec_dag_function(pusher_cache, client, trigger_sets, function_cache[fname], schedules, user_library, dag_runtimes, cache, schedulers, batching) user_library.close() del received_triggers[key] for key, success in zip(trigger_keys, successes): if success: del queue[fname][key[0]] # key[0] is trigger.id. fend = time.time() fstart = receive_times[key] average_time = (fend - work_start) / len(trigger_keys) runtimes[fname].append(average_time) exec_counts[fname] += 1 finished_executions[(schedule.id, fname)] = time.time() elapsed = time.time() - work_start event_occupancy['dag_exec'] += elapsed total_occupancy += elapsed if self_depart_socket in socks and socks[self_depart_socket] == \ zmq.POLLIN: # This message does not matter. self_depart_socket.recv() logging.info('Preparing to depart. No longer accepting requests ' + 'and clearing all queues.') status.ClearField('functions') status.running = False utils.push_status(schedulers, pusher_cache, status) departing = True # periodically report function occupancy report_end = time.time() if report_end - report_start > REPORT_THRESH: if len(cache) > 100: extra_keys = list(cache.keys())[:len(cache) - 100] for key in extra_keys: del cache[key] utilization = total_occupancy / (report_end - report_start) status.utilization = utilization # Periodically report my status to schedulers with the utilization # set. utils.push_status(schedulers, pusher_cache, status) logging.debug('Total thread occupancy: %.6f' % (utilization)) for event in event_occupancy: occ = event_occupancy[event] / (report_end - report_start) logging.debug('\tEvent %s occupancy: %.6f' % (event, occ)) event_occupancy[event] = 0.0 stats = ExecutorStatistics() for fname in runtimes: if exec_counts[fname] > 0: fstats = stats.functions.add() fstats.name = fname fstats.call_count = exec_counts[fname] fstats.runtime.extend(runtimes[fname]) runtimes[fname].clear() exec_counts[fname] = 0 for dname in dag_runtimes: dstats = stats.dags.add() dstats.name = dname dstats.runtimes.extend(dag_runtimes[dname]) dag_runtimes[dname].clear() # If we are running in cluster mode, mgmt_ip will be set, and we # will report our status and statistics to it. Otherwise, we will # write to the local conf file if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) sckt = pusher_cache.get(utils.get_util_report_address(mgmt_ip)) sckt.send(status.SerializeToString()) else: logging.info(stats) status.ClearField('utilization') report_start = time.time() total_occupancy = 0.0 # Periodically clear any old functions we have cached that we are # no longer accepting requests for. del_list = [] for fname in queue: if len(queue[fname]) == 0 and fname not in status.functions: del_list.append(fname) del function_cache[fname] del runtimes[fname] del exec_counts[fname] for fname in del_list: del queue[fname] del_list = [] for tid in finished_executions: if (time.time() - finished_executions[tid]) > 10: del_list.append(tid) for tid in del_list: del finished_executions[tid] # If we are departing and have cleared our queues, let the # management server know, and exit the process. if departing and len(queue) == 0: sckt = pusher_cache.get(utils.get_depart_done_addr(mgmt_ip)) sckt.send_string(ip) # We specifically pass 1 as the exit code when ending our # process so that the wrapper script does not restart us. sys.exit(1)
def run(self_ip): context = zmq.Context(1) pusher_cache = SocketCache(context, zmq.PUSH) restart_pull_socket = context.socket(zmq.REP) restart_pull_socket.bind('tcp://*:7000') churn_pull_socket = context.socket(zmq.PULL) churn_pull_socket.bind('tcp://*:7001') list_executors_socket = context.socket(zmq.PULL) list_executors_socket.bind('tcp://*:7002') function_status_socket = context.socket(zmq.PULL) function_status_socket.bind('tcp://*:7003') list_schedulers_socket = context.socket(zmq.REP) list_schedulers_socket.bind('tcp://*:7004') executor_depart_socket = context.socket(zmq.PULL) executor_depart_socket.bind('tcp://*:7005') statistics_socket = context.socket(zmq.PULL) statistics_socket.bind('tcp://*:7006') pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 seconds. pin_accept_socket.bind('tcp://*:' + PIN_ACCEPT_PORT) poller = zmq.Poller() poller.register(restart_pull_socket, zmq.POLLIN) poller.register(churn_pull_socket, zmq.POLLIN) poller.register(function_status_socket, zmq.POLLIN) poller.register(list_executors_socket, zmq.POLLIN) poller.register(list_schedulers_socket, zmq.POLLIN) poller.register(executor_depart_socket, zmq.POLLIN) poller.register(statistics_socket, zmq.POLLIN) add_push_socket = context.socket(zmq.PUSH) add_push_socket.connect('ipc:///tmp/node_add') remove_push_socket = context.socket(zmq.PUSH) remove_push_socket.connect('ipc:///tmp/node_remove') client, _ = util.init_k8s() scaler = DefaultScaler(self_ip, context, add_push_socket, remove_push_socket, pin_accept_socket) policy = DefaultHydroPolicy(scaler) # Tracks the self-reported statuses of each executor thread in the system. executor_statuses = {} # Tracks of which executors are departing. This is used to ensure all # threads acknowledge that they are finished before we remove a thread from # the system. departing_executors = {} # Tracks how often each function is called. function_frequencies = {} # Tracks the aggregated runtime for each function. function_runtimes = {} # Tracks the arrival times of DAG requests. arrival_times = {} # Tracks how often each DAG is called. dag_frequencies = {} # Tracks how long each DAG request spends in the system, end to end. dag_runtimes = {} start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if (churn_pull_socket in socks and socks[churn_pull_socket] == zmq.POLLIN): msg = churn_pull_socket.recv_string() args = msg.split(':') if args[0] == 'add': scaler.add_vms(args[2], args[1]) elif args[0] == 'remove': scaler.remove_vms(args[2], args[1]) if (restart_pull_socket in socks and socks[restart_pull_socket] == zmq.POLLIN): msg = restart_pull_socket.recv_string() args = msg.split(':') pod = util.get_pod_from_ip(client, args[1]) count = str(pod.status.container_statuses[0].restart_count) restart_pull_socket.send_string(count) if (list_executors_socket in socks and socks[list_executors_socket] == zmq.POLLIN): # We can safely ignore this message's contents, and the response # does not depend on it. response_ip = list_executors_socket.recv_string() ips = StringSet() for ip in util.get_pod_ips(client, 'role=function'): ips.keys.append(ip) for ip in util.get_pod_ips(client, 'role=gpu'): ips.keys.append(ip) sckt = pusher_cache.get(response_ip) sckt.send(ips.SerializeToString()) if (function_status_socket in socks and socks[function_status_socket] == zmq.POLLIN): # Dequeue all available ThreadStatus messages rather than doing # them one at a time---this prevents starvation if other operations # (e.g., pin) take a long time. while True: status = ThreadStatus() try: status.ParseFromString(function_status_socket.recv(zmq.DONTWAIT)) except: break # We've run out of messages. key = (status.ip, status.tid) # If this executor is one of the ones that's currently departing, # we can just ignore its status updates since we don't want # utilization to be skewed downwards. The reason we might still # receive this message is because the depart message may not have # arrived when this was sent. if key[0] in departing_executors: continue executor_statuses[key] = status # logging.info(('Received thread status update from %s:%d: %.4f ' + # 'occupancy, %d functions pinned') % # (status.ip, status.tid, status.utilization, # len(status.functions))) logging.info(f"Functions {status.functions} is placed on node " f"{status.ip}:{status.tid}") if (list_schedulers_socket in socks and socks[list_schedulers_socket] == zmq.POLLIN): # We can safely ignore this message's contents, and the response # does not depend on it. list_schedulers_socket.recv_string() ips = StringSet() for ip in util.get_pod_ips(client, 'role=scheduler'): ips.keys.append(ip) list_schedulers_socket.send(ips.SerializeToString()) if (executor_depart_socket in socks and socks[executor_depart_socket] == zmq.POLLIN): ip = executor_depart_socket.recv_string() departing_executors[ip] -= 1 # We wait until all the threads at this executor have acknowledged # that they are ready to leave, and we then remove the VM from the # system. if departing_executors[ip] == 0: logging.info('Removing node with ip %s' % ip) scaler.remove_vms('function', ip) del departing_executors[ip] if (statistics_socket in socks and socks[statistics_socket] == zmq.POLLIN): stats = ExecutorStatistics() stats.ParseFromString(statistics_socket.recv()) # Aggregates statistics reported for individual functions including # call frequencies, processed requests, and total runtimes. for fstats in stats.functions: fname = fstats.name if fname not in function_frequencies: function_frequencies[fname] = 0 if fname not in function_runtimes: function_runtimes[fname] = (0.0, 0) if fstats.runtime: old_latency = function_runtimes[fname] # This tracks how many calls were processed for the # function and the length of the total runtime of all # calls. function_runtimes[fname] = ( old_latency[0] + sum(fstats.runtime), old_latency[1] + fstats.call_count) else: # This tracks how many calls are made to the function. function_frequencies[fname] += fstats.call_count # Aggregates statistics for DAG requests, including call # frequencies, arrival rates, and end-to-end runtimes. for dstats in stats.dags: dname = dstats.name # Tracks the interarrival rates of requests to this function as # perceived by the scheduler. if dname not in arrival_times: arrival_times[dname] = [] arrival_times[dname] += list(dstats.interarrival) # Tracks how many calls to this DAG were received. if dname not in dag_frequencies: dag_frequencies[dname] = 0 dag_frequencies[dname] += dstats.call_count # Tracks the end-to-end runtime of individual requests # completed in the last epoch. if dname not in dag_runtimes: dag_runtimes[dname] = [] for rt in dstats.runtimes: dag_runtimes[dname].append(rt) end = time.time() if end - start > REPORT_PERIOD: logging.info('Checking hash ring...') check_hash_ring(client, context) # Invoke the configured policy to check system load and respond # appropriately. policy.replica_policy(function_frequencies, function_runtimes, dag_runtimes, executor_statuses, arrival_times) # TODO(simon): this turn off node scaling policy, which is what we want for static exp env # policy.executor_policy(executor_statuses, departing_executors) # Clears all metadata that was passed in for this epoch. function_runtimes.clear() function_frequencies.clear() dag_runtimes.clear() arrival_times.clear() # Restart the timer for the next reporting epoch. start = time.time()
class FluentUserLibrary(AbstractFluentUserLibrary): # ip: Executor IP. # tid: Executor thread ID. # anna_client: The Anna client, used for interfacing with the kvs. def __init__(self, ip, tid, anna_client): self.ctx = zmq.Context() self.send_socket_cache = SocketCache(self.ctx, zmq.PUSH) self.executor_ip = ip self.executor_tid = tid self.client = anna_client # Threadsafe queue to serve as this node's inbox. # Items are (sender string, message bytestring). # NB: currently unbounded in size. self.recv_inbox = queue.Queue() # Thread for receiving messages into our inbox. self.recv_inbox_thread = threading.Thread( target=self._recv_inbox_listener) self.recv_inbox_thread.do_run = True self.recv_inbox_thread.start() def put(self, ref, ltc): return self.client.put(ref, ltc) def get(self, ref): if type(ref) == list: return self.client.get(ref) return self.client.get(ref)[ref] def getid(self): return (self.executor_ip, self.executor_tid) # dest is currently (IP string, thread id int) of destination executor. def send(self, dest, bytestr): ip, tid = dest dest_addr = server_utils._get_user_msg_inbox_addr(ip, tid) sender = (self.executor_ip, self.executor_tid) socket = self.send_socket_cache.get(dest_addr) socket.send_pyobj((sender, bytestr)) def close(self): self.recv_inbox_thread.do_run = False self.recv_inbox_thread.join() def recv(self): res = [] while True: try: (sender, msg) = self.recv_inbox.get(block=False) res.append((sender, msg)) except queue.Empty: break return res # Function that continuously listens for send()s sent by other nodes, # and stores the messages in an inbox. def _recv_inbox_listener(self): # Socket for receiving send() messages from other nodes. recv_inbox_socket = self.ctx.socket(zmq.PULL) recv_inbox_socket.bind(server_utils.BIND_ADDR_TEMPLATE % (server_utils.RECV_INBOX_PORT + self.executor_tid)) t = threading.currentThread() while t.do_run: try: (sender, msg) = recv_inbox_socket.recv_pyobj(zmq.NOBLOCK) self.recv_inbox.put((sender, msg)) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: continue else: raise e time.sleep(.010) recv_inbox_socket.close()
def executor(ip, mgmt_ip, schedulers, thread_id): logging.basicConfig(filename='log_executor.txt', level=logging.INFO, format='%(asctime)s %(message)s') ctx = zmq.Context(1) poller = zmq.Poller() pin_socket = ctx.socket(zmq.PULL) pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id)) unpin_socket = ctx.socket(zmq.PULL) unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.UNPIN_PORT + thread_id)) exec_socket = ctx.socket(zmq.PULL) exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.FUNC_EXEC_PORT + thread_id)) dag_queue_socket = ctx.socket(zmq.PULL) dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_QUEUE_PORT + thread_id)) dag_exec_socket = ctx.socket(zmq.PULL) dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_EXEC_PORT + thread_id)) self_depart_socket = ctx.socket(zmq.PULL) self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SELF_DEPART_PORT + thread_id)) pusher_cache = SocketCache(ctx, zmq.PUSH) poller = zmq.Poller() poller.register(pin_socket, zmq.POLLIN) poller.register(unpin_socket, zmq.POLLIN) poller.register(exec_socket, zmq.POLLIN) poller.register(dag_queue_socket, zmq.POLLIN) poller.register(dag_exec_socket, zmq.POLLIN) poller.register(self_depart_socket, zmq.POLLIN) client = IpcAnnaClient(thread_id) status = ThreadStatus() status.ip = ip status.tid = thread_id status.running = True utils._push_status(schedulers, pusher_cache, status) departing = False # this is going to be a map of map of maps for every function that we have # pinnned, we will track a map of execution ids to DAG schedules queue = {} # track the actual function objects that we are storing here pinned_functions = {} # tracks runtime cost of excuting a DAG function runtimes = {} # if multiple triggers are necessary for a function, track the triggers as # we receive them received_triggers = {} # track when we received a function request, so we can report e2e latency receive_times = {} # track how many functions we're executing exec_counts = {} # metadata to track thread utilization report_start = time.time() event_occupancy = { 'pin': 0.0, 'unpin': 0.0, 'func_exec': 0.0, 'dag_queue': 0.0, 'dag_exec': 0.0 } total_occupancy = 0.0 while True: socks = dict(poller.poll(timeout=1000)) if pin_socket in socks and socks[pin_socket] == zmq.POLLIN: work_start = time.time() pin(pin_socket, client, status, pinned_functions, runtimes, exec_counts) utils._push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['pin'] += elapsed total_occupancy += elapsed if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN: work_start = time.time() unpin(unpin_socket, status, pinned_functions, runtimes, exec_counts) utils._push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['unpin'] += elapsed total_occupancy += elapsed if exec_socket in socks and socks[exec_socket] == zmq.POLLIN: work_start = time.time() exec_function(exec_socket, client, status) elapsed = time.time() - work_start event_occupancy['func_exec'] += elapsed total_occupancy += elapsed if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN: work_start = time.time() schedule = DagSchedule() schedule.ParseFromString(dag_queue_socket.recv()) fname = schedule.target_function logging.info('Received a schedule for DAG %s (%s), function %s.' % (schedule.dag.name, schedule.id, fname)) if fname not in queue: queue[fname] = {} queue[fname][schedule.id] = schedule if (schedule.id, fname) not in receive_times: receive_times[(schedule.id, fname)] = time.time() # in case we receive the trigger before we receive the schedule, we # can trigger from this operation as well trkey = (schedule.id, fname) if trkey in received_triggers and \ len(received_triggers[trkey]) == \ len(schedule.triggers): exec_dag_function(pusher_cache, client, received_triggers[trkey], pinned_functions[fname], schedule) del received_triggers[trkey] del queue[fname][schedule.id] fend = time.time() fstart = receive_times[(schedule.id, fname)] runtimes[fname] += fend - fstart exec_counts[fname] += 1 elapsed = time.time() - work_start event_occupancy['dag_queue'] += elapsed total_occupancy += elapsed if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN: work_start = time.time() trigger = DagTrigger() trigger.ParseFromString(dag_exec_socket.recv()) fname = trigger.target_function logging.info('Received a trigger for schedule %s, function %s.' % (trigger.id, fname)) key = (trigger.id, fname) if trigger.id not in received_triggers: received_triggers[key] = {} if (trigger.id, fname) not in receive_times: receive_times[(trigger.id, fname)] = time.time() received_triggers[key][trigger.source] = trigger if fname in queue and trigger.id in queue[fname]: schedule = queue[fname][trigger.id] if len(received_triggers[key]) == len(schedule.triggers): exec_dag_function(pusher_cache, client, received_triggers[key], pinned_functions[fname], schedule) del received_triggers[key] del queue[fname][trigger.id] fend = time.time() fstart = receive_times[(trigger.id, fname)] runtimes[fname] += fend - fstart exec_counts[fname] += 1 elapsed = time.time() - work_start event_occupancy['dag_exec'] += elapsed total_occupancy += elapsed if self_depart_socket in socks and socks[self_depart_socket] == \ zmq.POLLIN: # This message should not matter msg = self_depart_socket.recv() logging.info('Preparing to depart. No longer accepting requests ' + 'and clearing all queues.') status.ClearField('functions') status.running = False utils._push_status(schedulers, pusher_cache, status) departing = True # periodically report function occupancy report_end = time.time() if report_end - report_start > REPORT_THRESH: utilization = total_occupancy / (report_end - report_start) status.utilization = utilization sckt = pusher_cache.get(utils._get_util_report_address(mgmt_ip)) sckt.send(status.SerializeToString()) logging.info('Total thread occupancy: %.6f%%' % (utilization)) for event in event_occupancy: occ = event_occupancy[event] logging.info('Event %s occupancy: %.6f%%' % (event, occ)) event_occupancy[event] = 0.0 stats = ExecutorStatistics() for fname in runtimes: if exec_counts[fname] > 0: fstats = stats.statistics.add() fstats.fname = fname fstats.runtime = runtimes[fname] fstats.call_count = exec_counts[fname] runtimes[fname] = 0.0 exec_counts[fname] = 0 sckt = pusher_cache.get(sutils._get_statistics_report_address \ (mgmt_ip)) sckt.send(stats.SerializeToString()) report_start = time.time() total_occupancy = 0.0 # periodically clear any old functions we have cached that we are # no longer accepting requests for for fname in queue: if len(queue[fname]) == 0 and fname not in status.functions: del queue[fname] del pinned_functions[fname] del runtimes[fname] del exec_counts[fname] # if we are departing and have cleared our queues, let the # management server know, and exit the process if departing and len(queue) == 0: sckt = pusher_cache.get(utils._get_depart_done_addr(mgmt_ip)) sckt.send_string(ip) return 0
def executor(ip, mgmt_ip, schedulers, thread_id): global_util = 0 logging.basicConfig(filename='log_executor.txt', level=logging.INFO) ctx = zmq.Context(1) poller = zmq.Poller() pin_socket = ctx.socket(zmq.REP) pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id)) unpin_socket = ctx.socket(zmq.REP) unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.UNPIN_PORT + thread_id)) exec_socket = ctx.socket(zmq.REP) exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.FUNC_EXEC_PORT + thread_id)) dag_queue_socket = ctx.socket(zmq.REP) dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_QUEUE_PORT + thread_id)) dag_exec_socket = ctx.socket(zmq.PULL) dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_EXEC_PORT + thread_id)) self_depart_socket = ctx.socket(zmq.PULL) self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SELF_DEPART_PORT + thread_id)) pusher_cache = SocketCache(ctx, zmq.PUSH) poller = zmq.Poller() poller.register(pin_socket, zmq.POLLIN) poller.register(unpin_socket, zmq.POLLIN) poller.register(exec_socket, zmq.POLLIN) poller.register(dag_queue_socket, zmq.POLLIN) poller.register(dag_exec_socket, zmq.POLLIN) poller.register(self_depart_socket, zmq.POLLIN) client = IpcAnnaClient() status = ThreadStatus() status.ip = ip status.tid = thread_id status.running = True utils._push_status(schedulers, pusher_cache, status) departing = False # this is going to be a map of map of maps for every function that we have # pinnned, we will track a map of execution ids to DAG schedules queue = {} # track the actual function objects that we are storing here pinned_functions = {} # tracks runtime cost of excuting a DAG function runtimes = {} # metadata to track thread utilization report_start = time.time() event_occupancy = { 'pin': 0.0, 'unpin': 0.0, 'func_exec': 0.0, 'dag_queue': 0.0, 'dag_exec': 0.0 } total_occupancy = 0.0 while True: socks = dict(poller.poll(timeout=1000)) if pin_socket in socks and socks[pin_socket] == zmq.POLLIN: work_start = time.time() pin(pin_socket, client, status, pinned_functions, runtimes) utils._push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['pin'] += elapsed total_occupancy += elapsed if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN: work_start = time.time() unpin(unpin_socket, status, pinned_functions, runtimes) utils._push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['unpin'] += elapsed total_occupancy += elapsed if exec_socket in socks and socks[exec_socket] == zmq.POLLIN: work_start = time.time() exec_function(exec_socket, client, status) elapsed = time.time() - work_start event_occupancy['func_exec'] += elapsed total_occupancy += elapsed if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN: work_start = time.time() schedule = DagSchedule() schedule.ParseFromString(dag_queue_socket.recv()) fname = schedule.target_function logging.info('Received a schedule for DAG %s, function %s.' % (schedule.dag.name, fname)) # if we are trying to kill this node or unpin this function, we # don't accept requests anymore for DAG schedules; this also checks # to make sure it's the right IP for the target if not status.running or (fname not in status.functions and \ fname in queue and \ schedule.id not in queue[fname].keys()) or \ schedule.locations[fname].split(':')[0] != ip: sutils.error.error = INVALID_TARGET dag_queue_socket.send(sutils.error.SerializeToString()) continue if fname not in queue: queue[fname] = {} queue[fname][schedule.id] = schedule dag_queue_socket.send(sutils.ok_resp) elapsed = time.time() - work_start event_occupancy['dag_queue'] += elapsed total_occupancy += elapsed if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN: work_start = time.time() trigger = DagTrigger() trigger.ParseFromString(dag_exec_socket.recv()) fname = trigger.target_function exec_dag_function(pusher_cache, client, trigger, pinned_functions[fname], queue[fname][trigger.id]) elapsed = time.time() - work_start event_occupancy['dag_exec'] += elapsed total_occupancy += elapsed runtimes[fname] += elapsed if self_depart_socket in socks and socks[self_depart_socket] == \ zmq.POLLIN: # This message should not matter msg = self_depart_socket.recv() logging.info('Preparing to depart. No longer accepting requests ' + 'and clearing all queues.') status.ClearField('functions') status.running = False utils._push_status(schedulers, pusher_cache, status) departing = True # periodically report function occupancy report_end = time.time() if report_end - report_start > REPORT_THRESH: utilization = total_occupancy / (report_end - report_start) status.utilization = utilization sckt = pusher_cache.get(utils._get_util_report_address(mgmt_ip)) sckt.send(status.SerializeToString()) logging.info('Total thread occupancy: %.4f%%' % (utilization)) for event in event_occupancy: occ = event_occupancy[event] logging.info('Event %s occupancy: %.4f%%' % (event, occ)) event_occupancy[event] = 0.0 stats = ExecutorStatistics() for fname in runtimes: fstats = stats.statistics.add() fstats.fname = fname fstats.runtime = runtimes[fname] runtimes[fname] = 0.0 sckt = pusher_cache.get(sutils._get_statistics_report_address \ (mgmt_ip)) sckt.send(stats.SerializeToString()) report_start = time.time() total_occupancy = 0.0 # periodically clear any old functions we have cached that we are # no longer accepting requests for for fname in queue: if len(queue[fname]) == 0 and fname not in status.functions: del queue[fname] del pinned_functions[fname] # if we are departing and have cleared our queues, let the # management server know, and exit the process if departing and len(queue) == 0: sckt = pusher_cache.get(utils._get_depart_done_addr(mgmt_ip)) sckt.send_string(ip) return 0
def executor(ip, mgmt_ip, schedulers, thread_id): logging.basicConfig(filename='log_executor.txt', level=logging.INFO, format='%(asctime)s %(message)s') context = zmq.Context(1) poller = zmq.Poller() pin_socket = context.socket(zmq.PULL) pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id)) unpin_socket = context.socket(zmq.PULL) unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.UNPIN_PORT + thread_id)) exec_socket = context.socket(zmq.PULL) exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.FUNC_EXEC_PORT + thread_id)) dag_queue_socket = context.socket(zmq.PULL) dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_QUEUE_PORT + thread_id)) dag_exec_socket = context.socket(zmq.PULL) dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_EXEC_PORT + thread_id)) self_depart_socket = context.socket(zmq.PULL) self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SELF_DEPART_PORT + thread_id)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(pin_socket, zmq.POLLIN) poller.register(unpin_socket, zmq.POLLIN) poller.register(exec_socket, zmq.POLLIN) poller.register(dag_queue_socket, zmq.POLLIN) poller.register(dag_exec_socket, zmq.POLLIN) poller.register(self_depart_socket, zmq.POLLIN) # If the management IP is set to None, that means that we are running in # local mode, so we use a regular AnnaTcpClient rather than an IPC client. if mgmt_ip: client = AnnaIpcClient(thread_id, context) else: client = AnnaTcpClient('127.0.0.1', '127.0.0.1', local=True, offset=1) user_library = CloudburstUserLibrary(context, pusher_cache, ip, thread_id, client) status = ThreadStatus() status.ip = ip status.tid = thread_id status.running = True utils.push_status(schedulers, pusher_cache, status) departing = False # Maintains a request queue for each function pinned on this executor. Each # function will have a set of request IDs mapped to it, and this map stores # a schedule for each request ID. queue = {} # Tracks the actual function objects that are pinned to this executor. function_cache = {} # Tracks runtime cost of excuting a DAG function. runtimes = {} # If multiple triggers are necessary for a function, track the triggers as # we receive them. This is also used if a trigger arrives before its # corresponding schedule. received_triggers = {} # Tracks when we received a function request, so we can report end-to-end # latency for the whole executio. receive_times = {} # Tracks the number of requests we are finishing for each function pinned # here. exec_counts = {} # Tracks the end-to-end runtime of each DAG request for which we are the # sink function. dag_runtimes = {} # A map with KVS keys and their corresponding deserialized payloads. cache = {} # Internal metadata to track thread utilization. report_start = time.time() event_occupancy = { 'pin': 0.0, 'unpin': 0.0, 'func_exec': 0.0, 'dag_queue': 0.0, 'dag_exec': 0.0 } total_occupancy = 0.0 while True: socks = dict(poller.poll(timeout=1000)) if pin_socket in socks and socks[pin_socket] == zmq.POLLIN: work_start = time.time() pin(pin_socket, pusher_cache, client, status, function_cache, runtimes, exec_counts, user_library) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['pin'] += elapsed total_occupancy += elapsed if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN: work_start = time.time() unpin(unpin_socket, status, function_cache, runtimes, exec_counts) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['unpin'] += elapsed total_occupancy += elapsed if exec_socket in socks and socks[exec_socket] == zmq.POLLIN: work_start = time.time() exec_function(exec_socket, client, user_library, cache, function_cache) user_library.close() utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['func_exec'] += elapsed total_occupancy += elapsed if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN: work_start = time.time() schedule = DagSchedule() schedule.ParseFromString(dag_queue_socket.recv()) fname = schedule.target_function logging.info('Received a schedule for DAG %s (%s), function %s.' % (schedule.dag.name, schedule.id, fname)) if fname not in queue: queue[fname] = {} queue[fname][schedule.id] = schedule if (schedule.id, fname) not in receive_times: receive_times[(schedule.id, fname)] = time.time() # In case we receive the trigger before we receive the schedule, we # can trigger from this operation as well. trkey = (schedule.id, fname) if (trkey in received_triggers and (len(received_triggers[trkey]) == len(schedule.triggers))): exec_dag_function(pusher_cache, client, received_triggers[trkey], function_cache[fname], schedule, user_library, dag_runtimes, cache) user_library.close() del received_triggers[trkey] del queue[fname][schedule.id] fend = time.time() fstart = receive_times[(schedule.id, fname)] runtimes[fname].append(fend - fstart) exec_counts[fname] += 1 elapsed = time.time() - work_start event_occupancy['dag_queue'] += elapsed total_occupancy += elapsed if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN: work_start = time.time() trigger = DagTrigger() trigger.ParseFromString(dag_exec_socket.recv()) fname = trigger.target_function logging.info('Received a trigger for schedule %s, function %s.' % (trigger.id, fname)) key = (trigger.id, fname) if key not in received_triggers: received_triggers[key] = {} if (trigger.id, fname) not in receive_times: receive_times[(trigger.id, fname)] = time.time() received_triggers[key][trigger.source] = trigger if fname in queue and trigger.id in queue[fname]: schedule = queue[fname][trigger.id] if len(received_triggers[key]) == len(schedule.triggers): exec_dag_function(pusher_cache, client, received_triggers[key], function_cache[fname], schedule, user_library, dag_runtimes, cache) user_library.close() del received_triggers[key] del queue[fname][trigger.id] fend = time.time() fstart = receive_times[(trigger.id, fname)] runtimes[fname].append(fend - fstart) exec_counts[fname] += 1 elapsed = time.time() - work_start event_occupancy['dag_exec'] += elapsed total_occupancy += elapsed if self_depart_socket in socks and socks[self_depart_socket] == \ zmq.POLLIN: # This message does not matter. self_depart_socket.recv() logging.info('Preparing to depart. No longer accepting requests ' + 'and clearing all queues.') status.ClearField('functions') status.running = False utils.push_status(schedulers, pusher_cache, status) departing = True # periodically report function occupancy report_end = time.time() if report_end - report_start > REPORT_THRESH: cache.clear() utilization = total_occupancy / (report_end - report_start) status.utilization = utilization # Periodically report my status to schedulers with the utilization # set. utils.push_status(schedulers, pusher_cache, status) logging.info('Total thread occupancy: %.6f' % (utilization)) for event in event_occupancy: occ = event_occupancy[event] / (report_end - report_start) logging.info('\tEvent %s occupancy: %.6f' % (event, occ)) event_occupancy[event] = 0.0 stats = ExecutorStatistics() for fname in runtimes: if exec_counts[fname] > 0: fstats = stats.functions.add() fstats.name = fname fstats.call_count = exec_counts[fname] fstats.runtime.extend(runtimes[fname]) runtimes[fname].clear() exec_counts[fname] = 0 for dname in dag_runtimes: dstats = stats.dags.add() dstats.name = dname dstats.runtimes.extend(dag_runtimes[dname]) dag_runtimes[dname].clear() # If we are running in cluster mode, mgmt_ip will be set, and we # will report our status and statistics to it. Otherwise, we will # write to the local conf file if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) sckt = pusher_cache.get(utils.get_util_report_address(mgmt_ip)) sckt.send(status.SerializeToString()) else: logging.info(stats) status.ClearField('utilization') report_start = time.time() total_occupancy = 0.0 # Periodically clear any old functions we have cached that we are # no longer accepting requests for. del_list = [] for fname in queue: if len(queue[fname]) == 0 and fname not in status.functions: del_list.append(fname) del function_cache[fname] del runtimes[fname] del exec_counts[fname] for fname in del_list: del queue[fname] # If we are departing and have cleared our queues, let the # management server know, and exit the process. if departing and len(queue) == 0: sckt = pusher_cache.get(utils.get_depart_done_addr(mgmt_ip)) sckt.send_string(ip) # We specifically pass 1 as the exit code when ending our # process so that the wrapper script does not restart us. os._exit(1)
def scheduler(ip, mgmt_ip, route_addr, policy_type): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) context.set(zmq.MAX_SOCKETS, 10000) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = set() connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) # This is for handle the invocation from queue # Mainly for storage event func_call_queue_socket = context.socket(zmq.PULL) func_call_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_QUEUE_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 seconds. pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) continuation_socket = context.socket(zmq.PULL) continuation_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.CONTINUATION_PORT)) if not local: management_request_socket = context.socket(zmq.REQ) management_request_socket.setsockopt(zmq.RCVTIMEO, 500) # By setting this flag, zmq matches replies with requests. management_request_socket.setsockopt(zmq.REQ_CORRELATE, 1) # Relax strict alternation between request and reply. # For detailed explanation, see here: http://api.zeromq.org/4-1:zmq-setsockopt management_request_socket.setsockopt(zmq.REQ_RELAXED, 1) management_request_socket.connect( sched_utils.get_scheduler_list_address(mgmt_ip)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(func_call_queue_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) poller.register(continuation_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, policy_type, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if func_call_queue_socket in socks and socks[ func_call_queue_socket] == zmq.POLLIN: call_function_from_queue(func_call_queue_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: start_t = int(time.time() * 1000000) call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) sched_t = int(time.time() * 1000000) logging.info( f'App function {name} recv: {start_t}, scheduled: {sched_t}') dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname.name not in call_frequency: call_frequency[fname.name] = 0 policy.update_function_locations(status.function_locations) if continuation_socket in socks and socks[continuation_socket] == \ zmq.POLLIN: start_t = int(time.time() * 1000000) continuation = Continuation() continuation.ParseFromString(continuation_socket.recv()) call = continuation.call call.name = continuation.name result = Value() result.ParseFromString(continuation.result) dag, sources = dags[call.name] for source in sources: call.function_args[source].values.extend([result]) call_dag(call, pusher_cache, dags, policy, continuation.id) sched_t = int(time.time() * 1000000) print( f'App function {call.name} recv: {start_t}, scheduled: {sched_t}' ) for fname in dag.functions: call_frequency[fname.name] += 1 end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if not local: latest_schedulers = sched_utils.get_ip_set( management_request_socket, False) if latest_schedulers: schedulers = latest_schedulers if end - start > REPORT_THRESHOLD: status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.debug('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()
def scheduler(ip, mgmt_ip, route_addr): logging.basicConfig(filename='log_scheduler.txt', level=logging.INFO) kvs = AnnaClient(route_addr, ip) key_cache_map = {} key_ip_map = {} ctx = zmq.Context(1) # Each dag consists of a set of functions and connections. Each one of # the functions is pinned to one or more nodes, which is tracked here. dags = {} thread_statuses = {} func_locations = {} connect_socket = ctx.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = ctx.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = ctx.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) dag_create_socket = ctx.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = ctx.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) list_socket = ctx.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = ctx.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = ctx.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) requestor_cache = SocketCache(ctx, zmq.REQ) pusher_cache = SocketCache(ctx, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) departed_executors = set() executors, schedulers = _update_cluster_state(requestor_cache, mgmt_ip, departed_executors, key_cache_map, key_ip_map, kvs) # track how often each DAG function is called call_frequency = {} start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(routing_addr) if func_create_socket in socks and socks[ func_create_socket] == zmq.POLLIN: create_func(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, requestor_cache, executors, key_ip_map) if dag_create_socket in socks and socks[ dag_create_socket] == zmq.POLLIN: create_dag(dag_create_socket, requestor_cache, kvs, executors, dags, func_locations, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: call = DagCall() call.ParseFromString(dag_call_socket.recv()) exec_id = generate_timestamp(0) dag = dags[call.name] for fname in dag[0].functions: call_frequency[fname] += 1 accepted, error, rid = call_dag(call, requestor_cache, pusher_cache, dags, func_locations, key_ip_map) while not accepted: # the assumption here is that the request was not accepted # because the cluster was out of date -- so we update cluster # state before proceeding executors, schedulers = _update_cluster_state( requestor_cache, mgmt_ip, departed_executors, key_cache_map, key_ip_map, kvs) accepted, error, rid = call_dag(call, requestor_cache, pusher_cache, dags, func_locations, key_ip_map) resp = GenericResponse() resp.success = True resp.response_id = rid dag_call_socket.send(resp.SerializeToString()) if list_socket in socks and socks[list_socket] == zmq.POLLIN: logging.info('Received query for function list.') msg = list_socket.recv_string() prefix = msg if msg else '' resp = FunctionList() resp.names.extend(utils._get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) key = (status.ip, status.tid) logging.info('Received status update from executor %s:%d.' % (key[0], int(key[1]))) # this means that this node is currently departing, so we remove it # from all of our metadata tracking if not status.running: old_status = thread_statuses[key] executors.remove(key) departed_executors.add(status.ip) del thread_statuses[key] for fname in old_status.functions: func_locations[fname].remove(old_status.ip, old_status.tid) continue if key not in thread_statuses: thread_statuses[key] = status if key not in executors: executors.add(key) elif thread_statuses[key] != status: # remove all the old function locations, and all the new ones # -- there will probably be a large overlap, but this shouldn't # be much different than calculating two different set # differences anyway for func in thread_statuses[key].functions: if func in func_locations: func_locations[func].discard(key) for func in status.functions: if func not in func_locations: func_locations[func] = set() func_locations[func].add(key) thread_statuses[key] = status if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: logging.info('Received update from another scheduler.') ks = KeySet() ks.ParseFromString(sched_update_socket.recv()) # retrieve any DAG that some other scheduler knows about that we do # not yet know about for dname in ks.keys: if dname not in dags: dag = Dag() dag.ParseFromString(kvs.get(dname).value) dags[dname] = dag end = time.time() if end - start > THRESHOLD: executors, schedulers = _update_cluster_state( requestor_cache, mgmt_ip, departed_executors, key_cache_map, key_ip_map, kvs) dag_names = KeySet() for name in dags.keys(): dag_names.keys.append(name) msg = dag_names.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: pusher_cache.get( utils._get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.statistics.add() fstats.fname = fname fstats.call_count = call_frequency[fname] call_frequency[fname] = 0 sckt = pusher_cache.get(sutils._get_statistics_report_address \ (mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()