def test_xreplace_constrained_time_varying(tu, tv, tw, ti0, ti1, t0, t1, exprs, expected): exprs = EVAL(exprs, tu, tv, tw, ti0, ti1, t0, t1) make = lambda i: Scalar(name='r%d' % i).indexify() processed, found = xreplace_constrained(exprs, make, iq_timevarying(FlowGraph(exprs)), lambda i: estimate_cost(i) > 0) assert len(found) == len(expected) assert all(str(i.rhs) == j for i, j in zip(found, expected))
def test_xreplace_constrained_time_invariants(tu, tv, tw, ti0, ti1, t0, t1, exprs, expected): exprs = EVAL(exprs, tu, tv, tw, ti0, ti1, t0, t1) counter = generator() make = lambda: Scalar(name='r%d' % counter()).indexify() processed, found = xreplace_constrained(exprs, make, iq_timeinvariant(FlowGraph(exprs)), lambda i: estimate_cost(i) > 0) assert len(found) == len(expected) assert all(str(i.rhs) == j for i, j in zip(found, expected))
def compact_temporaries(temporaries, leaves): """Drop temporaries consisting of single symbols.""" exprs = temporaries + leaves targets = {i.lhs for i in leaves} g = FlowGraph(exprs) mapper = {k: v.rhs for k, v in g.items() if v.is_Scalar and (q_leaf(v.rhs) or v.rhs.is_Function) and not v.readby.issubset(targets)} processed = [] for k, v in g.items(): if k not in mapper: # The temporary /v/ is retained, and substitutions may be applied handle, _ = xreplace_constrained(v, mapper, repeat=True) assert len(handle) == 1 processed.extend(handle) return processed
def promote_scalar_expressions(exprs, shape, indices, onstack): """ Transform a collection of scalar expressions into tensor expressions. """ processed = [] # Fist promote the LHS mapper = {} for k, v in FlowGraph(exprs).items(): if v.is_scalar: # Create a new function symbol data = Array(name=k.name, shape=shape, dimensions=indices, onstack=onstack) indexed = Indexed(data.indexed, *indices) mapper[k] = indexed processed.append(Eq(indexed, v.rhs)) else: processed.append(Eq(k, v.rhs)) # Propagate the transformed LHS through the expressions processed = [Eq(n.lhs, n.rhs.xreplace(mapper)) for n in processed] return processed
def test_graph_isindex(fa, fb, fc, t0, t1, t2, exprs, expected): g = FlowGraph(EVAL(exprs, fa, fb, fc, t0, t1, t2)) mapper = eval(expected) for k, v in mapper.items(): assert g.is_index(k) == v
def test_graph_trace(tu, tv, tw, ti0, ti1, t0, t1, exprs, expected): g = FlowGraph(EVAL(exprs, tu, tv, tw, ti0, ti1, t0, t1)) mapper = eval(expected) for i in [tu, tv, tw, ti0, ti1, t0, t1]: assert set([j.lhs for j in g.trace(i)]) == mapper[i]