def sum_factorise(sum_indices, factors): """Optimise a tensor product through sum factorisation. :arg sum_indices: free indices for contractions :arg factors: product factors :returns: optimised GEM expression """ if len(sum_indices) > 5: raise NotImplementedError("Too many indices for sum factorisation!") # Form groups by free indices groups = OrderedDict() for factor in factors: groups.setdefault(factor.free_indices, []).append(factor) groups = [reduce(Product, terms) for terms in itervalues(groups)] # Sum factorisation expression = None best_flops = numpy.inf # Consider all orderings of contraction indices for ordering in permutations(sum_indices): terms = groups[:] flops = 0 # Apply contraction index by index for sum_index in ordering: # Select terms that need to be part of the contraction contract = [t for t in terms if sum_index in t.free_indices] deferred = [t for t in terms if sum_index not in t.free_indices] # A further optimisation opportunity is to consider # various ways of building the product tree. product = reduce(Product, contract) term = IndexSum(product, (sum_index,)) # For the operation count estimation we assume that no # operations were saved with the particular product tree # that we built above. flops += len(contract) * numpy.prod([i.extent for i in product.free_indices], dtype=int) # Replace the contracted terms with the result of the # contraction. terms = deferred + [term] # If some contraction indices were independent, then we may # still have several terms at this point. expr = reduce(Product, terms) flops += (len(terms) - 1) * numpy.prod([i.extent for i in expr.free_indices], dtype=int) if flops < best_flops: expression = expr best_flops = flops return expression
def sum_factorise(sum_indices, factors): """Optimise a tensor product through sum factorisation. :arg sum_indices: free indices for contractions :arg factors: product factors :returns: optimised GEM expression """ if len(factors) == 0 and len(sum_indices) == 0: # Empty product return one if len(sum_indices) > 6: raise NotImplementedError("Too many indices for sum factorisation!") # Form groups by free indices groups = groupby(factors, key=lambda f: f.free_indices) groups = [reduce(Product, terms) for _, terms in groups] # Sum factorisation expression = None best_flops = numpy.inf # Consider all orderings of contraction indices for ordering in permutations(sum_indices): terms = groups[:] flops = 0 # Apply contraction index by index for sum_index in ordering: # Select terms that need to be part of the contraction contract = [t for t in terms if sum_index in t.free_indices] deferred = [t for t in terms if sum_index not in t.free_indices] # Optimise associativity product, flops_ = associate(Product, contract) term = IndexSum(product, (sum_index, )) flops += flops_ + numpy.prod( [i.extent for i in product.free_indices], dtype=int) # Replace the contracted terms with the result of the # contraction. terms = deferred + [term] # If some contraction indices were independent, then we may # still have several terms at this point. expr, flops_ = associate(Product, terms) flops += flops_ if flops < best_flops: expression = expr best_flops = flops return expression
def _(node, self): unroll = tuple(filter(self.predicate, node.multiindex)) if unroll: # Unrolling summand = self(node.children[0]) shape = tuple(index.extent for index in unroll) unrolled = reduce(Sum, (Indexed(ComponentTensor(summand, unroll), alpha) for alpha in numpy.ndindex(shape)), Zero()) return IndexSum( unrolled, tuple(index for index in node.multiindex if index not in unroll)) else: return reuse_if_untouched(node, self)
def monomial_sum_to_expression(monomial_sum): """Convert a monomial sum to a GEM expression. :arg monomial_sum: an iterable of :class:`Monomial`s :returns: GEM expression """ indexsums = [] # The result is summation of indexsums # Group monomials according to their sum indices groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) # Create IndexSum's from each monomial group for _, monomials in groups: sum_indices = monomials[0].sum_indices products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials] indexsums.append(IndexSum(make_sum(products), sum_indices)) return make_sum(indexsums)
def make_expression(i, j): A = Variable('A', (6, )) s = IndexSum(Indexed(A, (j, )), (j, )) return Product(Indexed(A, (i, )), s)