def _wrap_to_py_func_no_return(self, node): func_qn = anno.getanno(node.func, anno.Basic.QN) args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) wrapper_name = self.context.namer.new_symbol(func_qn.ssf(), args_scope.referenced) wrapper_args = [] for arg in node.args: if anno.hasanno(arg, anno.Basic.QN): arg_qn = anno.getanno(arg, anno.Basic.QN) else: arg_qn = qual_names.QN('arg') wrapper_args.append( self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced)) # TODO(mdan): Properly handle varargs, kwargs, etc. # TODO(mdan): This is best handled as a dynamic dispatch. # That way we can separate tensors from non-tensor args. template = """ def wrapper(wrapper_args): call(wrapper_args) return 1 tf.py_func(wrapper, original_args, [tf.int64]) """ wrapper_def, call_expr = templates.replace( template, call=node.func, wrapper=wrapper_name, original_args=gast.List(elts=node.args, ctx=None), wrapper_args=wrapper_args) anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True) return (wrapper_def, call_expr)
def _rename_member_function_of_known_type(self, node): assert isinstance(node.func, gast.Attribute) type_fqn = anno.getanno(node.func, 'type_fqn') assert anno.hasanno(node.func, 'type') target_type = anno.getanno(node.func, 'type') if not self._should_compile(node, type_fqn): return node # TODO(mdan): We should not assume that the namer only needs the # member function name. method_name = node.func.attr method_object = getattr(target_type, method_name) new_name = self.namer.compiled_function_name( method_name, live_object=method_object, owner_type=target_type) if new_name != node.func.attr: # If a member function call is renamed, then the new function is no # longer bound to the target object. We then refactor the call from: # foo.bar(...) # to: # renamed_foo(bar, ...) # TODO(mdan): This risks causing duplication, if target_type is renamed. node.args = [node.func.value] + node.args node.func = gast.Name(new_name, gast.Load(), None) return node
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIs(anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIs( anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIs(anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIs( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node = parser.parse_object(test_fn) node = access.resolve(node) if_node = node.body[0].body[0] self.assertScopeIs( anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIs( anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIs( anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIs( anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node = parser.parse_object(test_fn) node = access.resolve(node) if_node = node.body[0].body[0] self.assertScopeIs(anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIs(anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIs(anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIs(anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def visit_Call(self, node): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_obj = anno.getanno(node.func, 'live_val') if target_obj in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least an argument.') anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_obj = anno.getanno(node.func, 'live_val') if self._function_is_compilable(target_obj): node = self._rename_compilable_function(node) else: raise NotImplementedError('py_func with return values') elif anno.hasanno(node.func, 'type_fqn'): node = self._rename_member_function_of_known_type(node) else: raise NotImplementedError( 'Member function call (of unknown type): %s.' % node.func.id) return node
def test_print_statement(self): def test_fn(a): b = 0 c = 1 print(a, b) return c node = parser.parse_object(test_fn) node = access.resolve(node) print_node = node.body[0].body[2] if isinstance(print_node, gast.Print): # Python 2 print_args_scope = anno.getanno(print_node, 'args_scope') else: # Python 3 assert isinstance(print_node, gast.Expr) # The call node should be the one being annotated. print_node = print_node.value print_args_scope = anno.getanno(print_node, 'args_scope') # We basically need to detect which variables are captured by the call # arguments. self.assertItemsEqual(['a', 'b'], print_args_scope.used) self.assertItemsEqual([], print_args_scope.modified) self.assertItemsEqual([], print_args_scope.created)
def visit_Call(self, node): target = node.func if not anno.hasanno(target, 'live_val'): if not isinstance(target, gast.Attribute): # Suspecting this pattern would reach here: # foo = bar # foo() raise ValueError('Dont know how to handle dynamic functions.') if not isinstance(target.value, gast.Name): # Possible example of this kind: # foo = module.Foo() # foo.bar.baz() # TODO(mdan): This should be doable by using the FQN. raise ValueError('Dont know how to handle object properties yet.') # In the example below, object_source is 'tr.train.Optimizer()': # opt = tf.train.Optimizer() # opt.foo() if self.scope.hasval(target.value.id): object_source = self.scope.getval(target.value.id) if not anno.hasanno(object_source, 'type'): raise ValueError('Could not determine type of "%s". Is it dynamic?' % (target.value.id)) anno.setanno(target, 'type', anno.getanno(object_source, 'type')) anno.setanno(target, 'type_fqn', anno.getanno(object_source, 'type_fqn')) else: # TODO(mdan): Figure out what could the user do to get past this. raise ValueError('No info on "%s". Is it dynamically built?' % (target.value.id)) self.generic_visit(node) return node
def visit_Assign(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): target = node.value.func if anno.hasanno(target, 'live_val'): target_obj = anno.getanno(target, 'live_val') if tf_inspect.isclass(target_obj): # This is then a constructor. anno.setanno(node.value, 'type', target_obj) anno.setanno(node.value, 'type_fqn', anno.getanno(target, 'fqn')) # TODO (mdan): Raise an error if constructor has side effects. id:2153 gh:2154 # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. for n in node.targets: if isinstance(n, gast.Tuple): for i, e in enumerate(n.elts): self.scope.setval( e.id, gast.Subscript(node.value, gast.Index(i), ctx=gast.Store())) else: self.scope.setval(n.id, node.value) return node
def _process_variable_assignment(self, source, targets): if isinstance(source, gast.Call): func = source.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): anno.setanno(source, 'is_constructor', True) anno.setanno(source, 'type', func_obj) anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. for t in targets: if isinstance(t, gast.Tuple): for i, e in enumerate(t.elts): self.scope.setval( anno.getanno(e, anno.Basic.QN), gast.Subscript(source, gast.Index(i), ctx=gast.Store())) elif isinstance(t, (gast.Name, gast.Attribute)): self.scope.setval(anno.getanno(t, anno.Basic.QN), source) else: raise ValueError('Dont know how to handle assignment to %s' % t)
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.context.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: owner_type = self._determine_function_owner(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = gast.Name(new_name, gast.Load(), None) return node
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): assert anno.hasanno(node, 'is_local'), node symbol_is_local = anno.getanno(node, 'is_local') assert anno.hasanno(node, 'is_modified_since_entry'), node symbol_is_modified = anno.getanno(node, 'is_modified_since_entry') assert anno.hasanno(node, 'is_param'), node symbol_is_param = anno.getanno(node, 'is_param') if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__name__,)) else: raise ValueError('Could not resolve symbol "%s".' % node.id) else: pass # TODO(mdan): Attempt to trace its value through the local chain. # TODO(mdan): Use type annotations as fallback. if not symbol_is_modified: if node.id in self.context.arg_values: obj = self.context.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.context.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: owner_type = self._determine_function_owner(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace('func_name', func_name=new_name)[0] return node
def visit_Call(self, node): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least an argument.') anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return(node, target_fqn) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if self.context.recursive: node = self._insert_dynamic_conversion(node) else: # Unresolved functions are allowed in non-recursive mode. pass return node
def _process_variable_assignment(self, source, targets): if isinstance(source, gast.Call): func = source.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): anno.setanno(source, 'is_constructor', True) anno.setanno(source, 'type', func_obj) anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. for t in targets: if isinstance(t, gast.Tuple): for i, e in enumerate(t.elts): self.scope.setval(e.id, gast.Subscript( source, gast.Index(i), ctx=gast.Store())) elif isinstance(t, gast.Name): self.scope.setval(t.id, source) elif isinstance(t, gast.Attribute): if not (isinstance(t.value, gast.Name) and t.value.id == 'self'): raise ValueError( 'Dont know how to handle assignment to attributes of objects' ' other than "self": [%s].%s' % (t.value, t.attr)) else: raise ValueError('Dont know how to handle assignment to %s' % t)
def test_call_with_composite_names(self): def foo(*_): pass def test_fn(a): foo(a.b, a.c) if a > 0: a.b = 2 else: d = 2 d.e = a.c f = d.e + 1 a.c = f node = self._parse_and_analyze(test_fn) call_node = node.body[0].body[0].value self.assertScopeIs( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (), ()) if_node = node.body[0].body[1] self.assertScopeIs( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ()) self.assertScopeIs( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f'))
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): assert anno.hasanno(node, NodeAnno.IS_LOCAL), node symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) assert anno.hasanno(node, NodeAnno.IS_PARAM), node symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__name__,)) else: pass # TODO(mdan): Should we raise an error here? # Can encounter this when: # * a symbol truly lacks reference # * a symbol is new, like the new name of a function we just renamed. else: pass # TODO(mdan): Attempt to trace its value through the local chain. # TODO(mdan): Use type annotations as fallback. if not symbol_is_modified: if node.id in self.context.arg_values: obj = self.context.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node
def visit_Attribute(self, node): self.generic_visit(node) if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) anno.setanno(node, 'parent_type', type(parent_object)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. elif anno.hasanno(node.value, 'type'): parent_type = anno.getanno(node.value, 'type') if hasattr(parent_type, node.attr): # This should hold for static members like methods. # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. anno.setanno(node, 'parent_type', parent_type) anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.context.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace('func_name', func_name=new_name)[0] return node
def visit_Call(self, node): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least an argument.') anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) else: raise NotImplementedError('py_func with return values') else: if self.context.recursive: raise NotImplementedError('Could not resolve target function.') else: # TODO(mdan): Double check. Is this reachable code? pass return node
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): # Expecting the actual type to be the second argument. if len(node.args) != 2: raise ValueError('"%s" must have exactly two parameters' % self.context.type_annotation_func) if not anno.hasanno(node.args[0], anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % self.context.type_annotation_func) if not anno.hasanno(node.args[1], 'live_val'): raise ValueError( 'the second argument of "%s" must be statically resolvable' % self.context.type_annotation_func) target_symbol = anno.getanno(node.args[0], anno.Basic.QN) element_type = anno.getanno(node.args[1], 'live_val') # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) anno.setanno(definition, 'element_type', element_type) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node)
def test_function_def(self): def test_fn(a): def f(x): y = x * x return y b = a for i in a: c = b b -= f(i) return b, c node = self._parse_and_analyze(test_fn) fndef_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(fndef_node, NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'), ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i')) self.assertScopeIsRmc( anno.getanno(fndef_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( 'x', 'y', ))
def _wrap_to_py_func_no_return(self, node): func_qn = anno.getanno(node.func, anno.Basic.QN) args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) wrapper_name = self.context.namer.new_symbol(func_qn.ssf(), args_scope.referenced) wrapper_args = [] for arg in node.args: if anno.hasanno(arg, anno.Basic.QN): arg_qn = anno.getanno(arg, anno.Basic.QN) else: arg_qn = qual_names.QN('arg') wrapper_args.append( self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced)) # TODO(mdan): Properly handle varargs, kwargs, etc. # TODO(mdan): This is best handled as a dynamic dispatch. # That way we can separate tensors from non-tensor args. template = """ def wrapper(wrapper_args): call(wrapper_args) return 1 tf.py_func(wrapper, original_args, [tf.int64]) """ wrapper_def, call_expr = templates.replace(template, call=node.func, wrapper=wrapper_name, original_args=gast.List( elts=node.args, ctx=None), wrapper_args=wrapper_args) anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True) return (wrapper_def, call_expr)
def test_call_with_composite_names(self): def foo(*_): pass def test_fn(a): foo(a.b, a.c) if a > 0: a.b = 2 else: d = 2 d.e = a.c f = d.e + 1 a.c = f node = self._parse_and_analyze(test_fn) call_node = node.body[0].body[0].value self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (), ()) if_node = node.body[0].body[1] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ()) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f'))
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): # Expecting the actual type to be the second argument. if len(node.args) != 2: raise ValueError('"%s" must have exactly two parameters' % self.context.type_annotation_func) if not anno.hasanno(node.args[0], anno.Basic.QN): raise ValueError( 'the first argument of "%s" must by a symbol' % self.context.type_annotation_func) if not anno.hasanno(node.args[1], 'live_val'): raise ValueError( 'the second argument of "%s" must be statically resolvable' % self.context.type_annotation_func) target_symbol = anno.getanno(node.args[0], anno.Basic.QN) element_type = anno.getanno(node.args[1], 'live_val') # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) anno.setanno(definition, 'element_type', element_type) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node)
def visit_Attribute(self, node): self.generic_visit(node) if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. elif anno.hasanno(node.value, 'type'): parent_type = anno.getanno(node.value, 'type') if hasattr(parent_type, node.attr): # This should hold for static members like methods. # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. assert anno.hasanno(stem_name, 'is_local'), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node
def _try_resolve_target(self, node): """Works for methods of objects of known type.""" if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): member = getattr(anno.getanno(node, 'type'), node.attr) return member return None
def test_basic(self): node = ast.Name() self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(3, anno.getanno(node, 'foo'))
def test_attribute_names(self): def test_fn(): return constant_op.constant(0) node = self._parse_and_analyze(test_fn, {'constant_op': constant_op}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn'))
def test_constructor_detection(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) return opt node = self._parse_and_analyze(test_fn, {'training': training}) call_node = node.body[0].body[0].value self.assertEquals(training.GradientDescentOptimizer, anno.getanno(call_node, 'type')) self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(call_node, 'type_fqn'))
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = self._parse_and_analyze(test_fn, {'foo': foo}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo', ), anno.getanno(func_node, 'fqn'))
def _try_resolve_target(self, node): """Works for methods of objects of known type.""" if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): owner_type = anno.getanno(node, 'type') if hasattr(owner_type, node.attr): return getattr(owner_type, node.attr) else: raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' % (owner_type, node.attr)) return None
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = self._parse_and_analyze(test_fn, {'foo': foo}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_obj = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(target_fqn): return node new_name = self.namer.compiled_function_name('.'.join(target_fqn), live_object=target_obj) node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None) return node
def test_attribute_names(self): def test_fn(): return constant_op.constant(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'constant_op': constant_op}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn'))
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_obj = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(target_fqn): return node new_name = self.namer.compiled_function_name( '.'.join(target_fqn), live_object=target_obj) node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None) return node
def test_type_annotation(self): class Foo(object): pass def test_fn(): f = [] f = utils.set_element_type(f, Foo) return f node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) f_def = node.body[0].body[0].value self.assertEqual(anno.getanno(f_def, 'element_type'), Foo) f_ref = node.body[0].body[1].value self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
def test_while(self): def test_fn(a): b = a while b > 0: c = b b -= 1 return b, c node = self._parse_and_analyze(test_fn) while_node = node.body[0].body[1] self.assertScopeIs(anno.getanno(while_node, 'body_scope'), ('b', ), ('b', 'c'), ('c', )) self.assertScopeIs(anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'), ('b', 'c'), ('a', 'b', 'c'))
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'foo': foo}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo', ), anno.getanno(func_node, 'fqn'))
def visit_Attribute(self, node): self.generic_visit(node) if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr, )) # TODO (mdan): Figure out what to do when calling attribute on local object. id:1287 gh:1288 # Maybe just leave as-is? return node
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Param): self._process_function_arg(node.id) elif isinstance(node.ctx, gast.Load) and self.scope.hasval(node.id): # E.g. if we had # a = b # then for future references to `a` we should have traced_source = `b` traced_source = self.scope.getval(node.id) if anno.hasanno(traced_source, 'type'): anno.setanno(node, 'type', anno.getanno(traced_source, 'type')) anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn')) return node
def visit_Name(self, node): self.generic_visit(node) qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): self._process_function_arg(qn) elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): # E.g. if we had # a = b # then for future references to `a` we should have traced_source = `b` traced_source = self.scope.getval(qn) if anno.hasanno(traced_source, 'type'): anno.setanno(node, 'type', anno.getanno(traced_source, 'type')) anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn')) return node
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'foo': foo}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
def test_class_members_in_with_stmt(self): def test_fn(x): with session.Session() as sess: sess.run(x) node = self._parse_and_analyze(test_fn, {'session': session}) constructor_call = node.body[0].body[0].items[0].context_expr self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) method_call = node.body[0].body[0].body[0].value.func self.assertEquals(session.Session.run, anno.getanno(method_call, 'live_val'))
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, 'body_scope') # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)` # Or maybe we should replace range with tf.range? if anno.hasanno(node, 'extra_cond'): def template(loop_iter, target, body, i, n, extra_cond): # pylint:disable=unused-argument i = 0 n = len(loop_iter) # pylint:disable=undefined-variable while i < n and extra_cond: # TODO(mdan): Use TensorListFromTensor(loop_iter) here. target = loop_iter[i] body # pylint:disable=pointless-statement i += 1 return templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, i=gast.Name( self.namer.new_symbol('i', body_scope.referenced), None, None), n=gast.Name( self.namer.new_symbol('n', body_scope.referenced), None, None), extra_cond=anno.getanno(node, 'extra_cond')) else: def template(loop_iter, target, body, i, n): # pylint:disable=unused-argument i = 0 n = len(loop_iter) # pylint:disable=undefined-variable while i < n: # TODO(mdan): Use TensorListFromTensor(loop_iter) here. target = loop_iter[i] body # pylint:disable=pointless-statement i += 1 return templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, i=gast.Name( self.namer.new_symbol('i', body_scope.referenced), None, None), n=gast.Name( self.namer.new_symbol('n', body_scope.referenced), None, None))