コード例 #1
0
def annotate_funcs_in_stmt(stmt: ast.stmt, decorator: ast.expr, copy: bool = False) -> ast.stmt:
    if copy:
        stmt = deepcopy(stmt)
    if isinstance(stmt, ast.FunctionDef) and not any(
        "savefig" in ast.unparse(stmt) or "self" in ast.unparse(stmt)
        for stmt in stmt.body
    ):
        stmt.decorator_list = [decorator] + stmt.decorator_list
        stmt.body = [annotate_funcs_in_stmt(stmt, decorator, copy=False) for stmt in stmt.body]
        return stmt
    else:
        if isinstance(stmt, stmts_with_body):
            stmt.body = [
                annotate_funcs_in_stmt(child_stmt, decorator, copy=False)
                for child_stmt in stmt.body
            ]
        if isinstance(stmt, stmts_with_orelse):
            stmt.orelse = [
                annotate_funcs_in_stmt(child_stmt, decorator, copy=False)
                for child_stmt in stmt.orelse
            ]
        if isinstance(stmt, stmts_with_finalbody):
            stmt.finalbody = [
                annotate_funcs_in_stmt(child_stmt, decorator, copy=False)
                for child_stmt in stmt.finalbody
            ]
        # if isinstance(stmt, stmts_with_handler):
        #     stmt.handlers = [
        #         annotate_funcs_in_stmt(handler, decorator, copy=False)
        #         for handler in stmt.handlers
        #     ]
        return cast(ast.stmt, stmt)
コード例 #2
0
 def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
     if node.value.id == "NotchCost":
         notches = [ast.Constant(value=notch.value - 1) for notch in node.slice.elts]  # apparently 1-indexed
         return ast.Call(
             func=ast.Attribute(value=ast.Name(id='state', ctx=ast.Load()), attr='_hk_notches', ctx=ast.Load()),
             args=[ast.Name(id="player", ctx=ast.Load())] + notches, keywords=[])
     elif node.value.id == "StartLocation":
         node.slice.value = node.slice.value.replace(" ", "_").lower()
         if node.slice.value in removed_starts:
             return ast.Constant(False, ctx=node.ctx)
         return ast.Call(
             func=ast.Attribute(value=ast.Name(id='state', ctx=ast.Load()), attr='_hk_start', ctx=ast.Load()),
             args=[ast.Name(id="player", ctx=ast.Load()), node.slice], keywords=[])
     elif node.value.id == "COMBAT":
         return macros[unparse(node)].body
     else:
         name = unparse(node)
         if name in self.additional_truths:
             return ast.Constant(True, ctx=ast.Load())
         elif name in self.additional_falses:
             return ast.Constant(False, ctx=ast.Load())
         elif name in macros:
             # macro such as "COMBAT[White_Palace_Arenas]"
             return macros[name].body
         else:
             # assume Entrance
             entrance = unparse(node)
             assert entrance in connectors, entrance
             return ast.Call(
                 func=ast.Attribute(value=ast.Name(id='state', ctx=ast.Load()), attr='can_reach', ctx=ast.Load()),
                 args=[ast.Constant(value=entrance),
                       ast.Constant(value="Entrance"),
                       ast.Name(id="player", ctx=ast.Load())],
                 keywords=[])
     return node
コード例 #3
0
def main(argv: typing.Optional[typing.Sequence[str]] = None):
    import argparse
    import pathlib

    parser = argparse.ArgumentParser(prog="python -m ast_codez_tools." +
                                     pathlib.Path(__file__).stem)
    parser.add_argument("before_file",
                        help="Path to Python file before changes")
    parser.add_argument("after_file", help="Path to Python file after changes")
    args = parser.parse_args(argv)

    before_name: str = args.before_file
    after_name: str = args.after_file
    before_code = pathlib.Path(before_name).read_text(encoding="utf8")
    after_code = pathlib.Path(after_name).read_text(encoding="utf8")

    for func_name, before_node, after_node in extract_function_pairs(
            before_code=before_code,
            before_name=before_name,
            after_code=after_code,
            after_name=after_name,
    ):
        func_before = ast.unparse(before_node)
        func_after = ast.unparse(after_node)

        print("-" * 80)
        print(f"Function name: {func_name}()")
        if func_before != func_after:
            print("\nHas differences!\n")
            print(f'{"Before: ":-<80}')
            print(func_before)
            print(f'{"After: ":-<80}')
            print(func_after)
コード例 #4
0
    def globals(self, root: str, node: _G) -> None:
        """Set up globals:

        + Type alias
        + Constants
        + `__all__` filter
        """
        if (isinstance(node, AnnAssign) and isinstance(node.target, Name)
                and node.value is not None):
            left = node.target
            expression = unparse(node.value)
            ann = self.resolve(root, node.annotation)
        elif (isinstance(node, Assign) and len(node.targets) == 1
              and isinstance(node.targets[0], Name)):
            left = node.targets[0]
            expression = unparse(node.value)
            if node.type_comment is None:
                ann = const_type(node.value)
            else:
                ann = node.type_comment
        else:
            return
        name = _m(root, left.id)
        self.alias[name] = expression
        if left.id.isupper():
            self.root[name] = root
            if self.const.get(name, ANY) == ANY:
                self.const[name] = ann
        if left.id != '__all__' or not isinstance(node.value, (Tuple, List)):
            return
        for e in node.value.elts:
            if isinstance(e, Constant) and isinstance(e.value, str):
                self.imp[root].add(_m(root, e.value))
コード例 #5
0
 def test_visit_constant(self):
     self.assertEqual(
         ast.unparse(self.tree.body[0]),
         "a = ''.join([chr(x) for x in [97, 32, 115, 116, 114, 105, 110, "
         "103, 32, 108, 105, 116, 101, 114, 97, 108]])")
     self.assertEqual(ast.unparse(self.tree.body[1]), "b = int('0x2a', 16)")
     self.assertEqual(ast.unparse(self.tree.body[2]),
                      "c = float.fromhex('0x1.9000000000000p+6')")
コード例 #6
0
 def _extract_codes(self, contract_source):
     class_types, reformatted_code = self._pythonize(contract_source)
     self._code_tree = self._parse(reformatted_code)
     self._vyper_code_tree = copy.deepcopy(self._code_tree)
     self._transform(self._vyper_code_tree)
     pythonized_vyper_code = unparse(self._vyper_code_tree)
     self._vyper_code = self._recover_vyper_code(pythonized_vyper_code,
                                                 class_types)
     self._mpc_code = unparse(self._mpc_code_tree)
     return self._vyper_code, self._mpc_code
コード例 #7
0
ファイル: check_new_syntax.py プロジェクト: J-M0/typeshed
 def visit_If(self, node: ast.If) -> None:
     if isinstance(node.test, ast.Compare) and ast.unparse(
             node.test).startswith(
                 "sys.version_info < ") and node.orelse:
         new_syntax = "if " + ast.unparse(node.test).replace(
             "<", ">=", 1)
         errors.append(
             f"{path}:{node.lineno}: When using if/else with sys.version_info, "
             f"put the code for new Python versions first, e.g. `{new_syntax}`"
         )
     self.generic_visit(node)
コード例 #8
0
    def functions_with_invariants(self, function_name=None):
        if function_name is None:
            functions = ''
            for f_name in self.invariants():
                try:
                    f_text = ast.unparse(self.functions_with_invariants_ast(f_name))
                except KeyError:
                    f_text = ''
                functions += f_text
            return functions

        return ast.unparse(self.functions_with_invariants_ast(function_name))
コード例 #9
0
    def typed_functions(self, function_name=None):
        if function_name is None:
            functions = ''
            for f_name in self.calls():
                try:
                    f_text = ast.unparse(self.typed_functions_ast(f_name))
                except KeyError:
                    f_text = ''
                functions += f_text
            return functions

        return ast.unparse(self.typed_functions_ast(function_name))
コード例 #10
0
ファイル: __init__.py プロジェクト: Buzzvil/gen_py2_dc
def parse_input(class_def: str) -> Iterable[Class]:
    result: ast.Module = ast.parse(class_def)

    for class_def in result.body:
        if not isinstance(class_def, ast.ClassDef):
            continue

        yield Class(
            name=class_def.name,
            fields=tuple(
                (ast.unparse(assign.target), ast.unparse(assign.annotation))
                for assign in class_def.body
                if isinstance(assign, ast.AnnAssign)
            ),
        )
コード例 #11
0
def main() -> int:
    args = parser.parse_args()
    if args.mode == 'auto':
        args.mode = 'run'
    source = args.script.read()
    try:
        filename = args.script.name
    except Exception:
        filename = '<unknown>'
    tree = parse(source, filename)
    if args.mode == 'dump':
        print(ast.dump(tree, indent=3, include_attributes=True))
    elif args.mode == 'run':
        compiled = compile(tree, filename, 'exec')
        _run_code(compiled, {
            '__builtins__': builtins
        }, mod_name='__main__', script_name=filename)
    elif args.mode == 'py':
        print(ast.unparse(tree))
    elif args.mode == 'compile_only':
        error: None
        start = time.process_time_ns()
        try:
            compile(tree, filename, 'exec')
        except SyntaxError as e:
            error = e
        end = time.process_time_ns()
        if error is None:
            print(f'Successfully compiled "{filename}" (AST had {count_nodes(tree)} nodes) in {end - start}ns.')
        else:
            print(f'Failed to compile "{filename}" (AST had {count_nodes(tree)} nodes) in {end - start}ns.')
            import traceback
            traceback.print_exception(SyntaxError, error, error.__traceback__, 0)
            return 1
コード例 #12
0
ファイル: teyit.py プロジェクト: isidentical/teyit
    def visit_assertTrue(self, node, positive=True):
        expr, *args = node.args
        if isinstance(expr, ast.Compare) and len(expr.ops) == 1:
            left = expr.left
            operator = type(expr.ops[0])
            if not positive:
                if operator in CONTRA_OPS:
                    operator = CONTRA_OPS[operator]
                else:
                    return None

            (comparator, ) = expr.comparators
            if (operator in (ast.Is, ast.IsNot)
                    and isinstance(comparator, ast.Constant)
                    and comparator.value is None):

                func = f"assert{operator.__name__}None"
                args = [left, *args]
            elif operator in OPERATOR_TABLE:
                func = OPERATOR_TABLE[operator]
                args = [left, comparator, *args]
            else:
                return None
        elif (isinstance(expr, ast.Call)
              and ast.unparse(expr.func) == "isinstance"
              and len(expr.args) == 2):
            if positive:
                func = "assertIsInstance"
            else:
                func = "assertNotIsInstance"
            args = [*expr.args, *args]
        else:
            return None
        return Rewrite(node, func, args)
コード例 #13
0
ファイル: goggles.py プロジェクト: charles-l/pygoggles
def serialize_cells(stream, cells):
    # TODO: store AST in Cell and serialize that instead. This will come
    #       with challenges since the AST won't include comments, but it's
    #       cleaner.

    cells_node = ast.List(elts=[], ctx=ast.Load())
    result_tree = ast.Module(
        body=[
            # cells = [...]
            ast.Assign(targets=[ast.Name(id='cells', ctx=ast.Store())],
                       value=cells_node,
                       lineno=1)
        ],
        type_ignores=[])

    for i, cell in enumerate(cells):
        text = cell.textbox.get('0.0', 'end')
        tree = ast.parse(text, mode='exec')
        result_tree.body.append(
            ast.FunctionDef(name=f'cell_{i}',
                            decorator_list=[],
                            args=ast.arguments(args=[],
                                               posonlyargs=[],
                                               defaults=[],
                                               kwonlyargs=[]),
                            lineno=i,
                            body=tree.body or [ast.Pass()]))
        cells_node.elts.append(ast.Name(id=f'cell_{i}', ctx=ast.Store()))

    stream.write(
        textwrap.dedent('''\
        # DO NOT EDIT
        # This file is autogenerated by pygoggles
        '''))
    stream.write(ast.unparse(result_tree))
コード例 #14
0
 def _get_replacement_f_string(self, node: ast.JoinedStr) -> str:
     # There is no easy way to compare two AST nodes for equality.
     # Instead we compare their serialized forms
     code = ast.unparse(node)
     return self._f_strings_seen.get(
         code) or self._f_strings_seen.setdefault(
             code, next(self._f_string_name_generator))
コード例 #15
0
def scan_file(filepath: str) -> set:
    """
    Scan a Python file and return a set of annotations.

    Since parsing `Optional[typing.List]` and `Optional[typing.Dict]` is the same,
    we're not interested in keeping the actual names.
    Therefore we replace every word with "a".
    It has two benefits:

    - we can get rid of syntaxically equivalent annotations (duplicates)
    - the resulting annotations takes less bytes

    Arguments:
        filepath: The path to the Python file to scan.

    Returns:
        A set of annotations.
    """
    annotations: set = set()
    path = Path(filepath)
    try:
        code = ast.parse(path.read_text())
    except:
        return annotations
    for node in ast.walk(code):
        if hasattr(node, "annotation"):
            try:
                unparsed = unparse(node.annotation)  # type: ignore
                annotations.add(regex.sub("a", unparsed))
            except:
                continue
    return annotations
コード例 #16
0
    def test_ast_transform(self):
        src = """
def foo(flag: bool) -> int:
    print('Hello World')
    if flag:
        return 1
    else:
        return 'you stupid'
        """
        target = """
def foo(flag: bool) -> int:
    print('Hello World')
    if flag:
        untypy._before_return(0)
        return 1
    else:
        untypy._before_return(1)
        return 'you stupid'
        """

        tree = ast.parse(src)
        mgr = ReturnTraceManager()
        ReturnTracesTransformer("<dummyfile>", mgr).visit(tree)
        ast.fix_missing_locations(tree)
        self.assertEqual(ast.unparse(tree).strip(), target.strip())
        self.assertEqual(mgr.get(0), ("<dummyfile>", 5))
        self.assertEqual(mgr.get(1), ("<dummyfile>", 7))
コード例 #17
0
def get_source_code_executed(function, function_graph):
    list_of_graph_vertices_not_yet_processed = []
    list_of_graph_vertices_already_processed = []
    source_codes_executed = []
    for current_function_def_node in function_graph:
        if (current_function_def_node.qualname == function.__qualname__):
            list_of_graph_vertices_not_yet_processed.append(
                current_function_def_node)
            break

    while (len(list_of_graph_vertices_not_yet_processed) > 0):
        current_vertice = list_of_graph_vertices_not_yet_processed.pop(0)

        source_codes_executed.append(ast.unparse(current_vertice))

        for linked_vertice in function_graph[current_vertice]:
            if (linked_vertice not in list_of_graph_vertices_not_yet_processed
                    and linked_vertice
                    not in list_of_graph_vertices_already_processed
                    and linked_vertice != current_vertice):
                list_of_graph_vertices_not_yet_processed.append(linked_vertice)

        list_of_graph_vertices_already_processed.append(current_vertice)

    return "\n".join(source_codes_executed)
コード例 #18
0
def test_stubgen_pydantic():
    # Prepare models
    models = Models(__name__, types=AttributeType.ALL, naming='{model}Model')
    models.sa_model(User)
    models.sa_model(Article)

    # Convert
    py = ast.unparse(stubs_for_pydantic(models))
    assert py == '''
from __future__ import annotations
import pydantic
import builtins, datetime, typing
NoneType = type(None)

class UserModel(pydantic.main.BaseModel):
    """ User model """
    id: int = ...
    login: typing.Union[str, NoneType] = ...
    articles: list[ArticleModel] = ...

class ArticleModel(pydantic.main.BaseModel):
    """ Article model """
    id: int = ...
    user_id: typing.Union[int, NoneType] = ...
    ctime: typing.Union[datetime.datetime, NoneType] = ...
    user: typing.Union[UserModel, NoneType] = ...    
'''.strip()
コード例 #19
0
def generate_code_file(mod_body,
                       file,
                       imports,
                       external_functions_source=False,
                       names="#"):
    for (module, name) in imports.as_imports:
        mod_body.insert(
            0, ast.Import(names=[ast.alias(name=module, asname=name)],
                          level=0))
    for (module, name) in imports.from_imports:
        mod_body.insert(
            0,
            ast.ImportFrom(module=module,
                           names=[ast.alias(name=name, asname=None)],
                           level=0))
    if external_functions_source:
        mod_body.insert(
            0,
            ast.ImportFrom(module=external_functions_source,
                           names=[ast.alias(name='*', asname=None)],
                           level=0))

    mod = wrap_module(mod_body)
    print('Generating Source')
    source = names + ast.unparse(mod)

    return source
コード例 #20
0
    def _add_to_total_imports(self, node: Union[ast.Import, ast.ImportFrom]):
        import_metadata = {}
        try:
            import_metadata.update({'exact_line': ast.unparse(node)})
        except AttributeError:
            pass

        import_metadata.update(
            {v: False
             for v in SKETCHY_TYPES_TABLE.values()})
        import_metadata.update({
            SKETCHY_TYPES_TABLE[node.__class__]: True
            for node in self.sketchy_nodes
        })
        names = set()
        if isinstance(node, ast.Import):
            _names = set(name.name for name in node.names)
            import_metadata['import'] = _names
            names.update(_names)
        elif isinstance(node, ast.ImportFrom):
            import_metadata['import_from'] = {node.module}
            names.add(node.module)
        else:
            raise NotImplementedError(
                f"Expected ast.Import or ast.ImportFrom this is {type(node)}")

        for name in names:
            self.total_imports[name].update({
                (self.filename, node.lineno):
                import_metadata
            })
コード例 #21
0
        def visit_Subscript(self, node: ast.Subscript) -> None:
            if isinstance(node.value, ast.Name):
                if node.value.id == "Union" and isinstance(
                        node.slice, ast.Tuple):
                    new_syntax = " | ".join(
                        ast.unparse(x) for x in node.slice.elts)
                    errors.append(
                        f"{path}:{node.lineno}: Use PEP 604 syntax for Union, e.g. `{new_syntax}`"
                    )
                if node.value.id == "Optional":
                    new_syntax = f"{ast.unparse(node.slice)} | None"
                    errors.append(
                        f"{path}:{node.lineno}: Use PEP 604 syntax for Optional, e.g. `{new_syntax}`"
                    )
                if node.value.id == "List":
                    new_syntax = f"list[{ast.unparse(node.slice)}]"
                    errors.append(
                        f"{path}:{node.lineno}: Use built-in generics, e.g. `{new_syntax}`"
                    )
                if node.value.id == "Dict":
                    new_syntax = f"dict[{unparse_without_tuple_parens(node.slice)}]"
                    errors.append(
                        f"{path}:{node.lineno}: Use built-in generics, e.g. `{new_syntax}`"
                    )
                # Tuple[Foo, ...] must be allowed because of mypy bugs
                if node.value.id == "Tuple" and not (
                        isinstance(node.slice, ast.Tuple)
                        and len(node.slice.elts) == 2
                        and is_dotdotdot(node.slice.elts[1])):
                    new_syntax = f"tuple[{unparse_without_tuple_parens(node.slice)}]"
                    errors.append(
                        f"{path}:{node.lineno}: Use built-in generics, e.g. `{new_syntax}`"
                    )

            self.generic_visit(node)
コード例 #22
0
    def visit_Compare(self, node):
        if len(node.ops) < 2:
            return self.generic_visit(node)

        left, *checks, right = node.ops
        if not (isinstance(left, ast.Lt) and isinstance(right, ast.Gt)):
            return self.generic_visit(node)

        template_name = node.left.id
        *context_vars, template_args = node.comparators

        if not isinstance(template_args, ast.Tuple):
            template_args = ast.copy_location(
                ast.Tuple(elts=[template_args], ctx=ast.Load()), template_args)

        arg_rewriter = ArgRewriter(template_name, template_args.elts)
        context_args = [
            arg_rewriter.rewrite(context_var) for context_var in context_vars
        ]
        if checks:
            context_args = [
                ast.Compare(context_args[0], checks, context_args[1:])
            ]
            template_args.elts.append(
                ast.Constant(ast.unparse(context_args[0])))

        call = ast.Call(ast.Name(f"__template_{template_name}", ast.Load()),
                        [*context_args, template_args],
                        keywords=[])
        return ast.copy_location(call, node)
コード例 #23
0
ファイル: paramutils.py プロジェクト: bek0s/gbkfit
 def visit_Subscript(self, node):
     code = ast.unparse(node).strip('\n')
     name, subscript = parse_param_symbol_into_name_and_subscript_str(code)
     desc = self._descs.get(name)
     # If symbol is not recognised,
     # ignore this node but keep traversing this branch.
     if not desc:
         self.generic_visit(node)
         return
     # If symbol is a scalar,
     # mark symbol as invalid and ignore this node.
     if isinstance(desc, ParamScalarDesc):
         self._invalid_scalars.add(code)
         return
     # If symbol subscript syntax is not supported,
     # mark symbol as invalid and ignore this node.
     if not is_param_symbol_subscript(subscript):
         self._invalid_vectors[code] = None
         return
     # Extract indices while unwrapping the negative ones.
     size = desc.size()
     indices = parse_param_symbol_subscript(subscript, size)
     indices, invalid_indices = iterutils.validate_sequence_indices(
         indices, size)
     indices = iterutils.unwrap_sequence_indices(indices, size)
     # If out-of-range indices are found,
     # mark symbol as invalid and ignore this node.
     if invalid_indices:
         self._invalid_vectors[code] = invalid_indices
         return
     # Symbol is recognised and valid, add it to the list.
     self._symbols[name].update(indices)
コード例 #24
0
ファイル: CodeGenerator.py プロジェクト: zsh2020/Ryven
    def get_modules(self, ni_ast_dict: dict) -> dict:
        modules = {'imports': {}, 'fromimports': {}}

        for ni in ni_ast_dict.keys():
            a: ast.AST = ni_ast_dict[ni][1]
            for b in a.body:
                command = ast.unparse(b)
                if command in self.IGNORED_COMMANDS:
                    continue
                if type(b) == ast.Import:
                    c: ast.Import = b
                    for alias in c.names:
                        modules['imports'][c.names[0].name] = [c, True]
                elif type(b) == ast.ImportFrom:
                    c: ast.ImportFrom = b
                    mod_name = c.module
                    names = [a.name for a in c.names]
                    if mod_name in modules['fromimports']:
                        fromnames = modules['fromimports'][mod_name][2]
                        for n in names:
                            if n not in fromnames:
                                fromnames.append(n)
                        modules['fromimports'][mod_name][2] = fromnames
                    else:
                        modules['fromimports'][mod_name] = [c, True, names]

        return modules
コード例 #25
0
ファイル: main.py プロジェクト: ubuntumaroon/staticcode
 def generic_visit(self, node: ast.AST) -> Any:
     print("child for:", node)
     print(ast.unparse(node))
     for field, value in ast.iter_fields(node):
         print(field, value, end=';')
     print('\n-----')
     super().generic_visit(node)
コード例 #26
0
ファイル: astparse.py プロジェクト: edart76/tree
def objToQuoteless(obj):
	"""remove quotes from string """
	tree = ast.parse(str(obj), mode="eval")
	result = ast.fix_missing_locations(
		ASTConstantToName().visit(tree)	)
	result = ast.unparse(result) # new in py 3.9 :(
	return result
コード例 #27
0
ファイル: __init__.py プロジェクト: vishalbelsare/flor
 def visit_If(self, node: ast.If):
     if "flor.SkipBlock.step_into" in ast.unparse(node.test):
         stmts = [stmt for stmt in node.body if isinstance(stmt, ast.For)]
         assert len(stmts) == 1
         stmt = stmts.pop()
         stmt.body.append(ast.parse(self.payload).body.pop())
     return self.generic_visit(node)
コード例 #28
0
    def traverse(node):
        if isinstance(node, ast.If):
            cond = ast.unparse(node.test).strip()
            conditions.append(cond)

        for child in ast.iter_child_nodes(node):
            traverse(child)
コード例 #29
0
def save_trees(args=None):
    dst: Path = args["dst"]
    trees = args["trees"]
    dst_full = OUT_PATH.joinpath(dst)
    dst_full.parent.mkdir(parents=True, exist_ok=True)
    dst_full.touch(exist_ok=False)
    # TODO: append "doctest.testmod(raise_on_error=True)"
    trees = [ast.fix_missing_locations(tree) for tree in trees]
    if SHOULD_SAVE_AST:
        new_txt = "\n".join([str(astpretty.pformat(tree)) for tree in trees])
        new_txt = f"""from ast import *
{new_txt}
"""
        dst_full.with_suffix(".ast.py").write_text(new_txt)
    new_txt = ""
    if dst.name.startswith("test_"):
        if "compatible" in str(dst):
            new_txt += f"""
import {COMPATIBLE_MODULE}.unittest
"""
        else:
            new_txt += """
import oneflow.unittest
"""
    new_txt += "\n".join([ast.unparse(tree) for tree in trees])
    dst_full.write_text(new_txt)
コード例 #30
0
ファイル: _compile.py プロジェクト: petioptrv/cupy
def _transpile_function(func, attributes, mode, consts, in_types, ret_type):
    """Transpile the function
    Args:
        func (ast.FunctionDef): Target function.
        attributes (str): The attributes of target function.
        mode ('numpy' or 'cuda'): The rule for typecast.
        consts (dict): The dictionary with keys as variable names and
            values as concrete data object.
        in_types (list of _types.TypeBase): The types of arguments.
        ret_type (_types.TypeBase): The type of return value.

    Returns:
        code (str): The generated CUDA code.
        env (Environment): More details of analysis result of the function,
            which includes preambles, estimated return type and more.
    """
    consts = dict([(k, Constant(v)) for k, v, in consts.items()])

    if not isinstance(func, ast.FunctionDef):
        # TODO(asi1024): Support for `ast.ClassDef`.
        raise NotImplementedError('Not supported: {}'.format(type(func)))
    if len(func.decorator_list) > 0:
        if sys.version_info >= (3, 9):
            # Code path for Python versions that support `ast.unparse`.
            for deco in func.decorator_list:
                deco_code = ast.unparse(deco)
                if deco_code not in ['rawkernel', 'vectorize']:
                    warnings.warn(
                        f'Decorator {deco_code} may not supported in JIT.',
                        RuntimeWarning)
    arguments = func.args
    if arguments.vararg is not None:
        raise NotImplementedError('`*args` is not supported currently.')
    if len(arguments.kwonlyargs) > 0:  # same length with `kw_defaults`.
        raise NotImplementedError(
            'keyword only arguments are not supported currently .')
    if arguments.kwarg is not None:
        raise NotImplementedError('`**kwargs` is not supported currently.')
    if len(arguments.defaults) > 0:
        raise NotImplementedError(
            'Default values are not supported currently.')

    args = [arg.arg for arg in arguments.args]
    if len(args) != len(in_types):
        raise TypeError(
            f'{func.name}() takes {len(args)} positional arguments '
            f'but {len(in_types)} were given.')
    params = dict([(x, CudaObject(x, t)) for x, t in zip(args, in_types)])
    env = Environment(mode, consts, params, ret_type)
    body = _transpile_stmts(func.body, True, env)
    params = ', '.join([env[a].ctype.declvar(a) for a in args])
    local_vars = [v.ctype.declvar(n) + ';' for n, v in env.locals.items()]

    if env.ret_type is None:
        env.ret_type = _types.Void()

    head = f'{attributes} {env.ret_type} {func.name}({params})'
    code = CodeBlock(head, local_vars + body)
    return str(code), env