def test_xreplace_constrained_time_varying(tu, tv, tw, ti0, ti1, t0, t1, exprs, expected): exprs = EVAL(exprs, tu, tv, tw, ti0, ti1, t0, t1) counter = generator() make = lambda: Scalar(name='r%d' % counter()).indexify() processed, found = xreplace_constrained(exprs, make, iq_timevarying(FlowGraph(exprs)), lambda i: estimate_cost(i) > 0) assert len(found) == len(expected) assert all(str(i.rhs) == j for i, j in zip(found, expected))
def _extract_sum_of_products(self, cluster, template, **kwargs): """ Extract sub-expressions in sum-of-product form, and assign them to temporaries. """ make = lambda: Scalar(name=template(), dtype=cluster.dtype).indexify() rule = q_sum_of_product costmodel = lambda e: not (q_leaf(e) or q_terminalop(e)) processed, _ = xreplace_constrained(cluster.exprs, make, rule, costmodel) return cluster.rebuild(processed)
def _extract_time_varying(self, cluster, template, **kwargs): """ Extract time-varying subexpressions, and assign them to temporaries. Time varying subexpressions arise for example when approximating derivatives through finite differences. """ make = lambda i: Scalar(name=template(i)).indexify() rule = iq_timevarying(cluster.trace) costmodel = lambda i: estimate_cost(i) > 0 processed, _ = xreplace_constrained(cluster.exprs, make, rule, costmodel) return cluster.rebuild(processed)
def _extract_time_varying(self, cluster, template, **kwargs): """ Extract time-varying subexpressions, and assign them to temporaries. Time varying subexpressions arise for example when approximating derivatives through finite differences. """ make = lambda: Scalar(name=template(), dtype=cluster.dtype).indexify() rule = iq_timevarying(cluster.trace) costmodel = lambda i: estimate_cost(i) > 0 processed, _ = xreplace_constrained(cluster.exprs, make, rule, costmodel) return cluster.rebuild(processed)
def _extract_time_invariants(self, cluster, template, with_cse=True, **kwargs): """ Extract time-invariant subexpressions, and assign them to temporaries. """ make = lambda: Scalar(name=template(), dtype=cluster.dtype).indexify() rule = iq_timeinvariant(cluster.trace) costmodel = lambda e: estimate_cost(e) > 0 processed, found = xreplace_constrained(cluster.exprs, make, rule, costmodel) if with_cse: leaves = [i for i in processed if i not in found] # Search for common sub-expressions amongst them (and only them) found = common_subexprs_elimination(found, make) # Some temporaries may be droppable at this point processed = compact_temporaries(found, leaves) return cluster.rebuild(processed)
def compact_temporaries(temporaries, leaves): """Drop temporaries consisting of isolated symbols.""" exprs = temporaries + leaves targets = {i.lhs for i in leaves} graph = FlowGraph(exprs) mapper = {k: v.rhs for k, v in graph.items() if v.is_Scalar and (q_leaf(v.rhs) or v.rhs.is_Function) and not v.readby.issubset(targets)} processed = [] for k, v in graph.items(): if k not in mapper: # The temporary /v/ is retained, and substitutions may be applied handle, _ = xreplace_constrained(v, mapper, repeat=True) assert len(handle) == 1 processed.extend(handle) return processed
def compact_temporaries(temporaries, leaves): """Drop temporaries consisting of single symbols.""" exprs = temporaries + leaves targets = {i.lhs for i in leaves} g = FlowGraph(exprs) mapper = {k: v.rhs for k, v in g.items() if v.is_Scalar and (q_leaf(v.rhs) or v.rhs.is_Function) and not v.readby.issubset(targets)} processed = [] for k, v in g.items(): if k not in mapper: # The temporary /v/ is retained, and substitutions may be applied handle, _ = xreplace_constrained(v, mapper, repeat=True) assert len(handle) == 1 processed.extend(handle) return processed