def make_loop(start, stop, step, do_min = True): i = self.fresh_var(niters.type, "i") self.blocks.push() slice_stop = self.add(i, step, "next_bound") slice_stop_min = self.min(slice_stop, stop) if do_min \ else slice_stop tile_bounds = syntax.Slice(i, slice_stop_min, one_i64, type = slice_t) nested_args = [self.index_along_axis(arg, axis, tile_bounds) for arg, axis in zip(args, axes)] new_acc = self.fresh_var(tiled_map_fn.return_type, "new_acc") self.comment("TiledReduce in %s: map_fn " % self.fn.name) do_inline(tiled_map_fn, map_closure_args + nested_args, self.type_env, self.blocks.top(), result_var = new_acc) loop_body = self.blocks.pop() if acc_is_array: outidx = self.tuple([syntax_helpers.slice_none] * result.type.rank) result_slice = self.index(result, outidx, temp = False) self.comment("") do_inline(tiled_combine, combine_closure_args + [result, new_acc, result_slice], self.type_env, loop_body, result_var = None) else: do_inline(tiled_combine, combine_closure_args + [result, new_acc], self.type_env, loop_body, result_var = result_after) return syntax.ForLoop(i, start, stop, step, loop_body, merge)
def make_loop(start, stop, step, do_min = True): i = self.fresh_var(niters.type, "i") self.blocks.push() slice_stop = self.add(i, step, "slice_stop") slice_stop_min = self.min(slice_stop, niters, "slice_min") if do_min \ else slice_stop tile_bounds = syntax.Slice(i, slice_stop_min, one_i64, type = slice_t) nested_args = [self.index_along_axis(arg, axis, tile_bounds) for arg, axis in zip(args, axes)] out_idx = self.fixed_idx if expr.fixed_tile_size else self.nesting_idx output_region = self.index_along_axis(array_result, out_idx, tile_bounds) nested_args.append(output_region) if nested_has_tiles: nested_args.append(self.tile_sizes_param) body = self.blocks.pop() do_inline(tiled_inner_fn, closure_args + nested_args, self.type_env, body, result_var = None) return syntax.ForLoop(i, start, stop, step, body, {})
def fuse(prev_fn, prev_fixed_args, next_fn, next_fixed_args, fusion_args): if syntax_helpers.is_identity_fn(next_fn): assert len(next_fixed_args) == 0 return prev_fn, prev_fixed_args """ Expects the prev_fn's returned value to be one or more of the arguments to next_fn. Any element in 'const_args' which is None gets replaced by the returned Var """ fused_formals = [] fused_input_types = [] fused_type_env = prev_fn.type_env.copy() fused_name = names.fresh('fused') prev_closure_formals = prev_fn.arg_names[:len(prev_fixed_args)] for prev_closure_arg_name in prev_closure_formals: t = prev_fn.type_env[prev_closure_arg_name] fused_formals.append(prev_closure_arg_name) fused_input_types.append(t) next_closure_formals = next_fn.arg_names[:len(next_fixed_args)] for next_closure_arg_name in next_closure_formals: t = next_fn.type_env[next_closure_arg_name] new_name = names.refresh(next_closure_arg_name) fused_type_env[new_name] = t fused_formals.append(new_name) fused_input_types.append(t) prev_direct_formals = prev_fn.arg_names[len(prev_fixed_args):] for arg_name in prev_direct_formals: t = prev_fn.type_env[arg_name] fused_formals.append(arg_name) fused_input_types.append(t) prev_return_var, fused_body = \ inline.replace_return_with_var(prev_fn.body, fused_type_env, prev_fn.return_type) # for now we're restricting both functions to have a single return at the # outermost scope inline_args = list(next_closure_formals) for arg in fusion_args: if arg is None: inline_args.append(prev_return_var) elif isinstance(arg, int): # positional arg which is not being fused out inner_name = next_fn.arg_names[arg] inner_type = next_fn.type_env[inner_name] new_name = names.refresh(inner_name) fused_formals.append(new_name) fused_type_env[new_name] = inner_type fused_input_types.append(inner_type) var = Var(new_name, inner_type) inline_args.append(var) else: assert arg.__class__ is Const, \ "Only scalars can be spliced as literals into a fused fn: %s" % arg inline_args.append(arg) next_return_var = inline.do_inline(next_fn, inline_args, fused_type_env, fused_body) fused_body.append(Return(next_return_var)) # we're not renaming variables that originate from the predecessor function new_fn = TypedFn(name = fused_name, arg_names = fused_formals, body = fused_body, input_types = tuple(fused_input_types), return_type = next_fn.return_type, type_env = fused_type_env) combined_args = prev_fixed_args + next_fixed_args return new_fn, combined_args