def _extract_increments(self, cluster, template, **kwargs): """ Extract the RHS of non-local tensor expressions performing an associative and commutative increment, and assign them to temporaries. """ processed = [] for e in cluster.exprs: if e.is_Increment and e.lhs.function.is_Input: handle = Scalar(name=template(), dtype=e.dtype).indexify() if q_scalar(e.rhs): extracted = e.rhs else: extracted = e.rhs.func(*[i for i in e.rhs.args if i != e.lhs]) processed.extend([Eq(handle, extracted), e.func(e.lhs, handle)]) else: processed.append(e) return cluster.rebuild(processed)
def _extract_nonaffine_indices(self, cluster, template, **kwargs): """ Extract non-affine array indices, and assign them to temporaries. """ make = lambda: Scalar(name=template(), dtype=np.int32).indexify() mapper = OrderedDict() for e in cluster.exprs: for indexed in retrieve_indexed(e): for i, d in zip(indexed.indices, indexed.function.indices): if q_affine(i, d) or q_scalar(i): continue elif i not in mapper: mapper[i] = make() processed = [Eq(v, k) for k, v in mapper.items()] processed.extend([e.xreplace(mapper) for e in cluster.exprs]) return cluster.rebuild(processed)
def _extract_nonaffine_indices(self, cluster, template, **kwargs): """ Extract non-affine array indices, and assign them to temporaries. """ make = lambda: Scalar(name=template(), dtype=np.int32).indexify() mapper = OrderedDict() for e in cluster.exprs: # Note: using mode='all' and then checking for presence in the mapper # (a few lines below), rather retrieving unique indexeds only (a set), # is the key to deterministic code generation for indexed in retrieve_indexed(e, mode='all'): for i, d in zip(indexed.indices, indexed.function.indices): if q_affine(i, d) or q_scalar(i): continue elif i not in mapper: mapper[i] = make() processed = [Eq(v, k) for k, v in mapper.items()] processed.extend([e.xreplace(mapper) for e in cluster.exprs]) return cluster.rebuild(processed)