def _select_expression(expressions, index): """Helper function to select an expression from a list of expressions with an index. This function expect sanitised input, one should normally call :py:func:`select_expression` instead. :arg expressions: a list of expressions :arg index: an index (free, fixed or variable) :returns: an expression """ expr = expressions[0] if all(e == expr for e in expressions): return expr types = set(map(type, expressions)) if types <= {Indexed, Zero}: multiindex, = set(e.multiindex for e in expressions if isinstance(e, Indexed)) # Shape only determined by free indices shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) def child(expression): if isinstance(expression, Indexed): return expression.children[0] elif isinstance(expression, Zero): return Zero(shape) return Indexed( _select_expression(list(map(child, expressions)), index), multiindex) if types <= {Literal, Zero, Failure}: return partial_indexed(ListTensor(expressions), (index, )) if types <= {ComponentTensor, Zero}: shape, = set(e.shape for e in expressions) multiindex = tuple(Index(extent=d) for d in shape) children = remove_componenttensors( [Indexed(e, multiindex) for e in expressions]) return ComponentTensor(_select_expression(children, index), multiindex) if len(types) == 1: cls, = types if cls.__front__ or cls.__back__: raise NotImplementedError( "How to factorise {} expressions?".format(cls.__name__)) assert all(len(e.children) == len(expr.children) for e in expressions) assert len(expr.children) > 0 return expr.reconstruct(*[ _select_expression(nth_children, index) for nth_children in zip(*[e.children for e in expressions]) ]) raise NotImplementedError( "No rule for factorising expressions of this kind.")
def _flatten_concatenate(node, self): result, = evaluate([node]) return partial_indexed(Literal(result.arr), result.fids)