def visit_If(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) # Note: this information needs to be extracted before the body conversion # that happens in the call to generic_visit below, because the conversion # generates nodes that lack static analysis annotations. need_alias_in_body = self._determine_aliased_symbols( body_scope, defined_in, node.body) need_alias_in_orelse = self._determine_aliased_symbols( orelse_scope, defined_in, node.orelse) node = self.generic_visit(node) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() composites = set() for s in modified_in_cond: if s in live_out and not s.is_composite(): returned_from_cond.add(s) if s.is_composite(): # Special treatment for compound objects, always return them. # This allows special handling within the if_stmt itself. # For example, in TensorFlow we need to restore the state of composite # symbols to ensure that only effects from the executed branch are seen. composites.add(s) created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in basic_created_in_body = tuple(s for s in created_in_body if not s.is_composite()) basic_created_in_orelse = tuple(s for s in created_in_orelse if not s.is_composite()) # These variables are defined only in a single branch. This is fine in # Python so we pass them through. Another backend, e.g. Tensorflow, may need # to handle these cases specially or throw an Error. possibly_undefined = (set(basic_created_in_body) ^ set(basic_created_in_orelse)) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict( zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) all_referenced = body_scope.referenced | orelse_scope.referenced state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) returned_from_cond = tuple(returned_from_cond) composites = tuple(composites) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple( [s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): Replace with None once side_effect_guards is retired. returned_from_body = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) returned_from_orelse = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) cond_assign = self.create_assignment(cond_var_name, node.test) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) undefined_assigns = self._create_undefined_assigns(possibly_undefined) composite_defs = self._create_state_functions(composites, state_getter_name, state_setter_name) basic_symbol_names = tuple( gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond) composite_symbol_names = tuple( gast.Constant(str(symbol), kind=None) for symbol in composites) cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, orelse_name, state_getter_name, state_setter_name, basic_symbol_names, composite_symbol_names) if_ast = (undefined_assigns + composite_defs + body_def + orelse_def + cond_assign + cond_expr) return if_ast
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name, max_return_length, parent_node_of_return): assert max_return_length >= 0, "Input illegal max_return_length" i = index_in_list(stmt_list, return_node) if i == -1: return False assign_nodes = [] # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = True' node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, True)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) assign_true_node = gast.parse(node_str).body[0] assign_nodes.append(assign_true_node) cur_func_node = self.function_def[-1] return_length = get_return_size(return_node) if return_length < max_return_length: # In this case we should append RETURN_NO_VALUE placeholder # # max_return_length must be >= 1 here because return_length will be # 0 at least. if self.return_value_name[cur_func_node] is None: self.return_value_name[cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) no_value_names = [ unique_name.generate(RETURN_NO_VALUE_VAR_NAME) for j in range(max_return_length - return_length) ] self.return_no_value_name[cur_func_node].extend(no_value_names) # Handle tuple/non-tuple case if max_return_length == 1: assign_nodes.append( gast.Assign( targets=[ gast.Name( id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Name( id=no_value_names[0], ctx=gast.Load(), annotation=None, type_comment=None))) else: # max_return_length > 1 which means we should assign tuple fill_tuple = [ gast.Name( id=n, ctx=gast.Load(), annotation=None, type_comment=None) for n in no_value_names ] if return_node.value is not None: if isinstance(return_node.value, gast.Tuple): fill_tuple[:0] = return_node.value.elts else: fill_tuple.insert(0, return_node.value) assign_nodes.append( gast.Assign( targets=[ gast.Name( id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Tuple( elts=fill_tuple, ctx=gast.Load()))) else: # In this case we should NOT append RETURN_NO_VALUE placeholder if return_node.value is not None: cur_func_node = self.function_def[-1] if self.return_value_name[cur_func_node] is None: self.return_value_name[ cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) assign_nodes.append( gast.Assign( targets=[ gast.Name( id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_node.value)) stmt_list[i:] = assign_nodes return True
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) body_defs = body_scope.created | body_scope.modified orelse_defs = orelse_scope.created | orelse_scope.modified live = anno.getanno(node, 'live_out') # We'll need to check if we're closing over variables that are defined # elsewhere in the function # NOTE: we can only detect syntactic closure in the scope # of the code passed in. If the AutoGraph'd function itself closes # over other variables, this analysis won't take that into account. defined = anno.getanno(node, 'defined_in') # We only need to return variables that are # - modified by one or both branches # - live (or has a live parent) at the end of the conditional modified = [] for def_ in body_defs | orelse_defs: def_with_parents = set((def_, )) | def_.support_set if live & def_with_parents: modified.append(def_) # We need to check if live created variables are balanced # in both branches created = live & (body_scope.created | orelse_scope.created) # The if statement is illegal if there are variables that are created, # that are also live, but both branches don't create them. if created: if created != (body_scope.created & live): raise ValueError( 'The main branch does not create all live symbols that the else ' 'branch does.') if created != (orelse_scope.created & live): raise ValueError( 'The else branch does not create all live symbols that the main ' 'branch does.') # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) aliased_orelse_orig_names = tuple(orelse_scope.modified - orelse_scope.created) aliased_body_new_names = tuple( self.context.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict( zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) if not modified: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. results = None elif len(modified) == 1: results = modified[0] else: results = gast.Tuple([s.ast() for s in modified], None) body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.context.namer.new_symbol('if_false', orelse_scope.referenced) if modified: def build_returns(aliased_names, alias_map, scope): """Builds list of return variables for a branch of a conditional.""" returns = [] for s in modified: if s in aliased_names: returns.append(alias_map[s]) else: if s not in scope.created | defined: raise ValueError( 'Attempting to return variable "%s" from the true branch of ' 'a conditional, but it was not closed over, or created in ' 'this branch.' % str(s)) else: returns.append(s) return tuple(returns) body_returns = build_returns(aliased_body_orig_names, alias_body_map, body_scope) orelse_returns = build_returns(aliased_orelse_orig_names, alias_orelse_map, orelse_scope) else: body_returns = orelse_returns = templates.replace( 'tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, aliased_orig_names=tuple(aliased_body_orig_names), aliased_new_names=tuple(aliased_body_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=tuple(aliased_orelse_orig_names), aliased_new_names=tuple(aliased_orelse_new_names), body=node_orelse, returns=orelse_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def sub(): return ast.Tuple(Placeholder(0), ast.Load())
def visit_If(self, node): if node.test not in self.static_expressions: return self.generic_visit(node) imported_ids = self.gather(ImportedIds, node) assigned_ids_left = self.escaping_ids(node, node.body) assigned_ids_right = self.escaping_ids(node, node.orelse) assigned_ids_both = assigned_ids_left.union(assigned_ids_right) imported_ids.update(i for i in assigned_ids_left if i not in assigned_ids_right) imported_ids.update(i for i in assigned_ids_right if i not in assigned_ids_left) imported_ids = sorted(imported_ids) assigned_ids = sorted(assigned_ids_both) fbody = self.make_fake(node.body) true_has_return = self.gather(HasReturn, fbody) true_has_break = self.gather(HasBreak, fbody) true_has_cont = self.gather(HasContinue, fbody) felse = self.make_fake(node.orelse) false_has_return = self.gather(HasReturn, felse) false_has_break = self.gather(HasBreak, felse) false_has_cont = self.gather(HasContinue, felse) has_return = true_has_return or false_has_return has_break = true_has_break or false_has_break has_cont = true_has_cont or false_has_cont self.generic_visit(node) func_true = outline(self.true_name(), imported_ids, assigned_ids, node.body, has_return, has_break, has_cont) func_false = outline(self.false_name(), imported_ids, assigned_ids, node.orelse, has_return, has_break, has_cont) self.new_functions.extend((func_true, func_false)) actual_call = self.make_dispatcher(node.test, func_true, func_false, imported_ids) # variable modified within the static_if expected_return = [ ast.Name(ii, ast.Store(), None) for ii in assigned_ids ] self.update = True # name for various variables resulting from the static_if n = len(self.new_functions) status_n = "$status{}".format(n) return_n = "$return{}".format(n) cont_n = "$cont{}".format(n) if has_return: cont_ass = self.make_control_flow_handlers(cont_n, status_n, expected_return, has_cont, has_break) cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None), [ast.Eq()], [ast.Num(EARLY_RET)]) fast_return = [ ast.Name(status_n, ast.Store(), None), ast.Name(return_n, ast.Store(), None), ast.Name(cont_n, ast.Store(), None) ] return [ ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call), ast.If(cmpr, [ast.Return(ast.Name(return_n, ast.Load(), None))], cont_ass) ] elif has_break or has_cont: cont_ass = self.make_control_flow_handlers(cont_n, status_n, expected_return, has_cont, has_break) fast_return = [ ast.Name(status_n, ast.Store(), None), ast.Name(cont_n, ast.Store(), None) ] return [ ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call) ] + cont_ass elif expected_return: return ast.Assign([ast.Tuple(expected_return, ast.Store())], actual_call) else: return ast.Expr(actual_call)
def visit_While(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.used: for root in s.support_set: if root not in body_scope.created: cond_closure.add(root) state = list(body_closure) if not state: # TODO(mdan): Implement this properly. # To complete this statement, we need to check whether any variable # created inside the body scope is used before being modified outside the # scope. This should be done during activity analysis, and in general # should cover the case where variables may not be initialized. raise ValueError('cannot convert while loop: no outputs') state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) # TODO(b/113118541) investigate the need-for and correctness-of extra_deps template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) return node
def visit_DictComp(self, node): # this is a quickfix to match visit_AnyComp signature # potential source of improvement there! node.elt = ast.List([ast.Tuple([node.key, node.value], ast.Load())], ast.Load()) return self.visit_AnyComp(node, "dict", "__dispatch__", "update")
def visit_ExtSlice(self, node): new_node = gast.Tuple(self._visit(node.dims), gast.Load()) gast.copy_location(new_node, node) return new_node
def _ast_tuple_or_item(self, elts, ctx): elts = list(elts) if len(elts) == 1: return elts[0] return gast.Tuple(elts, ctx)
def visit_FunctionDef(self, node): self.function_def.append(node) self.return_value_name[node] = None self.return_name[node] = [] self.return_no_value_name[node] = [] self.pre_analysis = ReturnAnalysisVisitor(node) max_return_length = self.pre_analysis.get_func_max_return_length(node) while self.pre_analysis.get_func_return_count(node) > 1: self.generic_visit(node) self.pre_analysis = ReturnAnalysisVisitor(node) if max_return_length == 0: self.function_def.pop() return node # Prepend initialization of final return and append final return statement value_name = self.return_value_name[node] if value_name is not None: node.body.append( gast.Return(value=gast.Name( id=value_name, ctx=gast.Load(), annotation=None, type_comment=None))) init_names = [ unique_name.generate(RETURN_VALUE_INIT_NAME) for i in range(max_return_length) ] assign_zero_nodes = [ create_fill_constant_node(iname, 0.0) for iname in init_names ] if len(init_names) == 1: return_value_nodes = gast.Name( id=init_names[0], ctx=gast.Load(), annotation=None, type_comment=None) else: # We need to initialize return value as a tuple because control # flow requires some inputs or outputs have same structure return_value_nodes = gast.Tuple( elts=[ gast.Name( id=iname, ctx=gast.Load(), annotation=None, type_comment=None) for iname in init_names ], ctx=gast.Load()) assign_return_value_node = gast.Assign( targets=[ gast.Name( id=value_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_value_nodes) node.body.insert(0, assign_return_value_node) node.body[:0] = assign_zero_nodes # Prepend control flow boolean nodes such as '__return@1 = False' for name in self.return_name[node]: assign_false_node = create_fill_constant_node(name, False) node.body.insert(0, assign_false_node) # Prepend no value placeholders for name in self.return_no_value_name[node]: assign_no_value_node = create_fill_constant_node( name, RETURN_NO_VALUE_MAGIC_NUM) node.body.insert(0, assign_no_value_node) self.function_def.pop() return node
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name, max_return_length): assert max_return_length >= 0, "Input illegal max_return_length" i = index_in_list(stmt_list, return_node) if i == -1: return False assign_nodes = [create_fill_constant_node(return_name, True)] cur_func_node = self.function_def[-1] return_length = get_return_size(return_node) if return_length < max_return_length: # In this case we should append RETURN_NO_VALUE placeholder # # max_return_length must be >= 1 here because return_length will be # 0 at least. if self.return_value_name[cur_func_node] is None: self.return_value_name[cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) no_value_names = [ unique_name.generate(RETURN_NO_VALUE_VAR_NAME) for j in range(max_return_length - return_length) ] self.return_no_value_name[cur_func_node].extend(no_value_names) # Handle tuple/non-tuple case if max_return_length == 1: assign_nodes.append( gast.Assign( targets=[ gast.Name( id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Name( id=no_value_names[0], ctx=gast.Load(), annotation=None, type_comment=None))) else: # max_return_length > 1 which means we should assign tuple fill_tuple = [ gast.Name( id=n, ctx=gast.Load(), annotation=None, type_comment=None) for n in no_value_names ] if return_node.value is not None: if isinstance(return_node.value, gast.Tuple): fill_tuple[:0] = return_node.value.elts else: fill_tuple.insert(0, return_node.value) assign_nodes.append( gast.Assign( targets=[ gast.Name( id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Tuple( elts=fill_tuple, ctx=gast.Load()))) else: # In this case we should NOT append RETURN_NO_VALUE placeholder if return_node.value is not None: cur_func_node = self.function_def[-1] if self.return_value_name[cur_func_node] is None: self.return_value_name[ cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) assign_nodes.append( gast.Assign( targets=[ gast.Name( id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_node.value)) stmt_list[i:] = assign_nodes return True
def visit_List(self, node): # because global lists in pythran are static lists return ast.Call( path_to_attr(('builtins', 'pythran', 'static_list')), [ast.Tuple([self.visit(elt) for elt in node.elts], ast.Load())], [])
def visit_Call(self, node): """Create adjoint for call. We don't allow unpacking of parameters, so we know that each argument gets passed in explicitly, allowing us to create partials for each. However, templates might perform parameter unpacking (for cases where the number of arguments is variable) and express their gradient as a tuple. In this case, we have to unpack this tuple of partials. """ # Find the function we are differentiating func = anno.getanno(node, 'func') if func in non_differentiable.NON_DIFFERENTIABLE: return node, [] if func == tracing.Traceable: return self.primal_and_adjoint_for_tracing(node) if func in grads.UNIMPLEMENTED_ADJOINTS: raise errors.ReverseNotImplementedError(func) # If we don't have an adjoint, we will have to step into the called # function and differentiate it if func not in grads.adjoints: active_args = tuple(i for i, arg in enumerate(node.args) if arg.id in self.active_variables) already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) pri_name = naming.primal_name(func, active_args) pri_call = gast.Call( func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None), args=[self.substack] + node.args, keywords=node.keywords) anno.setanno(pri_call, 'pri_call', True) dy = create.create_grad(self.target, self.namer) dy.ctx = gast.Load() dx = create.create_grad(node.args[0], self.namer) dx.ctx = gast.Store() adj_name = naming.adjoint_name(func, active_args) adj_call = gast.Call( func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None), args=[self.substack, dy] + node.args, keywords=node.keywords) anno.setanno(adj_call, 'adj_call', True) adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)] for j, i in enumerate(active_args): adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer, x=node.args[i].id, i=gast.Num(n=j))) return pri_call, adjoint # We have a template for the gradient that we need to fill in template_ = grads.adjoints[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) # Fill in any missing kwargs with the defaults from the template args = quoting.parse_function(template_).body[0].args kwargs = dict(zip(*map(reversed, [args.args, args.defaults]))) kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults))) for arg, val in kwargs.items(): if arg.id not in bound_args.arguments: bound_args.arguments[arg.id] = val # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: ast_.copy_node(self.target)} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs packing = [] flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) packing = [gast.Assign(targets=[target], value=value)] # And we fill in the packed tuple into the template arg_replacements[six.get_function_code( template_).co_varnames[-1]] = target adjoint = template.replace(template_, namer=self.namer, **arg_replacements) unpacking = [] if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [create.create_temp_grad(arg, self.namer) for arg in to_pack] value = create.create_grad(target, self.namer) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) unpacking = [gast.Assign(targets=[target], value=value)] return node, packing + adjoint + unpacking
def visit_FunctionDef(self, node): # Construct a namer to guarantee we create unique names that don't # override existing names self.namer = naming.Namer.build(node) # Check that this function has exactly one return statement at the end return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)] if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)): raise ValueError('function must have exactly one return statement') return_node = ast_.copy_node(return_nodes[0]) # Perform AD on the function body body, adjoint_body = self.visit_statements(node.body[:-1]) # Annotate the first statement of the primal and adjoint as such if body: body[0] = comments.add_comment(body[0], 'Beginning of forward pass') if adjoint_body: adjoint_body[0] = comments.add_comment( adjoint_body[0], 'Beginning of backward pass') # Before updating the primal arguments, extract the arguments we want # to differentiate with respect to dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer) for i in self.wrt], ctx=gast.Load()) if self.preserve_result: # Append an extra Assign operation to the primal body # that saves the original output value stored_result_node = quoting.quote(self.namer.unique('result')) assign_stored_result = template.replace( 'result=orig_result', result=stored_result_node, orig_result=return_node.value) body.append(assign_stored_result) dx.elts.append(stored_result_node) for _dx in dx.elts: _dx.ctx = gast.Load() return_dx = gast.Return(value=dx) # We add the stack as first argument of the primal node.args.args = [self.stack] + node.args.args # Rename the function to its primal name func = anno.getanno(node, 'func') node.name = naming.primal_name(func, self.wrt) # The new body is the primal body plus the return statement node.body = body + node.body[-1:] # Find the cost; the first variable of potentially multiple return values # The adjoint will receive a value for the initial gradient of the cost y = node.body[-1].value if isinstance(y, gast.Tuple): y = y.elts[0] dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(), annotation=None) # Construct the adjoint adjoint_template = grads.adjoints[gast.FunctionDef] adjoint, = template.replace(adjoint_template, namer=self.namer, adjoint_body=adjoint_body, return_dx=return_dx) adjoint.args.args.extend([self.stack, dy]) adjoint.args.args.extend(node.args.args[1:]) adjoint.name = naming.adjoint_name(func, self.wrt) return node, adjoint
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) if body_scope.created - orelse_scope.created: raise ValueError( 'The if branch creates new symbols that the else branch does not.') if orelse_scope.created - body_scope.created: raise ValueError( 'The else branch creates new symbols that the if branch does not.') modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. need_alias = ( (body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) aliased_orig_names = tuple(need_alias) aliased_new_names = tuple( self.context.namer.new_symbol(s.ssf(), all_referenced) for s in aliased_orig_names) alias_map = dict(zip(aliased_orig_names, aliased_new_names)) node_body = ast_util.rename_symbols(node.body, alias_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_map) if not modified: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. results = None elif len(modified) == 1: results = modified[0] else: results = gast.Tuple([s.ast() for s in modified], None) body_name = self.context.namer.new_symbol('if_true', all_referenced) orelse_name = self.context.namer.new_symbol('if_false', all_referenced) if modified: body_returns = tuple( alias_map[s] if s in aliased_orig_names else s for s in modified) else: body_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_orelse, returns=body_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, 'body_scope') orelse_scope = anno.getanno(node, 'orelse_scope') if body_scope.created - orelse_scope.created: raise ValueError( 'The if branch creates new symbols that the else branch does not.' ) if orelse_scope.created - body_scope.created: raise ValueError( 'The else branch creates new symbols that the if branch does not.' ) def template( # pylint:disable=missing-docstring test, body_name, body, orelse_name, orelse, aliased, aliases, # pylint:disable=unused-argument aliased_results, results): # pylint:disable=unused-argument def body_name(): # pylint:disable=function-redefined aliases, = aliased, # pylint:disable=unused-variable body # pylint:disable=pointless-statement return (aliased_results, ) def orelse_name(): # pylint:disable=function-redefined aliases, = aliased, # pylint:disable=unused-variable orelse # pylint:disable=pointless-statement return (aliased_results, ) results = tf.cond(test, body_name, orelse_name) # pylint:disable=undefined-variable all_modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. need_alias = ((body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) aliased = tuple(need_alias) aliases = tuple( self.namer.new_symbol(s, all_referenced) for s in aliased) alias_map = dict(zip(aliased, aliases)) node_body = node.body node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body] node_orelse = node.orelse node_orelse = [SymbolRenamer(alias_map).visit(n) for n in node_orelse] if len(all_modified) == 1: results = gast.Name(all_modified[0], None, None) else: results = gast.Tuple( tuple(gast.Name(s, None, None) for s in all_modified), None) return templates.replace( template, test=node.test, body_name=gast.Name( self.namer.new_symbol('if_true', all_referenced), None, None), body=node_body, orelse_name=gast.Name( self.namer.new_symbol('if_false', all_referenced), None, None), orelse=node_orelse, aliased=tuple(gast.Name(s, None, None) for s in aliased), aliases=tuple(gast.Name(s, None, None) for s in aliases), aliased_results=tuple( gast.Name(alias_map[s] if s in aliased else s, None, None) for s in all_modified), results=results)
def _consume_args(self): if self._arg_accumulator: self._argspec.append( gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load())) self._arg_accumulator = []
def visit_ExtSlice(self, node): new_dims = self._visit(node.dims) new_node = gast.Tuple(new_dims, gast.Load()) gast.copy_location(new_node, node) new_node.end_lineno = new_node.end_col_offset = None return new_node
def visit_If(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() for s in modified_in_cond: if s in live_out: returned_from_cond.add(s) elif s.is_composite(): # Special treatment for compound objects: if any of their owner entities # are live, then they are outputs as well. if live_out & s.owner_set: returned_from_cond.add(s) need_alias_in_body = body_scope.modified & defined_in need_alias_in_orelse = orelse_scope.modified & defined_in created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in if created_in_body != created_in_orelse: raise ValueError( 'if statement may not initialize all variables: the true branch' ' creates %s, while the false branch creates %s. Make sure all' ' these variables are initialized either in both' ' branches or before the if statement.' % (self._fmt_symbols(created_in_body), self._fmt_symbols(created_in_orelse))) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): This doesn't belong here; it's specific to the operator. returned_from_body = (templates.replace_as_expression('tf.constant(1)'),) returned_from_orelse = ( templates.replace_as_expression('tf.constant(1)'),) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) cond_expr = self._create_cond_expr(cond_results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, 'body_scope') orelse_scope = anno.getanno(node, 'orelse_scope') if body_scope.created - orelse_scope.created: raise ValueError( 'The if branch creates new symbols that the else branch does not.' ) if orelse_scope.created - body_scope.created: raise ValueError( 'The else branch creates new symbols that the if branch does not.' ) all_modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. need_alias = ((body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) aliased_orig_names = tuple(need_alias) aliased_new_names = tuple( self.namer.new_symbol(s, all_referenced) for s in aliased_orig_names) alias_map = dict(zip(aliased_orig_names, aliased_new_names)) node_body = node.body node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body] node_orelse = node.orelse node_orelse = [SymbolRenamer(alias_map).visit(n) for n in node_orelse] if len(all_modified) == 1: results = gast.Name(all_modified[0], None, None) else: results = gast.Tuple( tuple(gast.Name(s, None, None) for s in all_modified), None) template = """ def body_name(): aliased_new_names, = aliased_orig_names, body return (all_results,) def orelse_name(): aliased_new_names, = aliased_orig_names, orelse return (all_results,) results = tf.cond(test, body_name, orelse_name) """ body_name = self.namer.new_symbol('if_true', all_referenced) return templates.replace( template, test=node.test, body_name=body_name, body=node_body, orelse_name=self.namer.new_symbol('if_false', all_referenced), orelse=node_orelse, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), all_results=tuple(alias_map[s] if s in aliased_orig_names else s for s in all_modified), results=results)
def totuple(node): return ast.Tuple(node.elts, node.ctx)
def visit_If(self, node): self.generic_visit(node) if node.test not in self.static_expressions: return node imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx) assigned_ids_left = set( self.passmanager.gather(IsAssigned, self.make_fake(node.body), self.ctx).keys()) assigned_ids_right = set( self.passmanager.gather(IsAssigned, self.make_fake(node.orelse), self.ctx).keys()) assigned_ids_both = assigned_ids_left.union(assigned_ids_right) imported_ids.update(i for i in assigned_ids_left if i not in assigned_ids_right) imported_ids.update(i for i in assigned_ids_right if i not in assigned_ids_left) imported_ids = sorted(imported_ids) assigned_ids = sorted(assigned_ids_both) true_has_return = self.passmanager.gather(HasReturn, self.make_fake(node.body), self.ctx) false_has_return = self.passmanager.gather(HasReturn, self.make_fake(node.orelse), self.ctx) has_return = true_has_return or false_has_return func_true = outline(self.true_name(), imported_ids, assigned_ids, node.body, has_return) func_false = outline(self.false_name(), imported_ids, assigned_ids, node.orelse, has_return) self.new_functions.extend((func_true, func_false)) actual_call = self.make_dispatcher(node.test, func_true, func_false, imported_ids) expected_return = [ ast.Name(ii, ast.Load(), None) for ii in assigned_ids ] if has_return: n = len(self.new_functions) fast_return = [ ast.Name("$status{}".format(n), ast.Load(), None), ast.Name("$return{}".format(n), ast.Load(), None), ast.Name("$cont{}".format(n), ast.Load(), None) ] if expected_return: cont_ass = [ ast.Assign([ast.Tuple(expected_return, ast.Store())], ast.Name("$cont{}".format(n), ast.Load(), None)) ] else: cont_ass = [] return [ ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call), ast.If(ast.Name("$status{}".format(n), ast.Load(), None), [ ast.Return( ast.Name("$return{}".format(n), ast.Load(), None)) ], cont_ass) ] elif expected_return: return ast.Assign([ast.Tuple(expected_return, ast.Store())], actual_call) else: return ast.Expr(actual_call)
def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): # Patterns of single function calls, like: # opt.minimize(loss) # or: # tf.py_func(...) # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) # NOTE: We can't guard object attributes because they may not be writable. # In addition, avoid renaming well-known names. # TODO(mdan): Move these names into config. unguarded_names = (qual_names.QN('self'), qual_names.QN('ag__')) guarded_args = tuple( s for s in live_out if not s.is_composite() and s not in unguarded_names) # TODO(mdan): Include all arguments which depended on guarded_args too. # For example, the following will still cause a race: # tf.assign(a, a + 1) # b = a + 1 # tf.assign(a, a + 1) # Control deps here should include `b` # c = b + 1 # Or maybe we should just raise an "unsafe assign" error? if guarded_args: # The aliases may need new names to avoid incorrectly making them local. # TODO(mdan): This is brutal. It will even rename modules - any fix? need_alias = tuple(s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( self.ctx.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: s, = guarded_args aliased_guarded_args = alias_map.get(s, s) else: aliased_guarded_args = gast.Tuple( [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ with ag__.utils.control_dependency_on_returns(call): aliased_guarded_args = ag__.utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, call=node.value, aliased_guarded_args=aliased_guarded_args, guarded_args=guarded_args)[-1] else: alias_map = {} template = """ with ag__.utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] control_deps_guard.body = [] node = control_deps_guard anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, (node.body, alias_map)) return node
def visit_Call(self, node): if not self.target: return node func = anno.getanno(node, 'func') if func in tangents.UNIMPLEMENTED_TANGENTS: raise errors.ForwardNotImplementedError(func) if func == tracing.Traceable: raise NotImplementedError( 'Tracing of %s is not enabled in forward mode' % quoting.unquote(node)) if func not in tangents.tangents: try: quoting.parse_function(func) except: raise ValueError( 'No tangent found for %s, and could not get source.' % func.__name__) # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy) active_args = tuple(i for i, arg in enumerate(node.args) if isinstance(arg, gast.Name)) # TODO: Stack arguments are currently not considered # active, but for forward-mode applied to call trees, # they have to be. When we figure out how to update activity # analysis to do the right thing, we'll want to add the extra check: # `and arg.id in self.active_variables` # TODO: Duplicate of code in reverse_ad. already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) fn_name = naming.tangent_name(func, active_args) orig_args = quoting.parse_function(func).body[0].args tangent_keywords = [] for i in active_args: grad_node = create.create_grad(node.args[i], self.namer, tangent=True) arg_grad_node = create.create_grad(orig_args.args[i], self.namer, tangent=True) grad_node.ctx = gast.Load() tangent_keywords.append( gast.keyword(arg=arg_grad_node.id, value=grad_node)) # Update the original call rhs = gast.Call(func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None), args=node.args, keywords=tangent_keywords + node.keywords) # Set self.value to False to trigger whole primal replacement self.value = False return [rhs] template_ = tangents.tangents[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict( (keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) # Fill in any missing kwargs with the defaults from the template args = quoting.parse_function(template_).body[0].args kwargs = dict(zip(*map(reversed, [args.args, args.defaults]))) kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults))) for arg, val in kwargs.items(): if arg.id not in bound_args.arguments: bound_args.arguments[arg.id] = val # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: self.tmp_node} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) # And we fill in the packed tuple into the template arg_replacements[six.get_function_code( template_).co_varnames[-1]] = target tangent_node = template.replace(template_, replace_grad=template.Replace.TANGENT, namer=self.namer, **arg_replacements) # If the template uses the answer in the RHS of the tangent, # we need to make sure that the regular answer is replaced # with self.tmp_node, but that the gradient is not. We have # to be extra careful for statements like a = exp(a), because # both the target and RHS variables have the same name. tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True) tmp_grad_name = tmp_grad_node.id ans_grad_node = create.create_grad(self.target, self.namer, tangent=True) for _node in tangent_node: for succ in gast.walk(_node): if isinstance(succ, gast.Name) and succ.id == tmp_grad_name: succ.id = ans_grad_node.id if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [ create.create_temp_grad(arg, self.namer, True) for arg in to_pack ] value = create.create_grad(target, self.namer, tangent=True) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) # Stack pops have to be special-cased, we have # to set the 'push' attribute, so we know that if we # remove this pop, we have to remove the equivalent push. # NOTE: this only works if we're doing forward-over-reverse, # where reverse is applied in joint mode, with no call tree. # Otherwise, the pushes and pops won't be matched within a single # function call. if func == tangent.pop: if len(self.metastack): anno.setanno(tangent_node[0], 'push', self.metastack.pop()) else: anno.setanno(tangent_node[0], 'push', None) return tangent_node
id='__builtin__', ctx=ast.Load(), annotation=None), attr="pythran", ctx=ast.Load()), attr="abssqr", ctx=ast.Load()), args=[Placeholder(0)], keywords=[])), # __builtin__.tuple([X, ..., Z]) => (X, ..., Z) (ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr="tuple", ctx=ast.Load()), args=[ast.List(Placeholder(0), ast.Load())], keywords=[]), lambda: ast.Tuple(Placeholder(0), ast.Load())), # __builtin__.reversed(__builtin__.xrange(X)) => # __builtin__.xrange(X-1, -1, -1) # FIXME : We should do it even when begin/end/step are given (ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr="reversed", ctx=ast.Load()), args=[ ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr=range_name, ctx=ast.Load()),
(ast.BinOp(left=Placeholder(0), op=ast.Mult(), right=Placeholder(0)), lambda: ast.BinOp(left=Placeholder(0), op=ast.Pow(), right=ast.Num(n=2))), # a + "..." + b => "...".join((a, b)) (ast.BinOp(left=ast.BinOp(left=Placeholder(0), op=ast.Add(), right=ast.Str(Placeholder(1))), op=ast.Add(), right=Placeholder(2)), lambda: ast.Call(func=ast.Attribute( ast.Attribute( ast.Name('__builtin__', ast.Load(), None), 'str', ast.Load()), 'join', ast.Load()), args=[ast.Str(Placeholder(1)), ast.Tuple([Placeholder(0), Placeholder(2)], ast.Load())], keywords=[])), ] class PlaceholderReplace(Transformation): """ Helper class to replace the placeholder once value is collected. """ def __init__(self, placeholders): """ Store palceholders value collected. """ self.placeholders = placeholders super(PlaceholderReplace, self).__init__() def visit(self, node): """ Replace the placeholder if it is one or continue. """