class ValueRangeAnalyis(SyntaxVisitor): def __init__(self): SyntaxVisitor.__init__(self) self.ranges = {} self.old_values = ScopedDict() self.old_values.push() def get(self, expr): c = expr.__class__ if expr.type.__class__ is NoneT: return const_none elif c is Const: return Interval(expr.value, expr.value) elif c is Var and expr.name in self.ranges: return self.ranges[expr.name] elif c is Tuple: elt_values = [self.get(elt) for elt in expr.elts] return mk_tuple(elt_values) elif c is TupleProj: tup = self.get(expr.tuple) idx = unwrap_constant(expr.index) if tup.__class__ is TupleOfIntervals: return tup.elts[idx] elif c is Slice: start = self.get(expr.start) stop = self.get(expr.stop) if expr.step.type == NoneType: step = const_one else: step = self.get(expr.step) return mk_slice(start, stop, step) elif c is Shape: ndims = get_rank(expr.array.type) return mk_tuple([positive_interval] * ndims) elif c is Attribute: if expr.name == 'shape': ndims = get_rank(expr.value.type) return mk_tuple([positive_interval] * ndims) elif expr.name == 'start': sliceval = self.get(expr.value) if isinstance(sliceval, SliceOfIntervals): return sliceval.start elif expr.name == 'stop': sliceval = self.get(expr.value) if isinstance(sliceval, SliceOfIntervals): return sliceval.stop elif expr.name == 'step': sliceval = self.get(expr.value) if isinstance(sliceval, SliceOfIntervals): return sliceval.step elif c is PrimCall: p = expr.prim if p.nin == 2: x = self.get(expr.args[0]) y = self.get(expr.args[1]) if p == prims.add: return self.add_range(x,y) elif p == prims.subtract: return self.sub_range(x,y) elif p == prims.multiply: return self.mul_range(x,y) elif p.nin == 1: x = self.get(expr.args[0]) if p == prims.negative: return self.neg_range(x) return any_value def set(self, name, val): if val is not None and val is not unknown_value: old_value = self.ranges.get(name, unknown_value) self.ranges[name] = val if old_value != val and old_value is not unknown_value: self.old_values[name] = old_value def add_range(self, x, y): if not isinstance(x, Interval) or not isinstance(y, Interval): return any_value return Interval (x.lower + y.lower, x.upper + y.upper) def sub_range(self, x, y): if not isinstance(x, Interval) or not isinstance(y, Interval): return any_value return Interval(x.lower - y.upper, x.upper - y.lower) def mul_range(self, x, y): if not isinstance(x, Interval) or not isinstance(y, Interval): return any_value xl, xu = x.lower, x.upper yl, yu = y.lower, y.upper products = (xl * yl, xl * yu, xu * yl, xu * yu) lower = min(products) upper = max(products) return Interval(lower, upper) def neg_range(self, x): if not isinstance(x, Interval): return any_value return Interval(-x.upper, -x.lower) def visit_Assign(self, stmt): if stmt.lhs.__class__ is Var: name = stmt.lhs.name v = self.get(stmt.rhs) self.set(name, v) def visit_merge_left(self, phi_nodes): for (k, (left, _)) in phi_nodes.iteritems(): left_val = self.get(left) self.set(k, left_val) def visit_merge(self, phi_nodes): for (k, (left,right)) in phi_nodes.iteritems(): left_val = self.get(left) right_val = self.get(right) self.set(k, left_val.combine(right_val)) def visit_Select(self, expr): return self.get(expr.true_value).combine(self.get(expr.false_value)) def always_positive(self, x, inclusive = True): if not isinstance(x, Interval): return False elif inclusive: return x.lower >= 0 else: return x.lower > 0 def always_negative(self, x, inclusive = True): if not isinstance(x, Interval): return False elif inclusive: return x.upper <= 0 else: return x.upper < 0 def widen(self, old_values): for (k, oldv) in old_values.iteritems(): newv = self.ranges[k] if oldv != newv: self.ranges[k] = oldv.widen(newv) def run_loop(self, body, merge): # run loop for the first time self.old_values.push() self.visit_merge_left(merge) self.visit_block(body) self.visit_merge(merge) #run loop for the second time self.visit_block(body) self.visit_merge(merge) old_values = self.old_values.pop() self.widen(old_values) # TODO: verify that it's safe not to run loop with widened values #self.visit_block(body) #self.visit_merge(merge) def visit_While(self, stmt): self.run_loop(stmt.body, stmt.merge) def visit_ForLoop(self, stmt): start = self.get(stmt.start) stop = self.get(stmt.stop) step = self.get(stmt.step) name = stmt.var.name iterator_range = any_value if isinstance(start, Interval) and isinstance(stop, Interval): lower = min (start.lower, start.upper, stop.lower, stop.upper) upper = max (start.lower, start.upper, stop.lower, stop.upper) iterator_range = Interval(lower,upper) elif isinstance(start, Interval) and isinstance(step, Interval): if self.always_positive(step): iterator_range = Interval(start.lower, np.inf) elif self.always_negative(step, inclusive = False): iterator_range = Interval(-np.inf, start.upper) elif isinstance(stop, Interval) and isinstance(step, Interval): if self.always_positive(step): iterator_range = Interval(-np.inf, stop.upper) elif self.always_negative(step, inclusive = False): iterator_range = Interval(stop.lower, np.inf) self.set(name, iterator_range) self.run_loop(stmt.body, stmt.merge)
class Simplify(Transform): def __init__(self): transform.Transform.__init__(self) # associate var names with any immutable values # they are bound to self.bindings = ScopedDict() # which expressions have already been computed # and stored in some variable? self.available_expressions = ScopedDict() def pre_apply(self, fn): ma = TypeBasedMutabilityAnalysis() # which types have elements that might # change between two accesses? self.mutable_types = ma.visit_fn(fn) self.use_counts = use_count(fn) def immutable_type(self, t): return t not in self.mutable_types def children(self, expr, allow_mutable = False): c = expr.__class__ if c is Const or c is Var: return () elif c is PrimCall or c is Closure: return expr.args elif c is ClosureElt: return (expr.closure,) elif c is Tuple: return expr.elts elif c is TupleProj: return (expr.tuple,) # WARNING: this is only valid # if attributes are immutable elif c is Attribute: return (expr.value,) elif c is Slice: return (expr.start, expr.stop, expr.step) elif c is Cast: return (expr.value,) elif c is Map or c is OuterMap or c is IndexMap: return expr.args elif c is Scan or c is Reduce or c is IndexReduce or c is FilterReduce: args = tuple(expr.args) init = (expr.init,) if expr.init else () return init + args elif c is Call: # assume all Calls might modify their arguments if allow_mutable or all(self.immutable(arg) for arg in expr.args): return expr.args else: return None if allow_mutable or self.immutable_type(expr.type): if c is Array: return expr.elts elif c is ArrayView: return (expr.data, expr.shape, expr.strides, expr.offset, expr.size) elif c is Struct: return expr.args elif c is AllocArray: return (expr.shape,) elif c is Attribute: return (expr.value,) return None _immutable_classes = set([Const, Var, Closure, ClosureElt, Tuple, TupleProj, Cast, PrimCall, TypedFn, UntypedFn, ArrayView, Slice, ]) def immutable(self, expr): """ TODO: make all this mutability/immutability stuff sane """ klass = expr.__class__ result = (klass in self._immutable_classes and (all(self.immutable(c) for c in expr.children()))) or \ (klass is Attribute and isinstance(expr.type, ImmutableT)) return result def temp(self, expr, name = None, use_count = 1): """ Wrapper around Codegen.assign_name which also updates bindings and use_counts """ if self.is_simple(expr): return expr else: new_var = self.assign_name(expr, name = name) self.bindings[new_var.name] = expr self.use_counts[new_var.name] = use_count return new_var def transform_expr(self, expr): if self.is_simple(expr): if expr.type == NoneType: return none else: return Transform.transform_expr(self, expr) stored = self.available_expressions.get(expr) if stored is not None: return stored return Transform.transform_expr(self, expr) def transform_Var(self, expr): t = expr.type if t.__class__ is NoneT: return none elif t.__class__ is SliceT and \ t.start_type == NoneType and \ t.stop_type == NoneType and \ t.step_type == NoneType: return slice_none name = expr.name prev_expr = expr while name in self.bindings: prev_expr = expr expr = self.bindings[name] if expr.__class__ is Var: name = expr.name else: break c = expr.__class__ if c is Var or c is Const: return expr else: return prev_expr def transform_Cast(self, expr): v = self.transform_expr(expr.value) if v.type == expr.type: return v elif v.__class__ is Const and isinstance(expr.type, ScalarT): return Const(expr.type.dtype.type(v.value), type = expr.type) elif self.is_simple(v): expr.value = v return expr else: expr.value = self.assign_name(v) return expr def transform_Attribute(self, expr): v = self.transform_expr(expr.value) if v.__class__ is Var and v.name in self.bindings: stored_v = self.bindings[v.name] c = stored_v.__class__ if c is Var or c is Struct: v = stored_v elif c is ArrayView: if expr.name == 'shape': return self.transform_expr(stored_v.shape) elif expr.name == 'strides': return self.transform_expr(stored_v.strides) elif expr.name == 'data': return self.transform_expr(stored_v.data) elif c is AllocArray: if expr.name == 'shape': return self.transform_expr(stored_v.shape) elif c is Slice: if expr.name == "start": return self.transform_expr(stored_v.start) elif expr.name == "stop": return self.transform_expr(stored_v.stop) else: assert expr.name == "step", "Unexpected attribute for slice: %s" % expr.name return self.transform_expr(stored_v.step) if v.__class__ is Struct: idx = v.type.field_pos(expr.name) return v.args[idx] elif v.__class__ is not Var: v = self.temp(v, "struct") if expr.value == v: return expr else: return Attribute(value = v, name = expr.name, type = expr.type) def transform_Closure(self, expr): expr.args = tuple(self.transform_simple_exprs(expr.args)) return expr def transform_Tuple(self, expr): expr.elts = tuple( self.transform_simple_exprs(expr.elts)) return expr def transform_TupleProj(self, expr): idx = expr.index assert isinstance(idx, int), \ "TupleProj index must be an integer, got: " + str(idx) new_tuple = self.transform_expr(expr.tuple) if new_tuple.__class__ is Var and new_tuple.name in self.bindings: tuple_expr = self.bindings[new_tuple.name] if tuple_expr.__class__ is Tuple: assert idx < len(tuple_expr.elts), \ "Too few elements in tuple %s : %s, elts = %s" % (expr, tuple_expr.type, tuple_expr.elts) return tuple_expr.elts[idx] elif tuple_expr.__class__ is Struct: assert idx < len(tuple_expr.args), \ "Too few args in closure %s : %s, elts = %s" % (expr, tuple_expr.type, tuple_expr.elts) return tuple_expr.args[idx] if not self.is_simple(new_tuple): new_tuple = self.assign_name(new_tuple, "tuple") expr.tuple = new_tuple return expr def transform_ClosureElt(self, expr): idx = expr.index assert isinstance(idx, int), \ "ClosureElt index must be an integer, got: " + str(idx) new_closure = self.transform_expr(expr.closure) if new_closure.__class__ is Var and new_closure.name in self.bindings: closure_expr = self.bindings[new_closure.name] if closure_expr.__class__ is Closure: return closure_expr.args[idx] if not self.is_simple(new_closure): new_closure = self.assign_name(new_closure, "closure") expr.closure = new_closure return expr def transform_Call(self, expr): fn = self.transform_expr(expr.fn) args = self.transform_simple_exprs(expr.args) if fn.type.__class__ is ClosureT: closure_elts = self.closure_elts(fn) combined_args = tuple(closure_elts) + tuple(args) if fn.type.fn.__class__ is TypedFn: fn = fn.type.fn else: assert isinstance(fn.type.fn, UntypedFn) from .. type_inference import specialize fn = specialize(fn, get_types(combined_args)) assert fn.return_type == expr.type return Call(fn, combined_args, type = fn.return_type) else: expr.fn = fn expr.args = args return expr def transform_simple_expr(self, expr, name = None): if name is None: name = "temp" result = self.transform_expr(expr) if not self.is_simple(result): return self.assign_name(result, name) else: return result def transform_simple_exprs(self, args): return [self.transform_simple_expr(x) for x in args] def transform_Array(self, expr): expr.elts = tuple(self.transform_simple_exprs(expr.elts)) return expr def transform_Slice(self, expr): expr.start = self.transform_simple_expr(expr.start) expr.stop = self.transform_simple_expr(expr.stop) expr.step = self.transform_simple_expr(expr.step) return expr def transform_index_expr(self, expr): if expr.__class__ is Tuple: new_elts = [] for elt in expr.elts: new_elt = self.transform_expr(elt) if not self.is_simple(new_elt) and new_elt.type.__class__ is not SliceT: new_elt = self.temp(new_elt, "index_tuple_elt") new_elts.append(new_elt) expr.elts = tuple(new_elts) return expr else: return self.transform_expr(expr) def transform_Index(self, expr): expr.value = self.transform_expr(expr.value) expr.index = self.transform_index_expr(expr.index) if expr.value.__class__ is Array and expr.index.__class__ is Const: assert isinstance(expr.index.value, (int, long)) and \ len(expr.value.elts) > expr.index.value return expr.value.elts[expr.index.value] # take expressions like "a[i][j]" and turn them into "a[i,j]" if expr.value.__class__ is Index: base_array = expr.value.value if isinstance(base_array.type, ArrayT): base_index = expr.value.index if isinstance(base_index.type, TupleT): indices = self.tuple_elts(base_index) else: assert isinstance(base_index.type, ScalarT), \ "Unexpected index type %s : %s in %s" % (base_index, base_index.type, expr) indices = [base_index] if isinstance(expr.index.type, TupleT): indices = tuple(indices) + tuple(self.tuple_elts(expr.index)) else: assert isinstance(expr.index.type, ScalarT), \ "Unexpected index type %s : %s in %s" % (expr.index, expr.index.type, expr) indices = tuple(indices) + (expr.index,) expr = self.index(base_array, self.tuple(indices)) return self.transform_expr(expr) if expr.value.__class__ is not Var: expr.value = self.temp(expr.value, "array") return expr def transform_Struct(self, expr): new_args = self.transform_simple_exprs(expr.args) return syntax.Struct(new_args, type = expr.type) def transform_Select(self, expr): cond = self.transform_expr(expr.cond) trueval = self.transform_expr(expr.true_value) falseval = self.transform_expr(expr.false_value) if is_true(cond): return trueval elif is_false(cond): return falseval elif trueval == falseval: return trueval else: expr.cond = cond expr.false_value = falseval expr.true_value = trueval return expr def transform_PrimCall(self, expr): args = self.transform_simple_exprs(expr.args) prim = expr.prim if all_constants(args): return syntax.Const(value = prim.fn(*collect_constants(args)), type = expr.type) if len(args) == 1: x = args[0] if prim == prims.logical_not: if is_false(x): return true elif is_true(x): return false if len(args) == 2: x,y = args if prim == prims.add: if is_zero(x): return y elif is_zero(y): return x if y.__class__ is Const and y.value < 0: expr.prim = prims.subtract expr.args = (x, Const(value = -y.value, type = y.type)) return expr elif x.__class__ is Const and x.value < 0: expr.prim = prims.subtract expr.args = (y, Const(value = -x.value, type = x.type)) return expr elif prim == prims.subtract: if is_zero(y): return x elif is_zero(x) and y.__class__ is Var: stored = self.bindings.get(y.name) # 0 - (a * b) --> -a * b |or| a * -b if stored and stored.__class__ is PrimCall and stored.prim == prims.multiply: a,b = stored.args if a.__class__ is Const: expr.prim = prims.multiply neg_a = Const(value = -a.value, type = a.type) expr.args = [neg_a, b] return expr elif b.__class__ is Const: expr.prim = prims.multiply neg_b = Const(value = -b.value, type = b.type) expr.args = [a, neg_b] return expr elif prim == prims.multiply: if is_one(x): return y elif is_one(y): return x elif is_zero(x): return x elif is_zero(y): return y elif prim == prims.divide and is_one(y): return x elif prim == prims.power: if is_one(y): return self.cast(x, expr.type) elif is_zero(y): return one(expr.type) elif y.__class__ is Const and y.value == 2: return self.cast(self.mul(x, x, "sqr"), expr.type) elif prim == prims.logical_and: if is_true(x): return y elif is_false(x) or is_false(y): return false elif prim == prims.logical_or: if is_true(x) or is_true(y): return true elif is_false(x) or is_false(y): return false expr.args = args return expr def transform_Map(self, expr): expr.args = self.transform_simple_exprs(expr.args) expr.fn = self.transform_expr(expr.fn) expr.axis = self.transform_if_expr(expr.axis) return expr def transform_OuterMap(self, expr): expr.args = self.transform_simple_exprs(expr.args) expr.fn = self.transform_expr(expr.fn) expr.axis = self.transform_if_expr(expr.axis) return expr def transform_shape(self, expr): if isinstance(expr, Tuple): expr.elts = tuple(self.transform_simple_exprs(expr.elts)) return expr else: return self.transform_simple_expr(expr) def transform_ParFor(self, stmt): stmt.bounds = self.transform_shape(stmt.bounds) stmt.fn = self.transform_expr(stmt.fn) return stmt def transform_IndexMap(self, expr): expr.fn = self.transform_expr(expr.fn) expr.shape = self.transform_shape(expr.shape) return expr def transform_IndexReduce(self, expr): expr.fn = self.transform_if_expr(expr.fn) expr.combine = self.transform_expr(expr.combine) expr.init = self.transform_if_expr(expr.init) expr.shape = self.transform_shape(expr.shape) return expr def transform_IndexScan(self, expr): expr.fn = self.transform_if_expr(expr.fn) expr.combine = self.transform_expr(expr.combine) expr.emit = self.transform_if_expr(expr.emit) expr.init = self.transform_if_expr(expr.init) expr.shape = self.transform_shape(expr.shape) return expr def transform_ConstArray(self, expr): expr.shape = self.transform_shape(expr.shape) expr.value = self.transform_simple_expr(expr.value) return expr def transform_ConstArrayLike(self, expr): expr.array = self.transform_simple_expr(expr.array) expr.value = self.transform_simple_expr(expr.value) def transform_Reduce(self, expr): init = self.transform_expr(expr.init) if not self.is_simple(init): expr.init = self.assign_name(init, 'init') else: expr.init = init expr.args = self.transform_simple_exprs(expr.args) expr.fn = self.transform_expr(expr.fn) expr.combine = self.transform_expr(expr.combine) return expr def temp_in_block(self, expr, block, name = None): """ If we need a temporary variable not in the current top scope but in a particular block, then use this function. (this function also modifies the bindings dictionary) """ if name is None: name = "temp" var = self.fresh_var(expr.type, name) block.append(Assign(var, expr)) self.bindings[var.name] = expr return var def set_binding(self, name, value): assert value.__class__ is not Var or \ value.name != name, \ "Can't set name %s bound to itself" % name self.bindings[name] = value def bind_var(self, name, rhs): if rhs.__class__ is Var: old_val = self.bindings.get(rhs.name) if old_val and self.is_simple(old_val): self.set_binding(name, old_val) else: self.set_binding(name, rhs) else: self.set_binding(name, rhs) def bind(self, lhs, rhs): lhs_class = lhs.__class__ if lhs_class is Var: self.bind_var(lhs.name, rhs) elif lhs_class is Tuple and rhs.__class__ is Tuple: assert len(lhs.elts) == len(rhs.elts) for lhs_elt, rhs_elt in zip(lhs.elts, rhs.elts): self.bind(lhs_elt, rhs_elt) def transform_lhs_Index(self, lhs): lhs.index = self.transform_index_expr(lhs.index) if lhs.value.__class__ is Var: stored = self.bindings.get(lhs.value.name) if stored and stored.__class__ is Var: lhs.value = stored else: lhs.value = self.assign_name(lhs.value, "array") return lhs def transform_lhs_Attribute(self, lhs): # lhs.value = self.transform_expr(lhs.value) return lhs def transform_ExprStmt(self, stmt): """Don't run an expression unless it possibly has a side effect""" v = self.transform_expr(stmt.value) if self.immutable(v): return None else: stmt.value = v return stmt def transform_Assign(self, stmt): lhs = stmt.lhs rhs = self.transform_expr(stmt.rhs) lhs_class = lhs.__class__ rhs_class = rhs.__class__ if lhs_class is Var: if lhs.type.__class__ is NoneT and self.use_counts.get(lhs.name,0) == 0: return self.transform_stmt(ExprStmt(rhs)) elif self.immutable(rhs): self.bind_var(lhs.name, rhs) if rhs_class is not Var and rhs_class is not Const: self.available_expressions.setdefault(rhs, lhs) elif lhs_class is Tuple: self.bind(lhs, rhs) elif lhs_class is Index: if rhs_class is Index and \ lhs.value == rhs.value and \ lhs.index == rhs.index: # kill effect-free writes like x[i] = x[i] return None elif rhs_class is Var and \ lhs.value.__class__ is Var and \ lhs.value.name == rhs.name and \ lhs.index.type.__class__ is TupleT and \ all(elt_t == slice_none_t for elt_t in lhs.index.type.elt_types): # also kill x[:] = x return None else: lhs = self.transform_lhs_Index(lhs) # when assigning x[j] = [1,2,3] # just rewrite it as a sequence of element assignments # to avoid if lhs.type.__class__ is ArrayT and \ lhs.type.rank == 1 and \ rhs.__class__ is Array: lhs_slice = self.assign_name(lhs, "lhs_slice") for (elt_idx, elt) in enumerate(rhs.elts): lhs_idx = self.index(lhs_slice, const_int(elt_idx), temp = False) self.assign(lhs_idx, elt) return None elif not self.is_simple(rhs): rhs = self.assign_name(rhs) else: assert lhs_class is Attribute assert False, "Considering making attributes immutable" lhs = self.transform_lhs_Attribute(lhs) if rhs_class is Var and \ rhs.name in self.bindings and \ self.use_counts.get(rhs.name, 1) == 1: self.use_counts[rhs.name] = 0 rhs = self.bindings[rhs.name] stmt.lhs = lhs stmt.rhs = rhs return stmt def transform_block(self, stmts, keep_bindings = False): self.available_expressions.push() self.bindings.push() new_stmts = Transform.transform_block(self, stmts) self.available_expressions.pop() if not keep_bindings: self.bindings.pop() return new_stmts def enter_loop(self, phi_nodes): result = {} for (k, (left,right)) in phi_nodes.iteritems(): new_left = self.transform_expr(left) if new_left == right: self.set_binding(k, new_left) else: result[k] = (new_left, right) return result def transform_merge(self, phi_nodes, left_block, right_block): result = {} for (k, (left, right)) in phi_nodes.iteritems(): new_left = self.transform_expr(left) new_right = self.transform_expr(right) if not isinstance(new_left, (Const, Var)): new_left = self.temp_in_block(new_left, left_block) if not isinstance(new_right, (Const, Var)): new_right = self.temp_in_block(new_right, right_block) if new_left == new_right: # if both control flows yield the same value then # we don't actually need the phi-bound variable, we can just # replace the left value everywhere self.assign(Var(name= k, type = new_left.type), new_left) self.set_binding(k, new_left) else: result[k] = new_left, new_right return result def transform_If(self, stmt): stmt.true = self.transform_block(stmt.true, keep_bindings = True) stmt.false = self.transform_block(stmt.false, keep_bindings=True) stmt.merge = self.transform_merge(stmt.merge, left_block = stmt.true, right_block = stmt.false) self.bindings.pop() self.bindings.pop() stmt.cond = self.transform_simple_expr(stmt.cond, "cond") if len(stmt.true) == 0 and len(stmt.false) == 0 and len(stmt.merge) <= 2: for (lhs_name, (true_expr, false_expr)) in stmt.merge.items(): lhs_type = self.lookup_type(lhs_name) lhs_var = Var(name = lhs_name, type = lhs_type) assert true_expr.type == false_expr.type, \ "Unexpcted type mismatch: %s != %s" % (true_expr.type, false_expr.type) rhs = Select(stmt.cond, true_expr, false_expr, type = true_expr.type) self.bind_var(lhs_name, rhs) self.assign(lhs_var, rhs) return None return stmt def transform_loop_condition(self, expr, outer_block, loop_body, merge): """Normalize loop conditions so they are just simple variables""" if self.is_simple(expr): return self.transform_expr(expr) else: loop_carried_vars = [name for name in collect_var_names(expr) if name in merge] if len(loop_carried_vars) == 0: return expr left_values = [merge[name][0] for name in loop_carried_vars] right_values = [merge[name][1] for name in loop_carried_vars] left_cond = subst.subst_expr(expr, dict(zip(loop_carried_vars, left_values))) if not self.is_simple(left_cond): left_cond = self.temp_in_block(left_cond, outer_block, name = "cond") right_cond = subst.subst_expr(expr, dict(zip(loop_carried_vars, right_values))) if not self.is_simple(right_cond): right_cond = self.temp_in_block(right_cond, loop_body, name = "cond") cond_var = self.fresh_var(left_cond.type, "cond") merge[cond_var.name] = (left_cond, right_cond) return cond_var def transform_While(self, stmt): merge = self.enter_loop(stmt.merge) stmt.body = self.transform_block(stmt.body) stmt.merge = self.transform_merge(merge, left_block = self.blocks.current(), right_block = stmt.body) stmt.cond = \ self.transform_loop_condition(stmt.cond, outer_block = self.blocks.current(), loop_body = stmt.body, merge = stmt.merge) return stmt def transform_ForLoop(self, stmt): merge = self.enter_loop(stmt.merge) stmt.body = self.transform_block(stmt.body) stmt.merge = self.transform_merge(merge, left_block = self.blocks.current(), right_block = stmt.body) stmt.start = self.transform_simple_expr(stmt.start, 'start') stmt.stop = self.transform_simple_expr(stmt.stop, 'stop') if self.is_none(stmt.step): stmt.step = one(stmt.start.type) else: stmt.step = self.transform_simple_expr(stmt.step, 'step') # if a loop is only going to run for one iteration, might as well get rid of # it if stmt.start.__class__ is Const and \ stmt.stop.__class__ is Const and \ stmt.step.__class__ is Const: if stmt.start.value >= stmt.stop.value: for (var_name, (input_value, _)) in stmt.merge.iteritems(): var = Var(var_name, input_value.type) self.blocks.append(Assign(var, input_value)) return None elif stmt.start.value + stmt.step.value >= stmt.stop.value: for (var_name, (input_value, _)) in stmt.merge.iteritems(): var = Var(var_name, input_value.type) self.blocks.append(Assign(var, input_value)) self.assign(stmt.var, stmt.start) self.blocks.top().extend(stmt.body) return None return stmt def transform_Return(self, stmt): new_value = self.transform_expr(stmt.value) if new_value != stmt.value: stmt.value = new_value return stmt
class AST_Translator(ast.NodeVisitor): def __init__(self, globals_dict=None, closure_cell_dict=None, parent=None, function_name = None, filename = None): # assignments which need to get prepended at the beginning of the # function self.globals = globals_dict self.blocks = NestedBlocks() self.parent = parent self.scopes = ScopedDict() self.globals_dict = globals_dict self.closure_cell_dict = closure_cell_dict # mapping from names/paths to either a closure cell reference or a # global value self.python_refs = OrderedDict() self.original_outer_names = [] self.localized_outer_names = [] self.filename = filename self.function_name = function_name self.push() def push(self, scope = None, block = None): if scope is None: scope = {} if block is None: block = [] self.scopes.push(scope) self.blocks.push(block) def pop(self): scope = self.scopes.pop() block = self.blocks.pop() return scope, block def fresh_name(self, original_name): fresh_name = names.fresh(original_name) self.scopes[original_name] = fresh_name return fresh_name def fresh_names(self, original_names): return map(self.fresh_name, original_names) def fresh_var(self, name): return syntax.Var(self.fresh_name(name)) def fresh_vars(self, original_names): return map(self.fresh_var, original_names) def current_block(self): return self.blocks.top() def current_scope(self): return self.scopes.top() def syntax_to_value(self, expr): if isinstance(expr, ast.Num): return expr.n elif isinstance(expr, ast.Tuple): return tuple(self.syntax_to_value(elt) for elt in expr.elts) elif isinstance(expr, ast.Name): return self.lookup_global(expr.id) elif isinstance(expr, ast.Attribute): left = self.syntax_to_value(expr.value) if isinstance(left, ExternalValue): left = left.value return getattr(left, expr.attr) def lookup_global(self, key): if isinstance(key, (list, tuple)): assert len(key) == 1 key = key[0] else: assert isinstance(key, str), "Invalid global key: %s" % (key,) if self.globals: if key in self.globals: return self.globals[key] elif key in __builtin__.__dict__: return __builtin__.__dict__[key] else: assert False, "Couldn't find global name %s" % key else: assert self.parent is not None return self.parent.lookup_global(key) def is_global(self, key): if isinstance(key, (list, tuple)): key = key[0] if key in self.scopes: return False elif self.closure_cell_dict and key in self.closure_cell_dict: return False if self.globals: return key in self.globals or key in __builtins__ assert self.parent is not None return self.parent.is_global(key) def local_ref_name(self, ref, python_name): for (local_name, other_ref) in self.python_refs.iteritems(): if ref == other_ref: return Var(local_name) local_name = names.fresh(python_name) self.scopes[python_name] = local_name self.original_outer_names.append(python_name) self.localized_outer_names.append(local_name) self.python_refs[local_name] = ref return Var(local_name) def lookup(self, name): #if name in reserved_names: # return reserved_names[name] if name in self.scopes: return Var(self.scopes[name]) elif self.parent: # don't actually keep the outer binding name, we just # need to check that it's possible and tell the outer scope # to register any necessary python refs local_name = names.fresh(name) self.scopes[name] = local_name self.original_outer_names.append(name) self.localized_outer_names.append(local_name) return Var(local_name) elif self.closure_cell_dict and name in self.closure_cell_dict: ref = ClosureCellRef(self.closure_cell_dict[name], name) return self.local_ref_name(ref, name) elif self.is_global(name): value = self.lookup_global(name) if is_static_value(value): return value_to_syntax(value) elif isinstance(value, np.ndarray): ref = GlobalValueRef(value) return self.local_ref_name(ref, name) else: # assume that this is a module or object which will have some # statically convertible value pulled out of it return ExternalValue(value) #else: # assert False, "Can't use global value %s" % value else: raise NameNotFound(name) def visit_list(self, nodes): return map(self.visit, nodes) def tuple_arg_assignments(self, elts, var): """ Recursively decompose a nested tuple argument like def f((x,(y,z))): ... into a single name and a series of assignments: def f(tuple_arg): x = tuple_arg[0] tuple_arg_elt = tuple_arg[1] y = tuple_arg_elt[0] z = tuple_arg_elt[1] """ assignments = [] for (i, sub_arg) in enumerate(elts): if isinstance(sub_arg, ast.Tuple): name = "tuple_arg_elt" else: assert isinstance(sub_arg, ast.Name) name = sub_arg.id lhs = self.fresh_var(name) stmt = Assign(lhs, syntax.Index(var, syntax.Const(i))) assignments.append(stmt) if isinstance(sub_arg, ast.Tuple): more_stmts = self.tuple_arg_assignments(sub_arg.elts, lhs) assignments.extend(more_stmts) return assignments def translate_args(self, args): assert not args.kwarg formals = FormalArgs() assignments = [] for arg in args.args: if isinstance(arg, ast.Name): visible_name = arg.id local_name = self.fresh_name(visible_name) formals.add_positional(local_name, visible_name) else: assert isinstance(arg, ast.Tuple) arg_name = self.fresh_name("tuple_arg") formals.add_positional(arg_name) var = Var(arg_name) stmts = self.tuple_arg_assignments(arg.elts, var) assignments.extend(stmts) n_defaults = len(args.defaults) if n_defaults > 0: local_names = formals.positional[-n_defaults:] for (k,expr) in zip(local_names, args.defaults): formals.defaults[k] = self.syntax_to_value(expr) if args.vararg: assert isinstance(args.vararg, str) formals.starargs = self.fresh_name(args.vararg) return formals, assignments def visit_Name(self, expr): assert isinstance(expr, ast.Name), "Expected AST Name object: %s" % expr return self.lookup(expr.id) def create_phi_nodes(self, left_scope, right_scope, new_names = {}): """ Phi nodes make explicit the possible sources of each variable's values and are needed when either two branches merge or when one was optionally taken. """ merge = {} for (name, ssa_name) in left_scope.iteritems(): left = Var(ssa_name) if name in right_scope: right = Var(right_scope[name]) else: try: right = self.lookup(name) except NameNotFound: continue if name in new_names: new_name = new_names[name] else: new_name = self.fresh_name(name) merge[new_name] = (left, right) for (name, ssa_name) in right_scope.iteritems(): if name not in left_scope: try: left = self.lookup(name) right = Var(ssa_name) if name in new_names: new_name = new_names[name] else: new_name = self.fresh_name(name) merge[new_name] = (left, right) except names.NameNotFound: # for now skip over variables which weren't defined before # a control flow split, which means that loop-local variables # can't be used after the loop. # TODO: Fix this. Maybe with 'undef' nodes? pass return merge def visit_Index(self, expr): return self.visit(expr.value) def visit_Ellipsis(self, expr): raise RuntimeError("Ellipsis operator unsupported") def visit_Slice(self, expr): """ x[l:u:s] Optional fields expr.lower expr.upper expr.step """ start = self.visit(expr.lower) if expr.lower else none stop = self.visit(expr.upper) if expr.upper else none step = self.visit(expr.step) if expr.step else none return syntax.Slice(start, stop, step) def visit_ExtSlice(self, expr): slice_elts = map(self.visit, expr.dims) if len(slice_elts) > 1: return syntax.Tuple(slice_elts) else: return slice_elts[0] def visit_UnaryOp(self, expr): ssa_val = self.visit(expr.operand) # UAdd doesn't do anything! if expr.op.__class__.__name__ == 'UAdd': return ssa_val prim = prims.find_ast_op(expr.op) return syntax.PrimCall(prim, [ssa_val]) def visit_BinOp(self, expr): ssa_left = self.visit(expr.left) ssa_right = self.visit(expr.right) prim = prims.find_ast_op(expr.op) return syntax.PrimCall(prim, [ssa_left, ssa_right]) def visit_BoolOp(self, expr): values = map(self.visit, expr.values) prim = prims.find_ast_op(expr.op) # Python, strangely, allows more than two arguments to # Boolean operators result = values[0] for v in values[1:]: result = syntax.PrimCall(prim, [result, v]) return result def visit_Compare(self, expr): lhs = self.visit(expr.left) assert len(expr.ops) == 1 prim = prims.find_ast_op(expr.ops[0]) assert len(expr.comparators) == 1 rhs = self.visit(expr.comparators[0]) return syntax.PrimCall(prim, [lhs, rhs]) def visit_Subscript(self, expr): value = self.visit(expr.value) index = self.visit(expr.slice) return syntax.Index(value, index) def generic_visit(self, expr): raise UnsupportedSyntax(expr, function_name = self.function_name, filename = self.filename) def translate_builtin_call(self, value, positional, keywords_dict): from ..mappings import function_mappings if value is sum: return mk_reduce_call(prim_wrapper(prims.add), positional, zero_i64) elif value is max: if len(positional) == 1: return mk_reduce_call(prim_wrapper(prims.maximum), positional) else: assert len(positional) == 2 return syntax.PrimCall(prims.maximum, positional) elif value is min: if len(positional) == 1: return mk_reduce_call(prim_wrapper(prims.minimum), positional) else: assert len(positional) == 2 return syntax.PrimCall(prims.minimum, positional) elif value is map: assert len(keywords_dict) == 0 assert len(positional) > 1 axis = keywords_dict.get("axis", None) return Map(fn = positional[0], args = positional[1:], axis = axis) elif value in function_mappings: parakeet_equiv = function_mappings[value] if isinstance(parakeet_equiv, macro): return parakeet_equiv.transform(positional, keywords_dict) fn = value_to_syntax(value) return syntax.Call(fn, ActualArgs(positional, keywords_dict)) def visit(self, node): res = ast.NodeVisitor.visit(self, node) return res def translate_value_call(self, value, positional, keywords_dict= {}, starargs_expr = None): from ..mappings import function_mappings if value in function_mappings: value = function_mappings[value] if isinstance(value, macro): return value.transform(positional, keywords_dict) elif is_user_function(value): return syntax.Call(translate_function_value(value), ActualArgs(positional, keywords_dict, starargs_expr)) else: return self.translate_builtin_call(value, positional, keywords_dict) def visit_Call(self, expr): """ TODO: The logic here is broken and haphazard, eventually try to handle nested scopes correctly, along with globals, cell refs, etc.. """ fn, args, keywords_list, starargs, kwargs = \ expr.func, expr.args, expr.keywords, expr.starargs, expr.kwargs assert kwargs is None, "Dictionary of keyword args not supported" positional = self.visit_list(args) keywords_dict = {} for kwd in keywords_list: keywords_dict[kwd.arg] = self.visit(kwd.value) if starargs: starargs_expr = self.visit(starargs) else: starargs_expr = None def is_attr_chain(expr): return isinstance(expr, ast.Name) or \ (isinstance(expr, ast.Attribute) and is_attr_chain(expr.value)) def extract_attr_chain(expr): if isinstance(expr, ast.Name): return [expr.id] else: base = extract_attr_chain(expr.value) base.append(expr.attr) return base def lookup_attr_chain(names): value = self.lookup_global(names[0]) for name in names[1:]: if hasattr(value, name): value = getattr(value, name) else: try: value = value[name] except: assert False, "Couldn't find global name %s" % ('.'.join(names)) return value if is_attr_chain(fn): names = extract_attr_chain(fn) if self.is_global(names): return self.translate_value_call(lookup_attr_chain(names), positional, keywords_dict, starargs_expr) fn_node = self.visit(fn) if isinstance(fn_node, syntax.Expr): actuals = ActualArgs(positional, keywords_dict, starargs_expr) return syntax.Call(fn_node, actuals) else: assert isinstance(fn_node, ExternalValue) return self.translate_value_call(fn_node.value, positional, keywords_dict, starargs_expr) def visit_List(self, expr): return syntax.Array(self.visit_list(expr.elts)) def visit_Expr(self, expr): # dummy assignment to allow for side effects on RHS lhs = self.fresh_var("dummy") if isinstance(expr.value, ast.Str): return syntax.Assign(lhs, zero_i64) # return syntax.Comment(expr.value.s.strip().replace('\n', '')) else: rhs = self.visit(expr.value) return syntax.Assign(lhs, rhs) def visit_GeneratorExp(self, expr): return self.visit_ListComp(expr) def visit_ListComp(self, expr): gens = expr.generators assert len(gens) == 1 gen = gens[0] target = gen.target if target.__class__ is ast.Name: arg_vars = [target] else: assert target.__class__ is ast.Tuple and all(e.__class__ is ast.Name for e in target.elts),\ "Expected comprehension target to be variable or tuple of variables, got %s" % ast.dump(target) arg_vars = [ast.Tuple(elts = target.elts)] # build a lambda as a Python ast representing # what we do to each element args = ast.arguments(args = arg_vars, vararg = None, kwarg = None, defaults = ()) fn = translate_function_ast(name = "comprehension_map", args = args, body = [ast.Return(expr.elt)], parent = self) seq = self.visit(gen.iter) ifs = gen.ifs assert len(ifs) == 0, "Parakeet: Conditions in array comprehensions not yet supported" return Map(fn = fn, args=(seq,), axis = zero_i64) def visit_Attribute(self, expr): # TODO: # Recursive lookup to see if: # (1) base object is local, if so-- create chain of attributes # (2) base object is global but an adverb primitive-- use it locally # without adding it to nonlocals # (3) not local at all-- in which case, add the whole chain of strings # to nonlocals # # AN IDEA: # Allow external values to be brought into the syntax tree as # a designated ExternalValue node # and then here check if the LHS is an ExternalValue and if so, # pull out the value. If it's a constant, then make it into syntax, # if it's a function, then parse it, else raise an error. # from ..mappings import property_mappings, method_mappings value = self.visit(expr.value) attr = expr.attr if isinstance(value, ExternalValue): value = value.value value = getattr(value, attr) if is_static_value(value): return value_to_syntax(value) else: return ExternalValue(value) elif attr in property_mappings: fn = property_mappings[attr] if isinstance(fn, macro): return fn.transform( [value] ) else: return syntax.Call(translate_function_value(fn), ActualArgs(positional = (value,))) elif attr in method_mappings: fn_python = method_mappings[attr] fn_syntax = translate_function_value(fn_python) return syntax.Closure(fn_syntax, args=(value,)) else: assert False, "Attribute %s not supported" % attr # return syntax.Attribute(value, expr.attr) def visit_Num(self, expr): return syntax.Const(expr.n) def visit_Tuple(self, expr): return syntax.Tuple(self.visit_list(expr.elts)) def visit_IfExp(self, expr): temp1, temp2, result = self.fresh_vars(["if_true", "if_false", "if_result"]) cond = self.visit(expr.test) true_block = [Assign(temp1, self.visit(expr.body))] false_block = [Assign(temp2, self.visit(expr.orelse))] merge = {result.name : (temp1, temp2)} if_stmt = If(cond, true_block, false_block, merge) self.current_block().append(if_stmt) return result def visit_lhs(self, lhs): if isinstance(lhs, ast.Name): return self.fresh_var(lhs.id) elif isinstance(lhs, ast.Tuple): return syntax.Tuple( map(self.visit_lhs, lhs.elts)) else: # in case of slicing or attributes res = self.visit(lhs) return res def visit_Assign(self, stmt): # important to evaluate RHS before LHS for statements like 'x = x + 1' ssa_rhs = self.visit(stmt.value) ssa_lhs = self.visit_lhs(stmt.targets[0]) return Assign(ssa_lhs, ssa_rhs) def visit_AugAssign(self, stmt): ssa_incr = self.visit(stmt.value) ssa_old_value = self.visit(stmt.target) ssa_new_value = self.visit_lhs(stmt.target) prim = prims.find_ast_op(stmt.op) return Assign(ssa_new_value, PrimCall(prim, [ssa_old_value, ssa_incr])) def visit_Return(self, stmt): return syntax.Return(self.visit(stmt.value)) def visit_If(self, stmt): cond = self.visit(stmt.test) true_scope, true_block = self.visit_block(stmt.body) false_scope, false_block = self.visit_block(stmt.orelse) merge = self.create_phi_nodes(true_scope, false_scope) return syntax.If(cond, true_block, false_block, merge) def visit_loop_body(self, body, *exprs): merge = {} substitutions = {} curr_scope = self.current_scope() exprs = [self.visit(expr) for expr in exprs] scope_after, body = self.visit_block(body) for (k, name_after) in scope_after.iteritems(): if k in self.scopes: name_before = self.scopes[k] new_name = names.fresh(k + "_loop") merge[new_name] = (Var(name_before), Var(name_after)) substitutions[name_before] = new_name curr_scope[k] = new_name exprs = [subst_expr(expr, substitutions) for expr in exprs] body = subst_stmt_list(body, substitutions) return body, merge, exprs def visit_While(self, stmt): assert not stmt.orelse body, merge, (cond,) = self.visit_loop_body(stmt.body, stmt.test) return syntax.While(cond, body, merge) def visit_For(self, stmt): assert not stmt.orelse var = self.visit_lhs(stmt.target) assert isinstance(var, Var) seq = self.visit(stmt.iter) body, merge, _ = self.visit_loop_body(stmt.body) if isinstance(seq, syntax.Range): return ForLoop(var, seq.start, seq.stop, seq.step, body, merge) else: seq_name = self.fresh_name("seq") seq_var = Var(seq_name) self.current_block().append(Assign(seq_var, seq)) len_fn = translate_function_value(len) n = syntax.Call(len_fn, ActualArgs([seq_var])) start = zero_i64 step = one_i64 loop_counter_name = self.fresh_name('loop_counter') loop_var = Var(loop_counter_name) body = [Assign(var, syntax.Index(seq_var, loop_var))] + body return ForLoop(loop_var, start, n, step, body, merge) def visit_block(self, stmts): self.push() curr_block = self.current_block() for stmt in stmts: parakeet_stmt = self.visit(stmt) curr_block.append(parakeet_stmt) return self.pop() def visit_FunctionDef(self, node): """ Translate a nested function """ fundef = translate_function_ast(node.name, node.args, node.body, parent = self) local_var = self.fresh_var(node.name) return Assign(local_var, fundef) def visit_Lambda(self, node): return translate_function_ast("lambda", node.args, [ast.Return(node.body)], parent = self)