def _radd(self, o): if not isinstance(o, _valid_types): return NotImplemented if isinstance(o, numbers.Number) and o == 0: # Allow adding scalar int 0 as a no-op, even for shaped self, # needed for sum([a,b]) return self return Sum(o, self)
def _rsub(self, o): if not isinstance(o, _valid_types): return NotImplemented return Sum(o, -self)
def _add(self, o): if not isinstance(o, _valid_types): return NotImplemented return Sum(self, o)
def sum(self, x): """ Return the terms that might eventually yield the correct parts(!) The logic required for sums is a bit elaborate: A sum may contain terms providing different arguments. We should return (a sum of) a suitable subset of these terms. Those should all provide the same arguments. For each term in a sum, there are 2 simple possibilities: 1a) The relevant part of the term is zero -> skip. 1b) The term provides more arguments than we want -> skip 2) If all terms fall into the above category, we can just return zero. Any remaining terms may provide exactly the arguments we want, or fewer. This is where things start getting interesting. 3) Bottom-line: if there are terms with providing different arguments -- provide terms that contain the most arguments. If there are terms providing different sets of same size -> throw error (e.g. Argument(-1) + Argument(-2)) """ parts_that_provide = {} # 1. Skip terms that provide too much original_terms = x.operands() for term in original_terms: # Visit this term in the sum part, term_provides = self.visit(term) # If this part is zero or it provides more than we want, # skip it if isinstance(part, Zero) or (term_provides - self._want): continue # Add the contributions from this part to temporary list term_provides = frozenset(term_provides) if term_provides in parts_that_provide: parts_that_provide[term_provides] += [part] else: parts_that_provide[term_provides] = [part] # 2. If there are no remaining terms, return zero if not parts_that_provide: return (zero(x), set()) # 3. Return the terms that provide the biggest set most_provided = frozenset() for (provideds, parts ) in parts_that_provide.iteritems(): # TODO: Just sort instead? # Throw error if size of sets are equal (and not zero) if len(provideds) == len(most_provided) and len(most_provided): error( "Don't know what to do with sums with different Arguments." ) if provideds > most_provided: most_provided = provideds terms = parts_that_provide[most_provided] if len(terms) == len(original_terms): x = self.reuse_if_possible(x, *terms) else: x = Sum(*terms) return (x, most_provided)