def tuple_from_shape(self, expr): shape = self.visit_expr(expr) if shape.__class__ is Shape: return Tuple(tuple(shape.dims)) elif shape.__class__ is Const: return Tuple((shape.value, )) else: return Tuple((any_scalar, ) * expr.type.rank)
def visit_Struct(self, expr): if isinstance(expr.type, ArrayT): shape_tuple = self.visit_expr(expr.args[1]) return make_shape(shape_tuple.elts) elif isinstance(expr.type, TupleT): return Tuple(self.visit_expr_list(expr.args)) elif isinstance(expr.type, SliceT): start, stop, step = self.visit_expr_list(expr.args) return Slice(start, stop, step) else: return unknown_value
def visit_Attribute(self, expr): v = self.visit_expr(expr.value) name = expr.name if v.__class__ is Shape: if name == 'shape': return Tuple(v.dims) elif name == 'strides': return Tuple((any_scalar, ) * len(v.dims)) elif name in ('offset', 'size', 'nelts'): return any_scalar elif name == 'data': return Ptr(any_scalar) elif v.__class__ is Tuple: if name.startswith('elt'): idx = int(name[3:]) else: idx = int(name) return v[idx] elif v.__class__ is Slice: return getattr(v, name) elif v.__class__ is Closure: if name.startswith('elt'): idx = int(name[3:]) elif name.startswith('closure_elt'): idx = int(name[len('closure_elt'):]) else: idx = int(name) return v.args[idx] elif v.__class__ is Struct: return v.values[v.fields.index(name)] t = expr.value.type.field_type(name) if isinstance(t, ScalarT): return any_scalar else: return any_value
def from_type(self, t): if isinstance(t, ScalarT): return self.fresh_var() elif t.__class__ is ArrayT: dim_vars = [self.fresh_var() for _ in range(t.rank)] return Shape(dim_vars) elif t.__class__ is TupleT: elt_values = self.from_types(t.elt_types) return Tuple(elt_values) elif t.__class__ is SliceT: start = self.from_type(t.start_type) stop = self.from_type(t.stop_type) step = self.from_type(t.step_type) return Slice(start, stop, step) elif t.__class__ is ClosureT: arg_vals = self.from_types(t.arg_types) return Closure(t.fn, arg_vals) elif t.__class__ is FnT: return Closure(t.fn, ()) elif isinstance(t, StructT): field_names = [fn for (fn,_) in t._fields_] field_types = [ft for (_,ft) in t._fields_] field_vals = self.from_types(field_types) return Struct(field_names, field_vals) elif isinstance(t, (TypeValueT, NoneT)): return Tuple(()) elif isinstance(t, PtrT): return Ptr(any_scalar) else: assert False, "Unsupported type: %s" % t
def subst(x, env): if isinstance(x, Var): assert x in env, "Unknown variable %s" % x return env[x] elif isinstance(x, Scalar): return x elif isinstance(x, Shape): return make_shape(subst_list(x.dims, env)) elif isinstance(x, Tuple): return Tuple(tuple((subst_list(x.elts, env)))) elif isinstance(x, Closure): return Closure(x.fn, subst_list(x.args, env)) elif isinstance(x, Ptr): return Ptr(subst(x.elt_shape, env)) else: raise RuntimeError("Unexpected abstract expression: %s" % x)
def visit_Tuple(self, expr): return Tuple(self.visit_expr_list(expr.elts))
def concat_tuples(self, t1, t2): return Tuple(t1.elts + t2.elts)
def tuple(self, elts): return Tuple(tuple(elts))
def shape(self, x): if isinstance(x, Shape): return Tuple(x.dims) else: return Tuple(())