def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): # Expecting the actual type to be the second argument. if len(node.args) != 2: raise ValueError('"%s" must have exactly two parameters' % self.context.type_annotation_func) if not anno.hasanno(node.args[0], anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % self.context.type_annotation_func) if not anno.hasanno(node.args[1], 'live_val'): raise ValueError( 'the second argument of "%s" must be statically resolvable' % self.context.type_annotation_func) target_symbol = anno.getanno(node.args[0], anno.Basic.QN) element_type = anno.getanno(node.args[1], 'live_val') # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) anno.setanno(definition, 'element_type', element_type) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node)
def _wrap_to_py_func_no_return(self, node): args_scope = anno.getanno(node, 'args_scope') # TODO(mdan): Properly handle varargs, kwargs, etc. args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used) # pylint:disable=undefined-variable,unused-argument,function-redefined def template(call, wrapper, args): def wrapper(args): call(args) return 1 tf.py_func(wrapper, [args], [tf.int64]) # pylint:enable=undefined-variable,unused-argument,function-redefined wrapper_name = self.namer.compiled_function_name(node.func.id) wrapper_def, call_expr = templates.replace( template, call=node.func, wrapper=gast.Name(wrapper_name, gast.Load(), None), args=args) anno.setanno(call_expr.value, 'args_scope', args_scope) anno.setanno(wrapper_def, 'skip_processing', True) return (wrapper_def, call_expr)
def visit_Call(self, node): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least an argument.') anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) else: raise NotImplementedError('py_func with return values') else: if self.context.recursive: raise NotImplementedError('Could not resolve target function.') else: # TODO(mdan): Double check. Is this reachable code? pass return node
def visit_Call(self, node): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least an argument.') anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return(node, target_fqn) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if self.context.recursive: node = self._insert_dynamic_conversion(node) else: # Unresolved functions are allowed in non-recursive mode. pass return node
def visit_Call(self, node): target = node.func if not anno.hasanno(target, 'live_val'): if not isinstance(target, gast.Attribute): # Suspecting this pattern would reach here: # foo = bar # foo() raise ValueError('Dont know how to handle dynamic functions.') if not isinstance(target.value, gast.Name): # Possible example of this kind: # foo = module.Foo() # foo.bar.baz() # TODO(mdan): This should be doable by using the FQN. raise ValueError('Dont know how to handle object properties yet.') # In the example below, object_source is 'tr.train.Optimizer()': # opt = tf.train.Optimizer() # opt.foo() if self.scope.hasval(target.value.id): object_source = self.scope.getval(target.value.id) if not anno.hasanno(object_source, 'type'): raise ValueError('Could not determine type of "%s". Is it dynamic?' % (target.value.id)) anno.setanno(target, 'type', anno.getanno(object_source, 'type')) anno.setanno(target, 'type_fqn', anno.getanno(object_source, 'type_fqn')) else: # TODO(mdan): Figure out what could the user do to get past this. raise ValueError('No info on "%s". Is it dynamically built?' % (target.value.id)) self.generic_visit(node) return node
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): assert anno.hasanno(node, NodeAnno.IS_LOCAL), node symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) assert anno.hasanno(node, NodeAnno.IS_PARAM), node symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__name__,)) else: pass # TODO(mdan): Should we raise an error here? # Can encounter this when: # * a symbol truly lacks reference # * a symbol is new, like the new name of a function we just renamed. else: pass # TODO(mdan): Attempt to trace its value through the local chain. # TODO(mdan): Use type annotations as fallback. if not symbol_is_modified: if node.id in self.context.arg_values: obj = self.context.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): assert anno.hasanno(node, 'is_local'), node symbol_is_local = anno.getanno(node, 'is_local') assert anno.hasanno(node, 'is_modified_since_entry'), node symbol_is_modified = anno.getanno(node, 'is_modified_since_entry') assert anno.hasanno(node, 'is_param'), node symbol_is_param = anno.getanno(node, 'is_param') if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__name__,)) else: raise ValueError('Could not resolve symbol "%s".' % node.id) else: pass # TODO(mdan): Attempt to trace its value through the local chain. # TODO(mdan): Use type annotations as fallback. if not symbol_is_modified: if node.id in self.context.arg_values: obj = self.context.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node
def visit_Call(self, node): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_obj = anno.getanno(node.func, 'live_val') if target_obj in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least an argument.') anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_obj = anno.getanno(node.func, 'live_val') if self._function_is_compilable(target_obj): node = self._rename_compilable_function(node) else: raise NotImplementedError('py_func with return values') elif anno.hasanno(node.func, 'type_fqn'): node = self._rename_member_function_of_known_type(node) else: raise NotImplementedError( 'Member function call (of unknown type): %s.' % node.func.id) return node
def _wrap_to_py_func_no_return(self, node): func_qn = anno.getanno(node.func, anno.Basic.QN) args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) wrapper_name = self.context.namer.new_symbol(func_qn.ssf(), args_scope.referenced) wrapper_args = [] for arg in node.args: if anno.hasanno(arg, anno.Basic.QN): arg_qn = anno.getanno(arg, anno.Basic.QN) else: arg_qn = qual_names.QN('arg') wrapper_args.append( self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced)) # TODO(mdan): Properly handle varargs, kwargs, etc. # TODO(mdan): This is best handled as a dynamic dispatch. # That way we can separate tensors from non-tensor args. template = """ def wrapper(wrapper_args): call(wrapper_args) return 1 tf.py_func(wrapper, original_args, [tf.int64]) """ wrapper_def, call_expr = templates.replace( template, call=node.func, wrapper=wrapper_name, original_args=gast.List(elts=node.args, ctx=None), wrapper_args=wrapper_args) anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True) return (wrapper_def, call_expr)
def visit_With(self, node): current_scope = self.scope with_scope = Scope(current_scope, isolated=False) self.scope = with_scope self.generic_visit(node) anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope) self.scope = current_scope return node
def _as_function(self, func_name, args): template = """ func_name(args) """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name), args=args) anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) return replacement
def visit_Print(self, node): current_scope = self.scope args_scope = Scope(current_scope) self.scope = args_scope for n in node.values: self.visit(n) anno.setanno(node, 'args_scope', args_scope) self.scope = current_scope return node
def _inline_tf_op(self, op_name, args): template = """ tf.op_name(args) """ replacement = templates.replace(template, op_name=op_name, args=args) # It's a body with a single expression, we want its value. n = replacement[0].value anno.setanno(n, SAFE_BOOLEAN_OPERAND, True) return n
def _process_block_node(self, node, block, scope_name): current_scope = self.scope block_scope = Scope(current_scope, isolated=False) self.scope = block_scope for n in block: self.visit(n) anno.setanno(node, '%s_scope' % scope_name, block_scope) self.scope = current_scope return node
def test_basic(self): node = ast.Name() self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(3, anno.getanno(node, 'foo'))
def test_copyanno(self): node_1 = ast.Name() anno.setanno(node_1, 'foo', 3) node_2 = ast.Name() anno.copyanno(node_1, node_2, 'foo') anno.copyanno(node_1, node_2, 'bar') self.assertTrue(anno.hasanno(node_2, 'foo')) self.assertFalse(anno.hasanno(node_2, 'bar'))
def _process_function_arg(self, arg_name): str_name = str(arg_name) if self.function_level == 1 and str_name in self.context.arg_types: # Forge a node to hold the type information, so that method calls on # it can resolve the type. type_holder = arg_name.ast() type_string, type_obj = self.context.arg_types[str_name] anno.setanno(type_holder, 'type', type_obj) anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) self.scope.setval(arg_name, type_holder)
def visit_FunctionDef(self, node): if self.scope: qn = QN(node.name) self.scope.mark_write(qn) current_scope = self.scope fndef_scope = Scope(current_scope, isolated=True) self.scope = fndef_scope self.generic_visit(node) anno.setanno(node, NodeAnno.BODY_SCOPE, fndef_scope) self.scope = current_scope return node
def visit_Call(self, node): current_scope = self.scope args_scope = Scope(current_scope, isolated=False) self.scope = args_scope for n in node.args: self.visit(n) # TODO(mdan): Account starargs, kwargs for n in node.keywords: self.visit(n) anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) self.scope = current_scope self.visit(node.func) return node
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Param): self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None)) if self.function_level == 1 and node.id in self.context.arg_types: # Forge a node to hold the type information, so that method calls on # it can resolve the type. type_holder = gast.Name(node.id, gast.Load(), None) type_string, type_obj = self.context.arg_types[node.id] anno.setanno(type_holder, 'type', type_obj) anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) self.scope.setval(node.id, type_holder) return node
def visit_Call(self, node): current_scope = self.scope args_scope = Scope(current_scope) self.scope = args_scope for n in node.args: self.visit(n) # TODO(mdan): Account starargs, kwargs for n in node.keywords: self.visit(n) anno.setanno(node, 'args_scope', args_scope) self.scope = current_scope self.visit(node.func) return node
def visit_Name(self, node): self.generic_visit(node) qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): self._process_function_arg(qn) elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): # E.g. if we had # a = b # then for future references to `a` we should have traced_source = `b` traced_source = self.scope.getval(qn) if anno.hasanno(traced_source, 'type'): anno.setanno(node, 'type', anno.getanno(traced_source, 'type')) anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn')) return node
def visit_While(self, node): self.visit(node.test) current_scope = self.scope anno.setanno(node, 'parent_scope', current_scope) body_scope = Scope(current_scope, isolated=False) self.scope = body_scope for n in node.body: self.visit(n) anno.setanno(node, 'body_scope', body_scope) if node.orelse: raise NotImplementedError() # TODO(mdan): Add support for orelse. self.scope = current_scope return node
def visit_Attribute(self, node): self.generic_visit(node) if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. elif anno.hasanno(node.value, 'type'): parent_type = anno.getanno(node.value, 'type') if hasattr(parent_type, node.attr): # This should hold for static members like methods. # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. assert anno.hasanno(stem_name, 'is_local'), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node
def _inline_tf_op(self, op_name, args): if 'py2tf_utils' in op_name: # TODO(alexbw): explicitly spelling out the attribute function name # until fix for issue highlighted in cl/188931581 lands. template = """ py2tf_utils.op_name(args) """ op_name = op_name.replace('py2tf_utils.', '') else: template = """ tf.op_name(args) """ replacement = templates.replace_as_expression( template, op_name=op_name, args=args) anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) return replacement
def visit_Attribute(self, node): self.generic_visit(node) if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. assert anno.hasanno(stem_name, 'is_local'), stem_name assert anno.getanno(stem_name, 'is_local'), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node
def visit_Subscript(self, node): node = self.generic_visit(node) s = node.slice if not isinstance(s, gast.Index): # TODO(mdan): Support range and multi-dimensional indices. # Continuing silently because some demos use these. return node if isinstance(s.value, gast.Num): subscript = QN(NumberLiteral(s.value.n)) elif isinstance(s.value, gast.Str): subscript = QN(StringLiteral(s.value.s)) else: subscript = anno.getanno(node.slice.value, anno.Basic.QN) if anno.hasanno(node.value, anno.Basic.QN): anno.setanno(node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), subscript=subscript)) return node
def visit_Name(self, node): # TODO(mdan): This is insufficient for object fields, e.g. hp.learning_rate. self.generic_visit(node) if isinstance(node.ctx, gast.Store): self.scope.mark_write(node.id) elif isinstance(node.ctx, gast.Load): anno.setanno(node, 'is_local', self.scope.has(node.id)) self.scope.mark_read(node.id) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. # TODO(mdan): This bay be incorrect with nested functions. # For nested functions, we'll have to add the notion of hiding args from # the parent scope, not writing to them. self.scope.mark_write(node.id) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), node.id)) return node
def _wrap_to_py_func_no_return(self, node): args_scope = anno.getanno(node, 'args_scope') # TODO(mdan): Properly handle varargs, kwargs, etc. template = """ def wrapper(args): call(args) return 1 tf.py_func(wrapper, [args], [tf.int64]) """ wrapper_def, call_expr = templates.replace( template, call=node.func, wrapper=self.context.namer.compiled_function_name(node.func.id)[0], args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)) anno.setanno(call_expr.value, 'args_scope', args_scope) # TODO(mdan): Rename this annotation to 'graph_ready' anno.setanno(wrapper_def, 'skip_processing', True) return (wrapper_def, call_expr)
def visit_For(self, node): self.generic_visit(node.target) self.generic_visit(node.iter) scope = anno.getanno(node, 'body_scope') break_var = self.namer.new_symbol('break_requested', scope.referenced) self.break_uses.append([False, break_var]) node.body = self._manual_visit_list(node.body) if self.break_uses[-1][0]: anno.setanno(node, 'extra_cond', gast.UnaryOp(gast.Not(), gast.Name(break_var, gast.Load(), None))) final_nodes = [self._create_break_init(), node] else: final_nodes = node self.break_uses.pop() for n in node.orelse: self.generic_visit(n) return final_nodes
def visit_For(self, node): self.generic_visit(node.target) self.generic_visit(node.iter) scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.context.namer.new_symbol('break_requested', scope.referenced) self.break_uses.append([False, break_var]) node.body = self._manual_visit_list(node.body) if self.break_uses[-1][0]: anno.setanno( node, 'extra_cond', gast.UnaryOp(gast.Not(), gast.Name(break_var, gast.Load(), None))) final_nodes = [self._create_break_init(), node] else: final_nodes = node self.break_uses.pop() for n in node.orelse: self.generic_visit(n) return final_nodes
def generic_visit(self, node): new_fields = {} for f in node._fields: if f.startswith('__'): continue if not hasattr(node, f): continue v = getattr(node, f) if isinstance(v, list): v = [self.generic_visit(n) for n in v] elif isinstance(v, tuple): v = tuple(self.generic_visit(n) for n in v) elif isinstance(v, (gast.AST, ast.AST)): v = self.generic_visit(v) else: # Assume everything else is a value type. pass new_fields[f] = v new_node = type(node)(**new_fields) if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): anno.setanno(new_node, anno.Basic.SKIP_PROCESSING, True) return new_node
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): assert anno.hasanno(node, NodeAnno.IS_LOCAL), node symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) assert anno.hasanno(node, NodeAnno.IS_PARAM), node symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) if hasattr(obj, '__name__'): # If the symbol value is for example a primitive, then it will not # have a name. anno.setanno(node, 'fqn', (obj.__name__, )) else: pass # TODO(mdan): Should we raise an error here? # Can encounter this when: # * a symbol truly lacks reference # * a symbol is new, like the new name of a function we just renamed. else: pass # TODO(mdan): Attempt to trace its value through the local chain. # TODO(mdan): Use type annotations as fallback. if not symbol_is_modified: if node.id in self.context.arg_values: obj = self.context.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__, )) return node
def visit_Compare(self, node): node = self.generic_visit(node) ops_and_comps = list(zip(node.ops, node.comparators)) left = node.left op_tree = None # Repeated comparisons are converted to conjunctions: # a < b < c -> a < b and b < c while ops_and_comps: op, right = ops_and_comps.pop(0) binary_comparison = self._inline_tf_op(self._matching_tf_op(op), (left, right)) if isinstance(left, gast.Name) and isinstance(right, gast.Name): anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) if op_tree: self._expect_simple_symbol(right) op_tree = self._inline_tf_op('logical_and', (binary_comparison, op_tree)) else: op_tree = binary_comparison left = right assert op_tree is not None return op_tree
def visit_Call(self, node): target = node.func if not anno.hasanno(target, 'live_val'): if not isinstance(target, gast.Attribute): # Suspecting this pattern would reach here: # foo = bar # foo() raise ValueError('Dont know how to handle dynamic functions.') if not isinstance(target.value, gast.Name): # Possible example of this kind: # foo = module.Foo() # foo.bar.baz() # TODO(mdan): This should be doable by using the FQN. raise ValueError('Dont know how to handle object properties yet.') # In the example below, object_source is 'tr.train.Optimizer()': # opt = tf.train.Optimizer() # opt.foo() object_source = self.scope.getval(target.value.id) if not anno.hasanno(object_source, 'type'): raise ValueError('Could not determine type of "%s". Is it dynamic?' % (target.value.id)) anno.setanno(target, 'type_fqn', anno.getanno(object_source, 'type_fqn')) self.generic_visit(node) return node
def visit_Assign(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): target = node.value.func if anno.hasanno(target, 'live_val'): target_obj = anno.getanno(target, 'live_val') if tf_inspect.isclass(target_obj): # This is then a constructor. anno.setanno(node.value, 'type', target_obj) anno.setanno(node.value, 'type_fqn', anno.getanno(target, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. for n in node.targets: if isinstance(n, gast.Tuple): for i, e in enumerate(n.elts): self.scope.setval(e.id, gast.Subscript( node.value, gast.Index(i), ctx=gast.Store())) else: self.scope.setval(n.id, node.value) return node
def visit_Print(self, node): self.generic_visit(node) for n in node.values: n.ctx = gast.Param() call_node = gast.Call(func=gast.Name('print', gast.Load(), None), args=node.values, keywords=[]) anno.setanno(call_node.func, 'live_val', print) anno.setanno(call_node.func, 'fqn', 'print') anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope')) node = gast.Expr(call_node) return node
def visit_Name(self, node): self.generic_visit(node) qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): self._process_function_arg(qn) elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): # E.g. if we had # a = b # then for future references to `a` we should have definition = `b` definition = self.scope.getval(qn) if anno.hasanno(definition, 'type'): anno.setanno(node, 'type', anno.getanno(definition, 'type')) anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn')) if anno.hasanno(definition, 'element_type'): anno.setanno(node, 'element_type', anno.getanno(definition, 'element_type')) return node
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): assert anno.hasanno(node, 'is_local'), node symbol_is_local = anno.getanno(node, 'is_local') if not symbol_is_local: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) # TODO (mdan): Could live values have FQNs? i.e. 'a'.join() id:2150 gh:2151 elif node.id in self.namespace: obj = self.namespace[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__name__, )) else: raise ValueError('Could not find global symbol %s.' % node.id) else: pass # TODO (mdan): Attempt to trace its value through the local chain. id:919 gh:920 # TODO (mdan): Use type annotations as fallback. id:1530 gh:1531 return node
def _process_variable_assignment(self, source, targets): if isinstance(source, gast.Call): func = source.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): anno.setanno(source, 'is_constructor', True) anno.setanno(source, 'type', func_obj) anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. for t in targets: if isinstance(t, gast.Tuple): # need to recurse on the case of assigning nested tuples, # ex. a, (b, c) = f() self._process_tuple_assignment(source, t) elif isinstance(t, (gast.Name, gast.Attribute)): self.scope.setval(anno.getanno(t, anno.Basic.QN), source) else: raise ValueError('Dont know how to handle assignment to %s' % t)
def visit_While(self, node): anno.setanno(node, 'parent_scope_values', self.scope.copy()) self.generic_visit(node) return node
def visit_Attribute(self, node): self.generic_visit(node) if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) anno.setanno(node, 'parent_type', type(parent_object)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. elif anno.hasanno(node.value, 'type'): parent_type = anno.getanno(node.value, 'type') if hasattr(parent_type, node.attr): # This should hold for static members like methods. # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. anno.setanno(node, 'parent_type', parent_type) anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node
def visit_ClassDef(self, node): self.generic_visit(node) anno.setanno(node, 'live_val', self.context.namespace[node.name]) return node
def visit_Attribute(self, node): self.generic_visit(node) anno.setanno(node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), node.attr)) return node
def visit_Name(self, node): self.generic_visit(node) anno.setanno(node, anno.Basic.QN, QN(node.id)) return node
def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): # Patterns of single function calls, like: # opt.minimize(loss) # or: # tf.py_func(...) # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) # NOTE: We can't guard object attributes because they may not be writable. guarded_args = tuple(s for s in args_scope.used if not s.is_composite()) # TODO(mdan): Include all arguments which depended on guarded_args too. # For example, the following will still cause a race: # tf.assign(a, a + 1) # b = a + 1 # tf.assign(a, a + 1) # Control deps here should include `b` # c = b + 1 # Or maybe we should just raise an "unsafe assign" error? if guarded_args: # The aliases may need new names to avoid incorrectly making them local. # TODO(mdan): This is brutal. It will even rename modules - any fix? need_alias = tuple(s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( self.context.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: s, = guarded_args aliased_guarded_args = alias_map.get(s, s) else: aliased_guarded_args = gast.Tuple( [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ with py2tf_utils.control_dependency_on_returns(tf, call): aliased_guarded_args = py2tf_utils.alias_tensors(tf, guarded_args) """ control_deps_guard = templates.replace( template, call=node.value, aliased_guarded_args=aliased_guarded_args, guarded_args=guarded_args)[-1] else: alias_map = {} template = """ with py2tf_utils.control_dependency_on_returns(tf, call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] control_deps_guard.body = [] node = control_deps_guard anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, (node.body, alias_map)) return node