def visit_Expr(self, node):
     if (node in self.pure_expressions
             and not isinstance(node.value, ast.Yield)):
         self.update = True
         return ast.Pass()
     self.generic_visit(node)
     return node
Beispiel #2
0
    def visit_While(self, node):
        body_scope = anno.getanno(node, anno.Static.BODY_SCOPE)
        orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE)
        modified_in_cond = body_scope.modified | orelse_scope.modified

        node = self.generic_visit(node)

        if not hasattr(self.overload.module, 'while_stmt'):
            return node

        template = """
      def test_name():
        return test
      def body_name():
        body
      def orelse_name():
        orelse
      overload.while_stmt(test_name, body_name, orelse_name, (local_writes,))
    """

        node = templates.replace(
            template,
            overload=self.overload.symbol_name,
            test_name=self.ctx.namer.new_symbol('while_test', set()),
            test=node.test,
            body_name=self.ctx.namer.new_symbol('while_body', set()),
            body=node.body,
            orelse_name=self.ctx.namer.new_symbol('while_orelse', set()),
            orelse=node.orelse if node.orelse else gast.Pass(),
            local_writes=tuple(modified_in_cond))

        return node
    def visit_If(self, node):
        self.generic_visit(node)

        try:
            if ast.literal_eval(node.test):
                if not metadata.get(node, OMPDirective):
                    self.update = True
                    return node.body
            else:
                if not metadata.get(node, OMPDirective):
                    self.update = True
                    return node.orelse
        except ValueError:
            # not a constant expression
            pass

        have_body = any(not isinstance(x, ast.Pass) for x in node.body)
        have_else = any(not isinstance(x, ast.Pass) for x in node.orelse)
        # If the "body" is empty but "else content" is useful, switch branches
        # and remove else content
        if not have_body and have_else:
            test = ast.UnaryOp(op=ast.Not(), operand=node.test)
            self.update = True
            return ast.If(test=test, body=node.orelse, orelse=list())
        # if neither "if" and "else" are useful, keep test if it is not pure
        elif not have_body:
            self.update = True
            if node.test in self.pure_expressions:
                return ast.Pass()
            else:
                node = ast.Expr(value=node.test)
                self.generic_visit(node)
        return node
Beispiel #4
0
    def attach_data(self, node):
        '''Generic method called for visit_XXXX() with XXXX in
        GatherOMPData.statements list

        '''
        if self.current:
            for curr in self.current:
                md = OMPDirective(curr)
                metadata.add(node, md)
            self.current = list()
        # add a Pass to hold some directives
        for field_name, field in ast.iter_fields(node):
            if field_name in GatherOMPData.statement_lists:
                if(field and
                   isinstance(field[-1], ast.Expr) and
                   self.isompdirective(field[-1].value)):
                    field.append(ast.Pass())
        self.generic_visit(node)

        # add an If to hold scoping OpenMP directives
        directives = metadata.get(node, OMPDirective)
        field_names = {n for n, _ in ast.iter_fields(node)}
        has_no_scope = field_names.isdisjoint(GatherOMPData.statement_lists)
        if directives and has_no_scope:
            # some directives create a scope, but the holding stmt may not
            # artificially create one here if needed
            sdirective = ''.join(d.s for d in directives)
            scoping = ('parallel', 'task', 'section')
            if any(s in sdirective for s in scoping):
                metadata.clear(node, OMPDirective)
                node = ast.If(ast.Num(1), [node], [])
                for directive in directives:
                    metadata.add(node, directive)

        return node
Beispiel #5
0
    def visit_If(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)

        cond_vars, undefined, nouts = self._get_block_vars(
            node, body_scope.bound | orelse_scope.bound)

        undefined_assigns = self._create_undefined_assigns(undefined)

        nonlocal_declarations = self._create_nonlocal_declarations(cond_vars)

        reserved = body_scope.referenced | orelse_scope.referenced
        state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
        state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
        state_functions = self._create_state_functions(cond_vars,
                                                       nonlocal_declarations,
                                                       state_getter_name,
                                                       state_setter_name)

        orelse_body = node.orelse
        if not orelse_body:
            orelse_body = [gast.Pass()]

        template = """
      state_functions
      def body_name():
        nonlocal_declarations
        body
      def orelse_name():
        nonlocal_declarations
        orelse
      undefined_assigns
      ag__.if_stmt(
        test,
        body_name,
        orelse_name,
        state_getter_name,
        state_setter_name,
        (symbol_names,),
        nouts)
    """
        new_nodes = templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol('if_body', reserved),
            orelse=orelse_body,
            orelse_name=self.ctx.namer.new_symbol('else_body', reserved),
            nonlocal_declarations=nonlocal_declarations,
            nouts=gast.Constant(nouts, kind=None),
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            symbol_names=tuple(
                gast.Constant(str(s), kind=None) for s in cond_vars),
            test=node.test,
            undefined_assigns=undefined_assigns)
        origin_info.copy_origin(node, new_nodes[-1])
        return new_nodes
Beispiel #6
0
 def visit_Assign(self, node):
     self.generic_visit(node)
     if self.D in node.targets:
         self.capture = node.value
         if len(node.targets) == 1:
             return ast.Pass()
         node.targets.remove(self.D)
     return node
Beispiel #7
0
 def _track_and_visit_loop(self, node):
     self.state[_LoopScope].enter()
     self.state[_LoopScope].ast_node = node
     node = self.generic_visit(node)
     # Edge case: a loop with just one directive statement would become empty.
     if not node.body:
         node.body = [gast.Pass()]
     self.state[_LoopScope].exit()
     return node
Beispiel #8
0
 def visit_Assign(self, node):
     if node in self.nodes:
         to_prune = self.nodes[node]
         node.targets = [tgt for tgt in node.targets if tgt not in to_prune]
         if node.targets:
             return node
         else:
             return ast.Pass()
     return node
 def visit_Assign(self, node):
     node.targets = [target for target in node.targets
                     if self.used_target(target)]
     if node.targets:
         return node
     self.update = True
     if node.value in self.pure_expressions:
         return ast.Pass()
     else:
         return ast.Expr(value=node.value)
Beispiel #10
0
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        if len(node.ops) > 1:
            # in case we have more than one compare operator
            # we generate an auxiliary function
            # that lazily evaluates the needed parameters
            imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx)
            imported_ids = sorted(imported_ids)
            binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids]

            # name of the new function
            forged_name = "{0}_compare{1}".format(self.prefix,
                                                  len(self.compare_functions))

            # call site
            call = ast.Call(ast.Name(forged_name, ast.Load(), None),
                            binded_args, [])

            # new function
            arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids]
            args = ast.arguments(arg_names, None, [], [], None, [])

            body = []  # iteratively fill the body (yeah, feel your body!)

            if is_trivially_copied(node.left):
                prev_holder = node.left
            else:
                body.append(
                    ast.Assign([ast.Name('$0', ast.Store(), None)], node.left))
                prev_holder = ast.Name('$0', ast.Load(), None)

            for i, exp in enumerate(node.comparators):
                if is_trivially_copied(exp):
                    holder = exp
                else:
                    body.append(
                        ast.Assign(
                            [ast.Name('${}'.format(i + 1), ast.Store(), None)],
                            exp))
                    holder = ast.Name('${}'.format(i + 1), ast.Load(), None)
                cond = ast.Compare(prev_holder, [node.ops[i]], [holder])
                body.append(
                    ast.If(
                        cond, [ast.Pass()],
                        [ast.Return(path_to_attr(('__builtin__', 'False')))]))
                prev_holder = holder

            body.append(ast.Return(path_to_attr(('__builtin__', 'True'))))

            forged_fdef = ast.FunctionDef(forged_name, args, body, [], None)
            self.compare_functions.append(forged_fdef)

            return call
        else:
            return node
Beispiel #11
0
    def visit_For(self, node):
        body_scope = anno.getanno(node, anno.Static.BODY_SCOPE)
        orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE)
        modified_in_cond = body_scope.modified | orelse_scope.modified

        node = self.generic_visit(node)

        if not hasattr(self.overload.module, 'for_stmt'):
            return node

        # TODO(jmd1011): Handle extra_test

        targets = []

        if isinstance(node.target, gast.Tuple) or isinstance(
                node.target, gast.List):
            for target in node.target.elts:
                targets.append(target)
        elif isinstance(node.target, gast.Name):
            targets.append(node.target)
        else:
            raise ValueError(
                'For target must be gast.Tuple, gast.List, or gast.Name, got {}.'
                .format(type(node.target)))

        target_inits = [
            self._make_target_init(target, self.overload) for target in targets
        ]

        template = """
      target_inits
      def body_name():
        body
      def orelse_name():
        orelse
      overload.for_stmt(target, iter_, body_name, orelse_name, (local_writes,))
    """

        node = templates.replace(
            template,
            target_inits=target_inits,
            target=node.target,
            body_name=self.ctx.namer.new_symbol('for_body', set()),
            body=node.body,
            orelse_name=self.ctx.namer.new_symbol('for_orelse', set()),
            orelse=node.orelse if node.orelse else gast.Pass(),
            overload=self.overload.symbol_name,
            iter_=node.iter,
            local_writes=tuple(modified_in_cond))

        return node
    def _guard_if_present(self, block, var_name):
        """Prevents the block from executing if var_name is set."""

        # If we don't have statements that immediately depend on the break
        # we still need to make sure that the break variable remains
        # used, in case the break becomes useful in later stages of transformation.
        # Not having this broke the break_in_inner_loop test.
        if not block:
            block = [gast.Pass()]
        template = """
        if not var_name:
          block
      """
        node = templates.replace(template, var_name=var_name, block=block)
        return node
Beispiel #13
0
 def visit_Try(self, node):
     if node.orelse:
         node.body.append(
             ast.Try(node.orelse,
                     [ast.ExceptHandler(None, None, [ast.Pass()])], [], []))
         node.orelse = []
         self.update = True
     if node.finalbody:
         node.body.extend(node.finalbody)
         node.finalbody.append(ast.Raise(None, None))
         self.update = True
         node = ast.Try(node.body,
                        [ast.ExceptHandler(None, None, node.finalbody)], [],
                        [])
         node.finalbody = []
     return node
Beispiel #14
0
 def visit_If(self, node):
   return gast.Pass()
Beispiel #15
0
 def visit_If(self, node):
     self.generic_visit(node)
     node.body.append(gast.Pass())
     node.orelse.append(gast.Pass())
     return node
Beispiel #16
0
  def test_buildable(self, template):
    """Test that each template can be built when given acceptable arguments."""
    rng = np.random.RandomState(1234)

    # Construct a hole that this template can always fill.
    hole = top_down_refinement.Hole(
        template.fills_type,
        python_numbers_control_flow.ASTHoleMetadata(
            names_in_scope=frozenset({"a"}),
            inside_function=True,
            inside_loop=True,
            op_depth=0))
    self.assertTrue(template.can_fill(hole))

    # Make sure we can build this object with no errors.
    filler = template.fill(hole, rng)
    dummy_values = {
        python_numbers_control_flow.ASTHoleType.NUMBER:
            (lambda: gast.Constant(value=1, kind=None)),
        python_numbers_control_flow.ASTHoleType.BOOL:
            (lambda: gast.Constant(value=True, kind=None)),
        python_numbers_control_flow.ASTHoleType.STMT: gast.Pass,
        python_numbers_control_flow.ASTHoleType.STMTS: (lambda: []),
        python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY:
            (lambda: [gast.Pass()]),
        python_numbers_control_flow.ASTHoleType.BLOCK: (lambda: [gast.Pass()]),
    }
    hole_values = [dummy_values[h.hole_type]() for h in filler.holes]
    value = filler.build(*hole_values)

    # Check the type of the value that was built.
    if template.fills_type in (
        python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY,
        python_numbers_control_flow.ASTHoleType.BLOCK):
      self.assertTrue(value)
      for item in value:
        self.assertIsInstance(item, gast.stmt)
    elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMTS:
      for item in value:
        self.assertIsInstance(item, gast.stmt)
    elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMT:
      self.assertIsInstance(value, gast.stmt)
    elif template.fills_type in (python_numbers_control_flow.ASTHoleType.NUMBER,
                                 python_numbers_control_flow.ASTHoleType.BOOL):
      self.assertIsInstance(value, gast.expr)
    else:
      raise NotImplementedError(f"Unexpected fill type {template.fills_type}; "
                                "please update this test.")

    # Check that cost reflects number of AST nodes.
    total_cost = 0
    if isinstance(value, gast.AST):
      for _ in gast.walk(value):
        total_cost += 1
    else:
      for item in value:
        for _ in gast.walk(item):
          total_cost += 1

    self.assertEqual(template.required_cost, total_cost)

    cost_without_holes = total_cost - sum(
        python_numbers_control_flow.ALL_COSTS[h.hole_type]
        for h in filler.holes)

    self.assertEqual(filler.cost, cost_without_holes)

    # Check determinism
    for _ in range(20):
      rng = np.random.RandomState(1234)
      redo_value = template.fill(hole, rng).build(*hole_values)
      if isinstance(value, list):
        self.assertEqual([gast.dump(v) for v in value],
                         [gast.dump(v) for v in redo_value])
      else:
        self.assertEqual(gast.dump(value), gast.dump(redo_value))
Beispiel #17
0
 def generic_visit(self, node):
     if node in self.nodes:
         return ast.Pass()
     return super(Remover, self).generic_visit(node)