def transform_IndexReduce(self, expr): shape = self.transform_expr(expr.shape) map_fn_closure = self.transform_fn(expr.fn if expr.fn else untyped_identity_function) combine_closure = self.transform_fn(expr.combine) init = self.transform_if_expr(expr.init) shape_t = shape.type if isinstance(shape_t, IntT): shape = self.cast(shape, Int64) n_indices = 1 else: assert isinstance(shape_t, TupleT) assert all(isinstance(t, ScalarT) for t in shape_t.elt_types) n_indices = len(shape_t.elt_types) if not all(t == Int64 for t in shape_t.elt_types): elts = tuple(self.cast(elt, Int64) for elt in self.tuple_elts(shape)) shape = self.tuple(elts) result_type, typed_fn, typed_combine = \ specialize_IndexReduce(map_fn_closure.type, combine_closure, n_indices, init) if not self.is_none(init): init = self.cast(init, result_type) return syntax.IndexReduce(shape = shape, fn = make_typed_closure(map_fn_closure, typed_fn), combine = make_typed_closure(combine_closure, typed_combine), init = init, type = result_type)
def transform_Scan(self, expr): map_fn = self.transform_fn(expr.fn if expr.fn else untyped_identity_function) combine_fn = self.transform_fn(expr.combine) emit_fn = self.transform_fn(expr.emit) new_args = self.transform_args(expr.args, flat = True) arg_types = get_types(new_args) init = self.transform_expr(expr.init) if expr.init else None init_type = get_type(init) if init else None axis = self.transform_if_expr(expr.axis) axes = self.normalize_axes(new_args, axis) result_type, typed_map_fn, typed_combine_fn, typed_emit_fn = \ specialize_Scan(map_fn.type, combine_fn.type, emit_fn.type, arg_types, axes, init_type) map_fn.fn = typed_map_fn combine_fn.fn = typed_combine_fn emit_fn.fn = typed_emit_fn return syntax.Scan(fn = make_typed_closure(map_fn, typed_map_fn), combine = make_typed_closure(combine_fn, typed_combine_fn), emit = make_typed_closure(emit_fn, typed_emit_fn), args = new_args, axis = axis, type = result_type, init = init)
def transform_Scan(self, expr): map_fn = self.transform_expr(expr.fn if expr.fn else untyped_identity_function) combine_fn = self.transform_expr(expr.combine) emit_fn = self.transform_expr(expr.emit) new_args = self.transform_args(expr.args, flat = True) arg_types = get_types(new_args) init = self.transform_expr(expr.init) if expr.init else None init_type = get_type(init) if init else None result_type, typed_map_fn, typed_combine_fn, typed_emit_fn = \ specialize_Scan(map_fn.type, combine_fn.type, emit_fn.type, arg_types, init_type) map_fn.fn = typed_map_fn combine_fn.fn = typed_combine_fn emit_fn.fn = typed_emit_fn axis = self.transform_if_expr(expr.axis) if axis is None or self.is_none(axis): assert adverb_helpers.max_rank(arg_types) == 1 axis = zero_i64 return syntax.Scan(fn = make_typed_closure(map_fn, typed_map_fn), combine = make_typed_closure(combine_fn, typed_combine_fn), emit = make_typed_closure(emit_fn, typed_emit_fn), args = new_args, axis = axis, type = result_type, init = init)
def transform_Reduce(self, expr): assert len(expr.args) > 0, "Can't have Reduce without any arguments %s" % expr new_args = self.transform_args(expr.args, flat = True) arg_types = get_types(new_args) axis = self.transform_if_expr(expr.axis) map_fn = self.transform_fn(expr.fn if expr.fn else untyped_identity_function) combine_fn = self.transform_fn(expr.combine) # if there aren't any arrays, just treat this as a function call if all(isinstance(t, ScalarT) for t in arg_types): scalar_result = self.invoke(map_fn, new_args) cast_t = self.invoke_type(combine_fn, [scalar_result, scalar_result]) return self.cast(scalar_result, cast_t) if self.is_none(expr.init): init = none else: init = self.transform_expr(expr.init) axes = self.normalize_axes(new_args, axis) result_type, typed_map_fn, typed_combine_fn = \ specialize_Reduce(map_fn.type, combine_fn.type, arg_types, axes, init.type) typed_map_closure = make_typed_closure (map_fn, typed_map_fn) typed_combine_closure = make_typed_closure(combine_fn, typed_combine_fn) # if we encounter init = 0 for a Reduce which produces an array # then need to broadcast to get an initial value of the appropriate rank if init.type.__class__ is not NoneT and \ init.type != result_type and \ array_type.rank(init.type) < array_type.rank(result_type): assert len(new_args) == 1, "Can't have more than one arg in " % expr arg = new_args[0] first_elt = self.slice_along_axis(arg, axis, zero_i64) first_combine = specialize(combine_fn, (init.type, first_elt.type)) first_combine_closure = make_typed_closure(combine_fn, first_combine) init = self.call(first_combine_closure, (init, first_elt)) slice_rest = syntax.Slice(start = one_i64, stop = none, step = one_i64, type = array_type.SliceT(Int64, NoneType, Int64)) rest = self.slice_along_axis(arg, axis, slice_rest) new_args = (rest,) return syntax.Reduce(fn = typed_map_closure, combine = typed_combine_closure, args = new_args, axis = axis, type = result_type, init = init)
def flatten_Reduce(self, map_fn, combine, x, init): """Turn an axis-less reduction into a IndexReduce""" shape = self.shape(x) n_indices = self.rank(x) # build a function from indices which picks out the data elements # need for the original map_fn outer_closure_args = self.closure_elts(map_fn) args_obj = FormalArgs() inner_closure_vars = [] for i in xrange(len(outer_closure_args)): visible_name = "c%d" % i name = names.fresh(visible_name) args_obj.add_positional(name, visible_name) inner_closure_vars.append(Var(name)) data_arg_name = names.fresh("x") data_arg_var = Var(data_arg_name) idx_arg_name = names.fresh("i") idx_arg_var = Var(idx_arg_name) args_obj.add_positional(data_arg_name, "x") args_obj.add_positional(idx_arg_name, "i") idx_expr = syntax.Index(data_arg_var, idx_arg_var) inner_fn = self.get_fn(map_fn) fn_call_expr = syntax.Call(inner_fn, tuple(inner_closure_vars) + (idx_expr,)) idx_fn = UntypedFn(name = names.fresh("idx_map"), args = args_obj, body = [syntax.Return(fn_call_expr)] ) #t = closure_type.make_closure_type(typed_fn, get_types(closure_args)) #return Closure(typed_fn, closure_args, t) outer_closure_args = tuple(outer_closure_args) + (x,) idx_closure_t = closure_type.make_closure_type(idx_fn, get_types(outer_closure_args)) idx_closure = Closure(idx_fn, args = outer_closure_args, type = idx_closure_t) result_type, typed_fn, typed_combine = \ specialize_IndexReduce(idx_closure, combine, n_indices, init) if not self.is_none(init): init = self.cast(init, typed_combine.return_type) return syntax.IndexReduce(shape = shape, fn = make_typed_closure(idx_closure, typed_fn), combine = make_typed_closure(combine, typed_combine), init = init, type = result_type)
def transform_Map(self, expr): closure = self.transform_fn(expr.fn) new_args = self.transform_args(expr.args, flat = True) arg_types = get_types(new_args) assert len(arg_types) > 0, "Map requires array arguments" # if all arguments are scalars just handle map as a regular function call if all(isinstance(t, ScalarT) for t in arg_types): return self.invoke(closure, new_args) # if any arguments are tuples then all of them should be tuples of same len elif any(isinstance(t, TupleT) for t in arg_types): assert all(isinstance(t, TupleT) for t in arg_types), \ "Map doesn't support input types %s" % (arg_types,) nelts = len(arg_types[0].elt_types) assert all(len(t.elt_types) == nelts for t in arg_types[1:]), \ "Tuple arguments to Map must be of same length" zipped_elts = [] for i in xrange(nelts): zipped_elts.append([self.tuple_proj(arg,i) for arg in new_args]) return self.tuple([self.invoke(closure, elts) for elts in zipped_elts]) axis = self.transform_if_expr(expr.axis) axes = self.normalize_axes(new_args, axis) result_type, typed_fn = specialize_Map(closure.type, arg_types, axes) return syntax.Map(fn = make_typed_closure(closure, typed_fn), args = new_args, axis = axis, type = result_type)
def transform_OuterMap(self, expr): closure = self.transform_fn(expr.fn) new_args = self.transform_args (expr.args, flat = True) arg_types = get_types(new_args) n_args = len(arg_types) assert n_args > 0 axis = self.transform_if_expr(expr.axis) axes = self.normalize_axes(new_args, axis) result_type, typed_fn = specialize_OuterMap(closure.type, arg_types, axes) result = syntax.OuterMap(fn = make_typed_closure(closure, typed_fn), args = new_args, axis = axis, type = result_type) return result
def transform_OuterMap(self, expr): closure = self.transform_expr(expr.fn) new_args = self.transform_args (expr.args, flat = True) arg_types = get_types(new_args) n_args = len(arg_types) assert n_args > 0 result_type, typed_fn = specialize_OuterMap(closure.type, arg_types) axis = self.transform_if_expr(expr.axis) if axis is None or self.is_none(axis): axis = zero_i64 result = syntax.OuterMap(fn = make_typed_closure(closure, typed_fn), args = new_args, axis = axis, type = result_type) return result
def transform_IndexMap(self, expr): shape = self.transform_expr(expr.shape) if not isinstance(shape.type, TupleT): assert isinstance(shape.type, ScalarT), "Invalid shape for IndexMap: %s : %s" % (shape, shape.type) shape = self.tuple((shape,)) closure = self.transform_fn(expr.fn) shape_t = shape.type if isinstance(shape_t, IntT): shape = self.cast(shape, Int64) n_indices = 1 else: assert isinstance(shape_t, TupleT), "Expected shape to be tuple, instead got %s" % (shape_t,) assert all(isinstance(t, ScalarT) for t in shape_t.elt_types) n_indices = len(shape_t.elt_types) if not all(t == Int64 for t in shape_t.elt_types): elts = tuple(self.cast(elt, Int64) for elt in self.tuple_elts(shape)) shape = self.tuple(elts) result_type, typed_fn = specialize_IndexMap(closure.type, n_indices) return syntax.IndexMap(shape = shape, fn = make_typed_closure(closure, typed_fn), type = result_type)
def transform_Reduce(self, expr): new_args = self.transform_args(expr.args, flat = True) arg_types = get_types(new_args) axis = self.transform_if_expr(expr.axis) map_fn = self.transform_expr(expr.fn if expr.fn else untyped_identity_function) combine_fn = self.transform_expr(expr.combine) init = self.transform_expr(expr.init) if expr.init else None # if there aren't any arrays, just treat this as a function call if all(isinstance(t, ScalarT) for t in arg_types): return self.invoke(map_fn, new_args) init_type = init.type if init else None if self.is_none(axis): if adverb_helpers.max_rank(arg_types) > 1: assert len(new_args) == 1, \ "Can't handle multiple reduction inputs and flattening from axis=None" #x = new_args[0] #return self.flatten_Reduce(map_fn, combine_fn, x, init) #new_args = [self.ravel(new_args[0])] #arg_types = get_types(new_args) # Expect that the arguments will get raveled before # the adverb gets evaluated axis = self.none arg_types = [array_type.lower_rank(t, t.rank - 1) for t in arg_types if t.rank > 1] else: axis = self.int(0) result_type, typed_map_fn, typed_combine_fn = \ specialize_Reduce(map_fn.type, combine_fn.type, arg_types, init_type) typed_map_closure = make_typed_closure (map_fn, typed_map_fn) typed_combine_closure = make_typed_closure(combine_fn, typed_combine_fn) if init_type and init_type != result_type and \ array_type.rank(init_type) < array_type.rank(result_type): assert len(new_args) == 1 #assert is_zero(axis), "Unexpected axis %s : %s" % (axis, axis.type) arg = new_args[0] first_elt = syntax.Index(arg, zero_i64, type = arg.type.index_type(zero_i64)) first_combine = specialize(combine_fn, (init_type, first_elt.type)) first_combine_closure = make_typed_closure(combine_fn, first_combine) init = syntax.Call(first_combine_closure, (init, first_elt), type = first_combine.return_type) slice_rest = syntax.Slice(start = one_i64, stop = none, step = one_i64, type = array_type.SliceT(Int64, NoneType, Int64)) rest = syntax.Index(arg, slice_rest, type = arg.type.index_type(slice_rest)) new_args = (rest,) return syntax.Reduce(fn = typed_map_closure, combine = typed_combine_closure, args = new_args, axis = axis, type = result_type, init = init)