def clusterize(exprs, stencils, atomics=None): """ Derive :class:`Cluster` objects from an iterable of expressions; a stencil for each expression must be provided. A list of atomic dimensions (see description in Cluster.__doc__) may be provided. """ assert len(exprs) == len(stencils) exprs, stencils = aggregate(exprs, stencils) Info = namedtuple('Info', 'trace stencil') # Build a dependence graph and associate each node with its Stencil mapper = OrderedDict() g = TemporariesGraph(exprs) for (k, v), j in zip(g.items(), stencils): if v.is_tensor: trace = g.trace(k) trace += tuple(i for i in g.trace(k, readby=True) if i not in trace) mapper[k] = Info(trace, j) # A cluster stencil is determined iteratively, by first calculating the # "local" stencil and then by looking at the stencils of all other clusters # depending on it. The stencil information is propagated until there are # no more updates. queue = list(mapper) while queue: target = queue.pop(0) info = mapper[target] strict_trace = [i.lhs for i in info.trace if i.lhs != target] stencil = Stencil(info.stencil.entries) for i in strict_trace: if i in mapper: stencil = stencil.add(mapper[i].stencil) mapper[target] = Info(info.trace, stencil) if stencil != info.stencil: # Something has changed, need to propagate the update queue.extend([i for i in strict_trace if i not in queue]) clusters = [] for target, info in mapper.items(): # Drop all non-output tensors, as computed by other clusters exprs = [i for i in info.trace if i.lhs.is_Symbol or i.lhs == target] # Create and track the cluster clusters.append(Cluster(exprs, info.stencil.frozen, atomics)) return merge(clusters)
def clusterize(exprs, stencils): """ Derive :class:`Cluster` objects from an iterable of expressions; a stencil for each expression must be provided. """ assert len(exprs) == len(stencils) exprs, stencils = aggregate(exprs, stencils) # Create a PartialCluster for each sequence of expressions computing a tensor mapper = OrderedDict() g = TemporariesGraph(exprs) for (k, v), j in zip(g.items(), stencils): if v.is_tensor: exprs = g.trace(k) exprs += tuple(i for i in g.trace(k, readby=True) if i not in exprs) mapper[k] = PartialCluster(exprs, j) # Update the PartialClusters' Stencils by looking at the Stencil of the # surrounding PartialClusters. queue = list(mapper) while queue: target = queue.pop(0) pc = mapper[target] strict_trace = [i.lhs for i in pc.exprs if i.lhs != target] stencil = pc.stencil.copy() for i in strict_trace: if i in mapper: stencil = stencil.add(mapper[i].stencil) if stencil != pc.stencil: # Something has changed, need to propagate the update pc.stencil = stencil queue.extend([i for i in strict_trace if i not in queue]) # Drop all non-output tensors, as computed by other clusters clusters = ClusterGroup() for target, pc in mapper.items(): exprs = [i for i in pc.exprs if i.lhs.is_Symbol or i.lhs == target] clusters.append(PartialCluster(exprs, pc.stencil)) # Attempt grouping as many PartialClusters as possible together return groupby(clusters)
class Cluster(object): """ A Cluster is an ordered sequence of expressions that are necessary to compute a tensor, plus the tensor expression itself. A Cluster is associated with a stencil, which tracks what neighboring points are required, along each dimension, to compute an entry in the tensor. The parameter ``atomics`` allows to specify dimensions (a subset of those appearing in ``stencil``) along which a Cluster cannot be fused with other clusters. This is for example useful when a Cluster is evaluating a tensor temporary, whose values must all be updated before being accessed in the subsequent clusters. """ def __init__(self, exprs, stencil, atomics): self.trace = TemporariesGraph(exprs) self.stencil = stencil self.atomics = as_tuple(atomics) @property def exprs(self): return self.trace.values() @property def unknown(self): return self.trace.unknown @property def tensors(self): return self.trace.tensors @property def is_dense(self): return self.trace.space_indices and not self.trace.time_invariant() @property def is_sparse(self): return not self.is_dense def rebuild(self, exprs): """ Build a new cluster with expressions ``exprs`` having same stencil as ``self``. """ return Cluster(exprs, self.stencil, self.atomics)
def compact_temporaries(temporaries, leaves): """ Drop temporaries consisting of single symbols. """ exprs = temporaries + leaves targets = {i.lhs for i in leaves} g = TemporariesGraph(exprs) mapper = {k: v.rhs for k, v in g.items() if v.is_scalar and (q_leaf(v.rhs) or v.rhs.is_Function) and not v.readby.issubset(targets)} processed = [] for k, v in g.items(): if k not in mapper: # The temporary /v/ is retained, and substitutions may be applied handle, _ = xreplace_constrained(v, mapper, repeat=True) assert len(handle) == 1 processed.extend(handle) return processed
def promote_scalar_expressions(exprs, shape, indices, onstack): """ Transform a collection of scalar expressions into tensor expressions. """ processed = [] # Fist promote the LHS graph = TemporariesGraph(exprs) mapper = {} for k, v in graph.items(): if v.is_scalar: # Create a new function symbol data = Array(name=k.name, shape=shape, dimensions=indices, onstack=onstack) indexed = Indexed(data.indexed, *indices) mapper[k] = indexed processed.append(Eq(indexed, v.rhs)) else: processed.append(Eq(k, v.rhs)) # Propagate the transformed LHS through the expressions processed = [Eq(n.lhs, n.rhs.xreplace(mapper)) for n in processed] return processed
def __init__(self, exprs, stencil, atomics): self.trace = TemporariesGraph(exprs) self.stencil = stencil self.atomics = as_tuple(atomics)
def trace(self): return TemporariesGraph(self.exprs)