コード例 #1
0
ファイル: astChecker.py プロジェクト: chenzhifei731/Pysmell
 def visit_ImportFrom(self,node):
   if self.fileName[-12:] == '\\__init__.py':
     self.generic_visit(node)
     return
   try:
     if node.module is not None and len(node.module)>4 and node.module[0:2] == '__' and node.module[-2:] == '__':
       self.generic_visit(node)
       return
   except:
     print astunparse.unparse(node)
   for alias in node.names:
     if len(alias.name)>4 and alias.name[0:2] == '__' and alias.name[-2:] == '__':
         continue
     if alias.asname is not None:
       for (name,file,lineno) in self.imports:
         if name==alias.asname and self.fileName==file:
           break
       else:
         self.imports.add((alias.asname,self.fileName,node.lineno))
     elif alias.name != '*':
       for (name,file,lineno) in self.imports:
         if name==alias.name and self.fileName==file:
           break
       else:
         self.imports.add((alias.name,self.fileName,node.lineno))
   self.generic_visit(node)
コード例 #2
0
ファイル: translate.py プロジェクト: rymurr/q
def transformHelper(command):
    node = preprocess(command)
    print ast.dump(node)
    try:
        print astunparse.unparse(node)
    except:
        pass
    return node
コード例 #3
0
ファイル: if.py プロジェクト: cbrentharris/pybugs
 def test_program_doesnt_return_false_after_if(self, node):
     u"""checks that the node does not resemble the block of code such as
     if (some boolean expression):
         return False
     else:
         return True
     """
     if not isinstance(node, ast.If):
         return
     if self.return_false_in_if_body(node.body) and self.return_true_in_else_body(node.orelse):
         print "Unnecessary conditional - simply return negation of test - line {}".format(node.lineno)
         print astunparse.unparse(node)
コード例 #4
0
ファイル: cases.py プロジェクト: andy7i/taurus
    def assertFilesEqual(expected, actual, replace_str="", replace_with="", python_files=False):
        with open(expected) as exp, open(actual) as act:
            act_lines = [x.replace(replace_str, replace_with).rstrip() for x in act.readlines()]
            exp_lines = [x.replace(replace_str, replace_with).rstrip() for x in exp.readlines()]
            if python_files:
                act_lines = astunparse.unparse(ast.parse('\n'.join(act_lines))).split('\n')
                exp_lines = astunparse.unparse(ast.parse('\n'.join(exp_lines))).split('\n')

            diff = list(difflib.unified_diff(exp_lines, act_lines))
            if diff:
                ROOT_LOGGER.info("Replacements are: %s => %s", replace_str, replace_with)
                msg = "Failed asserting that two files are equal:\n%s\nversus\n%s\nDiff is:\n\n%s"
                raise AssertionError(msg % (actual, expected, "\n".join(diff)))
コード例 #5
0
ファイル: if.py プロジェクト: cbrentharris/pybugs
 def test_program_doesnt_have_if_false(self, node):
     u"""checks that the node does not resemble the block of code such as
     if False:
         ...
     """
     if not isinstance(node, ast.If):
         return 
     if not node.test.id == "False":
         return         
     if not hasattr(node,'test'):
         print "Unecessary if False statement - line {}".format(node.lineno)
     else:
         print "Dead code in if block - line {}".format(node.lineno)
     print astunparse.unparse(node)
コード例 #6
0
ファイル: truth_table.py プロジェクト: clj/truth-table
def collect_names(node):
    names = OrderedSet()
    if node.__class__.__name__ == 'BoolOp':
        for value in node.values:
            if value.__class__.__name__ in ['BoolOp', 'UnaryOp']:
                names = names | collect_names(value)
            else:
                names = names | OrderedSet(
                    [Name(astunparse.unparse(value).strip(), value)])
    elif node.__class__.__name__ == 'UnaryOp':
        names = names | collect_names(node.operand)
    else:
        names = names | OrderedSet(
            [Name(astunparse.unparse(node).strip(), node)])
    return names
コード例 #7
0
ファイル: unroller.py プロジェクト: ml-lab/TerpreT
def update_variable_domains(variable_domains, node_list):
    '''Updates variable_domains by inserting mappings from a variable name to the to the number of
    elements in their data domain.

    Program variables are created in assignments of the form
    "{varName} = Var({size})" or "{varName} = Param({size})".

    '''
    for ch in node_list:
        if isinstance(ch, ast.Assign) and len(ch.targets) == 1:
            name = astunparse.unparse(ch.targets[0]).rstrip()
            rhs = ch.value

            if isinstance(rhs, ast.Call):
                decl_name = rhs.func.id
                args = rhs.args
            elif isinstance(rhs, ast.Subscript) and isinstance(rhs.value, ast.Call):
                decl_name = rhs.value.func.id
                args = rhs.value.args
            else:
                continue

            if decl_name not in ["Param", "Var", "Input", "Output"]:
                continue
            
            if len(args) > 1:
                error.error('More than one size parameter in variable declaration of "%s".' % (name), ch)
            size = args[0]

            if isinstance(size, ast.Num):
                if name in variable_domains and variable_domains[name] != size.n:
                    error.fatal_error("Trying to reset the domain of variable '%s' to '%i' (old value '%i')." % (name, size.n, variable_domains[name]), size) 
                variable_domains[name] = size.n
            else:
                error.fatal_error("Trying to declare variable '%s', but size parameters '%s' is not understood." % (name, astunparse.unparse(size).rstrip()), size)
コード例 #8
0
ファイル: obfuscate.py プロジェクト: tobast/troll-obfuscator
def main():
    env = {}
    for builtin in (dir(__builtins__) + ['__builtins__']):
        env[builtin] = builtin
    for f in sys.argv[1:]:
        nCode,env = replaceIdents(f, env)
        print(astunparse.unparse(nCode))
コード例 #9
0
ファイル: truth_table.py プロジェクト: clj/truth-table
def truth_table(node):
    names = collect_names(node)
    conditions = list(itertools.product([False, True], repeat=len(names)))
    rows = [tuple(map(str, names) + [astunparse.unparse(node).strip()])]
    for condition in conditions:
        rows.append(condition + tuple([evaluate(node, names, condition)]))
    return rows
コード例 #10
0
ファイル: test_function.py プロジェクト: fjarri/peval
def function_from_source(source, globals_=None):
    """
    A helper function to construct a Function object from a source
    with custom __future__ imports.
    """

    module = ast.parse(unindent(source))
    ast.fix_missing_locations(module)

    for stmt in module.body:
        if type(stmt) == ast.FunctionDef:
            tree = stmt
            name = stmt.name
            break
    else:
        raise ValueError("No function definitions found in the provided source")

    code_object = compile(module, '<nofile>', 'exec', dont_inherit=True)
    locals_ = {}
    eval(code_object, globals_, locals_)

    function_obj = locals_[name]
    function_obj._peval_source = astunparse.unparse(tree)

    return Function.from_object(function_obj)
コード例 #11
0
ファイル: main.py プロジェクト: ChunHungLiu/pyformat.info
def unparse(node, strip=None):
    result = astunparse.unparse(node)
    if strip:
        result = result.lstrip().rstrip()
        if isinstance(node, _ast.BinOp):
            return result[1:-1]
    return result
コード例 #12
0
ファイル: astChecker.py プロジェクト: njuap/python-smells
 def visit_TryExcept(self,node):
     exceptions = ["BaseException","Exception","StandardError"]
     generalFlag = True
     for item in node.handlers:
         if astunparse.unparse(item.body[0]).strip() == "pass":
             # print "pass"
             self.result.append((7,self.fileName,node.lineno,-1))
             self.generic_visit(node) 
             return
         if item.type is not None:
             if isinstance(item.type,_ast.Tuple):
                 # print "a tuple exceptions"
                 for e in item.type.elts:
                     if hasattr(e,"id") and e.id in exceptions:
                         # print "general in tuple:",e.id
                         self.result.append((7,self.fileName,node.lineno,-1))
                         self.generic_visit(node) 
                         return
                 # print "not general in tuple"
                 generalFlag = False
             elif (hasattr(item.type,"id") and item.type.id in exceptions) is False:
                 generalFlag = False
     if generalFlag:
         # print "general"
         self.result.append((7,self.fileName,node.lineno,-1))
     self.generic_visit(node) 
コード例 #13
0
ファイル: parse.py プロジェクト: jmanuel1/concat
def p_push_primary(p):  # noqa
    """push_primary : DOLLARSIGN primary"""
    arg_list = ast.arguments(
        args=[ast.arg(arg='stack', annotation=None),
              ast.arg(arg='stash', annotation=None)],
        vararg=None,
        kwonlyargs=[],
        kwarg=None,
        defaults=[],
        kw_defaults=[])
    # print(p[2])
    if isinstance(p[2], ast.Name):
        pass
    # TODO: not a very good check
    elif 'stack.pop().' in astunparse.unparse(p[2][0]):
        # we are pushing an attributeref
        # get rid of the _call to leave the stack.pop().<attr> and concatify it
        p[2] = ast.Call(func=ast.Name(id='concatify', ctx=ast.Load()),
                        args=p[2][0].value.args[0:1], keywords=[])
    else:
        # print(p[2])
        p[2] = ast.Call(func=ast.Name(id='ConcatFunction', ctx=ast.Load()),
                        args=[ast.Lambda(arg_list, _combine_exprs(p[2]))],
                        keywords=[])
    p[0] = [ast.Expr(_push(p[2]))]
    _set_line_info(p)
コード例 #14
0
ファイル: dirt.py プロジェクト: aaron11496/dirt
 def __init__(self, node, filename):
     self.node = node
     self.lineno = node.lineno
     self.filename = filename
     self.name = node.name
     self.unparsed = astunparse.unparse(node)
     self.source = ''.join(self.unparsed).strip()
     self.linecount = len(self.source.splitlines())
コード例 #15
0
ファイル: unroller.py プロジェクト: ml-lab/TerpreT
 def get_name_and_const_from_test(test):
     if not(isinstance(test.ops[0], ast.Eq)):
         raise Exception("Tests in if can only use ==, not '%s'." % (astunparse.unparse(test.ops[0]).rstrip()), node)
     (name, const) = (None, None)
     if (isinstance(test.left, ast.Name) or isinstance(test.left, ast.Subscript)) and isinstance(test.comparators[0], ast.Num):
         (name, const) = (test.left, test.comparators[0].n)
     elif (isinstance(test.comparators[0], ast.Name) or isinstance(test.comparators[0], ast.Subscript)) and isinstance(test.left, ast.Num):
         (name, const) = (test.comparators[0], test.left.n)
     return (name, const)
コード例 #16
0
ファイル: astChecker.py プロジェクト: njuap/python-smells
 def visit_IfExp(self,node):
   elseblock = node.orelse
   if elseblock:
     if elseblock.lineno == node.lineno:
         stmt = astunparse.unparse(node)
         exprLength = len(stmt.strip()) - stmt.count(' ') - 2
         if exprLength >= config['ifinrowexp']:
             self.result.append((10,self.fileName,node.lineno,exprLength))
   self.generic_visit(node) 
コード例 #17
0
ファイル: tptv1.py プロジェクト: ml-lab/TerpreT
def get_var_name(node):
    if isinstance(node, ast.Name):
        return node.id
    elif isinstance(node, ast.Subscript):
        return get_var_name(node.value)
    elif isinstance(node, ast.Call):
        return None  # This is the Input()[10] case
    else:
        raise Exception("Can't extract var name from '%s'" % (unparse(node).rstrip()))
コード例 #18
0
def find_timeits(tree):
    """
    :param ast.AST tree: The container of one or more timeit_funcs.
    :rtype: list[TimeitHolder]

    The ``tree` may be:
        * an ast.Module containing the ``timeit_func``
        * an ast.Module containing a function containing the ``timeit_func``
        * the ``timeit_func`` directly as an ast.FunctionDef

    An AST parse tree will usually be an ast.Module, containing one or more
    items. But this function could also receive the timeit_func directly,
    or it could be nested within a function in the ast.Module.
    """
    setup = []
    timeits = []

    isfunc = lambda o: isinstance(o, ast.FunctionDef)

    # unwrap a top module containing only a function
    if isinstance(tree, ast.Module) and len(tree.body) == 1:
        tree = tree.body[0]

    if not hasattr(tree, 'body'):
        raise TypeError("Can't handle '%s'" % type(tree))

    # we might have the timeit func directly
    if isfunc(tree) and tree.name.startswith(func_name_prefix) \
            and not any(c for c in tree.body if isfunc(c)):
        timeits.append(TimeitHolder(name=tree.name[len(func_name_prefix):],
                                    setup="",
                                    code=astunparse.unparse(tree.body)))
    else:
        # otherwise, iterate the statement body and
        # find setup code and timeit funcs
        for child in tree.body:
            if isfunc(child) and child.name.startswith(func_name_prefix):
                timeits.append(TimeitHolder(name=child.name[len(func_name_prefix):],
                                            setup=astunparse.unparse(setup),
                                            code=astunparse.unparse(child.body)))
            else:
                setup.append(child)

    return timeits
コード例 #19
0
ファイル: tui.py プロジェクト: skylerberg/ASTeditor
def _display():
    screen = SCREEN
    node = globals['node']
    screen.erase()
    (y, x) = screen.getmaxyx()
    for (num, line) in enumerate(unparse(node.node).split('\n')):
        if (num >= y):
            break
        screen.addnstr(num, 0, line, (x - 1))
    screen.refresh()
コード例 #20
0
ファイル: utils.py プロジェクト: qfhuang/peval-1
def ast_to_source(tree):
    ''' Return python source of AST tree, as a string.
    '''
    source = astunparse.unparse(tree)

    # trim newlines and trailing spaces --- some pretty printers add it
    source = "\n".join(line.rstrip() for line in source.split("\n"))
    source = source.strip("\n")

    return source
コード例 #21
0
ファイル: utils.py プロジェクト: fjarri/peval
def assert_ast_equal(test_ast, expected_ast, print_ast=True):
    """
    Check that test_ast is equal to expected_ast,
    printing helpful error message if they are not equal
    """

    equal = ast_equal(test_ast, expected_ast)
    if not equal:

        if print_ast:
            expected_ast_str = astunparse.dump(expected_ast)
            test_ast_str = astunparse.dump(test_ast)
            print_diff(test_ast_str, expected_ast_str)

        expected_source = normalize_source(unparse(expected_ast))
        test_source = normalize_source(unparse(test_ast))
        print_diff(test_source, expected_source)

    assert equal
コード例 #22
0
ファイル: __main__.py プロジェクト: jmanuel1/concat
def compile_and_run(filename, file_obj=None, debug=False, globals=None):
    with file_obj or open(filename, 'r'):
        ast_ = parse.parse(file_obj.read(), debug)
    ast.fix_missing_locations(ast_)
    with open('debug.py', 'w') as f:
        f.write(astunparse.unparse(ast_))
    with open('ast.out', 'w') as f:
        f.write('\n------------ AST DUMP ------------\n')
        f.write(astunparse.dump(ast_))
    prog = compile(ast_, filename, 'exec')
    exec(prog, globals or {})
コード例 #23
0
ファイル: astChecker.py プロジェクト: chenzhifei731/Pysmell
 def visit_FunctionDef(self,node):
   # argsCount
   def findCharacter(s,d):
     try:
       value = s.index(d)
     except ValueError:
       return -1
     else:
       return value
   funcName = node.name.strip()
   p = re.compile("^(__[a-zA-Z0-9]+__)$")
   if p.match(funcName.strip()) and funcName != "__import__" and funcName != "__all__":
     self.defmagic.add((funcName,self.fileName,node.lineno))
   stmt = astunparse.unparse(node.args)
   arguments = stmt.split(",")
   argsCount = 0
   for element in arguments:
     if findCharacter(element,'=') == -1:
       argsCount += 1
   self.result.append((1,self.fileName,node.lineno,argsCount))
   #function length
   lines = set()
   res = [node]
   while len(res) >= 1:
     t = res[0]
     for n in ast.iter_child_nodes(t):
       if not hasattr(n,'lineno') or ((isinstance(t,_ast.FunctionDef) or isinstance(t,_ast.ClassDef)) and n == t.body[0] and isinstance(n,_ast.Expr)):
         continue
       lines.add(n.lineno)
       if isinstance(n,_ast.ClassDef) or isinstance(n,_ast.FunctionDef):
         continue
       else:
         res.append(n)
     del res[0]
   self.result.append((2,self.fileName,node.lineno,len(lines))) 
   #nested scope depth
   if node in self.scopenodes:
     self.scopenodes.remove(node)
     self.generic_visit(node)
     return
   dep = [[node,1]] #node,nestedlevel
   maxlevel = 1
   while len(dep) >= 1:
     t = dep[0][0]
     currentlevel = dep[0][1]
     for n in ast.iter_child_nodes(t):
       if isinstance(n,_ast.FunctionDef):
         self.scopenodes.append(n)
         dep.append([n,currentlevel+1])
     maxlevel = max(maxlevel,currentlevel)
     del dep[0]
   if maxlevel>1:
     self.result.append((3,self.fileName,node.lineno,maxlevel)) #DOC
   self.generic_visit(node) 
コード例 #24
0
ファイル: tptv1.py プロジェクト: ml-lab/TerpreT
def extend_subscript_for_input(node, extension):
    if isinstance(node.slice, ast.Index):
        idx = node.slice.value
        if isinstance(idx, ast.Tuple):
            new_idx = ast.Tuple([extension] + idx.elts, ast.Load())
        else:
            new_idx = ast.Tuple([extension, idx], ast.Load())
        node.slice.value = new_idx
    else:
        raise Exception("Unhandled node indexing: '%s'" % (unparse(node).rstrip()))
    return node
コード例 #25
0
ファイル: astChecker.py プロジェクト: chenzhifei731/Pysmell
 def visit_IfExp(self,node):
   expr = astunparse.unparse(node)
   exprLength = len(expr.strip()) - expr.count(' ') - 2
   childnodes = list(ast.walk(node))
   lines = 0
   for n in childnodes:
     if hasattr(n,'lineno'):
       lines = max(n.lineno-node.lineno,lines)
   lines = lines + 1
   self.result.append((10,self.fileName,node.lineno,exprLength,lines))
   self.generic_visit(node) 
コード例 #26
0
ファイル: fol2.py プロジェクト: vpramo/xos-1
def xproto_fol_to_python_validator(policy, fol, model, message, tag=None):
    if isinstance(fol, jinja2.Undefined):
        raise Exception('Could not find policy:', policy)

    f2p = FOL2Python()
    fol_reduced = f2p.hoist_outer(fol)

    if fol_reduced in ['True','False'] and fol != fol_reduced:
        raise TrivialPolicy("Policy %(name)s trivially reduces to %(reduced)s. If this is what you want, replace its contents with %(reduced)s"%{'name':policy, 'reduced':fol_reduced})

    a = f2p.gen_validation_function(fol_reduced, policy, message, tag='validator')
    
    return astunparse.unparse(a)
コード例 #27
0
ファイル: testing.py プロジェクト: cyrus-/tydy
def translation_eq(f, truth, print_f=False):
    """helper function for test_translate functions

    compares an AST to the string truth, which should contain Python code.
    truth is first dedented.
    """
    f.compile()
    translation = f.translation
    translation_s = astunparse.unparse(translation)
    if print_f:
        print translation_s
    truth_s = "\n" + textwrap.dedent(truth) + "\n"
    assert translation_s == truth_s
コード例 #28
0
ファイル: unroller.py プロジェクト: ml-lab/TerpreT
 def get_variable_domain(node):
     # Look up the number of values to switch on.
     if isinstance(node, ast.Name):
         return variable_domains[node.id]
     if isinstance(node, str):
         return variable_domains[node]
     if isinstance(node, ast.Subscript):
         if node.value.id in variable_domains:
             return variable_domains[node.value.id]
         node_name = astunparse.unparse(node).rstrip()
         if node_name in variable_domains:
             return variable_domains[node_name]
     error.fatal_error("No variable domain known for expression '%s', for which we want to unroll a for/with." % (astunparse.unparse(node).rstrip()), node)
コード例 #29
0
    def create_module(self, module, prefix, url, members = None, imports= None):

        p = Path(module)
        if True:
            prefix = str(prefix).replace('-','_')
            class_name = prefix.upper()
            print "module create:" + module, class_name
            ast = create_ont_module(class_name,prefix,url, imports, members)        
            code = astunparse.unparse(ast)

            o = open (module,"w")
            o.write(code)
            o.close()
コード例 #30
0
ファイル: func.py プロジェクト: esben/xd-build-core
 def set_source(self, source, function=None):
     if function is None:
         function = self.value
     assert isinstance(function, types.FunctionType)
     if isinstance(source, str):
         source = ast.parse(source)
         assert isinstance(source, ast.Module)
         assert isinstance(source.body, list) and len(source.body) == 1
         source = source.body[0]
     assert isinstance(source, ast.FunctionDef)
     source = astunparse.unparse(source)
     source = source.strip('\n')
     self.source[function] = source
コード例 #31
0
def t3st_unparse():
    code = test_unparse_ast
    print(astunparse.unparse(ast.parse(inspect.getsource(code))))
    # get a pretty-printed dump of the AST
    print(astunparse.dump(ast.parse(inspect.getsource(code))))
コード例 #32
0
def test_lambda_copy_nested_captured():
    a, new_a = util_run_parse("lambda b: (lambda a: a+b)")
    assert unparse(new_a).strip() == "(lambda arg_0: (lambda a: (a + arg_0)))"
コード例 #33
0
 def check_test(test):
     if not(isinstance(test.ops[0], ast.Eq)):
         self.__messages.append(CheckError("Tests can only use ==, not '%s'." % (astunparse.unparse(test.ops[0]).rstrip()), node))
     if not(isinstance(test.left, ast.Name)):
         self.__messages.append(CheckError("Tests have to have identifier, not '%s' as left operand." % (astunparse.unparse(test.left).rstrip()), node))
     if len(test.comparators) != 1:
         self.__messages.append(CheckError("Tests cannot have multiple comparators in test '%s'." % (astunparse.unparse(test).rstrip()), node))
     if not(isinstance(test.comparators[0], ast.Num)):
         self.__messages.append(CheckError("Tests have to have constant, not '%s' as right operand." % (astunparse.unparse(test.comparators[0]).rstrip()), node))
     return (test.left.id, test.comparators[0].n)
コード例 #34
0
def unparse(node):
    """ Unparses an AST node to a Python string, chomping trailing newline. """
    if node is None:
        return None
    return astunparse.unparse(node).strip()
コード例 #35
0
 def __call__(self, formula):
     parsed = ast.parse(formula)
     self.visitor.visit(parsed)
     return unparse(parsed).strip()
コード例 #36
0
ファイル: __init__.py プロジェクト: simhaonline/blockexplorer
 def to_source(self):
     return astunparse.unparse(self.parsed)
コード例 #37
0
ファイル: __init__.py プロジェクト: wangkela/WenyanLanguage
def ast_build(exprs):
    ast = Module(body=exprs)
    return unparse(ast)
コード例 #38
0
            return node

        return self.prepend_print_line(node)

    def visit_If(self, node):

        if len(node.children) > 0:
            node = self.generic_visit(node)

        return self.prepend_print_line(node)

    # def visit_Expr(self, node):
    #     # if isinstance(node, ast.BinOp):
    #     #     return self.append_print_line(node)
    #
    #     return node


tree = ParentChildNodeTransformer().visit(tree)
print("***************** New Tree ****************")

MyTransformer().visit(tree)

ast.fix_missing_locations(tree)

print(astunparse.unparse(tree))

with open('my-instrumented-program.py', 'w') as file:
    file.write(astunparse.unparse(tree))

print("Done and Bye!")
コード例 #39
0
            args=(
                self.visit(node.targets[0]), 
                self.visit(node.value),
            ),
            keywords=(),
            starargs=(),
            kwargs=(),
        ), node)
    
    def visit_While(self, node):
        # print astpretty.pprint(node)
        return copy_location(Call(
            func=Name(id="tf.while_loop", ctx=Load()),
            args=(
                Lambda(
                    args=arguments(args=[],defaults=[],vararg=[],kwarg=[]),
                    body=self.visit(node.test), 
                    ),
                map(self.visit, node.body),
            ),
            keywords=(),
            starargs=(),
            kwargs=(),
        ), node)


myast = ast.parse(txt)
myast = RewriteName().visit(myast)
# print astpretty.pprint(myast)
print(astunparse.unparse(myast))
コード例 #40
0
def unparse(something):
    return astunparse.unparse(something).rstrip("\n")
コード例 #41
0
            newName = self.get_variable_name(len(self.seen))
            self.seen[oldName] = newName
            return newName


if __name__ == "__main__":
    with open(sys.argv[1]) as f:
        tree = ast.parse(f.read())

    obfuscator = Obfuscator()
    obfuscator.visit(tree)

    random.seed('𝒙')
    shuffled = list(range(len(obfuscator.constants)))
    random.shuffle(shuffled)

    shuffled_constants = [
        obfuscator.constants[shuffled.index(i)]
        for i in range(len(obfuscator.constants))
    ]

    out = """
import random as X
X.seed('𝒙')
𝒙 = [""" + ",".join(shuffled_constants) + """]
X.shuffle(𝒙)
""" + astunparse.unparse(tree)

    with open(sys.argv[2], "w") as f:
        f.write(out)
コード例 #42
0
def json_pprint(data, stream=sys.stdout):
    s1 = json.dumps(data)
    tree = ast.parse(s1)
    s2 = astunparse.unparse(tree)
    s3 = re.sub(r',\n\s+([\]}][^ ])', r'\1', s2)
    stream.write(s3)
コード例 #43
0
ファイル: forktrans.py プロジェクト: vrthra/muforks.py
def forking_transform(src):
    return astunparse.unparse(ForkingTransformer().visit(ast.parse(src)))
コード例 #44
0
ファイル: test_unparse.py プロジェクト: tszdanger/astunparse
 def check_roundtrip(self, code1, filename="internal", mode="exec"):
     ast1 = compile(str(code1), filename, mode, ast.PyCF_ONLY_AST)
     code2 = astunparse.unparse(ast1)
     ast2 = compile(code2, filename, mode, ast.PyCF_ONLY_AST)
     self.assertASTEqual(ast1, ast2)
コード例 #45
0
def dfs(a):
    if isinstance(a, ast.Import):
        if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a):
            import_flag.append('init')
            return ast.parse('from jittor import init').body[0]
        if 'torch' in astunparse.unparse(a) and 'nn' in astunparse.unparse(a):
            import_flag.append('nn')
            return ast.parse('from jittor import nn').body[0]
        if a.names[0].name == 'torch': 
            return 'delete'
    elif isinstance(a, ast.ImportFrom):
        if 'torch' in a.module:
            return 'delete'
    elif isinstance(a, ast.Call):
        for idx, ag in enumerate(a.args): 
            ret = dfs(ag)
            if ret is not None:
                a.args[idx] = ret
        for idx, kw in enumerate(a.keywords): 
            ret = dfs(kw)
            if ret is not None:
                a.keywords[idx] = ret
        func = astunparse.unparse(a.func).strip('\n').split('.')
        prefix = '.'.join(func[0:-1])
        func_name = func[-1]
        if func_name in unsupport_ops:
            raise_unsupport(func_name)
        if func_name in pjmap.keys():
            ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
            kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
            ret = convert_(prefix, func_name, ags, kws)
            return ast.parse(ret).body[0].value
        if ".load_state_dict" in astunparse.unparse(a.func):
            a.func.attr = 'load_parameters'
        if astunparse.unparse(a.func).strip('\n').endswith(".size"):
            ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
            if len(ags) != 0:
                con = astunparse.unparse(a.func).split('.size')[0] + '.shape[' + ','.join(ags) + ']'
            else:
                con = astunparse.unparse(a.func).replace('size', 'shape')
            return ast.parse(con).body[0].value
    elif isinstance(a, ast.Expr): pass
    elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name): replace(a)
    elif isinstance(a, ast.FunctionDef):
        if a.name == 'forward': a.name = 'execute'
    if hasattr(a, '__dict__'):
        for k in a.__dict__.keys():
            if isinstance(a.__dict__[k], list):
                delete_flag = []
                for i,a_ in enumerate(a.__dict__[k]):
                    ret = dfs(a_)
                    if ret is 'delete':
                        delete_flag.append(True)
                        del a.__dict__[k][i]
                        continue
                    if ret is not None:
                        a.__dict__[k][i] = ret
                    delete_flag.append(False)
                tmp = [a_ for i,a_ in enumerate(a.__dict__[k]) if delete_flag[i] == False]
                a.__dict__[k] = tmp
            else:
                ret = dfs(a.__dict__[k])
                if ret is not None:
                    a.__dict__[k] = ret
コード例 #46
0
def ast_construct(astbody, parent):
    for child in ast.iter_child_nodes(astbody):
        if isinstance(child, ast.Assign):
            label = 'Assign'
            value = astunparse.unparse(child)
            value = pattern.sub('', value)
            assign = Node(label, value)
            parent.insert_child(assign)

        if isinstance(child, ast.Break):
            label = 'Break'
            value = 'break'
            n_break = Node(label, value)
            parent.insert_child(n_break)

        if isinstance(child, ast.Expr):
            if isinstance(child.value, ast.Call):
                label = 'Call'
                value = astunparse.unparse(child.value)
                value = pattern.sub('', value)
                call = Node(label, value)
                parent.insert_child(call)

            else:
                label = 'Expr'
                value = astunparse.unparse(child)
                value = pattern.sub('', value)
                node_expr = Node(label, value)
                parent.insert_child(node_expr)

        if isinstance(child, ast.AugAssign):
            label = 'AugAssign'
            value = astunparse.unparse(child)
            value = pattern.sub('', value)
            augassign = Node(label, value)
            parent.insert_child(augassign)

        if isinstance(child, ast.Return):
            n_return = Node('Return', 'return')
            parent.insert_child(n_return)
            # 如果有返回值,对返回值进行处理
            if child.value:
                return_value = astunparse.unparse(child.value)
                return_value = pattern.sub('', return_value)
                if return_value[0] == '(' and return_value[-1] == ')':
                    return_value = return_value[1:-1]
                return_value = return_value.split(',')
                for item in return_value:
                    node_return_value = Node('Return_Value', item)
                    n_return.insert_child(node_return_value)

        if isinstance(child, ast.If):
            label = 'If'
            value = astunparse.unparse(child.test)  #If节点的value为其判断条件
            value = pattern.sub('', value)
            node_if = Node(label, value)
            parent.insert_child(node_if)

            #如果If包含oreless块,则生成Then和Else作为If的子节点,否则body的内容直接作为If的孩子节点插入
            if child.orelse:
                node_then = Node('Then', value)  #自身不带value的节点从其父结点继承value
                node_if.insert_child(node_then)
                source_then = astunparse.unparse(
                    child.body)  # If.body无法直接处理,先反解析为源代码,再重新生成ast
                ast_then = ast.parse(source_then)
                ast_construct(ast_then, node_then)

                node_else = Node('Else', value)
                node_if.insert_child(node_else)
                source_else = astunparse.unparse(child.orelse)
                ast_else = ast.parse(source_else)
                ast_construct(ast_else, node_else)
            else:
                source_body = astunparse.unparse(
                    child.body)  # If.body无法直接处理,先泛解析为源代码,再重新生成ast
                ast_then = ast.parse(source_body)
                ast_construct(ast_then, node_if)

        if isinstance(child, ast.FunctionDef):
            value = child.name
            value = pattern.sub('', value)
            label = 'FunctionDef'
            n_functiondef = Node(label, value)
            parent.insert_child(n_functiondef)
            #parent.insert_child(n_functiondef)
            #处理FunctionDef的参数和装饰器decorator_list,参数分为位置参数args, 可变长度参数vararg和关键字参数kwargs几类
            source_args = astunparse.unparse(child.args)
            source_args = pattern.sub('', source_args)
            args = list()  #位置参数
            vararg = list()  #可变长度位置参数*args
            kwonlyargs = list()  #关键字参数
            kwargs = list()  # 可变长度关键字参数**args
            defaults = list()
            kw_defaults = list()
            varargs_flag = 0
            if source_args != "":
                source_args = source_args.split(',')
                for item in source_args:
                    if '**' in item:
                        kwargs.append(item[2:])
                    elif '*' in item:
                        vararg.append(item[1:])
                        varargs_flag = 1
                    elif '=' in item and varargs_flag == 1:
                        kw_defaults.append(item)
                    elif '=' not in item and varargs_flag == 1:
                        kwonlyargs.append(item)
                    elif '=' in item:
                        defaults.append(item)
                    else:
                        args.append(item)

            for arg in args:
                arg_node = Node('args', arg)
                n_functiondef.insert_child(arg_node)

            for arg in vararg:
                vararg_node = Node('vararg', arg)
                n_functiondef.insert_child(vararg_node)

            for arg in kwonlyargs:
                kwonlyargs_node = Node('kwnolyarg', arg)
                n_functiondef.insert_child(kwonlyargs_node)

            for arg in kwargs:
                kwargs_node = Node('kwargs', arg)
                n_functiondef.insert_child(kwargs_node)

            for arg in defaults:
                default_node = Node('default_args', arg)
                n_functiondef.insert_child(default_node)

            for arg in kw_defaults:
                kw_defaults_node = Node('kw_default', arg)
                n_functiondef.insert_child(kw_defaults_node)

                #处理Function_body
            source_functionbody = astunparse.unparse(child.body)
            ast_functionbody = ast.parse(source_functionbody)
            ast_construct(ast_functionbody, n_functiondef)

        if isinstance(child, ast.ClassDef):
            class_name = child.name
            label = 'ClassDef'
            node_classdef = Node(label, class_name)
            parent.insert_child(node_classdef)
            # 处理基类
            bases = astunparse.unparse(
                child.bases)  #只考虑name和base, keywos, starargs等暂时不考虑
            bases = pattern.sub('', bases)
            if bases != '':
                node_bases = Node('base', bases)
                node_classdef.insert_child(node_bases)

                #处理 class body
            source_classbody = astunparse.unparse(child.body)
            ast_classbody = ast.parse(source_classbody)
            ast_construct(ast_classbody, node_classdef)

        if isinstance(child, ast.For):
            label = 'For'
            value = astunparse.unparse(child.iter)  #For的value在考虑
            value = pattern.sub('', value)
            node_for = Node(label, value)
            parent.insert_child(node_for)
            #判断For循环是否有else部分,如果有,new一个ForBody和ForElse作为Node_For的孩子节点;否则,for.body直接作为For的孩子节点插入
            if child.orelse:
                #处理For.body
                node_then = Node('Then', value)
                node_for.insert_child(node_then)
                source_then = astunparse.unparse(child.body)
                ast_then = ast.parse(source_then)
                ast_construct(ast_then, node_then)
                node_else = Node('Else', value)
                node_for.insert_child(node_else)
                source_else = astunparse.unparse(child.orelse)
                ast_else = ast.parse(source_else)
                ast_construct(ast_else, node_else)
            else:
                source_forbody = astunparse.unparse(child.body)
                ast_forbody = ast.parse(source_forbody)
                ast_construct(ast_forbody, node_for)

        if isinstance(child, ast.While):
            label = 'While'
            value = astunparse.unparse(child.test)
            value = pattern.sub('', value)
            node_while = Node(label, value)
            parent.insert_child(node_while)
            source_whilebody = astunparse.unparse(child.body)
            ast_whilebody = ast.parse(source_whilebody)
            ast_construct(ast_whilebody, node_while)
def replace_validator():
    with open(root+'/Framework_kernel/validator.py', 'r+', encoding='utf-8') as f1, \
            open(root+'/Framework_Performance/validator_and_others.py', 'r+', encoding='utf-8')as f2,\
            open(root+'/Framework_kernel/host.py', 'r+', encoding='utf-8')as f3,\
            open(root+'/Framework_kernel/report.py', 'r+', encoding='utf-8')as f4,\
            open(root+'/Framework_kernel/execution_engine.py', 'r+', encoding='utf-8')as f5:
        validator = f1.read()
        validator_node = ast.parse(validator)
        fake_validator = f2.read()
        fake_validator_node = ast.parse(fake_validator)
        host = f3.read()
        host_node = ast.parse(host)
        report = f4.read()
        report_node = ast.parse(report)
        execution_engine = f5.read()
        execution_engine_node = ast.parse(execution_engine)
        # print(ast.dump(fake_validator_node))
        validator_map_dict = {
            'validate_jenkins_server': fake_validator_node.body[5].body[0],
            'validate_build_server': fake_validator_node.body[5].body[1],
            '__validate_QTP': fake_validator_node.body[5].body[2],
            '__validate_HPDM': fake_validator_node.body[5].body[3],
            'validate_uut': fake_validator_node.body[5].body[4],
            'validate_ftp': fake_validator_node.body[5].body[5],
            'validate': fake_validator_node.body[6].body[0],
            'build_task': fake_validator_node.body[-3],
            '__load_uut_result': fake_validator_node.body[-2],
            'send_report': fake_validator_node.body[-1],
        }
        # skip validate
        for i in validator_node.body:
            i_index = validator_node.body.index(i)
            if isinstance(i, _ast.ClassDef) and i.name == 'HostValidator':
                print(validator_node.body[i_index])
                for j in i.body:
                    j_index = i.body.index(j)
                    if isinstance(j, _ast.FunctionDef
                                  ) and j.name in validator_map_dict.keys():
                        validator_node.body[i_index].body[
                            j_index] = validator_map_dict[j.name]
            if isinstance(i, _ast.ClassDef) and i.name == 'ScriptValidator':
                for j in i.body:
                    j_index = i.body.index(j)
                    if isinstance(j, _ast.FunctionDef
                                  ) and j.name in validator_map_dict.keys():
                        validator_node.body[i_index].body[
                            j_index] = validator_map_dict[j.name]
        validator_source = astunparse.unparse(validator_node)
        f1.seek(0, 0)
        f1.truncate()
        f1.write(validator_source)
        # host skip generate script yml
        for h in host_node.body:
            h_index = host_node.body.index(h)
            if isinstance(h, _ast.ClassDef) and h.name == 'Build':
                # print(host_node.body[h_index])
                for j in h.body:
                    j_index = h.body.index(j)
                    if isinstance(j, _ast.FunctionDef
                                  ) and j.name in validator_map_dict.keys():
                        host_node.body[h_index].body[
                            j_index] = validator_map_dict[j.name]
                        host_source = astunparse.unparse(host_node)
                        f3.seek(0, 0)
                        f3.truncate()
                        f3.write(host_source)
        # report cancel error handler
        for r in report_node.body:
            r_index = report_node.body.index(r)
            if isinstance(r, _ast.ClassDef) and r.name == 'Report':
                # print(host_node.body[h_index])
                for j in r.body:
                    j_index = r.body.index(j)
                    if isinstance(j, _ast.FunctionDef
                                  ) and j.name in validator_map_dict.keys():
                        report_node.body[r_index].body[
                            j_index] = validator_map_dict[j.name]
                        report_source = astunparse.unparse(report_node)
                        f4.seek(0, 0)
                        f4.truncate()
                        f4.write(report_source)
        # execution_engine cancel email
        for e in execution_engine_node.body:
            e_index = execution_engine_node.body.index(e)
            if isinstance(e, _ast.ClassDef) and e.name == 'ExecutionEngine':
                # print(host_node.body[h_index])
                for j in e.body:
                    j_index = e.body.index(j)
                    if isinstance(j, _ast.FunctionDef
                                  ) and j.name in validator_map_dict.keys():
                        execution_engine_node.body[e_index].body[
                            j_index] = validator_map_dict[j.name]
                        report_source = astunparse.unparse(
                            execution_engine_node)
                        f5.seek(0, 0)
                        f5.truncate()
                        f5.write(report_source)
コード例 #48
0
 def save(self):
     code = astunparse.unparse(self._source_tree)
     with open(self.file, 'w') as f:
         f.write(code)
コード例 #49
0
        print("set arg to 'hello world'")

        args[0] = ast.Constant(s="hello world", kind=None)
        ast.fix_missing_locations(args[0])
        print("add another arg")

        last_index = len(args)
        print("last index: ", last_index)

        args.insert(last_index, ast.Constant(s="!!!!!", kind=None))
        ast.fix_missing_locations(args[last_index])

        print("add test call one more time.")
        node = ast.parse(add_arg).body[0].value
        args.insert(last_index + 1, node)
        ast.fix_missing_locations(args[last_index + 1])


v = Visitor()

v.visit(tree)

co = compile(tree, "log.log", "exec")

exec(co)

import astunparse

print(astunparse.unparse(tree))
コード例 #50
0
 def err():
     raise ValueError('Unexpected kind of slice: {}'.format(astunparse.unparse(subscript)))
コード例 #51
0
 def visit_Assign(self, state, assgn):
     (parameters, initialised, used) = state
     if len(assgn.targets) > 1:
         self.__messages.append(CheckError("Cannot process tuple assignment in '%s'." % (astunparse.unparse(assgn).rstrip()), assgn))
     target = assgn.targets[0]
     value = assgn.value
     if isinstance(target, ast.Name):
         assignedVar = target.id
         if isinstance(value, ast.Num):
             if self.__is_declared(assignedVar):
                 self.__messages.append(CheckError("Trying to assign num literal '%i' to model variable '%s'." % (value.n, assignedVar), assgn))
         elif isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
             if (value.func.id in ["Var", "Param", "Input", "Output"]):
                 if isinstance(value.args[0], ast.Num):
                     self.__set_domain(assgn, assignedVar, value.args[0].n)
                 else:
                     self.__messages.append(CheckError("Cannot declare variable/parameter with non-constant range '%s'." % (astunparse.unparse(value.func.args[0]).rstrip()), assgn))
                 if value.func.id == "Param":
                     return (parameters | {assignedVar}, initialised, used)
                 if value.func.id == "Input":
                     return (parameters, initialised | {assignedVar}, used)
                 if value.func.id in ["Output"]:
                     self.__outputs.append(assignedVar)
                     return (parameters, initialised, used | {assignedVar})
             elif value.func.id in self.__defined_functions:
                 for argument in value.args:
                     (_, _, used) = self.visit((parameters, initialised, used), argument)
                 return (parameters, initialised | {assignedVar}, used)
             else:
                 self.__messages.append(CheckError("Cannot assign unknown function to variable '%s'." % (assignedVar), assgn))
         else:
             if self.__is_declared(assignedVar):
                 self.__messages.append(CheckError("Trying to assign '%s' to model variable '%s'." % (astunparse.unparse(value).rstrip(), assignedVar), assgn))
     else:
         self.__messages.append(CheckError("Cannot assign value to non-variable '%s'." % (astunparse.unparse(target).rstrip()), assgn))
     return state
コード例 #52
0
 def generic_visit(self, state, node):
     self.__messages.append(CheckError("AST node '%s' unsupported." % (astunparse.unparse(node).strip()), node))
コード例 #53
0
        def visit_If(self, node):
            node.test = self.visit(node.test)
            node = eval_const_expressions(node)
            if not(isinstance(node, ast.If)):
                if isinstance(node, list):
                    return flat_map(self.visit, node)
                else:
                    return self.visit(node)

            # We want to unroll the "else: P" bit into something explicit for
            # all the cases that haven't been checked explicitly yet.
            def get_or_cases(test):
                if not(isinstance(test, ast.BoolOp) and isinstance(test.op, ast.Or)):
                    return [test]
                else:
                    return itertools.chain.from_iterable(get_or_cases(value) for value in test.values)

            def get_name_and_const_from_test(test):
                if not(isinstance(test.ops[0], ast.Eq)):
                    raise Exception("Tests in if can only use ==, not '%s'." % (astunparse.unparse(test.ops[0]).rstrip()), node)
                (name, const) = (None, None)
                if (isinstance(test.left, ast.Name) or isinstance(test.left, ast.Subscript)) and isinstance(test.comparators[0], ast.Num):
                    (name, const) = (test.left, test.comparators[0].n)
                elif (isinstance(test.comparators[0], ast.Name) or isinstance(test.comparators[0], ast.Subscript)) and isinstance(test.left, ast.Num):
                    (name, const) = (test.comparators[0], test.left.n)
                return (name, const)

            checked_values = set()
            checked_vars = set()

            # Now walk the .orelse branches, visiting each body independently. Expand ors on the way:
            last_if = None
            current_if = node
            while True:
                # Recurse with visitor first:
                current_if.test = eval_const_expressions(self.visit(current_if.test))
                current_if.body = flat_map(self.visit, current_if.body)
                # Now, unfold ors:
                test_cases = list(get_or_cases(current_if.test))
                if len(test_cases) > 1:
                    else_body = current_if.orelse
                    new_if_node = None
                    for test_case in reversed(test_cases):
                        body_copy = copy.deepcopy(current_if.body)
                        new_if_node = ast.copy_location(ast.If(test=test_case,
                                                               body=body_copy,
                                                               orelse=else_body),
                                                        node)
                        else_body = [new_if_node]
                    current_if = new_if_node
                    # Make the change stick:
                    if last_if is None:
                        node = current_if
                    else:
                        last_if.orelse = [current_if]

                # Do our deed:
                try:
                    (checked_var, checked_value) = get_name_and_const_from_test(current_if.test)
                    checked_vars.add(checked_var)
                    checked_values.add(checked_value)
                    # Look at the next elif:
                    if len(current_if.orelse) == 1 and isinstance(current_if.orelse[0], ast.If):
                        last_if = current_if
                        current_if = current_if.orelse[0]
                    else:
                        break
                except:
                    # This may happen if we couldn't cleanly identify the else case. For this, just leave things as they are:
                    return node

            # We need to stringify them, to not be confused by several instances refering to the same thing:
            checked_var_strs = set(astunparse.unparse(var) for var in checked_vars)
            if len(checked_var_strs) != 1:
                raise Exception("If-else checking more than one variable (%s)." % (checked_var_strs))
            checked_var = checked_vars.pop()
            domain_to_check = set(range(get_variable_domain(checked_var)))
            still_unchecked = domain_to_check - checked_values

            else_body = flat_map(self.visit, current_if.orelse)

            if len(else_body) == 0:
                return node

            # if len(still_unchecked) > 0:
            #     print("Else for values %s of %s:\n%s" % (still_unchecked, astunparse.unparse(checked_var).rstrip(), astunparse.unparse(else_body)))
            for value in still_unchecked:
                # print("Inserting case %s == %i in else unfolding." % (astunparse.unparse(checked_var).rstrip(), value))
                var_node = copy.deepcopy(checked_var)
                eq_node = ast.copy_location(ast.Eq(), node)
                value_node = ast.copy_location(ast.Num(n=value), node)
                test_node = ast.copy_location(ast.Compare(var_node, [eq_node], [value_node]), node)
                case_body = copy.deepcopy(else_body)
                new_if_node = ast.copy_location(ast.If(test=test_node, body=case_body, orelse=[]), node)
                current_if.orelse = [new_if_node]
                current_if = new_if_node

            return node
コード例 #54
0
 def __get_domain_of_expr(self, node):
     if isinstance(node, ast.Name):
         return self.__get_domain(node, node.id)
     elif isinstance(node, ast.Call):
         (_, value_domain) = self.__defined_functions[node.func.id]
         return value_domain
     elif isinstance(node, ast.Num):
         return node.n + 1 #And now all together: Zero-based counting is haaaaard.
     else:
         self.__messages.append(CheckError("Cannot determine domain of value '%s' used in assignment." % (astunparse.unparse(node).rstrip()), node))
         return 0
コード例 #55
0
def test_lambda_copy_simple():
    a, new_a = util_run_parse("lambda a: a")
    assert unparse(new_a).strip() == "(lambda arg_0: arg_0)"
    assert ast.dump(new_a) != ast.dump(a)
コード例 #56
0
 def visit_Call(self, state, call):
     (parameters, initialised, used) = state
     if isinstance(call.func, ast.Attribute):
         func = call.func
         func_name = func.attr
         set_variable = func.value.id
         if func_name in ["set_to_constant", "set_to"]:
             #Check that the children are fine:
             if len(call.args) != 1:
                 self.__messages.append(CheckError("'%s.%s' with more than one argument unsupported." % (set_variable, func_name), call))
             value = call.args[0]
             (_, _, val_used_vars) = self.visit(state, value)
             if set_variable in initialised:
                 self.__messages.append(CheckError("Trying to reset value of variable '%s'." % (set_variable), call))
             if set_variable in parameters:
                 self.__messages.append(CheckWarning("Setting value of parameter '%s'." % (set_variable), call))
             domain = self.__get_domain(call, set_variable)
             if isinstance(value, ast.Num) and (value.n < 0 or value.n >= domain):
                 self.__messages.append(CheckError("Trying to set variable '%s' (domain [0..%i]) to invalid value '%i'." % (set_variable, domain - 1, value.n), call))
             else:
                 value_domain = self.__get_domain_of_expr(value)
                 if value_domain != domain:
                     if isinstance(value, ast.Num) and (value.n >= domain or value.n < 0):
                         self.__messages.append(CheckError("Trying to set variable '%s' (domain [0..%i]) to value '%s'." % (set_variable, domain - 1, astunparse.unparse(value).rstrip()), value))
                     elif not(isinstance(value, ast.Num)):
                         self.__messages.append(CheckError("Trying to set variable '%s' (domain [0..%i]) to value '%s' with different domain [0..%i]." % (set_variable, domain - 1, astunparse.unparse(value).rstrip(), value_domain - 1), value))
             return (parameters, initialised | {set_variable}, val_used_vars)
         elif func_name in ["set_as_input"]:
             return (parameters, initialised | {set_variable}, used)
         elif func_name in ["set_as_output"]:
             return (parameters, initialised, used | {set_variable})
         elif func_name == "observe_value":
             #Check that the children are fine:
             if len(call.args) != 1:
                 self.__messages.append(CheckError("'%s.%s' with more than one argument unsupported." % (set_variable, func_name), call))
             (_, _, val_used_vars) = self.visit(state, call.args[0])
             if set_variable not in initialised:
                 self.__messages.append(CheckError("Observation of potentially uninitialised variable '%s'." % (set_variable), call))
             return (parameters, initialised, val_used_vars | {set_variable})
         else:
             self.__messages.append(CheckError("Unsupported call '%s'." % (astunparse.unparse(call).rstrip()), call))
     else:
         func_name = call.func.id
         func_information = self.__defined_functions.get(func_name, None)
         if func_information != None:
             (par_domains, _) = func_information
             used_vars = used
             if len(call.args) != len(par_domains):
                 self.__messages.append(CheckError("Call to %i-ary function '%s' with %i arguments." % (len(par_domains), func_name, len(call.args)), call))
             for idx in range(len(call.args)):
                 arg = call.args[idx]
                 (_, _, used_vars) = self.visit((parameters, initialised, used_vars), arg)
                 par_domain = par_domains[idx]
                 arg_domain = self.__get_domain_of_expr(arg)
                 if arg_domain != par_domain:
                     if isinstance(arg, ast.Num) and (arg.n >= par_domain or arg.n < 0):
                         self.__messages.append(CheckError("Parameter %i of function '%s' has domain [0..%i], but argument value '%s' is incompatible." % (idx + 1, func_name, par_domain - 1, astunparse.unparse(arg).rstrip()), arg))
                     elif not(isinstance(arg, ast.Num)):
                         self.__messages.append(CheckError("Parameter %i of function '%s' has domain [0..%i], but argument value '%s' has different domain [0..%i]." % (idx + 1, func_name, par_domain - 1, astunparse.unparse(arg).rstrip(), arg_domain - 1), arg))
             return (parameters, initialised, used_vars)
         else:
             self.__messages.append(CheckError("Call to undefined functions '%s'." % (func_name), call))
コード例 #57
0
def test_lambda_copy_no_arg():
    a, new_a = util_run_parse("lambda: 1+1")
    assert unparse(new_a).strip() == "(lambda : (1 + 1))"
    assert a is not new_a
コード例 #58
0
def test_lambda_copy_nested_same_arg_name():
    a, new_a = util_run_parse("lambda a: (lambda a: a)(a)")
    assert unparse(new_a).strip() == "(lambda arg_0: (lambda a: a)(arg_0))"
コード例 #59
0
def build_ignore_context_manager(ctx, stmt):
    InputType = namedtuple('InputType', ['name', 'ann'])
    OutputType = namedtuple('OutputType', ['name', 'ann'])

    def process_ins_outs(args):
        # parse the context manager to figure out inputs and outputs
        # with their annotated types
        # TODO: add input, output validator
        inputs = []
        outputs = []
        for arg in args:
            var_name = arg.arg
            if sys.version_info < (3, 8):
                # Starting python3.8 ast.Str is deprecated
                var_ann = arg.value.s
            else:
                var_ann = arg.value.value
            var_decl_type, var_ann = var_ann.split(":")
            if var_decl_type == "inp":
                inputs.append(InputType(var_name, var_ann))
            if var_decl_type == "out":
                outputs.append(OutputType(var_name, var_ann))
        return inputs, outputs

    def create_unique_name_ext(ctx, stmt):
        # extension will be based on the full path filename plus
        # the line number of original context manager
        fn = re.sub(r'[^a-zA-Z0-9_]', '_', ctx.filename)
        return f"{fn}_{stmt.lineno}"

    def build_return_ann_stmt(outputs):
        return_type_ann = ""
        return_statement_str = "return "
        if len(outputs) == 0:
            return_type_ann += " -> None"
        if len(outputs) == 1:
            return_type_ann = " -> " + outputs[0].ann
            return_statement_str += outputs[0].name
        if len(outputs) > 1:
            return_type_ann = " -> Tuple"
            return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]"
            return_statement_str += ", ".join([var.name for var in outputs])
        return return_type_ann, return_statement_str

    def build_args(args):
        return ", ".join([arg.name for arg in args])

    inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords)

    # build the replacement function str with given inputs and outputs
    ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt)
    ignore_function_str = "\ndef " + ignore_function_name
    ignore_function_str += "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")"

    return_ann, return_stmt = build_return_ann_stmt(outputs)
    ignore_function_str += return_ann + ": pass"

    # first create the functionDef object from just declaration
    ignore_function = ast.parse(ignore_function_str).body[0]

    # dump the body of context manager to dummy function
    ignore_function.body = stmt.body  # type: ignore[attr-defined]

    # insert return statement to the function
    return_stmt = ast.parse(return_stmt).body[0]
    ignore_function.body.append(return_stmt)  # type: ignore[attr-defined]

    # registers the custom function in the global context
    ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function)
    ignore_func_str += "\nglobals()[\"{}\"] = {}".format(ignore_function_name, ignore_function_name)
    exec(ignore_func_str)  # noqa: P204

    # build the statements as:
    # <out_1>, <out_2>, ... = torch.jit.frontend.<func>(<in_1>, <in_2>)
    assign_str_lhs = build_args(outputs)
    # this function will be registered in torch.jit.frontend module by default
    assign_str_rhs = "torch.jit.frontend.{}(".format(ignore_function_name) + build_args(inputs) + ")"

    if len(outputs) > 0:
        assign_str = assign_str_lhs + " = " + assign_str_rhs
    else:
        assign_str = assign_str_rhs
    assign_ast = ast.parse(assign_str).body[0]
    return assign_ast
コード例 #60
0
def slice_node_to_tuple_of_numbers(slice_node):
    if isinstance(slice_node.value, ast.Tuple):
        indices = (elt for elt in slice_node.value.elts)
    else:
        indices = (slice_node.value,)

    indices = list(indices)
    for index in indices:
        if not(isinstance(index, ast.Num)):
            error.fatal_error("Trying to use non-constant value '%s' as array index." % (astunparse.unparse(index).rstrip()), index)

    # Convert to python numbers
    indices = (index.n for index in indices)

    return indices