예제 #1
0
파일: syntax.py 프로젝트: cournape/parakeet
  def node_init(self):
    assert isinstance(self.name, str), \
        "Expected string for fn name, got %s" % self.name
    assert isinstance(self.args, args.FormalArgs), \
        "Expected arguments to fn to be FormalArgs object, got %s" % self.args
    assert isinstance(self.body, list), \
        "Expected body of fn to be list of statements, got " + str(self.body)

    import closure_type
    self.type = closure_type.make_closure_type(self, ())
    self.registry[self.name] = self
예제 #2
0
def _get_closure_type(fn):
  if fn.__class__ is closure_type.ClosureT:
    return fn
  elif isinstance(fn, typed_ast.Closure):
    return fn.type
  elif isinstance(fn, typed_ast.Var):
    assert isinstance(fn.type, closure_type.ClosureT)
    return fn.type
  else:
    fundef = _get_fundef(fn)
    return closure_type.make_closure_type(fundef, [])
예제 #3
0
def make_typed_closure(untyped_closure, typed_fn):
  if untyped_closure.__class__ is untyped_ast.Fn:
    return typed_fn

  assert isinstance(untyped_closure, untyped_ast.Expr) and \
      isinstance(untyped_closure.type, closure_type.ClosureT)
  _, closure_args = unpack_closure(untyped_closure)
  if len(closure_args) == 0:
    return typed_fn
  else:
    t = closure_type.make_closure_type(typed_fn, get_types(closure_args))
    return typed_ast.Closure(typed_fn, closure_args, t)
예제 #4
0
 def closure(self, maybe_fn, extra_args, name = None):
   fn = self.get_fn(maybe_fn)
   old_closure_elts = self.closure_elts(maybe_fn)
   closure_elts = old_closure_elts + tuple(extra_args)
   if len(closure_elts) == 0:
     return fn 
   closure_elt_types = [elt.type for elt in closure_elts]
   closure_t = make_closure_type(fn, closure_elt_types)
   result = Closure(fn, closure_elts, type = closure_t)
   if name:
     return self.assign_temp(result, name)
   else:
     return result
예제 #5
0
def _get_closure_type(fn):
  assert isinstance(fn, (Fn, TypedFn, ClosureT, Closure, Var)), \
    "Expected function, got %s" % fn
    
  if fn.__class__ is closure_type.ClosureT:
    return fn
  elif isinstance(fn, Closure):
    return fn.type
  elif isinstance(fn, Var):
    assert isinstance(fn.type, closure_type.ClosureT)
    return fn.type
  else:
    fundef = _get_fundef(fn)
    return closure_type.make_closure_type(fundef, [])
예제 #6
0
  def flatten_Reduce(self, map_fn, combine, x, init):
    """Turn an axis-less reduction into a IndexReduce"""
    shape = self.shape(x)
    n_indices = self.rank(x)
    # build a function from indices which picks out the data elements
    # need for the original map_fn
 
    
    outer_closure_args = self.closure_elts(map_fn)
    args_obj = FormalArgs()
    inner_closure_vars = []
    for i in xrange(len(outer_closure_args)):
      visible_name = "c%d" % i
      name = names.fresh(visible_name)
      args_obj.add_positional(name, visible_name)
      inner_closure_vars.append(Var(name))
    
    data_arg_name = names.fresh("x")
    data_arg_var = Var(data_arg_name)
    idx_arg_name = names.fresh("i")
    idx_arg_var = Var(idx_arg_name)
    
    args_obj.add_positional(data_arg_name, "x")
    args_obj.add_positional(idx_arg_name, "i")
    
    idx_expr = syntax.Index(data_arg_var, idx_arg_var)
    inner_fn = self.get_fn(map_fn)
    fn_call_expr = syntax.Call(inner_fn, tuple(inner_closure_vars)  + (idx_expr,))
    idx_fn = syntax.Fn(name = names.fresh("idx_map"),
                       args = args_obj, 
                       body =  [syntax.Return(fn_call_expr)]
                       )
    
    #t = closure_type.make_closure_type(typed_fn, get_types(closure_args))
    #return Closure(typed_fn, closure_args, t)
    outer_closure_args = tuple(outer_closure_args) + (x,)
  
    idx_closure_t = closure_type.make_closure_type(idx_fn, get_types(outer_closure_args))
    
    idx_closure = Closure(idx_fn, args = outer_closure_args, type = idx_closure_t)
    
    result_type, typed_fn, typed_combine = \
      specialize_IndexReduce(idx_closure, combine, n_indices, init)
    if not self.is_none(init):
      init = self.cast(init, typed_combine.return_type)
    return syntax.IndexReduce(shape = shape, 
                              fn = make_typed_closure(idx_closure, typed_fn),
                              combine = make_typed_closure(combine, typed_combine),
                              init = init,   
                              type = result_type)
예제 #7
0
 def expr_Fn():
   t = closure_type.make_closure_type(expr, ())
   return typed_ast.Closure(expr, [], type = t)
예제 #8
0
 def prim_to_closure(p):
   untyped_fn = prims.prim_wrapper(p)
   t = closure_type.make_closure_type(untyped_fn, ())
   return typed_ast.Closure(untyped_fn, (), type = t)
예제 #9
0
 def expr_Closure():
   new_args = annotate_children(expr.args)
   t = closure_type.make_closure_type(expr.fn, get_types(new_args))
   return typed_ast.Closure(expr.fn, new_args, type = t)
예제 #10
0
  def transform_Reduce(self, expr):
    self.num_tiles += 1

    depth = len(self.adverbs_visited)
    closure = expr.fn
    closure_args = []
    fn = closure
    if isinstance(fn, syntax.Closure):
      closure_args = closure.args
      fn = closure.fn

    axes = [self.get_num_expansions_at_depth(arg.name, depth) + expr.axis
            for arg in expr.args]
    self.push_exp(syntax.Reduce, AdverbArgs(combine = expr.combine,
                                             init = expr.init,
                                             fn = expr.fn,
                                             args = expr.args,
                                             axis = expr.axis,
                                             axes = axes))
    for fn_arg, adverb_arg in zip(fn.arg_names[:len(closure_args)],
                                  closure_args):
      name = self.get_closure_arg(adverb_arg)
      new_expansions = copy.deepcopy(self.get_expansions(name))
      self.expansions[fn_arg] = new_expansions
    for fn_arg, adverb_arg in zip(fn.arg_names[len(closure_args):], expr.args):
      new_expansions = copy.deepcopy(self.get_expansions(adverb_arg.name))
      new_expansions.append(depth)
      self.expansions[fn_arg] = new_expansions

    depths = self.get_depths_list(fn.arg_names)

    # Estimate the tile sizes
    self.estimate_tile_sizes(fn.arg_names, depths)

    new_fn = self.gen_unpack_tree(self.adverbs_visited, depths,
                                  fn.arg_names, fn.body, fn.type_env)
    if config.opt_reg_tile:
      adverb_tree = [get_tiled_version(adv) for adv in self.adverbs_visited]
      new_fn = self.gen_unpack_tree(adverb_tree, depths, fn.arg_names, new_fn,
                                    fn.type_env, reg_tiling = True)

    for arg, t in zip(expr.args, new_fn.input_types[len(closure_args):]):
      arg.type = t
    init = expr.init # Initial value lifted to proper shape in lowering
    if len(depths) > 1:
      depths.remove(depth)
      new_combine = self.unpack_combine(expr.combine, depths)
    else:
      new_combine = expr.combine
    return_t = new_fn.return_type
    if isinstance(closure, syntax.Closure):
      for c_arg, t in zip(closure.args, new_fn.input_types):
        c_arg.type = t
      closure_arg_types = [arg.type for arg in closure.args]
      closure.fn = new_fn
      closure.type = closure_type.make_closure_type(new_fn, closure_arg_types)
      new_fn = closure
    self.pop_exp()
    return syntax.TiledReduce(fn = new_fn,
                               combine = new_combine,
                               init = init,
                               args = expr.args,
                               axes = axes,
                               type = return_t)
예제 #11
0
  def transform_Map(self, expr):
    self.num_tiles += 1

    depth = len(self.adverbs_visited)
    closure = expr.fn
    closure_args = []
    fn = closure
    if isinstance(fn, syntax.Closure):
      closure_args = closure.args
      fn = closure.fn

    axes = [self.get_num_expansions_at_depth(arg.name, depth) + expr.axis
            for arg in expr.args]
    self.push_exp(syntax.Map, AdverbArgs(expr.fn, expr.args, expr.axis, axes))
    for fn_arg, adverb_arg in zip(fn.arg_names[:len(closure_args)],
                                  closure_args):
      name = self.get_closure_arg(adverb_arg).name
      new_expansions = copy.deepcopy(self.get_expansions(name))
      self.expansions[fn_arg] = new_expansions
    for fn_arg, adverb_arg in zip(fn.arg_names[len(closure_args):], expr.args):
      new_expansions = copy.deepcopy(self.get_expansions(adverb_arg.name))
      new_expansions.append(depth)
      self.expansions[fn_arg] = new_expansions

    depths = self.get_depths_list(fn.arg_names)
    find_adverbs = FindAdverbs()
    find_syntax.visit_fn(fn)

    if find_syntax.has_adverbs:
      arg_names = list(fn.arg_names)
      input_types = []
      self.push_type_env(fn.type_env)
      for arg, t in zip(arg_names, fn.input_types):
        new_type = array_type.increase_rank(t, len(self.get_expansions(arg)))
        input_types.append(new_type)
        self.type_env[arg] = new_type
      exps = self.get_depths_list(fn.arg_names)
      rank_inc = 0
      for i, exp in enumerate(exps):
        if exp >= depth:
          rank_inc = i
          break
      return_t = array_type.increase_rank(expr.type, rank_inc)
      new_fn = syntax.TypedFn(name = names.fresh("expanded_map_fn"),
                              arg_names = tuple(arg_names),
                              body = self.transform_block(fn.body),
                              input_types = input_types,
                              return_type = return_t,
                              type_env = self.pop_type_env())
      new_fn.has_tiles = True
    else:
      # Estimate the tile sizes
      self.estimate_tile_sizes(fn.arg_names, depths)

      new_fn = self.gen_unpack_tree(self.adverbs_visited, depths, fn.arg_names,
                                    fn.body, fn.type_env)
      if config.opt_reg_tile:
        adverb_tree = [get_tiled_version(adv) for adv in self.adverbs_visited]
        new_fn = self.gen_unpack_tree(adverb_tree, depths, fn.arg_names, new_fn,
                                      fn.type_env, reg_tiling = True)

    for arg, t in zip(expr.args, new_fn.input_types[len(closure_args):]):
      arg.type = t
    return_t = new_fn.return_type
    if isinstance(closure, syntax.Closure):
      for c_arg, t in zip(closure.args, new_fn.input_types):
        c_arg.type = t
      closure_arg_types = [arg.type for arg in closure.args]
      closure.fn = new_fn
      closure.type = closure_type.make_closure_type(new_fn, closure_arg_types)
      new_fn = closure
    self.pop_exp()
    return syntax.TiledMap(fn = new_fn, args = expr.args, axes = axes,
                            type = return_t)
예제 #12
0
 def transform_Arith(self, expr):
   untyped_fn = prims.prim_wrapper(expr)
   t = closure_type.make_closure_type(untyped_fn, ())
   return Closure(untyped_fn, (), type = t)
예제 #13
0
 def transform_Closure(self, expr):
   new_args = self.transform_expr_list(expr.args)
   t = closure_type.make_closure_type(expr.fn, get_types(new_args))
   return Closure(expr.fn, new_args, type = t)
예제 #14
0
def gen_par_work_function(adverb_class, f, nonlocals, nonlocal_types,
                          args_t, arg_types, dont_slice_position = -1):
  key = (adverb_class, f.name, tuple(arg_types), config.opt_tile)
  if key in _par_wrapper_cache:
    return _par_wrapper_cache[key]
  else:
    fn = gen_tiled_wrapper(adverb_class, f, arg_types, nonlocal_types)
    num_tiles = fn.num_tiles
    # Construct a typed parallel wrapper function that unpacks the args struct
    # and calls the (possibly tiled) payload function with its slices of the
    # arguments.
    start_var = syntax.Var(names.fresh("start"), type = Int64)
    stop_var = syntax.Var(names.fresh("stop"), type = Int64)
    args_var = syntax.Var(names.fresh("args"), type = args_t)
    tile_type = tuple_type.make_tuple_type([Int64 for _ in range(num_tiles)])
    tile_sizes_var = syntax.Var(names.fresh("tile_sizes"), type = tile_type)
    inputs = [start_var, stop_var, args_var, tile_sizes_var]

    # Manually unpack the args into types Vars and slice into them.
    slice_t = array_type.make_slice_type(Int64, Int64, Int64)
    arg_slice = syntax.Slice(start_var, stop_var, syntax_helpers.one_i64,
                             type = slice_t)
    def slice_arg(arg, t):
      indices = [arg_slice]
      for _ in xrange(1, arg.type.rank):
        indices.append(syntax_helpers.slice_none)
      tuple_t = tuple_type.make_tuple_type(syntax_helpers.get_types(indices))
      index_tuple = syntax.Tuple(indices, tuple_t)
      result_t = t.index_type(tuple_t)
      return syntax.Index(arg, index_tuple, type = result_t)
    unpacked_args = []
    i = 0
    for t in nonlocal_types:
      unpacked_args.append(syntax.Attribute(args_var, ("arg%d" % i), type = t))
      i += 1
    for t in arg_types:
      attr = syntax.Attribute(args_var, ("arg%d" % i), type = t)

      if isinstance(t, array_type.ArrayT) and i != dont_slice_position:
        # TODO: Handle axis.
        unpacked_args.append(slice_arg(attr, t))
      else:
        unpacked_args.append(attr)
      i += 1

    # If tiling, pass in the tile params array.
    if config.opt_tile:
      unpacked_args.append(tile_sizes_var)

    # Make a typed closure that calls the payload function with the arg slices.
    closure_t = closure_type.make_closure_type(fn, [])
    nested_closure = syntax.Closure(fn, [], type = closure_t)
    return_t = fn.return_type
    call = syntax.Call(nested_closure, unpacked_args, type = return_t)

    output_name = names.fresh("output")
    output_attr = syntax.Attribute(args_var, "output", type = return_t)
    output_var = syntax.Var(output_name, type = output_attr.type)
    output_slice = slice_arg(output_var, return_t)
    body = [syntax.Assign(output_var, output_attr),
            syntax.Assign(output_slice, call),
            syntax.Return(syntax_helpers.none)]
    type_env = {output_name:output_slice.type}
    for arg in inputs:
      type_env[arg.name] = arg.type

    # Construct the typed wrapper.
    wrapper_name = adverb_class.node_type() + fn.name + "_par"
    parallel_wrapper = \
        syntax.TypedFn(name = names.fresh(wrapper_name),
                       arg_names = [var.name for var in inputs],
                       input_types = syntax_helpers.get_types(inputs),
                       body = body,
                       return_type = core_types.NoneType,
                       type_env = type_env)
    lowered = lowering(parallel_wrapper)

    lowered.num_tiles = num_tiles
    lowered.dl_tile_estimates = fn.dl_tile_estimates
    lowered.ml_tile_estimates = fn.ml_tile_estimates
    _par_wrapper_cache[key] = lowered
    return lowered