def visit_With(self, node): """Deal with the special with insert_grad_of(x) statement.""" if ast_.is_insert_grad_of_statement(node): primal = [] adjoint = node.body if isinstance(adjoint[0], gast.With): _, adjoint = self.visit(adjoint[0]) node.body[0] = comments.add_comment(node.body[0], 'Inserted code') # Rename the gradients replacements = {} for item in node.items: if (not isinstance(item.context_expr.args[0], gast.Name) or not isinstance(item.optional_vars, gast.Name)): raise ValueError replacements[item.optional_vars.id] = create.create_grad( item.context_expr.args[0], self.namer) template.ReplaceTransformer(replacements).visit(node) return primal, adjoint else: return node, []
def visit_With(self, node): if ast_.is_insert_grad_of_statement(node): return None else: return node