Esempio n. 1
0
    def visit_If(self, node):
        cond = self.visit(node.test)

        # Create ifop
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
        modified_in_cond = list(body_scope.modified | orelse_scope.modified)
        outputs = [
            self.symbol_table.lookup_type(str(var)) for var in modified_in_cond
        ]
        ifop = tfp.IfOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
                               cond, outputs)

        # Cache the builder
        cache_builder = self.opbuilder

        # Visit body
        self.opbuilder = tfp.OpBuilder(ifop.getRegion(0))
        # Enter scope to avoid values generated inside the region to come in symbol
        # table
        self.symbol_table.enter_scope()
        for stmt in node.body:
            self.visit(stmt)
        retvals = [
            self.symbol_table.lookup(str(varname))
            for varname in modified_in_cond
        ]
        tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
                            retvals)
        self.symbol_table.exit_scope()

        # Visit orelse
        self.opbuilder = tfp.OpBuilder(ifop.getRegion(1))
        self.symbol_table.enter_scope()
        for stmt in node.orelse:
            self.visit(stmt)
        retvals = [
            self.symbol_table.lookup(str(varname))
            for varname in modified_in_cond
        ]
        tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
                            retvals)
        self.symbol_table.exit_scope()

        # Reset builder and enter return values in symbol table
        self.opbuilder = cache_builder
        for idx, var in enumerate(modified_in_cond):
            self.symbol_table.insert_symbol(str(var), ifop.getResult(idx))

        if ifop.getNumResults() == 1:
            return ifop.getResult(0)

        return tuple(ifop.getResult(i) for i in range(ifop.getNumResults()))
Esempio n. 2
0
  def visit_FunctionDef(self, node):
    # Cache the current builder
    cache_builder = self.opbuilder
    inputs, outputs = [], []

    for arg in node.args.args:
      inputs.append(self.process_type(arg.annotation))

    if node.returns:
      outputs = [self.process_type(node.returns)]

    currfunc = self.prog.add_function(
        self.ctx.namer.new_symbol(node.name, []),
        self.prog.get_function_type(inputs, outputs))

    # Add the function to symbol table and enter new scope
    self.symbol_table.insert_symbol(node.name, currfunc)
    self.symbol_table.enter_scope()

    # Add arguments to symbol table
    for arg, value in zip(node.args.args, currfunc.getArguments()):
      self.symbol_table.insert_symbol(arg.id, value)
    self.opbuilder = tfp.OpBuilder(currfunc.getBody())

    self.visit_block(node.body)
    self.symbol_table.exit_scope()
    self.opbuilder = cache_builder
Esempio n. 3
0
    def visit_While(self, node):

        # Create a new WhileOp
        # `inputs` are initial values for loop variables
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        loop_vars, _, _ = self._get_loop_vars(node, body_scope.modified)
        inputs = [self.symbol_table.lookup(str(name)) for name in loop_vars]
        types = [input_.getType() for input_ in inputs]
        while_op = tfp.WhileOp.create(self.opbuilder,
                                      self.opbuilder.getUnknownLoc(), inputs,
                                      types)

        # cache the current builder
        cache_builder = self.opbuilder

        # Process cond
        self.symbol_table.enter_scope()
        for input_, type_ in zip(loop_vars, types):
            self.symbol_table.insert_symbol(
                str(input_),
                while_op.getRegion(0).front().addArgument(type_))
        self.opbuilder = tfp.OpBuilder(while_op.getRegion(0))
        tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
                            [self.visit(node.test)])
        self.symbol_table.exit_scope()

        # Process body
        self.symbol_table.enter_scope()
        for input_, type_ in zip(loop_vars, types):
            self.symbol_table.insert_symbol(
                str(input_),
                while_op.getRegion(1).front().addArgument(type_))
        self.opbuilder = tfp.OpBuilder(while_op.getRegion(1))
        self.visit_block(node.body)
        tfp.ReturnOp.create(
            self.opbuilder, self.opbuilder.getUnknownLoc(),
            [self.symbol_table.lookup(str(name)) for name in loop_vars])
        self.symbol_table.exit_scope()

        # Enter new values as symbols
        for idx, var in enumerate(loop_vars):
            self.symbol_table.insert_symbol(str(var), while_op.getResult(idx))

        # Restore builder
        self.opbuilder = cache_builder