def extract(cls, n, context, min_cost, max_alias, cluster, sregistry): make = lambda: Scalar(name=sregistry.make_name(), dtype=cluster.dtype ).indexify() # The `depth` determines "how big" the extracted sum-of-products will be. # We observe that in typical FD codes: # add(mul, mul, ...) -> stems from first order derivative # add(mul(add(mul, mul, ...), ...), ...) -> stems from second order derivative # To search the muls in the former case, we need `depth=0`; to search the outer # muls in the latter case, we need `depth=2` depth = n exclude = { i.source.indexed for i in cluster.scope.d_flow.independent() } rule0 = lambda e: not e.free_symbols & exclude rule1 = lambda e: e.is_Mul and q_terminalop(e, depth) rule = lambda e: rule0(e) and rule1(e) extracted = OrderedDict() mapper = {} for e in cluster.exprs: for i in search(e, rule, 'all', 'bfs_first_hit'): if i in mapper: continue key = lambda a: a.is_Add terms, others = split(list(i.args), key) if max_alias: # Treat `e` as an FD expression and pull out the derivative # coefficient from `i` # Note: typically derivative coefficients are numbers, but # sometimes they could be provided in symbolic form through an # arbitrary Function. In the latter case, we rely on the # heuristic that such Function's basically never span the whole # grid, but rather a single Grid dimension (e.g., `c[z, n]` for a # stencil of diameter `n` along `z`) if e.grid is not None and terms: key = partial(maybe_coeff_key, e.grid) others, more_terms = split(others, key) terms.extend(more_terms) if terms: k = i.func(*terms) try: symbol, _ = extracted[k] except KeyError: symbol, _ = extracted.setdefault(k, (make(), e)) mapper[i] = i.func(symbol, *others) if mapper: extracted = [e.func(v, k) for k, (v, e) in extracted.items()] processed = [uxreplace(e, mapper) for e in cluster.exprs] return extracted + processed, extracted else: return cluster.exprs, []
def _extract_sum_of_products(self, cluster, template, **kwargs): """ Extract sub-expressions in sum-of-product form, and assign them to temporaries. """ make = lambda: Scalar(name=template(), dtype=cluster.dtype).indexify() rule = q_sum_of_product costmodel = lambda e: not (q_leaf(e) or q_terminalop(e)) processed, _ = yreplace(cluster.exprs, make, rule, costmodel) return cluster.rebuild(processed)
def _extract_sum_of_products(self, cluster, template, **kwargs): """ Extract sub-expressions in sum-of-product form, and assign them to temporaries. """ make = lambda: Scalar(name=template(), dtype=cluster.dtype).indexify() rule = q_sum_of_product costmodel = lambda e: not (q_leaf(e) or q_terminalop(e)) processed, _ = xreplace_constrained(cluster.exprs, make, rule, costmodel) return cluster.rebuild(processed)
def callbacks_sops(context, n): # The `depth` determines "how big" the extracted sum-of-products will be. # We observe that in typical FD codes: # add(mul, mul, ...) -> stems from first order derivative # add(mul(add(mul, mul, ...), ...), ...) -> stems from second order derivative # To catch the former, we would need `depth=1`; for the latter, `depth=3` depth = 2 * n + 1 extractor = lambda e: q_sum_of_product(e, depth) model = lambda e: not (q_leaf(e) or q_terminalop(e, depth - 1)) ignore_collected = lambda g: len(g) <= 1 selector = lambda c, n: c >= MIN_COST_ALIAS and n > 1 return extractor, model, ignore_collected, selector
def extract(cls, n, context, min_cost, cluster, sregistry): make = lambda: Scalar(name=sregistry.make_name(), dtype=cluster.dtype ).indexify() # The `depth` determines "how big" the extracted sum-of-products will be. # We observe that in typical FD codes: # add(mul, mul, ...) -> stems from first order derivative # add(mul(add(mul, mul, ...), ...), ...) -> stems from second order derivative # To search the muls in the former case, we need `depth=0`; to search the outer # muls in the latter case, we need `depth=2` depth = n exclude = { i.source.indexed for i in cluster.scope.d_flow.independent() } rule0 = lambda e: not e.free_symbols & exclude rule1 = lambda e: e.is_Mul and q_terminalop(e, depth) rule = lambda e: rule0(e) and rule1(e) extracted = OrderedDict() mapper = {} for e in cluster.exprs: for i in search(e, rule, 'all', 'bfs_first_hit'): if i in mapper: continue # Separate numbers and Functions, as they could be a derivative coeff terms, others = split(i.args, lambda a: a.is_Add) if terms: k = i.func(*terms) try: symbol, _ = extracted[k] except KeyError: symbol, _ = extracted.setdefault(k, (make(), e)) mapper[i] = i.func(symbol, *others) if mapper: extracted = [e.func(v, k) for k, (v, e) in extracted.items()] processed = [uxreplace(e, mapper) for e in cluster.exprs] return extracted + processed, extracted else: return cluster.exprs, []
def cire(cluster, template, mode, options, platform): """ Cross-iteration redundancies elimination. Parameters ---------- cluster : Cluster Input Cluster, subject of the optimization pass. template : callable To build the symbols (temporaries) storing the redundant expressions. mode : str The transformation mode. Accepted: ['invariants', 'sops']. * 'invariants' is for sub-expressions that are invariant w.r.t. one or more Dimensions. * 'sops' stands for sums-of-products, that is redundancies are searched across all expressions in sum-of-product form. options : dict The optimization mode. Accepted: ['min-storage']. * 'min-storage': if True, the pass will try to minimize the amount of storage introduced for the tensor temporaries. This might also reduce the operation count. On the other hand, this might affect fusion and therefore data locality. Defaults to False (legacy). platform : Platform The underlying platform. Used to optimize the shape of the introduced tensor symbols. Examples -------- 1) 'invariants'. Below is an expensive sub-expression invariant w.r.t. `t` t0 = (cos(a[x,y,z])*sin(b[x,y,z]))*c[t,x,y,z] becomes t1[x,y,z] = cos(a[x,y,z])*sin(b[x,y,z]) t0 = t1[x,y,z]*c[t,x,y,z] 2) 'sops'. Below are redundant sub-expressions in sum-of-product form (in this case, the sum degenerates to a single product). t0 = 2.0*a[x,y,z]*b[x,y,z] t1 = 3.0*a[x,y,z+1]*b[x,y,z+1] becomes t2[x,y,z] = a[x,y,z]*b[x,y,z] t0 = 2.0*t2[x,y,z] t1 = 3.0*t2[x,y,z+1] """ # Sanity checks assert mode in ['invariants', 'sops'] assert all(i > 0 for i in options['cire-repeats'].values()) # Relevant options min_storage = options['min-storage'] # Setup callbacks if mode == 'invariants': # Extraction rule def extractor(context): is_time_invariant = make_is_time_invariant(context) return lambda e: is_time_invariant(e) # Extraction model model = lambda e: estimate_cost(e, True) >= MIN_COST_ALIAS_INV # Selection rule selector = lambda c, n: c >= MIN_COST_ALIAS_INV and n >= 1 elif mode == 'sops': # Extraction rule def extractor(context): return lambda e: q_sum_of_product(e) # Extraction model model = lambda e: not (q_leaf(e) or q_terminalop(e)) # Selection rule selector = lambda c, n: c >= MIN_COST_ALIAS and n > 1 # Actual CIRE processed = [] context = cluster.exprs for _ in range(options['cire-repeats'][mode]): # Extract potentially aliasing expressions exprs, extracted = extract(cluster, extractor(context), model, template) if not extracted: # Do not waste time break # Search aliasing expressions aliases = collect(extracted, min_storage) # Rule out aliasing expressions with a bad flops/memory trade-off chosen, others = choose(exprs, aliases, selector) if not chosen: # Do not waste time break # Create Aliases and assign them to Clusters clusters, subs = process(cluster, chosen, aliases, template, platform) # Rebuild `cluster` so as to use the newly created Aliases rebuilt = rebuild(cluster, others, aliases, subs) # Prepare for the next round processed.extend(clusters) cluster = rebuilt context = flatten(c.exprs for c in processed) + list(cluster.exprs) processed.append(cluster) return processed