示例#1
0
    def statements(self, py_stmts: t.List, in_stencil_root_scope: bool = False):
        sir_stmts = []
        for stmt in py_stmts:
            if in_stencil_root_scope and isinstance(stmt, AnnAssign):
                self.temporary_field_declaration(stmt)
                continue

            # TODO: bad hardcoded strings
            stmt = dispatch(
                {
                    OneOf(Assign, AugAssign): self.assign,
                    If: self.if_stmt,
                    With(
                        items=FixedList(
                            withitem(
                                context_expr=Subscript(
                                    value=name("sparse"), slice=_, ctx=_
                                ),
                                optional_vars=_,
                            )
                        ),
                        body=_,
                        type_comment=_,
                    ): self.loop_stmt,
                    # assume a vertical region by default
                    With: self.vertical_loop,
                    Pass: lambda pass_node: None,
                },
                stmt,
            )
            if stmt is not None:
                sir_stmts.append(stmt)
        return sir_stmts
示例#2
0
class Grammar:
    @staticmethod
    def is_stencil(node) -> bool:
        return does_match(
            FunctionDef(
                name=_,
                args=_,
                body=_,
                decorator_list=FixedList(name(stencil_decorator.__name__)),
                returns=_,
                type_comment=_,
            ),
            node,
        )

    def __init__(self):
        self.ctx = DuskContextHelper()

    @transform(
        FunctionDef(
            name=Capture(str).to("name"),
            args=arguments(
                posonlyargs=EmptyList,
                args=Capture(Repeat(arg)).to("fields"),
                vararg=None,
                kwonlyargs=EmptyList,
                kw_defaults=EmptyList,
                kwarg=None,
                defaults=EmptyList,
            ),
            body=Capture(_).to("body"),
            decorator_list=FixedList(name(stencil_decorator.__name__)),
            returns=Optional(Constant(value=None, kind=None)),
            type_comment=None,
        ))
    def stencil(self, name: str, body: t.List, fields: t.List):
        with self.ctx.scope.new_scope():
            for field in fields:
                self.field_declaration(field)
            body = make_ast(self.statements(body, in_stencil_root_scope=True))
            fields = [
                symbol.sir for symbol in self.ctx.scope.current_scope
                if isinstance(symbol, (DuskField, DuskIndexField))
            ]
        return make_stencil(name, body, fields)

    @transform(
        arg(
            arg=Capture(str).to("name"),
            annotation=Capture(expr).to("field_type"),
            type_comment=None,
        ))
    def field_declaration(self, name: str, field_type: expr):
        self.add_field_declaration(name, field_type)

    @transform(
        AnnAssign(
            target=name(Capture(str).to("name"), ctx=Store),
            value=None,
            annotation=Capture(expr).to("field_type"),
            simple=1,
        ), )
    def temporary_field_declaration(self, name: str, field_type: expr):
        self.add_field_declaration(name, field_type, is_temporary=True)

    def add_field_declaration(self,
                              name: str,
                              field_type: expr,
                              is_temporary: bool = False):
        field_type, hindex, vindex = self.field_type(field_type)

        assert field_type in {"Field", "IndexField"}
        DuskFieldType = DuskField if field_type == "Field" else DuskIndexField

        if hindex is not None:
            dimensions = make_field_dimensions_unstructured(hindex, vindex)
        else:
            dimensions = make_field_dimensions_vertical()

        self.ctx.scope.current_scope.add(
            name, DuskFieldType(make_field(name, dimensions, is_temporary)))

    @transform(
        Subscript(
            # TODO: hardcoded string
            value=name(Capture(OneOf("Field", "IndexField")).to("field_type")),
            slice=Index(value=OneOf(
                Tuple(
                    elts=FixedList(
                        Capture(_).to("hindex"),
                        name(Capture("K").to("vindex")),
                    ),
                    ctx=Load,
                ),
                name(Capture("K").to("vindex")),
                Capture(_).to("hindex"),
            ), ),
            ctx=Load,
        ))
    def field_type(self,
                   field_type: str,
                   hindex: expr = None,
                   vindex: str = None):
        return (
            field_type,
            self.location_chain(hindex) if hindex is not None else None,
            1 if vindex is not None else 0,
        )

    @transform(
        OneOf(
            name(Capture(str).append("locations")),
            Compare(
                left=name(Capture(str).append("locations")),
                ops=Repeat(Gt),
                comparators=Repeat(name(Capture(str).append("locations"))),
            ),
        ))
    def location_chain(self, locations: t.List):
        return [self.location_type(location) for location in locations]

    @transform(Capture(str).to("name"))
    def location_type(self, name: str):
        location_names = {l.__name__ for l in LOCATION_TYPES}

        if name not in location_names:
            raise DuskSyntaxError(f"Invalid location type '{name}'!", name)
        return sir.LocationType.Value(name)

    @transform(Capture(list).to("py_stmts"))
    def statements(self,
                   py_stmts: t.List,
                   in_stencil_root_scope: bool = False):
        sir_stmts = []
        for stmt in py_stmts:
            if in_stencil_root_scope and isinstance(stmt, AnnAssign):
                self.temporary_field_declaration(stmt)
                continue

            # TODO: bad hardcoded strings
            stmt = dispatch(
                {
                    OneOf(Assign, AugAssign):
                    self.assign,
                    If:
                    self.if_stmt,
                    With(
                        items=FixedList(
                            withitem(
                                context_expr=Subscript(value=name("sparse"),
                                                       slice=_,
                                                       ctx=_),
                                optional_vars=_,
                            )),
                        body=_,
                        type_comment=_,
                    ):
                    self.loop_stmt,
                    # assume a vertical region by default
                    With:
                    self.vertical_loop,
                    Pass:
                    lambda pass_node: None,
                },
                stmt,
            )
            if stmt is not None:
                sir_stmts.append(stmt)
        return sir_stmts

    @transform(
        OneOf(
            Assign(
                targets=FixedList(Capture(expr).to("lhs")),
                value=Capture(expr).to("rhs"),
                type_comment=None,
            ),
            AugAssign(
                target=Capture(expr).to("lhs"),
                op=Capture(operator).to("op"),
                value=Capture(expr).to("rhs"),
            ),
        ), )
    def assign(self, lhs: expr, rhs: expr, op: t.Optional[operator] = None):

        py_assign_op_to_sir_assign_op = {
            Add: "+=",
            Sub: "-=",
            Mult: "*=",
            Div: "/=",
            Mod: "%=",
            LShift: "<<=",
            RShift: ">>=",
            BitOr: "|=",
            BitXor: "^=",
            BitAnd: "&=",
        }

        if op is None:
            op = "="

        elif isinstance(op, Pow):
            op = "="
            rhs = make_fun_call_expr(
                "gridtools::dawn::math::pow",
                [self.expression(lhs),
                 self.expression(rhs)],
            )
            return make_assignment_stmt(self.expression(lhs), rhs, op)

        elif type(op) in py_assign_op_to_sir_assign_op.keys():
            op = py_assign_op_to_sir_assign_op[type(op)]

        else:
            raise DuskSyntaxError(f"Unsupported assignment operator '{op}'!",
                                  op)

        return make_assignment_stmt(self.expression(lhs), self.expression(rhs),
                                    op)

    @transform(
        If(
            test=Capture(expr).to("condition"),
            body=Capture(list).to("body"),
            orelse=Capture(list).to("orelse"),
        ))
    def if_stmt(self, condition: expr, body: t.List, orelse: t.List):

        condition = make_expr_stmt(self.expression(condition))
        body = make_block_stmt(self.statements(body))
        orelse = make_block_stmt(self.statements(orelse))

        return make_if_stmt(condition, body, orelse)

    @transform(
        With(
            items=FixedList(
                # TODO: hardcoded strings
                withitem(
                    context_expr=OneOf(
                        name(
                            Capture(OneOf("levels_upward",
                                          "levels_downward")).to("order"), ),
                        Subscript(
                            value=name(id=Capture(
                                OneOf("levels_upward", "levels_downward")).to(
                                    "order")),
                            slice=Slice(
                                lower=Capture(_).to("lower"),
                                upper=Capture(_).to("upper"),
                                step=None,
                            ),
                            ctx=Load,
                        ),
                    ),
                    optional_vars=Optional(
                        name(Capture(str).to("var"), ctx=Store)),
                ), ),
            body=Capture(_).to("body"),
            type_comment=None,
        ), )
    def vertical_loop(self,
                      order,
                      body,
                      upper=None,
                      lower=None,
                      var: str = None):

        if lower is None:
            lower_level, lower_offset = sir.Interval.Start, 0
        else:
            lower_level, lower_offset = self.vertical_interval_bound(lower)

        if upper is None:
            upper_level, upper_offset = sir.Interval.End, 0
        else:
            upper_level, upper_offset = self.vertical_interval_bound(upper)

        order_mapper = {
            "levels_upward": sir.VerticalRegion.Forward,
            "levels_downward": sir.VerticalRegion.Backward,
        }
        with self.ctx.vertical_region(var):
            return make_vertical_region_decl_stmt(
                make_ast(self.statements(body)),
                make_interval(lower_level, upper_level, lower_offset,
                              upper_offset),
                order_mapper[order],
            )

    # TODO: richer vertical interval bounds
    @transform(Capture(OneOf(Constant, UnaryOp)).to("bound"))
    def vertical_interval_bound(self, bound):
        if does_match(Constant(value=int, kind=None), bound):
            return sir.Interval.Start, bound.value
        elif does_match(
                UnaryOp(op=USub, operand=Constant(value=int, kind=None)),
                bound):
            return sir.Interval.End, -bound.operand.value
        else:
            raise DuskSyntaxError(
                f"Unrecognized vertical intervals bound '{bound}'!", bound)

    @transform(
        With(
            items=FixedList(
                # TODO: bad hardcoded string `neighbors`
                withitem(
                    context_expr=Subscript(
                        value=name(id="sparse"),
                        slice=Index(value=Capture(_).to("neighborhood")),
                        ctx=Load,
                    ),
                    optional_vars=None,
                )),
            body=Capture(_).to("body"),
            type_comment=None,
        ))
    def loop_stmt(self, neighborhood, body: t.List):
        neighborhood = self.location_chain(neighborhood)

        with self.ctx.location.loop_stmt(neighborhood):
            body = self.statements(body)

        return make_loop_stmt(body, neighborhood)

    @transform(Capture(expr).to("expr"))
    def expression(self, expr: expr):
        return make_expr(
            dispatch(
                {
                    Constant: self.constant,
                    Name: self.var,
                    Subscript: self.subscript,
                    UnaryOp: self.unop,
                    BinOp: self.binop,
                    BoolOp: self.boolop,
                    Compare: self.compare,
                    IfExp: self.ifexp,
                    Call: self.funcall,
                },
                expr,
            ))

    @transform(Constant(value=Capture(_).to("value"), kind=None))
    def constant(self, value):
        # TODO: properly distinguish between float and double
        built_in_type_map = {bool: "Boolean", int: "Integer", float: "Double"}

        if type(value) not in built_in_type_map.keys():
            raise DuskSyntaxError(
                f"Unsupported constant '{value}' of type '{type(value)}'!",
                value)

        _type = sir.BuiltinType.TypeID.Value(built_in_type_map[type(value)])

        if isinstance(value, bool):
            value = "true" if value else "false"
        else:
            # TODO: does `str` really work here? (what about NaNs, precision, 1e11 notation, etc)
            value = str(value)

        return make_literal_access_expr(
            value,
            _type,
        )

    @transform(Name(id=Capture(str).to("name"), ctx=AnyContext))
    def var(self, name: str, index: expr = None):

        if not self.ctx.scope.current_scope.contains(name):
            raise DuskSyntaxError(f"Undeclared variable '{name}'!", name)

        symbol = self.ctx.scope.current_scope.fetch(name)
        if isinstance(symbol, DuskField):
            return self.field_access_expr(symbol, index)
        else:
            raise DuskInternalError(
                f"Encountered unknown symbol type '{symbol}' ('{name}')!")

    @transform(
        Subscript(
            value=Capture(Name).to("var"),
            slice=Index(value=Capture(expr).to("index")),
            ctx=AnyContext,
        ))
    def subscript(self, var: expr, index: expr):
        return self.var(var, index=index)

    def field_access_expr(self, field: DuskField, index: expr = None):
        if not self.ctx.location.in_vertical_region:
            raise DuskSyntaxError(
                f"Invalid field access {name} outside of a vertical region!")
        return make_unstructured_field_access_expr(
            field.sir.name, *self.field_index(index, field=field))

    @transform(
        OneOf(
            Tuple(
                elts=FixedList(
                    Capture(OneOf(Compare, Name)).to("hindex"),
                    Capture(expr).to("vindex"),
                ),
                ctx=Load,
            ),
            # FIXME: ensure built-ins (like `Edge`) aren't _shadowed_ by variables
            # TODO: hardcoded string
            Capture(OneOf(Compare, name(OneOf("Edge", "Cell",
                                              "Vertex")))).to("hindex"),
            Capture(expr).to("vindex"),
            None,
        ))
    def field_index(self, field: DuskField, vindex=None, hindex=None):

        voffset, vbase = (self.relative_vertical_offset(vindex)
                          if vindex is not None else (0, None))
        hindex = self.location_chain(hindex) if hindex is not None else None

        if not self.ctx.location.in_neighbor_iteration:
            if hindex is not None:
                raise DuskSyntaxError(
                    f"Invalid horizontal index for field '{field.sir.name}' "
                    "outside of neighbor iteration!")
            return make_unstructured_offset(False), voffset, vbase

        neighbor_iteration = self.ctx.location.current_neighbor_iteration
        field_dimension = self.ctx.location.get_field_dimension(field.sir)

        # TODO: `vindex` is _non-sensical_ if the field is 2d

        # TODO: we should check that `field_dimension` is valid for
        #       the current neighbor iteration(s?)

        if hindex is None:
            if self.ctx.location.is_dense(field_dimension):
                if self.ctx.location.is_ambiguous(neighbor_iteration):
                    raise DuskSyntaxError(
                        f"Field '{field.sir.name}' requires a horizontal index "
                        "inside of ambiguous neighbor iteration!")

                return (
                    make_unstructured_offset(
                        field_dimension[0] == neighbor_iteration[-1]),
                    voffset,
                    vbase,
                )
            return make_unstructured_offset(True), voffset, vbase

        # TODO: check if `hindex` is valid for this field's location type

        if len(hindex) == 1:
            if neighbor_iteration[0] != hindex[0]:
                raise DuskSyntaxError(
                    f"Invalid horizontal offset for field '{field.sir.name}'!")
            return make_unstructured_offset(False), voffset, vbase

        if hindex != neighbor_iteration:
            raise DuskSyntaxError(
                f"Invalid horizontal offset for field '{field.sir.name}'!")

        return make_unstructured_offset(True), voffset, vbase

    @transform(
        OneOf(
            BinOp(
                left=name(Capture(str).to("base")),
                op=Capture(OneOf(Add, Sub)).to("vop"),
                right=Constant(value=Capture(int).to("shift"), kind=None),
            ),
            name(Capture(str).to("base")),
        ), )
    def relative_vertical_offset(self, base: str, shift: int = 0, vop=Add()):
        base = self.ctx.scope.current_scope.fetch(base)

        if not isinstance(base, (VerticalIterationVariable, DuskIndexField)):
            raise DuskSyntaxError(
                f"'{base}' isn't a vertical iteration variable or index field!",
                base)
        if isinstance(base, DuskIndexField):
            # TODO: check that `base` is valid in this context
            #       * compatible neighbor iteration
            #       * _correct_ usage (not clear what this entails)
            base = base.sir.name
            if (self.ctx.location.in_neighbor_iteration
                    and self.ctx.location.is_ambiguous(
                        self.ctx.location.current_neighbor_iteration)):
                raise DuskSyntaxError(
                    f"Index field {base} used in ambiguous neighbor iteration!"
                )
        elif isinstance(base, VerticalIterationVariable):
            base = None
        else:
            assert False

        return (-shift if isinstance(vop, Sub) else shift), base

    @transform(
        UnaryOp(
            operand=Capture(expr).to("expr"),
            op=Capture(OneOf(UAdd, USub, Not)).to("op"),
        ))
    def unop(self, expr: expr, op):
        py_unop_to_sir_unop = {UAdd: "+", USub: "-", Not: "!"}
        return make_unary_operator(py_unop_to_sir_unop[type(op)],
                                   self.expression(expr))

    @transform(
        BinOp(
            left=Capture(expr).to("left"),
            op=Capture(operator).to("op"),
            right=Capture(expr).to("right"),
        ))
    def binop(self, left: expr, op: t.Any, right: expr):
        py_binops_to_sir_binops = {
            Add: "+",
            Sub: "-",
            Mult: "*",
            Div: "/",
            LShift: "<<",
            RShift: ">>",
            BitOr: "|",
            BitXor: "^",
            BitAnd: "&",
        }

        if type(op) in py_binops_to_sir_binops.keys():
            op = py_binops_to_sir_binops[type(op)]
            return make_binary_operator(self.expression(left), op,
                                        self.expression(right))

        elif isinstance(op, Pow):
            return make_fun_call_expr(
                "gridtools::dawn::math::pow",
                [self.expression(left),
                 self.expression(right)],
            )

        else:
            raise DuskSyntaxError(f"Unsupported binary operator '{op}'!", op)

    @transform(
        BoolOp(
            op=Capture(OneOf(And, Or)).to("op"),
            values=Capture(Repeat(expr)).to("values"),
        ))
    def boolop(self, op, values: t.List):
        py_boolops_to_sir_boolops = {And: "&&", Or: "||"}
        op = py_boolops_to_sir_boolops[type(op)]

        *remainder, last = values
        binop = self.expression(last)

        for value in reversed(remainder):
            binop = make_binary_operator(self.expression(value), op, binop)

        return binop

    @transform(
        Compare(
            left=Capture(expr).to("left"),
            # currently we only support two operands
            ops=FixedList(Capture(_).to("op")),
            comparators=FixedList(Capture(expr).to("right")),
        ), )
    def compare(self, left: expr, op, right: expr):
        # FIXME: we should probably have a better answer when we need such mappings
        py_compare_to_sir_compare = {
            Eq: "==",
            NotEq: "!=",
            Lt: "<",
            LtE: "<=",
            Gt: ">",
            GtE: ">=",
        }
        if type(op) not in py_compare_to_sir_compare.keys():
            raise DuskSyntaxError(f"Unsupported comparison operator '{op}'!",
                                  op)
        op = py_compare_to_sir_compare[type(op)]
        return make_binary_operator(self.expression(left), op,
                                    self.expression(right))

    @transform(
        IfExp(
            test=Capture(expr).to("condition"),
            body=Capture(expr).to("body"),
            orelse=Capture(expr).to("orelse"),
        ))
    def ifexp(self, condition: expr, body: expr, orelse: expr):

        condition = self.expression(condition)
        body = self.expression(body)
        orelse = self.expression(orelse)

        return make_ternary_operator(condition, body, orelse)

    @transform(
        Capture(Call(
            func=name(Capture(str).to("name")),
            args=_,
            keywords=_,
        )).to("node"))
    def funcall(self, name: str, node: Call):
        # TODO: bad hardcoded string
        if name == "reduce_over":
            return self.reduce_over(node)

        if name in {"sum_over", "min_over", "max_over"}:
            return self.short_reduce_over(node)

        if name in self.unary_math_functions or name in self.binary_math_functions:
            return self.math_function(node)

        raise DuskSyntaxError(f"Unrecognized function call '{name}'!", node)

    unary_math_functions = {f.__name__ for f in UNARY_MATH_FUNCTIONS}
    binary_math_functions = {f.__name__ for f in BINARY_MATH_FUNCTIONS}

    @transform(
        Call(
            func=name(Capture(str).to("name")),
            args=Capture(list).to("args"),
            keywords=EmptyList,
        ))
    def math_function(self, name: str, args: t.List):

        if name in self.unary_math_functions:
            if len(args) != 1:
                raise DuskSyntaxError(
                    f"Function '{name}' takes exactly one argument!")
            return make_fun_call_expr(f"gridtools::dawn::math::{name}",
                                      [self.expression(args[0])])

        if name in self.binary_math_functions:
            if len(args) != 2:
                raise DuskSyntaxError(
                    f"Function '{name}' takes exactly two arguments!")
            return make_fun_call_expr(
                f"gridtools::dawn::math::{name}",
                [self.expression(arg) for arg in args],
            )

        raise DuskSyntaxError(f"Unrecognized function call '{name}'!")

    @transform(
        Call(
            # TODO: bad hardcoded string
            func=name("reduce_over"),
            args=FixedList(
                Capture(expr).to("neighborhood"),
                Capture(expr).to("expr"),
                name(Capture(str).to("op")),
            ),
            keywords=Repeat(
                keyword(
                    arg=Capture(str).append("kwargs_keys"),
                    value=Capture(expr).append("kwargs_values"),
                )),
        ), )
    def reduce_over(
        self,
        expr: expr,
        neighborhood: expr,
        op: str,
        kwargs_keys: t.List[str] = [],
        kwargs_values: t.List[expr] = [],
    ):
        return self.reduction(expr, neighborhood, op, kwargs_keys,
                              kwargs_values)

    @transform(
        Call(
            func=name(
                # TODO: bad hardcoded string
                Capture(OneOf("sum_over", "min_over",
                              "max_over")).to("short_cut_name")),
            args=FixedList(
                Capture(expr).to("neighborhood"),
                Capture(expr).to("expr"),
            ),
            keywords=Repeat(
                keyword(
                    arg=Capture(str).append("kwargs_keys"),
                    value=Capture(expr).append("kwargs_values"),
                )),
        ), )
    def short_reduce_over(
        self,
        expr: expr,
        neighborhood: expr,
        short_cut_name: str,
        kwargs_keys: t.List[str] = [],
        kwargs_values: t.List[expr] = [],
    ):
        short_cut_to_op_map = {
            "sum_over": "sum",
            "min_over": "min",
            "max_over": "max"
        }
        op = short_cut_to_op_map[short_cut_name]

        return self.reduction(expr, neighborhood, op, kwargs_keys,
                              kwargs_values)

    def reduction(
        self,
        expr: expr,
        neighborhood: expr,
        op: str,
        kwargs_keys: t.List[str],
        kwargs_values: t.List[expr],
    ):
        kwargs = dict(zip(kwargs_keys, kwargs_values))
        assert len(kwargs) == len(kwargs_keys) == len(kwargs_values)

        # TODO: what about these hard coded strings?
        wrong_kwargs = kwargs.keys() - {"init", "weights"}
        if 0 < len(wrong_kwargs):
            raise DuskSyntaxError(
                f"Unsupported kwargs '{wrong_kwargs}' in reduction!")

        neighborhood = self.location_chain(neighborhood)
        with self.ctx.location.reduction(neighborhood):
            expr = self.expression(expr)

        op_map = {"sum": "+", "mul": "*", "min": "min", "max": "max"}
        if not op in op_map:
            raise DuskSyntaxError(f"Invalid operator '{op}' for reduction!")

        if "init" in kwargs:
            init = self.expression(kwargs["init"])
        else:
            # TODO: "min" and "max" are still kinda stupid
            # we should use something like this:
            # https://en.cppreference.com/w/cpp/types/numeric_limits/max
            # the current solution is:
            #   - appropriate for doubles
            #   - okish for floats (may trip floating point exceptions but correct outcome)
            #   - breaks for int (undefined behavior!)
            init_map = {
                "sum": "0",
                "mul": "1",
                "min": "1.79769313486231571e+308",
                "max": "-1.79769313486231571e+308",
            }
            init = make_literal_access_expr(
                init_map[op], sir.BuiltinType.TypeID.Value("Double"))

        op = op_map[op]

        weights = None
        if "weights" in kwargs:
            # TODO: check for `kwargs["weight"].ctx == Load`?
            weights = [
                self.expression(weight) for weight in kwargs["weights"].elts
            ]

        return make_reduction_over_neighbor_expr(
            op,
            expr,
            init,
            neighborhood,
            weights,
        )
示例#3
0
    IndexField as DuskIndexField,
    VerticalIterationVariable,
    DuskContextHelper,
)
from dusk.script import stencil as stencil_decorator
from dusk.script.stubs import (
    LOCATION_TYPES,
    UNARY_MATH_FUNCTIONS,
    BINARY_MATH_FUNCTIONS,
)
from dusk.errors import DuskInternalError, DuskSyntaxError
from dusk.util import pprint_matcher as pprint

# Short cuts
EmptyList = FixedList()
AnyContext = OneOf(Load, Store, Del, AugLoad, AugStore, Param)


def name(id, ctx=Load) -> Name:
    return Name(id=id, ctx=ctx)


def transform(matcher) -> t.Callable:
    def decorator(transformer: t.Callable) -> t.Callable:
        def transformer_with_matcher(self, node, *args, **kwargs):
            captures = {}
            match(matcher, node, capturer=captures)
            return transformer(self, *args, **captures, **kwargs)

        return transformer_with_matcher