def _mem_summary(self): """ The amount of data, in bytes, used by the Operator. This is provided as symbolic expressions, one symbolic expression for each memory scope (external, stack, heap). """ roots = [self] + [i.root for i in self._func_table.values()] functions = [i for i in derive_parameters(roots) if i.is_Function] summary = {} external = [i.symbolic_shape for i in functions if i._mem_external] external = sum(reduce(mul, i, 1) for i in external)*self._dtype().itemsize summary['external'] = external heap = [i.symbolic_shape for i in functions if i._mem_heap] heap = sum(reduce(mul, i, 1) for i in heap)*self._dtype().itemsize summary['heap'] = heap stack = [i.symbolic_shape for i in functions if i._mem_stack] stack = sum(reduce(mul, i, 1) for i in stack)*self._dtype().itemsize summary['stack'] = stack summary['total'] = external + heap + stack return summary
def _build_parameters(self, iet): """Derive the Operator parameters.""" parameters = derive_parameters(iet, True) # Hackish: add parameters not emebedded directly in any IET node, # e.g. those produced by the DLE or by a backend parameters.extend([i for i in self.input if i not in parameters]) return tuple(parameters)
def _mem_summary(self): """ The amount of data, in bytes, used by the Operator. This is provided as symbolic expressions, one symbolic expression for each memory scope (external, stack, heap). """ roots = [self] + [i.root for i in self._func_table.values()] functions = [i for i in derive_parameters(roots) if i.is_Function] summary = {} external = [i.symbolic_shape for i in functions if i._mem_external] external = sum(reduce(mul, i, 1) for i in external) * self._dtype().itemsize summary['external'] = external heap = [i.symbolic_shape for i in functions if i._mem_heap] heap = sum(reduce(mul, i, 1) for i in heap) * self._dtype().itemsize summary['heap'] = heap stack = [i.symbolic_shape for i in functions if i._mem_stack] stack = sum(reduce(mul, i, 1) for i in stack) * self._dtype().itemsize summary['stack'] = stack summary['total'] = external + heap + stack return summary
def _make_fetchupdate(self, iet, sync_ops, pieces, *args): # Construct fetches postactions = [] for s in sync_ops: # The condition is already encoded in `iet` with a Conditional, # which stems from the originating Cluster's guards assert s.fcond is None imask = [(s.tstore, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] postactions.append( PragmaTransfer(self.lang._map_update_device, s.target, imask=imask)) # Turn init IET into a Callable functions = filter_ordered( flatten([(s.target, s.function) for s in sync_ops])) name = self.sregistry.make_name(prefix='init_device') body = List(body=iet.body + tuple(postactions)) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread iet = List(header=c.Comment("Initialize data stream"), body=Call(name, parameters)) return iet
def _finalize(self, iet, parameters): for k, v in self._func_table.items(): parameters = derive_parameters(v.root, True) root = super(OperatorOPS, self)._finalize(v.root, parameters) self._func_table[k] = MetaCall(root, v.local) return super()._finalize(iet, parameters)
def _finalize(self, iet, parameters): # Applies _finalize for each generated OPS kernel, which will generate the .h file self._ops_kernels = [super(OperatorOPS, self)._finalize( kernel, derive_parameters(kernel, True)) for kernel in self._ops_kernels] return super()._finalize(iet, parameters)
def _lower_iet(cls, stree, profiler, **kwargs): """ Iteration/Expression tree lowering: * Turn a ScheduleTree into an Iteration/Expression tree; * Introduce distributed-memory, shared-memory, and SIMD parallelism; * Introduce optimizations for data locality; * Finalize (e.g., symbol definitions, array casts) """ name = kwargs.get("name", "Kernel") sregistry = kwargs['sregistry'] # Build an IET from a ScheduleTree iet = iet_build(stree) # Analyze the IET Sections for C-level profiling profiler.analyze(iet) # Wrap the IET with an EntryFunction (a special Callable representing # the entry point of the generated library) parameters = derive_parameters(iet, True) iet = EntryFunction(name, iet, 'int', parameters, ()) # Lower IET to a target-specific IET graph = Graph(iet) graph = cls._specialize_iet(graph, **kwargs) # Instrument the IET for C-level profiling # Note: this is postponed until after _specialize_iet because during # specialization further Sections may be introduced instrument(graph, profiler=profiler, sregistry=sregistry) return graph.root, graph
def _lower_iet(cls, stree, profiler, **kwargs): """ Iteration/Expression tree lowering: * Turn a ScheduleTree into an Iteration/Expression tree; * Perform analysis to detect optimization opportunities; * Introduce distributed-memory, shared-memory, and SIMD parallelism; * Introduce optimizations for data locality; * Finalize (e.g., symbol definitions, array casts) """ name = kwargs.get("name", "Kernel") iet = iet_build(stree) # Instrument the IET for C-level profiling iet = profiler.instrument(iet) # Wrap the IET with a Callable parameters = derive_parameters(iet, True) iet = Callable(name, iet, 'int', parameters, ()) # Lower IET to a target-specific IET graph = Graph(iet) graph = cls._specialize_iet(graph, **kwargs) return graph.root, graph
def update_halo(f, fixed): """ Construct an IET performing a halo exchange for a :class:`TensorFunction`. """ # Requirements assert f.is_Function assert f.grid is not None distributor = f.grid.distributor nb = distributor._C_neighbours.obj comm = distributor._C_comm fixed = {d: Symbol(name="o%s" % d.root) for d in fixed} mapper = get_views(f, fixed) body = [] for d in f.dimensions: if d in fixed: continue rpeer = FieldFromPointer("%sright" % d, nb) lpeer = FieldFromPointer("%sleft" % d, nb) # Sending to left, receiving from right lsizes, loffsets = mapper[(d, LEFT, OWNED)] rsizes, roffsets = mapper[(d, RIGHT, HALO)] assert lsizes == rsizes sizes = lsizes parameters = ([f] + list(f.symbolic_shape) + sizes + loffsets + roffsets + [rpeer, lpeer, comm]) call = Call('sendrecv_%s' % f.name, parameters) body.append(Conditional(Symbol(name='m%sl' % d), call)) # Sending to right, receiving from left rsizes, roffsets = mapper[(d, RIGHT, OWNED)] lsizes, loffsets = mapper[(d, LEFT, HALO)] assert rsizes == lsizes sizes = rsizes parameters = ([f] + list(f.symbolic_shape) + sizes + roffsets + loffsets + [lpeer, rpeer, comm]) call = Call('sendrecv_%s' % f.name, parameters) body.append(Conditional(Symbol(name='m%sr' % d), call)) iet = List(body=body) parameters = derive_parameters(iet, drop_locals=True) return Callable('halo_exchange_%s' % f.name, iet, 'void', parameters, ('static', ))
def __make_tfunc(self, name, iet, root, threads): # Create the SharedData required = derive_parameters(iet) known = (root.parameters + tuple(i for i in required if i.is_Array and i._mem_shared)) parameters, dynamic_parameters = split(required, lambda i: i in known) sdata = SharedData(name=self.sregistry.make_name(prefix='sdata'), nthreads_std=threads.size, fields=dynamic_parameters) parameters.append(sdata) # Prepend the unwinded SharedData fields, available upon thread activation preactions = [ DummyExpr(i, FieldFromPointer(i.name, sdata.symbolic_base)) for i in dynamic_parameters ] preactions.append( DummyExpr(sdata.symbolic_id, FieldFromPointer(sdata._field_id, sdata.symbolic_base))) # Append the flag reset postactions = [ List(body=[ BlankLine, DummyExpr( FieldFromPointer(sdata._field_flag, sdata.symbolic_base), 1) ]) ] iet = List(body=preactions + [iet] + postactions) # Append the flag reset # The thread has work to do when it receives the signal that all locks have # been set to 0 by the main thread iet = Conditional( CondEq(FieldFromPointer(sdata._field_flag, sdata.symbolic_base), 2), iet) # The thread keeps spinning until the alive flag is set to 0 by the main thread iet = While( CondNe(FieldFromPointer(sdata._field_flag, sdata.symbolic_base), 0), iet) return Callable(name, iet, 'void', parameters, 'static'), sdata
def _build_parameters(self, nodes): """Determine the Operator parameters based on the Iteration/Expression tree ``nodes``.""" return derive_parameters(nodes, True)
def _build(cls, expressions, **kwargs): expressions = as_tuple(expressions) # Input check if any(not isinstance(i, Eq) for i in expressions): raise InvalidOperator("Only `devito.Eq` expressions are allowed.") name = kwargs.get("name", "Kernel") dse = kwargs.get("dse", configuration['dse']) # Python-level (i.e., compile time) and C-level (i.e., run time) performance profiler = create_profile('timers') # Lower input expressions to internal expressions (e.g., attaching metadata) expressions = cls._lower_exprs(expressions, **kwargs) # Group expressions based on their iteration space and data dependences # Several optimizations are applied (fusion, lifting, flop reduction via DSE, ...) clusters = clusterize(expressions, dse_mode=set_dse_mode(dse)) # Lower Clusters to a Schedule tree stree = st_build(clusters) # Lower Schedule tree to an Iteration/Expression tree (IET) iet = iet_build(stree) # Instrument the IET for C-level profiling iet = profiler.instrument(iet) # Wrap the IET with a Callable parameters = derive_parameters(iet, True) op = Callable(name, iet, 'int', parameters, ()) # Lower IET to a Target-specific IET op, target_state = cls._specialize_iet(op, **kwargs) # Make it an actual Operator op = Callable.__new__(cls, **op.args) Callable.__init__(op, **op.args) # Header files, etc. op._headers = list(cls._default_headers) op._headers.extend(target_state.headers) op._globals = list(cls._default_globals) op._includes = list(cls._default_includes) op._includes.extend(profiler._default_includes) op._includes.extend(target_state.includes) # Required for the jit-compilation op._compiler = configuration['compiler'] op._lib = None op._cfunction = None # References to local or external routines op._func_table = OrderedDict() op._func_table.update( OrderedDict([(i, MetaCall(None, False)) for i in profiler._ext_calls])) op._func_table.update( OrderedDict([(i.root.name, i) for i in target_state.funcs])) # Internal state. May be used to store information about previous runs, # autotuning reports, etc op._state = cls._initialize_state(**kwargs) # Produced by the various compilation passes op._input = filter_sorted( flatten(e.reads + e.writes for e in expressions)) op._output = filter_sorted(flatten(e.writes for e in expressions)) op._dimensions = filter_sorted( flatten(e.dimensions for e in expressions)) op._dimensions.extend(target_state.dimensions) op._dtype, op._dspace = clusters.meta op._profiler = profiler return op
def _build_parameters(self, iet): """Determine the Operator parameters based on the Iteration/Expression tree ``iet``.""" return derive_parameters(iet, True)
def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root): fetches = [] prefetches = [] presents = [] for s in sync_ops: if s.direction is Forward: fc = s.fetch.subs(s.dim, s.dim.symbolic_min) pfc = s.fetch + 1 fc_cond = s.next_cbk(s.dim.symbolic_min) pfc_cond = s.next_cbk(s.dim + 1) else: fc = s.fetch.subs(s.dim, s.dim.symbolic_max) pfc = s.fetch - 1 fc_cond = s.next_cbk(s.dim.symbolic_max) pfc_cond = s.next_cbk(s.dim - 1) # Construct init IET imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] fetch = PragmaList(self.lang._map_to(s.function, imask), {s.function} | fc.free_symbols) fetches.append(Conditional(fc_cond, fetch)) # Construct present clauses imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] presents.extend(as_list(self.lang._map_present(s.function, imask))) # Construct prefetch IET imask = [(pfc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] prefetch = PragmaList(self.lang._map_to_wait(s.function, imask, SharedData._field_id), {s.function} | pfc.free_symbols) prefetches.append(Conditional(pfc_cond, prefetch)) # Turn init IET into a Callable functions = filter_ordered(s.function for s in sync_ops) name = self.sregistry.make_name(prefix='init_device') body = List(body=fetches) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread pieces.init.append(List( header=c.Comment("Initialize data stream"), body=[Call(name, parameters), BlankLine] )) # Turn prefetch IET into a ThreadFunction name = self.sregistry.make_name(prefix='prefetch_host_to_device') body = List(header=c.Line(), body=prefetches) tctx = make_thread_ctx(name, body, root, None, sync_ops, self.sregistry) pieces.funcs.extend(tctx.funcs) # Glue together all the IET pieces, including the activation logic sdata = tctx.sdata threads = tctx.threads iet = List(body=[ BlankLine, BusyWait(CondNe(FieldFromComposite(sdata._field_flag, sdata[threads.index]), 1)), List(header=presents), iet, tctx.activate ]) # Fire up the threads pieces.init.append(tctx.init) pieces.threads.append(threads) # Final wait before jumping back to Python land pieces.finalize.append(tctx.finalize) return iet
def _create_elemental_functions(self, nodes, state): """ Extract :class:`Iteration` sub-trees and move them into :class:`Callable`s. Currently, only tagged, elementizable Iteration objects are targeted. """ noinline = self._compiler_decoration('noinline', c.Comment('noinline?')) functions = OrderedDict() mapper = {} for tree in retrieve_iteration_tree(nodes, mode='superset'): # Search an elementizable sub-tree (if any) tagged = filter_iterations(tree, lambda i: i.tag is not None, 'asap') if not tagged: continue root = tagged[0] if not root.is_Elementizable: continue target = tree[tree.index(root):] # Elemental function arguments args = [] # Found so far (scalars, tensors) defined_args = {} # Map of argument values defined by loop bounds # Build a new Iteration/Expression tree with free bounds free = [] for i in target: name, bounds = i.dim.name, i.bounds_symbolic # Iteration bounds start = Scalar(name='%s_start' % name, dtype=np.int32) finish = Scalar(name='%s_finish' % name, dtype=np.int32) defined_args[start.name] = bounds[0] defined_args[finish.name] = bounds[1] # Iteration unbounded indices ufunc = [ Scalar(name='%s_ub%d' % (name, j), dtype=np.int32) for j in range(len(i.uindices)) ] defined_args.update( {uf.name: j.start for uf, j in zip(ufunc, i.uindices)}) limits = [ Scalar(name=start.name, dtype=np.int32), Scalar(name=finish.name, dtype=np.int32), 1 ] uindices = [ UnboundedIndex(j.index, i.dim + as_symbol(k)) for j, k in zip(i.uindices, ufunc) ] free.append( i._rebuild(limits=limits, offsets=None, uindices=uindices)) # Construct elemental function body, and inspect it free = NestedTransformer(dict((zip(target, free)))).visit(root) # Insert array casts for all non-defined f_symbols = FindSymbols('symbolics').visit(free) defines = [s.name for s in FindSymbols('defines').visit(free)] casts = [ ArrayCast(f) for f in f_symbols if f.is_Tensor and f.name not in defines ] free = (List(body=casts), free) for i in derive_parameters(free): if i.name in defined_args: args.append((defined_args[i.name], i)) elif i.is_Dimension: d = Scalar(name=i.name, dtype=i.dtype) args.append((d, d)) else: args.append((i, i)) call, params = zip(*args) name = "f_%d" % root.tag # Produce the new Call mapper[root] = List(header=noinline, body=Call(name, call)) # Produce the new Callable functions.setdefault( name, Callable(name, free, 'void', flatten(params), ('static', ))) # Transform the main tree processed = Transformer(mapper).visit(nodes) return processed, {'elemental_functions': functions.values()}
def _make_fetchprefetch(self, iet, sync_ops, pieces, root): fid = SharedData._field_id fetches = [] prefetches = [] presents = [] for s in sync_ops: f = s.function dimensions = s.dimensions fc = s.fetch ifc = s.ifetch pfc = s.pfetch fcond = s.fcond pcond = s.pcond # Construct init IET imask = [(ifc, s.size) if d.root is s.dim.root else FULL for d in dimensions] fetch = PragmaTransfer(self.lang._map_to, f, imask=imask) fetches.append(Conditional(fcond, fetch)) # Construct present clauses imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in dimensions] presents.append( PragmaTransfer(self.lang._map_present, f, imask=imask)) # Construct prefetch IET imask = [(pfc, s.size) if d.root is s.dim.root else FULL for d in dimensions] prefetch = PragmaTransfer(self.lang._map_to_wait, f, imask=imask, queueid=fid) prefetches.append(Conditional(pcond, prefetch)) # Turn init IET into a Callable functions = filter_ordered(s.function for s in sync_ops) name = self.sregistry.make_name(prefix='init_device') body = List(body=fetches) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread pieces.init.append( List(header=c.Comment("Initialize data stream"), body=[Call(name, parameters), BlankLine])) # Turn prefetch IET into a ThreadFunction name = self.sregistry.make_name(prefix='prefetch_host_to_device') body = List(header=c.Line(), body=prefetches) tctx = make_thread_ctx(name, body, root, None, sync_ops, self.sregistry) pieces.funcs.extend(tctx.funcs) # Glue together all the IET pieces, including the activation logic sdata = tctx.sdata threads = tctx.threads iet = List(body=[ BlankLine, BusyWait( CondNe( FieldFromComposite(sdata._field_flag, sdata[ threads.index]), 1)) ] + presents + [iet, tctx.activate]) # Fire up the threads pieces.init.append(tctx.init) # Final wait before jumping back to Python land pieces.finalize.append(tctx.finalize) # Keep track of created objects pieces.objs.add(sync_ops, sdata, threads) return iet
def __init__(self, expressions, **kwargs): expressions = as_tuple(expressions) # Input check if any(not isinstance(i, Eq) for i in expressions): raise InvalidOperator("Only `devito.Eq` expressions are allowed.") self.name = kwargs.get("name", "Kernel") subs = kwargs.get("subs", {}) dse = kwargs.get("dse", configuration['dse']) # Header files, etc. self._headers = list(self._default_headers) self._includes = list(self._default_includes) self._globals = list(self._default_globals) # Required for compilation self._compiler = configuration['compiler'] self._lib = None self._cfunction = None # References to local or external routines self._func_table = OrderedDict() # Internal state. May be used to store information about previous runs, # autotuning reports, etc self._state = self._initialize_state(**kwargs) # Form and gather any required implicit expressions expressions = self._add_implicit(expressions) # Expression lowering: evaluation of derivatives, indexification, # substitution rules, specialization expressions = [i.evaluate for i in expressions] expressions = [indexify(i) for i in expressions] expressions = self._apply_substitutions(expressions, subs) expressions = self._specialize_exprs(expressions) # Expression analysis self._input = filter_sorted( flatten(e.reads + e.writes for e in expressions)) self._output = filter_sorted(flatten(e.writes for e in expressions)) self._dimensions = filter_sorted( flatten(e.dimensions for e in expressions)) # Group expressions based on their iteration space and data dependences, # and apply the Devito Symbolic Engine (DSE) for flop optimization clusters = clusterize(expressions) clusters = rewrite(clusters, mode=set_dse_mode(dse)) self._dtype, self._dspace = clusters.meta # Lower Clusters to a Schedule tree stree = st_build(clusters) # Lower Schedule tree to an Iteration/Expression tree (IET) iet = iet_build(stree) iet, self._profiler = self._profile_sections(iet) iet = self._specialize_iet(iet, **kwargs) # Derive all Operator parameters based on the IET parameters = derive_parameters(iet, True) # Finalization: introduce declarations, type casts, etc iet = self._finalize(iet, parameters) super(Operator, self).__init__(self.name, iet, 'int', parameters, ())
def __init__(self, expressions, **kwargs): expressions = as_tuple(expressions) # Input check if any(not isinstance(i, Eq) for i in expressions): raise InvalidOperator("Only `devito.Eq` expressions are allowed.") self.name = kwargs.get("name", "Kernel") subs = kwargs.get("subs", {}) dse = kwargs.get("dse", configuration['dse']) # Header files, etc. self._headers = list(self._default_headers) self._includes = list(self._default_includes) self._globals = list(self._default_globals) # Required for compilation self._compiler = configuration['compiler'] self._lib = None self._cfunction = None # References to local or external routines self._func_table = OrderedDict() # Internal state. May be used to store information about previous runs, # autotuning reports, etc self._state = {} # Form and gather any required implicit expressions expressions = self._add_implicit(expressions) # Expression lowering: indexification, substitution rules, specialization expressions = [indexify(i) for i in expressions] expressions = self._apply_substitutions(expressions, subs) expressions = self._specialize_exprs(expressions) # Expression analysis self._input = filter_sorted(flatten(e.reads + e.writes for e in expressions)) self._output = filter_sorted(flatten(e.writes for e in expressions)) self._dimensions = filter_sorted(flatten(e.dimensions for e in expressions)) # Group expressions based on their iteration space and data dependences, # and apply the Devito Symbolic Engine (DSE) for flop optimization clusters = clusterize(expressions) clusters = rewrite(clusters, mode=set_dse_mode(dse)) self._dtype, self._dspace = clusters.meta # Lower Clusters to a Schedule tree stree = st_build(clusters) # Lower Schedule tree to an Iteration/Expression tree (IET) iet = iet_build(stree) iet, self._profiler = self._profile_sections(iet) iet = self._specialize_iet(iet, **kwargs) # Derive all Operator parameters based on the IET parameters = derive_parameters(iet, True) # Finalization: introduce declarations, type casts, etc iet = self._finalize(iet, parameters) super(Operator, self).__init__(self.name, iet, 'int', parameters, ())
def make_yask_kernels(iet, **kwargs): yk_solns = kwargs.pop('yk_solns') mapper = {} for n, (section, trees) in enumerate(find_affine_trees(iet).items()): dimensions = tuple(filter_ordered(i.dim.root for i in flatten(trees))) # Retrieve the section dtype exprs = FindNodes(Expression).visit(section) dtypes = {e.dtype for e in exprs} if len(dtypes) != 1: log("Unable to offload in presence of mixed-precision arithmetic") continue dtype = dtypes.pop() context = contexts.fetch(dimensions, dtype) # A unique name for the 'real' compiler and kernel solutions name = namespace['jit-soln'](Signer._digest(configuration, *[i.root for i in trees])) # Create a YASK compiler solution for this Operator yc_soln = context.make_yc_solution(name) try: # Generate YASK vars and populate `yc_soln` with equations local_vars = yaskit(trees, yc_soln) # Build the new IET nodes yk_soln_obj = YASKSolnObject(namespace['code-soln-name'](n)) funcall = make_sharedptr_funcall(namespace['code-soln-run'], ['time'], yk_soln_obj) funcall = Offloaded(funcall, dtype) mapper[trees[0].root] = funcall mapper.update({i.root: mapper.get(i.root) for i in trees}) # Drop trees # JIT-compile the newly-created YASK kernel yk_soln = context.make_yk_solution(name, yc_soln, local_vars) yk_solns[(dimensions, yk_soln_obj)] = yk_soln # Print some useful information about the newly constructed solution log("Solution '%s' contains %d var(s) and %d equation(s)." % (yc_soln.get_name(), yc_soln.get_num_vars(), yc_soln.get_num_equations())) except NotImplementedError as e: log("Unable to offload a candidate tree. Reason: [%s]" % str(e)) iet = Transformer(mapper).visit(iet) if not yk_solns: log("No offloadable trees found") # Some Iteration/Expression trees are not offloaded to YASK and may # require further processing to be executed through YASK, due to the # different storage layout yk_var_objs = { i.name: YASKVarObject(i.name) for i in FindSymbols().visit(iet) if i.from_YASK } yk_var_objs.update({i: YASKVarObject(i) for i in get_local_vars(yk_solns)}) iet = make_var_accesses(iet, yk_var_objs) # The signature needs to be updated # TODO: this could be done automagically through the iet pass engine, but # currently it only supports *appending* to the parameters list. While here # we actually need to change it as some parameters may disappear (x_m, x_M, ...) parameters = derive_parameters(iet, True) iet = iet._rebuild(parameters=parameters) return iet, {}
def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root): threads = self.__make_threads() fetches = [] prefetches = [] presents = [] for s in sync_ops: if s.direction is Forward: fc = s.fetch.subs(s.dim, s.dim.symbolic_min) fsize = s.function._C_get_field(FULL, s.dim).size fc_cond = fc + (s.size - 1) < fsize pfc = s.fetch + 1 pfc_cond = pfc + (s.size - 1) < fsize else: fc = s.fetch.subs(s.dim, s.dim.symbolic_max) fc_cond = fc >= 0 pfc = s.fetch - 1 pfc_cond = pfc >= 0 # Construct fetch IET imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] fetch = List(header=self._P._map_to(s.function, imask)) fetches.append(Conditional(fc_cond, fetch)) # Construct present clauses imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] presents.extend(as_list(self._P._map_present(s.function, imask))) # Construct prefetch IET imask = [(pfc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] prefetch = List(header=self._P._map_to_wait( s.function, imask, SharedData._field_id)) prefetches.append(Conditional(pfc_cond, prefetch)) functions = filter_ordered(s.function for s in sync_ops) casts = [PointerCast(f) for f in functions] # Turn init IET into a Callable name = self.sregistry.make_name(prefix='init_device') body = List(body=casts + fetches) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread pieces.init.append( List(header=c.Comment("Initialize data stream for `%s`" % threads.name), body=[Call(name, func.parameters), BlankLine])) # Turn prefetch IET into a threaded Callable name = self.sregistry.make_name(prefix='prefetch_host_to_device') body = List(header=c.Line(), body=casts + prefetches) tfunc, sdata = self.__make_tfunc(name, body, root, threads) pieces.funcs.append(tfunc) # Glue together all the IET pieces, including the activation bits iet = List(body=[ BlankLine, BusyWait( CondNe( FieldFromComposite(sdata._field_flag, sdata[ threads.index]), 1)), List(header=presents), iet, self.__make_activate_thread(threads, sdata, sync_ops) ]) # Fire up the threads pieces.init.append( self.__make_init_threads(threads, sdata, tfunc, pieces)) pieces.threads.append(threads) # Final wait before jumping back to Python land pieces.finalize.append(self.__make_finalize_threads(threads, sdata)) return iet
def _build_parameters(self, iet): """Derive the Operator parameters.""" return derive_parameters(iet, True)
def _specialize_iet(cls, iet, **kwargs): """ Transform the Iteration/Expression tree to offload the computation of one or more loop nests onto YASK. This involves calling the YASK compiler to generate YASK code. Such YASK code is then called from within the transformed Iteration/Expression tree. """ mapper = {} yk_solns = kwargs.pop('yk_solns') for n, (section, trees) in enumerate(find_affine_trees(iet).items()): dimensions = tuple( filter_ordered(i.dim.root for i in flatten(trees))) # Retrieve the section dtype exprs = FindNodes(Expression).visit(section) dtypes = {e.dtype for e in exprs} if len(dtypes) != 1: log("Unable to offload in presence of mixed-precision arithmetic" ) continue dtype = dtypes.pop() context = contexts.fetch(dimensions, dtype) # A unique name for the 'real' compiler and kernel solutions name = namespace['jit-soln'](Signer._digest( configuration, *[i.root for i in trees])) # Create a YASK compiler solution for this Operator yc_soln = context.make_yc_solution(name) try: # Generate YASK vars and populate `yc_soln` with equations local_vars = yaskit(trees, yc_soln) # Build the new IET nodes yk_soln_obj = YaskSolnObject(namespace['code-soln-name'](n)) funcall = make_sharedptr_funcall(namespace['code-soln-run'], ['time'], yk_soln_obj) funcall = Offloaded(funcall, dtype) mapper[trees[0].root] = funcall mapper.update({i.root: mapper.get(i.root) for i in trees}) # Drop trees # JIT-compile the newly-created YASK kernel yk_soln = context.make_yk_solution(name, yc_soln, local_vars) yk_solns[(dimensions, yk_soln_obj)] = yk_soln # Print some useful information about the newly constructed solution log("Solution '%s' contains %d var(s) and %d equation(s)." % (yc_soln.get_name(), yc_soln.get_num_vars(), yc_soln.get_num_equations())) except NotImplementedError as e: log("Unable to offload a candidate tree. Reason: [%s]" % str(e)) iet = Transformer(mapper).visit(iet) if not yk_solns: log("No offloadable trees found") # Some Iteration/Expression trees are not offloaded to YASK and may # require further processing to be executed through YASK, due to the # different storage layout yk_var_objs = { i.name: YaskVarObject(i.name) for i in FindSymbols().visit(iet) if i.from_YASK } yk_var_objs.update( {i: YaskVarObject(i) for i in cls._get_local_vars(yk_solns)}) iet = make_var_accesses(iet, yk_var_objs) # The signature needs to be updated parameters = derive_parameters(iet, True) iet = iet._rebuild(parameters=parameters) return super(OperatorYASK, cls)._specialize_iet(iet, **kwargs)
def _create_efuncs(self, nodes, state): """ Extract Iteration sub-trees and turn them into Calls+Callables. Currently, only tagged, elementizable Iteration objects are targeted. """ noinline = self._compiler_decoration('noinline', c.Comment('noinline?')) efuncs = OrderedDict() mapper = {} for tree in retrieve_iteration_tree(nodes, mode='superset'): # Search an elementizable sub-tree (if any) tagged = filter_iterations(tree, lambda i: i.tag is not None, 'asap') if not tagged: continue root = tagged[0] if not root.is_Elementizable: continue target = tree[tree.index(root):] # Build a new Iteration/Expression tree with free bounds free = [] defined_args = {} # Map of argument values defined by loop bounds for i in target: name, bounds = i.dim.name, i.symbolic_bounds # Iteration bounds _min = Scalar(name='%sf_m' % name, dtype=np.int32, is_const=True) _max = Scalar(name='%sf_M' % name, dtype=np.int32, is_const=True) defined_args[_min.name] = bounds[0] defined_args[_max.name] = bounds[1] # Iteration unbounded indices ufunc = [ Scalar(name='%s_ub%d' % (name, j), dtype=np.int32) for j in range(len(i.uindices)) ] defined_args.update({ uf.name: j.symbolic_min for uf, j in zip(ufunc, i.uindices) }) uindices = [ IncrDimension(j.parent, i.dim + as_symbol(k), 1, j.name) for j, k in zip(i.uindices, ufunc) ] free.append( i._rebuild(limits=(_min, _max, 1), offsets=None, uindices=uindices)) # Construct elemental function body free = Transformer(dict((zip(target, free))), nested=True).visit(root) items = FindSymbols().visit(free) # Insert array casts casts = [ArrayCast(i) for i in items if i.is_Tensor] free = List(body=casts + [free]) # Insert declarations external = [i for i in items if i.is_Array] free = iet_insert_C_decls(free, external) # Create the Callable name = "f_%d" % root.tag params = derive_parameters(free) efuncs.setdefault(name, Callable(name, free, 'void', params, 'static')) # Create the Call args = [defined_args.get(i.name, i) for i in params] mapper[root] = List(header=noinline, body=Call(name, args)) # Transform the main tree processed = Transformer(mapper).visit(nodes) return processed, {'efuncs': efuncs.values()}