def scheduler(ip, mgmt_ip, route_addr): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = [] connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500) pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) requestor_cache = SocketCache(context, zmq.REQ) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname not in call_frequency: call_frequency[fname] = 0 policy.update_function_locations(status.function_locations) end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if mgmt_ip: schedulers = sched_utils.get_ip_set( sched_utils.get_scheduler_list_address(mgmt_ip), requestor_cache, False) if end - start > REPORT_THRESHOLD: num_unique_executors = policy.get_unique_executors() key = scheduler_id + ':' + str(time.time()) data = {'key': key, 'count': num_unique_executors} status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.info('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()
class TestDefaultSchedulerPolicy(unittest.TestCase): ''' This test suite tests the parts of the default scheduler policy that aren't covered by the scheduler creation and call test cases. In particular, most of these test cases have to do with the metadata management that is invoked periodically from the server. ''' def setUp(self): self.pusher_cache = zmq_utils.MockPusherCache() self.socket = zmq_utils.MockZmqSocket() self.pin_socket = zmq_utils.MockZmqSocket() self.kvs_client = kvs_client.MockAnnaClient() self.ip = '127.0.0.1' self.policy = DefaultCloudburstSchedulerPolicy(self.pin_socket, self.pusher_cache, self.kvs_client, self.ip, random_threshold=0) def tearDown(self): # Clear all policy metadata. self.policy.running_counts.clear() self.policy.backoff.clear() self.policy.key_locations.clear() self.policy.unpinned_executors.clear() self.policy.function_locations.clear() self.policy.pending_dags.clear() self.policy.thread_statuses.clear() def test_policy_ignore_overloaded(self): ''' This test ensures that the policy engine correctly ignores nodes that have either explicitly reported a high load recently or have received many calls in the recent past. ''' # Create two executors, one of which has received too many requests, # and the other of which has reported high load. address_set = {(self.ip, 1), (self.ip, 2)} self.policy.unpinned_executors.update(address_set) self.policy.backoff[(self.ip, 1)] = time.time() self.policy.running_counts[self.ip, 2] = set() for _ in range(1100): time.sleep(.0001) self.policy.running_counts[(self.ip, 2)].add(time.time()) # Ensure that we have returned None because both our valid executors # were overloaded. result = self.policy.pick_executor([]) self.assertEqual(result, None) def test_pin_reject(self): ''' This test explicitly rejects a pin request from the policy and ensures that it tries again to pin on another node. ''' # Create two unpinned executors. address_set = {(self.ip, 1), (self.ip, 2)} self.policy.unpinned_executors.update(address_set) # Create one failing and one successful response. self.pin_socket.inbox.append(sutils.ok_resp) self.pin_socket.inbox.append(sutils.error.SerializeToString()) success = self.policy.pin_function('dag', 'function') self.assertTrue(success) # Ensure that both remaining executors have been removed from unpinned # and that the DAG commit is pending. self.assertEqual(len(self.policy.unpinned_executors), 0) self.assertEqual(len(self.policy.pending_dags), 1) def test_process_status(self): ''' This test ensures that when a new status update is received from an executor, the local server metadata is correctly updated in the normal case. ''' # Construct a new thread status to pass into the policy engine. function_name = 'square' status = ThreadStatus() status.running = True status.ip = self.ip status.tid = 1 status.functions.append(function_name) status.utilization = 0.90 key = (status.ip, status.tid) # Process the newly created status. self.policy.process_status(status) self.assertTrue(key not in self.policy.unpinned_executors) self.assertTrue(key in self.policy.function_locations[function_name]) self.assertTrue(key in self.policy.backoff) def test_process_status_restart(self): ''' This tests checks that when we receive a status update from a restarted executor, we correctly update its metadata. ''' # Construct a new thread status to pass into the policy engine. function_name = 'square' status = ThreadStatus() status.running = True status.ip = self.ip status.tid = 1 status.functions.append(function_name) status.utilization = 0 key = (status.ip, status.tid) # Add metadata to the policy engine to make it think that this node # used to have a pinned function. self.policy.function_locations[function_name] = {key} self.policy.thread_statuses[key] = status # Clear the status' pinned functions (i.e., restart). status = ThreadStatus() status.ip = self.ip status.tid = 1 status.running = True status.utilization = 0 # Process the status and check the metadata. self.policy.process_status(status) self.assertEqual(len(self.policy.function_locations[function_name]), 0) self.assertTrue(key in self.policy.unpinned_executors) def test_process_status_not_running(self): ''' This test passes in a status for a server that is leaving the system and ensures that all the metadata reflects this fact after the processing. ''' # Construct a new thread status to prime the policy engine with. function_name = 'square' status = ThreadStatus() status.running = True status.ip = self.ip status.tid = 1 status.functions.append(function_name) status.utilization = 0 key = (status.ip, status.tid) # Add metadata to the policy engine to make it think that this node # used to have a pinned function. self.policy.function_locations[function_name] = {key} self.policy.thread_statuses[key] = status # Clear the status' fields to report that it is turning off. status = ThreadStatus() status.running = False status.ip = self.ip status.tid = 1 # Process the status and check the metadata. self.policy.process_status(status) self.assertTrue(key not in self.policy.thread_statuses) self.assertTrue(key not in self.policy.unpinned_executors) self.assertEqual(len(self.policy.function_locations[function_name]), 0) def test_metadata_update(self): ''' This test calls the periodic metadata update protocol and ensures that the correct metadata is removed from the system and that the correct metadata is retrieved/updated from the KVS. ''' # Create two executor threads on separate machines. old_ip = '127.0.0.1' new_ip = '192.168.0.1' old_executor = (old_ip, 1) new_executor = (new_ip, 2) old_status = ThreadStatus() old_status.ip = old_ip old_status.tid = 1 old_status.running = True new_status = ThreadStatus() new_status.ip = new_ip new_status.tid = 2 new_status.running = True self.policy.thread_statuses[old_executor] = old_status self.policy.thread_statuses[new_executor] = new_status # Add two executors, one with old an old backoff and one with a new # time. self.policy.backoff[old_executor] = time.time() - 10 self.policy.backoff[new_executor] = time.time() # For the new executor, add 10 old running times and 10 new ones. self.policy.running_counts[new_executor] = set() for _ in range(10): time.sleep(.0001) self.policy.running_counts[new_executor].add(time.time() - 10) for _ in range(10): time.sleep(.0001) self.policy.running_counts[new_executor].add(time.time()) # Publish some caching metadata into the KVS for each executor. old_set = StringSet() old_set.keys.extend(['key1', 'key2', 'key3']) new_set = StringSet() new_set.keys.extend(['key3', 'key4', 'key5']) self.kvs_client.put(get_cache_ip_key(old_ip), LWWPairLattice(0, old_set.SerializeToString())) self.kvs_client.put(get_cache_ip_key(new_ip), LWWPairLattice(0, new_set.SerializeToString())) self.policy.update() # Check that the metadata has been correctly pruned. self.assertEqual(len(self.policy.backoff), 1) self.assertTrue(new_executor in self.policy.backoff) self.assertEqual(len(self.policy.running_counts[new_executor]), 10) # Check that the caching information is correct. self.assertTrue(len(self.policy.key_locations['key1']), 1) self.assertTrue(len(self.policy.key_locations['key2']), 1) self.assertTrue(len(self.policy.key_locations['key3']), 2) self.assertTrue(len(self.policy.key_locations['key4']), 1) self.assertTrue(len(self.policy.key_locations['key5']), 1) self.assertTrue(old_ip in self.policy.key_locations['key1']) self.assertTrue(old_ip in self.policy.key_locations['key2']) self.assertTrue(old_ip in self.policy.key_locations['key3']) self.assertTrue(new_ip in self.policy.key_locations['key3']) self.assertTrue(new_ip in self.policy.key_locations['key4']) self.assertTrue(new_ip in self.policy.key_locations['key5']) def test_update_function_locations(self): ''' This test ensures that the update_function_locations method correctly updates local metadata about which functions are pinned on which nodes. ''' # Construct function location metadata received from another node. locations = {} function1 = 'square' function2 = 'incr' key1 = ('127.0.0.1', 0) key2 = ('127.0.0.2', 0) key3 = ('192.168.0.1', 0) locations[function1] = [key1, key3] locations[function2] = [key2, key3] # Serialize the location map in the expected protobuf. status = SchedulerStatus() for function_name in locations: for ip, tid in locations[function_name]: location = status.function_locations.add() location.name = function_name location.ip = ip location.tid = tid self.policy.update_function_locations(status.function_locations) self.assertEqual(len(self.policy.function_locations[function1]), 2) self.assertEqual(len(self.policy.function_locations[function2]), 2) self.assertTrue(key1 in self.policy.function_locations[function1]) self.assertTrue(key3 in self.policy.function_locations[function1]) self.assertTrue(key2 in self.policy.function_locations[function2]) self.assertTrue(key3 in self.policy.function_locations[function2])
def scheduler(ip, mgmt_ip, route_addr, policy_type): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) context.set(zmq.MAX_SOCKETS, 10000) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = set() connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) # This is for handle the invocation from queue # Mainly for storage event func_call_queue_socket = context.socket(zmq.PULL) func_call_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_QUEUE_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 seconds. pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) continuation_socket = context.socket(zmq.PULL) continuation_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.CONTINUATION_PORT)) if not local: management_request_socket = context.socket(zmq.REQ) management_request_socket.setsockopt(zmq.RCVTIMEO, 500) # By setting this flag, zmq matches replies with requests. management_request_socket.setsockopt(zmq.REQ_CORRELATE, 1) # Relax strict alternation between request and reply. # For detailed explanation, see here: http://api.zeromq.org/4-1:zmq-setsockopt management_request_socket.setsockopt(zmq.REQ_RELAXED, 1) management_request_socket.connect( sched_utils.get_scheduler_list_address(mgmt_ip)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(func_call_queue_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) poller.register(continuation_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, policy_type, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if func_call_queue_socket in socks and socks[ func_call_queue_socket] == zmq.POLLIN: call_function_from_queue(func_call_queue_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: start_t = int(time.time() * 1000000) call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) sched_t = int(time.time() * 1000000) logging.info( f'App function {name} recv: {start_t}, scheduled: {sched_t}') dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname.name not in call_frequency: call_frequency[fname.name] = 0 policy.update_function_locations(status.function_locations) if continuation_socket in socks and socks[continuation_socket] == \ zmq.POLLIN: start_t = int(time.time() * 1000000) continuation = Continuation() continuation.ParseFromString(continuation_socket.recv()) call = continuation.call call.name = continuation.name result = Value() result.ParseFromString(continuation.result) dag, sources = dags[call.name] for source in sources: call.function_args[source].values.extend([result]) call_dag(call, pusher_cache, dags, policy, continuation.id) sched_t = int(time.time() * 1000000) print( f'App function {call.name} recv: {start_t}, scheduled: {sched_t}' ) for fname in dag.functions: call_frequency[fname.name] += 1 end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if not local: latest_schedulers = sched_utils.get_ip_set( management_request_socket, False) if latest_schedulers: schedulers = latest_schedulers if end - start > REPORT_THRESHOLD: status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.debug('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()