def _create_wrapper(self, n_pos, static_pairs, dynamic_keywords): args = FormalArgs() pos_vars = [] keyword_vars = {} for i in xrange(n_pos): local_name = names.fresh("input_%d" % i) args.add_positional(local_name) pos_vars.append(syntax.Var(local_name)) for visible_name in dynamic_keywords: local_name = names.fresh(visible_name) args.add_positional(local_name, visible_name) keyword_vars[visible_name] = syntax.Var(local_name) for (static_name, value) in static_pairs: if isinstance(value, syntax.Expr): assert isinstance(value, syntax.Const) keyword_vars[static_name] = value elif value is not None: assert syntax_helpers.is_python_constant(value), \ "Unexpected type for static/staged value: %s : %s" % \ (value, type(value)) keyword_vars[static_name] = syntax_helpers.const(value) result_expr = self.f(*pos_vars, **keyword_vars) body = [syntax.Return(result_expr)] wrapper_name = "%s_wrapper_%d_%d" % (self.name, n_pos, len(dynamic_keywords)) wrapper_name = names.fresh(wrapper_name) return syntax.Fn(name = wrapper_name, args = args, body = body)
def gen_tiled_wrapper(adverb_class, fn, arg_types, nonlocal_types): key = (adverb_class, fn.name, tuple(arg_types), config.opt_tile) if key in _lowered_wrapper_cache: return _lowered_wrapper_cache[key] else: # Generate a wrapper for the payload function, and then type specialize it # as well as tile it. Tiling needs to happen here, as we don't want to # tile the outer parallel wrapper function. nested_wrapper = \ adverb_wrapper.untyped_wrapper(adverb_class, map_fn_name = 'fn', data_names = fn.args.positional, varargs_name = None, axis = 0) nonlocal_args = ActualArgs([syntax.Var(names.fresh("arg")) for _ in nonlocal_types]) untyped_args = [syntax.Var(names.fresh("arg")) for _ in arg_types] fn_args_obj = FormalArgs() for arg in nonlocal_args: fn_args_obj.add_positional(arg.name) for arg in untyped_args: fn_args_obj.add_positional(arg.name) nested_closure = syntax.Closure(nested_wrapper, []) call = syntax.Call(nested_closure, [syntax.Closure(fn, nonlocal_args)] + untyped_args) body = [syntax.Return(call)] fn_name = names.fresh(adverb_class.node_type() + fn.name + "_wrapper") untyped_wrapper = syntax.Fn(fn_name, fn_args_obj, body) all_types = arg_types.prepend_positional(nonlocal_types) typed = type_inference.specialize(untyped_wrapper, all_types) typed = high_level_optimizations(typed) if config.opt_tile: return tiling(typed) else: return typed
def __getitem__(self, key): for scope in reversed(self.scopes): res = scope.get(key) if res: return res if self.outer_env: # don't actually keep the outer binding name, we just # need to check that it's possible and tell the outer scope # to register any necessary python refs self.outer_env[key] local_name = names.fresh(key) self.top_scope()[key] = local_name self.original_outer_names.append(key) self.localized_outer_names.append(local_name) return local_name if self.closure_cell_dict and key in self.closure_cell_dict: ref = ClosureCellRef(self.closure_cell_dict[key], key) elif self.globals_dict and key in self.globals_dict: ref = GlobalRef(self.globals_dict, key) else: raise NameNotFound(key) for (local_name, other_ref) in self.python_refs.iteritems(): if ref == other_ref: return local_name local_name = names.fresh(key) self.python_refs[local_name] = ref return local_name
def mk_simple_fn(mk_body, input_name = "x", fn_name = "cast"): unique_arg_name = names.fresh(input_name) unique_fn_name = names.fresh(fn_name) var = syntax.Var(unique_arg_name) formals = FormalArgs() formals.add_positional(unique_arg_name, input_name) body = mk_body(var) return syntax.Fn(unique_fn_name, formals, body)
def value_to_syntax(v): if syntax_helpers.is_python_constant(v): return syntax_helpers.const(v) elif isinstance(v, np.dtype): x = names.fresh("x") fn_name = names.fresh("cast") formals = FormalArgs() formals.add_positional(x, "x") body = [syntax.Return(syntax.Cast(syntax.Var(x), type=core_types.from_dtype(v)))] return syntax.Fn(fn_name, formals, body) else: assert is_function_value(v), "Can't make value %s : %s into static syntax" % (v, type(v)) return translate_function_value(v)
def flatten_Reduce(self, map_fn, combine, x, init): """Turn an axis-less reduction into a IndexReduce""" shape = self.shape(x) n_indices = self.rank(x) # build a function from indices which picks out the data elements # need for the original map_fn outer_closure_args = self.closure_elts(map_fn) args_obj = FormalArgs() inner_closure_vars = [] for i in xrange(len(outer_closure_args)): visible_name = "c%d" % i name = names.fresh(visible_name) args_obj.add_positional(name, visible_name) inner_closure_vars.append(Var(name)) data_arg_name = names.fresh("x") data_arg_var = Var(data_arg_name) idx_arg_name = names.fresh("i") idx_arg_var = Var(idx_arg_name) args_obj.add_positional(data_arg_name, "x") args_obj.add_positional(idx_arg_name, "i") idx_expr = syntax.Index(data_arg_var, idx_arg_var) inner_fn = self.get_fn(map_fn) fn_call_expr = syntax.Call(inner_fn, tuple(inner_closure_vars) + (idx_expr,)) idx_fn = syntax.Fn(name = names.fresh("idx_map"), args = args_obj, body = [syntax.Return(fn_call_expr)] ) #t = closure_type.make_closure_type(typed_fn, get_types(closure_args)) #return Closure(typed_fn, closure_args, t) outer_closure_args = tuple(outer_closure_args) + (x,) idx_closure_t = closure_type.make_closure_type(idx_fn, get_types(outer_closure_args)) idx_closure = Closure(idx_fn, args = outer_closure_args, type = idx_closure_t) result_type, typed_fn, typed_combine = \ specialize_IndexReduce(idx_closure, combine, n_indices, init) if not self.is_none(init): init = self.cast(init, typed_combine.return_type) return syntax.IndexReduce(shape = shape, fn = make_typed_closure(idx_closure, typed_fn), combine = make_typed_closure(combine, typed_combine), init = init, type = result_type)
def nested_maps(inner_fn, depth, arg_names): if depth <= 0: return inner_fn key = inner_fn.name, depth, tuple(arg_names) if key in _nested_map_cache: return _nested_map_cache[key] args_obj = args.FormalArgs() arg_vars = [] for var_name in arg_names: local_name = names.refresh(var_name) args_obj.add_positional(local_name) arg_vars.append(syntax.Var(local_name)) name = names.fresh(inner_fn.name + "_broadcast%d" % depth) nested_fn = nested_maps(inner_fn, depth - 1, arg_names) map_expr = syntax.Map(fn = nested_fn, axis = syntax_helpers.zero_i64, args = arg_vars) fn = syntax.Fn( name = name, args = args_obj, body = [syntax.Return(map_expr)] ) _nested_map_cache[key] = fn return fn
def translate_function_ast(function_def_ast, globals_dict = None, closure_vars = [], closure_cells = [], outer_env = None): """ Helper to launch translation of a python function's AST, and then construct an untyped parakeet function from the arguments, refs, and translated body. """ assert len(closure_vars) == len(closure_cells) closure_cell_dict = dict(zip(closure_vars, closure_cells)) translator = AST_Translator(globals_dict, closure_cell_dict, outer_env) ssa_args, assignments = translator.translate_args(function_def_ast.args) _, body = translator.visit_block(function_def_ast.body) body = assignments + body ssa_fn_name = names.fresh(function_def_ast.name) refs = [] ref_names = [] for (ssa_name, ref) in translator.env.python_refs.iteritems(): refs.append(ref) ref_names.append(ssa_name) # if function was nested in parakeet, it can have references to its # surrounding parakeet scope, which can't be captured with a python ref cell original_outer_names = translator.env.original_outer_names localized_outer_names = translator.env.localized_outer_names ssa_args.prepend_nonlocal_args(localized_outer_names + ref_names) return syntax.Fn(ssa_fn_name, ssa_args, body, refs, original_outer_names)
def get_index_fn(self, array_t, idx_t): key = (array_t, idx_t) if key in self.index_function_cache: return self.index_function_cache[key] array_name = names.fresh("array") array_var = Var(array_name, type = array_t) idx_name = names.fresh("idx") idx_var = Var(idx_name, type = idx_t) if idx_t is not Int64: idx_var = Cast(value = idx_var, type = Int64) elt_t = array_t.index_type(idx_t) fn = syntax.TypedFn( name = names.fresh("idx"), arg_names = (array_name, idx_name), input_types = (array_t, idx_t), return_type = elt_t, type_env = {array_name:array_t, idx_name:idx_t}, body = [Return (Index(array_var, idx_var, type = elt_t))]) self.index_function_cache[key] = fn return fn
def lookup(self, name): #if name in reserved_names: # return reserved_names[name] if name in self.scopes: return Var(self.scopes[name]) elif self.parent: # don't actually keep the outer binding name, we just # need to check that it's possible and tell the outer scope # to register any necessary python refs local_name = names.fresh(name) self.scopes[name] = local_name self.original_outer_names.append(name) self.localized_outer_names.append(local_name) return Var(local_name) elif self.closure_cell_dict and name in self.closure_cell_dict: ref = ClosureCellRef(self.closure_cell_dict[name], name) for (local_name, other_ref) in self.python_refs.iteritems(): if ref == other_ref: return Var(local_name) local_name = names.fresh(name) self.scopes[name] = local_name self.original_outer_names.append(name) self.localized_outer_names.append(local_name) self.python_refs[local_name] = ref return Var(local_name) elif self.is_global(name): value = self.lookup_global(name) if is_static_value(value): return value_to_syntax(value) else: return ExternalValue(value) #else: # assert False, "External values must be scalars or functions" else: raise NameNotFound(name)
def visit_loop_body(self, body, *exprs): merge = {} substitutions = {} curr_scope = self.current_scope() exprs = [self.visit(expr) for expr in exprs] scope_after, body = self.visit_block(body) for (k, name_after) in scope_after.iteritems(): if k in self.scopes: name_before = self.scopes[k] new_name = names.fresh(k + "_loop") merge[new_name] = (Var(name_before), Var(name_after)) substitutions[name_before] = new_name curr_scope[k] = new_name exprs = [subst_expr(expr, substitutions) for expr in exprs] body = subst_stmt_list(body, substitutions) return body, merge, exprs
def translate_function_ast(name, args, body, globals_dict = None, closure_vars = [], closure_cells = [], parent = None, filename = None): """ Helper to launch translation of a python function's AST, and then construct an untyped parakeet function from the arguments, refs, and translated body. """ assert len(closure_vars) == len(closure_cells) closure_cell_dict = dict(zip(closure_vars, closure_cells)) if filename is None and parent is not None: filename = parent.filename translator = AST_Translator(globals_dict, closure_cell_dict, parent, function_name = name, filename = filename) ssa_args, assignments = translator.translate_args(args) _, body = translator.visit_block(body) body = assignments + body ssa_fn_name = names.fresh(name) # if function was nested in parakeet, it can have references to its # surrounding parakeet scope, which can't be captured with a python ref cell original_outer_names = translator.original_outer_names localized_outer_names = translator.localized_outer_names python_refs = translator.python_refs ssa_args.prepend_nonlocal_args(localized_outer_names) if globals_dict: assert parent is None assert len(original_outer_names) == len(python_refs) return syntax.Fn(ssa_fn_name, ssa_args, body, python_refs.values(), []) else: assert parent fn = syntax.Fn(ssa_fn_name, ssa_args, body, [], original_outer_names) if len(original_outer_names) > 0: outer_ssa_vars = [parent.lookup(x) for x in original_outer_names] return syntax.Closure(fn, outer_ssa_vars) else: return fn
def prim_wrapper(p): """Given a primitive, return an untyped function which calls that prim""" if p in _untyped_prim_wrappers: return _untyped_prim_wrappers[p] else: fn_name = names.fresh(p.name) args_obj = FormalArgs() arg_vars = [] for name in names.fresh_list(p.nin): args_obj.add_positional(name) arg_vars.append(syntax.Var(name)) body = [syntax.Return(syntax.PrimCall(p, arg_vars))] fundef = syntax.Fn(fn_name, args_obj, body, []) _untyped_prim_wrappers[p] = fundef return fundef
def transform_AllPairs(self, expr): """ Transform each AllPairs(f, X, Y) operation into a pair of nested maps: def g(x_elt): def h(y_elt): return f(x_elt, y_elt) return map(g, X) """ if expr.out is not None: return expr # if the adverb function is a closure, give me all the values it # closes over closure_elts = self.closure_elts(expr.fn) n_closure_elts = len(closure_elts) # strip off the closure wrappings and give me the underlying TypedFn fn = self.get_fn(expr.fn) # the two array arguments to this AllPairs adverb x, y_outer = expr.args x_elt_name = names.fresh('x_elt') x_elt_t = fn.input_types[n_closure_elts] x_elt_var = Var(x_elt_name, type = x_elt_t) y_inner_name = names.fresh('y') y_inner = Var(y_inner_name, type = y_outer.type) inner_closure_args = [] for (i, elt) in enumerate(closure_elts): t = elt.type if elt.__class__ is Var: name = names.refresh(elt.name) else: name = names.fresh('closure_arg%d' % i) inner_closure_args.append(Var(name, type = t)) inner_arg_names = [] inner_input_types = [] type_env = {} for var in inner_closure_args + [y_inner, x_elt_var]: type_env[var.name] = var.type inner_arg_names.append(var.name) inner_input_types.append(var.type) inner_closure_rhs = self.closure(fn, inner_closure_args + [x_elt_var]) inner_result_t = array_type.lower_rank(expr.type, 1) inner_fn = TypedFn( name = names.fresh('allpairs_into_maps_wrapper'), arg_names = tuple(inner_arg_names), input_types = tuple(inner_input_types), return_type = inner_result_t, type_env = type_env, body = [ Return(Map(inner_closure_rhs, args=[y_inner], axis = expr.axis, type = inner_result_t)) ] ) closure = self.closure(inner_fn, closure_elts + [y_outer]) return Map(closure, [x], axis = expr.axis, type = expr.type)
def untyped_wrapper(adverb_class, map_fn_name = None, combine_fn_name = None, emit_fn_name = None, data_names = [], varargs_name = 'xs', axis = 0): """ Given: - an adverb class (i.e. Map, Reduce, Scan, or AllPairs) - function var names (some of which can be None) - optional list of positional data arg names - optional name for the varargs parameter - an axis along which the adverb operates Return a function which calls the desired adverb with the data args and unpacked varargs tuple. """ axis = syntax_helpers.wrap_if_constant(axis) key = (adverb_class.__name__, map_fn_name, combine_fn_name, emit_fn_name, axis, tuple(data_names), varargs_name) if key in _adverb_wrapper_cache: return _adverb_wrapper_cache[key] else: fn_args_obj = FormalArgs() def mk_input_var(name): if name is None: return None else: local_name = names.refresh(name) fn_args_obj.add_positional(local_name, name) return syntax.Var(local_name) map_fn = mk_input_var(map_fn_name) combine_fn = mk_input_var(combine_fn_name) emit_fn = mk_input_var(emit_fn_name) data_arg_vars = map(mk_input_var, data_names) if varargs_name: local_name = names.refresh(varargs_name) local_name = fn_args_obj.starargs = local_name unpack = syntax.Var(local_name) else: unpack = None data_args = ActualArgs(data_arg_vars, starargs = unpack) adverb_param_names = adverb_class.members() adverb_params = {'axis': axis, 'args': data_args} if 'init' in adverb_param_names: init_name = names.fresh('init') fn_args_obj.add_positional(init_name, 'init') fn_args_obj.defaults[init_name] = None init_var = syntax.Var(init_name) adverb_params['init'] = init_var def add_fn_arg(field, value): if value: adverb_params[field] = value add_fn_arg('fn', map_fn) add_fn_arg('combine', combine_fn) add_fn_arg('emit', emit_fn) adverb = adverb_class(**adverb_params) body = [syntax.Return(adverb)] fn_name = names.fresh(adverb_class.node_type() + "_wrapper") fundef = syntax.Fn(fn_name, fn_args_obj, body) _adverb_wrapper_cache[key] = fundef return fundef
def transform_Map(self, expr): self.num_tiles += 1 depth = len(self.adverbs_visited) closure = expr.fn closure_args = [] fn = closure if isinstance(fn, syntax.Closure): closure_args = closure.args fn = closure.fn axes = [self.get_num_expansions_at_depth(arg.name, depth) + expr.axis for arg in expr.args] self.push_exp(syntax.Map, AdverbArgs(expr.fn, expr.args, expr.axis, axes)) for fn_arg, adverb_arg in zip(fn.arg_names[:len(closure_args)], closure_args): name = self.get_closure_arg(adverb_arg).name new_expansions = copy.deepcopy(self.get_expansions(name)) self.expansions[fn_arg] = new_expansions for fn_arg, adverb_arg in zip(fn.arg_names[len(closure_args):], expr.args): new_expansions = copy.deepcopy(self.get_expansions(adverb_arg.name)) new_expansions.append(depth) self.expansions[fn_arg] = new_expansions depths = self.get_depths_list(fn.arg_names) find_adverbs = FindAdverbs() find_syntax.visit_fn(fn) if find_syntax.has_adverbs: arg_names = list(fn.arg_names) input_types = [] self.push_type_env(fn.type_env) for arg, t in zip(arg_names, fn.input_types): new_type = array_type.increase_rank(t, len(self.get_expansions(arg))) input_types.append(new_type) self.type_env[arg] = new_type exps = self.get_depths_list(fn.arg_names) rank_inc = 0 for i, exp in enumerate(exps): if exp >= depth: rank_inc = i break return_t = array_type.increase_rank(expr.type, rank_inc) new_fn = syntax.TypedFn(name = names.fresh("expanded_map_fn"), arg_names = tuple(arg_names), body = self.transform_block(fn.body), input_types = input_types, return_type = return_t, type_env = self.pop_type_env()) new_fn.has_tiles = True else: # Estimate the tile sizes self.estimate_tile_sizes(fn.arg_names, depths) new_fn = self.gen_unpack_tree(self.adverbs_visited, depths, fn.arg_names, fn.body, fn.type_env) if config.opt_reg_tile: adverb_tree = [get_tiled_version(adv) for adv in self.adverbs_visited] new_fn = self.gen_unpack_tree(adverb_tree, depths, fn.arg_names, new_fn, fn.type_env, reg_tiling = True) for arg, t in zip(expr.args, new_fn.input_types[len(closure_args):]): arg.type = t return_t = new_fn.return_type if isinstance(closure, syntax.Closure): for c_arg, t in zip(closure.args, new_fn.input_types): c_arg.type = t closure_arg_types = [arg.type for arg in closure.args] closure.fn = new_fn closure.type = closure_type.make_closure_type(new_fn, closure_arg_types) new_fn = closure self.pop_exp() return syntax.TiledMap(fn = new_fn, args = expr.args, axes = axes, type = return_t)
def gen_unpack_fn(depth_idx, arg_order): if depth_idx >= len(depths): if reg_tiling: return inner else: # For each stmt in body, add its lhs free vars to the type env inner_type_env = copy.copy(type_env) return_t = Int32 # Dummy type for s in inner: if isinstance(s, syntax.Assign): lhs_names = free_vars(s.lhs) lhs_types = [type_env[name] for name in lhs_names] for name, t in zip(lhs_names, lhs_types): inner_type_env[name] = t elif isinstance(s, syntax.Return): if isinstance(s.value, str): return_t = type_env[s.value.name] else: return_t = s.value.type # The innermost function always uses all the variables input_types = [type_env[arg] for arg in arg_order] fn = syntax.TypedFn(name = names.fresh("inner_block"), arg_names = tuple([name for name in v_names]), body = inner, input_types = input_types, return_type = return_t, type_env = inner_type_env) return fn else: # Get the current depth depth = depths[depth_idx] # Order the arguments for the current depth, i.e. for the nested fn cur_arg_names, fixed_arg_names = order_args(depth) nested_arg_names = fixed_arg_names + cur_arg_names # Make a type env for this function based on the number of expansions # left for each arg adv_args = self.adverb_args[depth_idx] if reg_tiling: new_adverb = adverb_tree[depth_idx](fn = adv_args.fn, args = adv_args.args, axes = adv_args.axes, fixed_tile_size = True) else: new_adverb = adverb_tree[depth_idx](fn = adv_args.fn, args = adv_args.args, axis = adv_args.axis) # Increase the rank of each arg by the number of nested expansions # (i.e. the expansions of that arg that occur deeper in the nesting) new_type_env = {} if reg_tiling: for arg in nested_arg_names: new_type_env[arg] = inner.type_env[arg] else: for arg in nested_arg_names: exps = self.get_expansions(arg) rank_increase = 0 for i, e in enumerate(exps): if e >= depth: rank_increase = len(exps) - i break new_type_env[arg] = \ array_type.increase_rank(type_env[arg], rank_increase) cur_arg_types = [new_type_env[arg] for arg in cur_arg_names] fixed_arg_types = [new_type_env[arg] for arg in fixed_arg_names] # Generate the nested function with the proper arg order and wrap it # in a closure nested_fn = gen_unpack_fn(depth_idx+1, nested_arg_names) nested_args = [syntax.Var(name, type = t) for name, t in zip(cur_arg_names, cur_arg_types)] nested_fixed_args = \ [syntax.Var(name, type = t) for name, t in zip(fixed_arg_names, fixed_arg_types)] nested_closure = self.closure(nested_fn, nested_fixed_args) # Make an adverb that wraps the nested fn new_adverb.fn = nested_closure new_adverb.args = nested_args return_t = nested_fn.return_type if isinstance(new_adverb, syntax.Reduce): if reg_tiling: ds = copy.copy(depths) ds.remove(depth) new_adverb.combine = self.unpack_combine(adv_args.combine, ds) else: new_adverb.combine = adv_args.combine new_adverb.init = adv_args.init elif not reg_tiling: return_t = array_type.increase_rank(nested_fn.return_type, 1) new_adverb.type = return_t # Add the adverb to the body of the current fn and return the fn name = names.fresh("reg_tile" if reg_tiling else "intermediate_depth") arg_types = [new_type_env[arg] for arg in arg_order] fn = syntax.TypedFn(name = name, arg_names = arg_order, body = [syntax.Return(new_adverb)], input_types = arg_types, return_type = return_t, type_env = new_type_env) return fn
def fresh_var(self, t, prefix = "temp"): assert prefix is not None assert t is not None, "Type required for new variable %s" % prefix ssa_id = names.fresh(prefix) self.type_env[ssa_id] = t return Var(ssa_id, type = t)
def gen_par_work_function(adverb_class, f, nonlocals, nonlocal_types, args_t, arg_types, dont_slice_position = -1): key = (adverb_class, f.name, tuple(arg_types), config.opt_tile) if key in _par_wrapper_cache: return _par_wrapper_cache[key] else: fn = gen_tiled_wrapper(adverb_class, f, arg_types, nonlocal_types) num_tiles = fn.num_tiles # Construct a typed parallel wrapper function that unpacks the args struct # and calls the (possibly tiled) payload function with its slices of the # arguments. start_var = syntax.Var(names.fresh("start"), type = Int64) stop_var = syntax.Var(names.fresh("stop"), type = Int64) args_var = syntax.Var(names.fresh("args"), type = args_t) tile_type = tuple_type.make_tuple_type([Int64 for _ in range(num_tiles)]) tile_sizes_var = syntax.Var(names.fresh("tile_sizes"), type = tile_type) inputs = [start_var, stop_var, args_var, tile_sizes_var] # Manually unpack the args into types Vars and slice into them. slice_t = array_type.make_slice_type(Int64, Int64, Int64) arg_slice = syntax.Slice(start_var, stop_var, syntax_helpers.one_i64, type = slice_t) def slice_arg(arg, t): indices = [arg_slice] for _ in xrange(1, arg.type.rank): indices.append(syntax_helpers.slice_none) tuple_t = tuple_type.make_tuple_type(syntax_helpers.get_types(indices)) index_tuple = syntax.Tuple(indices, tuple_t) result_t = t.index_type(tuple_t) return syntax.Index(arg, index_tuple, type = result_t) unpacked_args = [] i = 0 for t in nonlocal_types: unpacked_args.append(syntax.Attribute(args_var, ("arg%d" % i), type = t)) i += 1 for t in arg_types: attr = syntax.Attribute(args_var, ("arg%d" % i), type = t) if isinstance(t, array_type.ArrayT) and i != dont_slice_position: # TODO: Handle axis. unpacked_args.append(slice_arg(attr, t)) else: unpacked_args.append(attr) i += 1 # If tiling, pass in the tile params array. if config.opt_tile: unpacked_args.append(tile_sizes_var) # Make a typed closure that calls the payload function with the arg slices. closure_t = closure_type.make_closure_type(fn, []) nested_closure = syntax.Closure(fn, [], type = closure_t) return_t = fn.return_type call = syntax.Call(nested_closure, unpacked_args, type = return_t) output_name = names.fresh("output") output_attr = syntax.Attribute(args_var, "output", type = return_t) output_var = syntax.Var(output_name, type = output_attr.type) output_slice = slice_arg(output_var, return_t) body = [syntax.Assign(output_var, output_attr), syntax.Assign(output_slice, call), syntax.Return(syntax_helpers.none)] type_env = {output_name:output_slice.type} for arg in inputs: type_env[arg.name] = arg.type # Construct the typed wrapper. wrapper_name = adverb_class.node_type() + fn.name + "_par" parallel_wrapper = \ syntax.TypedFn(name = names.fresh(wrapper_name), arg_names = [var.name for var in inputs], input_types = syntax_helpers.get_types(inputs), body = body, return_type = core_types.NoneType, type_env = type_env) lowered = lowering(parallel_wrapper) lowered.num_tiles = num_tiles lowered.dl_tile_estimates = fn.dl_tile_estimates lowered.ml_tile_estimates = fn.ml_tile_estimates _par_wrapper_cache[key] = lowered return lowered
def replace_return_with_var(body, type_env, return_type): result_name = names.fresh("result") type_env[result_name] = return_type result_var = Var(result_name, type = return_type) new_body = replace_returns(body, result_var) return result_var, new_body
def fuse(prev_fn, prev_fixed_args, next_fn, next_fixed_args, fusion_args): if syntax_helpers.is_identity_fn(next_fn): assert len(next_fixed_args) == 0 return prev_fn, prev_fixed_args """ Expects the prev_fn's returned value to be one or more of the arguments to next_fn. Any element in 'const_args' which is None gets replaced by the returned Var """ fused_formals = [] fused_input_types = [] fused_type_env = prev_fn.type_env.copy() fused_name = names.fresh('fused') prev_closure_formals = prev_fn.arg_names[:len(prev_fixed_args)] for prev_closure_arg_name in prev_closure_formals: t = prev_fn.type_env[prev_closure_arg_name] fused_formals.append(prev_closure_arg_name) fused_input_types.append(t) next_closure_formals = next_fn.arg_names[:len(next_fixed_args)] for next_closure_arg_name in next_closure_formals: t = next_fn.type_env[next_closure_arg_name] new_name = names.refresh(next_closure_arg_name) fused_type_env[new_name] = t fused_formals.append(new_name) fused_input_types.append(t) prev_direct_formals = prev_fn.arg_names[len(prev_fixed_args):] for arg_name in prev_direct_formals: t = prev_fn.type_env[arg_name] fused_formals.append(arg_name) fused_input_types.append(t) prev_return_var, fused_body = \ inline.replace_return_with_var(prev_fn.body, fused_type_env, prev_fn.return_type) # for now we're restricting both functions to have a single return at the # outermost scope inline_args = list(next_closure_formals) for arg in fusion_args: if arg is None: inline_args.append(prev_return_var) elif isinstance(arg, int): # positional arg which is not being fused out inner_name = next_fn.arg_names[arg] inner_type = next_fn.type_env[inner_name] new_name = names.refresh(inner_name) fused_formals.append(new_name) fused_type_env[new_name] = inner_type fused_input_types.append(inner_type) var = Var(new_name, inner_type) inline_args.append(var) else: assert arg.__class__ is Const, \ "Only scalars can be spliced as literals into a fused fn: %s" % arg inline_args.append(arg) next_return_var = inline.do_inline(next_fn, inline_args, fused_type_env, fused_body) fused_body.append(Return(next_return_var)) # we're not renaming variables that originate from the predecessor function new_fn = TypedFn(name = fused_name, arg_names = fused_formals, body = fused_body, input_types = tuple(fused_input_types), return_type = next_fn.return_type, type_env = fused_type_env) combined_args = prev_fixed_args + next_fixed_args return new_fn, combined_args
def fresh_name(self, original_name): fresh_name = names.fresh(original_name) self.scopes[original_name] = fresh_name return fresh_name
def fresh(self, name): fresh_name = names.fresh(name) self.scopes[-1][name] = fresh_name return fresh_name