def test_loop_optimise(): I = 20 J = K = 10 i = Index('i', I) j = Index('j', J) k = Index('k', K) A1 = Variable('a1', (I, )) A2 = Variable('a2', (I, )) A3 = Variable('a3', (I, )) A1i = Indexed(A1, (i, )) A2i = Indexed(A2, (i, )) A3i = Indexed(A3, (i, )) B = Variable('b', (J, )) C = Variable('c', (J, )) Bj = Indexed(B, (j, )) Cj = Indexed(C, (j, )) E = Variable('e', (K, )) F = Variable('f', (K, )) G = Variable('g', (K, )) Ek = Indexed(E, (k, )) Fk = Indexed(F, (k, )) Gk = Indexed(G, (k, )) Z = Variable('z', ()) # Bj*Ek + Bj*Fk => (Ek + Fk)*Bj expr = Sum(Product(Bj, Ek), Product(Bj, Fk)) result, = optimise_expressions([expr], (j, k)) expected = Product(Sum(Ek, Fk), Bj) assert result == expected # Bj*Ek + Bj*Fk + Bj*Gk + Cj*Ek + Cj*Fk => # (Ek + Fk + Gk)*Bj + (Ek+Fk)*Cj expr = Sum( Sum(Sum(Sum(Product(Bj, Ek), Product(Bj, Fk)), Product(Bj, Gk)), Product(Cj, Ek)), Product(Cj, Fk)) result, = optimise_expressions([expr], (j, k)) expected = Sum(Product(Sum(Sum(Ek, Fk), Gk), Bj), Product(Sum(Ek, Fk), Cj)) assert result == expected # Z*A1i*Bj*Ek + Z*A2i*Bj*Ek + A3i*Bj*Ek + Z*A1i*Bj*Fk => # Bj*(Ek*(Z*A1i + Z*A2i) + A3i) + Z*A1i*Fk) expr = Sum( Sum( Sum(Product(Z, Product(A1i, Product(Bj, Ek))), Product(Z, Product(A2i, Product(Bj, Ek)))), Product(A3i, Product(Bj, Ek))), Product(Z, Product(A1i, Product(Bj, Fk)))) result, = optimise_expressions([expr], (j, k)) expected = Product( Sum(Product(Ek, Sum(Sum(Product(Z, A1i), Product(Z, A2i)), A3i)), Product(Fk, Product(Z, A1i))), Bj) assert result == expected
def test_delta_elimination(): i = Index() j = Index() k = Index() I = Identity(3) sum_indices = (i, j) factors = [Delta(i, j), Delta(i, k), Indexed(I, (j, k))] sum_indices, factors = delta_elimination(sum_indices, factors) factors = remove_componenttensors(factors) assert sum_indices == [] assert factors == [one, one, Indexed(I, (k, k))]
def test_conditional_zero_folding(): b = Variable("B", ()) a = Variable("A", (3, )) i = Index() expr = Conditional(LogicalAnd(b, b), Product(Indexed(a, (i, )), Zero()), Zero()) assert expr == Zero()
def test_loop_fusion(): i = Index() j = Index() Ri = Indexed(Variable('R', (6, )), (i, )) def make_expression(i, j): A = Variable('A', (6, )) s = IndexSum(Indexed(A, (j, )), (j, )) return Product(Indexed(A, (i, )), s) e1 = make_expression(i, j) e2 = make_expression(i, i) def gencode(expr): impero_c = impero_utils.compile_gem([(Ri, expr)], (i, j)) return impero_c.tree assert len(gencode(e1).children) == len(gencode(e2).children)
def test_replace_div(): i = Index() A = Variable('A', ()) B = Variable('B', (6,)) Bi = Indexed(B, (i,)) d = Division(Bi, A) result, = replace_division([d]) expected = Product(Bi, Division(Literal(1.0), A)) assert result == expected
def _select_expression(expressions, index): """Helper function to select an expression from a list of expressions with an index. This function expect sanitised input, one should normally call :py:func:`select_expression` instead. :arg expressions: a list of expressions :arg index: an index (free, fixed or variable) :returns: an expression """ expr = expressions[0] if all(e == expr for e in expressions): return expr types = set(map(type, expressions)) if types <= {Indexed, Zero}: multiindex, = set(e.multiindex for e in expressions if isinstance(e, Indexed)) # Shape only determined by free indices shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) def child(expression): if isinstance(expression, Indexed): return expression.children[0] elif isinstance(expression, Zero): return Zero(shape) return Indexed( _select_expression(list(map(child, expressions)), index), multiindex) if types <= {Literal, Zero, Failure}: return partial_indexed(ListTensor(expressions), (index, )) if types <= {ComponentTensor, Zero}: shape, = set(e.shape for e in expressions) multiindex = tuple(Index(extent=d) for d in shape) children = remove_componenttensors( [Indexed(e, multiindex) for e in expressions]) return ComponentTensor(_select_expression(children, index), multiindex) if len(types) == 1: cls, = types if cls.__front__ or cls.__back__: raise NotImplementedError( "How to factorise {} expressions?".format(cls.__name__)) assert all(len(e.children) == len(expr.children) for e in expressions) assert len(expr.children) > 0 return expr.reconstruct(*[ _select_expression(nth_children, index) for nth_children in zip(*[e.children for e in expressions]) ]) raise NotImplementedError( "No rule for factorising expressions of this kind.")
def _unconcatenate(cache, pairs): # Tail-call recursive core of unconcatenate. # Assumes that input has already been sanitised. concat_group = find_group([e for v, e in pairs]) if concat_group is None: return pairs # Get the index split concat_ref = next(iter(concat_group)) assert isinstance(concat_ref, Indexed) concat_expr, = concat_ref.children index, = concat_ref.multiindex assert isinstance(concat_expr, Concatenate) try: multiindices = cache[index] except KeyError: multiindices = tuple( tuple(Index(extent=d) for d in child.shape) for child in concat_expr.children) cache[index] = multiindices def cut(node): """No need to rebuild expression of independent of the relevant concatenation index.""" return index not in node.free_indices # Build Concatenate node replacement mappings mappings = [{} for i in range(len(multiindices))] for concat_ref in concat_group: concat_expr, = concat_ref.children for i in range(len(multiindices)): sub_ref = Indexed(concat_expr.children[i], multiindices[i]) sub_ref, = remove_componenttensors((sub_ref, )) mappings[i][concat_ref] = sub_ref # Finally, split assignment pairs split_pairs = [] for var, expr in pairs: if index not in var.free_indices: split_pairs.append((var, expr)) else: for v, m in zip(split_variable(var, index, multiindices), mappings): split_pairs.append((v, replace_node(expr, m, cut))) # Run again, there may be other Concatenate groups return _unconcatenate(cache, split_pairs)
def select_expression(expressions, index): """Select an expression from a list of expressions with an index. Semantically equivalent to partial_indexed(ListTensor(expressions), (index,)) but has a much more optimised implementation. :arg expressions: a list of expressions of the same shape :arg index: an index (free, fixed or variable) :returns: an expression of the same shape as the given expressions """ # Check arguments shape = expressions[0].shape assert all(e.shape == shape for e in expressions) # Sanitise input expressions alpha = tuple(Index() for s in shape) exprs = remove_componenttensors([Indexed(e, alpha) for e in expressions]) # Factor the expressions recursively and convert result selected = _select_expression(exprs, index) return ComponentTensor(selected, alpha)