Beispiel #1
0
class Find_LICM_Candidates(SyntaxVisitor):
  def __init__(self):
    SyntaxVisitor.__init__(self)
    self.mutable_types = None
    self.volatile_vars = ScopedSet()
    self.depends_on = {}
    self.safe_to_move = set([])
    self.curr_block_id = None
    self.block_contains_return = set([])
    self.may_alias = None

  def visit_fn(self, fn):
    self.volatile_vars.push(fn.arg_names)
    self.may_alias = may_alias(fn)
    SyntaxVisitor.visit_fn(self, fn)
    return self.safe_to_move

  def mark_safe_assignments(self, block, volatile_set):
    for stmt in block:
      klass = stmt.__class__
      if klass is Assign and \
         stmt.lhs.__class__ is Var:
        name = stmt.lhs.name
        dependencies = self.depends_on.get(name, set([]))
        volatile = name in volatile_set or \
                   any(d in volatile_set for d in dependencies)
        if not volatile:
          self.safe_to_move.add(name)
      # just in case there are Returns in nested control flow
      # we should probably avoid changing the performance characteristics
      # by pulling out statements which will never run
      elif klass is If:
        if id(stmt.true) in self.block_contains_return or \
           id(stmt.false) in self.block_contains_return:
          break
      elif klass is While:
        if id(stmt.body) in self.block_contains_return:
          break
      elif klass is Return:
        break

  def mark_curr_block_returns(self):
    self.block_contains_return.add(self.curr_block_id)

  def does_block_return(self, block):
    return id(block) in self.block_contains_return

  def visit_Return(self, stmt):
    self.mark_curr_block_returns()

  def visit_block(self, stmts):
    self.curr_block_id = id(stmts)
    SyntaxVisitor.visit_block(self, stmts)

  def visit_merge(self, merge, both_branches = True):
    pass

  def visit_ForLoop(self, stmt):
    self.volatile_vars.push(stmt.merge.keys())
    self.volatile_vars.add(stmt.var.name)
    SyntaxVisitor.visit_ForLoop(self, stmt)
    if self.does_block_return(stmt.body):
      self.block_contains_return()
    volatile_in_scope = self.volatile_vars.pop()
    self.mark_safe_assignments(stmt.body, volatile_in_scope)

  def visit_While(self, stmt):
    self.volatile_vars.push(stmt.merge.keys())
    SyntaxVisitor.visit_While(self, stmt)
    if self.does_block_return(stmt.body):
      self.block_contains_return()
    volatile_in_scope = self.volatile_vars.pop()
    self.mark_safe_assignments(stmt.body, volatile_in_scope)

  def visit_Var(self, expr):
    self.volatile_vars.add(expr.name)

  def visit_If(self, stmt):
    self.volatile_vars.push(stmt.merge.keys())
    self.visit_expr(stmt.cond)
    SyntaxVisitor.visit_If(self, stmt) 
    if self.does_block_return(stmt.true) or self.does_block_return(stmt.false):
      self.mark_curr_block_returns()
    volatile_in_scope = self.volatile_vars.pop()
    self.mark_safe_assignments(stmt.true, volatile_in_scope)
    self.mark_safe_assignments(stmt.false, volatile_in_scope)

  def is_mutable_alloc(self, expr):
    c = expr.__class__
    
    return c is Alloc or \
           c is AllocArray or \
           c is Array or \
           c is ArrayView or \
           c is Slice or  \
           (c is Struct and not isinstance(expr.type, ImmutableT))

  def visit_Assign(self, stmt):
     
    lhs_names = collect_binding_names(stmt.lhs)
    rhs_names = collect_var_names(stmt.rhs)
   
    for x in lhs_names:
      dependencies = self.depends_on.get(x, set([]))
      dependencies.update(rhs_names)
      self.depends_on[x] = dependencies

    if any(x in self.volatile_vars for x in rhs_names):
      self.volatile_vars.update(lhs_names)
    elif self.is_mutable_alloc(stmt.rhs):
      if len(lhs_names) == 1 and \
         len(self.may_alias.get(lhs_names[0], [])) <= 1:
        pass
      else:
        self.volatile_vars.update(lhs_names)
    # mark any array writes as volatile 
    if stmt.lhs.__class__ is Index:
      assert stmt.lhs.value.__class__ is Var, \
        "Expected LHS array to be variable but instead got %s" % stmt  
      self.volatile_vars.add(stmt.lhs.value.name)