Exemplo n.º 1
0
 def transform_IndexReduce(self, expr):
   shape = self.transform_expr(expr.shape)
   map_fn_closure = self.transform_fn(expr.fn if expr.fn else untyped_identity_function)
   combine_closure = self.transform_fn(expr.combine)
   init = self.transform_if_expr(expr.init)
   shape_t = shape.type
   if isinstance(shape_t, IntT):
     shape = self.cast(shape, Int64)
     n_indices = 1
   else:
     assert isinstance(shape_t, TupleT)
     assert all(isinstance(t, ScalarT) for t in shape_t.elt_types)
     n_indices = len(shape_t.elt_types)
     if not all(t == Int64 for t in shape_t.elt_types):
       elts = tuple(self.cast(elt, Int64) for elt in self.tuple_elts(shape))
       shape = self.tuple(elts)
   result_type, typed_fn, typed_combine = \
     specialize_IndexReduce(map_fn_closure.type, combine_closure, n_indices, init)
   if not self.is_none(init):
     init = self.cast(init, result_type)
   return syntax.IndexReduce(shape = shape, 
                             fn = make_typed_closure(map_fn_closure, typed_fn),
                             combine = make_typed_closure(combine_closure, typed_combine),
                             init = init,  
                             type = result_type)
Exemplo n.º 2
0
 def transform_Scan(self, expr):
   map_fn = self.transform_fn(expr.fn if expr.fn else untyped_identity_function)
   combine_fn = self.transform_fn(expr.combine)
   emit_fn = self.transform_fn(expr.emit)
   new_args = self.transform_args(expr.args, flat = True)
   arg_types = get_types(new_args)
   
   init = self.transform_expr(expr.init) if expr.init else None
   
   init_type = get_type(init) if init else None
   
   axis = self.transform_if_expr(expr.axis)
   axes = self.normalize_axes(new_args, axis)
   result_type, typed_map_fn, typed_combine_fn, typed_emit_fn = \
       specialize_Scan(map_fn.type, combine_fn.type, emit_fn.type,
                       arg_types, axes, init_type)
   map_fn.fn = typed_map_fn
   combine_fn.fn = typed_combine_fn
   emit_fn.fn = typed_emit_fn
   return syntax.Scan(fn = make_typed_closure(map_fn, typed_map_fn),
                      combine = make_typed_closure(combine_fn,
                                                   typed_combine_fn),
                      emit = make_typed_closure(emit_fn, typed_emit_fn),
                      args = new_args,
                      axis = axis,
                      type = result_type,
                      init = init)
Exemplo n.º 3
0
 def transform_Scan(self, expr):
   map_fn = self.transform_expr(expr.fn if expr.fn else untyped_identity_function)
   combine_fn = self.transform_expr(expr.combine)
   emit_fn = self.transform_expr(expr.emit)
   new_args = self.transform_args(expr.args, flat = True)
   arg_types = get_types(new_args)
   
   init = self.transform_expr(expr.init) if expr.init else None
   
   init_type = get_type(init) if init else None
   
   result_type, typed_map_fn, typed_combine_fn, typed_emit_fn = \
       specialize_Scan(map_fn.type, combine_fn.type, emit_fn.type,
                       arg_types, init_type)
   map_fn.fn = typed_map_fn
   combine_fn.fn = typed_combine_fn
   emit_fn.fn = typed_emit_fn
   axis = self.transform_if_expr(expr.axis)
   if axis is None or self.is_none(axis):
     assert adverb_helpers.max_rank(arg_types) == 1
     axis = zero_i64
     
   return syntax.Scan(fn = make_typed_closure(map_fn, typed_map_fn),
                      combine = make_typed_closure(combine_fn,
                                                   typed_combine_fn),
                      emit = make_typed_closure(emit_fn, typed_emit_fn),
                      args = new_args,
                      axis = axis,
                      type = result_type,
                      init = init)
Exemplo n.º 4
0
  def transform_Reduce(self, expr):
    assert len(expr.args) > 0, "Can't have Reduce without any arguments %s" % expr 
    new_args = self.transform_args(expr.args, flat = True)
    arg_types = get_types(new_args)
    axis = self.transform_if_expr(expr.axis)

    map_fn = self.transform_fn(expr.fn if expr.fn else untyped_identity_function) 
    combine_fn = self.transform_fn(expr.combine)
   
    # if there aren't any arrays, just treat this as a function call
    if all(isinstance(t, ScalarT) for t in arg_types):
      scalar_result = self.invoke(map_fn, new_args)
      cast_t = self.invoke_type(combine_fn, [scalar_result, scalar_result])
      return self.cast(scalar_result, cast_t)
    
    if self.is_none(expr.init):
      init = none 
    else: 
      init = self.transform_expr(expr.init)
    
    axes = self.normalize_axes(new_args, axis)
    result_type, typed_map_fn, typed_combine_fn = \
      specialize_Reduce(map_fn.type,
                        combine_fn.type,
                        arg_types, 
                        axes, 
                        init.type)

    typed_map_closure = make_typed_closure (map_fn, typed_map_fn)
    typed_combine_closure = make_typed_closure(combine_fn, typed_combine_fn)
    # if we encounter init = 0 for a Reduce which produces an array
    # then need to broadcast to get an initial value of the appropriate rank 
    if init.type.__class__ is not NoneT and \
       init.type != result_type and \
       array_type.rank(init.type) < array_type.rank(result_type):
      assert len(new_args) == 1, "Can't have more than one arg in " % expr  
      arg = new_args[0]
      first_elt = self.slice_along_axis(arg, axis, zero_i64)
      first_combine = specialize(combine_fn, (init.type, first_elt.type))
      
      first_combine_closure = make_typed_closure(combine_fn, first_combine)
      init = self.call(first_combine_closure, (init, first_elt))
      slice_rest = syntax.Slice(start = one_i64, stop = none, step = one_i64, 
                                   type = array_type.SliceT(Int64, NoneType, Int64))
      rest = self.slice_along_axis(arg, axis, slice_rest)
      new_args = (rest,)  

    return syntax.Reduce(fn = typed_map_closure,
                         combine = typed_combine_closure,
                         args = new_args,
                         axis = axis,
                         type = result_type,
                         init = init)
Exemplo n.º 5
0
  def flatten_Reduce(self, map_fn, combine, x, init):
    """Turn an axis-less reduction into a IndexReduce"""
    shape = self.shape(x)
    n_indices = self.rank(x)
    # build a function from indices which picks out the data elements
    # need for the original map_fn
 
    
    outer_closure_args = self.closure_elts(map_fn)
    args_obj = FormalArgs()
    inner_closure_vars = []
    for i in xrange(len(outer_closure_args)):
      visible_name = "c%d" % i
      name = names.fresh(visible_name)
      args_obj.add_positional(name, visible_name)
      inner_closure_vars.append(Var(name))
    
    data_arg_name = names.fresh("x")
    data_arg_var = Var(data_arg_name)
    idx_arg_name = names.fresh("i")
    idx_arg_var = Var(idx_arg_name)
    
    args_obj.add_positional(data_arg_name, "x")
    args_obj.add_positional(idx_arg_name, "i")
    
    idx_expr = syntax.Index(data_arg_var, idx_arg_var)
    inner_fn = self.get_fn(map_fn)
    fn_call_expr = syntax.Call(inner_fn, tuple(inner_closure_vars)  + (idx_expr,))
    idx_fn = UntypedFn(name = names.fresh("idx_map"),
                       args = args_obj, 
                       body =  [syntax.Return(fn_call_expr)]
                       )
    
    #t = closure_type.make_closure_type(typed_fn, get_types(closure_args))
    #return Closure(typed_fn, closure_args, t)
    outer_closure_args = tuple(outer_closure_args) + (x,)
  
    idx_closure_t = closure_type.make_closure_type(idx_fn, get_types(outer_closure_args))
    
    idx_closure = Closure(idx_fn, args = outer_closure_args, type = idx_closure_t)
    
    result_type, typed_fn, typed_combine = \
      specialize_IndexReduce(idx_closure, combine, n_indices, init)
    if not self.is_none(init):
      init = self.cast(init, typed_combine.return_type)
    return syntax.IndexReduce(shape = shape, 
                              fn = make_typed_closure(idx_closure, typed_fn),
                              combine = make_typed_closure(combine, typed_combine),
                              init = init,   
                              type = result_type)
Exemplo n.º 6
0
 def transform_Map(self, expr):
   closure = self.transform_fn(expr.fn)
   new_args = self.transform_args(expr.args, flat = True)
   arg_types = get_types(new_args)
   assert len(arg_types) > 0, "Map requires array arguments"
   # if all arguments are scalars just handle map as a regular function call
   if all(isinstance(t, ScalarT) for t in arg_types):
     return self.invoke(closure, new_args)
   # if any arguments are tuples then all of them should be tuples of same len
   elif any(isinstance(t, TupleT) for t in arg_types):
     assert all(isinstance(t, TupleT) for t in arg_types), \
       "Map doesn't support input types %s" % (arg_types,)
     nelts = len(arg_types[0].elt_types)
     assert all(len(t.elt_types) == nelts for t in arg_types[1:]), \
      "Tuple arguments to Map must be of same length"
     zipped_elts = []
     for i in xrange(nelts):
       zipped_elts.append([self.tuple_proj(arg,i) for arg in new_args])
     return self.tuple([self.invoke(closure, elts) for elts in zipped_elts])
   axis = self.transform_if_expr(expr.axis)
   axes = self.normalize_axes(new_args, axis)
   result_type, typed_fn = specialize_Map(closure.type, arg_types, axes)
   return syntax.Map(fn = make_typed_closure(closure, typed_fn),
                      args = new_args,
                      axis = axis,
                      type = result_type)
Exemplo n.º 7
0
 def transform_OuterMap(self, expr):
   closure = self.transform_fn(expr.fn)
   new_args = self.transform_args (expr.args, flat = True)
   arg_types = get_types(new_args)
   n_args = len(arg_types)
   assert n_args > 0
   axis = self.transform_if_expr(expr.axis)
   axes = self.normalize_axes(new_args, axis)
   result_type, typed_fn = specialize_OuterMap(closure.type, arg_types, axes)
   result = syntax.OuterMap(fn = make_typed_closure(closure, typed_fn),
                            args = new_args,
                            axis = axis,
                            type = result_type)
   return result 
Exemplo n.º 8
0
  def transform_OuterMap(self, expr):
    closure = self.transform_expr(expr.fn)
    new_args = self.transform_args (expr.args, flat = True)
    arg_types = get_types(new_args)
    n_args = len(arg_types)
    assert n_args > 0
    result_type, typed_fn = specialize_OuterMap(closure.type, arg_types)
    axis = self.transform_if_expr(expr.axis)
    if axis is None or self.is_none(axis):
      axis = zero_i64
    result = syntax.OuterMap(fn = make_typed_closure(closure, typed_fn),
                           args = new_args,
                           axis = axis,
                           type = result_type)

    return result 
Exemplo n.º 9
0
 def transform_IndexMap(self, expr):
   shape = self.transform_expr(expr.shape)
   if not isinstance(shape.type, TupleT):
     assert isinstance(shape.type, ScalarT), "Invalid shape for IndexMap: %s : %s" % (shape, shape.type)
     shape = self.tuple((shape,))
   closure = self.transform_fn(expr.fn)
   shape_t = shape.type
   if isinstance(shape_t, IntT):
     shape = self.cast(shape, Int64)
     n_indices = 1
   else:
     assert isinstance(shape_t, TupleT), "Expected shape to be tuple, instead got %s" % (shape_t,)
     assert all(isinstance(t, ScalarT) for t in shape_t.elt_types)
     n_indices = len(shape_t.elt_types)
     if not all(t == Int64 for t in shape_t.elt_types):
       elts = tuple(self.cast(elt, Int64) for elt in self.tuple_elts(shape))
       shape = self.tuple(elts)
   result_type, typed_fn = specialize_IndexMap(closure.type, n_indices)
   return syntax.IndexMap(shape = shape, 
                          fn = make_typed_closure(closure, typed_fn), 
                          type = result_type)
Exemplo n.º 10
0
  def transform_Reduce(self, expr):
    new_args = self.transform_args(expr.args, flat = True)
    arg_types = get_types(new_args)
    axis = self.transform_if_expr(expr.axis)

    map_fn = self.transform_expr(expr.fn if expr.fn else untyped_identity_function) 
    combine_fn = self.transform_expr(expr.combine)
    
    init = self.transform_expr(expr.init) if expr.init else None
    
    # if there aren't any arrays, just treat this as a function call
    if all(isinstance(t, ScalarT) for t in arg_types):
      return self.invoke(map_fn, new_args)
    
    init_type = init.type if init else None
    
    if self.is_none(axis):
      if adverb_helpers.max_rank(arg_types) > 1:
        assert len(new_args) == 1, \
          "Can't handle multiple reduction inputs and flattening from axis=None"
        #x = new_args[0]
        #return self.flatten_Reduce(map_fn, combine_fn, x, init)
        #new_args = [self.ravel(new_args[0])]
        #arg_types = get_types(new_args)
        
        # Expect that the arguments will get raveled before 
        # the adverb gets evaluated 
        axis = self.none
        arg_types = [array_type.lower_rank(t, t.rank - 1) 
                     for t in arg_types
                     if t.rank > 1]
      else:
        axis = self.int(0)                        
    
    result_type, typed_map_fn, typed_combine_fn = \
        specialize_Reduce(map_fn.type,
                          combine_fn.type,
                          arg_types, 
                          init_type)
    typed_map_closure = make_typed_closure (map_fn, typed_map_fn)
    typed_combine_closure = make_typed_closure(combine_fn, typed_combine_fn)
    
    
    if init_type and init_type != result_type and \
       array_type.rank(init_type) < array_type.rank(result_type):
      assert len(new_args) == 1
      #assert is_zero(axis), "Unexpected axis %s : %s" % (axis, axis.type)
      arg = new_args[0]
      first_elt = syntax.Index(arg, zero_i64, 
                               type = arg.type.index_type(zero_i64))
      first_combine = specialize(combine_fn, (init_type, first_elt.type))
      first_combine_closure = make_typed_closure(combine_fn, first_combine)
      init = syntax.Call(first_combine_closure, (init, first_elt), 
                                 type = first_combine.return_type)
      slice_rest = syntax.Slice(start = one_i64, stop = none, step = one_i64, 
                                   type = array_type.SliceT(Int64, NoneType, Int64))
      rest = syntax.Index(arg, slice_rest, 
                             type = arg.type.index_type(slice_rest))
      new_args = (rest,)  
    
    return syntax.Reduce(fn = typed_map_closure,
                         combine = typed_combine_closure,
                         args = new_args,
                         axis = axis,
                         type = result_type,
                         init = init)