Esempio n. 1
0
class CSE(Transform):
  
  def pre_apply(self, fn):
    # which expressions have already been computed
    # and stored in some variable?
    self.available_expressions = ScopedDictionary()
    
    ma = TypeBasedMutabilityAnalysis()
    # which types have elements that might
    # change between two accesses?
    self.mutable_types = ma.visit_fn(fn)
  
  def transform_expr(self, expr):
    if self.is_simple(expr):
      return expr 
    stored = self.available_expressions.get(expr)
    if stored is not None:
      return stored
    else:
      return Transform.transform_expr(self, expr)
  
  def transform_block(self, stmts):
    self.available_expressions.push()
    new_stmts = Transform.transform_block(self, stmts)
    self.available_expressions.pop()
    return new_stmts
  
  def transform_Assign(self, stmt):
    stmt.rhs = self.transform_expr(stmt.rhs)
    
    if stmt.lhs.__class__ is Var and \
       not self.is_simple(stmt.rhs) and \
       self.immutable(stmt.rhs) and \
       stmt.rhs not in self.available_expressions:
      self.available_expressions[stmt.rhs] = stmt.lhs
    return stmt
  
  
  def immutable_type(self, t):
    return t not in self.mutable_types

  def children(self, expr, allow_mutable = False):
    c = expr.__class__

    if c is Const or c is Var:
      return ()
    elif c is PrimCall or c is Closure:
      return expr.args
    elif c is ClosureElt:
      return (expr.closure,)
    elif c is Tuple:
      return expr.elts
    elif c is TupleProj:
      return (expr.tuple,)
    # WARNING: this is only valid
    # if attributes are immutable
    elif c is Attribute:
      return (expr.value,)
    elif c is Slice:
      return (expr.start, expr.stop, expr.step)
    elif c is Cast:
      return (expr.value,)
    elif c is Map or c is AllPairs:
      return expr.args
    elif c is Scan or c is Reduce:
      args = tuple(expr.args)
      init = (expr.init,) if expr.init else ()
      return init + args
    elif c is Call:
      # assume all Calls might modify their arguments
      if allow_mutable or all(self.immutable(arg) for arg in expr.args):
        return expr.args
      else:
        return None

    if allow_mutable or self.immutable_type(expr.type):
      if c is Array:
        return expr.elts
      elif c is ArrayView:
        return (expr.data, expr.shape, expr.strides, expr.offset, expr.size)
      elif c is Struct:
        return expr.args
      elif c is AllocArray:
        return (expr.shape,)
      elif c is Attribute:
        return (expr.value,)
    return None

  def immutable(self, expr):
    c = expr.__class__
    if c is Const:
      return True
    elif c is Tuple or c is TupleProj or \
         c is Closure or c is ClosureElt:
      return True
    # WARNING: making attributes always immutable
    elif c in (Attribute, Struct, Slice, ArrayView):
      return True
    # elif c is Attribute and expr.value.type.__class__ is TupleT:
    #  return True
    elif expr.type in self.mutable_types:
      return False
    child_nodes = self.children(expr, allow_mutable = False)
    if child_nodes is None:
      result =  False
    else:
      result = all(self.immutable(child) for child in child_nodes)
    return result
Esempio n. 2
0
class Simplify(Transform):
  def __init__(self):
    transform.Transform.__init__(self)
    # associate var names with any immutable values
    # they are bound to
    self.bindings = {}

    # which expressions have already been computed
    # and stored in some variable?
    self.available_expressions = ScopedDictionary()

  def pre_apply(self, fn):
    ma = TypeBasedMutabilityAnalysis()

    # which types have elements that might
    # change between two accesses?
    self.mutable_types = ma.visit_fn(fn)
    self.use_counts = use_count(fn)

  def immutable_type(self, t):
    return t not in self.mutable_types

  def children(self, expr, allow_mutable = False):
    c = expr.__class__

    if c is Const or c is Var:
      return ()
    elif c is PrimCall or c is Closure:
      return expr.args
    elif c is ClosureElt:
      return (expr.closure,)
    elif c is Tuple:
      return expr.elts
    elif c is TupleProj:
      return (expr.tuple,)
    # WARNING: this is only valid
    # if attributes are immutable
    elif c is Attribute:
      return (expr.value,)
    elif c is Slice:
      return (expr.start, expr.stop, expr.step)
    elif c is Cast:
      return (expr.value,)
    elif c is Map or c is AllPairs or c is IndexMap:
      return expr.args
    elif c is Scan or c is Reduce or c is IndexReduce:
      args = tuple(expr.args)
      init = (expr.init,) if expr.init else ()
      return init + args
    elif c is Call:
      # assume all Calls might modify their arguments
      if allow_mutable or all(self.immutable(arg) for arg in expr.args):
        return expr.args
      else:
        return None

    if allow_mutable or self.immutable_type(expr.type):
      if c is Array:
        return expr.elts
      elif c is ArrayView:
        return (expr.data, expr.shape, expr.strides, expr.offset, expr.size)
      elif c is Struct:
        return expr.args
      elif c is AllocArray:
        return (expr.shape,)
      elif c is Attribute:
        return (expr.value,)
    return None

  def immutable(self, expr):
    return (isinstance(expr, (Const, Tuple, Adverb, Cast, Var, PrimCall)) and 
            (all(self.immutable(c) for c in expr.children())) or \
           (isinstance(expr, (Attribute, TupleProj)) and \
            isinstance(expr.type, (ScalarT, TupleT, SliceT))))
    
                      
  def temp(self, expr, name = None, use_count = 1):
    """
    Wrapper around Codegen.assign_temp which also updates bindings and
    use_counts
    """
    if self.is_simple(expr):
      return expr 
    else:
      new_var = self.assign_temp(expr, name = name)
      self.bindings[new_var.name] = expr
      self.use_counts[new_var.name] = use_count
      return new_var

  def transform_expr(self, expr):
    if not self.is_simple(expr):
      stored = self.available_expressions.get(expr)
      if stored is not None:
        return stored
    return Transform.transform_expr(self, expr)

  def transform_Var(self, expr):
    name = expr.name
    prev_expr = expr
    while name in self.bindings:
      prev_expr = expr 
      expr = self.bindings[name]
      if expr.__class__ is Var:
        name = expr.name
      else:
        break
    c = expr.__class__ 
    if c is Var or c is Const:
      return expr
    else:
      return prev_expr
  
  def transform_Cast(self, expr):
    
    v = self.transform_expr(expr.value)
    if v.type == expr.type:
      return v
    elif v.__class__ is Const and isinstance(expr.type, ScalarT):
      return Const(expr.type.dtype.type(v.value), type = expr.type)
    elif self.is_simple(v):
      expr.value = v
      return expr
    else:
      expr.value = self.assign_temp(v)
      return expr

  def transform_Attribute(self, expr):
    v = self.transform_expr(expr.value)
    if v.__class__ is Var and v.name in self.bindings:
      stored_v = self.bindings[v.name]
      c = stored_v.__class__
      if c is Var or c is Struct:
        v = stored_v
      elif c is ArrayView:
        if expr.name == 'shape':
          return self.transform_expr(stored_v.shape)
        elif expr.name == 'strides':
          return self.transform_expr(stored_v.strides)
        elif expr.name == 'data':
          return self.transform_expr(stored_v.data)
      elif c is AllocArray:
        if expr.name == 'shape':
          return self.transform_expr(stored_v.shape)
    if v.__class__ is Struct:
      idx = v.type.field_pos(expr.name)
      return v.args[idx]
    elif v.__class__ is not Var:
      v = self.temp(v, "struct")
    if expr.value == v:
      return expr
    else:
      return Attribute(value = v, name = expr.name, type = expr.type)
    
  
  def transform_Closure(self, expr):
    expr.args = tuple(self.transform_args(expr.args))
    return expr

  def transform_Tuple(self, expr):
    expr.elts = tuple( self.transform_args(expr.elts))
    return expr

  def transform_TupleProj(self, expr):
    idx = expr.index
    assert isinstance(idx, int), \
        "TupleProj index must be an integer, got: " + str(idx)
    new_tuple = self.transform_expr(expr.tuple)

    if new_tuple.__class__ is Var and new_tuple.name in self.bindings:
      tuple_expr = self.bindings[new_tuple.name]
      if tuple_expr.__class__ is Tuple:
        return tuple_expr.elts[idx]
      elif tuple_expr.__class__ is Struct:
        return tuple_expr.args[idx]

    if not self.is_simple(new_tuple):
      new_tuple = self.assign_temp(new_tuple, "tuple")
    expr.tuple = new_tuple
    return expr

  def transform_ClosureElt(self, expr):
    idx = expr.index
    assert isinstance(idx, int), \
        "ClosureElt index must be an integer, got: " + str(idx)
    new_closure = self.transform_expr(expr.closure)
    if new_closure.__class__ is Var and new_closure.name in self.bindings:
      closure_expr = self.bindings[new_closure.name]
      if closure_expr.__class__ is Closure:
        return closure_expr.args[idx]

    if not self.is_simple(new_closure):
      new_closure = self.assign_temp(new_closure, "closure")
    expr.closure = new_closure
    return expr

  def transform_Call(self, expr):
    fn = self.transform_expr(expr.fn)
    args = self.transform_args(expr.args)
    if fn.type.__class__ is ClosureT:
      closure_elts = self.closure_elts(fn)
      combined_args = closure_elts + args
      if fn.type.fn.__class__ is TypedFn:
        fn = fn.type.fn
      else:
        assert isinstance(fn.type.fn, Fn)
        import type_inference
        fn = type_inference.specialize(fn, get_types(combined_args))
      assert fn.return_type == expr.type
      return Call(fn, combined_args, type = fn.return_type)
    else:
      expr.fn = fn
      expr.args = args
      return expr

  def transform_arg(self, x, name = None):

    return self.temp(self.transform_expr(x), name = name)
  
  def transform_args(self, args):
    return [self.transform_arg(x) for x in args]

  def transform_Array(self, expr):
    expr.elts = tuple(self.transform_args(expr.elts))
    return expr

  def transform_Index(self, expr):
    expr.value = self.transform_expr(expr.value)
    expr.index = self.transform_expr(expr.index)
    if expr.value.__class__ is Array and expr.index.__class__ is Const:
      assert isinstance(expr.index.value, (int, long)) and \
             len(expr.value.elts) > expr.index.value

      return expr.value.elts[expr.index.value]
    if expr.value.__class__ is not Var:
      expr.value = self.temp(expr.value, "array")
    return expr

  def transform_Struct(self, expr):
    new_args = self.transform_args(expr.args)
    return syntax.Struct(new_args, type = expr.type)

  def transform_PrimCall(self, expr):
    args = self.transform_args(expr.args)
    prim = expr.prim
    if all_constants(args):
      return syntax.Const(value = prim.fn(*collect_constants(args)),
                          type = expr.type)
    elif prim == prims.add:
      x,y = args
      if is_zero(x):
        return y
      elif is_zero(y):
        return x
    elif prim == prims.multiply:
      x,y = args
      if is_one(x):
        return y
      elif is_one(y):
        return x
      elif is_zero(x):
        return x
      elif is_zero(y):
        return y
      
    elif prim == prims.divide:
      x,y = args 
      if is_one(y):
        return x
      
    elif prim == prims.power:
      x,y = args
      if is_one(y):
        return self.cast(x, expr.type)
      elif is_zero(y):
        return syntax_helpers.one(expr.type)
      elif y.__class__ is Const and y.value == 2:
        return self.cast(self.mul(x, x, "sqr"), expr.type)

    expr.args = args
    return expr 
  
  def transform_Reduce(self, expr):
    
    init = self.transform_expr(expr.init)
    if not self.is_simple(init):
      expr.init = self.assign_temp(init, 'init')
    else:
      expr.init = init
    expr.args = self.transform_args(expr.args)
    expr.fn = self.transform_expr(expr.fn)
    expr.combine = self.transform_expr(expr.combine)
    return expr  
  
  def temp_in_block(self, expr, block, name = None):
    """
    If we need a temporary variable not in the current top scope but in a
    particular block, then use this function. (this function also modifies the
    bindings dictionary)
    """
    if name is None:
      name = "temp"
    var = self.fresh_var(expr.type, name)
    block.append(Assign(var, expr))
    self.bindings[var.name] = expr
    return var

  def set_binding(self, name, value):
    assert value.__class__ is not Var or \
        value.name != name, \
        "Can't set name %s bound to itself" % name
    self.bindings[name] = value

  def bind_var(self, name, rhs):
    if rhs.__class__ is Var:
      old_val = self.bindings.get(rhs.name)
      if old_val and self.is_simple(old_val):
        self.set_binding(name, old_val)
      else:
        self.set_binding(name, rhs)
    else:
      self.set_binding(name, rhs)

  def bind(self, lhs, rhs):
    lhs_class = lhs.__class__
    if lhs_class is Var:
      self.bind_var(lhs.name, rhs)
    elif lhs_class is Tuple and rhs.__class__ is Tuple:
      assert len(lhs.elts) == len(rhs.elts)
      for lhs_elt, rhs_elt in zip(lhs.elts, rhs.elts):
        self.bind(lhs_elt, rhs_elt)

  def transform_lhs_Index(self, lhs):
    # lhs.value = self.transform_expr(lhs.value)
    lhs.index = self.transform_expr(lhs.index)
    if lhs.value.__class__ is Var:
      stored = self.bindings.get(lhs.value.name)
      if stored and stored.__class__ is Var:
        lhs.value = stored
    else:
      lhs.value = self.assign_temp(lhs.value, "array")
    return lhs

  def transform_lhs_Attribute(self, lhs):
    # lhs.value = self.transform_expr(lhs.value)
    return lhs

  def transform_ExprStmt(self, stmt):
    """Don't run an expression unless it possibly has a side effect"""

    v = self.transform_expr(stmt.value)
    if self.immutable(v):
      return None
    else:
      stmt.value = v
      return stmt

  def transform_Assign(self, stmt):
    lhs = stmt.lhs
    rhs = self.transform_expr(stmt.rhs)

    lhs_class = lhs.__class__
    rhs_class = rhs.__class__
    if lhs_class is Var:
      if rhs.type.__class__ is NoneT and self.use_counts.get(lhs.name,0) == 0:
        return self.transform_stmt(ExprStmt(rhs))
      else:
        self.bind_var(lhs.name, rhs)
        if rhs_class is not Var and \
           rhs_class is not Const and \
           self.immutable(rhs):
          self.available_expressions.setdefault(rhs, lhs)
    elif lhs_class is Tuple:
      self.bind(lhs, rhs)
    # assigning x[i] = x[i]
    # does nothing
    elif lhs_class is Index:
      if rhs_class is Index and \
         lhs.value == rhs.value and \
         lhs.index == rhs.index:
        # kill effect-free writes like x[i] = x[i]
        return None
      elif rhs_class is Var and \
           lhs.value.__class__ is Var and \
           lhs.value.name == rhs.name and \
           lhs.index.type.__class__ is TupleT and \
           all(elt_t == slice_none_t for elt_t in lhs.index.type.elt_types):
        # also kill x[:] = x
        return None
      else:
        lhs = self.transform_lhs_Index(lhs)
        # when assigning x[j] = [1,2,3]
        # just rewrite it as a sequence of element assignments 
        # to avoid 
        if lhs.type.__class__ is ArrayT and \
           lhs.type.rank == 1 and \
           rhs.__class__ is Array:
          lhs_slice = self.assign_temp(lhs, "lhs_slice")
          for (elt_idx, elt) in enumerate(rhs.elts):
            lhs_idx = self.index(lhs_slice, const_int(elt_idx), temp = False)
            self.assign(lhs_idx, elt)
          return None
        elif not self.is_simple(rhs):
          rhs = self.assign_temp(rhs)
    else:
      assert lhs_class is Attribute
      assert False, "Considering making attributes immutable"
      lhs = self.transform_lhs_Attribute(lhs)

    if rhs_class is Var and \
       rhs.name in self.bindings and \
       self.use_counts.get(rhs.name, 1) == 1:
      self.use_counts[rhs.name] = 0
      rhs = self.bindings[rhs.name]
    stmt.lhs = lhs
    
    stmt.rhs = rhs
    return stmt

  def transform_block(self, stmts):
    self.available_expressions.push()
    new_stmts = Transform.transform_block(self, stmts)
    self.available_expressions.pop()
    return new_stmts

  def transform_merge(self, phi_nodes, left_block, right_block):
    result = {}
    for (k, (left, right)) in phi_nodes.iteritems():
      new_left = self.transform_expr(left)
      new_right = self.transform_expr(right)

      if not isinstance(new_left, (Const, Var)):
        new_left = self.temp_in_block(new_left, left_block)
      if not isinstance(new_right, (Const, Var)):
        new_right = self.temp_in_block(new_right, right_block)

      if new_left == new_right:
        # if both control flows yield the same value then
        # we don't actually need the phi-bound variable, we can just
        # replace the left value everywhere
        self.set_binding(k, new_left)
      else:
        result[k] = new_left, new_right
    return result

  def transform_If(self, stmt):
    stmt.true = self.transform_block(stmt.true)
    stmt.false = self.transform_block(stmt.false)
    stmt.merge = self.transform_merge(stmt.merge,
                                      left_block = stmt.true,
                                      right_block = stmt.false)
    stmt.cond = self.transform_expr(stmt.cond)
    return stmt

  def transform_loop_condition(self, expr, outer_block, loop_body, merge):
    """Normalize loop conditions so they are just simple variables"""

    if self.is_simple(expr):
      return self.transform_expr(expr)
    else:
      loop_carried_vars = [name for name in collect_var_names(expr)
                           if name in merge]
      if len(loop_carried_vars) == 0:
        return expr

      left_values = [merge[name][0] for name in loop_carried_vars]
      right_values = [merge[name][1] for name in loop_carried_vars]

      left_cond = subst.subst_expr(expr, dict(zip(loop_carried_vars,
                                                  left_values)))
      if not self.is_simple(left_cond):
        left_cond = self.temp_in_block(left_cond, outer_block, name = "cond")

      right_cond = subst.subst_expr(expr, dict(zip(loop_carried_vars,
                                                   right_values)))
      if not self.is_simple(right_cond):
        right_cond = self.temp_in_block(right_cond, loop_body, name = "cond")

      cond_var = self.fresh_var(left_cond.type, "cond")
      merge[cond_var.name] = (left_cond, right_cond)
      return cond_var

    
  def transform_While(self, stmt):
    stmt.body = self.transform_block(stmt.body)
    stmt.merge = self.transform_merge(stmt.merge,
                                      left_block = self.blocks.current(),
                                      right_block = stmt.body)
    stmt.cond = \
        self.transform_loop_condition(stmt.cond,
                                      outer_block = self.blocks.current(),
                                      loop_body = stmt.body,
                                      merge = stmt.merge)
    return stmt

  def transform_ForLoop(self, stmt):
    stmt.body = self.transform_block(stmt.body)
    stmt.merge = self.transform_merge(stmt.merge,
                                      left_block = self.blocks.current(),
                                      right_block = stmt.body)
    stmt.start = self.transform_arg(stmt.start, 'start')
    stmt.stop = self.transform_arg(stmt.stop, 'stop')
    if self.is_none(stmt.step):
      stmt.step = syntax_helpers.one(stmt.start.type)
    else:
      stmt.step = self.transform_arg(stmt.step, 'step')


    # if a loop is only going to run for one iteration, might as well get rid of
    # it
    if stmt.start.__class__ is Const and \
       stmt.stop.__class__ is Const and \
       stmt.step.__class__ is Const:
      if stmt.start.value >= stmt.stop.value:
        for (var_name, (input_value, _)) in stmt.merge.iteritems():
          var = Var(var_name, input_value.type)
          self.blocks.append(Assign(var, input_value))
        return None
      elif stmt.start.value + stmt.step.value >= stmt.stop.value:
        for (var_name, (input_value, _)) in stmt.merge.iteritems():
          var = Var(var_name, input_value.type)
          self.blocks.append(Assign(var, input_value))
        self.assign(stmt.var, stmt.start)
        self.blocks.top().extend(stmt.body)
        return None
    return stmt

  def transform_Return(self, stmt):
    new_value = self.transform_expr(stmt.value)
    """
    if new_value.__class__ is Var and \
       new_value.name in self.use_counts and \
       self.use_counts[new_value.name] == 1 and \
       new_value.name in self.bindings:
      stored = self.bindings[stmt.value.name]
      if self.immutable(stored) and stored.__class__ is not AllPairs:
        print "Replacing %s => %s" % (stmt, stored)
        stmt.value = stored
        return stmt
    """
    if new_value != stmt.value:
      stmt.value = new_value
    return stmt