Exemplo n.º 1
0
    def run():
        asdl_text = open('asdl/lang/py/py_asdl.txt').read()
        grammar = ASDLGrammar.from_text(asdl_text)

        annot_file = 'data/django/all.anno'
        code_file = 'data/django/all.code'

        transition_system = PythonTransitionSystem(grammar)

        for idx, (src_query, tgt_code) in enumerate(zip(open(annot_file), open(code_file))):
            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example(src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code).body[0]
            gold_source = astor.to_source(python_ast)
            tgt_ast = python_ast_to_asdl_ast(python_ast, grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # sanity check
            hyp = Hypothesis()
            hyp2 = Hypothesis()
            for action in tgt_actions:
                assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(hyp)
                hyp = hyp.clone_and_apply_action(action)
                hyp2.apply_action(action)

            src_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, grammar))
            assert src_from_hyp == gold_source
            assert hyp.tree == hyp2.tree and hyp.tree is not hyp2.tree

            print(idx)
Exemplo n.º 2
0
    def output(cls, node1, node2, operator_name):
        """
        Compare the original source code and mutant.
        :param node1:
        :param node2:
        :return:
        """
        if not os.path.exists(os.path.curdir + '/output'):
            os.mkdir(os.path.curdir + '/output')
        dest_dir = os.path.curdir + '/output/'

        if not os.path.isfile(dest_dir + 'original.py'):
            # write the original code to a file
            # original_code = codegen.to_source(node1)
            original_code = astor.to_source(node1, add_line_information=True)
            filename = "original.py"
            path = os.path.join(dest_dir, filename)
            cls.write_to_file(path, original_code)

        # write the mutated code to a file
        # mutated_code = codegen.to_source(node2)
        mutated_code = astor.to_source(node2, add_line_information=True)
        filename = None
        while True:
            timestamp = str(int(time.time()))
            filename = operator_name + "_mutant_" + timestamp + ".py"
            if not os.path.isfile(dest_dir + filename):
                break
        path = os.path.join(dest_dir, filename)
        cls.write_to_file(path, mutated_code)
Exemplo n.º 3
0
def compare(input_src, expected_output_src, transformer_class):
    """
    Testing utility. Takes the input source and transforms it with
    `transformer_class`. It then compares the output with the given
    reference and throws an exception if they don't match.
    
    This method also deals with name-mangling.
    """
    
    uid = naming.UniqueIdentifierFactory()
    
    actual_root = ast.parse(unindent(input_src))
    EncodeNames().visit(actual_root)
    actual_root = transformer_class(uid).checked_visit(actual_root)
    actual_root = ast.fix_missing_locations(actual_root)
    compile(actual_root, "<string>", 'exec')
    actual_src = astor.to_source(actual_root)
    expected_root = ast.parse(unindent(expected_output_src))
    EncodeNames().visit(expected_root)
    expected_src = astor.to_source(expected_root)
    
    cmps = itertools.izip_longest(expected_src.splitlines(), actual_src.splitlines())
    for linenr, c in enumerate(cmps, 1):
        expected_line = c[0]
        actual_line = c[1]
        if expected_line != actual_line:
            sys.stderr.write(actual_src)
            sys.stderr.write("\n")
        if expected_line != actual_line:
            raise AssertionError("Line %s differs. Expected %s but got %s." % (linenr, repr(expected_line), repr(actual_line)))
Exemplo n.º 4
0
def extract_grammar(code_file, prefix='py'):
    line_num = 0
    parse_trees = []
    for line in open(code_file):
        code = line.strip()
        parse_tree = parse(code)

        # leaves = parse_tree.get_leaves()
        # for leaf in leaves:
        #     if not is_terminal_type(leaf.type):
        #         print parse_tree

        # parse_tree = add_root(parse_tree)

        parse_trees.append(parse_tree)

        # sanity check
        ast_tree = parse_tree_to_python_ast(parse_tree)
        ref_ast_tree = ast.parse(canonicalize_code(code)).body[0]
        source1 = astor.to_source(ast_tree)
        source2 = astor.to_source(ref_ast_tree)

        assert source1 == source2

        # check rules
        # rule_list = parse_tree.get_rule_list(include_leaf=True)
        # for rule in rule_list:
        #     if rule.parent.type == int and rule.children[0].type == int:
        #         # rule.parent.type == str and rule.children[0].type == str:
        #         pass

        # ast_tree = tree_to_ast(parse_tree)
        # print astor.to_source(ast_tree)
            # print parse_tree
        # except Exception as e:
        #     error_num += 1
        #     #pass
        #     #print e

        line_num += 1

    print 'total line of code: %d' % line_num

    grammar = get_grammar(parse_trees)

    with open(prefix + '.grammar.txt', 'w') as f:
        for rule in grammar:
            str = rule.__repr__()
            f.write(str + '\n')

    with open(prefix + '.parse_trees.txt', 'w') as f:
        for tree in parse_trees:
            f.write(tree.__repr__() + '\n')

    return grammar, parse_trees
Exemplo n.º 5
0
def process_query(query, code):
    from parse import code_to_ast, ast_to_tree, tree_to_ast, parse
    import astor
    str_count = 0
    str_map = dict()

    match_count = 1
    match = QUOTED_STRING_RE.search(query)
    while match:
        str_repr = '_STR:%d_' % str_count
        str_literal = match.group(0)
        str_string = match.group(2)

        match_count += 1

        # if match_count > 50:
        #     return
        #

        query = QUOTED_STRING_RE.sub(str_repr, query, 1)
        str_map[str_literal] = str_repr

        str_count += 1
        match = QUOTED_STRING_RE.search(query)

        code = code.replace(str_literal, '\'' + str_repr + '\'')

    # clean the annotation
    # query = query.replace('.', ' . ')

    for k, v in str_map.iteritems():
        if k == '\'%s\'' or k == '\"%s\"':
            query = query.replace(v, k)
            code = code.replace('\'' + v + '\'', k)

    # tokenize
    query_tokens = nltk.word_tokenize(query)

    new_query_tokens = []
    # break up function calls
    for token in query_tokens:
        new_query_tokens.append(token)
        i = token.find('.')
        if 0 < i < len(token) - 1:
            new_tokens = ['['] + token.replace('.', ' . ').split(' ') + [']']
            new_query_tokens.extend(new_tokens)

    # check if the code compiles
    tree = parse(code)
    ast_tree = tree_to_ast(tree)
    astor.to_source(ast_tree)

    return new_query_tokens, code, str_map
Exemplo n.º 6
0
def _log_failure(arg_num, msg=None):
    """ Retrace stack and log the failed expresion information """

    # stack() returns a list of frame records
    #   0 is the _log_failure() function
    #   1 is the expect() function 
    #   2 is the function that called expect(), that's what we want
    #
    # a frame record is a tuple like this:
    #   (frame, filename, line, funcname, contextlist, index)
    # we're only interested in the first 4. 
    frame,  filename, file_lineno, funcname = inspect.stack()[2][:4]
    # Note that a frame object should be deleted once used to be safe and stop possible 
    # memory leak from circular referencing 
    try:
        frame_source_lines, frame_start_lineno = (inspect.getsourcelines(frame)) 
    finally:
        del frame

    filename = os.path.basename(filename)

    # Build abstract syntax tree from source of frame
    source_ast = ast.parse(''.join(frame_source_lines))

    # Locate the executed expect function 
    func_body = source_ast.body[0].body

    map_lineno_to_node = {}
    for idx, node in enumerate(func_body):
        map_lineno_to_node[node.lineno] = node
    
    last_lineno = file_lineno - frame_start_lineno + 1

    element_idx = [x for x in map_lineno_to_node.keys() if x <= last_lineno]
    element_idx = max(element_idx)

    expect_function_ast = map_lineno_to_node[element_idx]

    # Return the source code of the numbered argument
    arg = expect_function_ast.value.args[arg_num]
    line = arg.lineno
    if isinstance(arg, (ast.Tuple, ast.List)):
        expr = astor.to_source(arg.elts[0])
    else:
        expr = astor.to_source(arg)

    filename = os.path.basename(filename)

    failure_info = {'file': filename, 'line': line, 'funcname': funcname, 'msg': msg, 'expression': expr}

    _failed_expectations.append(failure_info)
Exemplo n.º 7
0
def test_more_captures():
    name_types = (ast.Name, ast.arg) if six.PY3 else ast.Name

    @compile_template
    def map_lambda(var=name_types, body=ast.expr, seq=ast.expr):
        map(lambda var: body, seq)

    @get_body_ast
    def tree():
        squares = map(lambda x: x ** 2, range(10))

    m = match(map_lambda, tree)[0]
    assert astor.to_source(m.captures['body']).strip() == '(x ** 2)'
    assert astor.to_source(m.captures['seq']).strip() == 'range(10)'
Exemplo n.º 8
0
 def compare(self, src, expected_src):
     actual_root = ast.parse(utils.unindent(src))
     scoping.ScopeAssigner().visit(actual_root)
     EncodeScopeInIdentifier().visit(actual_root)
     actual_src = astor.to_source(actual_root)
     
     expected_root = ast.parse(utils.unindent(expected_src))
     expected_src = astor.to_source(expected_root)
             
     cmps = itertools.izip_longest(expected_src.splitlines(), actual_src.splitlines())
     for linenr, c in enumerate(cmps, 1):
         expected_line = c[0]
         actual_line = c[1]
         self.assertEqual(expected_line, actual_line, "Line %s differs. Expected %s but got %s." % (linenr, repr(expected_line), repr(actual_line)))
Exemplo n.º 9
0
def parse_bbscript(f,basename,filename,filesize):
    global commandDB,astRoot,charName,j,MODE
    BASE = f.tell()
    astRoot = Module(body=[])
    j = OrderedDict()
    j["Functions"] = []
    j["FunctionsPy"] = []
    charName = filename[-6:-4]
    FUNCTION_COUNT, = struct.unpack(MODE+"I",f.read(4))
   # f.seek(BASE+4+0x20)
   # initEnd, = struct.unpack(MODE+"I",f.read(4))
   # initEnd = BASE + initEnd+4+0x24*FUNCTION_COUNT
   # initEnd = BASE+filesize
    f.seek(BASE+4+0x24*(FUNCTION_COUNT))
    parse_bbscript_routine(f,os.path.getsize(f.name))
    '''
    for i in range(0,FUNCTION_COUNT):
        f.seek(BASE+4+0x24*i)
        FUNCTION_NAME = f.read(0x20).split("\x00")[0]
        if log: log.write("\n#---------------{0} {1}/{2}\n".format(FUNCTION_NAME,i,FUNCTION_COUNT))
        FUNCTION_OFFSET, = struct.unpack(MODE+"I",f.read(4))
        f.seek(BASE+4+0x24*FUNCTION_COUNT+FUNCTION_OFFSET)
        parse_bbscript_routine(f)
    '''
    if len(sys.argv) == 3:
        outpath = os.path.join(sys.argv[2],filename[:-4] + '.py')
    else:
        outpath = filename[:-4] + '.py'
    py = open(outpath,"wb")
    py.write(astor.to_source(astRoot))
    py.close()
    return filename,j
Exemplo n.º 10
0
def canonicalize_hs_example(query, code):
    query = re.sub(r'<.*?>', '', query)
    query_tokens = nltk.word_tokenize(query)

    code = code.replace('§', '\n').strip()

    # sanity check
    parse_tree = parse_raw(code)
    gold_ast_tree = ast.parse(code).body[0]
    gold_source = astor.to_source(gold_ast_tree)
    ast_tree = parse_tree_to_python_ast(parse_tree)
    pred_source = astor.to_source(ast_tree)

    assert gold_source == pred_source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, pred_source)

    return query_tokens, code, parse_tree
Exemplo n.º 11
0
    def canonicalize_raw_django_oneliner(code):
        # use the astor-style code
        code = Django.canonicalize_code(code)
        py_ast = ast.parse(code).body[0]
        code = astor.to_source(py_ast).strip()

        return code
Exemplo n.º 12
0
 def log_mutant(self, active_file, logger):
     """ Prints a one-line summary to highlight the difference between the original code and the mutant
     split('\n')[0] is used to truncate if/elif mutation instances (entire if sections were printed before)
     """
     logger.info("{0} - Line {1}".format(active_file, self.line_no))
     logger.info("Original: {0}".format(self.original_source.split('\n')[0]))
     logger.info("Mutant  : {0}".format(astor.to_source(self.base_node)).split('\n')[0])
Exemplo n.º 13
0
    def startJob(self, body):
        project = json.loads(body)

        os.mkdir(self.job_dir)

        # sprit_id = project['sprite_idx']
        block_id = project['block_idx']

        # Write uploaded program
        xml = os.path.join(self.job_dir, 'job.xml')
        with open(xml, 'w') as file:
            file.write(project['project'].encode('utf-8'))

        # Parse and write python program
        p = parser.parses(project['project'].encode('utf-8'))
        ctx = p.create_context()
        file_ast = p.to_ast(ctx, 'main_%s' % block_id)
        code = astor.to_source(file_ast)
        program = os.path.join(self.job_dir, 'job.py')
        with open(program, 'w') as file:
            file.write(code)
        self.job_process = JobProcess(self, self.id)
        reactor.spawnProcess(
            self.job_process, sys.executable,
            [sys.executable, program], env=os.environ)
Exemplo n.º 14
0
 def visit_If(self, node):
   """Converts "if A: B" to "A <= B"."""
   if not self._auto_constrain:
     return self.generic_visit(node)
   if len(node.body) == 0:
     _fail(node, msg='If statement missing body expressions')
   test = self.visit(node.test)
   body = self.visit_if_body(node.body)
   orelse = self.visit_if_body(node.orelse)
   assignments, body, orelse = _collect_conditional_assignments(
       test, body, orelse)
   constraints, body, orelse = _collect_conditional_constraints(
       test, body, orelse)
   if body or orelse:
     remaining = ast.If(test=node.test, body=body, orelse=orelse)
     _fail(remaining, msg='if statement expressions unconverted')
   result = []
   for assignment in assignments:
     result.append(self.visit(assignment))
   for constraint in constraints:
     result.append(self._constrain_expr(self.visit(constraint)))
   return ast.If(
       test=ast.Str(s=astor.to_source(node.test).replace('\n', ' ').strip()),
       body=result,
       orelse=[],
   )
Exemplo n.º 15
0
def parse_bbscript(f,basename,dirname):
    global commandDB,astRoot,j,MODE
    BASE = f.tell()
    astRoot = Module(body=[])
    j = OrderedDict()
    j["Functions"] = []
    j["FunctionsPy"] = []
    f.seek(0x30)
    filesize = struct.unpack(MODE+"I",f.read(4))[0]
    f.seek(0x38)
    FUNCTION_COUNT, = struct.unpack(MODE+"I",f.read(4))
    f.seek(0x24*(FUNCTION_COUNT),1)
    parse_bbscript_routine(f, filesize + 0x38)
    '''
    for i in range(0,FUNCTION_COUNT):
        f.seek(BASE+4+0x24*i)
        FUNCTION_NAME = f.read(0x20).split("\x00")[0]
        if log: log.write("\n#---------------{0} {1}/{2}\n".format(FUNCTION_NAME,i,FUNCTION_COUNT))
        FUNCTION_OFFSET, = struct.unpack(MODE+"I",f.read(4))
        f.seek(BASE+4+0x24*FUNCTION_COUNT+FUNCTION_OFFSET)
        parse_bbscript_routine(f)
    '''
    py = open(os.path.join(dirname, basename) + ".py","w")
    py.write(astor.to_source(astRoot))
    py.close()
    return j
Exemplo n.º 16
0
def _fail(node, msg='Visit error'):
  try:
    raise NotImplementedError('%s (in ast.%s). Source:\n\t\t\t%s' % (
      msg, node.__class__.__name__, astor.to_source(node)))
  except AttributeError:  # Astor was unable to convert the source.
    raise NotImplementedError('%s (in ast.%s).' % (
      msg, node.__class__.__name__))
Exemplo n.º 17
0
 def execute(self, src, output_contains=None):
     src = utils.unindent(src)
     
     expected = self._run(src)
     
     node = ast.parse(src)
     node = saneitizer.Saneitizer().process(node)
     naming.MakeIdsValid().visit(node)
     
     transformed_code = astor.to_source(node)
     
     pydron_builtins = "from pydron.translation.builtins import *"
     transformed_code = pydron_builtins + "\n\n" + transformed_code
     
     try:
         # just to see if it compiles
         compile(node, "[string]", 'exec')
         
         # we actually use the source code to run
         actual = self._run(transformed_code)
         
         self.assertEqual(actual, expected)
         if output_contains:
             self.assertIn(output_contains, actual)
     except:
         sys.stderr.write(transformed_code)
         sys.stderr.write("\n\n")
         raise
Exemplo n.º 18
0
 def assertAstSourceEqual(self, srctxt):
     """This asserts that the reconstituted source
        code is identical to the original source code.
        This is a much stronger statement than assertAstEqual,
        which may not always be appropriate.
     """
     srctxt = canonical(srctxt)
     self.assertEqual(astor.to_source(ast.parse(srctxt)).rstrip(), srctxt)
Exemplo n.º 19
0
    def hyp_correct(self, hyp, example):
        ref_code = example.tgt_code
        ref_py_ast = ast.parse(ref_code).body[0]
        ref_reformatted_code = astor.to_source(ref_py_ast).strip()

        ref_code_tokens = tokenize_code(ref_reformatted_code)
        hyp_code_tokens = tokenize_code(hyp.code)

        return ref_code_tokens == hyp_code_tokens
Exemplo n.º 20
0
 def test_pass_arguments_node(self):
     source = textwrap.dedent("""\
     j = [1, 2, 3]
     def test(a1, a2, b1=j, b2='123', b3={}, b4=[]):
         pass""")
     root_node = ast.parse(source)
     arguments_node = [n for n in ast.walk(root_node)
                       if isinstance(n, ast.arguments)][0]
     self.assertEqual(astor.to_source(arguments_node),
                      "a1, a2, b1=j, b2='123', b3={}, b4=[]")
Exemplo n.º 21
0
    def __parse_abort(self, node):
        status, n = node.value.args
        help_ = list(filter(lambda x: x.arg == 'help_text', node.value.keywords))
        message = astor.to_source(n)
        if type(n) is ast.Str:
            message = n.s

        error = OrderedDict({
            'status': status.n,
            'message': message
        })

        if any(help_):
            help_text = astor.to_source(help_[0].value)
            if type(help_[0].value) is ast.Str:
                help_text = help_[0].value.s

            error['help'] = help_text
        self.errors.append(error)
Exemplo n.º 22
0
def print_source(my_ast, source_file='out/inline'):
	try:
		source = codegen.to_source(my_ast)
		if os.path.exists('out'):
			open(source_file + ".py", 'wt').write(source)
		print(source)  # => CODE
	except:
		# raise
		print("SOURCE NOT STANDARD CONFORM")
		import traceback
		traceback.print_exc()  # backtrace
Exemplo n.º 23
0
 def assertAstEqual(self, srctxt):
     """This asserts that the reconstituted source
        code can be compiled into the exact same AST
        as the original source code.
     """
     srctxt = canonical(srctxt)
     srcast = ast.parse(srctxt)
     dsttxt = astor.to_source(srcast)
     dstast = ast.parse(dsttxt)
     srcdmp = astor.dump_tree(srcast)
     dstdmp = astor.dump_tree(dstast)
     self.assertEqual(dstdmp, srcdmp)
Exemplo n.º 24
0
 def assertCompiles(self, expr, code):
     ec = ExpressionCompiler({'foo', 'baz'})
     py_expr = ec.visit(to_expr(expr, {}))
     first = astor.to_source(py_expr)
     if not PY3:
         first = first.replace("u'", "'")
     second = dedent(code).strip()
     if first != second:
         msg = ('Compiled code is not equal:\n\n{}'
                .format('\n'.join(difflib.ndiff(first.splitlines(),
                                                second.splitlines()))))
         raise self.failureException(msg)
Exemplo n.º 25
0
    def test_parse(self):
        snap_stdlib.cleanReport()
        if self.script:
            document = self.xml.format(
                script=self.script,
                block="".join(self.blocks) or self.block)
        else:
            document = self.xml
        parser = snap_parser.parses(document)
        ctx = parser.create_context()
        script = parser.to_ast(ctx)
        ast.fix_missing_locations(script)
        try:
            code = compile(script, '<string>', 'exec')
            module = imp.new_module(__name__ + '.compiled_block')

            exec code in module.__dict__
            module.main_0()

            if self.report:
                self.assertEqual(
                    module.stdlib._report, self.report,
                    "%s != %s\ncode::\n\n%s" % (module.stdlib._report,
                                                self.report,
                                                astor.to_source(script)))
            if self.vars:
                self.assertEqual(
                    module._globals, self.vars,
                    "%s != %s\ncode::\n\n%s" % (module._globals,
                                                self.vars,
                                                astor.to_source(script)))
        except:
            print "Generated AST object\n", ast.dump(script)
            parsed = ast.parse(astor.to_source(script))
            print "Parsed AST object\n", ast.dump(parsed)
            print "Script\n", astor.to_source(script)
            raise

        self.assertTrue(len(parser.stack) == 0, parser.stack)
Exemplo n.º 26
0
 def get_subscript_or_attribute(self, node):
     var = None
     if isinstance(node, ast.Subscript) or isinstance(node, ast.Attribute):
         src = astor.to_source(node)
         idx = src.rfind('[') if isinstance(node, ast.Subscript) else src.rfind('.')
         var_name = src[:idx]
         obj_var = self.get_obj_var(var_name)
         var = self.get_var(src)
         obj_var.lines.append(node.lineno)
         var.lines.append(node.lineno)
         if obj_var.children.get(src) is None:
             obj_var.children[src] = var
     return var
Exemplo n.º 27
0
def getHistory(history):
    from itertools import islice
    if history is None and '__IPYTHON__' in __builtins__:
        import IPython
        ip = IPython.core.getipython.get_ipython()
        history = ip.history_manager.input_hist_parsed  # @UndefinesdVariable

    newhistory = []
    for line in islice(history, len(history)-1):
        with ignored(SyntaxError):
            for lline in ast.parse(line).body:
                newhistory.append(astor.to_source(lline))
    return newhistory
 def visit_Expr(self, node):
   """
     Visits expressions and check if they are in the form of either 
     `environment.define` or `environment.undefine` properly stores the 
     arguments definition as string.
   """
   value = node.value
   if isinstance(value, ast.Call):
     function = value.func
     if isinstance(function, ast.Attribute):
       attribute = function.value
       if isinstance(attribute, ast.Name):
         name = attribute.id
         if name == 'environment' and function.attr == 'define' and not value.keywords:
           if not len(value.args) == 2:
             message = (
               'Not enough arguments for environment definition. Function '
               'name and alias are required.'
             )
             raise Exception(message)
           func_name = value.args[0].id
           func_alias = value.args[1].s
           function_node = self.function_dict[func_name]
           function_string = astor.to_source(function_node)
           self.environment_setup_dict[func_name] = {
             "code": function_string,
             "alias": func_alias
           }
         elif name == 'environment' and function.attr == 'define' and value.keywords:
           for keyword in value.keywords:
             arg_name = keyword.arg
             arg_value_node = keyword.value
             
             # The value can be a number, string or name. We need to handle 
             # them separatedly. This dict trick was used to avoid the very
             # ugly if.
             node_value_dict = {
               ast.Num: lambda node: str(node.n),
               ast.Str: lambda node: node.s,
               ast.Name: lambda node: node.id
             }
             arg_value = node_value_dict[type(arg_value_node)](arg_value_node)
             self.environment_var_dict[arg_name] = arg_value
         elif name == 'environment' and function.attr == 'undefine':
           func_alias = value.args[0].s
           self.environment_remove_list.append(func_alias)
         elif name == 'environment' and function.attr == 'clearAll':
           self.environment_clear_all = True
         elif name == 'environment'and function.attr == 'showSetup':
           self.show_environment_setup = True
   return node
Exemplo n.º 29
0
def get_trait_definition(parent, trait_name):
    """ Retrieve the Trait attribute definition from the source file.

    Parameters
    ----------
    parent :
        The module or class where the trait is defined.

    trait_name : string
        The name of the trait.

    Returns
    -------
    definition : string
        The trait definition from the source.

    """
    # Get the class source.
    source = inspect.getsource(parent)
    nodes = ast.parse(source)

    if not inspect.ismodule(parent):
        for node in ast.iter_child_nodes(nodes):
            if isinstance(node, ClassDef):
                parent_node = node
                break
        else:
            message = 'Could not find class definition {0} for {1}'
            raise DefinitionError(message.format(parent, trait_name))
    else:
        parent_node = nodes

    # Get the container node(s)
    targets = collections.defaultdict(list)
    for node in ast.walk(parent_node):
        if isinstance(node, Assign):
            target = trait_node(node, trait_name)
            if target is not None:
                targets[node.col_offset].append((node, target))

    if len(targets) == 0:
        message = 'Could not find trait definition of {0} in {1}'
        raise DefinitionError(message.format(trait_name, parent))
    else:
        # keep the assignment with the smallest column offset
        assignments = targets[min(targets)]
        # we always get the last assignment in the file
        node, name = assignments[-1]

    return astor.to_source(node.value).strip()
    def inject_into_main_urls(self):
        urls_file = open(os.path.join(self.PROJECT_DIR,"urls.py"), "r+")
        urls_ast = ast.parse(urls_file.read())
        is_api_url_allredy_injected = False
        
        class CheckIsUrlsPatched(ast.NodeTransformer):
            is_api_url_allredy_injected = False
            def visit_Call(self, node):
                if len(node.args) > 0 and node.args[0].s == '^api/v1':
                    self.is_api_url_allredy_injected = True
                return node

        class MainUrlTransformer(ast.NodeTransformer):
            def visit_Assign(self, node):
                if node.targets[0].id == "urlpatterns":
                    api_url_elt = ast.Call(
                        func=ast.Name(id='url', ctx=ast.Load()),
                        args=[ast.Str('^api/v1'), 'include(api_urls.router.urls)'],
                        keywords = [], starargs=None, kwargs=None
                    )
                    api_url_elt.lineno = 5                    
                    node.value.elts.insert(2,api_url_elt)
                return node

        check_is_urls_patched_visitor = CheckIsUrlsPatched()
        check_is_urls_patched_visitor.visit(urls_ast)


        if check_is_urls_patched_visitor.is_api_url_allredy_injected:
            print_warn("  api_urls is already injected to main urls - SKIP")
            return

        urls_ast.body.insert(0, ast.Import(
            names=[
                ast.alias(name='%s.api_urls' % (self.BASE_APP,), asname='api_urls')
            ]
            )
        )

        MainUrlTransformer().visit(urls_ast)
        ast.fix_missing_locations(urls_ast)
        
        urls_file.seek(0)

        out_code = FormatCode(astor.to_source(urls_ast, add_line_information=False))[0]
                
        urls_file.write(out_code)
        urls_file.truncate()
        urls_file.close()
        print_ok("      %s/urls.py - MODIFIED" %(self.BASE_APP, ))
Exemplo n.º 31
0
def _(ra: ras.Select, ctx: Context):
    # Function + Loop body
    filter_ast = ast.parse(inspect.getsource(FILTER_FUNC))
    filter_func = filter_ast.body[0]
    filter_func.name = unique_name('filter')

    # Create comparison
    cmp_ast = single_row_expression_func(ra.operands[1], ctx)

    # Inject comparison function defintion into Loop body
    filter_func.body.insert(0, cmp_ast)

    for _import in ctx.imports_required:
        filter_func.body.insert(0, _import)

    filter_ast = ast_transformer.FindAndReplaceNames({
        '__cmp_func__': cmp_ast,
    }).visit(filter_ast)

    ast.fix_missing_locations(filter_ast)

    logger.debug(astor.to_source(filter_ast))

    return filter_ast
Exemplo n.º 32
0
def func(foo):
    from .impl import get_runtime
    src = remove_indent(inspect.getsource(foo))
    tree = ast.parse(src)

    func_body = tree.body[0]
    func_body.decorator_list = []

    visitor = ASTTransformer(is_kernel=False)
    visitor.visit(tree)
    ast.fix_missing_locations(tree)

    if get_runtime().print_preprocessed:
        import astor
        print('After preprocessing:')
        print(astor.to_source(tree.body[0], indent_with='  '))

    ast.increment_lineno(tree, inspect.getsourcelines(foo)[1] - 1)

    frame = inspect.currentframe().f_back
    exec(compile(tree, filename=inspect.getsourcefile(foo), mode='exec'),
         dict(frame.f_globals, **frame.f_locals), locals())
    compiled = locals()[foo.__name__]
    return compiled
Exemplo n.º 33
0
def update_args_of_func(node, dygraph_node, method_name):
    assert isinstance(node, gast.Call)
    if method_name not in ["__init__", "forward"]:
        raise ValueError(
            "The method name of class to update args should be '__init__' or 'forward'"
        )

    class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func))
    import paddle.fluid as fluid
    if method_name == "__init__" or eval(
            "issubclass({}, fluid.dygraph.Layer)".format(class_src)):
        full_args = eval("inspect.getargspec({}.{})".format(class_src,
                                                            method_name))
        full_args_name = [
            arg_name for arg_name in full_args[0] if arg_name != "self"
        ]
    else:
        full_args_name = []
    added_keywords = []
    for idx, arg in enumerate(node.args):
        added_keywords.append(gast.keyword(arg=full_args_name[idx], value=arg))

    node.args = []
    node.keywords = added_keywords + node.keywords
Exemplo n.º 34
0
    def build_Assert(ctx, node):
        extra_args = ast.List(elts=[], ctx=ast.Load())
        if node.msg is not None:
            if isinstance(node.msg, ast.Constant):
                msg = node.msg.value
            elif isinstance(node.msg, ast.Str):
                msg = node.msg.s
            elif StmtBuilder._is_string_mod_args(node.msg):
                msg = build_expr(ctx, node.msg)
                msg, extra_args = StmtBuilder._handle_string_mod_args(ctx, msg)
            else:
                raise ValueError(
                    f"assert info must be constant, not {ast.dump(node.msg)}")
        else:
            import astor
            msg = astor.to_source(node.test)
        node.test = build_expr(ctx, node.test)

        new_node = parse_stmt('ti.ti_assert(0, 0, [])')
        new_node.value.args[0] = node.test
        new_node.value.args[1] = parse_expr("'{}'".format(msg.strip()))
        new_node.value.args[2] = extra_args
        new_node = ast.copy_location(new_node, node)
        return new_node
Exemplo n.º 35
0
def parse_nni_variable(code):
    """Parse `nni.variable` expression.
    Return the name argument and AST node of annotated expression.
    code: annotation string
    """
    name, call = parse_annotation_function(code, 'variable')

    assert len(call.args) == 1, 'nni.variable contains more than one arguments'
    arg = call.args[0]
    assert type(
        arg) is ast.Call, 'Value of nni.variable is not a function call'
    assert type(
        arg.func) is ast.Attribute, 'nni.variable value is not a NNI function'
    assert type(
        arg.func.value) is ast.Name, 'nni.variable value is not a NNI function'
    assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'

    name_str = astor.to_source(name).strip()
    keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str))
    arg.keywords.append(keyword_arg)
    if arg.func.attr == 'choice':
        convert_args_to_dict(arg)

    return name, arg
Exemplo n.º 36
0
def parse(code, para, module):
    """Annotate user code.
    Return annotated code (str) if annotation detected; return None if not.
    code: original user code (str)
    """
    global para_cfg
    global prefix_name
    para_cfg = para
    prefix_name = module
    try:
        ast_tree = ast.parse(code)
    except Exception:
        raise RuntimeError('Bad Python code')

    transformer = Transformer()
    try:
        transformer.visit(ast_tree)
    except AssertionError as exc:
        raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0]))

    if not transformer.annotated:
        return None

    return astor.to_source(ast_tree)
Exemplo n.º 37
0
 def visit_Lambda(self, node):
     try:
         logging.info("Transforming Lambda: " + to_source(node))
     except:
         logging.info("Transforming Lambda: " + ast.dump(node))
     method_name = "_lambda_" + str(self.id_counter)
     SQLiteSerializer.serializeLambdaTransformation(self.filename, self.id_counter, node.lineno)
     self.methods.append(getLocalsFunction(method_name))
     self.methods.append(getLambdaMethod(
         method_name, node.body, node.args, self.filename, self.id_counter))
     self.globalvars.append(getGlobalVariable(method_name))
     self.id_counter += 1
     return Call(
         func=Name(
             id=method_name, ctx=Load()),
         args=[
             Call(
                 func=Name(id='locals', ctx=Load()),
                 args=[],
                 keywords=[]
             )
         ],
         keywords=[]
     )
Exemplo n.º 38
0
def replace_fluent_alias(source, fluent_mapping):
    fluent_mapping = {a: b for a, b in fluent_mapping}
    new_src = source
    for _ in range(100):  # 100 is a (random) big enough number
        replaced = False
        tree = ast.parse(new_src)
        for node in ast.walk(tree):
            if (isinstance(node, ast.Call)
                    and isinstance(node.func, ast.Attribute)
                    and isinstance(node.func.value, ast.Name)
                    and node.func.value.id == 'd2l'
                    and node.func.attr in fluent_mapping):
                new_node = ast.Call(
                    ast.Attribute(value=node.args[0],
                                  attr=fluent_mapping[node.func.attr]),
                    node.args[1:], node.keywords)
                new_src = new_src.replace(
                    ast.get_source_segment(new_src, node),
                    astor.to_source(new_node).rstrip())
                replaced = True
                break
        if not replaced:
            break
    return new_src
Exemplo n.º 39
0
def handle_function_def(filen, destination_tree, src_entry, update_args):
    """ Add or modify a 'def'

    Args:
        filen (str): The nname of the file being modified
        destination_tree (ast): An ast generated from the destination file
        src_entry (ast node): An ast node found missing from the destination_tree

    """
    found_by_name = False
    for i, dst_entry in enumerate(destination_tree.body):
        if isinstance(dst_entry, ast.FunctionDef):
            if dst_entry.name == src_entry.name:
                found_by_name = True
                if astor.to_source(dst_entry.args) != astor.to_source(
                        src_entry.args):
                    if update_args:
                        dst_entry.args = src_entry.args
                    else:
                        log_src = copy.copy(src_entry)
                        log_src.body = []
                        log_dst = copy.copy(dst_entry)
                        log_dst.body = []
                        PrintInColor.message(color='RED',
                                             action="warning",
                                             string=filen)
                        PrintInColor.diff(left=astor.to_source(log_src),
                                          right=astor.to_source(log_dst),
                                          fromfile="Codegen package",
                                          tofile="Project")
                elif astor.to_source(dst_entry.body[0]) != astor.to_source(
                        src_entry.body[0]):
                    if isinstance(src_entry.body[0], ast.Expr):
                        if isinstance(dst_entry.body[0], ast.Expr):
                            destination_tree.body[i].body[0] = src_entry.body[
                                0]
                        else:
                            destination_tree.body[i].body.insert(
                                0, src_entry.body[0])
    if not found_by_name:
        destination_tree.body.append(src_entry)
    return destination_tree
Exemplo n.º 40
0
def test_primitive_bool(assertion_to_ast):
    assertion = MagicMock(value=True)
    assertion_to_ast.visit_primitive_assertion(assertion)
    assert (astor.to_source(
        Module(body=assertion_to_ast.nodes)) == "assert var0 is True\n")
Exemplo n.º 41
0
def test_not_none(assertion_to_ast):
    assertion = MagicMock(value=False)
    assertion_to_ast.visit_none_assertion(assertion)
    assert (astor.to_source(
        Module(body=assertion_to_ast.nodes)) == "assert var0 is not None\n")
Exemplo n.º 42
0
    elif dct['ast_type'] == "GtE":
        return ast.GtE()
    elif dct['ast_type'] == "Is":
        return ast.Is()
    elif dct['ast_type'] == "IsNot":
        return ast.IsNot()
    elif dct['ast_type'] == "In":
        return ast.In()
    elif dct['ast_type'] == "NotIn":
        return ast.NotIn()
    elif dct['ast_type'] == "comprehension":
        return ast.comprehension(dct["target"], dct["iter"], dct["ifs"])
    elif dct['ast_type'] == "ExceptHandler":
        return ast.ExceptHandler(dct["type"], dct["name"], dct["body"])
    elif dct['ast_type'] == "arguments":
        return ast.arguments(dct["args"], dct["vararg"], dct["kwarg"],
                             dct["defaults"])
    elif dct['ast_type'] == "keyword":
        return ast.keyword(dct["arg"], dct["value"])
    elif dct['ast_type'] == "alias":
        return ast.alias(dct["name"], dct["asname"])
    else:
        return dct


content = sys.stdin.read()

tree = json.loads(content, object_hook=as_ast)
#print ast.dump(tree)
print astor.to_source(tree)
def test_test_case_to_ast_once(simple_test_case):
    visitor = tc_to_ast.TestCaseToAstVisitor()
    simple_test_case.accept(visitor)
    simple_test_case.accept(visitor)
    assert (astor.to_source(Module(body=visitor.test_case_asts[0])) ==
            "var0 = 5\nvar1 = module0.SomeType(var0)\nassert var1 == 3\n")
Exemplo n.º 44
0
def stateful_eval(expr, env, metadata, state, config):
    """
    Evaluate an expression with a given state.

    WARNING: State can be mutated. If you want to preserve a previous state,
    create a copy before passing it to this function.
    """

    metadata = {} if metadata is None else metadata
    state = {} if state is None else state
    env = LayeredMapping(
        env
    )  # We sometimes mutate env, so we make sure we do so in a local mutable layer.

    # Ensure that variable names in code are valid for Python's interpreter
    # If not, create new variable in mutable env layer, and update code.
    expr = sanitize_variable_names(expr, env)

    # Parse Python code
    code = ast.parse(expr, mode='eval')

    # Extract the nodes of the graph that correspond to stateful transforms
    stateful_nodes = {}
    for node in ast.walk(code):
        if isinstance(node, ast.Call) and getattr(env.get(
                node.func.id), '__is_stateful_transform__', False):
            stateful_nodes[astor.to_source(node).strip()] = node

    # Mutate stateful nodes to pass in state from a shared dictionary.
    for name, node in stateful_nodes.items():
        name = name.replace('"', r'\\\\"')
        if name not in state:
            state[name] = {}
        node.keywords.append(
            ast.keyword(
                'metadata',
                ast.parse(f'__FORMULAIC_METADATA__.get("{name}")',
                          mode='eval').body))
        node.keywords.append(
            ast.keyword(
                'state',
                ast.parse(f'__FORMULAIC_STATE__["{name}"]', mode='eval').body))
        node.keywords.append(
            ast.keyword('config',
                        ast.parse('__FORMULAIC_CONFIG__', mode='eval').body))

    # Compile mutated AST
    code = compile(ast.fix_missing_locations(code), '', 'eval')

    assert "__FORMULAIC_METADATA__" not in env
    assert "__FORMULAIC_STATE__" not in env
    assert "__FORMULAIC_CONFIG__" not in env

    # Evaluate and return
    return eval(code, {},
                LayeredMapping(
                    {
                        '__FORMULAIC_METADATA__': metadata,
                        '__FORMULAIC_CONFIG__': config,
                        '__FORMULAIC_STATE__': state
                    }, env))  # nosec
Exemplo n.º 45
0
 def _slot_item_name(self, node: ast.AST) -> Optional[str]:
     if isinstance(node, ast.Str):
         return node.s
     if isinstance(node, ast.Starred):
         return astor.to_source(node).strip()
     return None
import astor, ast
import sys
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Rewrites programs.')
    parser.add_argument('-t', '--target', required=True)
    parser.add_argument("remaining", nargs="*")
    args = parser.parse_args()

    target = args.target
    sys.argv[1:] = args.remaining

    root = astor.parse_file(target)

    # implement rewriting routines here
    # make modifications to the AST

    modified = astor.to_source(root)
    with open(target, "w") as f:
        f.write(modified)
        f.close()
Exemplo n.º 47
0
def replace_calls(frame_groups: Dict[FrameID, List[NodeInfo]]):
    """Replaces call exprs with intermediate variables."""
    for _, frame in frame_groups.items():
        i = 0  # call index in this frame.
        for _, group in itertools.groupby(frame, lambda x: x.surrounding):
            ast_to_intermediate: Dict[str, str] = {}
            intermediate_vars = {}  # Mapping of intermediate vars and their values.
            for node, _, arg_values in group:
                # ri_ appeared before should be captured in node.vars.
                node.vars.update(intermediate_vars)
                # Replaces nested calls with intermediate vars.
                for inner_call, intermediate in ast_to_intermediate.items():
                    node.code_str = node.code_str.replace(inner_call, intermediate, 1)
                if node.type is NodeType.CALL:
                    ast_to_intermediate[node.code_str] = f"r{i}_"
                    node.code_str = f"r{i}_ = " + node.code_str
                    node.step_into = frame_groups[node.frame_id + (i,)][0].node
                    node.step_into.prev = node
                    node.returned_from = frame_groups[node.frame_id + (i,)][-1].node
                    intermediate_vars[f"r{i}_"] = node.returned_from.return_value
                    i += 1
                node.code_ast = utils.parse_code_str(node.code_str)
                if node.type is NodeType.CALL:
                    assert arg_values, "call node should have arg_values."
                    node.set_param_arg_mapping(arg_values)

            # Deals with some special cases.
            assert len(node.code_ast.body) == 1
            stmt = node.code_ast.body[0]

            # Checks if LHS is ri_ and ri_ only, e.g. r1_ = f(1, 2)
            lhs_is_ri = lambda stmt: (
                isinstance(stmt.value, ast.Name) and re.match(r"r[\d]+_", stmt.value.id)
            )

            if isinstance(stmt, ast.Expr) and lhs_is_ri(stmt):
                # Current node is "r0_", previous node is "r0_ = f()".
                # This happens when the whole line is just "f()".
                # Solution: removes current node, restores previous node to "f()".
                assert node.type is NodeType.LINE
                prev = node.prev
                assert (
                    prev
                    and prev.type is NodeType.CALL
                    and prev.code_str.startswith(f"{stmt.value.id}")
                )
                prev.next = node.next
                if node.next:
                    node.next.prev = prev
                prev.code_str = prev.code_str.split("=", 1)[1].lstrip()
                prev.code_ast = utils.parse_code_str(node.code_str)
            elif isinstance(stmt, ast.Assign) and lhs_is_ri(stmt):
                # Current node represents "a = r0_", previous node is "r0_ = f()"
                # Solution: changes previous to 'a = f()', discards current node.
                # We don't need to modify frame_groups, it's not used in tracing.
                value = stmt.value
                prev = node.prev
                assert (
                    prev
                    and prev.type is NodeType.CALL
                    and prev.code_str.startswith(f"{value.id} =")
                )
                prev.next = node.next
                if node.next:
                    node.next.prev = prev
                prev.code_ast.body[0].targets = stmt.targets
                prev.code_str = astor.to_source(prev.code_ast).strip()
Exemplo n.º 48
0
def process(item):
    (split, the_hash, og_code) = item

    transforms = [('transforms.Identity', t_identity)]

    doDepthK = 'DEPTH' in os.environ and len(os.environ['DEPTH']) > 0
    if doDepthK:
        assert 'NUM_SAMPLES' in os.environ and len(
            os.environ['NUM_SAMPLES']) > 0
        DEPTH = int(os.environ['DEPTH'])
        NUM_SAMPLES = int(os.environ['NUM_SAMPLES'])

        for s in range(NUM_SAMPLES):
            the_seq = []
            for _ in range(DEPTH):
                rand_int = random.randint(1, 8)
                if rand_int == 1:
                    the_seq.append(t_replace_true_false)
                elif rand_int == 2:
                    the_seq.append(t_rename_local_variables)
                elif rand_int == 3:
                    the_seq.append(t_rename_parameters)
                elif rand_int == 4:
                    the_seq.append(t_rename_fields)
                elif rand_int == 5:
                    the_seq.append(t_insert_print_statements)
                elif rand_int == 6:
                    the_seq.append(t_add_dead_code)
                elif rand_int == 7:
                    the_seq.append(t_unroll_whiles)
                elif rand_int == 8:
                    the_seq.append(t_wrap_try_catch)

        transforms.append(('depth-{}-sample-{}'.format(DEPTH,
                                                       s + 1), t_seq(the_seq)))
    else:
        # transforms.append(('renamevar-param', t_seq([t_rename_local_variables, t_rename_parameters], all_sites=True)))
        # transforms.append(('transforms.InsertPrintStatements', t_insert_print_statements))
        # transforms.append(('transforms.RenameLocalVariables',  t_rename_local_variables))
        # transforms.append(('transforms.RenameParameters', t_rename_parameters))
        # transforms.append(('transforms.ReplaceTrueFalse',  t_replace_true_false))
        # transforms.append(('transforms.RenameFields', t_rename_fields))
        # transforms.append(('transforms.AddDeadCode', t_add_dead_code))
        # transforms.append(('transforms.UnrollWhiles', t_unroll_whiles))
        # transforms.append(('transforms.WrapTryCatch', t_wrap_try_catch))
        #transforms.append(('transforms.Combined', t_seq([t_rename_local_variables, t_rename_parameters, t_rename_fields, t_replace_true_false, t_insert_print_statements, t_add_dead_code], all_sites=True)))
        #transforms.append(('transforms.Insert', t_seq([t_insert_print_statements, t_add_dead_code], all_sites=True)))
        transforms.append(('transforms.Replace',
                           t_seq([
                               t_rename_local_variables, t_rename_parameters,
                               t_rename_fields, t_replace_true_false
                           ],
                                 all_sites=True)))

    results = []
    for t_name, t_func in transforms:
        try:
            # print(t_func)
            changed, result, last_idx, site_map = t_func(ast.parse(og_code),
                                                         all_sites=True)
            results.append((changed, split, t_name, the_hash,
                            astor.to_source(result), site_map))
        except Exception as ex:
            import traceback
            traceback.print_exc()
            results.append((False, split, t_name, the_hash, og_code, {}))
    return results
Exemplo n.º 49
0
def get_attr_full_name(node): 
    #assert isinstance(node, gast.Attribute)
    return astor.to_source(gast.gast_to_ast(node)).strip()
Exemplo n.º 50
0
def test_literal_string_annotation(annotation: str, expected: str) -> None:
    """Strings inside Literal annotations must not be recursively parsed."""
    stmt, = ast.parse(annotation).body
    assert isinstance(stmt, ast.Expr)
    unstringed = astbuilder._AnnotationStringParser().visit(stmt.value)
    assert astor.to_source(unstringed).strip() == expected
Exemplo n.º 51
0
def type2str(type_expr):
    if type_expr is None:
        return None
    else:
        return astor.to_source(type_expr).strip()
Exemplo n.º 52
0
def to_source(tree: ast.AST) -> str:
    """
    Dump the AST to generated source doe.
    """
    return astor.to_source(tree)
Exemplo n.º 53
0
def test_primitive_float(assertion_to_ast):
    assertion = MagicMock(value=1.5)
    assertion_to_ast.visit_primitive_assertion(assertion)
    assert (astor.to_source(Module(body=assertion_to_ast.nodes)) ==
            "assert math.isclose(var0, 1.5, abs_tol=0.01)\n")
Exemplo n.º 54
0
def normalise(path):

    split = os.path.split(path)
    base = split[0]
    dirname = split[1]

    normalised_target_path = os.path.join(base, dirname + "_normalised")
    processed_file_path = os.path.join(path, "processed.txt")

    print("Writing normalised files to %s" % normalised_target_path)

    python_files = [y[len(path)+1:] for x in os.walk(path) for y in iglob(os.path.join(x[0], '*.py'))]

    # For debugging
    # python_files = ["debug/test.py"]
    # python_files = ["web2py/gluon/contrib/memcache/memcache.py"]

    processed_files = []
    initial_processed = 0
    syntax_errors = []
    filenotfound_errors = []
    errors = []
    skipped = []

    if os.path.exists(processed_file_path):
        print("Found processed files from previous session, continuing...")
        with open(processed_file_path) as p:
            processed_files = p.read().splitlines()
            initial_processed = len(processed_files)

    def complete():
        write_output(processed_file_path, processed_files)
        print("Processed files: %d\nSyntax errors: %d\nFile not found errors: %d\nOther errors: %d\nSkipped: %d" %
              (len(processed_files) - initial_processed, len(syntax_errors),
               len(filenotfound_errors), len(errors), len(skipped)))

    for filename in python_files:
        if filename in processed_files:
            skipped.append(filename)
            continue

        error = False
        try:
            input_file = os.path.join(path, filename)
            normalised_target_file = os.path.join(normalised_target_path, filename)
            source, tree = get_source_tree(input_file)
        except SyntaxError:
            syntax_errors.append(filename)
            continue
        except FileNotFoundError:
            filenotfound_errors.append(filename)
            continue
        except KeyboardInterrupt:
            print("Keyboard interrupt, saving...")
            complete()
            sys.exit()
        except:
            print("Failed to parse %s due to %s" % (filename, sys.exc_info()[0]))
            errors.append((filename, sys.exc_info()[0]))
            continue

        # AST variable replacement and formatting
        try:
            walker = astwalker.ASTWalker()
            # walker.randomise = False  # For debugging
            walker.walk(tree)
            walker.process_replace_queue()
            ast_source = astor.to_source(tree)
            writefile(normalised_target_file, ast_source)
        except KeyboardInterrupt:
            print("Keyboard interrupt, saving...")
            complete()
            sys.exit()
        except:
            print("Failed to process normalisation for file %s" % filename)
            print(sys.exc_info()[0])
            error = True
            if len(python_files) == 1:
                raise

        if not error:
            processed_files.append(filename)

    complete()
Exemplo n.º 55
0
def test_primitive_non_bool(assertion_to_ast):
    assertion = MagicMock(value=42)
    assertion_to_ast.visit_primitive_assertion(assertion)
    assert astor.to_source(
        Module(body=assertion_to_ast.nodes)) == "assert var0 == 42\n"
Exemplo n.º 56
0
def get_attribute_full_name(node):
    assert isinstance(
        node,
        gast.Attribute), "Input non-Attribute node to get attribute full name"
    return astor.to_source(gast.gast_to_ast(node)).strip()
Exemplo n.º 57
0
def to_source(node) -> str:
    s = astor.to_source(node, pretty_source=lambda x: "".join(x))
    return s
Exemplo n.º 58
0
def translate_function(function, scheduler, saneitize=True):
    """
    Translates a function into a :class:`tasks.ScheduledCallable`.
    """
    def main_workaround(module_name):
        """
        If the function is inside __main__ we have problem
        since the workers will have a different module called
        __main__.
        So we make a best-effort attemt to find if __main__
        is also reachable as a module.
        """
        if module_name != "__main__":
            return module_name  # we are fine.

        # See if we can find a sys.path that matches the location of __main__
        main_file = getattr(sys.modules["__main__"], "__file__", "")
        candidates = {path for path in sys.path if main_file.startswith(path)}
        candidates = sorted(candidates, key=len)
        for candidate in candidates:

            # Try to create the absolute module name from the filename only.
            module_name = main_file[len(candidate):]
            if module_name.endswith(".py"):
                module_name = module_name[:-3]
            if module_name.endswith(".pyc"):
                module_name = module_name[:-4]
            module_name = module_name.replace("/", ".")
            module_name = module_name.replace("\\", ".")
            while module_name.startswith("."):
                module_name = module_name[1:]

            # Check if it actually works.
            try:
                module = importlib.import_module(module_name)
                return module.__name__
            except ImportError:
                pass

        # we were unlucky.
        raise ValueError(
            "The functions in the __main__ module cannot be translated.")

    source = inspect.getsourcelines(function)
    source = "".join(source[0])

    logger.info("Translating: \n%s" % source)

    node = ast.parse(utils.unindent(source))

    # Remove decorators

    # TODO handle decorators properly
    assert len(node.body) == 1
    funcdef = node.body[0]
    assert isinstance(funcdef, ast.FunctionDef)
    funcdef.decorator_list = []

    if len(funcdef.args.defaults) != 0:
        # TODO add support
        raise ValueError(
            "Cannot translate %f: @schedule does not support functions with default arguments"
        )

    id_factory = naming.UniqueIdentifierFactory()

    if saneitize:
        makesane = saneitizer.Saneitizer()
        node = makesane.process(node, id_factory)

    module_name = getattr(function, "__module__", None)
    if not module_name:
        raise ValueError(
            "Cannot translate %f: The module in which it is defined is unknown."
        )
    module_name = main_workaround(module_name)

    import astor
    logger.info("Preprocessed source:\n%s" % astor.to_source(node))

    translator = Translator(id_factory, scheduler, module_name)
    graph = translator.visit(node)

    def find_FunctionDefTask(graph):
        for tick in graph.get_all_ticks():
            task = graph.get_task(tick)
            if isinstance(task, tasks.FunctionDefTask):
                return task
        raise ValueError("No function was translated.")

    funcdeftask = find_FunctionDefTask(graph)

    defaults = function.__defaults__
    if not defaults:
        defaults = tuple()

    if funcdeftask.num_defaults != len(defaults):
        raise ValueError("Number of default arguments doesn't match.")

    if function.__closure__:
        raise ValueError("Translating closures currently not supported.")

    inputs = {"default_%s" % i: v for i, v in enumerate(defaults)}
    scheduled_callable = funcdeftask.evaluate(inputs)['function']
    return scheduled_callable
Exemplo n.º 59
0
Arquivo: model.py Projeto: Kuree/karst
def define_memory(func: Callable[["MemoryModel"], None]):
    func_src = inspect.getsource(func)
    func_tree = ast.parse(textwrap.dedent(func_src))
    # remove the decorator
    func_tree.body[0].decorator_list = []
    # find the model name
    find_model_name = FindModelVariableName()
    find_model_name.visit(func_tree)
    assert find_model_name.name, "unable to find model variable name"
    model_name = find_model_name.name
    action_visitor = FindActionDefine()
    action_visitor.visit(func_tree)
    assert len(action_visitor.nodes) > 0
    # get all the marked as well
    mark_visitor = FindMarkedFunction("mark")
    mark_visitor.visit(func_tree)
    after_config_visitor = FindLoopRangeVar("after_config", model_name)
    after_config_visitor.visit(func_tree)
    # also transform the functions marked as global
    global_visitor = FindMarkedFunction("global_func")
    global_visitor.visit(func_tree)
    nodes = action_visitor.nodes + mark_visitor.nodes +\
        after_config_visitor.nodes + global_visitor.nodes
    for action_node in nodes:
        # multiple passes
        # 1. convert all the assignment into function
        assign_visitor = AssignNodeVisitor(model_name)
        action_node = assign_visitor.visit(action_node)
        ast.fix_missing_locations(action_node)
        # 2. convert if statement
        if_visitor = IfNodeVisitor(model_name)
        if_visitor.visit(action_node)
        ast.fix_missing_locations(action_node)
        # 3. add int() call for every for loop
        for_transform = ForVarVisitor()
        for_transform.visit(action_node)
        # 4. add eval to list index
        index_transform = ListIndex(model_name)
        index_transform.visit(action_node)

    # let the model know which config variables are used on loop generation
    # it's done through adding extra line to the source code
    if after_config_visitor.range_vars:
        node = add_model_loop_vars(model_name, after_config_visitor.range_vars,
                                   "add_loop_var")
        assert isinstance(func_tree.body[0].body[-1], ast.Return)
        func_tree.body[0].body.insert(-1, node)

    # insert name to the model
    node = add_model_name(model_name)
    func_tree.body[0].body.insert(-1, node)
    ast.fix_missing_locations(func_tree)

    # formatting
    def pretty_source(source):
        return "".join(source)

    new_src = astor.to_source(func_tree,
                              indent_with=" " * 2,
                              pretty_source=pretty_source)
    func_name = func.__name__
    code_obj = compile(new_src, "<ast>", "exec")
    exec(code_obj, globals(), locals())
    namespace = locals()
    return namespace[func_name]
Exemplo n.º 60
0
def render_view(tree, selected_node, _):
    # Ignore thw window width, astor is not that smart

    generator_class = cursor_highlighter_of(selected_node)
    return astor.to_source(tree, source_generator_class=generator_class)