def _get_func_list(self, prefix=None): msg = prefix if prefix else '' self.list_sock.send_string(msg) flist = StringSet() flist.ParseFromString(self.list_sock.recv()) return flist.keys
def update(self): # Periodically clean up the running counts map to drop any times older # than 5 seconds. for executor in self.running_counts: new_set = set() for ts in self.running_counts[executor]: if time.time() - ts < 5: new_set.add(ts) self.running_counts[executor] = new_set # Clean up any backoff messages that were added more than 5 seconds ago # -- this should be enough to drain a queue. remove_set = set() for executor in self.backoff: if time.time() - self.backoff[executor] > 5: remove_set.add(executor) for executor in remove_set: del self.backoff[executor] executors = set(map(lambda status: status.ip, self.thread_statuses.values())) # Update the sets of keys that are being cached at each IP address. self.key_locations.clear() for ip in executors: key = get_cache_ip_key(ip) # This is of type LWWPairLattice, which has a StringSet protobuf # packed into it; we want the keys in that StringSet protobuf. lattice = self.kvs_client.get(key)[key] if lattice is None: # We will only get None if this executor is still joining; if # so, we just ignore this for now and move on. continue st = StringSet() st.ParseFromString(lattice.reveal()) for key in st.keys: if key not in self.key_locations: self.key_locations[key] = [] self.key_locations[key].append(ip)
def get_ip_set(request_ip, socket_cache, exec_threads=True): sckt = socket_cache.get(request_ip) # we can send an empty request because the response is always thes same sckt.send(b'') ips = StringSet() ips.ParseFromString(sckt.recv()) result = set() if exec_threads: for ip in ips.keys: for i in range(NUM_EXEC_THREADS): result.add((ip, i)) return result else: return set(ips.keys)
def get_ip_set(management_request_socket, exec_threads=True): # we can send an empty request because the response is always the same management_request_socket.send(b'') try: ips = StringSet() ips.ParseFromString(management_request_socket.recv()) result = set() if exec_threads: for ip in ips.keys: for i in range(NUM_EXEC_THREADS): result.add((ip, i)) return result else: return set(ips.keys) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: return None else: raise e
def test_metadata_update(self): ''' This test calls the periodic metadata update protocol and ensures that the correct metadata is removed from the system and that the correct metadata is retrieved/updated from the KVS. ''' # Create two executor threads on separate machines. old_ip = '127.0.0.1' new_ip = '192.168.0.1' old_executor = (old_ip, 1) new_executor = (new_ip, 2) old_status = ThreadStatus() old_status.ip = old_ip old_status.tid = 1 old_status.running = True new_status = ThreadStatus() new_status.ip = new_ip new_status.tid = 2 new_status.running = True self.policy.thread_statuses[old_executor] = old_status self.policy.thread_statuses[new_executor] = new_status # Add two executors, one with old an old backoff and one with a new # time. self.policy.backoff[old_executor] = time.time() - 10 self.policy.backoff[new_executor] = time.time() # For the new executor, add 10 old running times and 10 new ones. self.policy.running_counts[new_executor] = set() for _ in range(10): time.sleep(.0001) self.policy.running_counts[new_executor].add(time.time() - 10) for _ in range(10): time.sleep(.0001) self.policy.running_counts[new_executor].add(time.time()) # Publish some caching metadata into the KVS for each executor. old_set = StringSet() old_set.keys.extend(['key1', 'key2', 'key3']) new_set = StringSet() new_set.keys.extend(['key3', 'key4', 'key5']) self.kvs_client.put(get_cache_ip_key(old_ip), LWWPairLattice(0, old_set.SerializeToString())) self.kvs_client.put(get_cache_ip_key(new_ip), LWWPairLattice(0, new_set.SerializeToString())) self.policy.update() # Check that the metadata has been correctly pruned. self.assertEqual(len(self.policy.backoff), 1) self.assertTrue(new_executor in self.policy.backoff) self.assertEqual(len(self.policy.running_counts[new_executor]), 10) # Check that the caching information is correct. self.assertTrue(len(self.policy.key_locations['key1']), 1) self.assertTrue(len(self.policy.key_locations['key2']), 1) self.assertTrue(len(self.policy.key_locations['key3']), 2) self.assertTrue(len(self.policy.key_locations['key4']), 1) self.assertTrue(len(self.policy.key_locations['key5']), 1) self.assertTrue(old_ip in self.policy.key_locations['key1']) self.assertTrue(old_ip in self.policy.key_locations['key2']) self.assertTrue(old_ip in self.policy.key_locations['key3']) self.assertTrue(new_ip in self.policy.key_locations['key3']) self.assertTrue(new_ip in self.policy.key_locations['key4']) self.assertTrue(new_ip in self.policy.key_locations['key5'])
def scheduler(ip, mgmt_ip, route_addr): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = [] connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500) pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) requestor_cache = SocketCache(context, zmq.REQ) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname not in call_frequency: call_frequency[fname] = 0 policy.update_function_locations(status.function_locations) end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if mgmt_ip: schedulers = sched_utils.get_ip_set( sched_utils.get_scheduler_list_address(mgmt_ip), requestor_cache, False) if end - start > REPORT_THRESHOLD: num_unique_executors = policy.get_unique_executors() key = scheduler_id + ':' + str(time.time()) data = {'key': key, 'count': num_unique_executors} status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.info('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()
def scheduler(ip, mgmt_ip, route_addr, policy_type): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) context.set(zmq.MAX_SOCKETS, 10000) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = set() connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) # This is for handle the invocation from queue # Mainly for storage event func_call_queue_socket = context.socket(zmq.PULL) func_call_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_QUEUE_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 seconds. pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) continuation_socket = context.socket(zmq.PULL) continuation_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.CONTINUATION_PORT)) if not local: management_request_socket = context.socket(zmq.REQ) management_request_socket.setsockopt(zmq.RCVTIMEO, 500) # By setting this flag, zmq matches replies with requests. management_request_socket.setsockopt(zmq.REQ_CORRELATE, 1) # Relax strict alternation between request and reply. # For detailed explanation, see here: http://api.zeromq.org/4-1:zmq-setsockopt management_request_socket.setsockopt(zmq.REQ_RELAXED, 1) management_request_socket.connect( sched_utils.get_scheduler_list_address(mgmt_ip)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(func_call_queue_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) poller.register(continuation_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, policy_type, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if func_call_queue_socket in socks and socks[ func_call_queue_socket] == zmq.POLLIN: call_function_from_queue(func_call_queue_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: start_t = int(time.time() * 1000000) call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) sched_t = int(time.time() * 1000000) logging.info( f'App function {name} recv: {start_t}, scheduled: {sched_t}') dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname.name not in call_frequency: call_frequency[fname.name] = 0 policy.update_function_locations(status.function_locations) if continuation_socket in socks and socks[continuation_socket] == \ zmq.POLLIN: start_t = int(time.time() * 1000000) continuation = Continuation() continuation.ParseFromString(continuation_socket.recv()) call = continuation.call call.name = continuation.name result = Value() result.ParseFromString(continuation.result) dag, sources = dags[call.name] for source in sources: call.function_args[source].values.extend([result]) call_dag(call, pusher_cache, dags, policy, continuation.id) sched_t = int(time.time() * 1000000) print( f'App function {call.name} recv: {start_t}, scheduled: {sched_t}' ) for fname in dag.functions: call_frequency[fname.name] += 1 end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if not local: latest_schedulers = sched_utils.get_ip_set( management_request_socket, False) if latest_schedulers: schedulers = latest_schedulers if end - start > REPORT_THRESHOLD: status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.debug('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()