def eager_reduce_exp(op, arg, reduced_vars): # x.exp().reduce(ops.add) == x.reduce(ops.logaddexp).exp() log_result = arg.arg.reduce(ops.logaddexp, reduced_vars) if log_result is not normalize(Reduce, ops.logaddexp, arg.arg, reduced_vars): return log_result.exp() return None
def normalize_contraction_commutative_canonical_order(red_op, bin_op, reduced_vars, *terms): # when bin_op is commutative, put terms into a canonical order for pattern matching new_terms = tuple(v for i, v in sorted( enumerate(terms), key=lambda t: (ORDERING.get(type(t[1]).__origin__, -1), t[0]))) if any(v is not vv for v, vv in zip(terms, new_terms)): return Contraction(red_op, bin_op, reduced_vars, *new_terms) return normalize(Contraction, red_op, bin_op, reduced_vars, new_terms)
def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms): # push down leaf reductions terms, reduced_vars, leaf_reduced = list(terms), frozenset( reduced_vars), False for i, v in enumerate(terms): unique_vars = reduced_vars.intersection(v.inputs) - \ frozenset().union(*(reduced_vars.intersection(vv.inputs) for vv in terms if vv is not v)) if unique_vars: result = v.reduce(red_op, unique_vars) if result is not normalize(Contraction, red_op, nullop, unique_vars, (v, )): terms[i] = result reduced_vars -= unique_vars leaf_reduced = True if leaf_reduced: return Contraction(red_op, bin_op, reduced_vars, *terms) # exploit associativity to recursively evaluate this contraction # a bit expensive, but handles interpreter-imposed directionality constraints terms = tuple(terms) # return reduce(bin_op, terms).reduce(red_op, reduced_vars) # for i, (lhs, rhs) in enumerate(zip(terms[0:-1], terms[1:])): for i, lhs in enumerate(terms[0:-1]): for j_, rhs in enumerate(terms[i + 1:]): j = i + j_ + 1 unique_vars = reduced_vars.intersection(lhs.inputs, rhs.inputs) - \ frozenset().union(*(reduced_vars.intersection(vv.inputs) for vv in terms[:i] + terms[i+1:j] + terms[j+1:])) result = Contraction(red_op, bin_op, unique_vars, lhs, rhs) if result is not normalize(Contraction, red_op, bin_op, unique_vars, (lhs, rhs)): # did we make progress? # pick the first evaluable pair reduced_vars -= unique_vars new_terms = terms[:i] + (result, ) + terms[i + 1:j] + terms[j + 1:] return Contraction(red_op, bin_op, reduced_vars, *new_terms) return None
def normalize_contraction_generic_args(red_op, bin_op, reduced_vars, *terms): return normalize(Contraction, red_op, bin_op, reduced_vars, tuple(terms))