コード例 #1
0
ファイル: apis.py プロジェクト: vfdev-5/python-record-api
    def parameters(
        self, type: typing.Literal["function", "classmethod", "method"]
    ) -> cst.Parameters:
        posonly_params = [
            cst.Param(cst.Name(k), cst.Annotation(v.annotation))
            for k, v in self.pos_only_required.items()
        ] + [
            cst.Param(cst.Name(k),
                      cst.Annotation(v.annotation),
                      default=cst.Ellipsis()) for k, v in possibly_order_dict(
                          self.pos_only_optional,
                          self.pos_only_optional_ordering).items()
        ]

        if type == "classmethod":
            posonly_params.insert(0, cst.Param(cst.Name("cls")))
        elif type == "method":
            posonly_params.insert(0, cst.Param(cst.Name("self")))

        return cst.Parameters(
            posonly_params=posonly_params,
            params=[
                cst.Param(cst.Name(k), cst.Annotation(v.annotation))
                for k, v in self.pos_or_kw_required.items()
            ] + [
                cst.Param(cst.Name(k),
                          cst.Annotation(v.annotation),
                          default=cst.Ellipsis())
                for k, v in possibly_order_dict(
                    self.pos_or_kw_optional,
                    self.pos_or_kw_optional_ordering).items()
            ],
            star_arg=(cst.Param(
                cst.Name(self.var_pos[0]),
                cst.Annotation(self.var_pos[1].annotation),
            ) if self.var_pos else cst.MaybeSentinel.DEFAULT),
            star_kwarg=(cst.Param(cst.Name(self.var_kw[0]),
                                  cst.Annotation(self.var_kw[1].annotation))
                        if self.var_kw else None),
            kwonly_params=[
                cst.Param(cst.Name(k), cst.Annotation(v.annotation))
                for k, v in self.kw_only_required.items()
            ] + [
                cst.Param(cst.Name(k),
                          cst.Annotation(v.annotation),
                          default=cst.Ellipsis())
                for k, v in self.kw_only_optional.items()
            ],
        )
コード例 #2
0
    def test_parameters(self) -> None:
        # Test that we can insert a parameter into a function def normally.
        statement = parse_template_statement(
            "def foo({arg}): pass",
            arg=cst.Name("bar"),
        )
        self.assertEqual(
            self.code(statement),
            "def foo(bar): pass\n",
        )

        # Test that we can insert a parameter as a special case.
        statement = parse_template_statement(
            "def foo({arg}): pass",
            arg=cst.Param(cst.Name("bar")),
        )
        self.assertEqual(
            self.code(statement),
            "def foo(bar): pass\n",
        )

        # Test that we can insert a parameters list as a special case.
        statement = parse_template_statement(
            "def foo({args}): pass",
            args=cst.Parameters(
                (cst.Param(cst.Name("bar")),),
            ),
        )
        self.assertEqual(
            self.code(statement),
            "def foo(bar): pass\n",
        )

        # Test filling out multiple parameters
        statement = parse_template_statement(
            "def foo({args}): pass",
            args=cst.Parameters(
                params=(
                    cst.Param(cst.Name("bar")),
                    cst.Param(cst.Name("baz")),
                ),
                star_kwarg=cst.Param(cst.Name("rest")),
            ),
        )
        self.assertEqual(
            self.code(statement),
            "def foo(bar, baz, **rest): pass\n",
        )
コード例 #3
0
    def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
        if not any(
            QualifiedNameProvider.has_name(
                self,
                decorator.decorator,
                QualifiedName(name="builtins.classmethod", source=QualifiedNameSource.BUILTIN),
            )
            for decorator in node.decorators
        ):
            return  # If it's not a @classmethod, we are not interested.
        if not node.params.params:
            # No params, but there must be the 'cls' param.
            # Note that pyre[47] already catches this, but we also generate
            # an autofix, so it still makes sense for us to report it here.
            new_params = node.params.with_changes(params=(cst.Param(name=cst.Name(value=CLS)),))
            repl = node.with_changes(params=new_params)
            self.report(node, replacement=repl)
            return

        p0_name = node.params.params[0].name
        if p0_name.value == CLS:
            return  # All good.

        # Rename all assignments and references of the first param within the
        # function scope, as long as they are done via a Name node.
        # We rely on the parser to correctly derive all
        # assigments and references within the FunctionScope.
        # The Param node's scope is our classmethod's FunctionScope.
        scope = self.get_metadata(ScopeProvider, p0_name, None)
        if not scope:
            # Cannot autofix without scope metadata. Only report in this case.
            # Not sure how to repro+cover this in a unit test...
            # If metadata creation fails, then the whole lint fails, and if it succeeds,
            # then there is valid metadata. But many other lint rule implementations contain
            # a defensive scope None check like this one, so I assume it is necessary.
            self.report(node)
            return

        if scope[CLS]:
            # The scope already has another assignment to "cls".
            # Trying to rename the first param to "cls" as well may produce broken code.
            # We should therefore refrain from suggesting an autofix in this case.
            self.report(node)
            return

        refs: List[Union[cst.Name, cst.Attribute]] = []
        assignments = scope[p0_name.value]
        for a in assignments:
            if isinstance(a, Assignment):
                assign_node = a.node
                if isinstance(assign_node, cst.Name):
                    refs.append(assign_node)
                elif isinstance(assign_node, cst.Param):
                    refs.append(assign_node.name)
                # There are other types of possible assignment nodes: ClassDef,
                # FunctionDef, Import, etc. We deliberately do not handle those here.
            refs += [r.node for r in a.references]

        repl = node.visit(_RenameTransformer(refs, CLS))
        self.report(node, replacement=repl)
コード例 #4
0
ファイル: target_functions.py プロジェクト: tblazina/mcx
def sample_predictive(model):
    """Sample from the model's predictive distribution."""
    graph = copy.deepcopy(model.graph)

    rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key")

    # Update the SampleOps to return a sample from the distribution so that
    # `a <~ Normal(0, 1)` becomes `a = Normal(0, 1).sample(rng_key)`.
    def distribution_to_sampler(cst_generator, *args, **kwargs):
        rng_key = kwargs.pop("rng_key")
        return cst.Call(
            func=cst.Attribute(cst_generator(*args, **kwargs), cst.Name("sample")),
            args=[cst.Arg(value=rng_key)],
        )

    def model_to_sampler(model_name, *args, **kwargs):
        rng_key = kwargs.pop("rng_key")
        return cst.Call(
            func=cst.Name(value=model_name), args=[cst.Arg(value=rng_key)] + list(args)
        )

    random_variables = []
    for node in reversed(list(graph.random_variables)):
        if isinstance(node, SampleModelOp):
            node.cst_generator = partial(model_to_sampler, node.model_name)
        else:
            node.cst_generator = partial(distribution_to_sampler, node.cst_generator)
        random_variables.append(node)

    # Link the `rng_key` placeholder to the sampling expressions
    graph.add(rng_node)
    for var in random_variables:
        graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"])

    return compile_graph(graph, model.namespace, f"{graph.name}_sample")
コード例 #5
0
 def test_from_function_data(self) -> None:
     three_parameters = [
         cst.Param(name=cst.Name("x1"), annotation=None),
         cst.Param(name=cst.Name("x2"), annotation=None),
         cst.Param(name=cst.Name("x3"), annotation=None),
     ]
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=True,
             annotated_parameter_count=3,
             is_method_or_classmethod=False,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.FULLY_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=True,
             annotated_parameter_count=0,
             is_method_or_classmethod=False,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.PARTIALLY_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=0,
             is_method_or_classmethod=False,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.NOT_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=1,
             is_method_or_classmethod=False,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.PARTIALLY_ANNOTATED,
     )
     # An untyped `self` parameter of a method does not count for partial
     # annotation. As per PEP 484, we need an explicitly annotated parameter.
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=1,
             is_method_or_classmethod=True,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.NOT_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=2,
             is_method_or_classmethod=True,
             parameters=three_parameters,
         ),
         FunctionAnnotationKind.PARTIALLY_ANNOTATED,
     )
     # An annotated `self` suffices to make Pyre typecheck the method.
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=1,
             is_method_or_classmethod=True,
             parameters=[
                 cst.Param(
                     name=cst.Name("self"),
                     annotation=cst.Annotation(cst.Name("Foo")),
                 )
             ],
         ),
         FunctionAnnotationKind.PARTIALLY_ANNOTATED,
     )
     self.assertEqual(
         FunctionAnnotationKind.from_function_data(
             is_return_annotated=False,
             annotated_parameter_count=0,
             is_method_or_classmethod=True,
             parameters=[],
         ),
         FunctionAnnotationKind.NOT_ANNOTATED,
     )
コード例 #6
0
 def argument_cst(name, default=None):
     return cst.Param(cst.Name(name), default=default)
コード例 #7
0
class LambdaParserTest(CSTNodeTest):
    @data_provider((
        # Simple lambda
        (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"),
        # Test basic positional params
        (
            cst.Lambda(
                cst.Parameters(params=(
                    cst.Param(
                        cst.Name("bar"),
                        star="",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Param(cst.Name("baz"), star=""),
                )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda bar, baz: 5",
        ),
        # Test basic positional default params
        (
            cst.Lambda(
                cst.Parameters(default_params=(
                    cst.Param(
                        cst.Name("bar"),
                        default=cst.SimpleString('"one"'),
                        equal=cst.AssignEqual(),
                        star="",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Param(
                        cst.Name("baz"),
                        default=cst.Integer("5"),
                        equal=cst.AssignEqual(),
                        star="",
                    ),
                )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda bar = "one", baz = 5: 5',
        ),
        # Mixed positional and default params.
        (
            cst.Lambda(
                cst.Parameters(
                    params=(cst.Param(
                        cst.Name("bar"),
                        star="",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ), ),
                    default_params=(cst.Param(
                        cst.Name("baz"),
                        default=cst.Integer("5"),
                        equal=cst.AssignEqual(),
                        star="",
                    ), ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda bar, baz = 5: 5",
        ),
        # Test kwonly params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(cst.Name("baz"), star=""),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda *, bar = "one", baz: 5',
        ),
        # Mixed params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(
                            cst.Name("first"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first, second, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed default_params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    default_params=(
                        cst.Param(
                            cst.Name("first"),
                            default=cst.Float("1.0"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            default=cst.Float("1.5"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params, default_params, and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(
                            cst.Name("first"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    default_params=(
                        cst.Param(
                            cst.Name("third"),
                            default=cst.Float("1.0"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("fourth"),
                            default=cst.Float("1.5"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params"), star="*")),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda *params: 5",
        ),
        # Typed star_arg, include kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(
                        cst.Name("params"),
                        star="*",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda *params, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params default_params, star_arg and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(
                            cst.Name("first"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    default_params=(
                        cst.Param(
                            cst.Name("third"),
                            default=cst.Float("1.0"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("fourth"),
                            default=cst.Float("1.5"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.Param(
                        cst.Name("params"),
                        star="*",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg and star_kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_kwarg=cst.Param(cst.Name("kwparams"), star="**")),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda **kwparams: 5",
        ),
        # Test star_arg and kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(
                        cst.Name("params"),
                        star="*",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    star_kwarg=cst.Param(cst.Name("kwparams"), star="**"),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda *params, **kwparams: 5",
        ),
        # Inner whitespace
        (
            cst.Lambda(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                params=cst.Parameters(),
                colon=cst.Colon(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace(" "),
                ),
                body=cst.Integer("5"),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( lambda  : 5 )",
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node, code, parse_expression, position)
コード例 #8
0
class LambdaCreationTest(CSTNodeTest):
    @data_provider((
        # Simple lambda
        (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"),
        # Test basic positional params
        (
            cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("bar")),
                                       cst.Param(cst.Name("baz")))),
                cst.Integer("5"),
            ),
            "lambda bar, baz: 5",
        ),
        # Test basic positional default params
        (
            cst.Lambda(
                cst.Parameters(default_params=(
                    cst.Param(cst.Name("bar"),
                              default=cst.SimpleString('"one"')),
                    cst.Param(cst.Name("baz"), default=cst.Integer("5")),
                )),
                cst.Integer("5"),
            ),
            'lambda bar = "one", baz = 5: 5',
        ),
        # Mixed positional and default params.
        (
            cst.Lambda(
                cst.Parameters(
                    params=(cst.Param(cst.Name("bar")), ),
                    default_params=(cst.Param(cst.Name("baz"),
                                              default=cst.Integer("5")), ),
                ),
                cst.Integer("5"),
            ),
            "lambda bar, baz = 5: 5",
        ),
        # Test kwonly params
        (
            cst.Lambda(
                cst.Parameters(kwonly_params=(
                    cst.Param(cst.Name("bar"),
                              default=cst.SimpleString('"one"')),
                    cst.Param(cst.Name("baz")),
                )),
                cst.Integer("5"),
            ),
            'lambda *, bar = "one", baz: 5',
        ),
        # Mixed params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed default_params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    default_params=(
                        cst.Param(cst.Name("first"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("second"),
                                  default=cst.Float("1.5")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params, default_params, and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    default_params=(
                        cst.Param(cst.Name("third"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("fourth"),
                                  default=cst.Float("1.5")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5',
            CodeRange((1, 0), (1, 84)),
        ),
        # Test star_arg
        (
            cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("params"))),
                cst.Integer("5"),
            ),
            "lambda *params: 5",
        ),
        # Typed star_arg, include kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params")),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda *params, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params default_params, star_arg and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    default_params=(
                        cst.Param(cst.Name("third"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("fourth"),
                                  default=cst.Float("1.5")),
                    ),
                    star_arg=cst.Param(cst.Name("params")),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg and star_kwarg
        (
            cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("kwparams"))),
                cst.Integer("5"),
            ),
            "lambda **kwparams: 5",
        ),
        # Test star_arg and kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params")),
                    star_kwarg=cst.Param(cst.Name("kwparams")),
                ),
                cst.Integer("5"),
            ),
            "lambda *params, **kwparams: 5",
        ),
        # Inner whitespace
        (
            cst.Lambda(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                whitespace_after_lambda=cst.SimpleWhitespace("  "),
                params=cst.Parameters(),
                colon=cst.Colon(whitespace_after=cst.SimpleWhitespace(" ")),
                body=cst.Integer("5"),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( lambda  : 5 )",
            CodeRange((1, 2), (1, 13)),
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node, code, expected_position=position)

    @data_provider((
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("arg"), default=cst.Integer("5")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("arg"))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"),
                                                    equal=cst.AssignEqual())),
                cst.Integer("5"),
            ),
            "Must have a default when specifying an AssignEqual.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"),
                                                    star="***")),
                cst.Integer("5"),
            ),
            r"Must specify either '', '\*' or '\*\*' for star.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(
                    cst.Name("bar"), default=cst.SimpleString('"one"')), )),
                cst.Integer("5"),
            ),
            "Cannot have defaults for params",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(cst.Name("bar")), )),
                cst.Integer("5"),
            ),
            "Must have defaults for default_params",
        ),
        (
            lambda: cst.Lambda(cst.Parameters(star_arg=cst.ParamStar()),
                               cst.Integer("5")),
            "Must have at least one kwonly param if ParamStar is used.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("bar"), star="*"), )
                               ),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("bar"),
                    default=cst.SimpleString('"one"'),
                    star="*",
                ), )),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(cst.Name("bar"),
                                                        star="*"), )),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("bar"), star="**")),
                cst.Integer("5"),
            ),
            r"Expecting a star prefix of '\*'",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="*")
                               ),
                cst.Integer("5"),
            ),
            r"Expecting a star prefix of '\*\*'",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(
                    cst.Name("arg"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("arg"),
                    default=cst.Integer("5"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("arg"),
                                                  annotation=cst.Annotation(
                                                      cst.Name("str")))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(
                    cst.Name("arg"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"),
                                                    annotation=cst.Annotation(
                                                        cst.Name("str")))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
コード例 #9
0
    'async': lambda _: {cst.FunctionDef}
}


nesting_op_children_getter = {
    '.': lambda node: node.body.body if isinstance(node, cst.ClassDef) else (),
    '(': lambda node: node.params.params if isinstance(node, cst.FunctionDef) else (),
    ',': getNextParam,
    None: lambda node: node.children
}


per_path_default_kwargs = RelativePathDict({
    (cst.ClassDef, cst.FunctionDef): lambda path: {
        # TODO: if node is not decorated as staticmethod or class method!
        'params': cst.Parameters(params=[cst.Param(cst.Name('self'))])
    },
})


# NOTE: it is bad design to allow users to craft coincidentally unambiguous
# scope expressions, instead, it would be better to have a plenty of unambigous
# property keys (i.e. func, class, async, static, etc) for the user to craft with
def possibleElemTypes(scope_expr: ScopeExpr) -> Set[cst.CSTNode]:
    return reduce(and_,
                  map(lambda key, val: possible_node_classes_per_prop[key](val),
                      scope_expr.properties.keys(),
                      scope_expr.properties.values()),
                  node_types)

コード例 #10
0
ファイル: target_functions.py プロジェクト: tblazina/mcx
def sample_posterior_predictive(model, node_names):
    """Sample from the posterior predictive distribution.

    Example
    -------

    We transform MCX models of the form:

        >>> def linear_regression(X, lmbda=1.):
        ...    scale <~ Exponential(lmbda)
        ...    coef <~ Normal(jnp.zeros(X.shape[0]), 1)
        ...    y = jnp.dot(X, coef)
        ...    pred <~ Normal(y, scale)
        ...    return pred

    into:

        >>> def linear_regression_pred(rng_key, scale, coef, X, lambda=1.):
        ...    idx = jax.random.choice(rng_key, scale.shape[0])
        ...    scale_sample = scale[idx]
        ...    coef_sample = coef[idx]
        ...    y = jnp.dot(X, coef_sample)
        ...    pred = Normal(y, scale_sample).sample(rng_key)
        ...    return pred

    """
    graph = copy.deepcopy(model.graph)
    nodes = [graph.find_node(name) for name in node_names]

    # We will need to pass a RNG key to the function to sample from
    # the distributions; we create a placeholder for this key.
    rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key")
    graph.add_node(rng_node)

    # To take a sampler from the posterior distribution we first choose a sample id
    # at random `idx = mcx.jax.choice(rng_key, num_samples)`. We later index each
    # array of samples passed by this `idx`.
    def choice_ast(rng_key):
        return cst.Call(
            func=cst.Attribute(
                value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")),
                attr=cst.Name("choice"),
            ),
            args=[
                cst.Arg(rng_key),
                cst.Arg(
                    cst.Subscript(
                        cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")),
                        [cst.SubscriptElement(cst.Index(cst.Integer("0")))],
                    )
                ),
            ],
        )

    choice_node = Op(choice_ast, graph.name, "idx")
    graph.add(choice_node, rng_node)

    # Remove all edges incoming to the nodes that are targetted
    # by the intervention.
    to_remove = []
    for e in graph.in_edges(nodes):
        to_remove.append(e)

    for edge in to_remove:
        graph.remove_edge(*edge)

    # Each SampleOp that is intervened on is replaced by a placeholder that is indexed
    # by the index of the sample being taken.
    for node in reversed(nodes):
        rv_name = node.name

        # Add the placeholder
        placeholder = Placeholder(
            partial(lambda name: cst.Param(cst.Name(name)), rv_name),
            rv_name,
            is_random_variable=True,
        )
        graph.add_node(placeholder)

        def sample_index(placeholder, idx):
            return cst.Subscript(placeholder, [cst.SubscriptElement(cst.Index(idx))])

        chosen_sample = Op(sample_index, graph.name, rv_name + "_sample")
        graph.add(chosen_sample, placeholder, choice_node)

        original_edges = []
        for e in graph.out_edges(node):
            data = graph.get_edge_data(*e)
            original_edges.append(e)
            graph.add_edge(chosen_sample, e[1], **data)

        for e in original_edges:
            graph.remove_edge(*e)

        graph.remove_node(node)

    # recursively remove every node that has no outgoing edge and is not
    # returned
    graph = remove_dangling_nodes(graph)

    # replace SampleOps by sampling instruction
    def to_sampler(cst_generator, *args, **kwargs):
        rng_key = kwargs.pop("rng_key")
        return cst.Call(
            func=cst.Attribute(
                value=cst_generator(*args, **kwargs), attr=cst.Name("sample")
            ),
            args=[cst.Arg(value=rng_key)],
        )

    random_variables = []
    for node in reversed(list(graph.nodes())):
        if not isinstance(node, SampleOp):
            continue
        node.cst_generator = partial(to_sampler, node.cst_generator)
        random_variables.append(node)

    # Add the placeholders to the graph
    for var in random_variables:
        graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"])

    return compile_graph(
        graph, model.namespace, f"{graph.name}_sample_posterior_predictive"
    )
コード例 #11
0
ファイル: target_functions.py プロジェクト: tblazina/mcx
def sample_joint(model):
    """Obtain forward samples from the joint distribution defined by the model."""
    graph = copy.deepcopy(model.graph)
    namespace = model.namespace

    def to_dictionary_of_samples(random_variables, *_):
        scopes = [rv.scope for rv in random_variables]
        names = [rv.name for rv in random_variables]

        scoped = defaultdict(dict)
        for scope, var_name, var in zip(scopes, names, random_variables):
            scoped[scope][var_name] = var

        # if there is only one scope (99% of models) we return a flat dictionary
        if len(set(scopes)) == 1:

            scope = scopes[0]
            return cst.Dict(
                [
                    cst.DictElement(
                        cst.SimpleString(f"'{var_name}'"),
                        cst.Name(var.name),
                    )
                    for var_name, var in scoped[scope].items()
                ]
            )

        # Otherwise we return a nested dictionary where the first level is
        # the scope, and then the variables.
        return cst.Dict(
            [
                cst.DictElement(
                    cst.SimpleString(f"'{scope}'"),
                    cst.Dict(
                        [
                            cst.DictElement(
                                cst.SimpleString(f"'{var_name}'"),
                                cst.Name(var.name),
                            )
                            for var_name, var in scoped[scope].items()
                        ]
                    ),
                )
                for scope in scoped.keys()
            ]
        )

    # no node is returned anymore
    for node in graph.nodes():
        if isinstance(node, Op):
            node.is_returned = False

    rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key")

    # Update the SampleOps to return a sample from the distribution so that
    # `a <~ Normal(0, 1)` becomes `a = Normal(0, 1).sample(rng_key)`.
    def distribution_to_sampler(cst_generator, *args, **kwargs):
        rng_key = kwargs.pop("rng_key")
        return cst.Call(
            func=cst.Attribute(cst_generator(*args, **kwargs), cst.Name("sample")),
            args=[cst.Arg(value=rng_key)],
        )

    def model_to_sampler(model_name, *args, **kwargs):
        rng_key = kwargs.pop("rng_key")
        return cst.Call(
            func=cst.Attribute(cst.Name(value=model_name), cst.Name("sample")),
            args=[cst.Arg(value=rng_key)] + list(args),
        )

    random_variables = []
    for node in reversed(list(graph.random_variables)):
        if isinstance(node, SampleModelOp):
            node.cst_generator = partial(model_to_sampler, node.model_name)
        else:
            node.cst_generator = partial(distribution_to_sampler, node.cst_generator)
        random_variables.append(node)

    # Link the `rng_key` placeholder to the sampling expressions
    graph.add(rng_node)
    for var in random_variables:
        graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"])

    for node in graph.random_variables:
        if not isinstance(node, SampleModelOp):
            continue

        rv_name = node.name
        returned_var_name = node.graph.returned_variables[0].name

        def sample_index(rv, returned_var, *_):
            return cst.Subscript(
                cst.Name(rv),
                [cst.SubscriptElement(cst.SimpleString(f"'{returned_var}'"))],
            )

        chosen_sample = Op(
            partial(sample_index, rv_name, returned_var_name),
            graph.name,
            rv_name + "_value",
        )

        original_edges = []
        data = []
        out_nodes = []
        for e in graph.out_edges(node):
            datum = graph.get_edge_data(*e)
            data.append(datum)
            original_edges.append(e)
            out_nodes.append(e[1])

        for e in original_edges:
            graph.remove_edge(*e)

        graph.add(chosen_sample, node)
        for e, d in zip(out_nodes, data):
            graph.add_edge(chosen_sample, e, **d)

    tuple_node = Op(
        partial(to_dictionary_of_samples, graph.random_variables),
        graph.name,
        "forward_samples",
        is_returned=True,
    )
    graph.add(tuple_node, *graph.random_variables)

    return compile_graph(graph, namespace, f"{graph.name}_sample_forward")
コード例 #12
0
ファイル: target_functions.py プロジェクト: tblazina/mcx
 def placeholder_to_param(name: str):
     return cst.Param(cst.Name(name))