def _select_expression(expressions, index): """Helper function to select an expression from a list of expressions with an index. This function expect sanitised input, one should normally call :py:func:`select_expression` instead. :arg expressions: a list of expressions :arg index: an index (free, fixed or variable) :returns: an expression """ expr = expressions[0] if all(e == expr for e in expressions): return expr types = set(map(type, expressions)) if types <= {Indexed, Zero}: multiindex, = set(e.multiindex for e in expressions if isinstance(e, Indexed)) # Shape only determined by free indices shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) def child(expression): if isinstance(expression, Indexed): return expression.children[0] elif isinstance(expression, Zero): return Zero(shape) return Indexed( _select_expression(list(map(child, expressions)), index), multiindex) if types <= {Literal, Zero, Failure}: return partial_indexed(ListTensor(expressions), (index, )) if types <= {ComponentTensor, Zero}: shape, = set(e.shape for e in expressions) multiindex = tuple(Index(extent=d) for d in shape) children = remove_componenttensors( [Indexed(e, multiindex) for e in expressions]) return ComponentTensor(_select_expression(children, index), multiindex) if len(types) == 1: cls, = types if cls.__front__ or cls.__back__: raise NotImplementedError( "How to factorise {} expressions?".format(cls.__name__)) assert all(len(e.children) == len(expr.children) for e in expressions) assert len(expr.children) > 0 return expr.reconstruct(*[ _select_expression(nth_children, index) for nth_children in zip(*[e.children for e in expressions]) ]) raise NotImplementedError( "No rule for factorising expressions of this kind.")
def test_delta_elimination(): i = Index() j = Index() k = Index() I = Identity(3) sum_indices = (i, j) factors = [Delta(i, j), Delta(i, k), Indexed(I, (j, k))] sum_indices, factors = delta_elimination(sum_indices, factors) factors = remove_componenttensors(factors) assert sum_indices == [] assert factors == [one, one, Indexed(I, (k, k))]
def contraction(expression): """Optimise the contractions of the tensor product at the root of the expression, including: - IndexSum-Delta cancellation - Sum factorisation This routine was designed with finite element coefficient evaluation in mind. """ # Eliminate annoying ComponentTensors expression, = remove_componenttensors([expression]) # Flatten product tree, eliminate deltas, sum factorise def rebuild(expression): sum_indices, factors = delta_elimination(*traverse_product(expression)) factors = remove_componenttensors(factors) return sum_factorise(sum_indices, factors) # Sometimes the value shape is composed as a ListTensor, which # could get in the way of decomposing factors. In particular, # this is the case for H(div) and H(curl) conforming tensor # product elements. So if ListTensors are used, they are pulled # out to be outermost, so we can straightforwardly factorise each # of its entries. lt_fis = OrderedDict() # ListTensor free indices for node in traversal((expression, )): if isinstance(node, Indexed): child, = node.children if isinstance(child, ListTensor): lt_fis.update(zip_longest(node.multiindex, ())) lt_fis = tuple(index for index in lt_fis if index in expression.free_indices) if lt_fis: # Rebuild each split component tensor = ComponentTensor(expression, lt_fis) entries = [ Indexed(tensor, zeta) for zeta in numpy.ndindex(tensor.shape) ] entries = remove_componenttensors(entries) return Indexed( ListTensor( numpy.array(list(map(rebuild, entries))).reshape(tensor.shape)), lt_fis) else: # Rebuild whole expression at once return rebuild(expression)
def _replace_delta_delta(node, self): i, j = node.i, node.j if isinstance(i, Index) or isinstance(j, Index): if isinstance(i, Index) and isinstance(j, Index): assert i.extent == j.extent if isinstance(i, Index): assert i.extent is not None size = i.extent if isinstance(j, Index): assert j.extent is not None size = j.extent return Indexed(Identity(size), (i, j)) else: def expression(index): if isinstance(index, int): return Literal(index) elif isinstance(index, VariableIndex): return index.expression else: raise ValueError("Cannot convert running index to expression.") e_i = expression(i) e_j = expression(j) return Conditional(Comparison("==", e_i, e_j), one, Zero())
def applier(expr): pairs = [(i, j) for i, j in renames if i in expr.free_indices] if pairs: current, renamed = zip(*pairs) return Indexed(ComponentTensor(expr, current), renamed) else: return expr
def split_variable(variable_ref, index, multiindices): """Splits a flexibly indexed variable along a concatenation index. :param variable_ref: flexibly indexed variable to split :param index: :py:class:`Concatenate` index to split along :param multiindices: one multiindex for each split variable :returns: generator of split indexed variables """ assert isinstance(variable_ref, FlexiblyIndexed) other_indices = list(variable_ref.index_ordering()) other_indices.remove(index) other_indices = tuple(other_indices) data = ComponentTensor(variable_ref, (index, ) + other_indices) slices = [slice(None)] * len(other_indices) shapes = [(other_index.extent, ) for other_index in other_indices] offset = 0 for multiindex in multiindices: shape = tuple(index.extent for index in multiindex) size = numpy.prod(shape, dtype=int) slice_ = slice(offset, offset + size) offset += size sub_ref = Indexed(reshape(view(data, slice_, *slices), shape, *shapes), multiindex + other_indices) sub_ref, = remove_componenttensors((sub_ref, )) yield sub_ref
def test_loop_optimise(): I = 20 J = K = 10 i = Index('i', I) j = Index('j', J) k = Index('k', K) A1 = Variable('a1', (I, )) A2 = Variable('a2', (I, )) A3 = Variable('a3', (I, )) A1i = Indexed(A1, (i, )) A2i = Indexed(A2, (i, )) A3i = Indexed(A3, (i, )) B = Variable('b', (J, )) C = Variable('c', (J, )) Bj = Indexed(B, (j, )) Cj = Indexed(C, (j, )) E = Variable('e', (K, )) F = Variable('f', (K, )) G = Variable('g', (K, )) Ek = Indexed(E, (k, )) Fk = Indexed(F, (k, )) Gk = Indexed(G, (k, )) Z = Variable('z', ()) # Bj*Ek + Bj*Fk => (Ek + Fk)*Bj expr = Sum(Product(Bj, Ek), Product(Bj, Fk)) result, = optimise_expressions([expr], (j, k)) expected = Product(Sum(Ek, Fk), Bj) assert result == expected # Bj*Ek + Bj*Fk + Bj*Gk + Cj*Ek + Cj*Fk => # (Ek + Fk + Gk)*Bj + (Ek+Fk)*Cj expr = Sum( Sum(Sum(Sum(Product(Bj, Ek), Product(Bj, Fk)), Product(Bj, Gk)), Product(Cj, Ek)), Product(Cj, Fk)) result, = optimise_expressions([expr], (j, k)) expected = Sum(Product(Sum(Sum(Ek, Fk), Gk), Bj), Product(Sum(Ek, Fk), Cj)) assert result == expected # Z*A1i*Bj*Ek + Z*A2i*Bj*Ek + A3i*Bj*Ek + Z*A1i*Bj*Fk => # Bj*(Ek*(Z*A1i + Z*A2i) + A3i) + Z*A1i*Fk) expr = Sum( Sum( Sum(Product(Z, Product(A1i, Product(Bj, Ek))), Product(Z, Product(A2i, Product(Bj, Ek)))), Product(A3i, Product(Bj, Ek))), Product(Z, Product(A1i, Product(Bj, Fk)))) result, = optimise_expressions([expr], (j, k)) expected = Product( Sum(Product(Ek, Sum(Sum(Product(Z, A1i), Product(Z, A2i)), A3i)), Product(Fk, Product(Z, A1i))), Bj) assert result == expected
def substitute(expression, from_, to_): if from_ not in expression.free_indices: return expression elif isinstance(expression, Delta): mapper = MemoizerArg(filtered_replace_indices) return mapper(expression, ((from_, to_), )) else: return Indexed(ComponentTensor(expression, (from_, )), (to_, ))
def test_conditional_zero_folding(): b = Variable("B", ()) a = Variable("A", (3, )) i = Index() expr = Conditional(LogicalAnd(b, b), Product(Indexed(a, (i, )), Zero()), Zero()) assert expr == Zero()
def test_replace_div(): i = Index() A = Variable('A', ()) B = Variable('B', (6,)) Bi = Indexed(B, (i,)) d = Division(Bi, A) result, = replace_division([d]) expected = Product(Bi, Division(Literal(1.0), A)) assert result == expected
def aggressive_unroll(expression): """Aggressively unrolls all loop structures.""" # Unroll expression shape if expression.shape: tensor = numpy.empty(expression.shape, dtype=object) for alpha in numpy.ndindex(expression.shape): tensor[alpha] = Indexed(expression, alpha) expression, = remove_componenttensors((ListTensor(tensor), )) # Unroll summation expression, = unroll_indexsum((expression, ), predicate=lambda index: True) expression, = remove_componenttensors((expression, )) return expression
def _(node, self): unroll = tuple(filter(self.predicate, node.multiindex)) if unroll: # Unrolling summand = self(node.children[0]) shape = tuple(index.extent for index in unroll) unrolled = reduce(Sum, (Indexed(ComponentTensor(summand, unroll), alpha) for alpha in numpy.ndindex(shape)), Zero()) return IndexSum( unrolled, tuple(index for index in node.multiindex if index not in unroll)) else: return reuse_if_untouched(node, self)
def replace_indices_indexed(node, self, subst): child, = node.children substitute = dict(subst) multiindex = tuple(substitute.get(i, i) for i in node.multiindex) if isinstance(child, ComponentTensor): # Indexing into ComponentTensor # Inline ComponentTensor and augment the substitution rules substitute.update(zip(child.multiindex, multiindex)) return self(child.children[0], tuple(sorted(substitute.items()))) else: # Replace indices new_child = self(child, subst) if new_child == child and multiindex == node.multiindex: return node else: return Indexed(new_child, multiindex)
def _unconcatenate(cache, pairs): # Tail-call recursive core of unconcatenate. # Assumes that input has already been sanitised. concat_group = find_group([e for v, e in pairs]) if concat_group is None: return pairs # Get the index split concat_ref = next(iter(concat_group)) assert isinstance(concat_ref, Indexed) concat_expr, = concat_ref.children index, = concat_ref.multiindex assert isinstance(concat_expr, Concatenate) try: multiindices = cache[index] except KeyError: multiindices = tuple( tuple(Index(extent=d) for d in child.shape) for child in concat_expr.children) cache[index] = multiindices def cut(node): """No need to rebuild expression of independent of the relevant concatenation index.""" return index not in node.free_indices # Build Concatenate node replacement mappings mappings = [{} for i in range(len(multiindices))] for concat_ref in concat_group: concat_expr, = concat_ref.children for i in range(len(multiindices)): sub_ref = Indexed(concat_expr.children[i], multiindices[i]) sub_ref, = remove_componenttensors((sub_ref, )) mappings[i][concat_ref] = sub_ref # Finally, split assignment pairs split_pairs = [] for var, expr in pairs: if index not in var.free_indices: split_pairs.append((var, expr)) else: for v, m in zip(split_variable(var, index, multiindices), mappings): split_pairs.append((v, replace_node(expr, m, cut))) # Run again, there may be other Concatenate groups return _unconcatenate(cache, split_pairs)
def test_loop_fusion(): i = Index() j = Index() Ri = Indexed(Variable('R', (6, )), (i, )) def make_expression(i, j): A = Variable('A', (6, )) s = IndexSum(Indexed(A, (j, )), (j, )) return Product(Indexed(A, (i, )), s) e1 = make_expression(i, j) e2 = make_expression(i, i) def gencode(expr): impero_c = impero_utils.compile_gem([(Ri, expr)], (i, j)) return impero_c.tree assert len(gencode(e1).children) == len(gencode(e2).children)
def select_expression(expressions, index): """Select an expression from a list of expressions with an index. Semantically equivalent to partial_indexed(ListTensor(expressions), (index,)) but has a much more optimised implementation. :arg expressions: a list of expressions of the same shape :arg index: an index (free, fixed or variable) :returns: an expression of the same shape as the given expressions """ # Check arguments shape = expressions[0].shape assert all(e.shape == shape for e in expressions) # Sanitise input expressions alpha = tuple(Index() for s in shape) exprs = remove_componenttensors([Indexed(e, alpha) for e in expressions]) # Factor the expressions recursively and convert result selected = _select_expression(exprs, index) return ComponentTensor(selected, alpha)
def make_expression(i, j): A = Variable('A', (6, )) s = IndexSum(Indexed(A, (j, )), (j, )) return Product(Indexed(A, (i, )), s)