def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode:
        try:
            key = original.func.attr.value
            kword_params = self.METHOD_TO_PARAMS[key]
        except (AttributeError, KeyError):
            # Either not a method from the API or too convoluted to be sure.
            return updated

        # If the existing code is valid, keyword args come after positional args.
        # Therefore, all positional args must map to the first parameters.
        args, kwargs = partition(lambda a: not bool(a.keyword), updated.args)
        if any(k.keyword.value == "request" for k in kwargs):
            # We've already fixed this file, don't fix it again.
            return updated

        kwargs, ctrl_kwargs = partition(
            lambda a: not a.keyword.value in self.CTRL_PARAMS, kwargs)

        args, ctrl_args = args[:len(kword_params)], args[len(kword_params):]
        ctrl_kwargs.extend(
            cst.Arg(value=a.value, keyword=cst.Name(value=ctrl))
            for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS))

        request_arg = cst.Arg(
            value=cst.Dict([
                cst.DictElement(cst.SimpleString("'{}'".format(name)),
                                cst.Element(value=arg.value))
                # Note: the args + kwargs looks silly, but keep in mind that
                # the control parameters had to be stripped out, and that
                # those could have been passed positionally or by keyword.
                for name, arg in zip(kword_params, args + kwargs)
            ]),
            keyword=cst.Name("request"))

        return updated.with_changes(args=[request_arg] + ctrl_kwargs)
Exemplo n.º 2
0
    def to_dictionary_of_samples(random_variables, *_):
        scopes = [rv.scope for rv in random_variables]
        names = [rv.name for rv in random_variables]

        scoped = defaultdict(dict)
        for scope, var_name, var in zip(scopes, names, random_variables):
            scoped[scope][var_name] = var

        # if there is only one scope (99% of models) we return a flat dictionary
        if len(set(scopes)) == 1:

            scope = scopes[0]
            return cst.Dict(
                [
                    cst.DictElement(
                        cst.SimpleString(f"'{var_name}'"),
                        cst.Name(var.name),
                    )
                    for var_name, var in scoped[scope].items()
                ]
            )

        # Otherwise we return a nested dictionary where the first level is
        # the scope, and then the variables.
        return cst.Dict(
            [
                cst.DictElement(
                    cst.SimpleString(f"'{scope}'"),
                    cst.Dict(
                        [
                            cst.DictElement(
                                cst.SimpleString(f"'{var_name}'"),
                                cst.Name(var.name),
                            )
                            for var_name, var in scoped[scope].items()
                        ]
                    ),
                )
                for scope in scoped.keys()
            ]
        )
Exemplo n.º 3
0
 def _dict_call(self, node: cst.Call) -> Union[cst.Call, cst.Dict, cst.DictComp]:
     if not node.args:
         return cst.Dict(elements=[])
     if len(node.args) != 1:
         return node
     value = node.args[0].value
     if isinstance(value, cst.DictComp):
         return value
     if isinstance(value, (cst.ListComp, cst.GeneratorExp)):
         elt = value.elt
         if isinstance(elt, (cst.Tuple, cst.List)) and len(elt.elements) == 2:
             return cst.DictComp(
                 key=elt.elements[0].value,
                 value=elt.elements[1].value,
                 for_in=value.for_in,
             )
     if isinstance(value, (cst.Tuple, cst.List)):
         if value.elements:
             elements = []
             for el in value.elements:
                 if (
                     isinstance(el.value, (cst.Tuple, cst.List))
                     and len(el.value.elements) == 2
                 ):
                     elements.append(
                         cst.DictElement(
                             key=el.value.elements[0].value,
                             value=el.value.elements[1].value,
                         )
                     )
                 else:
                     break
             else:
                 return cst.Dict(elements=elements)
         else:
             return cst.Dict(elements=[])
     return node
Exemplo n.º 4
0
    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(
                    func=m.Name("tuple") | m.Name("list") | m.Name("set")
                    | m.Name("dict"),
                    args=[m.Arg(value=m.List() | m.Tuple())],
                ),
        ) or m.matches(
                node,
                m.Call(func=m.Name("tuple") | m.Name("list") | m.Name("dict"),
                       args=[]),
        ):

            pairs_matcher = m.ZeroOrMore(
                m.Element(m.Tuple(
                    elements=[m.DoNotCare(), m.DoNotCare()]))
                | m.Element(m.List(
                    elements=[m.DoNotCare(), m.DoNotCare()])))

            exp = cst.ensure_type(node, cst.Call)
            call_name = cst.ensure_type(exp.func, cst.Name).value

            # If this is a empty call, it's an Unnecessary Call where we rewrite the call
            # to literal, except set().
            if not exp.args:
                elements = []
                message_formatter = UNNCESSARY_CALL
            else:
                arg = exp.args[0].value
                elements = cst.ensure_type(
                    arg, cst.List
                    if isinstance(arg, cst.List) else cst.Tuple).elements
                message_formatter = UNNECESSARY_LITERAL

            if call_name == "tuple":
                new_node = cst.Tuple(elements=elements)
            elif call_name == "list":
                new_node = cst.List(elements=elements)
            elif call_name == "set":
                # set() doesn't have an equivelant literal call. If it was
                # matched here, it's an unnecessary literal suggestion.
                if len(elements) == 0:
                    self.report(
                        node,
                        UNNECESSARY_LITERAL.format(func=call_name),
                        replacement=node.deep_replace(
                            node, cst.Call(func=cst.Name("set"))),
                    )
                    return
                new_node = cst.Set(elements=elements)
            elif len(elements) == 0 or m.matches(
                    exp.args[0].value,
                    m.Tuple(elements=[pairs_matcher])
                    | m.List(elements=[pairs_matcher]),
            ):
                new_node = cst.Dict(elements=[(
                    lambda val: cst.DictElement(val.elements[
                        0].value, val.elements[1].value))(cst.ensure_type(
                            ele.value,
                            cst.Tuple if isinstance(ele.value, cst.Tuple
                                                    ) else cst.List,
                        )) for ele in elements])
            else:
                # Unrecoginized form
                return

            self.report(
                node,
                message_formatter.format(func=call_name),
                replacement=node.deep_replace(node, new_node),
            )
Exemplo n.º 5
0
    def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode:
        try:
            key = original.func.attr.value
            kword_params = self.METHOD_TO_PARAMS[key]
        except (AttributeError, KeyError):
            # Either not a method from the API or too convoluted to be sure.
            return updated

        # If the existing code is valid, keyword args come after positional args.
        # Therefore, all positional args must map to the first parameters.
        args, kwargs = partition(lambda a: not bool(a.keyword), updated.args)
        if any(k.keyword.value == "request" for k in kwargs):
            # We've already fixed this file, don't fix it again.
            return updated

        kwargs, ctrl_kwargs = partition(
            lambda a: not a.keyword.value in self.CTRL_PARAMS, kwargs)

        args, ctrl_args = args[:len(kword_params)], args[len(kword_params):]
        ctrl_kwargs.extend(
            cst.Arg(
                value=a.value,
                keyword=cst.Name(value=ctrl),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace(""),
                    whitespace_after=cst.SimpleWhitespace(""),
                ),
            ) for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS))

        if self._use_keywords:
            new_kwargs = [
                cst.Arg(
                    value=arg.value,
                    keyword=cst.Name(value=name),
                    equal=cst.AssignEqual(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                ) for name, arg in zip(kword_params, args + kwargs)
            ]
            new_kwargs.extend([
                cst.Arg(
                    value=arg.value,
                    keyword=cst.Name(value=arg.keyword.value),
                    equal=cst.AssignEqual(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                ) for arg in ctrl_kwargs
            ])
            return updated.with_changes(args=new_kwargs)
        else:
            request_arg = cst.Arg(
                value=cst.Dict([
                    cst.DictElement(
                        cst.SimpleString('"{}"'.format(name)),
                        cst.Element(value=arg.value),
                    ) for name, arg in zip(kword_params, args + kwargs)
                ] + [
                    cst.DictElement(
                        cst.SimpleString('"{}"'.format(arg.keyword.value)),
                        cst.Element(value=arg.value),
                    ) for arg in ctrl_kwargs
                ]),
                keyword=cst.Name("request"),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace(""),
                    whitespace_after=cst.SimpleWhitespace(""),
                ),
            )

            return updated.with_changes(args=[request_arg])
Exemplo n.º 6
0
def make_dict(items):
    return cst.Dict(
        elements=[cst.DictElement(key=k, value=v) for k, v in items])
Exemplo n.º 7
0
class DictTest(CSTNodeTest):
    @data_provider([
        # zero-element dict
        {
            "node": cst.Dict([]),
            "code": "{}",
            "parser": parse_expression,
            "expected_position": CodeRange((1, 0), (1, 2)),
        },
        # one-element dict, sentinel comma value
        {
            "node": cst.Dict([cst.DictElement(cst.Name("k"), cst.Name("v"))]),
            "code": "{k: v}",
            "parser": parse_expression,
            "expected_position": CodeRange((1, 0), (1, 6)),
        },
        {
            "node": cst.Dict([cst.StarredDictElement(cst.Name("expanded"))]),
            "code": "{**expanded}",
            "parser": parse_expression,
            "expected_position": CodeRange((1, 0), (1, 12)),
        },
        # two-element dict, sentinel comma value
        {
            "node":
            cst.Dict([
                cst.DictElement(cst.Name("k1"), cst.Name("v1")),
                cst.DictElement(cst.Name("k2"), cst.Name("v2")),
            ]),
            "code":
            "{k1: v1, k2: v2}",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 16)),
        },
        # custom whitespace between brackets
        {
            "node":
            cst.Dict(
                [cst.DictElement(cst.Name("k"), cst.Name("v"))],
                lbrace=cst.LeftCurlyBrace(
                    whitespace_after=cst.SimpleWhitespace("\t")),
                rbrace=cst.RightCurlyBrace(
                    whitespace_before=cst.SimpleWhitespace("\t\t")),
            ),
            "code":
            "{\tk: v\t\t}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 9)),
        },
        # with parenthesis
        {
            "node":
            cst.Dict(
                [cst.DictElement(cst.Name("k"), cst.Name("v"))],
                lpar=[cst.LeftParen()],
                rpar=[cst.RightParen()],
            ),
            "code":
            "({k: v})",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 7)),
        },
        # starred element
        {
            "node":
            cst.Dict([
                cst.StarredDictElement(cst.Name("one")),
                cst.StarredDictElement(cst.Name("two")),
            ]),
            "code":
            "{**one, **two}",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 14)),
        },
        # custom comma on DictElement
        {
            "node":
            cst.Dict([
                cst.DictElement(cst.Name("k"),
                                cst.Name("v"),
                                comma=cst.Comma())
            ]),
            "code":
            "{k: v,}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 7)),
        },
        # custom comma on StarredDictElement
        {
            "node":
            cst.Dict([
                cst.StarredDictElement(cst.Name("expanded"), comma=cst.Comma())
            ]),
            "code":
            "{**expanded,}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 13)),
        },
        # custom whitespace on DictElement
        {
            "node":
            cst.Dict([
                cst.DictElement(
                    cst.Name("k"),
                    cst.Name("v"),
                    whitespace_before_colon=cst.SimpleWhitespace("\t"),
                    whitespace_after_colon=cst.SimpleWhitespace("\t\t"),
                )
            ]),
            "code":
            "{k\t:\t\tv}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 8)),
        },
        # custom whitespace on StarredDictElement
        {
            "node":
            cst.Dict([
                cst.DictElement(cst.Name("k"),
                                cst.Name("v"),
                                comma=cst.Comma()),
                cst.StarredDictElement(
                    cst.Name("expanded"),
                    whitespace_before_value=cst.SimpleWhitespace("  "),
                ),
            ]),
            "code":
            "{k: v,**  expanded}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 19)),
        },
        # missing spaces around dict is always okay
        {
            "node":
            cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    cst.Name("b"),
                    cst.Dict([cst.DictElement(cst.Name("k"), cst.Name("v"))]),
                    ifs=[
                        cst.CompIf(
                            cst.Name("c"),
                            whitespace_before=cst.SimpleWhitespace(""),
                        )
                    ],
                    whitespace_after_in=cst.SimpleWhitespace(""),
                ),
            ),
            "parser":
            parse_expression,
            "code":
            "(a for b in{k: v}if c)",
        },
    ])
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider([
        # unbalanced Dict
        {
            "get_node": lambda: cst.Dict([], lpar=[cst.LeftParen()]),
            "expected_re": "left paren without right paren",
        }
    ])
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)

    @data_provider((
        {
            "code": "{**{}}",
            "parser": parse_expression_as(python_version="3.5"),
            "expect_success": True,
        },
        {
            "code": "{**{}}",
            "parser": parse_expression_as(python_version="3.3"),
            "expect_success": False,
        },
    ))
    def test_versions(self, **kwargs: Any) -> None:
        self.assert_parses(**kwargs)