def test_exec_function_with_error(self): ''' Attempts to executet a function that raises an error during its execution. Ensures that an error is returned to the user. ''' e_msg = 'This is a broken_function!' def func(_, x): raise ValueError(e_msg) fname = 'func' create_function(func, self.kvs_client, fname) # Put the functin into the KVS and create a function call. call = self._create_function_call(fname, [''], NORMAL) self.socket.inbox.append(call.SerializeToString()) # Execute the function call. exec_function(self.socket, self.kvs_client, self.user_library, {}, {}) # Retrieve the result from the KVS and ensure that it is the correct # lattice type. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), LWWPairLattice) result = serializer.load_lattice(result) # Check the type and values of the error. self.assertEqual(type(result), tuple) self.assertTrue(e_msg in result[0]) # Unpack the GenericResponse and check its values. response = GenericResponse() response.ParseFromString(result[1]) self.assertEqual(response.success, False) self.assertEqual(response.error, EXECUTION_ERROR)
def test_exec_function_nonexistent(self): ''' Attempts to execute a non-existent function and ensures that an error is thrown and returned to the user. ''' # Create a call to a function that does not exist. call = self._create_function_call('bad_func', [1], NORMAL) self.socket.inbox.append(call.SerializeToString()) # Attempt to execute the nonexistent function. exec_function(self.socket, self.kvs_client, self.user_library, {}, {}) # Assert that there have been 0 messages sent. self.assertEqual(len(self.socket.outbox), 0) # Retrieve the result from the KVS and ensure that it is a lattice as # we expect. Deserialize it and check for an error. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), LWWPairLattice) result = serializer.load_lattice(result) # Check the type and values of the error. self.assertEqual(type(result), tuple) self.assertEqual(result[0], 'ERROR') # Unpack the GenericResponse and check its values. response = GenericResponse() response.ParseFromString(result[1]) self.assertEqual(response.success, False) self.assertEqual(response.error, FUNC_NOT_FOUND)
def register_dag(self, name, functions, connections): ''' Registers a new DAG with the system. This operation will fail if any of the functions provided cannot be identified in the system. name: A unique name for this DAG. functions: A list of names of functions to be included in this DAG. connections: A list of ordered pairs of function names that represent the edges in this DAG. ''' flist = self._get_func_list() for fname in functions: if fname not in flist: logging.info( 'Function %s not registered. Please register before \ including it in a DAG.' % (fname)) return False, None dag = Dag() dag.name = name dag.functions.extend(functions) for pair in connections: conn = dag.connections.add() conn.source = pair[0] conn.sink = pair[1] self.dag_create_sock.send(dag.SerializeToString()) r = GenericResponse() r.ParseFromString(self.dag_create_sock.recv()) return r.success, r.error
def test_function_call_no_resources(self): ''' Constructs a scenario where there are no available resources in the system, and ensures that the scheduler correctly returns an error to the user. ''' # Clear all executors from the system. self.policy.thread_statuses.clear() self.policy.unpinned_executors.clear() # Create a function call. call = FunctionCall() call.name = 'function' call.request_id = 12 val = call.arguments.values.add() serializer.dump(2, val) self.socket.inbox.append(call.SerializeToString()) # Execute the scheduling policy. call_function(self.socket, self.pusher_cache, self.policy) # Check that the correct number of messages were sent. self.assertEqual(len(self.socket.outbox), 1) self.assertEqual(len(self.pusher_cache.socket.outbox), 0) # Extract and deserialize the messages. response = GenericResponse() response.ParseFromString(self.socket.outbox[0]) self.assertFalse(response.success) self.assertEqual(response.error, NO_RESOURCES)
def test_succesful_pin(self): ''' This test executes a pin operation that is supposed to be successful, and it checks to make sure that that the correct metadata for execution and reporting is generated. ''' # Create a new function in the KVS. fname = 'incr' def func(_, x): return x + 1 create_function(func, self.kvs_client, fname) # Create a pin message and put it into the socket. msg = PinFunction(name=fname, response_address=self.ip) self.socket.inbox.append(msg.SerializeToString()) # Execute the pin operation. pin(self.socket, self.pusher_cache, self.kvs_client, self.status, self.pinned_functions, self.runtimes, self.exec_counts, self.user_library, False, False) # Check that the correct messages were sent and the correct metadata # created. self.assertEqual(len(self.pusher_cache.socket.outbox), 1) response = GenericResponse() response.ParseFromString(self.pusher_cache.socket.outbox[0]) self.assertTrue(response.success) self.assertEqual(func('', 1), self.pinned_functions[fname]('', 1)) self.assertTrue(fname in self.pinned_functions) self.assertTrue(fname in self.runtimes) self.assertTrue(fname in self.exec_counts) self.assertTrue(fname in self.status.functions)
def register(self, function, name): ''' Registers a new function or class with the system. The returned object can be called like a regular Python function, which returns a Cloudburst Future. If the input is a class, the class is expected to have a run method, which is what is invoked at runtime. function: The function object that we are registering. name: A unique name for the function to be stored with in the system. ''' func = Function() func.name = name func.body = serializer.dump(function) self.func_create_sock.send(func.SerializeToString()) resp = GenericResponse() resp.ParseFromString(self.func_create_sock.recv()) if resp.success: registered_functon = CloudburstFunction(name, self, self.kvs_client) # print("55", self.kvs_client, "in register 66") return registered_functon else: raise RuntimeError( f'Unexpected error while registering function: {resp}.')
def test_create_gpu_dag_no_resources(self): # Create a simple two-function DAG and add it to the inbound socket. dag_name = 'dag' dag = create_linear_dag([None], ['fn'], self.kvs_client, dag_name) dag.functions[0].gpu = True self.socket.inbox.append(dag.SerializeToString()) dags = {} call_frequency = {} create_dag(self.socket, self.pusher_cache, self.kvs_client, dags, self.policy, call_frequency) # Check that an error was returned to the user. self.assertEqual(len(self.socket.outbox), 1) response = GenericResponse() response.ParseFromString(self.socket.outbox[0]) self.assertFalse(response.success) self.assertEqual(response.error, NO_RESOURCES) # Test that the correct pin messages were sent. self.assertEqual(len(self.pusher_cache.socket.outbox), 0) # Check that no additional messages were sent. self.assertEqual(len(self.policy.unpinned_cpu_executors), 0) self.assertEqual(len(self.policy.function_locations), 0) self.assertEqual(len(self.policy.pending_dags), 0) # Check that no additional metadata was created or sent. self.assertEqual(len(call_frequency), 0) self.assertEqual(len(dags), 0)
def call_dag(self, dname, arg_map, direct_response=False, consistency=NORMAL, output_key=None, client_id=None): ''' Issues a new request to execute the DAG. Returns a CloudburstFuture that dname: The name of the DAG to cexecute. arg_map: A map from function names to lists of arguments for each of the functions in the DAG. direct_response: If True, the response will be synchronously received by the client; otherwise, the result will be stored in the KVS. consistency: The consistency mode to use with this function: either NORMAL or MULTI. output_key: The KVS key in which to store the result of thie DAG. client_id: An optional ID associated with an individual client across requests; this is used for causal metadata. ''' dc = DagCall() dc.name = dname dc.consistency = consistency if output_key: dc.output_key = output_key if client_id: dc.client_id = client_id for fname in arg_map: fname_args = arg_map[fname] if type(fname_args) != list: fname_args = [fname_args] args = [serializer.dump(arg, serialize=False) for arg in fname_args] al = dc.function_args[fname] al.values.extend(args) if direct_response: dc.response_address = self.response_address self.dag_call_sock.send(dc.SerializeToString()) r = GenericResponse() r.ParseFromString(self.dag_call_sock.recv()) if direct_response: try: result = self.response_sock.recv() return serializer.load(result) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: return None else: raise e else: if r.success: return CloudburstFuture(r.response_id, self.kvs_client, serializer) else: return None
def pin_function(self, dag_name, function_ref): # If there are no functions left to choose from, then we return None, # indicating that we ran out of resources to use. if len(self.unpinned_executors) == 0: return False if dag_name not in self.pending_dags: self.pending_dags[dag_name] = [] # Make a copy of the set of executors, so that we don't modify the # system's metadata. candidates = set(self.unpinned_executors) # Construct a PinFunction message to be sent to executors. pin_msg = PinFunction() pin_msg.name = function_ref.name pin_msg.response_address = self.ip serialized = pin_msg.SerializeToString() while True: # Pick a random executor from the set of candidates and attempt to # pin this function there. node, tid = sys_random.sample(candidates, 1)[0] sckt = self.pusher_cache.get(get_pin_address(node, tid)) sckt.send(serialized) response = GenericResponse() try: response.ParseFromString(self.pin_accept_socket.recv()) except zmq.ZMQError: logging.error('Pin operation to %s:%d timed out. Retrying.' % (node, tid)) continue # Do not use this executor either way: If it rejected, it has # something else pinned, and if it accepted, it has pinned what we # just asked it to pin. # In local model allow executors to have multiple functions pinned if not self.local: self.unpinned_executors.discard((node, tid)) candidates.discard((node, tid)) if response.success: # The pin operation succeeded, so we return the node and thread # ID to the caller. self.pending_dags[dag_name].append((function_ref.name, (node, tid))) return True else: # The pin operation was rejected, remove node and try again. logging.error('Node %s:%d rejected pin for %s. Retrying.' % (node, tid, function_ref.name)) continue
def delete_dag(self, dname): ''' Removes the specified DAG from the system. dname: The name of the DAG to delete. ''' self.dag_delete_sock.send_string(dname) r = GenericResponse() r.ParseFromString(self.dag_delete_sock.recv()) return r.success, r.error
def test_delete_dag(self): ''' We attempt to delete a DAG that has already been created and check to ensure that the correct unpin messages are sent to executors and that the metadata is updated appropriately. ''' # Create a simple two fucntion DAG and add it to the system metadata. source = 'source' sink = 'sink' dag_name = 'dag' dag = create_linear_dag([None, None], [source, sink], self.kvs_client, dag_name) dags = {} call_frequency = {} dags[dag.name] = (dag, {source}) call_frequency[source] = 100 call_frequency[sink] = 100 # Add the correct metadata to the policy engine. source_location = (self.ip, 1) sink_location = (self.ip, 2) self.policy.function_locations[source] = {source_location} self.policy.function_locations[sink] = {sink_location} self.socket.inbox.append(dag.name) # Attempt to delete the DAG. delete_dag(self.socket, dags, self.policy, call_frequency) # Check that the correct unpin messages were sent. messages = self.pusher_cache.socket.outbox self.assertEqual(len(messages), 2) self.assertEqual(messages[0], source) self.assertEqual(messages[1], sink) addresses = self.pusher_cache.addresses self.assertEqual(len(addresses), 2) self.assertEqual(addresses[0], get_unpin_address(*source_location)) self.assertEqual(addresses[1], get_unpin_address(*sink_location)) # Check that the server metadata was updated correctly. self.assertEqual(len(dags), 0) self.assertEqual(len(call_frequency), 0) # Check that the correct message was sent to the user. self.assertEqual(len(self.socket.outbox), 1) response = GenericResponse() response.ParseFromString(self.socket.outbox.pop()) self.assertTrue(response.success)
def call_function(func_call_socket, pusher_cache, policy): # Parse the received protobuf for this function call. call = FunctionCall() call.ParseFromString(func_call_socket.recv()) # If there is no response key set for this request, we generate a random # UUID. if not call.response_key: call.response_key = str(uuid.uuid4()) # Filter the arguments for CloudburstReferences, and use the policy engine to # pick a node for this request. refs = list( filter(lambda arg: type(arg) == CloudburstReference, map(lambda arg: serializer.load(arg), call.arguments.values))) result = policy.pick_executor(refs) response = GenericResponse() if result is None: response.success = False response.error = NO_RESOURCES func_call_socket.send(response.SerializeToString()) return # Forward the request on to the chosen executor node. ip, tid = result sckt = pusher_cache.get(utils.get_exec_address(ip, tid)) sckt.send(call.SerializeToString()) # Send a success response to the user with the response key. response.success = True response.response_id = call.response_key func_call_socket.send(response.SerializeToString())
def register_dag(self, name, functions, connections): ''' Registers a new DAG with the system. This operation will fail if any of the functions provided cannot be identified in the system. name: A unique name for this DAG. functions: A list of names of functions to be included in this DAG. connections: A list of ordered pairs of function names that represent the edges in this DAG. ''' flist = self._get_func_list() for fname in functions: if isinstance(fname, tuple): fname = fname[0] if fname not in flist: raise RuntimeError( f'Function {fname} not registered. Please register before ' + 'including it in a DAG.') dag = Dag() dag.name = name for function in functions: ref = dag.functions.add() if type(function) == tuple: fname = function[0] invalids = function[1] ref.type = MULTIEXEC else: fname = function invalids = [] ref.name = fname for invalid in invalids: ref.invalid_results.append(serializer.dump(invalid)) for pair in connections: conn = dag.connections.add() conn.source = pair[0] conn.sink = pair[1] self.dag_create_sock.send(dag.SerializeToString()) r = GenericResponse() r.ParseFromString(self.dag_create_sock.recv()) return r.success, r.error
def exec_func(self, name, args): call = FunctionCall() call.name = name call.request_id = self.rid for arg in args: argobj = call.arguments.values.add() serializer.dump(arg, argobj) self.func_call_sock.send(call.SerializeToString()) r = GenericResponse() r.ParseFromString(self.func_call_sock.recv()) self.rid += 1 return r.response_id
def test_call_function_with_refs(self): ''' Creates a scenario where the policy should deterministically pick the same executor to run a request on: There is one reference, and it's cached only on the node we create in this test. ''' # Add a new executor for which we will construct cached references. ip_address = '192.168.0.1' new_key = (ip_address, 2) self.policy.unpinned_executors.add(new_key) # Create a new reference and add its metadata. ref_name = 'reference' self.policy.key_locations[ref_name] = [ip_address] # Create a function call that asks for this reference. call = FunctionCall() call.name = 'function' call.request_id = 12 val = call.arguments.values.add() serializer.dump(CloudburstReference(ref_name, True), val) self.socket.inbox.append(call.SerializeToString(0)) # Execute the scheduling policy. call_function(self.socket, self.pusher_cache, self.policy) # Check that the correct number of messages were sent. self.assertEqual(len(self.socket.outbox), 1) self.assertEqual(len(self.pusher_cache.socket.outbox), 1) # Extract and deserialize the messages. response = GenericResponse() forwarded = FunctionCall() response.ParseFromString(self.socket.outbox[0]) forwarded.ParseFromString(self.pusher_cache.socket.outbox[0]) self.assertTrue(response.success) self.assertEqual(response.response_id, forwarded.response_key) self.assertEqual(forwarded.name, call.name) self.assertEqual(forwarded.request_id, call.request_id) # Makes sure that the correct executor was chosen. self.assertEqual(len(self.pusher_cache.addresses), 1) self.assertEqual(self.pusher_cache.addresses[0], utils.get_exec_address(*new_key))
def test_occupied_pin(self): ''' This test attempts to pin a function onto a node where another function is already pinned. We currently only allow one pinned node per machine, so this operation should fail. ''' # Create a new function in the KVS. fname = 'incr' def func(_, x): return x + 1 create_function(func, self.kvs_client, fname) # Create a pin message and put it into the socket. msg = PinFunction(name=fname, response_address=self.ip) self.socket.inbox.append(msg.SerializeToString()) # Add an already pinned_function, so that we reject the request. self.pinned_functions['square'] = lambda _, x: x * x self.runtimes['square'] = [] self.exec_counts['square'] = [] self.status.functions.append('square') # Execute the pin operation. pin(self.socket, self.pusher_cache, self.kvs_client, self.status, self.pinned_functions, self.runtimes, self.exec_counts, self.user_library, False, False) # Check that the correct messages were sent and the correct metadata # created. self.assertEqual(len(self.pusher_cache.socket.outbox), 1) response = GenericResponse() response.ParseFromString(self.pusher_cache.socket.outbox[0]) self.assertFalse(response.success) # Make sure that none of the metadata was corrupted with this failed # pin attempt self.assertTrue(fname not in self.pinned_functions) self.assertTrue(fname not in self.runtimes) self.assertTrue(fname not in self.exec_counts) self.assertTrue(fname not in self.status.functions)
def test_delete_nonexistent_dag(self): ''' This test attempts to delete a nonexistent DAG and ensures that no metadata is affected by the failed operation. ''' # Make a request to delete an unknown DAG. self.socket.inbox.append('dag') delete_dag(self.socket, {}, self.policy, {}) # Ensure that an error response is sent to the user. self.assertEqual(len(self.socket.outbox), 1) response = GenericResponse() response.ParseFromString(self.socket.outbox[0]) self.assertFalse(response.success) self.assertEqual(response.error, NO_SUCH_DAG) # Check that no additional messages were sent and no metadata changed. self.assertEqual(len(self.pusher_cache.socket.outbox), 0) self.assertEqual(len(self.policy.function_locations), 0) self.assertEqual(len(self.policy.unpinned_executors), 0)
def test_create_dag_already_exists(self): ''' This test attempts to create a DAG that already exists and makes sure that the server correctly rejects the request. ''' # Create a simple two-function DAG and add it to the inbound socket. source = 'source' sink = 'sink' dag_name = 'dag' dag = create_linear_dag([None, None], [source, sink], self.kvs_client, dag_name) self.socket.inbox.append(dag.SerializeToString()) # Add this to the existing server metadata. dags = {dag.name: (dag, {source})} # Add relevant metadata to the policy engine. address_set = {(self.ip, 1), (self.ip, 2)} self.policy.unpinned_executors.update(address_set) # Attempt to create the DAG. call_frequency = {} create_dag(self.socket, self.pusher_cache, self.kvs_client, dags, self.policy, call_frequency) # Check that an error was returned to the user. self.assertEqual(len(self.socket.outbox), 1) response = GenericResponse() response.ParseFromString(self.socket.outbox[0]) self.assertFalse(response.success) self.assertEqual(response.error, DAG_ALREADY_EXISTS) # Check that no additional metadata was created or sent. self.assertEqual(len(self.pusher_cache.socket.outbox), 0) self.assertEqual(len(self.policy.unpinned_executors), 2) self.assertEqual(len(self.policy.function_locations), 0) self.assertEqual(len(self.policy.pending_dags), 0)
def test_call_function_no_refs(self): ''' A basic test that makes sure that an executor is successfully chosen when there is only one possible executor to choose from. ''' # Create a new function call for a function that doesn't exist. call = FunctionCall() call.name = 'function' call.request_id = 12 # Add an argument to thhe function. val = call.arguments.values.add() serializer.dump(2, val) self.socket.inbox.append(call.SerializeToString()) # Execute the scheduling policy. call_function(self.socket, self.pusher_cache, self.policy) # Check that the correct number of messages were sent. self.assertEqual(len(self.socket.outbox), 1) self.assertEqual(len(self.pusher_cache.socket.outbox), 1) # Extract and deserialize the messages. response = GenericResponse() forwarded = FunctionCall() response.ParseFromString(self.socket.outbox[0]) forwarded.ParseFromString(self.pusher_cache.socket.outbox[0]) self.assertTrue(response.success) self.assertEqual(response.response_id, forwarded.response_key) self.assertEqual(forwarded.name, call.name) self.assertEqual(forwarded.request_id, call.request_id) # Makes sure that the correct executor was chosen. self.assertEqual(len(self.pusher_cache.addresses), 1) self.assertEqual(self.pusher_cache.addresses[0], utils.get_exec_address(*self.executor_key))
def scheduler(ip, mgmt_ip, route_addr, policy_type): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) context.set(zmq.MAX_SOCKETS, 10000) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = set() connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) # This is for handle the invocation from queue # Mainly for storage event func_call_queue_socket = context.socket(zmq.PULL) func_call_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_QUEUE_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 seconds. pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) continuation_socket = context.socket(zmq.PULL) continuation_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.CONTINUATION_PORT)) if not local: management_request_socket = context.socket(zmq.REQ) management_request_socket.setsockopt(zmq.RCVTIMEO, 500) # By setting this flag, zmq matches replies with requests. management_request_socket.setsockopt(zmq.REQ_CORRELATE, 1) # Relax strict alternation between request and reply. # For detailed explanation, see here: http://api.zeromq.org/4-1:zmq-setsockopt management_request_socket.setsockopt(zmq.REQ_RELAXED, 1) management_request_socket.connect( sched_utils.get_scheduler_list_address(mgmt_ip)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(func_call_queue_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) poller.register(continuation_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, policy_type, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if func_call_queue_socket in socks and socks[ func_call_queue_socket] == zmq.POLLIN: call_function_from_queue(func_call_queue_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: start_t = int(time.time() * 1000000) call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) sched_t = int(time.time() * 1000000) logging.info( f'App function {name} recv: {start_t}, scheduled: {sched_t}') dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname.name not in call_frequency: call_frequency[fname.name] = 0 policy.update_function_locations(status.function_locations) if continuation_socket in socks and socks[continuation_socket] == \ zmq.POLLIN: start_t = int(time.time() * 1000000) continuation = Continuation() continuation.ParseFromString(continuation_socket.recv()) call = continuation.call call.name = continuation.name result = Value() result.ParseFromString(continuation.result) dag, sources = dags[call.name] for source in sources: call.function_args[source].values.extend([result]) call_dag(call, pusher_cache, dags, policy, continuation.id) sched_t = int(time.time() * 1000000) print( f'App function {call.name} recv: {start_t}, scheduled: {sched_t}' ) for fname in dag.functions: call_frequency[fname.name] += 1 end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if not local: latest_schedulers = sched_utils.get_ip_set( management_request_socket, False) if latest_schedulers: schedulers = latest_schedulers if end - start > REPORT_THRESHOLD: status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.debug('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()
def pin_function(self, dag_name, function_ref, colocated): # If there are no functions left to choose from, then we return None, # indicating that we ran out of resources to use. if function_ref.gpu and len(self.unpinned_gpu_executors) == 0: return False elif not function_ref.gpu and len(self.unpinned_cpu_executors) == 0: return False if dag_name not in self.pending_dags: self.pending_dags[dag_name] = [] # Make a copy of the set of executors, so that we don't modify the # system's metadata. if function_ref.gpu: candidates = set(self.unpinned_gpu_executors) elif len(colocated) == 0: # If this is not a GPU function, just look at all of the unpinned # executors. candidates = set(self.unpinned_cpu_executors) else: candidates = set() already_pinned = set() for fn, thread in self.pending_dags[dag_name]: if fn in colocated: already_pinned.add((fn, thread)) candidate_nodes = set() if len(already_pinned) > 0: for fn, thread in already_pinned: candidate_nodes.add(thread[0]) # The node's IP for node, tid in self.unpinned_cpu_executors: if node in candidate_nodes: candidates.add((node, tid)) else: # If this is the first colocate to be pinned, try to assign to # an empty node. nodes = {} for node, tid in self.unpinned_cpu_executors: if node not in nodes: nodes[node] = 0 nodes[node] += 1 for node in nodes: if nodes[node] == NUM_EXECUTOR_THREADS: for i in range(NUM_EXECUTOR_THREADS): candidates.add((node, i)) if len(candidates) == 0: # There no valid executors to colocate on. return self.pin_function(dag_name, function_ref, []) # Construct a PinFunction message to be sent to executors. pin_msg = PinFunction() pin_msg.name = function_ref.name pin_msg.batching = function_ref.batching pin_msg.response_address = self.ip serialized = pin_msg.SerializeToString() while True: # Pick a random executor from the set of candidates and attempt to # pin this function there. node, tid = sys_random.sample(candidates, 1)[0] sckt = self.pusher_cache.get(get_pin_address(node, tid)) sckt.send(serialized) response = GenericResponse() try: response.ParseFromString(self.pin_accept_socket.recv()) except zmq.ZMQError: logging.error('Pin operation to %s:%d timed out. Retrying.' % (node, tid)) continue # Do not use this executor either way: If it rejected, it has # something else pinned, and if it accepted, it has pinned what we # just asked it to pin. In local mode, however we allow executors # to have multiple functions pinned. if not self.local: if function_ref.gpu: self.unpinned_gpu_executors.discard((node, tid)) candidates.discard((node, tid)) else: self.unpinned_cpu_executors.discard((node, tid)) candidates.discard((node, tid)) if response.success: # The pin operation succeeded, so we return the node and thread # ID to the caller. self.pending_dags[dag_name].append((function_ref.name, (node, tid))) return True else: # The pin operation was rejected, remove node and try again. logging.error('Node %s:%d rejected pin for %s. Retrying.' % (node, tid, function_ref.name)) continue if len(candidates) == 0 and len(colocated) > 0: # Try again without colocation. return self.pin_function(self, dag_name, function_ref, [])
def test_create_gpu_dag(self): # Create a simple two-function DAG and add it to the inbound socket. dag_name = 'dag' fn = 'fn' dag = create_linear_dag([None], [fn], self.kvs_client, dag_name) dag.functions[0].gpu = True self.socket.inbox.append(dag.SerializeToString()) dags = {} call_frequency = {} address_set = {(self.ip, 1)} self.policy.unpinned_gpu_executors.update(address_set) self.pin_socket.inbox.append(sutils.ok_resp) create_dag(self.socket, self.pusher_cache, self.kvs_client, dags, self.policy, call_frequency) # Test that the correct metadata was created. self.assertTrue(dag_name in dags) created, dag_source = dags[dag_name] self.assertEqual(created, dag) self.assertEqual(len(dag_source), 1) self.assertEqual(list(dag_source)[0], fn) self.assertTrue(fn in call_frequency) self.assertEqual(call_frequency[fn], 0) # Test that the DAG is stored in the KVS correctly. result = self.kvs_client.get(dag_name)[dag_name] created = Dag() created.ParseFromString(result.reveal()) self.assertEqual(created, dag) # Test that the correct response was returned to the user. self.assertTrue(len(self.socket.outbox), 1) response = GenericResponse() response.ParseFromString(self.socket.outbox.pop()) self.assertTrue(response.success) # Test that the correct pin messages were sent. self.assertEqual(len(self.pusher_cache.socket.outbox), 1) messages = self.pusher_cache.socket.outbox function_set = {fn} for message in messages: pin_msg = PinFunction() pin_msg.ParseFromString(message) self.assertEqual(pin_msg.response_address, self.ip) self.assertTrue(pin_msg.name in function_set) function_set.discard(pin_msg.name) self.assertEqual(len(function_set), 0) for address in address_set: self.assertTrue( get_pin_address(*address) in self.pusher_cache.addresses) # Test that the policy engine has the correct metadata stored. self.assertEqual(len(self.policy.unpinned_cpu_executors), 0) self.assertEqual(len(self.policy.pending_dags), 0) self.assertTrue(fn in self.policy.function_locations) self.assertEqual(len(self.policy.function_locations[fn]), 1)
DAG_EXEC_PORT = 4040 SELF_DEPART_PORT = 4050 STATUS_PORT = 5007 SCHED_UPDATE_PORT = 5008 BACKOFF_PORT = 5009 PIN_ACCEPT_PORT = 5010 CONTINUATION_PORT = 5011 # For message sending via the user library. RECV_INBOX_PORT = 5500 STATISTICS_REPORT_PORT = 7006 # Create a generic error response protobuf. error = GenericResponse() error.success = False # Create a generic success response protobuf. ok = GenericResponse() ok.success = True ok_resp = ok.SerializeToString() # Create a default vector clock for keys that have no dependencies. DEFAULT_VC = VectorClock({'base': MaxIntLattice(1)}) def get_func_kvs_name(fname): return FUNC_PREFIX + fname
def call_dag(call, pusher_cache, dags, policy): dag, sources = dags[call.name] schedule = DagSchedule() schedule.id = str(uuid.uuid4()) schedule.dag.CopyFrom(dag) schedule.start_time = time.time() schedule.consistency = call.consistency if call.response_address: schedule.response_address = call.response_address if call.output_key: schedule.output_key = call.output_key if call.client_id: schedule.client_id = call.client_id for fref in dag.functions: args = call.function_args[fref.name].values refs = list( filter(lambda arg: type(arg) == CloudburstReference, map(lambda arg: serializer.load(arg), args))) result = policy.pick_executor(refs, fref.name) if result is None: response = GenericResponse() response.success = False response.error = NO_RESOURCES return response ip, tid = result schedule.locations[fref.name] = ip + ':' + str(tid) # copy over arguments into the dag schedule arg_list = schedule.arguments[fref.name] arg_list.values.extend(args) for fref in dag.functions: loc = schedule.locations[fref.name].split(':') ip = utils.get_queue_address(loc[0], loc[1]) schedule.target_function = fref.name triggers = sutils.get_dag_predecessors(dag, fref.name) if len(triggers) == 0: triggers.append('BEGIN') schedule.ClearField('triggers') schedule.triggers.extend(triggers) sckt = pusher_cache.get(ip) sckt.send(schedule.SerializeToString()) for source in sources: trigger = DagTrigger() trigger.id = schedule.id trigger.source = 'BEGIN' trigger.target_function = source ip = sutils.get_dag_trigger_address(schedule.locations[source]) sckt = pusher_cache.get(ip) sckt.send(trigger.SerializeToString()) response = GenericResponse() response.success = True if schedule.output_key: response.response_id = schedule.output_key else: response.response_id = schedule.id return response
def scheduler(ip, mgmt_ip, route_addr): # If the management IP is not set, we are running in local mode. local = (mgmt_ip is None) kvs = AnnaTcpClient(route_addr, ip, local=local) scheduler_id = str(uuid.uuid4()) context = zmq.Context(1) # A mapping from a DAG's name to its protobuf representation. dags = {} # Tracks how often a request for each function is received. call_frequency = {} # Tracks the time interval between successive requests for a particular # DAG. interarrivals = {} # Tracks the most recent arrival for each DAG -- used to calculate # interarrival times. last_arrivals = {} # Maintains a list of all other schedulers in the system, so we can # propagate metadata to them. schedulers = [] connect_socket = context.socket(zmq.REP) connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT)) func_create_socket = context.socket(zmq.REP) func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT)) func_call_socket = context.socket(zmq.REP) func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT)) dag_create_socket = context.socket(zmq.REP) dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT)) dag_call_socket = context.socket(zmq.REP) dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT)) dag_delete_socket = context.socket(zmq.REP) dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT)) list_socket = context.socket(zmq.REP) list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT)) exec_status_socket = context.socket(zmq.PULL) exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT)) sched_update_socket = context.socket(zmq.PULL) sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SCHED_UPDATE_PORT)) pin_accept_socket = context.socket(zmq.PULL) pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500) pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_ACCEPT_PORT)) requestor_cache = SocketCache(context, zmq.REQ) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(connect_socket, zmq.POLLIN) poller.register(func_create_socket, zmq.POLLIN) poller.register(func_call_socket, zmq.POLLIN) poller.register(dag_create_socket, zmq.POLLIN) poller.register(dag_call_socket, zmq.POLLIN) poller.register(dag_delete_socket, zmq.POLLIN) poller.register(list_socket, zmq.POLLIN) poller.register(exec_status_socket, zmq.POLLIN) poller.register(sched_update_socket, zmq.POLLIN) # Start the policy engine. policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket, pusher_cache, kvs, ip, local=local) policy.update() start = time.time() while True: socks = dict(poller.poll(timeout=1000)) if connect_socket in socks and socks[connect_socket] == zmq.POLLIN: msg = connect_socket.recv_string() connect_socket.send_string(route_addr) if (func_create_socket in socks and socks[func_create_socket] == zmq.POLLIN): create_function(func_create_socket, kvs) if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN: call_function(func_call_socket, pusher_cache, policy) if (dag_create_socket in socks and socks[dag_create_socket] == zmq.POLLIN): create_dag(dag_create_socket, pusher_cache, kvs, dags, policy, call_frequency) if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN: call = DagCall() call.ParseFromString(dag_call_socket.recv()) name = call.name t = time.time() if name in last_arrivals: if name not in interarrivals: interarrivals[name] = [] interarrivals[name].append(t - last_arrivals[name]) last_arrivals[name] = t if name not in dags: resp = GenericResponse() resp.success = False resp.error = NO_SUCH_DAG dag_call_socket.send(resp.SerializeToString()) continue dag = dags[name] for fname in dag[0].functions: call_frequency[fname.name] += 1 response = call_dag(call, pusher_cache, dags, policy) dag_call_socket.send(response.SerializeToString()) if (dag_delete_socket in socks and socks[dag_delete_socket] == zmq.POLLIN): delete_dag(dag_delete_socket, dags, policy, call_frequency) if list_socket in socks and socks[list_socket] == zmq.POLLIN: msg = list_socket.recv_string() prefix = msg if msg else '' resp = StringSet() resp.keys.extend(sched_utils.get_func_list(kvs, prefix)) list_socket.send(resp.SerializeToString()) if exec_status_socket in socks and socks[exec_status_socket] == \ zmq.POLLIN: status = ThreadStatus() status.ParseFromString(exec_status_socket.recv()) policy.process_status(status) if sched_update_socket in socks and socks[sched_update_socket] == \ zmq.POLLIN: status = SchedulerStatus() status.ParseFromString(sched_update_socket.recv()) # Retrieve any DAGs that some other scheduler knows about that we # do not yet know about. for dname in status.dags: if dname not in dags: payload = kvs.get(dname) while None in payload: payload = kvs.get(dname) dag = Dag() dag.ParseFromString(payload[dname].reveal()) dags[dag.name] = (dag, sched_utils.find_dag_source(dag)) for fname in dag.functions: if fname not in call_frequency: call_frequency[fname] = 0 policy.update_function_locations(status.function_locations) end = time.time() if end - start > METADATA_THRESHOLD: # Update the scheduler policy-related metadata. policy.update() # If the management IP is None, that means we arre running in # local mode, so there is no need to deal with caches and other # schedulers. if mgmt_ip: schedulers = sched_utils.get_ip_set( sched_utils.get_scheduler_list_address(mgmt_ip), requestor_cache, False) if end - start > REPORT_THRESHOLD: num_unique_executors = policy.get_unique_executors() key = scheduler_id + ':' + str(time.time()) data = {'key': key, 'count': num_unique_executors} status = SchedulerStatus() for name in dags.keys(): status.dags.append(name) for fname in policy.function_locations: for loc in policy.function_locations[fname]: floc = status.function_locations.add() floc.name = fname floc.ip = loc[0] floc.tid = loc[1] msg = status.SerializeToString() for sched_ip in schedulers: if sched_ip != ip: sckt = pusher_cache.get( sched_utils.get_scheduler_update_address(sched_ip)) sckt.send(msg) stats = ExecutorStatistics() for fname in call_frequency: fstats = stats.functions.add() fstats.name = fname fstats.call_count = call_frequency[fname] logging.info('Reporting %d calls for function %s.' % (call_frequency[fname], fname)) call_frequency[fname] = 0 for dname in interarrivals: dstats = stats.dags.add() dstats.name = dname dstats.call_count = len(interarrivals[dname]) + 1 dstats.interarrival.extend(interarrivals[dname]) interarrivals[dname].clear() # We only attempt to send the statistics if we are running in # cluster mode. If we are running in local mode, we write them to # the local log file. if mgmt_ip: sckt = pusher_cache.get( sutils.get_statistics_report_address(mgmt_ip)) sckt.send(stats.SerializeToString()) start = time.time()
def test_create_dag(self): ''' This test creates a new DAG, checking that the correct pin messages are sent to executors and that it is persisted in the KVS correctly. It also checks that the server metadata was updated as expected. ''' # Create a simple two-function DAG and add it to the inbound socket. source = 'source' sink = 'sink' dag_name = 'dag' dag = create_linear_dag([None, None], [source, sink], self.kvs_client, dag_name) self.socket.inbox.append(dag.SerializeToString()) # Add relevant metadata to the policy engine. address_set = {(self.ip, 1), (self.ip, 2)} self.policy.unpinned_executors.update(address_set) # Prepopulate the pin_accept socket with sufficient success messages. self.pin_socket.inbox.append(sutils.ok_resp) self.pin_socket.inbox.append(sutils.ok_resp) # Call the DAG creation method. dags = {} call_frequency = {} create_dag(self.socket, self.pusher_cache, self.kvs_client, dags, self.policy, call_frequency) # Test that the correct metadata was created. self.assertTrue(dag_name in dags) created, dag_source = dags[dag_name] self.assertEqual(created, dag) self.assertEqual(len(dag_source), 1) self.assertEqual(list(dag_source)[0], source) self.assertTrue(source in call_frequency) self.assertTrue(sink in call_frequency) self.assertEqual(call_frequency[source], 0) self.assertEqual(call_frequency[sink], 0) # Test that the DAG is stored in the KVS correctly. result = self.kvs_client.get(dag_name)[dag_name] created = Dag() created.ParseFromString(result.reveal()) self.assertEqual(created, dag) # Test that the correct response was returned to the user. self.assertTrue(len(self.socket.outbox), 1) response = GenericResponse() response.ParseFromString(self.socket.outbox.pop()) self.assertTrue(response.success) # Test that the correct pin messages were sent. self.assertEqual(len(self.pusher_cache.socket.outbox), 2) messages = self.pusher_cache.socket.outbox function_set = {source, sink} for message in messages: self.assertTrue(':' in message) ip, fname = message.split(':') self.assertEqual(ip, self.ip) self.assertTrue(fname in function_set) function_set.discard(fname) self.assertEqual(len(function_set), 0) for address in address_set: self.assertTrue( get_pin_address(*address) in self.pusher_cache.addresses) # Test that the policy engine has the correct metadata stored. self.assertEqual(len(self.policy.unpinned_executors), 0) self.assertEqual(len(self.policy.pending_dags), 0) self.assertTrue(source in self.policy.function_locations) self.assertTrue(sink in self.policy.function_locations) self.assertEqual(len(self.policy.function_locations[source]), 1) self.assertEqual(len(self.policy.function_locations[sink]), 1)
def test_create_dag_insufficient_resources(self): ''' This test attempts to create a DAG even though there are not enough free executors in the system. It checks that a pin message is attempted to be sent, we run out of resources, and then the request is rejected. We check that the metadata is properly restored back to its original state. ''' # Create a simple two-function DAG and add it to the inbound socket. source = 'source' sink = 'sink' dag_name = 'dag' dag = create_linear_dag([None, None], [source, sink], self.kvs_client, dag_name) self.socket.inbox.append(dag.SerializeToString()) # Add relevant metadata to the policy engine, but set the number of # executors to fewer than needed. address_set = {(self.ip, 1)} self.policy.unpinned_executors.update(address_set) # Prepopulate the pin_accept socket with sufficient success messages. self.pin_socket.inbox.append(sutils.ok_resp) # Attempt to create the DAG. dags = {} call_frequency = {} create_dag(self.socket, self.pusher_cache, self.kvs_client, dags, self.policy, call_frequency) # Check that an error was returned to the user. self.assertEqual(len(self.socket.outbox), 1) response = GenericResponse() response.ParseFromString(self.socket.outbox[0]) self.assertFalse(response.success) self.assertEqual(response.error, NO_RESOURCES) # Test that the correct pin messages were sent. self.assertEqual(len(self.pusher_cache.socket.outbox), 2) messages = self.pusher_cache.socket.outbox # Checks for the pin message. self.assertTrue(':' in messages[0]) ip, fname = messages[0].split(':') self.assertEqual(ip, self.ip) self.assertEqual(source, fname) # Checks for the unpin message. self.assertEqual(messages[1], source) address = random.sample(address_set, 1)[0] addresses = self.pusher_cache.addresses self.assertEqual(get_pin_address(*address), addresses[0]) self.assertEqual(get_unpin_address(*address), addresses[1]) # Check that no additional messages were sent. self.assertEqual(len(self.policy.unpinned_executors), 0) self.assertEqual(len(self.policy.function_locations), 0) self.assertEqual(len(self.policy.pending_dags), 0) # Check that no additional metadata was created or sent. self.assertEqual(len(call_frequency), 0) self.assertEqual(len(dags), 0)