def recurse(task, val): if val == 0: nonlocal counter assert counter_vals[counter] counter_vals[counter] = False counter += 1 else: run_task(recurse, (val - 1, ), []) run_task(recurse, (val - 1, ), [])
def recursion_with_manual_continuation(task, val): if val == 0: nonlocal counter counter += 1 else: t1 = run_task(recursion_with_manual_continuation, (val - 1, ), []) t2 = run_task(recursion_with_manual_continuation, (val - 1, ), []) if val == 1: nonlocal outermost outermost.append(t1) outermost.append(t2)
def tasks_with_deps(task): counter = 0 def increment(task): nonlocal counter counter += 1 def check(task): assert counter == 1 first = run_task(increment, tuple(), []) second = run_task(check, tuple(), [])
def recursion_without_continuation(task): counter = 0 counter_vals = [True, True, True, True] def recurse(task, val): if val == 0: nonlocal counter assert counter_vals[counter] counter_vals[counter] = False counter += 1 else: run_task(recurse, (val - 1, ), []) run_task(recurse, (val - 1, ), []) run_task(recurse, [2], [])
def test_flag_increment(): external_flag = 0 def increment_flag(task): nonlocal external_flag external_flag += 1 tsk = run_task(increment_flag, tuple(), []) assert external_flag
def test_recursion_with_finalization(): outermost = [] counter = 0 def recursion_with_manual_continuation(task, val): if val == 0: nonlocal counter counter += 1 else: t1 = run_task(recursion_with_manual_continuation, (val - 1, ), []) t2 = run_task(recursion_with_manual_continuation, (val - 1, ), []) if val == 1: nonlocal outermost outermost.append(t1) outermost.append(t2) run_task(recursion_with_manual_continuation, (3, ), []) def check_counter(task): assert counter == 8 run_task(check_counter, tuple(), outermost)
def test_exception_handling(): class CustomException(Exception): pass def raise_exc(task): raise CustomException("error") try: tsk = run_task(raise_exc, tuple(), []) except: success = True else: success = False assert success
def decorator(body): nonlocal taskid if inspect.isgeneratorfunction(body): raise TypeError( "Spawned tasks must be normal functions or coroutines; not generators." ) # Compute the flat dependency set (including unwrapping TaskID objects) deps = [] for ds in dependencies: if not isinstance(ds, Iterable): ds = (ds, ) for d in ds: if hasattr(d, "task"): d = d.task if not isinstance(d, task_runtime.Task): raise TypeError("Dependencies must be TaskIDs or Tasks: " + str(d)) deps.append(d) if inspect.iscoroutine(body): # An already running coroutine does not need changes since we assume # it was changed correctly when the original function was spawned. separated_body = body else: # Perform a horrifying hack to build a new function which will # not be able to observe changes in the original cells in the # tasks outer scope. To do this we build a new function with a # replaced closure which contains new cells. separated_body = type(body)(body.__code__, body.__globals__, body.__name__, body.__defaults__, closure=body.__closure__ and tuple( _make_cell(x.cell_contents) for x in body.__closure__)) separated_body.__annotations__ = body.__annotations__ separated_body.__doc__ = body.__doc__ separated_body.__kwdefaults__ = body.__kwdefaults__ separated_body.__module__ = body.__module__ data = _TaskData(_task_locals, separated_body, dependencies) if not taskid: taskid = TaskID("global_" + str(len(_task_locals.global_tasks)), len(_task_locals.global_tasks)) _task_locals.global_tasks += [taskid] taskid.data = data taskid.dependencies = dependencies data.taskid = taskid # Spawn the task via the Parla runtime API task = task_runtime.run_task(_task_callback, (data, ), deps, queue_identifier=placement) # Store the task object in it's ID object taskid.task = task logger.debug("Created: %s <%s, %s, %r>", taskid, placement, body) for scope in _task_locals.task_scopes: scope.append(task) # Return the task object return task