def visit_If(self, node): cond = self.visit(node.test) # Create ifop body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) modified_in_cond = list(body_scope.modified | orelse_scope.modified) outputs = [ self.symbol_table.lookup_type(str(var)) for var in modified_in_cond ] ifop = tfp.IfOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), cond, outputs) # Cache the builder cache_builder = self.opbuilder # Visit body self.opbuilder = tfp.OpBuilder(ifop.getRegion(0)) # Enter scope to avoid values generated inside the region to come in symbol # table self.symbol_table.enter_scope() for stmt in node.body: self.visit(stmt) retvals = [ self.symbol_table.lookup(str(varname)) for varname in modified_in_cond ] tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals) self.symbol_table.exit_scope() # Visit orelse self.opbuilder = tfp.OpBuilder(ifop.getRegion(1)) self.symbol_table.enter_scope() for stmt in node.orelse: self.visit(stmt) retvals = [ self.symbol_table.lookup(str(varname)) for varname in modified_in_cond ] tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals) self.symbol_table.exit_scope() # Reset builder and enter return values in symbol table self.opbuilder = cache_builder for idx, var in enumerate(modified_in_cond): self.symbol_table.insert_symbol(str(var), ifop.getResult(idx)) if ifop.getNumResults() == 1: return ifop.getResult(0) return tuple(ifop.getResult(i) for i in range(ifop.getNumResults()))
def visit_FunctionDef(self, node): # Cache the current builder cache_builder = self.opbuilder inputs, outputs = [], [] for arg in node.args.args: inputs.append(self.process_type(arg.annotation)) if node.returns: outputs = [self.process_type(node.returns)] currfunc = self.prog.add_function( self.ctx.namer.new_symbol(node.name, []), self.prog.get_function_type(inputs, outputs)) # Add the function to symbol table and enter new scope self.symbol_table.insert_symbol(node.name, currfunc) self.symbol_table.enter_scope() # Add arguments to symbol table for arg, value in zip(node.args.args, currfunc.getArguments()): self.symbol_table.insert_symbol(arg.id, value) self.opbuilder = tfp.OpBuilder(currfunc.getBody()) self.visit_block(node.body) self.symbol_table.exit_scope() self.opbuilder = cache_builder
def visit_While(self, node): # Create a new WhileOp # `inputs` are initial values for loop variables body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) loop_vars, _, _ = self._get_loop_vars(node, body_scope.modified) inputs = [self.symbol_table.lookup(str(name)) for name in loop_vars] types = [input_.getType() for input_ in inputs] while_op = tfp.WhileOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), inputs, types) # cache the current builder cache_builder = self.opbuilder # Process cond self.symbol_table.enter_scope() for input_, type_ in zip(loop_vars, types): self.symbol_table.insert_symbol( str(input_), while_op.getRegion(0).front().addArgument(type_)) self.opbuilder = tfp.OpBuilder(while_op.getRegion(0)) tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), [self.visit(node.test)]) self.symbol_table.exit_scope() # Process body self.symbol_table.enter_scope() for input_, type_ in zip(loop_vars, types): self.symbol_table.insert_symbol( str(input_), while_op.getRegion(1).front().addArgument(type_)) self.opbuilder = tfp.OpBuilder(while_op.getRegion(1)) self.visit_block(node.body) tfp.ReturnOp.create( self.opbuilder, self.opbuilder.getUnknownLoc(), [self.symbol_table.lookup(str(name)) for name in loop_vars]) self.symbol_table.exit_scope() # Enter new values as symbols for idx, var in enumerate(loop_vars): self.symbol_table.insert_symbol(str(var), while_op.getResult(idx)) # Restore builder self.opbuilder = cache_builder