def applier(expr): pairs = [(i, j) for i, j in renames if i in expr.free_indices] if pairs: current, renamed = zip(*pairs) return Indexed(ComponentTensor(expr, current), renamed) else: return expr
def split_variable(variable_ref, index, multiindices): """Splits a flexibly indexed variable along a concatenation index. :param variable_ref: flexibly indexed variable to split :param index: :py:class:`Concatenate` index to split along :param multiindices: one multiindex for each split variable :returns: generator of split indexed variables """ assert isinstance(variable_ref, FlexiblyIndexed) other_indices = list(variable_ref.index_ordering()) other_indices.remove(index) other_indices = tuple(other_indices) data = ComponentTensor(variable_ref, (index, ) + other_indices) slices = [slice(None)] * len(other_indices) shapes = [(other_index.extent, ) for other_index in other_indices] offset = 0 for multiindex in multiindices: shape = tuple(index.extent for index in multiindex) size = numpy.prod(shape, dtype=int) slice_ = slice(offset, offset + size) offset += size sub_ref = Indexed(reshape(view(data, slice_, *slices), shape, *shapes), multiindex + other_indices) sub_ref, = remove_componenttensors((sub_ref, )) yield sub_ref
def substitute(expression, from_, to_): if from_ not in expression.free_indices: return expression elif isinstance(expression, Delta): mapper = MemoizerArg(filtered_replace_indices) return mapper(expression, ((from_, to_), )) else: return Indexed(ComponentTensor(expression, (from_, )), (to_, ))
def _select_expression(expressions, index): """Helper function to select an expression from a list of expressions with an index. This function expect sanitised input, one should normally call :py:func:`select_expression` instead. :arg expressions: a list of expressions :arg index: an index (free, fixed or variable) :returns: an expression """ expr = expressions[0] if all(e == expr for e in expressions): return expr types = set(map(type, expressions)) if types <= {Indexed, Zero}: multiindex, = set(e.multiindex for e in expressions if isinstance(e, Indexed)) # Shape only determined by free indices shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) def child(expression): if isinstance(expression, Indexed): return expression.children[0] elif isinstance(expression, Zero): return Zero(shape) return Indexed( _select_expression(list(map(child, expressions)), index), multiindex) if types <= {Literal, Zero, Failure}: return partial_indexed(ListTensor(expressions), (index, )) if types <= {ComponentTensor, Zero}: shape, = set(e.shape for e in expressions) multiindex = tuple(Index(extent=d) for d in shape) children = remove_componenttensors( [Indexed(e, multiindex) for e in expressions]) return ComponentTensor(_select_expression(children, index), multiindex) if len(types) == 1: cls, = types if cls.__front__ or cls.__back__: raise NotImplementedError( "How to factorise {} expressions?".format(cls.__name__)) assert all(len(e.children) == len(expr.children) for e in expressions) assert len(expr.children) > 0 return expr.reconstruct(*[ _select_expression(nth_children, index) for nth_children in zip(*[e.children for e in expressions]) ]) raise NotImplementedError( "No rule for factorising expressions of this kind.")
def _(node, self): unroll = tuple(filter(self.predicate, node.multiindex)) if unroll: # Unrolling summand = self(node.children[0]) shape = tuple(index.extent for index in unroll) unrolled = reduce(Sum, (Indexed(ComponentTensor(summand, unroll), alpha) for alpha in numpy.ndindex(shape)), Zero()) return IndexSum( unrolled, tuple(index for index in node.multiindex if index not in unroll)) else: return reuse_if_untouched(node, self)
def contraction(expression): """Optimise the contractions of the tensor product at the root of the expression, including: - IndexSum-Delta cancellation - Sum factorisation This routine was designed with finite element coefficient evaluation in mind. """ # Eliminate annoying ComponentTensors expression, = remove_componenttensors([expression]) # Flatten product tree, eliminate deltas, sum factorise def rebuild(expression): sum_indices, factors = delta_elimination(*traverse_product(expression)) factors = remove_componenttensors(factors) return sum_factorise(sum_indices, factors) # Sometimes the value shape is composed as a ListTensor, which # could get in the way of decomposing factors. In particular, # this is the case for H(div) and H(curl) conforming tensor # product elements. So if ListTensors are used, they are pulled # out to be outermost, so we can straightforwardly factorise each # of its entries. lt_fis = OrderedDict() # ListTensor free indices for node in traversal((expression, )): if isinstance(node, Indexed): child, = node.children if isinstance(child, ListTensor): lt_fis.update(zip_longest(node.multiindex, ())) lt_fis = tuple(index for index in lt_fis if index in expression.free_indices) if lt_fis: # Rebuild each split component tensor = ComponentTensor(expression, lt_fis) entries = [ Indexed(tensor, zeta) for zeta in numpy.ndindex(tensor.shape) ] entries = remove_componenttensors(entries) return Indexed( ListTensor( numpy.array(list(map(rebuild, entries))).reshape(tensor.shape)), lt_fis) else: # Rebuild whole expression at once return rebuild(expression)
def select_expression(expressions, index): """Select an expression from a list of expressions with an index. Semantically equivalent to partial_indexed(ListTensor(expressions), (index,)) but has a much more optimised implementation. :arg expressions: a list of expressions of the same shape :arg index: an index (free, fixed or variable) :returns: an expression of the same shape as the given expressions """ # Check arguments shape = expressions[0].shape assert all(e.shape == shape for e in expressions) # Sanitise input expressions alpha = tuple(Index() for s in shape) exprs = remove_componenttensors([Indexed(e, alpha) for e in expressions]) # Factor the expressions recursively and convert result selected = _select_expression(exprs, index) return ComponentTensor(selected, alpha)