def __init__(self): SyntaxVisitor.__init__(self) self.mutable_types = None self.volatile_vars = ScopedSet() self.depends_on = {} self.safe_to_move = set([]) self.curr_block_id = None self.block_contains_return = set([]) self.may_alias = None
class Find_LICM_Candidates(SyntaxVisitor): def __init__(self): SyntaxVisitor.__init__(self) self.mutable_types = None self.volatile_vars = ScopedSet() self.depends_on = {} self.safe_to_move = set([]) self.curr_block_id = None self.block_contains_return = set([]) self.may_alias = None def visit_fn(self, fn): self.volatile_vars.push(fn.arg_names) self.may_alias = may_alias(fn) SyntaxVisitor.visit_fn(self, fn) return self.safe_to_move def mark_safe_assignments(self, block, volatile_set): for stmt in block: klass = stmt.__class__ if klass is Assign and \ stmt.lhs.__class__ is Var: name = stmt.lhs.name dependencies = self.depends_on.get(name, set([])) volatile = name in volatile_set or \ any(d in volatile_set for d in dependencies) if not volatile: self.safe_to_move.add(name) # just in case there are Returns in nested control flow # we should probably avoid changing the performance characteristics # by pulling out statements which will never run elif klass is If: if id(stmt.true) in self.block_contains_return or \ id(stmt.false) in self.block_contains_return: break elif klass is While: if id(stmt.body) in self.block_contains_return: break elif klass is Return: break def mark_curr_block_returns(self): self.block_contains_return.add(self.curr_block_id) def does_block_return(self, block): return id(block) in self.block_contains_return def visit_Return(self, stmt): self.mark_curr_block_returns() def visit_block(self, stmts): self.curr_block_id = id(stmts) SyntaxVisitor.visit_block(self, stmts) def visit_merge(self, merge, both_branches = True): pass def visit_ForLoop(self, stmt): self.volatile_vars.push(stmt.merge.keys()) self.volatile_vars.add(stmt.var.name) SyntaxVisitor.visit_ForLoop(self, stmt) if self.does_block_return(stmt.body): self.block_contains_return() volatile_in_scope = self.volatile_vars.pop() self.mark_safe_assignments(stmt.body, volatile_in_scope) def visit_While(self, stmt): self.volatile_vars.push(stmt.merge.keys()) SyntaxVisitor.visit_While(self, stmt) if self.does_block_return(stmt.body): self.block_contains_return() volatile_in_scope = self.volatile_vars.pop() self.mark_safe_assignments(stmt.body, volatile_in_scope) def visit_Var(self, expr): self.volatile_vars.add(expr.name) def visit_If(self, stmt): self.volatile_vars.push(stmt.merge.keys()) self.visit_expr(stmt.cond) SyntaxVisitor.visit_If(self, stmt) if self.does_block_return(stmt.true) or self.does_block_return(stmt.false): self.mark_curr_block_returns() volatile_in_scope = self.volatile_vars.pop() self.mark_safe_assignments(stmt.true, volatile_in_scope) self.mark_safe_assignments(stmt.false, volatile_in_scope) def is_mutable_alloc(self, expr): c = expr.__class__ return c is Alloc or \ c is AllocArray or \ c is Array or \ c is ArrayView or \ c is Slice or \ (c is Struct and not isinstance(expr.type, ImmutableT)) def visit_Assign(self, stmt): lhs_names = collect_binding_names(stmt.lhs) rhs_names = collect_var_names(stmt.rhs) for x in lhs_names: dependencies = self.depends_on.get(x, set([])) dependencies.update(rhs_names) self.depends_on[x] = dependencies if any(x in self.volatile_vars for x in rhs_names): self.volatile_vars.update(lhs_names) elif self.is_mutable_alloc(stmt.rhs): if len(lhs_names) == 1 and \ len(self.may_alias.get(lhs_names[0], [])) <= 1: pass else: self.volatile_vars.update(lhs_names) # mark any array writes as volatile if stmt.lhs.__class__ is Index: assert stmt.lhs.value.__class__ is Var, \ "Expected LHS array to be variable but instead got %s" % stmt self.volatile_vars.add(stmt.lhs.value.name)