Пример #1
0
 def visit_Reduce(self, expr):
   fn = self.visit_expr(expr.fn)
   combine = self.visit_expr(expr.combine)
   arg_shapes = self.visit_expr_list(expr.args)
   init = self.visit_expr(expr.init) if expr.init else None
   axis = unwrap_constant(expr.axis)
   return shape_semantics.eval_reduce(fn, combine, init, arg_shapes, axis)
Пример #2
0
 def expr_Reduce():
   map_fn = eval_expr(expr.fn)
   combine_fn = eval_expr(expr.combine)
   args = eval_args(expr.args)
   init = eval_expr(expr.init) if expr.init else None
   axis = syntax_helpers.unwrap_constant(expr.axis)
   return adverb_evaluator.eval_reduce(map_fn, combine_fn, init, args, axis)
Пример #3
0
 def transform_AllPairs(self, expr):
   fn = self.transform_expr(expr.fn)
   args = self.transform_expr_list(expr.args)
   assert len(args) == 2
   x,y = self.transform_expr_list(args)
   axis = syntax_helpers.unwrap_constant(expr.axis)
   return self.eval_allpairs(fn, x, y, axis)
Пример #4
0
 def transform_Reduce(self, expr):
   fn = self.transform_expr(expr.fn)
   args = self.transform_expr_list(expr.args)
   combine = self.transform_expr(expr.combine)
   init = self.transform_if_expr(expr.init)
   axis = syntax_helpers.unwrap_constant(expr.axis)
   return self.eval_reduce(fn, combine, init, args, axis)
Пример #5
0
 def expr_Scan():
   map_fn = eval_expr(expr.fn)
   combine = eval_expr(expr.combine)
   emit = eval_expr(expr.emit)
   args = eval_args(expr.args)
   init = eval_expr(expr.init)
   axis = syntax_helpers.unwrap_constant(expr.axis)
   return adverb_evaluator.eval_scan(map_fn, combine, emit, init, args, axis)
Пример #6
0
  def transform_Scan(self, expr):
    fn = self.transform_expr(expr.fn)
    args = self.transform_expr_list(expr.args)
    combine = self.transform_expr(expr.combine)
    print expr.emit 
    emit = self.transform_expr(expr.emit)

    init = self.transform_if_expr(expr.init)
    axis = syntax_helpers.unwrap_constant(expr.axis)
    return self.eval_scan(fn, combine, emit, init, args, axis)
Пример #7
0
def num_outer_axes(arg_types, axis):
  """
  Helper for adverb type inference to figure out how many axes it will loop over
  -- either 1 particular one or all of them when axis is None.
  """

  axis = syntax_helpers.unwrap_constant(axis)
  if isinstance(arg_types, core_types.Type):
    max_arg_rank = arg_types.rank
  else:
    max_arg_rank = max_rank(arg_types)
  return 1 if (max_arg_rank > 0 and axis is not None) else max_arg_rank
Пример #8
0
 def expr_AllPairs():
   closure = annotate_child(expr.fn)
   new_args = annotate_args (expr.args, flat = True)
   arg_types = get_types(new_args)
   assert len(arg_types) == 2
   xt,yt = arg_types
   result_type, typed_fn = specialize_AllPairs(closure.type, xt, yt)
   axis = unwrap_constant(expr.axis)
   return adverbs.AllPairs(make_typed_closure(closure, typed_fn),
                           args = new_args,
                           axis = axis,
                           type = result_type)
Пример #9
0
  def expr_Map():
    closure = annotate_child(expr.fn)
    new_args = annotate_args(expr.args, flat = True)
    axis = unwrap_constant(expr.axis)
    arg_types = get_types(new_args)
    result_type, typed_fn = specialize_Map(closure.type, arg_types)

    if axis is None and adverb_helpers.max_rank(arg_types) == 1:
      axis = 0
    return adverbs.Map(fn = make_typed_closure(closure, typed_fn),
                       args = new_args,
                       axis = axis,
                       type = result_type)
Пример #10
0
 def expr_Reduce():
   map_fn = annotate_child(expr.fn)
   combine_fn = annotate_child(expr.combine)
   new_args = annotate_args(expr.args, flat = True)
   arg_types = get_types(new_args)
   init = annotate_child(expr.init) if expr.init else None
   init_type = init.type if init else None
   result_type, typed_map_fn, typed_combine_fn = \
       specialize_Reduce(map_fn.type,
                         combine_fn.type,
                         arg_types, 
                         init_type)
   typed_map_closure = make_typed_closure (map_fn, typed_map_fn)
   typed_combine_closure = make_typed_closure(combine_fn, typed_combine_fn)
   axis = unwrap_constant(expr.axis)
   if axis is None and adverb_helpers.max_rank(arg_types) == 1:
     axis = 0
   if init_type and init_type != result_type and \
      array_type.rank(init_type) < array_type.rank(result_type):
     assert len(new_args) == 1
     assert axis == 0
     arg = new_args[0]
     first_elt = typed_ast.Index(arg, zero_i64, 
                                 type = arg.type.index_type(zero_i64))
     first_combine = specialize(combine_fn, (init_type, first_elt.type))
     first_combine_closure = make_typed_closure(combine_fn, first_combine)
     init = typed_ast.Call(first_combine_closure, (init, first_elt), 
                                type = first_combine.return_type)
     slice_rest = typed_ast.Slice(start = one_i64, stop = none, step = one_i64, 
                                  type = array_type.SliceT(Int64, NoneType, Int64))
     rest = typed_ast.Index(arg, slice_rest, 
                            type = arg.type.index_type(slice_rest))
     new_args = (rest,)  
   return adverbs.Reduce(fn = typed_map_closure,
                         combine = typed_combine_closure,
                         args = new_args,
                         axis = axis,
                         type = result_type,
                         init = init)
Пример #11
0
 def expr_Scan():
   map_fn = annotate_child(expr.fn)
   combine_fn = annotate_child(expr.combine)
   emit_fn = annotate_child(expr.emit)
   new_args = annotate_args(expr.args, flat = True)
   arg_types = get_types(new_args)
   init = annotate_child(expr.init) if expr.init else None
   init_type = get_type(init) if init else None
   result_type, typed_map_fn, typed_combine_fn, typed_emit_fn = \
       specialize_Scan(map_fn.type, combine_fn.type, emit_fn.type,
                       arg_types, init_type)
   map_fn.fn = typed_map_fn
   combine_fn.fn = typed_combine_fn
   emit_fn.fn = typed_emit_fn
   axis = unwrap_constant(expr.axis)
   return adverbs.Scan(fn = make_typed_closure(map_fn, typed_map_fn),
                       combine = make_typed_closure(combine_fn,
                                                    typed_combine_fn),
                       emit = make_typed_closure(emit_fn, typed_emit_fn),
                       args = new_args,
                       axis = axis,
                       type = result_type,
                       init = init)
Пример #12
0
 def visit_Map(self, expr):
   arg_shapes = self.visit_expr_list(expr.args)
   fn = self.visit_expr(expr.fn)
   axis = unwrap_constant(expr.axis)
   return shape_semantics.eval_map(fn, arg_shapes, axis)
Пример #13
0
  def transform_TiledReduce(self, expr):
    args = expr.args
    axes = expr.axes

    # TODO: Should make sure that all the shapes conform here,
    # but we don't yet have anything like assertions or error handling.
    niters = self.shape(args[0], syntax_helpers.unwrap_constant(axes[0]))

    if expr.fixed_tile_size:
      self.fixed_idx += 1
      tile_size = syntax_helpers.const(self.fixed_tile_sizes[self.fixed_idx])
    else:
      self.tiling = True
      self.fn.has_tiles = True
      self.nesting_idx += 1
      tile_size = self.index(self.tile_sizes_param, self.nesting_idx,
                             temp = True, name = "tilesize")

    slice_t = array_type.make_slice_type(Int64, Int64, Int64)

    untiled_map_fn = self.get_fn(expr.fn)

    acc_type = untiled_map_fn.return_type
    acc_is_array = not isinstance(acc_type, ScalarT)

    tiled_map_fn = self.transform_TypedFn(untiled_map_fn)
    map_closure_args = [self.get_closure_arg(e)
                        for e in self.closure_elts(expr.fn)]

    untiled_combine = self.get_fn(expr.combine)
    combine_closure_args = []

    tiled_combine = self.transform_TypedFn(untiled_combine, acc_is_array)
    if self.output_var and acc_is_array:
      result = self.output_var
    else:
      shape_args = map_closure_args + args
      result = self._create_output_array(untiled_map_fn, shape_args,
                                         [], "loop_result")
    init = result
    rslt_t = result.type

    if not acc_is_array:
      result_before = self.fresh_var(rslt_t, "result_before")
      init = result_before

    # Lift the initial value and fill it.
    def init_unpack(i, cur):
      if i == 0:
        return syntax.Assign(cur, syntax_helpers.zero_f64)
      else:
        j = self.fresh_i64("j")
        start = zero_i64
        stop = self.shape(cur, 0)

        self.blocks.push()
        n = self.index_along_axis(cur, 0, j)
        self.blocks += init_unpack(i-1, n)
        body = self.blocks.pop()

        return syntax.ForLoop(j, start, stop, one_i64, body, {})
    num_exps = array_type.get_rank(init.type) - \
               array_type.get_rank(expr.init.type)

    # TODO: Get rid of this when safe to do so.
    if not expr.fixed_tile_size or True:
      self.comment("TiledReduce in %s: init_unpack" % self.fn.name)
      self.blocks += init_unpack(num_exps, init)

    # Loop over the remaining tiles.
    merge = {}

    if not acc_is_array:
      result_after = self.fresh_var(rslt_t, "result_after")
      merge[result.name] = (result_before, result_after)

    def make_loop(start, stop, step, do_min = True):
      i = self.fresh_var(niters.type, "i")
      self.blocks.push()
      slice_stop = self.add(i, step, "next_bound")
      slice_stop_min = self.min(slice_stop, stop) if do_min \
                       else slice_stop

      tile_bounds = syntax.Slice(i, slice_stop_min, one_i64, type = slice_t)
      nested_args = [self.index_along_axis(arg, axis, tile_bounds)
                     for arg, axis in zip(args, axes)]

      new_acc = self.fresh_var(tiled_map_fn.return_type, "new_acc")
      self.comment("TiledReduce in %s: map_fn " % self.fn.name)
      do_inline(tiled_map_fn,
                map_closure_args + nested_args,
                self.type_env,
                self.blocks.top(),
                result_var = new_acc)

      loop_body = self.blocks.pop()
      if acc_is_array:
        outidx = self.tuple([syntax_helpers.slice_none] * result.type.rank)
        result_slice = self.index(result, outidx, temp = False)
        self.comment("")
        do_inline(tiled_combine,
                  combine_closure_args + [result, new_acc, result_slice],
                  self.type_env,
                  loop_body,
                  result_var = None)
      else:
        do_inline(tiled_combine,
                  combine_closure_args + [result, new_acc],
                  self.type_env, loop_body,
                  result_var = result_after)
      return syntax.ForLoop(i, start, stop, step, loop_body, merge)

    assert isinstance(tile_size, syntax.Expr), "%s not an expr" % tile_size

    self.comment("TiledReduce in %s: combine" % self.fn.name)

    if expr.fixed_tile_size and \
       config.opt_reg_tiles_not_tile_size_dependent and \
       syntax_helpers.unwrap_constant(tile_size) > 1:
      num_tiles = self.div(niters, tile_size, "num_tiles")
      tile_stop = self.mul(num_tiles, tile_size, "tile_stop")
      loop1 = make_loop(zero_i64, tile_stop, tile_size, False)
      self.blocks.append(loop1)
      loop2_start = self.assign_temp(loop1.var, "loop2_start")
      self.blocks.append(make_loop(loop2_start, niters, one_i64, False))
    else:
      self.blocks.append(make_loop(zero_i64, niters, tile_size))

    return result
Пример #14
0
  def transform_TiledMap(self, expr):
    args = expr.args
    axes = expr.axes

    # TODO: Should make sure that all the shapes conform here,
    # but we don't yet have anything like assertions or error handling
    niters = self.shape(expr.args[0],
                        syntax_helpers.unwrap_constant(axes[0]))

    # Create the tile size variable and find the number of tiles
    if expr.fixed_tile_size:
      self.fixed_idx += 1
      tile_size = syntax_helpers.const(self.fixed_tile_sizes[self.fixed_idx])
    else:
      self.tiling = True
      self.fn.has_tiles = True
      self.nesting_idx += 1
      tile_size = self.index(self.tile_sizes_param, self.nesting_idx,
                             temp = True, name = "tilesize")

    untiled_inner_fn = self.get_fn(expr.fn)
    if isinstance(untiled_inner_fn.return_type, ScalarT):
      tiled_inner_fn = self.transform_TypedFn(untiled_inner_fn)
    else:
      tiled_inner_fn = self.transform_TypedFn(untiled_inner_fn,
                                              preallocate_output = True)

    nested_has_tiles = tiled_inner_fn.has_tiles

    # Increase the nesting_idx by the number of tiles in the nested fn
    self.nesting_idx += tiled_inner_fn.num_tiles

    slice_t = array_type.make_slice_type(Int64, Int64, Int64)

    closure_args = [self.get_closure_arg(e)
                    for e in self.closure_elts(expr.fn)]

    if self.output_var and \
       not isinstance(untiled_inner_fn.return_type, ScalarT):
      array_result = self.output_var
    else:
      shape_args = closure_args + expr.args
      array_result = self._create_output_array(untiled_inner_fn, shape_args,
                                               [], "array_result")

    assert self.output_var is None or \
           self.output_var.type.__class__ is ArrayT, \
           "Invalid output var %s : %s" % \
           (self.output_var, self.output_var.type)

    def make_loop(start, stop, step, do_min = True):
      i = self.fresh_var(niters.type, "i")

      self.blocks.push()
      slice_stop = self.add(i, step, "slice_stop")
      slice_stop_min = self.min(slice_stop, niters, "slice_min") if do_min \
                       else slice_stop

      tile_bounds = syntax.Slice(i, slice_stop_min, one_i64, type = slice_t)
      nested_args = [self.index_along_axis(arg, axis, tile_bounds)
                     for arg, axis in zip(args, axes)]
      out_idx = self.fixed_idx if expr.fixed_tile_size else self.nesting_idx
      output_region = self.index_along_axis(array_result, out_idx, tile_bounds)
      nested_args.append(output_region)

      if nested_has_tiles:
        nested_args.append(self.tile_sizes_param)
      body = self.blocks.pop()
      do_inline(tiled_inner_fn,
                closure_args + nested_args,
                self.type_env,
                body,
                result_var = None)
      return syntax.ForLoop(i, start, stop, step, body, {})

    assert isinstance(tile_size, syntax.Expr)
    self.comment("TiledMap in %s" % self.fn.name)

    if expr.fixed_tile_size and \
       config.opt_reg_tiles_not_tile_size_dependent and \
       syntax_helpers.unwrap_constant(tile_size) > 1:
      num_tiles = self.div(niters, tile_size, "num_tiles")
      tile_stop = self.mul(num_tiles, tile_size, "tile_stop")
      loop1 = make_loop(zero_i64, tile_stop, tile_size, False)
      self.blocks.append(loop1)
      loop2_start = self.assign_temp(loop1.var, "loop2_start")
      self.blocks.append(make_loop(loop2_start, niters, one_i64, False))
    else:
      self.blocks.append(make_loop(zero_i64, niters, tile_size))
    return array_result
Пример #15
0
 def expr_AllPairs():
   fn = eval_expr(expr.fn)
   x,y = eval_args(expr.args)
   axis = syntax_helpers.unwrap_constant(expr.axis)
   return adverb_evaluator.eval_allpairs(fn, x, y, axis)
Пример #16
0
 def transform_Map(self, expr, output = None):
   fn = self.transform_expr(expr.fn)
   args = self.transform_expr_list(expr.args)
   axis = syntax_helpers.unwrap_constant(expr.axis)
   return self.eval_map(fn, args, axis, output = output)
Пример #17
0
def get_axis(kwargs):
  axis = kwargs.get('axis', 0)
  return syntax_helpers.unwrap_constant(axis)
Пример #18
0
 def expr_Map():
   fn = eval_expr(expr.fn)
   args = eval_args(expr.args)
   axis = syntax_helpers.unwrap_constant(expr.axis)
   return adverb_evaluator.eval_map(fn, args, axis)