def _expand_conditional_conditional(node, self): if self.predicate(node): condition, then, else_ = map(self, node.children) return Sum(Product(Conditional(condition, one, Zero()), then), Product(Conditional(condition, Zero(), one), else_)) else: return reuse_if_untouched(node, self)
def _replace_delta_delta(node, self): i, j = node.i, node.j if isinstance(i, Index) or isinstance(j, Index): if isinstance(i, Index) and isinstance(j, Index): assert i.extent == j.extent if isinstance(i, Index): assert i.extent is not None size = i.extent if isinstance(j, Index): assert j.extent is not None size = j.extent return Indexed(Identity(size), (i, j)) else: def expression(index): if isinstance(index, int): return Literal(index) elif isinstance(index, VariableIndex): return index.expression else: raise ValueError("Cannot convert running index to expression.") e_i = expression(i) e_j = expression(j) return Conditional(Comparison("==", e_i, e_j), one, Zero())
def test_conditional_simplification(): a = Variable("A", ()) b = Variable("B", ()) expr = Conditional(LogicalAnd(b, a), a, a) assert expr == a
def _collect_monomials_conditional(expression, self): """Refactorises a conditional expression into a sum-of-products form, pulling only "atomics" out of conditional expressions. :arg expression: a GEM expression to refactorise :arg self: function for recursive calls :returns: :py:class:`MonomialSum` """ condition, then, else_ = expression.children # Recursively refactorise both branches to `MonomialSum`s then_ms = self(then) else_ms = self(else_) result = MonomialSum() # For each set of atomics, create a new Conditional node. Atomics # are considered safe to be pulled out of conditionals, but other # expressions remain inside conditional branches. zero = Zero() for k in then_ms.monomials.keys() | else_ms.monomials.keys(): _then = then_ms.monomials.get(k, zero) _else = else_ms.monomials.get(k, zero) result.monomials[k] = Conditional(condition, _then, _else) # Construct a deterministic ordering result.ordering = then_ms.ordering.copy() for k, v in else_ms.ordering.items(): result.ordering.setdefault(k, v) return result
def test_conditional_zero_folding(): b = Variable("B", ()) a = Variable("A", (3, )) i = Index() expr = Conditional(LogicalAnd(b, b), Product(Indexed(a, (i, )), Zero()), Zero()) assert expr == Zero()