示例#1
0
    def test_unparse_invalid_examples(self):
        """Raise errors on ASTs of invalid examples as expected."""
        for description, example in INVALID_EXAMPLES.items():
            for mode in MODES:
                if example['trees'][mode] is None:
                    continue
                with self.assertRaises(SyntaxError, msg=(description, mode)) as raised:
                    typed_astunparse.unparse(example['trees'][mode])
                self.assertIn('PEP 526', str(raised.exception), msg=(description, mode))

                with self.assertRaises(SyntaxError, msg=(description, mode)):
                    typed_ast.ast3.parse(source=example['code'], mode=mode)
示例#2
0
    def _inline_call_in_expr(self, expr):
        call = expr.value
        assert self._is_valid_target_for_inlining(call)

        replacers = []
        decl_replacer = DeclarationReplacer()
        replacers.append(decl_replacer)
        replacers += create_name_replacers(self._inlined_args, call.args)

        inlined = self._inline_call(call, replacers)

        if decl_replacer.replaced:
            _LOG.warning('omitted %i declarations', len(decl_replacer.replaced))
            if self._omitted_declarations:
                _LOG.warning('will not restore them because there already are some declarations')
            else:
                if self._verbose:
                    call_code = typed_astunparse.unparse(call).strip()
                    self._omitted_declarations.append((horast_nodes.Comment(
                        ' start of declarations from inlined {}'.format(call_code), eol=False),
                                                       None))
                self._omitted_declarations += decl_replacer.replaced
                if self._verbose:
                    self._omitted_declarations.append((horast_nodes.Comment(
                        ' end of declarations from inlined {}'.format(call_code), eol=False), None))

        return inlined
示例#3
0
 def test_ann_assign(self):
     for ast_module in AST_MODULES:
         resolver = TypeHintResolver[ast_module,
                                     ast](globals_=GLOBALS_EXTERNAL)
         typer = StaticTyper[ast_module]()
         for example, assigned_vars in {
                 'a: int\n': {
                     'a': int
                 },
                 'a: int = 0\n': {
                     'a': int
                 },
                 'value: float = "oh my"\n': {
                     'value': float
                 }
         }.items():
             tree = ast_module.parse(example, mode='single')
             node = tree.body[0]
             node = resolver.visit(node)
             with self.subTest(ast_module=ast_module,
                               example=example,
                               assigned_vars=assigned_vars,
                               node=ast_module.dump(node)):
                 ann_assign = typer.visit(node)
                 self.assertIsInstance(ann_assign,
                                       StaticallyTypedAnnAssign[ast_module])
                 self.assertIsInstance(ann_assign._vars, dict)
                 self.assertEqual(len(ann_assign._vars), len(assigned_vars))
                 ann_assign_vars = {
                     typed_astunparse.unparse(name).rstrip(): type_
                     for name, type_ in ann_assign._vars.items()
                 }
                 self.assertDictEqual(ann_assign_vars, assigned_vars)
                 _LOG.info('%s', ann_assign)
示例#4
0
    def annotate(self, fpath, pred_idx, type_idx):
        self.__reset()

        if pred_idx == -1:
            self.__sift(fpath=fpath)
        else:
            self.__sift(pred_idx=pred_idx)

        # if no proper (i.e. non-property-access) predictions are for this file,
        # or the predictions are fewer than args.top
        if len(self.__rel_lines) == 0 or type_idx >= len(
                self.__rel_lines[0]["predicted_annotation_logprob_dist"]):
            return fpath

        self.__type_idx = type_idx

        with open(fpath) as src:
            self.__tree = parse(src.read())
        # self.__PP.pprint(dump(tree))

        new_tree = self.visit(self.__tree)
        if self.__unmodified:
            return fpath
        self.__add_type_imports(self.__get_types_2_import())
        new_tree = fix_missing_locations(new_tree)

        OLD_EXT = ".py"
        NEW_EXT = f"_tpl_{type_idx}.py"
        new_fpath = rreplace(fpath, OLD_EXT, NEW_EXT, 1)
        with open(new_fpath, "w", encoding="utf8") as dst:
            dst.write(typed_astunparse.unparse(new_tree))

        return new_fpath
示例#5
0
 def test_roundtrip(self):
     only_localizable = False
     for name, example in EXAMPLES.items():
         if ' with eol comments' in name or name.startswith('multiline '):
             continue
         data = {}
         with self.subTest(name=name, example=example, data=data):
             tree = typed_ast.ast3.parse(example)
             code = typed_astunparse.unparse(tree)
             complete_tree = parse(example)
             data['complete_tree'] = complete_tree
             complete_code = unparse(complete_tree)
             data['complete_code'] = complete_code
             self.assertGreaterEqual(len(complete_code.replace(' ', '')),
                                     len(code.replace(' ', '')),
                                     (complete_code, code))
             reparsed_tree = typed_ast.ast3.parse(code)
             tree_nodes = ast_to_list(tree, only_localizable)
             reparsed_tree_nodes = ast_to_list(reparsed_tree,
                                               only_localizable)
             self.assertEqual(len(reparsed_tree_nodes), len(tree_nodes))
             self.assertEqual(typed_ast.ast3.dump(reparsed_tree),
                              typed_ast.ast3.dump(tree))
             reparsed_complete_tree = parse(complete_code)
             complete_tree_nodes = ast_to_list(complete_tree,
                                               only_localizable)
             reparsed_complete_tree_nodes = ast_to_list(
                 reparsed_complete_tree, only_localizable)
             self.assertEqual(len(reparsed_complete_tree_nodes),
                              len(complete_tree_nodes), complete_code)
             self.assertEqual(
                 typed_ast.ast3.dump(reparsed_complete_tree),
                 typed_ast.ast3.dump(complete_tree),
                 '"""\n{}\n""" vs. original """\n{}\n"""'.format(
                     complete_code, example))
示例#6
0
def _transform(path: str, code: str,
               target: CompilationTarget) -> Tuple[str, List[str]]:
    """Applies all transformation for passed target."""
    dependencies = []  # type: List[str]

    for transformer in transformers:
        tree = ast.parse(code, path)
        if transformer.target < target:
            continue

        try:
            result = transformer.transform(tree)
        except:
            raise TransformationError(path, transformer, dump(tree),
                                      format_exc())

        if not result.tree_changed:
            continue

        dependencies.extend(result.dependencies)

        try:
            code = unparse(tree)
        except:
            raise TransformationError(path, transformer, dump(tree),
                                      format_exc())

    return fix_code(code), dependencies
    def test_unparse_invalid_examples(self):
        """Raise errors on ASTs of invalid examples as expected."""
        for description, example in INVALID_EXAMPLES.items():
            for mode in MODES:
                if example['trees'][mode] is None:
                    continue
                with self.assertRaises(SyntaxError, msg=(description, mode)):
                    typed_ast.ast3.parse(source=example['code'], mode=mode)

                code = typed_astunparse.unparse(example['trees'][mode])
                tree = None
                try:
                    tree = typed_ast.ast3.parse(source=code, mode=mode)
                except SyntaxError:
                    continue
                code = typed_astunparse.unparse(tree)
示例#8
0
def test_variables_replacer():
    tree = ast.parse('''
from f.f import f as f
import f as f

class f(f):
    def f(f):
        f = f
        for f in f:
            with f as f:
                yield f
        return f

    ''')
    VariablesReplacer.replace(tree, {'f': 'x'})
    code = unparse(tree)

    expected = '''
from x.x import x as x
import x as x

class x(x):

    def x(x):
        x = x
        for x in x:
            with x as x:
                (yield x)
        return x
    '''

    assert code.strip() == expected.strip()
 def test_bad_raw_literal(self):
     raw_literal = rb'''\t\t ' """ ''' + rb""" " ''' \n"""
     tree = typed_ast.ast3.Bytes(raw_literal, 'rb')
     # with self.assertRaises(SyntaxError):
     code = typed_astunparse.unparse(tree)
     print(code)
     for mode in MODES:
         tree = typed_ast.ast3.parse(source=code, mode=mode)
示例#10
0
 def test_unparse_examples(self):
     """Unparse ASTs of examples correctly."""
     for description, example in EXAMPLES.items():
         for mode in MODES:
             if example['trees'][mode] is None:
                 continue
             code = typed_astunparse.unparse(example['trees'][mode])
             _LOG.debug('%s', code)
             code = code.strip()
             self.assertEqual(code, example['code'], msg=(description, mode))
示例#11
0
 def test_untyped_files(self):
     """Unparse Python stdlib correctly even if parsed using built-in ast package."""
     for path in PATHS:
         with open(path, 'r', encoding='utf-8') as py_file:
             original_code = py_file.read()
         tree = ast.parse(source=original_code, filename=path)
         code = typed_astunparse.unparse(tree)
         roundtrip_tree = ast.parse(source=code)
         tree_dump = ast.dump(tree, include_attributes=False)
         roundtrip_tree_dump = ast.dump(roundtrip_tree, include_attributes=False)
         self.assertEqual(tree_dump, roundtrip_tree_dump, msg=path)
示例#12
0
def transform(transformer, before):
    tree = parse(before)
    try:
        transformer().visit(tree)
        return unparse(tree).strip()
    except:
        print('Before:')
        print(dump(parse(before)))
        print('After:')
        print(dump(tree))
        raise
示例#13
0
 def test_files(self):
     """Keep Python stdlib tree the same after roundtrip parse-unparse."""
     for path in PATHS:
         with open(path, 'r', encoding='utf-8') as py_file:
             original_code = py_file.read()
         tree = typed_ast.ast3.parse(source=original_code, filename=path)
         code = typed_astunparse.unparse(tree)
         roundtrip_tree = typed_ast.ast3.parse(source=code)
         tree_dump = typed_ast.ast3.dump(tree, include_attributes=False)
         roundtrip_tree_dump = typed_ast.ast3.dump(roundtrip_tree, include_attributes=False)
         self.assertEqual(tree_dump, roundtrip_tree_dump, msg=path)
示例#14
0
def test_replace_at(as_ast, as_str):
    def fn():
        print('hi there')

    tree = as_ast(fn)
    replace_at(0, tree.body[0], to_insert.get_body())

    def fn():
        print(10)

    expected_code = as_str(fn)
    assert unparse(tree).strip() == expected_code
示例#15
0
def test_extend_tree():
    tree = ast.parse('''
x = 1
extend(y)
    ''')
    extend_tree(tree, {'y': to_extend.get_body()})
    code = unparse(tree)
    expected = '''
x = 1
y = 5
    '''
    assert code.strip() == expected.strip()
示例#16
0
 def test_generalize_examples(self):
     code_reader = CodeReader()
     parser = C99Parser()
     ast_generalizer = CAstGeneralizer()
     for path in EXAMPLES_C11_FILES:
         code = code_reader.read_file(path)
         tree = parser.parse(code, path)
         basic_check_c_ast(self, path, tree)
         tree = ast_generalizer.generalize(tree)
         basic_check_python_ast(self, path, tree)
         _LOG.debug('%s', typed_astunparse.dump(tree))
         _LOG.debug('%s', typed_astunparse.unparse(tree))
示例#17
0
 def _unsupported_syntax(self, tree):
     unparsed = 'invalid'
     try:
         unparsed = '"""{}"""'.format(
             typed_astunparse.unparse(tree).strip())
     except AttributeError:
         pass
     self.fill('unsupported_syntax')
     raise SyntaxError(
         'unparsing {} like """{}""" ({} in Python) is unsupported for {}'.
         format(tree.__class__.__name__, typed_ast3.dump(tree), unparsed,
                self.lang_name))
def as_init_str(init_args):
    """
    Create the __init__ string by using unparse in ast.
    Args:
        init_args: The ast args object of the init arguments.

    Returns:

    """
    # Unparsing the `__init__` args and normalising the string
    args = ast_unparse.unparse(init_args).strip().split(",", 1)
    return args[1].strip().replace("  ", "")
 def test_generalize_examples(self):
     code_reader = CodeReader()
     parser = CppParser()
     for path in EXAMPLES_CPP14_FILES:
         ast_generalizer = CppAstGeneralizer(scope={'path': path})
         code = code_reader.read_file(path)
         tree = parser.parse(code, path)
         basic_check_cpp_ast(self, path, tree)
         with self.subTest(path=path):
             tree = ast_generalizer.generalize(tree)
             basic_check_python_ast(self, path, tree)
             _LOG.debug('%s', typed_astunparse.dump(tree))
             _LOG.debug('%s', typed_astunparse.unparse(tree))
示例#20
0
 def test_generalize_examples(self, input_path):
     code_reader = CodeReader()
     code = code_reader.read_file(input_path)
     parser = CppParser()
     cpp_ast = parser.parse(code, input_path)
     basic_check_cpp_ast(self, input_path, cpp_ast)
     ast_generalizer = CppAstGeneralizer(scope={'path': input_path})
     with _TIME.measure('generalize.{}'.format(input_path.name.replace('.', '_'))) as timer:
         syntax = ast_generalizer.generalize(cpp_ast)
     basic_check_python_ast(self, input_path, syntax)
     _LOG.info('generalized "%s" in %fs', input_path, timer.elapsed)
     _LOG.debug('%s', typed_astunparse.dump(syntax))
     _LOG.debug('%s', typed_astunparse.unparse(syntax))
 def test_unparse_examples(self):
     """Unparse ASTs of examples correctly."""
     for description, example in itertools.chain(
             EXAMPLES.items(), UNVERIFIED_EXAMPLES.items()):
         for mode in MODES:
             if example['trees'][mode] is None:
                 continue
             with self.subTest(description=description):
                 code = typed_astunparse.unparse(example['trees'][mode])
                 _LOG.debug('%s', code)
                 code = code.strip()
                 self.assertEqual(code,
                                  example['code'],
                                  msg=(description, mode))
示例#22
0
 def visit_node(self, node):
     if not isinstance(node, ast.Assign) or len(node.targets) != 1 \
             or not isinstance(node.targets[0], ast.Attribute):
         return node
     obj_attr = typed_astunparse.unparse(node.targets[0]).strip()
     if obj_attr not in instrumented_targets:
         # _LOG.warning('discarding candidate %s', obj_attr)
         return node
     obj, attr = instrumented_targets[obj_attr]
     instrumentation = ast.parse(
         'protonn.parameters.core._observe({}, {}, {})'.format(
             repr(obj_attr), obj, repr(attr)),
         mode='eval')
     return [node, ast.Expr(instrumentation.body)]
示例#23
0
    def test_many_dump_roundtrips(self):
        """Preserve ASTs after unparse(parse(...unparse(parse(dump(tree)))...))."""
        for description, example in EXAMPLES.items():
            for mode in MODES:
                if example['trees'][mode] is None:
                    continue

                dump = typed_astunparse.dump(example['trees'][mode])
                for _ in range(4):
                    tree = typed_ast.ast3.parse(source=dump, mode=mode)
                    dump = typed_astunparse.unparse(tree)
                    _LOG.debug('%s', dump)
                    clean_dump = dump.replace('\n', '').replace(' ', '')
                    self.assertEqual(clean_dump, example['dumps'][mode], msg=(description, mode))
示例#24
0
 def test_files(self):
     """Keep Python stdlib tree the same after roundtrip parse-unparse."""
     for path in PATHS:
         with open(path, 'r', encoding='utf-8') as py_file:
             original_code = py_file.read()
         tree = typed_ast.ast3.parse(source=original_code, filename=path)
         code = typed_astunparse.unparse(tree)
         try:
             roundtrip_tree = typed_ast.ast3.parse(source=code)
         except SyntaxError as err:
             self.fail(msg='bad syntax after unparsing "{}"\n{}'.format(path, err))
         tree_dump = typed_ast.ast3.dump(tree, include_attributes=False)
         roundtrip_tree_dump = typed_ast.ast3.dump(roundtrip_tree, include_attributes=False)
         self.assertEqual(tree_dump, roundtrip_tree_dump, msg=path)
示例#25
0
    def test_many_roundtrips(self):
        """Prserve ASTs when doing parse(unparse(parse(...unparse(parse(code))...)))."""
        for description, example in EXAMPLES.items():
            for mode in MODES:
                if example['trees'][mode] is None:
                    continue

                tree = example['trees'][mode]
                for _ in range(4):
                    code = typed_astunparse.unparse(tree)
                    _LOG.debug('%s', code)
                    clean_code = code.strip()
                    self.assertEqual(clean_code, example['code'], msg=(description, mode))
                    tree = typed_ast.ast3.parse(source=code, mode=mode)
示例#26
0
def transform(path: str, code: str, target: CompilationTarget) -> str:
    """Applies all transformation for passed target."""
    from ..exceptions import TransformationError

    for transformer in transformers:
        tree = ast.parse(code, path)
        if transformer.target >= target:
            transformer().visit(tree)
        try:
            code = unparse(tree)
        except:
            raise TransformationError(path, transformer, dump(tree),
                                      format_exc())

    return fix_code(code)
 def test_files(self):
     """Keep Python stdlib tree the same after roundtrip parse-unparse."""
     for path in PATHS:
         if sys.version_info[:2] == (
                 3, 7) and pathlib.Path(path).name == 'dataclasses.py':
             continue
         with open(path, 'r', encoding='utf-8') as py_file:
             original_code = py_file.read()
         tree = typed_ast.ast3.parse(source=original_code, filename=path)
         code = typed_astunparse.unparse(tree)
         with self.subTest(path=path):
             roundtrip_tree = typed_ast.ast3.parse(source=code)
             tree_dump = typed_ast.ast3.dump(tree, include_attributes=False)
             roundtrip_tree_dump = typed_ast.ast3.dump(
                 roundtrip_tree, include_attributes=False)
             self.assertEqual(tree_dump, roundtrip_tree_dump, msg=path)
示例#28
0
    def test_many_dump_roundtrips(self):
        """Preserve ASTs after unparse(parse(...unparse(parse(dump(tree)))...))."""
        for description, example in EXAMPLES.items():
            for mode in MODES:
                if example['trees'][mode] is None:
                    continue

                dump = typed_astunparse.dump(example['trees'][mode])
                for _ in range(4):
                    tree = typed_ast.ast3.parse(source=dump, mode=mode)
                    dump = typed_astunparse.unparse(tree)
                    _LOG.debug('%s', dump)
                    clean_dump = dump.replace('\n', '').replace(' ', '')
                    self.assertEqual(clean_dump,
                                     example['dumps'][mode],
                                     msg=(description, mode))
示例#29
0
def annotate(fname, outfname, line2shape, debug=False):
    tree = ast.parse(open(fname).read())
    #astpretty.pprint (tree)

    ann = Annotator(line2shape)
    tree = ann.visit(tree)

    if debug:
        print(line2shape)

    #treestr = astpretty.pformat(tree)
    #astpretty.pprint (tree)
    code = typed_astunparse.unparse(tree)
    #print (code)
    print(f'Writing to annotated file {outfname}')
    with open(outfname, 'w') as f:
        f.write(code)
示例#30
0
 def _inline_call(self, call, replacers):
     # template_code = '''for dummy_variable in (0,):\n    pass'''
     # inlined_call = typed_ast3.parse(template_code).body[0]
     call_code = typed_astunparse.unparse(call).strip()
     inlined_statements = []
     if self._verbose:
         inlined_statements.append(
             horast_nodes.Comment(' inlined {}'.format(call_code), eol=False))
     for stmt in self._inlined_function.body:
         stmt = st.augment(copy.deepcopy(stmt), eval_=False)
         for replacer in replacers:
             stmt = replacer.visit(stmt)
         if stmt is not None:
             inlined_statements.append(stmt)
     if self._verbose:
         inlined_statements.append(
             horast_nodes.Comment(' end of inlined {}'.format(call_code), eol=False))
     _LOG.warning('inlined a call %s using replacers %s', call_code, replacers)
     # inlined_call.body = scope
     # return st.augment(inlined_call), eval_=False)
     assert inlined_statements
     if len(inlined_statements) == 1:
         return inlined_statements[0]
     return inlined_statements
示例#31
0
def insert_comment_tokens_approx(
        tree: typed_ast.ast3.AST, tokens: t.List[tokenize.TokenInfo]) -> typed_ast.ast3.AST:
    assert isinstance(tree, typed_ast.ast3.AST)
    assert isinstance(tokens, list)
    token_locations = get_token_locations(tokens)
    _LOG.debug('token locations: %s', token_locations)
    nodes = ast_to_list(tree, only_localizable=True)
    if not nodes and tokens:
        _LOG.debug('overwriting empty AST with simplest editable tree')
        tree = typed_ast.ast3.Module(body=[], type_ignores=[], lineno=1, col_offset=0)
        nodes = ast_to_list(tree, only_localizable=True)
    node_locations = get_ast_node_locations(nodes)
    _LOG.debug('node locations: %s', node_locations)
    node_locations_iter = enumerate(node_locations)
    token_insertion_indices = []
    tokens_eol_status = []
    for token_index, token_location in enumerate(token_locations):
        eol_comment_here = False
        try:
            node_index, node_location = next(node_locations_iter)
        except StopIteration:
            node_index = len(node_locations)
            node_location = None
        while node_location is not None:
            token_line, _ = token_location
            node_line, _ = node_location
            if node_line > token_line:
                break
            if node_line == token_line:
                eol_comment_here = True
                if node_index < len(node_locations) - 1:
                    next_node_line, _ = node_locations[node_index + 1]
                    if next_node_line == token_line:
                        eol_comment_here = False
                # if eol_comment_here:
                #    raise NotImplementedError(
                #        'code "{}" and comment "{}" in line {}'
                #        ' -- only whole line comments are currently supported'
                #        .format(typed_astunparse.unparse(nodes[node_index]).strip(),
                #                tokens[token_index].string, node_line))
            try:
                node_index, node_location = next(node_locations_iter)
            except StopIteration:
                node_index = len(node_locations)
                break
        tokens_eol_status.append(eol_comment_here)
        token_insertion_indices.append(node_index)
    _LOG.debug('token insertion indices: %s', token_insertion_indices)
    _LOG.debug('tree before insertion:\n"""\n%s\n"""', typed_astunparse.dump(tree))
    _LOG.debug('code before insertion:\n"""\n%s\n"""', typed_astunparse.unparse(tree).strip())
    for token_index, token_insertion_index in reversed(list(enumerate(token_insertion_indices))):
        token = tokens[token_index]
        eol = tokens_eol_status[token_index]
        comment = Comment.from_token(token, eol)
        if token_insertion_index == 0:
            anchor = nodes[token_insertion_index]
            before_anchor = True
        elif token_insertion_index == len(node_locations):
            anchor = nodes[-1]
            before_anchor = False
        else:
            anchor = nodes[token_insertion_index - 1]
            before_anchor = False
        _LOG.debug('inserting %s %s %s', comment, 'before' if before_anchor else 'after', anchor)
        tree = insert_in_tree(tree, comment, anchor=anchor, before_anchor=before_anchor)
    _LOG.debug('tree after insertion:\n"""\n%s\n"""', typed_astunparse.dump(tree))
    # _LOG.warning('code after insertion:\n"""\n%s\n"""', typed_astunparse.unparse(tree).strip())
    return tree