示例#1
0
 def test_TryFinally(self):
     code = 'try:pass\nfinally:pass'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[Try(body=[Pass()], handlers=[], orelse=[], "
             "finalbody=[Pass()])], type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#2
0
 def test_Bytes(self):
     code = 'b"0012"'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[Expr(value=Constant(value=b'0012', "
             "kind=None))], type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#3
0
 def test_walk(self):
     code = 'x + 1'
     tree = gast.parse(code, mode='eval')
     dump = gast.dump(tree)
     norm = ("Expression(body=BinOp(left=Name(id='x', ctx=Load(), "
             "annotation=None), op=Add(), right=Num(n=1)))")
     self.assertEqual(dump, norm)
     self.assertEqual(len(list(gast.walk(tree))), 6)
示例#4
0
 def test_TryExcept(self):
     code = 'try:pass\nexcept e:pass\nelse:pass'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[Try(body=[Pass()], handlers=[ExceptHandler("
             "type=Name(id='e', ctx=Load(), annotation=None, "
             "type_comment=None), name=None, body=[Pass()])]"
             ", orelse=[Pass()], finalbody=[])], type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#5
0
 def test_dump(self):
     code = 'lambda x: x'
     tree = gast.parse(code, mode='eval')
     dump = gast.dump(tree)
     norm = ("Expression(body=Lambda(args=arguments(args=[Name(id='x', "
             "ctx=Param(), annotation=None)], vararg=None, kwonlyargs=[], "
             "kw_defaults=[], kwarg=None, defaults=[]), body=Name(id='x', "
             "ctx=Load(), annotation=None)))")
     self.assertEqual(dump, norm)
示例#6
0
 def test_NamedExpr(self):
     code = '(x := 1) '
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[Expr(value=NamedExpr(target=Name(id='x',"
             " ctx=Store(), annotation=None, type_comment=None), "
             "value=Constant(value=1, kind=None)))], type_ignores="
             "[])")
     self.assertEqual(gast.dump(tree), norm)
示例#7
0
def test_preprocess_augassign():
    # test cases: line 2 is input AST, line 3 is expected output AST
    def add_augassign_fn():
        x, y = 1, 2
        x += y
        x = x + y

    def sub_augassign_fn():
        x, y = 1, 2
        x -= y
        x = x - y

    def mul_augassign_fn():
        x, y = 1, 2
        x *= y
        x = x * y

    def div_augassign_fn():
        x, y = 1, 2
        x /= y
        x = x / y

    class C:
        x = 1

    def attribute_augassign_fn():
        y = 1, 2
        C.x /= y
        C.x = C.x / y

    augassign_fns = [
        add_augassign_fn,
        sub_augassign_fn,
        mul_augassign_fn,
        div_augassign_fn,
        attribute_augassign_fn,
    ]

    for fn in augassign_fns:
        fn_ast = parse_ast(fn).body[0]
        aug_ast, expected_ast = fn_ast.body[1], fn_ast.body[2]
        actual_ast = preprocess_augassign(aug_ast)
        assert gast.dump(actual_ast) == gast.dump(expected_ast)
示例#8
0
 def test_With(self):
     code = 'with open("any"): pass'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[With(items=[withitem(context_expr=Call(func="
             "Name(id='open', ctx=Load(), annotation=None, "
             "type_comment=None), args=[Constant(value='any', "
             "kind=None)], keywords=[]), optional_vars=None)], body=["
             "Pass()], type_comment=None)], type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#9
0
 def test_TypeIgnore(self):
     code = 'def foo(): pass  # type: ignore[excuse]'
     tree = gast.parse(code, type_comments=True)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[FunctionDef(name='foo', args=arguments("
             "args=[], posonlyargs=[], vararg=None, kwonlyargs=[], "
             "kw_defaults=[], kwarg=None, defaults=[]), body=["
             "Pass()], decorator_list=[], returns=None, "
             "type_comment=None)], type_ignores="
             "[TypeIgnore(lineno=1, tag='[excuse]')])")
     self.assertEqual(gast.dump(tree), norm)
示例#10
0
 def dump_ast(mod, name):
     print(gast.dump(mod))
     if IMPORT_ASTMONKEY:
         mod = deepcopy(mod)
         mod = transformers.ParentChildNodeTransformer().visit(deepcopy(mod))
         visitor = visitors.GraphNodeVisitor()
         visitor.visit(mod)
         visitor.graph.write_png(name + '.png')
         print("\033[1;32;40mAST visualization saved as \033[94m%s.png\033[0m" % name)
     else:
         print("\033[93mInstall astmonkey for visualization.\033[0m")
示例#11
0
 def test_keyword_argument(self):
     code = 'def foo(**a): pass'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args=[], "
             "posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[], "
             "kwarg=Name(id='a', ctx=Param(), annotation=None, "
             "type_comment=None), defaults=[]), body=[Pass()], "
             "decorator_list=[], returns=None, type_comment=None)], "
             "type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#12
0
 def test_KeywordOnlyArgument(self):
     code = 'def foo(*, x=1): pass'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args="
             "[], posonlyargs=[], vararg=None, kwonlyargs=[Name"
             "(id='x', ctx=Param(), annotation=None, type_comment=None"
             ")], kw_defaults=[Constant(value=1, kind=None)], kwarg="
             "None, defaults=[]), body=[Pass()], decorator_list=[], "
             "returns=None, type_comment=None)], type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#13
0
 def test_FormattedValue(self):
     code = 'e = 1; f"{e}"'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[Assign(targets=[Name(id='e', ctx=Store()"
             ", annotation=None, type_comment=None"
             ")], value=Constant(value=1, kind=None)), Expr(value="
             "JoinedStr(values=[FormattedValue(value=Name(id='e', "
             "ctx=Load(), annotation=None, type_comment=None), "
             "conversion=-1, format_spec=None)]))], "
             "type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#14
0
def eval_ast(nast, env):
    for k, v in env.get_var_dict().items():
        assert not isinstance(v, onnx.ValueInfoProto), '%s %s' % (k, v)

    global _eval_ast_depth
    if not isinstance(nast, list):
        dprint('-' * _eval_ast_depth, gast.dump(nast), env.get_var_dict().keys())

    _eval_ast_depth += 1
    r = eval_ast_impl(nast, env)
    _eval_ast_depth -= 1
    return _value(r)
示例#15
0
 def test_Index(self):
     code = 'def foo(a): a[1]'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args=["
             "Name(id='a', ctx=Param(), annotation=None, type_comment=None)"
             "], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[]"
             ", kwarg=None, defaults=[]), body=[Expr(value=Subscript(value="
             "Name(id='a', ctx=Load(), annotation=None, type_comment=None)"
             ", slice=Index(value=Constant(value=1, kind=None)), ctx=Load()"
             "))], decorator_list=[], returns=None, type_comment=None)]"
             ", type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#16
0
    def infer_stmt(self, node):
        if self.is_debug:
            debug(gast.dump(node))

        if isinstance(node, gast.FunctionDef):
            self.nodetype[node] = self.infer_FunctionDef(node)
        elif isinstance(node, gast.Return):
            # Return(expr? value)
            if node.value is None:
                self.nodetype[node] = TyNone()
            else:
                self.nodetype[node] = self.infer_expr(node.value)
        elif isinstance(node, gast.Delete):
            # TODO(momohatt): erase from tyenv, etc.
            # TODO(momohatt): support deletion of element from list
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.Assign):
            self.infer_Assign(node)
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.AugAssign):
            self.infer_AugAssign(node)
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.For):
            self.infer_For(node)
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.While):
            # While(expr test, stmt* body, stmt* orelse)
            pass
        elif isinstance(node, gast.If):
            self.nodetype[node] = self.infer_If(node)
        elif isinstance(node, gast.Raise):
            self.nodetype[node] = TyVar()
        elif isinstance(node, gast.Try):
            # TODO(momohatt): What is 'finalbody' ?
            ty_ret = self.infer_2blocks(self, self, node.body, node.orelse)
            self.nodetype[node] = ty_ret
        elif isinstance(node, gast.Assert):
            self.nodetype[node] = TyNone()
        elif isinstance(node, (gast.Import, gast.ImportFrom)):
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.Expr):
            # Expr(expr value)
            self.infer_expr(node.value)
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.Pass):
            self.nodetype[node] = TyNone()

        assert node in self.nodetype.keys(), type(node).__name__
        return self.nodetype[node]
示例#17
0
 def test_ExtSlices(self):
     self.maxDiff = None
     code = 'def foo(a): a[1,:]'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args=["
             "Name(id='a', ctx=Param(), annotation=None, type_comment=None)"
             "], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[]"
             ", kwarg=None, defaults=[]), body=[Expr(value=Subscript(value="
             "Name(id='a', ctx=Load(), annotation=None, type_comment=None)"
             ", slice=Tuple(elts=[Constant(value=1, kind="
             "None), Slice(lower=None, upper=None, step=None)], ctx=Load())"
             ", ctx=Load()))], decorator_list=[], returns=None, "
             "type_comment=None)], type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#18
0
 def test_Call(self):
     self.maxDiff = None
     code = 'foo(x, y=1, *args, **kwargs)'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[Expr(value=Call(func=Name(id='foo', ctx=Load"
             "(), annotation=None, type_comment=None"
             "), args=[Name(id='x', ctx=Load(), "
             "annotation=None, type_comment=None), Starred(value=Name("
             "id='args', ctx=Load(), annotation=None, type_comment=None)"
             ", ctx=Load())], keywords=[keyword("
             "arg='y', value=Constant(value=1, kind=None)), keyword(arg"
             "=None, value=Name(id='kwargs', ctx=Load(), annotation=None, "
             "type_comment=None))]))], type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#19
0
 def test_FunctionDef(self):
     code = 'def foo((x, y)): return x, y'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args="
             "[Tuple(elts=[Name(id='x', ctx=Store(), annotation=None, "
             "type_comment=None), Name(id='y', ctx=Store(), "
             "annotation=None, type_comment=None)], ctx=Store())], "
             "posonlyargs=[], vararg=None, "
             "kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), "
             "body=[Return(value=Tuple(elts=[Name(id='x', ctx=Load(), "
             "annotation=None, type_comment=None), "
             "Name(id='y', ctx=Load(), "
             "annotation=None, type_comment=None"
             ")], ctx=Load()))], decorator_list="
             "[], returns=None, type_comment=None)], type_ignores=[])")
     self.assertEqual(gast.dump(tree), norm)
示例#20
0
    def infer_stmt(self, node):
        if self.is_debug:
            debug(gast.dump(node))

        self.stack.append(node)

        if isinstance(node, gast.FunctionDef):
            self.nodetype[node] = self.infer_FunctionDef(node)
        elif isinstance(node, gast.Return):
            # Return(expr? value)
            if node.value is None:
                self.nodetype[node] = TyNone()
            else:
                self.nodetype[node] = self.infer_expr(node.value)
        elif isinstance(node, gast.Delete):
            # TODO(momohatt): erase from tyenv, etc.
            # TODO(momohatt): support deletion of element from list
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.Assign):
            self.infer_Assign(node)
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.AugAssign):
            self.infer_AugAssign(node)
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.For):
            self.infer_For(node)
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.While):
            # While(expr test, stmt* body, stmt* orelse)
            pass
        elif isinstance(node, gast.If):
            self.infer_If(node)
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.Expr):
            # Expr(expr value)
            self.infer_expr(node.value)
            self.nodetype[node] = TyNone()
        elif isinstance(node, gast.Pass):
            self.nodetype[node] = TyNone()

        assert node in self.nodetype.keys(), type(node).__name__
        self.stack.pop()
        return self.nodetype[node]
示例#21
0
    def test_var_env(self):

        for i, func in enumerate(test_funcs):
            var_type = result_var_type[i]
            test_source_code = inspect.getsource(func)
            ast_root = gast.parse(test_source_code)
            print(gast.dump(ast_root))
            visitor = StaticAnalysisVisitor(ast_root)
            var_env = visitor.get_var_env()

            # There must be 1 sub scope for the test function
            self.assertEqual(1, len(var_env.cur_scope.sub_scopes))
            var_env.cur_scope = var_env.cur_scope.sub_scopes[0]

            scope_var_type = var_env.get_scope_var_type()
            print(scope_var_type)
            self.assertEqual(len(scope_var_type), len(var_type))
            for name in scope_var_type:
                print("Test var name %s" % (name))
                self.assertTrue(name in var_type)
                self.assertEqual(scope_var_type[name], var_type[name])
示例#22
0
    def test_Raise(self):
        codes = (
            'raise Exception',
            'raise "Exception"',
            'raise Exception, "err"',
            'raise Exception("err")',
            'raise E, V, T',
        )
        norms = (
            "Module(body=[Raise(exc=Name(id='Exception', ctx=Load(), "
            "annotation=None, type_comment=None),"
            " cause=None)], type_ignores=[])",
            "Module(body=[Raise(exc=Constant(value='Exception', kind="
            "None), cause=None)], type_ignores=[])",
            "Module(body=[Raise(exc=Call(func=Name(id='Exception', "
            "ctx=Load(), annotation=None, type_comment=None), "
            "args=[Constant(value='err', kind=None)], "
            "keywords=[]), cause=None)], type_ignores=[])",
            "Module(body=[Raise(exc=Call(func=Name(id='Exception', "
            "ctx=Load(), annotation=None, type_comment=None), "
            "args=[Constant(value='err', kind=None)], "
            "keywords=[]), cause=None)], type_ignores=[])",
            "Module(body=[Raise(exc=Call(func=Attribute(value=Call("
            "func=Name(id='E', ctx=Load(), annotation=None, "
            "type_comment=None), args=[Name(id='V', ctx="
            "Load(), annotation=None, type_comment=None)], keywords=[]), "
            "attr='with_traceback', ctx=Load"
            "()), args=[Name(id='T', ctx=Load(), annotation=None, "
            "type_comment=None)], keywords=[]), "
            "cause=None)], type_ignores=[])",
        )

        if sys.version_info.major == 3:
            codes = codes[0], codes[1], codes[3]
            norms = norms[0], norms[1], norms[3]

        for code, norm in zip(codes, norms):
            tree = gast.parse(code)
            compile(gast.gast_to_ast(tree), '<test>', 'exec')
            self.assertEqual(gast.dump(tree), norm)
示例#23
0
    def get_function_instance(self, node):
        if isinstance(node, gast.Attribute):
            if isinstance(node.value, gast.Name) and \
                    hasattr(self.module, node.value.id):
                # function of imported libraries (eg. np, chainer, F, L)
                module = getattr(self.module, node.value.id)
                return getattr(module, node.attr), None

            ty_obj = self.infer_expr(node.value).deref()

            if isinstance(ty_obj, TyList):
                return getattr(list, node.attr, None), ty_obj

            if isinstance(ty_obj, TyTensor):
                if ty_obj.is_ndarray():
                    return getattr(np.ndarray, node.attr, None), ty_obj
                if ty_obj.is_torch_tensor():
                    return getattr(torch.Tensor, node.attr, None), ty_obj

            if isinstance(ty_obj, TyUserDefinedClass):
                # if there is no such attribute, just return None (undefined)
                return getattr(ty_obj.instance, node.attr, None), None

            return None, None

        if isinstance(node, gast.Name):
            if node.id in self.tyenv.keys():
                ty = self.tyenv[node.id].deref()
                if isinstance(ty, TyUserDefinedClass):
                    return ty.instance, None

            if node.id in __builtins__.keys():
                return __builtins__[node.id], None

            if hasattr(self.module, node.id):
                return getattr(self.module, node.id), None

        assert False, gast.dump(node)
示例#24
0
 def dump(node):
     return ast.dump(node, indent=2)
示例#25
0
  def test_buildable(self, template):
    """Test that each template can be built when given acceptable arguments."""
    rng = np.random.RandomState(1234)

    # Construct a hole that this template can always fill.
    hole = top_down_refinement.Hole(
        template.fills_type,
        python_numbers_control_flow.ASTHoleMetadata(
            names_in_scope=frozenset({"a"}),
            inside_function=True,
            inside_loop=True,
            op_depth=0))
    self.assertTrue(template.can_fill(hole))

    # Make sure we can build this object with no errors.
    filler = template.fill(hole, rng)
    dummy_values = {
        python_numbers_control_flow.ASTHoleType.NUMBER:
            (lambda: gast.Constant(value=1, kind=None)),
        python_numbers_control_flow.ASTHoleType.BOOL:
            (lambda: gast.Constant(value=True, kind=None)),
        python_numbers_control_flow.ASTHoleType.STMT: gast.Pass,
        python_numbers_control_flow.ASTHoleType.STMTS: (lambda: []),
        python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY:
            (lambda: [gast.Pass()]),
        python_numbers_control_flow.ASTHoleType.BLOCK: (lambda: [gast.Pass()]),
    }
    hole_values = [dummy_values[h.hole_type]() for h in filler.holes]
    value = filler.build(*hole_values)

    # Check the type of the value that was built.
    if template.fills_type in (
        python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY,
        python_numbers_control_flow.ASTHoleType.BLOCK):
      self.assertTrue(value)
      for item in value:
        self.assertIsInstance(item, gast.stmt)
    elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMTS:
      for item in value:
        self.assertIsInstance(item, gast.stmt)
    elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMT:
      self.assertIsInstance(value, gast.stmt)
    elif template.fills_type in (python_numbers_control_flow.ASTHoleType.NUMBER,
                                 python_numbers_control_flow.ASTHoleType.BOOL):
      self.assertIsInstance(value, gast.expr)
    else:
      raise NotImplementedError(f"Unexpected fill type {template.fills_type}; "
                                "please update this test.")

    # Check that cost reflects number of AST nodes.
    total_cost = 0
    if isinstance(value, gast.AST):
      for _ in gast.walk(value):
        total_cost += 1
    else:
      for item in value:
        for _ in gast.walk(item):
          total_cost += 1

    self.assertEqual(template.required_cost, total_cost)

    cost_without_holes = total_cost - sum(
        python_numbers_control_flow.ALL_COSTS[h.hole_type]
        for h in filler.holes)

    self.assertEqual(filler.cost, cost_without_holes)

    # Check determinism
    for _ in range(20):
      rng = np.random.RandomState(1234)
      redo_value = template.fill(hole, rng).build(*hole_values)
      if isinstance(value, list):
        self.assertEqual([gast.dump(v) for v in value],
                         [gast.dump(v) for v in redo_value])
      else:
        self.assertEqual(gast.dump(value), gast.dump(redo_value))
示例#26
0
def test_analyze_block():
    def ternary_fn():
        x = 1 if True else False

    def list_comp_fn():
        x = [i for i in [1, 2, 3]]

    def dict_comp_fn():
        x = {i: j for i, j in [[1, 2], [2, 3], [3, 4]]}

    def lambda_fn():
        x = lambda y: y + 2

    def func_fn():
        def fn(y):
            return y + 2

    def ifelse_fn():
        if True:
            x = 1
        else:
            x = 3

    def for_fn():
        for i in [1, 2, 3]:
            x = i
        else:
            y = i

    def while_fn():
        while True:
            x = 2

    def with_fn():
        with 1 as x:
            y = x

    def try_fn():
        try:
            x = 1
        finally:
            z = x == 1

    # test case, whether the first statement is code block, expected ast getter
    block_fns = [
        (ternary_fn, False),
        (list_comp_fn, False),
        (dict_comp_fn, False),
        (lambda_fn, False),
        (func_fn, True),
        (ifelse_fn, True),
        (for_fn, True),
        (while_fn, True),
        (with_fn, True),
        (try_fn, True),
    ]

    for fn, is_expected_block in block_fns:
        analyzed_ast = analyze_block(parse_ast(fn))
        fn_ast = analyzed_ast.body[0]
        block_ast = fn_ast.body[0]
        # check code blocks are labeled correctly
        assert block_ast.is_block == is_expected_block
        # check back edges to code block are created correctedly to child nodes
        if is_expected_block:
            for child in gast.walk(block_ast):
                # walk() will include the root block ast node..
                # ignore when checking code block back edges
                if child == block_ast:
                    continue

                if child.block != block_ast:
                    __import__("pprint").pprint((gast.dump(child.block)))
                    __import__("pprint").pprint((gast.dump(block_ast)))
                assert child.block == block_ast
示例#27
0
 def test_keyword_argument(self):
     code = 'def foo(**a): pass'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     gast.dump(tree, include_attributes=True)
示例#28
0
 def test_keyword_argument(self):
     code = 'def foo(**a): pass'
     tree = gast.parse(code)
     compile(gast.gast_to_ast(tree), '<test>', 'exec')
     gast.dump(tree, include_attributes=True)
示例#29
0
def test_symbol_resolution():
    def simple_fn():
        simple = 2
        simple

    def multi_assign():
        multi_a, multi_b = True, False
        multi_a, multi_b

    def repeated_assign():
        repeated = "first"
        repeated = "second"
        repeated

    def scoped_assign():
        scoped = False

        def fn():
            scoped = True

        # should reference the first definition of 'scoped' as the second definition
        # is scoped to only within function
        scoped

    def aug_assign():
        aug = 0
        aug = aug + 1
        aug

    class Qualified:
        a = 1

    def qualified_assign():
        Qualified.a = 1
        Qualified.a

    # test case functions, the line no. wrt. the function where the variable last defined
    # and finally list of all line no. where variable is defined.
    # if line no. is None, the symbols is defined global symbol
    symbol_fns = [
        (simple_fn, 0, [0]),
        (multi_assign, 0, [0]),
        (repeated_assign, 1, [0, 1]),
        (scoped_assign, 0, [0]),
        (aug_assign, 1, [0, 1]),
        (qualified_assign, 0, [0]),
    ]

    for symbol_fn, n_latest_def_line, n_def_lines in symbol_fns:
        ast = parse_ast(symbol_fn)
        required_analyzers = [
            analyze_symbol,
            analyze_assign,
        ]
        for analyzer in required_analyzers:
            ast = analyzer(ast)
        analyzed_ast = resolve_symbol(ast)
        fn_ast = analyzed_ast.body[0]

        # check latest symbol definition labeled as 'definition'
        latest_sym_def = fn_ast.body[n_latest_def_line]
        sym_ref = fn_ast.body[-1].value
        latest_sym_defs = latest_sym_def.values
        sym_refs = sym_ref.elts if isinstance(sym_ref, Tuple) else [sym_ref]
        for latest_sym_def, sym_ref in zip(latest_sym_defs, sym_refs):
            if sym_ref.definition != latest_sym_def:
                print(gast.dump(sym_ref))
                print(gast.dump(sym_ref.definition))
            assert sym_ref.definition == latest_sym_def

        # check all symbol definitions labeled as 'definitions'
        sym_defs = [fn_ast.body[n_line] for n_line in n_def_lines]
        for line_sym_defs in sym_defs:
            for sym_def, sym_ref in zip(line_sym_defs.values, sym_refs):
                if sym_ref.definitions.count(sym_def) != 1:
                    print(gast.dump(sym_ref))
                    print([gast.dump(d) for d in sym_ref.definitions])
                assert sym_ref.definitions.count(sym_def) == 1
示例#30
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
  """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    node, name, ns = class_to_graph(o, program_ctx)
  elif tf_inspect.isfunction(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  elif tf_inspect.ismethod(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'contrib/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  # TODO(mdan): This is temporary. it should be created using a converter.
  # TODO(mdan): The attribute should be added with a helper, not directly.
  # The helper can ensure there are no collisions.
  template = '''
      entity.autograph_info__ = {}
  '''
  node.extend(templates.replace(template, entity=name))

  program_ctx.add_to_cache(o, node)

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(node))
  if logging.has_verbosity(4):
    for n in node:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n', o, gast.dump(n))

  if program_ctx.options.recursive:
    while True:
      candidate = None
      for obj in program_ctx.name_map.keys():
        if obj not in program_ctx.dependency_cache:
          candidate = obj
          break
      if candidate is None:
        break
      if (hasattr(candidate, 'im_class') and
          getattr(candidate, 'im_class') not in program_ctx.partial_types):
        # Class members are converted with their objects, unless they're
        # only converted partially.
        continue
      entity_to_graph(candidate, program_ctx, {}, {})

  return node, name, ns