def linearize_accesses(iet, key, cache, sregistry): """ Turn Indexeds into FIndexeds and create the necessary access Macros. """ # `functions` are all Functions that `iet` may need to linearize functions = [f for f in FindSymbols().visit(iet) if key(f) and f.ndim > 1] functions = sorted(functions, key=lambda f: len(f.dimensions), reverse=True) # `functions_unseen` are all Functions that `iet` may need to linearize # and have not been seen while processing other IETs functions_unseen = [f for f in functions if f not in cache] # Find unique sizes (unique -> minimize necessary registers) mapper = DefaultOrderedDict(list) for f in functions: # NOTE: the outermost dimension is unnecessary for d in f.dimensions[1:]: # TODO: same grid + same halo => same padding, however this is # never asserted throughout the compiler yet... maybe should do # it when in debug mode at `prepare_arguments` time, ie right # before jumping to C? mapper[(d, f._size_halo[d], getattr(f, 'grid', None))].append(f) # For all unseen Functions, build the size exprs. For example: # `x_fsz0 = u_vec->size[1]` imapper = DefaultOrderedDict(dict) for (d, halo, _), v in mapper.items(): v_unseen = [f for f in v if f in functions_unseen] if not v_unseen: continue expr = _generate_fsz(v_unseen[0], d, sregistry) if expr: for f in v_unseen: imapper[f][d] = expr.write cache[f].stmts0.append(expr) # For all unseen Functions, build the stride exprs. For example: # `y_stride0 = y_fsz0*z_fsz0` built = {} mapper = DefaultOrderedDict(dict) for f, v in imapper.items(): for d in v: n = f.dimensions.index(d) expr = prod(v[i] for i in f.dimensions[n:]) try: stmt = built[expr] except KeyError: name = sregistry.make_name(prefix='%s_stride' % d.name) s = Symbol(name=name, dtype=np.int64, is_const=True) stmt = built[expr] = DummyExpr(s, expr, init=True) mapper[f][d] = stmt.write cache[f].stmts1.append(stmt) mapper.update([(f, {}) for f in functions_unseen if f not in mapper]) # For all unseen Functions, build defines. For example: # `#define uL(t, x, y, z) u[(t)*t_stride0 + (x)*x_stride0 + (y)*y_stride0 + (z)]` headers = [] findexeds = {} for f in functions: if cache[f].cbk is None: header, cbk = _generate_macro(f, mapper[f], sregistry) headers.append(header) cache[f].cbk = findexeds[f] = cbk else: findexeds[f] = cache[f].cbk # Build "functional" Indexeds. For example: # `u[t2, x+8, y+9, z+7] => uL(t2, x+8, y+9, z+7)` mapper = {} indexeds = FindSymbols('indexeds').visit(iet) for i in indexeds: try: mapper[i] = findexeds[i.function](i) except KeyError: pass # Introduce the linearized expressions iet = Uxreplace(mapper).visit(iet) # All Functions that actually require linearization in `iet` candidates = [] candidates.extend(filter_ordered(i.function for i in indexeds)) calls = FindNodes(Call).visit(iet) cfuncs = filter_ordered(flatten(i.functions for i in calls)) candidates.extend(i for i in cfuncs if i.function.is_DiscreteFunction) # All Functions that can be linearized in `iet` defines = FindSymbols('defines-aliases').visit(iet) # Place the linearization expressions or delegate to ancestor efunc stmts0 = [] stmts1 = [] args = [] for f in candidates: if f in defines: stmts0.extend(cache[f].stmts0) stmts1.extend(cache[f].stmts1) else: args.extend([e.write for e in cache[f].stmts1]) if stmts0: assert len(stmts1) > 0 stmts0 = filter_ordered(stmts0) + [BlankLine] stmts1 = filter_ordered(stmts1) + [BlankLine] body = iet.body._rebuild(body=tuple(stmts0) + tuple(stmts1) + iet.body.body) iet = iet._rebuild(body=body) else: assert len(stmts0) == 0 return iet, headers, args
def linearize_accesses(iet, cache, sregistry): """ Turn Indexeds into FIndexeds and create the necessary access Macros. """ # Find all objects amenable to linearization symbol_names = {i.name for i in FindSymbols('indexeds').visit(iet)} functions = [f for f in FindSymbols().visit(iet) if ((f.is_DiscreteFunction or f.is_Array) and f.ndim > 1 and f.name in symbol_names)] functions = sorted(functions, key=lambda f: len(f.dimensions), reverse=True) # Find unique sizes (unique -> minimize necessary registers) mapper = DefaultOrderedDict(list) for f in functions: if f not in cache: # NOTE: the outermost dimension is unnecessary for d in f.dimensions[1:]: # TODO: same grid + same halo => same padding, however this is # never asserted throughout the compiler yet... maybe should do # it when in debug mode at `prepare_arguments` time, ie right # before jumping to C? mapper[(d, f._size_halo[d], getattr(f, 'grid', None))].append(f) # Build all exprs such as `x_fsz0 = u_vec->size[1]` imapper = DefaultOrderedDict(list) for (d, halo, _), v in mapper.items(): name = sregistry.make_name(prefix='%s_fsz' % d.name) s = Symbol(name=name, dtype=np.int32, is_const=True) try: expr = DummyExpr(s, v[0]._C_get_field(FULL, d).size, init=True) except AttributeError: assert v[0].is_Array expr = DummyExpr(s, v[0].symbolic_shape[d], init=True) for f in v: imapper[f].append((d, s)) cache[f].stmts0.append(expr) # Build all exprs such as `y_slc0 = y_fsz0*z_fsz0` built = {} mapper = DefaultOrderedDict(list) for f, v in imapper.items(): for n, (d, _) in enumerate(v): expr = prod(list(zip(*v[n:]))[1]) try: stmt = built[expr] except KeyError: name = sregistry.make_name(prefix='%s_slc' % d.name) s = Symbol(name=name, dtype=np.int32, is_const=True) stmt = built[expr] = DummyExpr(s, expr, init=True) mapper[f].append(stmt.write) cache[f].stmts1.append(stmt) mapper.update([(f, []) for f in functions if f not in mapper]) # Build defines. For example: # `define uL(t, x, y, z) u[(t)*t_slice_sz + (x)*x_slice_sz + (y)*y_slice_sz + (z)]` headers = [] findexeds = {} for f, szs in mapper.items(): if cache[f].cbk is not None: # Perhaps we've already built an access macro for `f` through another efunc findexeds[f] = cache[f].cbk else: assert len(szs) == len(f.dimensions) - 1 pname = sregistry.make_name(prefix='%sL' % f.name) expr = sum([MacroArgument(d.name)*s for d, s in zip(f.dimensions, szs)]) expr += MacroArgument(f.dimensions[-1].name) expr = Indexed(IndexedData(f.name, None, f), expr) define = DefFunction(pname, f.dimensions) headers.append((ccode(define), ccode(expr))) cache[f].cbk = findexeds[f] = lambda i, pname=pname: FIndexed(i, pname) # Build "functional" Indexeds. For example: # `u[t2, x+8, y+9, z+7] => uL(t2, x+8, y+9, z+7)` mapper = {} for n in FindNodes(Expression).visit(iet): subs = {} for i in retrieve_indexed(n.expr): try: subs[i] = findexeds[i.function](i) except KeyError: pass mapper[n] = n._rebuild(expr=uxreplace(n.expr, subs)) # Put together all of the necessary exprs for `y_fsz0`, ..., `y_slc0`, ... stmts0 = filter_ordered(flatten(cache[f].stmts0 for f in functions)) if stmts0: stmts0.append(BlankLine) stmts1 = filter_ordered(flatten(cache[f].stmts1 for f in functions)) if stmts1: stmts1.append(BlankLine) iet = Transformer(mapper).visit(iet) body = iet.body._rebuild(body=tuple(stmts0) + tuple(stmts1) + iet.body.body) iet = iet._rebuild(body=body) return iet, headers