Пример #1
0
  def _create_wrapper(self, n_pos, static_pairs, dynamic_keywords):
    args = FormalArgs()
    pos_vars = []
    keyword_vars = {}
    for i in xrange(n_pos):
      local_name = names.fresh("input_%d" % i)
      args.add_positional(local_name)
      pos_vars.append(syntax.Var(local_name))
  
    
    for visible_name in dynamic_keywords:
      local_name = names.fresh(visible_name)
      args.add_positional(local_name, visible_name)
      keyword_vars[visible_name] = syntax.Var(local_name)

    for (static_name, value) in static_pairs:
      if isinstance(value, syntax.Expr):
        assert isinstance(value, syntax.Const)
        keyword_vars[static_name] = value
      elif value is not None:
        assert syntax_helpers.is_python_constant(value), \
            "Unexpected type for static/staged value: %s : %s" % \
            (value, type(value))
        keyword_vars[static_name] = syntax_helpers.const(value)

    result_expr = self.f(*pos_vars, **keyword_vars)
    body = [syntax.Return(result_expr)]
    wrapper_name = "%s_wrapper_%d_%d" % (self.name, n_pos,
                                         len(dynamic_keywords))
    wrapper_name = names.fresh(wrapper_name)
    return syntax.Fn(name = wrapper_name, args = args, body = body)
Пример #2
0
def gen_tiled_wrapper(adverb_class, fn, arg_types, nonlocal_types):
  key = (adverb_class, fn.name, tuple(arg_types), config.opt_tile)
  if key in _lowered_wrapper_cache:
    return _lowered_wrapper_cache[key]
  else:
    # Generate a wrapper for the payload function, and then type specialize it
    # as well as tile it.  Tiling needs to happen here, as we don't want to
    # tile the outer parallel wrapper function.
    nested_wrapper = \
        adverb_wrapper.untyped_wrapper(adverb_class,
                                       map_fn_name = 'fn',
                                       data_names = fn.args.positional,
                                       varargs_name = None,
                                       axis = 0)
    nonlocal_args = ActualArgs([syntax.Var(names.fresh("arg"))
                                for _ in nonlocal_types])
    untyped_args = [syntax.Var(names.fresh("arg")) for _ in arg_types]
    fn_args_obj = FormalArgs()
    for arg in nonlocal_args:
      fn_args_obj.add_positional(arg.name)
    for arg in untyped_args:
      fn_args_obj.add_positional(arg.name)
    nested_closure = syntax.Closure(nested_wrapper, [])
    call = syntax.Call(nested_closure,
                       [syntax.Closure(fn, nonlocal_args)] + untyped_args)
    body = [syntax.Return(call)]
    fn_name = names.fresh(adverb_class.node_type() + fn.name + "_wrapper")
    untyped_wrapper = syntax.Fn(fn_name, fn_args_obj, body)
    all_types = arg_types.prepend_positional(nonlocal_types)
    typed = type_inference.specialize(untyped_wrapper, all_types)
    typed = high_level_optimizations(typed)
    if config.opt_tile:
      return tiling(typed)
    else:
      return typed
Пример #3
0
  def __getitem__(self, key):
    for scope in reversed(self.scopes):
      res = scope.get(key)
      if res:
        return res

    if self.outer_env:
      # don't actually keep the outer binding name, we just
      # need to check that it's possible and tell the outer scope
      # to register any necessary python refs
      self.outer_env[key]
      local_name = names.fresh(key)
      self.top_scope()[key] = local_name
      self.original_outer_names.append(key)
      self.localized_outer_names.append(local_name)
      return local_name

    if self.closure_cell_dict and key in self.closure_cell_dict:
      ref = ClosureCellRef(self.closure_cell_dict[key], key)
    elif self.globals_dict and key in self.globals_dict:
      ref = GlobalRef(self.globals_dict, key)
    else:
      raise NameNotFound(key)
    for (local_name, other_ref) in self.python_refs.iteritems():
      if ref == other_ref:
        return local_name
    local_name = names.fresh(key)
    self.python_refs[local_name] = ref
    return local_name
Пример #4
0
def mk_simple_fn(mk_body, input_name = "x", fn_name = "cast"):
  unique_arg_name = names.fresh(input_name)
  unique_fn_name = names.fresh(fn_name)
  var = syntax.Var(unique_arg_name)
  formals = FormalArgs()
  formals.add_positional(unique_arg_name, input_name)
  body = mk_body(var)
  return syntax.Fn(unique_fn_name, formals, body)
Пример #5
0
def value_to_syntax(v):
  if syntax_helpers.is_python_constant(v):
    return syntax_helpers.const(v)
  elif isinstance(v, np.dtype):
    x = names.fresh("x")
    fn_name = names.fresh("cast") 
    formals = FormalArgs()
    formals.add_positional(x, "x")
    body = [syntax.Return(syntax.Cast(syntax.Var(x), type=core_types.from_dtype(v)))]
    return syntax.Fn(fn_name, formals, body)
  else:
    assert is_function_value(v), "Can't make value %s : %s into static syntax" % (v, type(v))
    return translate_function_value(v)  
Пример #6
0
  def flatten_Reduce(self, map_fn, combine, x, init):
    """Turn an axis-less reduction into a IndexReduce"""
    shape = self.shape(x)
    n_indices = self.rank(x)
    # build a function from indices which picks out the data elements
    # need for the original map_fn
 
    
    outer_closure_args = self.closure_elts(map_fn)
    args_obj = FormalArgs()
    inner_closure_vars = []
    for i in xrange(len(outer_closure_args)):
      visible_name = "c%d" % i
      name = names.fresh(visible_name)
      args_obj.add_positional(name, visible_name)
      inner_closure_vars.append(Var(name))
    
    data_arg_name = names.fresh("x")
    data_arg_var = Var(data_arg_name)
    idx_arg_name = names.fresh("i")
    idx_arg_var = Var(idx_arg_name)
    
    args_obj.add_positional(data_arg_name, "x")
    args_obj.add_positional(idx_arg_name, "i")
    
    idx_expr = syntax.Index(data_arg_var, idx_arg_var)
    inner_fn = self.get_fn(map_fn)
    fn_call_expr = syntax.Call(inner_fn, tuple(inner_closure_vars)  + (idx_expr,))
    idx_fn = syntax.Fn(name = names.fresh("idx_map"),
                       args = args_obj, 
                       body =  [syntax.Return(fn_call_expr)]
                       )
    
    #t = closure_type.make_closure_type(typed_fn, get_types(closure_args))
    #return Closure(typed_fn, closure_args, t)
    outer_closure_args = tuple(outer_closure_args) + (x,)
  
    idx_closure_t = closure_type.make_closure_type(idx_fn, get_types(outer_closure_args))
    
    idx_closure = Closure(idx_fn, args = outer_closure_args, type = idx_closure_t)
    
    result_type, typed_fn, typed_combine = \
      specialize_IndexReduce(idx_closure, combine, n_indices, init)
    if not self.is_none(init):
      init = self.cast(init, typed_combine.return_type)
    return syntax.IndexReduce(shape = shape, 
                              fn = make_typed_closure(idx_closure, typed_fn),
                              combine = make_typed_closure(combine, typed_combine),
                              init = init,   
                              type = result_type)
Пример #7
0
def nested_maps(inner_fn, depth, arg_names):
  if depth <= 0:
    return inner_fn

  key = inner_fn.name, depth, tuple(arg_names)
  if key in _nested_map_cache:
    return _nested_map_cache[key]
  args_obj = args.FormalArgs()
  arg_vars = []
  for var_name in arg_names:
    local_name = names.refresh(var_name)
    args_obj.add_positional(local_name)
    arg_vars.append(syntax.Var(local_name))

  name = names.fresh(inner_fn.name + "_broadcast%d" % depth)
  nested_fn = nested_maps(inner_fn, depth - 1, arg_names)

  map_expr = syntax.Map(fn = nested_fn, 
                        axis = syntax_helpers.zero_i64, 
                        args = arg_vars)
  fn = syntax.Fn(
    name = name,
    args = args_obj,
    body = [syntax.Return(map_expr)]
  )

  _nested_map_cache[key] = fn
  return fn
Пример #8
0
def translate_function_ast(function_def_ast, globals_dict = None,
                           closure_vars = [], closure_cells = [],
                           outer_env = None):
  """
  Helper to launch translation of a python function's AST, and then construct
  an untyped parakeet function from the arguments, refs, and translated body.
  """

  assert len(closure_vars) == len(closure_cells)
  closure_cell_dict = dict(zip(closure_vars, closure_cells))

  translator = AST_Translator(globals_dict, closure_cell_dict, outer_env)

  ssa_args, assignments = translator.translate_args(function_def_ast.args)
  _, body = translator.visit_block(function_def_ast.body)
  body = assignments + body
  ssa_fn_name = names.fresh(function_def_ast.name)

  refs = []
  ref_names = []
  for (ssa_name, ref) in translator.env.python_refs.iteritems():
    refs.append(ref)
    ref_names.append(ssa_name)

  # if function was nested in parakeet, it can have references to its
  # surrounding parakeet scope, which can't be captured with a python ref cell
  original_outer_names = translator.env.original_outer_names
  localized_outer_names = translator.env.localized_outer_names

  ssa_args.prepend_nonlocal_args(localized_outer_names + ref_names)

  return syntax.Fn(ssa_fn_name, ssa_args, body, refs, original_outer_names)
Пример #9
0
  def get_index_fn(self, array_t, idx_t):
    key = (array_t, idx_t) 
    if key in self.index_function_cache:
      return self.index_function_cache[key]
    array_name = names.fresh("array")
    array_var = Var(array_name, type = array_t)
    idx_name = names.fresh("idx")
    idx_var = Var(idx_name, type = idx_t)
    if idx_t is not Int64:
      idx_var = Cast(value = idx_var,  type = Int64)
    elt_t = array_t.index_type(idx_t)

    fn = syntax.TypedFn(
        name = names.fresh("idx"), 
        arg_names = (array_name, idx_name),
        input_types = (array_t, idx_t),
        return_type = elt_t,
        type_env = {array_name:array_t, idx_name:idx_t}, 
        body = [Return (Index(array_var, idx_var, type = elt_t))]) 
    self.index_function_cache[key] = fn 
    return fn 
Пример #10
0
 def lookup(self, name):
   #if name in reserved_names:
   #  return reserved_names[name]
   if name in self.scopes:
     return Var(self.scopes[name])
   elif self.parent:
     # don't actually keep the outer binding name, we just
     # need to check that it's possible and tell the outer scope
     # to register any necessary python refs
     local_name = names.fresh(name)
     self.scopes[name] = local_name
     self.original_outer_names.append(name)
     self.localized_outer_names.append(local_name)
     return Var(local_name)
   elif self.closure_cell_dict and name in self.closure_cell_dict:
     ref = ClosureCellRef(self.closure_cell_dict[name], name)
     for (local_name, other_ref) in self.python_refs.iteritems():
       if ref == other_ref:
         return Var(local_name)
     local_name = names.fresh(name)
     self.scopes[name] = local_name
     self.original_outer_names.append(name)
     self.localized_outer_names.append(local_name)
     self.python_refs[local_name] = ref
     return Var(local_name)
   elif self.is_global(name):
     value = self.lookup_global(name)
     if is_static_value(value): 
       return value_to_syntax(value)
     else:
       return ExternalValue(value)
     
       
     #else:
     #  assert False, "External values must be scalars or functions"
   else:
     raise NameNotFound(name)
Пример #11
0
 def visit_loop_body(self, body, *exprs):
   merge = {}
   substitutions = {}
   curr_scope = self.current_scope()
   exprs = [self.visit(expr) for expr in exprs]
   scope_after, body = self.visit_block(body)
   for (k, name_after) in scope_after.iteritems():
     if k in self.scopes:
       name_before = self.scopes[k]
       new_name = names.fresh(k + "_loop")
       merge[new_name] = (Var(name_before), Var(name_after))
       substitutions[name_before]  = new_name
       curr_scope[k] = new_name
   
   exprs = [subst_expr(expr, substitutions) for expr in exprs]
   body = subst_stmt_list(body, substitutions)
   return body, merge, exprs 
Пример #12
0
def translate_function_ast(name, 
                           args, 
                           body, 
                           globals_dict = None,
                           closure_vars = [], 
                           closure_cells = [],
                           parent = None, 
                           filename = None):
  """
  Helper to launch translation of a python function's AST, and then construct
  an untyped parakeet function from the arguments, refs, and translated body.
  """
  
  assert len(closure_vars) == len(closure_cells)
  closure_cell_dict = dict(zip(closure_vars, closure_cells))
  
  if filename is None and parent is not None:
    filename = parent.filename 
    
  translator = AST_Translator(globals_dict, closure_cell_dict, 
                              parent, function_name = name, filename = filename)

  ssa_args, assignments = translator.translate_args(args)
  _, body = translator.visit_block(body)
  body = assignments + body

  ssa_fn_name = names.fresh(name)

  # if function was nested in parakeet, it can have references to its
  # surrounding parakeet scope, which can't be captured with a python ref cell
  original_outer_names = translator.original_outer_names
  localized_outer_names = translator.localized_outer_names
  python_refs = translator.python_refs 
  ssa_args.prepend_nonlocal_args(localized_outer_names)
  if globals_dict:
    assert parent is None
    assert len(original_outer_names) == len(python_refs)
    return syntax.Fn(ssa_fn_name, ssa_args, body, python_refs.values(), [])
  else:
    assert parent
    fn = syntax.Fn(ssa_fn_name, ssa_args, body, [], original_outer_names)
    if len(original_outer_names) > 0:
      outer_ssa_vars = [parent.lookup(x) for x in original_outer_names]
      return syntax.Closure(fn, outer_ssa_vars)
    else:
      return fn
Пример #13
0
def prim_wrapper(p):
  """Given a primitive, return an untyped function which calls that prim"""
  if p in _untyped_prim_wrappers:
    return _untyped_prim_wrappers[p]
  else:
    fn_name = names.fresh(p.name)

    args_obj = FormalArgs()

    arg_vars = []
    for name in names.fresh_list(p.nin):
      args_obj.add_positional(name)
      arg_vars.append(syntax.Var(name))
    body = [syntax.Return(syntax.PrimCall(p, arg_vars))]
    fundef = syntax.Fn(fn_name, args_obj, body, [])
    _untyped_prim_wrappers[p] = fundef

    return fundef
Пример #14
0
  def transform_AllPairs(self, expr):
    """
    Transform each AllPairs(f, X, Y) operation into a pair of nested maps:
      def g(x_elt):
        def h(y_elt):
          return f(x_elt, y_elt)
      return map(g, X)
    """

    if expr.out is not None:
      return expr

    # if the adverb function is a closure, give me all the values it
    # closes over
    closure_elts = self.closure_elts(expr.fn)
    n_closure_elts = len(closure_elts)

    # strip off the closure wrappings and give me the underlying TypedFn
    fn = self.get_fn(expr.fn)

    # the two array arguments to this AllPairs adverb
    x, y_outer = expr.args

    x_elt_name = names.fresh('x_elt')
    x_elt_t = fn.input_types[n_closure_elts]
    x_elt_var = Var(x_elt_name, type = x_elt_t)
    y_inner_name = names.fresh('y')
    y_inner = Var(y_inner_name, type = y_outer.type)

    inner_closure_args = []
    for (i, elt) in enumerate(closure_elts):
      t = elt.type
      if elt.__class__ is Var:
        name = names.refresh(elt.name)
      else:
        name = names.fresh('closure_arg%d' % i)
      inner_closure_args.append(Var(name, type = t))

    inner_arg_names = []
    inner_input_types = []
    type_env = {}

    for var in inner_closure_args + [y_inner, x_elt_var]:
      type_env[var.name] = var.type
      inner_arg_names.append(var.name)
      inner_input_types.append(var.type)

    inner_closure_rhs = self.closure(fn, inner_closure_args + [x_elt_var])

    inner_result_t = array_type.lower_rank(expr.type, 1)
    inner_fn = TypedFn(
      name = names.fresh('allpairs_into_maps_wrapper'),
      arg_names = tuple(inner_arg_names),
      input_types = tuple(inner_input_types),
      return_type = inner_result_t,
      type_env = type_env,
      body = [
        Return(Map(inner_closure_rhs,
                   args=[y_inner],
                   axis = expr.axis,
                   type = inner_result_t))
      ]
    )
    closure = self.closure(inner_fn, closure_elts + [y_outer])
    return Map(closure, [x], axis = expr.axis, type = expr.type)
Пример #15
0
def untyped_wrapper(adverb_class,
                    map_fn_name = None,
                    combine_fn_name = None,
                    emit_fn_name = None,
                    data_names = [],
                    varargs_name = 'xs',
                    axis = 0):
  """
  Given:
    - an adverb class (i.e. Map, Reduce, Scan, or AllPairs)
    - function var names (some of which can be None)
    - optional list of positional data arg names
    - optional name for the varargs parameter
    - an axis along which the adverb operates
  Return a function which calls the desired adverb with the data args and
  unpacked varargs tuple.
  """
  

  axis = syntax_helpers.wrap_if_constant(axis)
  key = (adverb_class.__name__,
         map_fn_name,
         combine_fn_name,
         emit_fn_name,
         axis,
         tuple(data_names),
         varargs_name)

  if key in _adverb_wrapper_cache:
    return _adverb_wrapper_cache[key]
  else:
    fn_args_obj = FormalArgs()
    def mk_input_var(name):
      if name is None:
        return None
      else:
        local_name = names.refresh(name)
        fn_args_obj.add_positional(local_name,  name)
        return syntax.Var(local_name)

    map_fn = mk_input_var(map_fn_name)
    combine_fn = mk_input_var(combine_fn_name)
    emit_fn = mk_input_var(emit_fn_name)
    data_arg_vars = map(mk_input_var, data_names)

    if varargs_name:
      local_name = names.refresh(varargs_name)
      local_name = fn_args_obj.starargs = local_name
      unpack = syntax.Var(local_name)
    else:
      unpack = None

    data_args = ActualArgs(data_arg_vars, starargs = unpack)
    adverb_param_names = adverb_class.members()

    adverb_params = {'axis': axis, 'args': data_args}

    if 'init' in adverb_param_names:
      init_name = names.fresh('init')
      fn_args_obj.add_positional(init_name, 'init')
      fn_args_obj.defaults[init_name] = None
      init_var = syntax.Var(init_name)
      adverb_params['init'] = init_var

    def add_fn_arg(field, value):
      if value:
        adverb_params[field] = value

    add_fn_arg('fn', map_fn)
    add_fn_arg('combine', combine_fn)
    add_fn_arg('emit', emit_fn)

    adverb = adverb_class(**adverb_params)
    body = [syntax.Return(adverb)]
    fn_name = names.fresh(adverb_class.node_type() + "_wrapper")

    fundef = syntax.Fn(fn_name, fn_args_obj, body)
    _adverb_wrapper_cache[key] = fundef

    return fundef
Пример #16
0
  def transform_Map(self, expr):
    self.num_tiles += 1

    depth = len(self.adverbs_visited)
    closure = expr.fn
    closure_args = []
    fn = closure
    if isinstance(fn, syntax.Closure):
      closure_args = closure.args
      fn = closure.fn

    axes = [self.get_num_expansions_at_depth(arg.name, depth) + expr.axis
            for arg in expr.args]
    self.push_exp(syntax.Map, AdverbArgs(expr.fn, expr.args, expr.axis, axes))
    for fn_arg, adverb_arg in zip(fn.arg_names[:len(closure_args)],
                                  closure_args):
      name = self.get_closure_arg(adverb_arg).name
      new_expansions = copy.deepcopy(self.get_expansions(name))
      self.expansions[fn_arg] = new_expansions
    for fn_arg, adverb_arg in zip(fn.arg_names[len(closure_args):], expr.args):
      new_expansions = copy.deepcopy(self.get_expansions(adverb_arg.name))
      new_expansions.append(depth)
      self.expansions[fn_arg] = new_expansions

    depths = self.get_depths_list(fn.arg_names)
    find_adverbs = FindAdverbs()
    find_syntax.visit_fn(fn)

    if find_syntax.has_adverbs:
      arg_names = list(fn.arg_names)
      input_types = []
      self.push_type_env(fn.type_env)
      for arg, t in zip(arg_names, fn.input_types):
        new_type = array_type.increase_rank(t, len(self.get_expansions(arg)))
        input_types.append(new_type)
        self.type_env[arg] = new_type
      exps = self.get_depths_list(fn.arg_names)
      rank_inc = 0
      for i, exp in enumerate(exps):
        if exp >= depth:
          rank_inc = i
          break
      return_t = array_type.increase_rank(expr.type, rank_inc)
      new_fn = syntax.TypedFn(name = names.fresh("expanded_map_fn"),
                              arg_names = tuple(arg_names),
                              body = self.transform_block(fn.body),
                              input_types = input_types,
                              return_type = return_t,
                              type_env = self.pop_type_env())
      new_fn.has_tiles = True
    else:
      # Estimate the tile sizes
      self.estimate_tile_sizes(fn.arg_names, depths)

      new_fn = self.gen_unpack_tree(self.adverbs_visited, depths, fn.arg_names,
                                    fn.body, fn.type_env)
      if config.opt_reg_tile:
        adverb_tree = [get_tiled_version(adv) for adv in self.adverbs_visited]
        new_fn = self.gen_unpack_tree(adverb_tree, depths, fn.arg_names, new_fn,
                                      fn.type_env, reg_tiling = True)

    for arg, t in zip(expr.args, new_fn.input_types[len(closure_args):]):
      arg.type = t
    return_t = new_fn.return_type
    if isinstance(closure, syntax.Closure):
      for c_arg, t in zip(closure.args, new_fn.input_types):
        c_arg.type = t
      closure_arg_types = [arg.type for arg in closure.args]
      closure.fn = new_fn
      closure.type = closure_type.make_closure_type(new_fn, closure_arg_types)
      new_fn = closure
    self.pop_exp()
    return syntax.TiledMap(fn = new_fn, args = expr.args, axes = axes,
                            type = return_t)
Пример #17
0
    def gen_unpack_fn(depth_idx, arg_order):
      if depth_idx >= len(depths):
        if reg_tiling:
          return inner
        else:
          # For each stmt in body, add its lhs free vars to the type env
          inner_type_env = copy.copy(type_env)
          return_t = Int32 # Dummy type
          for s in inner:
            if isinstance(s, syntax.Assign):
              lhs_names = free_vars(s.lhs)
              lhs_types = [type_env[name] for name in lhs_names]
              for name, t in zip(lhs_names, lhs_types):
                inner_type_env[name] = t
            elif isinstance(s, syntax.Return):
              if isinstance(s.value, str):
                return_t = type_env[s.value.name]
              else:
                return_t = s.value.type

          # The innermost function always uses all the variables
          input_types = [type_env[arg] for arg in arg_order]
          fn = syntax.TypedFn(name = names.fresh("inner_block"),
                              arg_names = tuple([name for name in v_names]),
                              body = inner,
                              input_types = input_types,
                              return_type = return_t,
                              type_env = inner_type_env)
          return fn
      else:
        # Get the current depth
        depth = depths[depth_idx]

        # Order the arguments for the current depth, i.e. for the nested fn
        cur_arg_names, fixed_arg_names = order_args(depth)
        nested_arg_names = fixed_arg_names + cur_arg_names

        # Make a type env for this function based on the number of expansions
        # left for each arg
        adv_args = self.adverb_args[depth_idx]
        if reg_tiling:
          new_adverb = adverb_tree[depth_idx](fn = adv_args.fn,
                                              args = adv_args.args,
                                              axes = adv_args.axes,
                                              fixed_tile_size = True)
        else:
          new_adverb = adverb_tree[depth_idx](fn = adv_args.fn,
                                              args = adv_args.args,
                                              axis = adv_args.axis)

        # Increase the rank of each arg by the number of nested expansions
        # (i.e. the expansions of that arg that occur deeper in the nesting)
        new_type_env = {}
        if reg_tiling:
          for arg in nested_arg_names:
            new_type_env[arg] = inner.type_env[arg]
        else:
          for arg in nested_arg_names:
            exps = self.get_expansions(arg)
            rank_increase = 0
            for i, e in enumerate(exps):
              if e >= depth:
                rank_increase = len(exps) - i
                break
            new_type_env[arg] = \
                array_type.increase_rank(type_env[arg], rank_increase)

        cur_arg_types = [new_type_env[arg] for arg in cur_arg_names]
        fixed_arg_types = [new_type_env[arg] for arg in fixed_arg_names]

        # Generate the nested function with the proper arg order and wrap it
        # in a closure
        nested_fn = gen_unpack_fn(depth_idx+1, nested_arg_names)
        nested_args = [syntax.Var(name, type = t)
                       for name, t in zip(cur_arg_names, cur_arg_types)]
        nested_fixed_args = \
            [syntax.Var(name, type = t)
             for name, t in zip(fixed_arg_names, fixed_arg_types)]
        nested_closure = self.closure(nested_fn, nested_fixed_args)

        # Make an adverb that wraps the nested fn
        new_adverb.fn = nested_closure
        new_adverb.args = nested_args
        return_t = nested_fn.return_type
        if isinstance(new_adverb, syntax.Reduce):
          if reg_tiling:
            ds = copy.copy(depths)
            ds.remove(depth)
            new_adverb.combine = self.unpack_combine(adv_args.combine, ds)
          else:
            new_adverb.combine = adv_args.combine
          new_adverb.init = adv_args.init
        elif not reg_tiling:
          return_t = array_type.increase_rank(nested_fn.return_type, 1)
        new_adverb.type = return_t

        # Add the adverb to the body of the current fn and return the fn
        name = names.fresh("reg_tile" if reg_tiling else "intermediate_depth")
        arg_types = [new_type_env[arg] for arg in arg_order]
        fn = syntax.TypedFn(name = name,
                            arg_names = arg_order,
                            body = [syntax.Return(new_adverb)],
                            input_types = arg_types,
                            return_type = return_t,
                            type_env = new_type_env)
        return fn
Пример #18
0
 def fresh_var(self, t, prefix = "temp"):
   assert prefix is not None
   assert t is not None, "Type required for new variable %s" % prefix
   ssa_id = names.fresh(prefix)
   self.type_env[ssa_id] = t
   return Var(ssa_id, type = t)
Пример #19
0
def gen_par_work_function(adverb_class, f, nonlocals, nonlocal_types,
                          args_t, arg_types, dont_slice_position = -1):
  key = (adverb_class, f.name, tuple(arg_types), config.opt_tile)
  if key in _par_wrapper_cache:
    return _par_wrapper_cache[key]
  else:
    fn = gen_tiled_wrapper(adverb_class, f, arg_types, nonlocal_types)
    num_tiles = fn.num_tiles
    # Construct a typed parallel wrapper function that unpacks the args struct
    # and calls the (possibly tiled) payload function with its slices of the
    # arguments.
    start_var = syntax.Var(names.fresh("start"), type = Int64)
    stop_var = syntax.Var(names.fresh("stop"), type = Int64)
    args_var = syntax.Var(names.fresh("args"), type = args_t)
    tile_type = tuple_type.make_tuple_type([Int64 for _ in range(num_tiles)])
    tile_sizes_var = syntax.Var(names.fresh("tile_sizes"), type = tile_type)
    inputs = [start_var, stop_var, args_var, tile_sizes_var]

    # Manually unpack the args into types Vars and slice into them.
    slice_t = array_type.make_slice_type(Int64, Int64, Int64)
    arg_slice = syntax.Slice(start_var, stop_var, syntax_helpers.one_i64,
                             type = slice_t)
    def slice_arg(arg, t):
      indices = [arg_slice]
      for _ in xrange(1, arg.type.rank):
        indices.append(syntax_helpers.slice_none)
      tuple_t = tuple_type.make_tuple_type(syntax_helpers.get_types(indices))
      index_tuple = syntax.Tuple(indices, tuple_t)
      result_t = t.index_type(tuple_t)
      return syntax.Index(arg, index_tuple, type = result_t)
    unpacked_args = []
    i = 0
    for t in nonlocal_types:
      unpacked_args.append(syntax.Attribute(args_var, ("arg%d" % i), type = t))
      i += 1
    for t in arg_types:
      attr = syntax.Attribute(args_var, ("arg%d" % i), type = t)

      if isinstance(t, array_type.ArrayT) and i != dont_slice_position:
        # TODO: Handle axis.
        unpacked_args.append(slice_arg(attr, t))
      else:
        unpacked_args.append(attr)
      i += 1

    # If tiling, pass in the tile params array.
    if config.opt_tile:
      unpacked_args.append(tile_sizes_var)

    # Make a typed closure that calls the payload function with the arg slices.
    closure_t = closure_type.make_closure_type(fn, [])
    nested_closure = syntax.Closure(fn, [], type = closure_t)
    return_t = fn.return_type
    call = syntax.Call(nested_closure, unpacked_args, type = return_t)

    output_name = names.fresh("output")
    output_attr = syntax.Attribute(args_var, "output", type = return_t)
    output_var = syntax.Var(output_name, type = output_attr.type)
    output_slice = slice_arg(output_var, return_t)
    body = [syntax.Assign(output_var, output_attr),
            syntax.Assign(output_slice, call),
            syntax.Return(syntax_helpers.none)]
    type_env = {output_name:output_slice.type}
    for arg in inputs:
      type_env[arg.name] = arg.type

    # Construct the typed wrapper.
    wrapper_name = adverb_class.node_type() + fn.name + "_par"
    parallel_wrapper = \
        syntax.TypedFn(name = names.fresh(wrapper_name),
                       arg_names = [var.name for var in inputs],
                       input_types = syntax_helpers.get_types(inputs),
                       body = body,
                       return_type = core_types.NoneType,
                       type_env = type_env)
    lowered = lowering(parallel_wrapper)

    lowered.num_tiles = num_tiles
    lowered.dl_tile_estimates = fn.dl_tile_estimates
    lowered.ml_tile_estimates = fn.ml_tile_estimates
    _par_wrapper_cache[key] = lowered
    return lowered
Пример #20
0
def replace_return_with_var(body, type_env, return_type):
  result_name = names.fresh("result")
  type_env[result_name] = return_type
  result_var = Var(result_name, type = return_type)
  new_body = replace_returns(body, result_var)
  return result_var, new_body
Пример #21
0
def fuse(prev_fn, prev_fixed_args, next_fn, next_fixed_args, fusion_args):
  if syntax_helpers.is_identity_fn(next_fn):
    assert len(next_fixed_args) == 0
    return prev_fn, prev_fixed_args

  """
  Expects the prev_fn's returned value to be one or more of the arguments to
  next_fn. Any element in 'const_args' which is None gets replaced by the
  returned Var
  """
  fused_formals = []
  fused_input_types = []
  fused_type_env = prev_fn.type_env.copy()
  fused_name = names.fresh('fused')

  prev_closure_formals = prev_fn.arg_names[:len(prev_fixed_args)]

  for prev_closure_arg_name in prev_closure_formals:
    t = prev_fn.type_env[prev_closure_arg_name]
    fused_formals.append(prev_closure_arg_name)
    fused_input_types.append(t)

  next_closure_formals = next_fn.arg_names[:len(next_fixed_args)]
  for next_closure_arg_name in next_closure_formals:
    t = next_fn.type_env[next_closure_arg_name]
    new_name = names.refresh(next_closure_arg_name)
    fused_type_env[new_name] = t
    fused_formals.append(new_name)
    fused_input_types.append(t)

  prev_direct_formals = prev_fn.arg_names[len(prev_fixed_args):]
  for arg_name in prev_direct_formals:
    t = prev_fn.type_env[arg_name]
    fused_formals.append(arg_name)
    fused_input_types.append(t)

  prev_return_var, fused_body = \
      inline.replace_return_with_var(prev_fn.body,
                                     fused_type_env,
                                     prev_fn.return_type)
  # for now we're restricting both functions to have a single return at the
  # outermost scope
  inline_args = list(next_closure_formals)
  for arg in fusion_args:
    if arg is None:
      inline_args.append(prev_return_var)
    elif isinstance(arg, int):
      # positional arg which is not being fused out
      inner_name = next_fn.arg_names[arg]

      inner_type = next_fn.type_env[inner_name]
      new_name = names.refresh(inner_name)
      fused_formals.append(new_name)
      fused_type_env[new_name] = inner_type
      fused_input_types.append(inner_type)
      var = Var(new_name, inner_type)
      inline_args.append(var)
    else:
      assert arg.__class__ is Const, \
         "Only scalars can be spliced as literals into a fused fn: %s" % arg
      inline_args.append(arg)
  next_return_var = inline.do_inline(next_fn, inline_args,
                                     fused_type_env, fused_body)
  fused_body.append(Return(next_return_var))

  # we're not renaming variables that originate from the predecessor function
  new_fn = TypedFn(name = fused_name,
                   arg_names = fused_formals,
                   body = fused_body,
                   input_types = tuple(fused_input_types),
                   return_type = next_fn.return_type,
                   type_env = fused_type_env)

  combined_args = prev_fixed_args + next_fixed_args
  return new_fn, combined_args 
Пример #22
0
 def fresh_name(self, original_name):
   fresh_name = names.fresh(original_name)
   self.scopes[original_name] = fresh_name
   return fresh_name
Пример #23
0
 def fresh(self, name):
   fresh_name = names.fresh(name)
   self.scopes[-1][name] = fresh_name
   return fresh_name