def _select_expression(expressions, index): """Helper function to select an expression from a list of expressions with an index. This function expect sanitised input, one should normally call :py:func:`select_expression` instead. :arg expressions: a list of expressions :arg index: an index (free, fixed or variable) :returns: an expression """ expr = expressions[0] if all(e == expr for e in expressions): return expr types = set(map(type, expressions)) if types <= {Indexed, Zero}: multiindex, = set(e.multiindex for e in expressions if isinstance(e, Indexed)) # Shape only determined by free indices shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) def child(expression): if isinstance(expression, Indexed): return expression.children[0] elif isinstance(expression, Zero): return Zero(shape) return Indexed( _select_expression(list(map(child, expressions)), index), multiindex) if types <= {Literal, Zero, Failure}: return partial_indexed(ListTensor(expressions), (index, )) if types <= {ComponentTensor, Zero}: shape, = set(e.shape for e in expressions) multiindex = tuple(Index(extent=d) for d in shape) children = remove_componenttensors( [Indexed(e, multiindex) for e in expressions]) return ComponentTensor(_select_expression(children, index), multiindex) if len(types) == 1: cls, = types if cls.__front__ or cls.__back__: raise NotImplementedError( "How to factorise {} expressions?".format(cls.__name__)) assert all(len(e.children) == len(expr.children) for e in expressions) assert len(expr.children) > 0 return expr.reconstruct(*[ _select_expression(nth_children, index) for nth_children in zip(*[e.children for e in expressions]) ]) raise NotImplementedError( "No rule for factorising expressions of this kind.")
def aggressive_unroll(expression): """Aggressively unrolls all loop structures.""" # Unroll expression shape if expression.shape: tensor = numpy.empty(expression.shape, dtype=object) for alpha in numpy.ndindex(expression.shape): tensor[alpha] = Indexed(expression, alpha) expression, = remove_componenttensors((ListTensor(tensor), )) # Unroll summation expression, = unroll_indexsum((expression, ), predicate=lambda index: True) expression, = remove_componenttensors((expression, )) return expression
def contraction(expression): """Optimise the contractions of the tensor product at the root of the expression, including: - IndexSum-Delta cancellation - Sum factorisation This routine was designed with finite element coefficient evaluation in mind. """ # Eliminate annoying ComponentTensors expression, = remove_componenttensors([expression]) # Flatten product tree, eliminate deltas, sum factorise def rebuild(expression): sum_indices, factors = delta_elimination(*traverse_product(expression)) factors = remove_componenttensors(factors) return sum_factorise(sum_indices, factors) # Sometimes the value shape is composed as a ListTensor, which # could get in the way of decomposing factors. In particular, # this is the case for H(div) and H(curl) conforming tensor # product elements. So if ListTensors are used, they are pulled # out to be outermost, so we can straightforwardly factorise each # of its entries. lt_fis = OrderedDict() # ListTensor free indices for node in traversal((expression, )): if isinstance(node, Indexed): child, = node.children if isinstance(child, ListTensor): lt_fis.update(zip_longest(node.multiindex, ())) lt_fis = tuple(index for index in lt_fis if index in expression.free_indices) if lt_fis: # Rebuild each split component tensor = ComponentTensor(expression, lt_fis) entries = [ Indexed(tensor, zeta) for zeta in numpy.ndindex(tensor.shape) ] entries = remove_componenttensors(entries) return Indexed( ListTensor( numpy.array(list(map(rebuild, entries))).reshape(tensor.shape)), lt_fis) else: # Rebuild whole expression at once return rebuild(expression)