def test_exec_function_normal(self): ''' Tests creating and executing a function in normal mode, ensuring that no messages are sent outside of the system and that the serialized result is as expected. ''' # Create the function and put it into the KVS. def func(_, x): return x * x fname = 'square' arg = 2 # Put the function into the KVS and create a function call. create_function(func, self.kvs_client, fname) call = self._create_function_call(fname, [arg], NORMAL) self.socket.inbox.append(call.SerializeToString()) # Execute the function call. exec_function(self.socket, self.kvs_client, self.user_library, {}) # Assert that there have been 0 messages sent. self.assertEqual(len(self.socket.outbox), 0) # Retrieve the result, ensure it is a LWWPairLattice, then deserialize # it. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), LWWPairLattice) result = serializer.load_lattice(result) # Check that the output is equal to a local function execution. self.assertEqual(result, func('', arg))
def test_exec_function_causal(self): ''' Tests creating and executing a function in causal mode, ensuring that no messages are sent outside of the system and that the serialized result is as expected. ''' # Create the function and serialize it into a lattice. def func(_, x): return x * x fname = 'square' create_function(func, self.kvs_client, fname, SingleKeyCausalLattice) arg = 2 # Put the function into the KVS and create a function call. call = self._create_function_call(fname, [arg], MULTI) self.socket.inbox.append(call.SerializeToString()) # Execute the function call. exec_function(self.socket, self.kvs_client, self.user_library, {}) # Assert that there have been 0 messages sent. self.assertEqual(len(self.socket.outbox), 0) # Retrieve the result, ensure it is a MultiKeyCausalLattice, then # deserialize it. Also check to make sure we have an empty vector clock # because this request populated no dependencies. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), MultiKeyCausalLattice) self.assertEqual(result.vector_clock, DEFAULT_VC) result = serializer.load_lattice(result)[0] # Check that the output is equal to a local function execution. self.assertEqual(result, func('', arg))
def test_exec_with_ordered_set(self): ''' Tests a single function execution with an ordered set input as an argument to validate that ordered sets are correctly handled. ''' def func(_, x): return len(x) >= 2 and x[0] < x[1] fname = 'set_order' arg_value = [2, 3] arg_name = 'set' self.kvs_client.put(arg_name, serializer.dump_lattice(arg_value)) # Put the function into the KVS and create a function call. create_function(func, self.kvs_client, fname) call = self._create_function_call(fname, [DropletReference(arg_name, True)], NORMAL) self.socket.inbox.append(call.SerializeToString()) # Execute the function call. exec_function(self.socket, self.kvs_client, self.user_library, {}) # Assert that there have been 0 messages sent. self.assertEqual(len(self.socket.outbox), 0) # Retrieve the result, ensure it is a LWWPairLattice, then deserialize # it. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), LWWPairLattice) result = serializer.load_lattice(result) # Check that the output is equal to a local function execution. self.assertEqual(result, func('', arg_value))
def test_exec_function_with_error(self): ''' Attempts to executet a function that raises an error during its execution. Ensures that an error is returned to the user. ''' e_msg = 'This is a broken_function!' def func(_, x): raise ValueError(e_msg) fname = 'func' create_function(func, self.kvs_client, fname) # Put the functin into the KVS and create a function call. call = self._create_function_call(fname, [''], NORMAL) self.socket.inbox.append(call.SerializeToString()) # Execute the function call. exec_function(self.socket, self.kvs_client, self.user_library, {}) # Retrieve the result from the KVS and ensure that it is the correct # lattice type. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), LWWPairLattice) result = serializer.load_lattice(result) # Check the type and values of the error. self.assertEqual(type(result), tuple) self.assertTrue(e_msg in result[0]) # Unpack the GenericResponse and check its values. response = GenericResponse() response.ParseFromString(result[1]) self.assertEqual(response.success, False) self.assertEqual(response.error, EXECUTION_ERROR)
def test_exec_function_nonexistent(self): ''' Attempts to execute a non-existent function and ensures that an error is thrown and returned to the user. ''' # Create a call to a function that does not exist. call = self._create_function_call('bad_func', [1], NORMAL) self.socket.inbox.append(call.SerializeToString()) # Attempt to execute the nonexistent function. exec_function(self.socket, self.kvs_client, self.user_library, {}) # Assert that there have been 0 messages sent. self.assertEqual(len(self.socket.outbox), 0) # Retrieve the result from the KVS and ensure that it is a lattice as # we expect. Deserialize it and check for an error. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), LWWPairLattice) result = serializer.load_lattice(result) # Check the type and values of the error. self.assertEqual(type(result), tuple) self.assertEqual(result[0], 'ERROR') # Unpack the GenericResponse and check its values. response = GenericResponse() response.ParseFromString(result[1]) self.assertEqual(response.success, False) self.assertEqual(response.error, FUNC_NOT_FOUND)
def test_exec_func_with_causal_ref(self): ''' Tests a function execution where the argument is a reference to the KVS in causal mode. Ensures that the result has the correct causal dependencies and metadata. ''' # Create the function and serialize it into a lattice. def func(_, x): return x * x fname = 'square' create_function(func, self.kvs_client, fname, SingleKeyCausalLattice) # Put an argument value into the KVS. arg_value = 2 arg_name = 'key' self.kvs_client.put( arg_name, serializer.dump_lattice(arg_value, MultiKeyCausalLattice)) # Create and serialize the function call. call = self._create_function_call(fname, [DropletReference(arg_name, True)], MULTI) self.socket.inbox.append(call.SerializeToString()) # Execute the function call. exec_function(self.socket, self.kvs_client, self.user_library, {}) # Assert that there have been 0 messages sent. self.assertEqual(len(self.socket.outbox), 0) # Retrieve the result, ensure it is a MultiKeyCausalLattice, then # deserialize it. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), MultiKeyCausalLattice) self.assertEqual(result.vector_clock, DEFAULT_VC) self.assertEqual(len(result.dependencies.reveal()), 1) self.assertTrue(arg_name in result.dependencies.reveal()) self.assertEqual(result.dependencies.reveal()[arg_name], DEFAULT_VC) result = serializer.load_lattice(result)[0] # Check that the output is equal to a local function execution. self.assertEqual(result, func('', arg_value))
def test_exec_class_function(self): ''' Tests creating and executing a class method in normal mode, ensuring that no messages are sent outside of the system and that the serialized result is as expected. ''' # Create the function and put it into the KVS. class Test: def __init__(self, num): self.num = num def run(self, droplet, inp): return inp + self.num fname = 'class' init_arg = 3 arg = 2 # Put the function into the KVS and create a function call. create_function((Test, (init_arg, )), self.kvs_client, fname) call = self._create_function_call(fname, [arg], NORMAL) self.socket.inbox.append(call.SerializeToString()) # Execute the function call. exec_function(self.socket, self.kvs_client, self.user_library, {}) # Assert that there have been 0 messages sent. self.assertEqual(len(self.socket.outbox), 0) # Retrieve the result, ensure it is a LWWPairLattice, then deserialize # it. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), LWWPairLattice) result = serializer.load_lattice(result) # Check that the output is equal to a local function execution. self.assertEqual(result, Test(init_arg).run('', arg))
def test_exec_func_with_ref(self): ''' Tests a function execution where the argument is a reference to the KVS in normal mode. ''' # Create the function and serialize it into a lattice. def func(_, x): return x * x fname = 'square' create_function(func, self.kvs_client, fname) # Put an argument value into the KVS. arg_value = 2 arg_name = 'key' self.kvs_client.put(arg_name, serializer.dump_lattice(arg_value)) # Create and serialize the function call. call = self._create_function_call(fname, [DropletReference(arg_name, True)], NORMAL) self.socket.inbox.append(call.SerializeToString()) # Execute the function call. exec_function(self.socket, self.kvs_client, self.user_library, {}) # Assert that there have been 0 messages sent. self.assertEqual(len(self.socket.outbox), 0) # Retrieve the result, ensure it is a LWWPairLattice, then deserialize # it. result = self.kvs_client.get(self.response_key)[self.response_key] self.assertEqual(type(result), LWWPairLattice) result = serializer.load_lattice(result) # Check that the output is equal to a local function execution. self.assertEqual(result, func('', arg_value))
def executor(ip, mgmt_ip, schedulers, thread_id): logging.basicConfig(filename='log_executor.txt', level=logging.INFO, format='%(asctime)s %(message)s') context = zmq.Context(1) poller = zmq.Poller() pin_socket = context.socket(zmq.PULL) pin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.PIN_PORT + thread_id)) unpin_socket = context.socket(zmq.PULL) unpin_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.UNPIN_PORT + thread_id)) exec_socket = context.socket(zmq.PULL) exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.FUNC_EXEC_PORT + thread_id)) dag_queue_socket = context.socket(zmq.PULL) dag_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_QUEUE_PORT + thread_id)) dag_exec_socket = context.socket(zmq.PULL) dag_exec_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.DAG_EXEC_PORT + thread_id)) self_depart_socket = context.socket(zmq.PULL) self_depart_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.SELF_DEPART_PORT + thread_id)) pusher_cache = SocketCache(context, zmq.PUSH) poller = zmq.Poller() poller.register(pin_socket, zmq.POLLIN) poller.register(unpin_socket, zmq.POLLIN) poller.register(exec_socket, zmq.POLLIN) poller.register(dag_queue_socket, zmq.POLLIN) poller.register(dag_exec_socket, zmq.POLLIN) poller.register(self_depart_socket, zmq.POLLIN) # If the management IP is set to None, that means that we are running in # local mode, so we use a regular AnnaTcpClient rather than an IPC client. if mgmt_ip: client = AnnaIpcClient(thread_id, context) else: client = AnnaTcpClient('127.0.0.1', '127.0.0.1', local=True, offset=1) user_library = DropletUserLibrary(context, pusher_cache, ip, thread_id, client) status = ThreadStatus() status.ip = ip status.tid = thread_id status.running = True utils.push_status(schedulers, pusher_cache, status) departing = False # Maintains a request queue for each function pinned on this executor. Each # function will have a set of request IDs mapped to it, and this map stores # a schedule for each request ID. queue = {} # Tracks the actual function objects that are pinned to this executor. pinned_functions = {} # Tracks runtime cost of excuting a DAG function. runtimes = {} # If multiple triggers are necessary for a function, track the triggers as # we receive them. This is also used if a trigger arrives before its # corresponding schedule. received_triggers = {} # Tracks when we received a function request, so we can report end-to-end # latency for the whole executio. receive_times = {} # Tracks the number of requests we are finishing for each function pinned # here. exec_counts = {} # Tracks the end-to-end runtime of each DAG request for which we are the # sink function. dag_runtimes = {} # A map with KVS keys and their corresponding deserialized payloads. cache = {} # Internal metadata to track thread utilization. report_start = time.time() event_occupancy = {'pin': 0.0, 'unpin': 0.0, 'func_exec': 0.0, 'dag_queue': 0.0, 'dag_exec': 0.0} total_occupancy = 0.0 while True: socks = dict(poller.poll(timeout=1000)) if pin_socket in socks and socks[pin_socket] == zmq.POLLIN: work_start = time.time() pin(pin_socket, pusher_cache, client, status, pinned_functions, runtimes, exec_counts) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['pin'] += elapsed total_occupancy += elapsed if unpin_socket in socks and socks[unpin_socket] == zmq.POLLIN: work_start = time.time() unpin(unpin_socket, status, pinned_functions, runtimes, exec_counts) utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['unpin'] += elapsed total_occupancy += elapsed if exec_socket in socks and socks[exec_socket] == zmq.POLLIN: work_start = time.time() exec_function(exec_socket, client, user_library, cache) user_library.close() utils.push_status(schedulers, pusher_cache, status) elapsed = time.time() - work_start event_occupancy['func_exec'] += elapsed total_occupancy += elapsed if dag_queue_socket in socks and socks[dag_queue_socket] == zmq.POLLIN: work_start = time.time() schedule = DagSchedule() schedule.ParseFromString(dag_queue_socket.recv()) fname = schedule.target_function logging.info('Received a schedule for DAG %s (%s), function %s.' % (schedule.dag.name, schedule.id, fname)) if fname not in queue: queue[fname] = {} queue[fname][schedule.id] = schedule if (schedule.id, fname) not in receive_times: receive_times[(schedule.id, fname)] = time.time() # In case we receive the trigger before we receive the schedule, we # can trigger from this operation as well. trkey = (schedule.id, fname) if (trkey in received_triggers and (len(received_triggers[trkey]) == len(schedule.triggers))): exec_dag_function(pusher_cache, client, received_triggers[trkey], pinned_functions[fname], schedule, user_library, dag_runtimes, cache) user_library.close() del received_triggers[trkey] del queue[fname][schedule.id] fend = time.time() fstart = receive_times[(schedule.id, fname)] runtimes[fname].append(fend - fstart) exec_counts[fname] += 1 elapsed = time.time() - work_start event_occupancy['dag_queue'] += elapsed total_occupancy += elapsed if dag_exec_socket in socks and socks[dag_exec_socket] == zmq.POLLIN: work_start = time.time() trigger = DagTrigger() trigger.ParseFromString(dag_exec_socket.recv()) fname = trigger.target_function logging.info('Received a trigger for schedule %s, function %s.' % (trigger.id, fname)) key = (trigger.id, fname) if key not in received_triggers: received_triggers[key] = {} if (trigger.id, fname) not in receive_times: receive_times[(trigger.id, fname)] = time.time() received_triggers[key][trigger.source] = trigger if fname in queue and trigger.id in queue[fname]: schedule = queue[fname][trigger.id] if len(received_triggers[key]) == len(schedule.triggers): exec_dag_function(pusher_cache, client, received_triggers[key], pinned_functions[fname], schedule, user_library, dag_runtimes, cache) user_library.close() del received_triggers[key] del queue[fname][trigger.id] fend = time.time() fstart = receive_times[(trigger.id, fname)] runtimes[fname].append(fend - fstart) exec_counts[fname] += 1 elapsed = time.time() - work_start event_occupancy['dag_exec'] += elapsed total_occupancy += elapsed if self_depart_socket in socks and socks[self_depart_socket] == \ zmq.POLLIN: # This message does not matter. self_depart_socket.recv() logging.info('Preparing to depart. No longer accepting requests ' + 'and clearing all queues.') status.ClearField('functions') status.running = False utils.push_status(schedulers, pusher_cache, status) departing = True # periodically report function occupancy report_end = time.time() if report_end - report_start > REPORT_THRESH: cache.clear() utilization = total_occupancy / (report_end - report_start) status.utilization = utilization # Periodically report my status to schedulers with the utilization # set. utils.push_status(schedulers, pusher_cache, status) logging.info('Total thread occupancy: %.6f' % (utilization)) for event in event_occupancy: occ = event_occupancy[event] / (report_end - report_start) logging.info('\tEvent %s occupancy: %.6f' % (event, occ)) event_occupancy[event] = 0.0 stats = ExecutorStatistics() for fname in runtimes: if exec_counts[fname] > 0: fstats = stats.functions.add() fstats.name = fname fstats.call_count = exec_counts[fname] fstats.runtime.extend(runtimes[fname]) runtimes[fname].clear() exec_counts[fname] = 0 for dname in dag_runtimes: dstats = stats.dags.add() dstats.name = dname dstats.runtimes.extend(dag_runtimes[dname]) dag_runtimes[dname].clear() # If we are running in cluster mode, mgmt_ip will be set, and we # will report our status and statistics to it. Otherwise, we will # write to the local conf file if mgmt_ip: sckt = pusher_cache.get(sutils.get_statistics_report_address (mgmt_ip)) sckt.send(stats.SerializeToString()) sckt = pusher_cache.get(utils.get_util_report_address(mgmt_ip)) sckt.send(status.SerializeToString()) else: logging.info(stats) status.ClearField('utilization') report_start = time.time() total_occupancy = 0.0 # Periodically clear any old functions we have cached that we are # no longer accepting requests for. for fname in queue: if len(queue[fname]) == 0 and fname not in status.functions: del queue[fname] del pinned_functions[fname] del runtimes[fname] del exec_counts[fname] # If we are departing and have cleared our queues, let the # management server know, and exit the process. if departing and len(queue) == 0: sckt = pusher_cache.get(utils.get_depart_done_addr(mgmt_ip)) sckt.send_string(ip) # We specifically pass 1 as the exit code when ending our # process so that the wrapper script does not restart us. os._exit(1)