def map_schedule_onto_host_or_device(kernel): from loopy.kernel import kernel_state assert kernel.state == kernel_state.SCHEDULED from functools import partial device_prog_name_gen = partial( kernel.get_var_name_generator(), kernel.target.device_program_name_prefix + kernel.name + kernel.target.device_program_name_suffix) if not kernel.target.split_kernel_at_global_barriers(): new_schedule = ([ CallKernel(kernel_name=device_prog_name_gen(), extra_args=[], extra_inames=[]) ] + list(kernel.schedule) + [ReturnFromKernel(kernel_name=kernel.name)]) kernel = kernel.copy(schedule=new_schedule) else: kernel = map_schedule_onto_host_or_device_impl(kernel, device_prog_name_gen) return restore_and_save_temporaries(add_extra_args_to_schedule(kernel))
def map_schedule_onto_host_or_device(kernel): # FIXME: Should be idempotent. from loopy.kernel import KernelState assert kernel.state == KernelState.LINEARIZED from functools import partial device_prog_name_gen = partial( kernel.get_var_name_generator(), kernel.target.device_program_name_prefix + kernel.name + kernel.target.device_program_name_suffix) if not kernel.target.split_kernel_at_global_barriers(): new_schedule = ([ CallKernel(kernel_name=device_prog_name_gen(), extra_args=[], extra_inames=[]) ] + list(kernel.schedule) + [ReturnFromKernel(kernel_name=kernel.name)]) kernel = kernel.copy(schedule=new_schedule) else: kernel = map_schedule_onto_host_or_device_impl(kernel, device_prog_name_gen) return kernel
def map_schedule_onto_host_or_device_impl(kernel, device_prog_name_gen): schedule = kernel.schedule loop_bounds = get_block_boundaries(schedule) # {{{ inner mapper function dummy_call = CallKernel(kernel_name="", extra_args=[], extra_inames=[]) dummy_return = ReturnFromKernel(kernel_name="") def inner_mapper(start_idx, end_idx, new_schedule): schedule_required_splitting = False i = start_idx current_chunk = [] while i <= end_idx: sched_item = schedule[i] if isinstance(sched_item, RunInstruction): current_chunk.append(sched_item) i += 1 elif isinstance(sched_item, EnterLoop): loop_end = loop_bounds[i] inner_schedule = [] loop_required_splitting = inner_mapper(i + 1, loop_end - 1, inner_schedule) start_item = schedule[i] end_item = schedule[loop_end] i = loop_end + 1 if loop_required_splitting: schedule_required_splitting = True if current_chunk: new_schedule.extend([dummy_call.copy()] + current_chunk + [dummy_return.copy()]) new_schedule.extend([start_item] + inner_schedule + [end_item]) current_chunk = [] else: current_chunk.extend([start_item] + inner_schedule + [end_item]) elif isinstance(sched_item, Barrier): if sched_item.synchronization_kind == "global": # Wrap the current chunk into a kernel call. schedule_required_splitting = True if current_chunk: new_schedule.extend([dummy_call.copy()] + current_chunk + [dummy_return.copy()]) new_schedule.append(sched_item) current_chunk = [] else: current_chunk.append(sched_item) i += 1 else: raise LoopyError("unexpected type of schedule item: %s" % type(sched_item).__name__) if current_chunk and schedule_required_splitting: # Wrap remainder of schedule into a kernel call. new_schedule.extend([dummy_call.copy()] + current_chunk + [dummy_return.copy()]) else: new_schedule.extend(current_chunk) return schedule_required_splitting # }}} new_schedule = [] split_kernel = inner_mapper(0, len(schedule) - 1, new_schedule) if not split_kernel: # Wrap everything into a kernel call. new_schedule = ([dummy_call.copy()] + new_schedule + [dummy_return.copy()]) # Assign names, extra_inames to CallKernel / ReturnFromKernel instructions inames = [] for idx, sched_item in enumerate(new_schedule): if isinstance(sched_item, CallKernel): last_kernel_name = device_prog_name_gen() new_schedule[idx] = sched_item.copy(kernel_name=last_kernel_name, extra_inames=list(inames)) elif isinstance(sched_item, ReturnFromKernel): new_schedule[idx] = sched_item.copy(kernel_name=last_kernel_name) elif isinstance(sched_item, EnterLoop): inames.append(sched_item.iname) elif isinstance(sched_item, LeaveLoop): inames.pop() new_kernel = kernel.copy(schedule=new_schedule) return new_kernel
def map_schedule_onto_host_or_device_impl(kernel, device_prog_name_gen): schedule = kernel.schedule loop_bounds = get_block_boundaries(schedule) # {{{ Inner mapper function dummy_call = CallKernel(kernel_name="", extra_args=[], extra_inames=[]) dummy_return = ReturnFromKernel(kernel_name="") def inner_mapper(start_idx, end_idx, new_schedule): schedule_required_splitting = False i = start_idx current_chunk = [] while i <= end_idx: sched_item = schedule[i] if isinstance(sched_item, RunInstruction): current_chunk.append(sched_item) i += 1 elif isinstance(sched_item, EnterLoop): loop_end = loop_bounds[i] inner_schedule = [] loop_required_splitting = inner_mapper( i + 1, loop_end - 1, inner_schedule) start_item = schedule[i] end_item = schedule[loop_end] i = loop_end + 1 if loop_required_splitting: schedule_required_splitting = True if current_chunk: new_schedule.extend( [dummy_call.copy()] + current_chunk + [dummy_return.copy()]) new_schedule.extend( [start_item] + inner_schedule + [end_item]) current_chunk = [] else: current_chunk.extend( [start_item] + inner_schedule + [end_item]) elif isinstance(sched_item, Barrier): if sched_item.kind == "global": # Wrap the current chunk into a kernel call. schedule_required_splitting = True if current_chunk: new_schedule.extend( [dummy_call.copy()] + current_chunk + [dummy_return.copy()]) current_chunk = [] else: current_chunk.append(sched_item) i += 1 else: raise LoopyError("unexepcted type of schedule item: %s" % type(sched_item).__name__) if current_chunk and schedule_required_splitting: # Wrap remainder of schedule into a kernel call. new_schedule.extend( [dummy_call.copy()] + current_chunk + [dummy_return.copy()]) else: new_schedule.extend(current_chunk) return schedule_required_splitting # }}} new_schedule = [] split_kernel = inner_mapper(0, len(schedule) - 1, new_schedule) if not split_kernel: # Wrap everything into a kernel call. new_schedule = ( [dummy_call.copy()] + new_schedule + [dummy_return.copy()]) # Assign names, extra_inames to CallKernel / ReturnFromKernel instructions inames = [] for idx, sched_item in enumerate(new_schedule): if isinstance(sched_item, CallKernel): last_kernel_name = device_prog_name_gen() new_schedule[idx] = sched_item.copy( kernel_name=last_kernel_name, extra_inames=list(inames)) elif isinstance(sched_item, ReturnFromKernel): new_schedule[idx] = sched_item.copy( kernel_name=last_kernel_name) elif isinstance(sched_item, EnterLoop): inames.append(sched_item.iname) elif isinstance(sched_item, LeaveLoop): inames.pop() new_kernel = kernel.copy(schedule=new_schedule) return new_kernel