def test_dag_call_no_refs(self): ''' Tests a DAG call without any references. We do not currently have a test for selecting a DAG call with references because the reference logic in the default policy is the same for individual functions (tested above) and for DAGs. ''' # Create a simple DAG. source = 'source' sink = 'sink' dag, source_address, sink_address = self._construct_dag_with_locations( source, sink) # Create a DAG call that corresponds to this new DAG. call = DagCall() call.name = dag.name call.consistency = NORMAL call.output_key = 'output_key' call.client_id = '0' # Execute the scheduling policy. call_dag(call, self.pusher_cache, {dag.name: (dag, {source})}, self.policy) # Check that the correct number of messages were sent. self.assertEqual(len(self.pusher_cache.socket.outbox), 3) # Extract each of the two schedules and ensure that they are correct. source_schedule = DagSchedule() source_schedule.ParseFromString(self.pusher_cache.socket.outbox[0]) self._verify_dag_schedule(source, 'BEGIN', source_schedule, dag, call) sink_schedule = DagSchedule() sink_schedule.ParseFromString(self.pusher_cache.socket.outbox[1]) self._verify_dag_schedule(sink, source, sink_schedule, dag, call) # Make sure that only trigger was sent, and it was for the DAG source. trigger = DagTrigger() trigger.ParseFromString(self.pusher_cache.socket.outbox[2]) self.assertEqual(trigger.id, source_schedule.id) self.assertEqual(trigger.target_function, source) self.assertEqual(trigger.source, 'BEGIN') self.assertEqual(len(trigger.version_locations), 0) self.assertEqual(len(trigger.dependencies), 0) # Ensure that all the the destination addresses match the addresses we # expect. self.assertEqual(len(self.pusher_cache.addresses), 3) self.assertEqual(self.pusher_cache.addresses[0], utils.get_queue_address(*source_address)) self.assertEqual(self.pusher_cache.addresses[1], utils.get_queue_address(*sink_address)) self.assertEqual( self.pusher_cache.addresses[2], sutils.get_dag_trigger_address(':'.join( map(lambda s: str(s), source_address))))
def _exec_dag_function_normal(pusher_cache, kvs, triggers, function, schedule, user_lib, cache): fname = schedule.target_function fargs = list(schedule.arguments[fname].values) for trigger in triggers: fargs += list(trigger.arguments.values) fargs = [serializer.load(arg) for arg in fargs] result = _exec_func_normal(kvs, function, fargs, user_lib, cache) this_ref = None for ref in schedule.dag.functions: if ref.name == fname: this_ref = ref # There must be a match. success = True if this_ref.type == MULTIEXEC: if serializer.dump(result) in this_ref.invalid_results: return False, False is_sink = True new_trigger = _construct_trigger(schedule.id, fname, result) for conn in schedule.dag.connections: if conn.source == fname: is_sink = False new_trigger.target_function = conn.sink dest_ip = schedule.locations[conn.sink] sckt = pusher_cache.get(sutils.get_dag_trigger_address(dest_ip)) sckt.send(new_trigger.SerializeToString()) if is_sink: if schedule.response_address: sckt = pusher_cache.get(schedule.response_address) logging.info('DAG %s (ID %s) result returned to requester.' % (schedule.dag.name, trigger.id)) sckt.send(serializer.dump(result)) else: lattice = serializer.dump_lattice(result) output_key = schedule.output_key if schedule.output_key \ else schedule.id logging.info('DAG %s (ID %s) result in KVS at %s.' % (schedule.dag.name, trigger.id, output_key)) kvs.put(output_key, lattice) return is_sink, success
def _exec_dag_function_causal(pusher_cache, kvs, triggers, function, schedule, user_lib): schedule = schedule[0] triggers = triggers[0] fname = schedule.target_function fargs = list(schedule.arguments[fname].values) key_version_locations = {} dependencies = {} for trigger in triggers: fargs += list(trigger.arguments.values) # Combine the locations of upstream cached key versions from all # triggers. for addr in trigger.version_locations: if addr in key_version_locations: key_version_locations[addr].extend( trigger.version_locations[addr].key_versions) else: key_version_locations[addr] = list( trigger.version_locations[addr]) # Combine the dependency sets from all triggers. for dependency in trigger.dependencies: vc = VectorClock(dict(dependency.vector_clock), True) key = dependency.key if key in dependencies: dependencies[key].merge(vc) else: dependencies[key] = vc fargs = [serializer.load(arg) for arg in fargs] result = _exec_func_causal(kvs, function, fargs, user_lib, schedule, key_version_locations, dependencies) this_ref = None for ref in schedule.dag.functions: if ref.name == fname: this_ref = ref # There must be a match. success = True if this_ref.type == MULTIEXEC: if serializer.dump(result) in this_ref.invalid_results: return False, False # Create a new trigger with the schedule ID and results of this execution. new_trigger = _construct_trigger(schedule.id, fname, result) # Serialize the key version location information into this new trigger. for addr in key_version_locations: new_trigger.version_locations[addr].keys.extend( key_version_locations[addr]) # Serialize the set of dependency versions for causal metadata. for key in dependencies: dep = new_trigger.dependencies.add() dep.key = key dependencies[key].serialize(dep.vector_clock) is_sink = True for conn in schedule.dag.connections: if conn.source == fname: is_sink = False new_trigger.target_function = conn.sink dest_ip = schedule.locations[conn.sink] sckt = pusher_cache.get(sutils.get_dag_trigger_address(dest_ip)) sckt.send(new_trigger.SerializeToString()) if is_sink: logging.info('DAG %s (ID %s) completed in causal mode; result at %s.' % (schedule.dag.name, schedule.id, schedule.output_key)) vector_clock = {} okey = schedule.output_key if okey in dependencies: prev_count = 0 if schedule.client_id in dependencies[okey]: prev_count = dependencies[okey][schedule.client_id] dependencies[okey].update(schedule.client_id, prev_count + 1) dependencies[okey].serialize(vector_clock) del dependencies[okey] else: vector_clock = {schedule.client_id: 1} # Serialize result into a MultiKeyCausalLattice. vector_clock = VectorClock(vector_clock, True) result = serializer.dump(result) dependencies = MapLattice(dependencies) lattice = MultiKeyCausalLattice(vector_clock, dependencies, SetLattice({result})) succeed = kvs.causal_put(schedule.output_key, lattice, schedule.client_id) while not succeed: succeed = kvs.causal_put(schedule.output_key, lattice, schedule.client_id) # Issues requests to all upstream caches for this particular request # and asks them to garbage collect pinned versions stored for the # context of this request. for cache_addr in key_version_locations: gc_address = utils.get_cache_gc_address(cache_addr) sckt = pusher_cache.get(gc_address) sckt.send_string(schedule.client_id) return is_sink, [success]
def _exec_dag_function_normal(pusher_cache, kvs, trigger_sets, function, schedules, user_lib, cache, schedulers, batching): fname = schedules[0].target_function # We construct farg_sets to have a request by request set of arguments. # That is, each element in farg_sets will have all the arguments for one # invocation. farg_sets = [] for schedule, trigger_set in zip(schedules, trigger_sets): fargs = list(schedule.arguments[fname].values) for trigger in trigger_set: fargs += list(trigger.arguments.values) fargs = [serializer.load(arg) for arg in fargs] farg_sets.append(fargs) if batching: fargs = [[]] * len(farg_sets[0]) for idx in range(len(fargs)): fargs[idx] = [fset[idx] for fset in farg_sets] else: # There will only be one thing in farg_sets fargs = farg_sets[0] result_list = _exec_func_normal(kvs, function, fargs, user_lib, cache) if not isinstance(result_list, list): result_list = [result_list] successes = [] is_sink = True for schedule, result in zip(schedules, result_list): this_ref = None for ref in schedule.dag.functions: if ref.name == fname: this_ref = ref # There must be a match. if this_ref.type == MULTIEXEC: if serializer.dump(result) in this_ref.invalid_results: successes.append(False) continue successes.append(True) new_trigger = _construct_trigger(schedule.id, fname, result) for conn in schedule.dag.connections: if conn.source == fname: is_sink = False new_trigger.target_function = conn.sink dest_ip = schedule.locations[conn.sink] sckt = pusher_cache.get( sutils.get_dag_trigger_address(dest_ip)) sckt.send(new_trigger.SerializeToString()) if is_sink: if schedule.continuation.name: for idx, pair in enumerate(zip(schedules, result_list)): schedule, result = pair if successes[idx]: cont = schedule.continuation cont.id = schedule.id cont.result = serializer.dump(result) logging.info( 'Sending continuation to scheduler for DAG %s.' % (schedule.id)) sckt = pusher_cache.get( utils.get_continuation_address(schedulers)) sckt.send(cont.SerializeToString()) elif schedule.response_address: for idx, pair in enumerate(zip(schedules, result_list)): schedule, result = pair if successes[idx]: sckt = pusher_cache.get(schedule.response_address) logging.info( 'DAG %s (ID %s) result returned to requester.' % (schedule.dag.name, trigger.id)) sckt.send(serializer.dump(result)) else: keys = [] lattices = [] for idx, pair in enumerate(zip(schedules, result_list)): schedule, result = pair if successes[idx]: lattice = serializer.dump_lattice(result) output_key = schedule.output_key if schedule.output_key \ else schedule.id logging.info('DAG %s (ID %s) result in KVS at %s.' % (schedule.dag.name, schedule.id, output_key)) keys.append(output_key) lattices.append(lattice) kvs.put(keys, lattices) return is_sink, successes
def call_dag(call, pusher_cache, dags, policy): dag, sources = dags[call.name] schedule = DagSchedule() schedule.id = str(uuid.uuid4()) schedule.dag.CopyFrom(dag) schedule.start_time = time.time() schedule.consistency = call.consistency if call.response_address: schedule.response_address = call.response_address if call.output_key: schedule.output_key = call.output_key if call.client_id: schedule.client_id = call.client_id for fref in dag.functions: args = call.function_args[fref.name].values refs = list( filter(lambda arg: type(arg) == CloudburstReference, map(lambda arg: serializer.load(arg), args))) result = policy.pick_executor(refs, fref.name) if result is None: response = GenericResponse() response.success = False response.error = NO_RESOURCES return response ip, tid = result schedule.locations[fref.name] = ip + ':' + str(tid) # copy over arguments into the dag schedule arg_list = schedule.arguments[fref.name] arg_list.values.extend(args) for fref in dag.functions: loc = schedule.locations[fref.name].split(':') ip = utils.get_queue_address(loc[0], loc[1]) schedule.target_function = fref.name triggers = sutils.get_dag_predecessors(dag, fref.name) if len(triggers) == 0: triggers.append('BEGIN') schedule.ClearField('triggers') schedule.triggers.extend(triggers) sckt = pusher_cache.get(ip) sckt.send(schedule.SerializeToString()) for source in sources: trigger = DagTrigger() trigger.id = schedule.id trigger.source = 'BEGIN' trigger.target_function = source ip = sutils.get_dag_trigger_address(schedule.locations[source]) sckt = pusher_cache.get(ip) sckt.send(trigger.SerializeToString()) response = GenericResponse() response.success = True if schedule.output_key: response.response_id = schedule.output_key else: response.response_id = schedule.id return response