def match(self, expr: Expr, bindings: Bindings, succeed: Continuation) -> Success: if not isinstance(expr, JaxVar): return jaxvar_matcher = matcher.matcher((self.name, self.shape, self.dtype)) yield from jaxvar_matcher((expr.name, expr.shape, expr.dtype), bindings, succeed)
def match(self, expr: Expr, bindings: Bindings, succeed: Continuation) -> Success: """Matches the formula and operands of an `Einsum`.""" if not isinstance(expr, Einsum): return yield from matcher.matcher((self.operands, self.formula))( (expr.operands, expr.formula), bindings, succeed)
def match(self, expr: Expr, bindings: Bindings, succeed: Continuation) -> Success: """Matches the formula and operands of an `AddN`.""" if not isinstance(expr, AddN): return yield from matcher.matcher(self.operands)(expr.operands, bindings, succeed)
def match(self, expr: Expr, bindings: Bindings, succeed: Continuation) -> Success: if not isinstance(expr, Primitive): return yield from matcher.matcher( (self.primitive, self.operands, self.params))( (expr.primitive, expr.operands, expr.params), bindings, succeed)
def match(self, expr: Expr, bindings: Bindings, succeed: Continuation) -> Success: if not isinstance(expr, BoundExpression): return yield from matcher.matcher(self.expressions)(expr.expressions, bindings, succeed)
def match(self, expr: Expr, bindings: Bindings, succeed: Continuation) -> Success: if not isinstance(expr, Part): return yield from matcher.matcher((self.operand, self.index))( (expr.operand, expr.index), bindings, succeed)
def match(self, expr: Expr, bindings: Bindings, succeed: Continuation) -> Success: if not isinstance(expr, Params): return yield from matcher.matcher((self.sorted_keys, self.sorted_values))( (expr.sorted_keys, expr.sorted_values), bindings, succeed)
def match(self, expr: Expr, bindings: Bindings, succeed: Continuation) -> Success: if not isinstance(expr, Literal): return yield from matcher.matcher(self.value)(expr.value, bindings, succeed)