Ejemplo n.º 1
0
class ValueRangeAnalyis(SyntaxVisitor):
  def __init__(self):
    SyntaxVisitor.__init__(self)
    self.ranges = {} 
    self.old_values = ScopedDict()
    self.old_values.push()

  def get(self, expr):
    c = expr.__class__ 
    
    if expr.type.__class__ is NoneT:
      return const_none 
    elif c is Const:
      return Interval(expr.value, expr.value)
    
    elif c is Var and expr.name in self.ranges:
      return self.ranges[expr.name]
    
    elif c is Tuple:
      elt_values = [self.get(elt) for elt in expr.elts]
      return mk_tuple(elt_values)
    
    elif c is TupleProj:
      tup = self.get(expr.tuple)
      idx = unwrap_constant(expr.index) 
      if tup.__class__ is TupleOfIntervals:
        return tup.elts[idx]
      
    elif c is Slice:
      start = self.get(expr.start)
      stop = self.get(expr.stop)
      if expr.step.type ==  NoneType:
        step = const_one   
      else:
        step = self.get(expr.step)
      return mk_slice(start, stop, step)
    
    elif c is Shape:
      ndims = get_rank(expr.array.type)   
      return mk_tuple([positive_interval] * ndims)
    
    elif c is Attribute:
      if expr.name == 'shape':
        ndims = get_rank(expr.value.type)   
        return mk_tuple([positive_interval] * ndims)
      elif expr.name == 'start':
        sliceval = self.get(expr.value)
        if isinstance(sliceval, SliceOfIntervals):
          return sliceval.start
      elif expr.name == 'stop':
        sliceval = self.get(expr.value)
        if isinstance(sliceval, SliceOfIntervals):
          return sliceval.stop
      elif expr.name == 'step':
        sliceval = self.get(expr.value)
        if isinstance(sliceval, SliceOfIntervals):
          return sliceval.step
    elif c is PrimCall:
      p = expr.prim 
      if p.nin == 2:
        x = self.get(expr.args[0])
        y = self.get(expr.args[1])
        if p == prims.add:

          return self.add_range(x,y)
        elif p == prims.subtract:
          return self.sub_range(x,y)
        elif p == prims.multiply:
          return self.mul_range(x,y)
      elif p.nin == 1:
        x = self.get(expr.args[0])
        if p == prims.negative:
          return self.neg_range(x)
    return any_value 
  
  def set(self, name, val):
    if val is not None and val is not unknown_value:
      old_value = self.ranges.get(name, unknown_value)
      self.ranges[name] = val
      if old_value != val and old_value is not unknown_value:
        self.old_values[name] = old_value 

      
  def add_range(self, x, y):
    if not isinstance(x, Interval) or not isinstance(y, Interval):
      return any_value 
    return Interval (x.lower + y.lower, x.upper + y.upper)
  
  def sub_range(self, x, y):
    if not isinstance(x, Interval) or not isinstance(y, Interval):
      return any_value
    return Interval(x.lower - y.upper, x.upper - y.lower)
  
  def mul_range(self, x, y):
    if not isinstance(x, Interval) or not isinstance(y, Interval):
      return any_value
    xl, xu = x.lower, x.upper 
    yl, yu = y.lower, y.upper 
    products = (xl * yl, xl * yu, xu * yl, xu * yu)
    lower = min(products)
    upper = max(products)
    return Interval(lower, upper)
  
  def neg_range(self, x):
    if not isinstance(x, Interval):
      return any_value 
    return Interval(-x.upper, -x.lower) 
  
  def visit_Assign(self, stmt):
    if stmt.lhs.__class__ is Var:
      name = stmt.lhs.name 
      v = self.get(stmt.rhs)
      self.set(name, v)
  
  def visit_merge_left(self, phi_nodes):
    for (k, (left, _)) in phi_nodes.iteritems():
      left_val = self.get(left)
      self.set(k, left_val)
      
      
  def visit_merge(self, phi_nodes):    
    for (k, (left,right)) in phi_nodes.iteritems():
      left_val = self.get(left)
      right_val = self.get(right)
      self.set(k, left_val.combine(right_val))

  def visit_Select(self, expr):
    return self.get(expr.true_value).combine(self.get(expr.false_value))
  
  def always_positive(self, x, inclusive = True):
    if not isinstance(x, Interval):
      return False 
    elif inclusive:
      return x.lower >= 0
    else:
      return x.lower > 0
  
  def always_negative(self, x, inclusive = True):
    if not isinstance(x, Interval):
      return False 
    elif inclusive:
      return x.upper <= 0
    else:
      return x.upper < 0
    
  
  def widen(self, old_values):
    for (k, oldv) in old_values.iteritems():
      newv = self.ranges[k]
      if oldv != newv:
        self.ranges[k] = oldv.widen(newv)
          
  def run_loop(self, body, merge):
    
    
    # run loop for the first time 
    self.old_values.push()
    self.visit_merge_left(merge)
    self.visit_block(body)
    self.visit_merge(merge)
    
    #run loop for the second time 
    self.visit_block(body)
    self.visit_merge(merge)
    
    old_values = self.old_values.pop()
    self.widen(old_values)
    
    # TODO: verify that it's safe not to run loop with widened values
    #self.visit_block(body)
    #self.visit_merge(merge)
    
    
  def visit_While(self, stmt):
    self.run_loop(stmt.body, stmt.merge)
  
  
  def visit_ForLoop(self, stmt):    

    start = self.get(stmt.start)
    stop = self.get(stmt.stop)
    step = self.get(stmt.step)
    name = stmt.var.name
    iterator_range = any_value  
    if isinstance(start, Interval) and isinstance(stop, Interval):
      lower = min (start.lower, start.upper, stop.lower, stop.upper)
      upper = max (start.lower, start.upper, stop.lower, stop.upper)
      iterator_range = Interval(lower,upper)
      
    elif isinstance(start, Interval) and isinstance(step, Interval):
      if self.always_positive(step):
        iterator_range = Interval(start.lower, np.inf)
      elif self.always_negative(step, inclusive = False):
        iterator_range = Interval(-np.inf, start.upper)
    
    elif isinstance(stop, Interval) and isinstance(step, Interval):
      if self.always_positive(step):
        iterator_range = Interval(-np.inf, stop.upper)
      elif self.always_negative(step, inclusive = False):
        iterator_range = Interval(stop.lower, np.inf)  
    self.set(name, iterator_range)
    self.run_loop(stmt.body, stmt.merge)

      
    
Ejemplo n.º 2
0
class Simplify(Transform):
  def __init__(self):
    transform.Transform.__init__(self)
    # associate var names with any immutable values
    # they are bound to
    self.bindings = ScopedDict()

    # which expressions have already been computed
    # and stored in some variable?
    self.available_expressions = ScopedDict()

  def pre_apply(self, fn):
    ma = TypeBasedMutabilityAnalysis()

    # which types have elements that might
    # change between two accesses?
    self.mutable_types = ma.visit_fn(fn)
    self.use_counts = use_count(fn)
    
  def immutable_type(self, t):
    return t not in self.mutable_types

  def children(self, expr, allow_mutable = False):
    c = expr.__class__

    if c is Const or c is Var:
      return ()
    elif c is PrimCall or c is Closure:
      return expr.args
    elif c is ClosureElt:
      return (expr.closure,)
    elif c is Tuple:
      return expr.elts
    elif c is TupleProj:
      return (expr.tuple,)
    # WARNING: this is only valid
    # if attributes are immutable
    elif c is Attribute:
      return (expr.value,)
    elif c is Slice:
      return (expr.start, expr.stop, expr.step)
    elif c is Cast:
      return (expr.value,)
    elif c is Map or c is OuterMap or c is IndexMap:
      return expr.args
    elif c is Scan or c is Reduce or c is IndexReduce or c is FilterReduce:
      args = tuple(expr.args)
      init = (expr.init,) if expr.init else ()
      return init + args
    elif c is Call:
      # assume all Calls might modify their arguments
      if allow_mutable or all(self.immutable(arg) for arg in expr.args):
        return expr.args
      else:
        return None

    if allow_mutable or self.immutable_type(expr.type):
      if c is Array:
        return expr.elts
      elif c is ArrayView:
        return (expr.data, expr.shape, expr.strides, expr.offset, expr.size)
      elif c is Struct:
        return expr.args
      elif c is AllocArray:
        return (expr.shape,)
      elif c is Attribute:
        return (expr.value,)
    return None

  _immutable_classes = set([Const,  Var, 
                            Closure, ClosureElt, 
                            Tuple, TupleProj, 
                            Cast, PrimCall, 
                            TypedFn, UntypedFn, 
                            ArrayView, 
                            Slice, 
                            ])
  
  def immutable(self, expr):
    """
    TODO: make all this mutability/immutability stuff sane 
    """
    klass = expr.__class__ 
    
    
    result = (klass in self._immutable_classes and 
                (all(self.immutable(c) for c in expr.children()))) or \
             (klass is Attribute and isinstance(expr.type, ImmutableT))
    return result 
    
                      
  def temp(self, expr, name = None, use_count = 1):
    """
    Wrapper around Codegen.assign_name which also updates bindings and
    use_counts
    """
    if self.is_simple(expr):
      return expr 
    else:
      new_var = self.assign_name(expr, name = name)
      self.bindings[new_var.name] = expr
      self.use_counts[new_var.name] = use_count
      return new_var

  def transform_expr(self, expr):
    if self.is_simple(expr):
      if expr.type == NoneType:
        return none 
      else:
        return Transform.transform_expr(self, expr)
    stored = self.available_expressions.get(expr)
    if stored is not None: 
      return stored
    return Transform.transform_expr(self, expr)

    
  def transform_Var(self, expr):
    t = expr.type 
    if t.__class__ is NoneT:
      return none 
    elif t.__class__ is SliceT and \
         t.start_type == NoneType and \
         t.stop_type == NoneType and \
         t.step_type == NoneType:
      return slice_none 
    
    name = expr.name
    prev_expr = expr

    while name in self.bindings:
      prev_expr = expr 
      expr = self.bindings[name]
      if expr.__class__ is Var:
        name = expr.name
      else:
        break
    c = expr.__class__ 
    
    if c is Var or c is Const:

      return expr
    else:

      return prev_expr
  
  def transform_Cast(self, expr):
    
    v = self.transform_expr(expr.value)
    if v.type == expr.type:
      return v
    elif v.__class__ is Const and isinstance(expr.type, ScalarT):
      return Const(expr.type.dtype.type(v.value), type = expr.type)
    elif self.is_simple(v):
      expr.value = v
      return expr
    else:
      expr.value = self.assign_name(v)
      return expr

  def transform_Attribute(self, expr):
    v = self.transform_expr(expr.value)
    
    if v.__class__ is Var and v.name in self.bindings:
      stored_v = self.bindings[v.name]
      c = stored_v.__class__
      if c is Var or c is Struct:
        v = stored_v
      elif c is ArrayView:
        if expr.name == 'shape':
          return self.transform_expr(stored_v.shape)
        elif expr.name == 'strides':
          return self.transform_expr(stored_v.strides)
        elif expr.name == 'data':
          return self.transform_expr(stored_v.data)
      elif c is AllocArray:
        if expr.name == 'shape':
          return self.transform_expr(stored_v.shape)
      elif c is Slice:
        if expr.name == "start":
          return self.transform_expr(stored_v.start)
        elif expr.name == "stop":
          return self.transform_expr(stored_v.stop)
        else:
          assert expr.name == "step", "Unexpected attribute for slice: %s" % expr.name  
          return self.transform_expr(stored_v.step)
    if v.__class__ is Struct:
      idx = v.type.field_pos(expr.name)
      return v.args[idx]
    elif v.__class__ is not Var:
      v = self.temp(v, "struct")
    if expr.value == v:
      return expr
    else:
      return Attribute(value = v, name = expr.name, type = expr.type)
  
  def transform_Closure(self, expr):
    expr.args = tuple(self.transform_simple_exprs(expr.args))
    return expr

  def transform_Tuple(self, expr):
    expr.elts = tuple( self.transform_simple_exprs(expr.elts))
    return expr

  def transform_TupleProj(self, expr):
    idx = expr.index
    assert isinstance(idx, int), \
        "TupleProj index must be an integer, got: " + str(idx)
    new_tuple = self.transform_expr(expr.tuple)

    if new_tuple.__class__ is Var and new_tuple.name in self.bindings:
      tuple_expr = self.bindings[new_tuple.name]
      if tuple_expr.__class__ is Tuple:
        assert idx < len(tuple_expr.elts), \
          "Too few elements in tuple %s : %s, elts = %s" % (expr, tuple_expr.type, tuple_expr.elts)
        return tuple_expr.elts[idx]
      elif tuple_expr.__class__ is Struct:
        assert idx < len(tuple_expr.args), \
          "Too few args in closure %s : %s, elts = %s" % (expr, tuple_expr.type, tuple_expr.elts) 
        return tuple_expr.args[idx]

    if not self.is_simple(new_tuple):
      new_tuple = self.assign_name(new_tuple, "tuple")
    expr.tuple = new_tuple
    return expr

  def transform_ClosureElt(self, expr):
    idx = expr.index
    assert isinstance(idx, int), \
        "ClosureElt index must be an integer, got: " + str(idx)
    new_closure = self.transform_expr(expr.closure)

    if new_closure.__class__ is Var and new_closure.name in self.bindings:
      closure_expr = self.bindings[new_closure.name]
      if closure_expr.__class__ is Closure:
        return closure_expr.args[idx]

    if not self.is_simple(new_closure):
      new_closure = self.assign_name(new_closure, "closure")
    expr.closure = new_closure
    return expr

  def transform_Call(self, expr):
    fn = self.transform_expr(expr.fn)
    args = self.transform_simple_exprs(expr.args)
    if fn.type.__class__ is ClosureT:
      closure_elts = self.closure_elts(fn)
      combined_args = tuple(closure_elts) + tuple(args)
      if fn.type.fn.__class__ is TypedFn:
        fn = fn.type.fn
      else:
        assert isinstance(fn.type.fn, UntypedFn)
        from .. type_inference import specialize 
        fn = specialize(fn, get_types(combined_args))
      assert fn.return_type == expr.type
      return Call(fn, combined_args, type = fn.return_type)
    else:
      expr.fn = fn
      expr.args = args
      return expr

  
  def transform_simple_expr(self, expr, name = None):
    if name is None: 
      name = "temp"
    result = self.transform_expr(expr)
    if not self.is_simple(result):
      return self.assign_name(result, name)
    else:
      return result
  
  def transform_simple_exprs(self, args):
    return [self.transform_simple_expr(x) for x in args]

  def transform_Array(self, expr):
    expr.elts = tuple(self.transform_simple_exprs(expr.elts))
    return expr
  
  def transform_Slice(self, expr):
    expr.start = self.transform_simple_expr(expr.start)
    expr.stop = self.transform_simple_expr(expr.stop)
    expr.step = self.transform_simple_expr(expr.step)
    return expr 

  
  def transform_index_expr(self, expr):
    if expr.__class__ is Tuple:
      new_elts = []
      for elt in expr.elts:
        new_elt = self.transform_expr(elt)
        if not self.is_simple(new_elt) and new_elt.type.__class__ is not SliceT:
          new_elt = self.temp(new_elt, "index_tuple_elt")
        new_elts.append(new_elt)
      expr.elts = tuple(new_elts)
      return expr 
    else:
      return self.transform_expr(expr) 
  
  def transform_Index(self, expr):
    expr.value = self.transform_expr(expr.value)

    expr.index = self.transform_index_expr(expr.index)

    if expr.value.__class__ is Array and expr.index.__class__ is Const:
      assert isinstance(expr.index.value, (int, long)) and \
             len(expr.value.elts) > expr.index.value
      return expr.value.elts[expr.index.value]
    
    # take expressions like "a[i][j]" and turn them into "a[i,j]" 
    if expr.value.__class__ is Index: 
      base_array = expr.value.value
      if isinstance(base_array.type, ArrayT):
        base_index = expr.value.index 
        if isinstance(base_index.type, TupleT):
          indices = self.tuple_elts(base_index)
        else:
          assert isinstance(base_index.type, ScalarT), \
            "Unexpected index type %s : %s in %s" % (base_index, base_index.type, expr)
          indices = [base_index]
        if isinstance(expr.index.type, TupleT):
          indices = tuple(indices) + tuple(self.tuple_elts(expr.index))
        else:
          assert isinstance(expr.index.type, ScalarT), \
            "Unexpected index type %s : %s in %s" % (expr.index, expr.index.type, expr)
          indices = tuple(indices) + (expr.index,)
        expr = self.index(base_array, self.tuple(indices))
        return self.transform_expr(expr)
    if expr.value.__class__ is not Var:
      expr.value = self.temp(expr.value, "array")
    return expr

  def transform_Struct(self, expr):
    new_args = self.transform_simple_exprs(expr.args)
    return syntax.Struct(new_args, type = expr.type)

  def transform_Select(self, expr):
    cond = self.transform_expr(expr.cond)
    trueval = self.transform_expr(expr.true_value)
    falseval = self.transform_expr(expr.false_value)
    if is_true(cond):
      return trueval 
    elif is_false(cond):
      return falseval
    elif trueval == falseval:
      return trueval 
    else:
      expr.cond = cond 
      expr.false_value = falseval 
      expr.true_value = trueval 
      return expr    
  
  def transform_PrimCall(self, expr):
    args = self.transform_simple_exprs(expr.args)
    prim = expr.prim
    if all_constants(args):
      return syntax.Const(value = prim.fn(*collect_constants(args)),
                          type = expr.type)
    
    if len(args) == 1:
      x = args[0]
      if prim == prims.logical_not:
        if is_false(x):
          return true 
        elif is_true(x):
          return false 
    if len(args) == 2:
      x,y = args 
      
      if prim == prims.add:
        if is_zero(x):
          return y
        elif is_zero(y):
          return x
        if y.__class__ is Const and y.value < 0:
          expr.prim = prims.subtract
          expr.args = (x, Const(value = -y.value, type = y.type))
          
          return expr 
        elif x.__class__ is Const and x.value < 0:
          expr.prim = prims.subtract
          expr.args = (y, Const(value = -x.value, type = x.type)) 
          return expr 
        
      elif prim == prims.subtract:
        if is_zero(y):
          return x
        elif is_zero(x) and y.__class__ is Var:
           
          stored = self.bindings.get(y.name)
          
          # 0 - (a * b) --> -a * b |or| a * -b 
          if stored and stored.__class__ is PrimCall and stored.prim == prims.multiply:
            
            a,b = stored.args
            if a.__class__ is Const:
              expr.prim = prims.multiply
              neg_a = Const(value = -a.value, type = a.type)
              expr.args = [neg_a, b]
              return expr 
            elif b.__class__ is Const:
              expr.prim = prims.multiply
              neg_b = Const(value = -b.value, type = b.type)
              expr.args = [a, neg_b]
              return expr 
            
      elif prim == prims.multiply:
      
        if is_one(x):
          return y
        elif is_one(y):
          return x
        elif is_zero(x):
          return x
        elif is_zero(y):
          return y
        
      elif prim == prims.divide and is_one(y):
        return x
      
      elif prim == prims.power:
      
        if is_one(y):
          return self.cast(x, expr.type)
        elif is_zero(y):
          return one(expr.type)
        elif y.__class__ is Const and y.value == 2:
          return self.cast(self.mul(x, x, "sqr"), expr.type)
      elif prim == prims.logical_and:
        if is_true(x):
          return y 
        elif is_false(x) or is_false(y):
          return false 
      elif prim == prims.logical_or:
        if is_true(x) or is_true(y):
          return true 
        elif is_false(x) or is_false(y):
          return false 
    expr.args = args
    return expr 
  
  def transform_Map(self, expr):
    expr.args = self.transform_simple_exprs(expr.args)
    expr.fn = self.transform_expr(expr.fn)
    expr.axis = self.transform_if_expr(expr.axis)
    return expr  
  
  def transform_OuterMap(self, expr):
    expr.args = self.transform_simple_exprs(expr.args)
    expr.fn = self.transform_expr(expr.fn)
    expr.axis = self.transform_if_expr(expr.axis)
    return expr 
  
  def transform_shape(self, expr):
    if isinstance(expr, Tuple):
      expr.elts = tuple(self.transform_simple_exprs(expr.elts))
      return expr 
    else:
      return self.transform_simple_expr(expr)
  
  def transform_ParFor(self, stmt):
    stmt.bounds = self.transform_shape(stmt.bounds)
    stmt.fn = self.transform_expr(stmt.fn)
    return stmt
  
  def transform_IndexMap(self, expr):
    expr.fn = self.transform_expr(expr.fn)
    expr.shape = self.transform_shape(expr.shape)
    return expr 
  
  def transform_IndexReduce(self, expr):
    expr.fn = self.transform_if_expr(expr.fn)
    expr.combine = self.transform_expr(expr.combine)
    expr.init = self.transform_if_expr(expr.init)
    expr.shape = self.transform_shape(expr.shape)
    return expr 
  
  def transform_IndexScan(self, expr):
    expr.fn = self.transform_if_expr(expr.fn)
    expr.combine = self.transform_expr(expr.combine)
    expr.emit = self.transform_if_expr(expr.emit)
    expr.init = self.transform_if_expr(expr.init)
    expr.shape = self.transform_shape(expr.shape)
    return expr 
  
     
  def transform_ConstArray(self, expr):
    expr.shape = self.transform_shape(expr.shape)
    expr.value = self.transform_simple_expr(expr.value)
    return expr
  
  def transform_ConstArrayLike(self, expr):
    expr.array = self.transform_simple_expr(expr.array)
    expr.value = self.transform_simple_expr(expr.value)
  
  def transform_Reduce(self, expr):
    
    init = self.transform_expr(expr.init)
    if not self.is_simple(init):
      expr.init = self.assign_name(init, 'init')
    else:
      expr.init = init
    expr.args = self.transform_simple_exprs(expr.args)
    expr.fn = self.transform_expr(expr.fn)
    expr.combine = self.transform_expr(expr.combine)
    return expr  
  
  def temp_in_block(self, expr, block, name = None):
    """
    If we need a temporary variable not in the current top scope but in a
    particular block, then use this function. (this function also modifies the
    bindings dictionary)
    """
    if name is None:
      name = "temp"
    var = self.fresh_var(expr.type, name)
    block.append(Assign(var, expr))
    self.bindings[var.name] = expr
    return var

  def set_binding(self, name, value):
    assert value.__class__ is not Var or \
        value.name != name, \
        "Can't set name %s bound to itself" % name
    self.bindings[name] = value

  def bind_var(self, name, rhs):
    if rhs.__class__ is Var:
      old_val = self.bindings.get(rhs.name)
      if old_val and self.is_simple(old_val):
        self.set_binding(name, old_val)
      else:
        self.set_binding(name, rhs)
    else:
      self.set_binding(name, rhs)

  def bind(self, lhs, rhs):
    lhs_class = lhs.__class__
    if lhs_class is Var:
      self.bind_var(lhs.name, rhs)
    elif lhs_class is Tuple and rhs.__class__ is Tuple:
      assert len(lhs.elts) == len(rhs.elts)
      for lhs_elt, rhs_elt in zip(lhs.elts, rhs.elts):
        self.bind(lhs_elt, rhs_elt)

  def transform_lhs_Index(self, lhs):
    lhs.index = self.transform_index_expr(lhs.index)
    if lhs.value.__class__ is Var:
      stored = self.bindings.get(lhs.value.name)
      if stored and stored.__class__ is Var:
        lhs.value = stored
    else:
      lhs.value = self.assign_name(lhs.value, "array")
    return lhs

  def transform_lhs_Attribute(self, lhs):
    # lhs.value = self.transform_expr(lhs.value)
    return lhs

  def transform_ExprStmt(self, stmt):
    """Don't run an expression unless it possibly has a side effect"""

    v = self.transform_expr(stmt.value)
    if self.immutable(v):
      return None
    else:
      stmt.value = v
      return stmt

  def transform_Assign(self, stmt):
    
    lhs = stmt.lhs
    rhs = self.transform_expr(stmt.rhs)
    lhs_class = lhs.__class__
    rhs_class = rhs.__class__

    if lhs_class is Var:
      if lhs.type.__class__ is NoneT and self.use_counts.get(lhs.name,0) == 0:
        return self.transform_stmt(ExprStmt(rhs))
      elif self.immutable(rhs):
        
        self.bind_var(lhs.name, rhs)
        if rhs_class is not Var and rhs_class is not Const:
          self.available_expressions.setdefault(rhs, lhs)
    elif lhs_class is Tuple:
      self.bind(lhs, rhs)

    elif lhs_class is Index:
      if rhs_class is Index and \
         lhs.value == rhs.value and \
         lhs.index == rhs.index:
        # kill effect-free writes like x[i] = x[i]
        return None
      elif rhs_class is Var and \
           lhs.value.__class__ is Var and \
           lhs.value.name == rhs.name and \
           lhs.index.type.__class__ is TupleT and \
           all(elt_t == slice_none_t for elt_t in lhs.index.type.elt_types):
        # also kill x[:] = x
        return None
      else:
        lhs = self.transform_lhs_Index(lhs)
        # when assigning x[j] = [1,2,3]
        # just rewrite it as a sequence of element assignments 
        # to avoid 
        if lhs.type.__class__ is ArrayT and \
           lhs.type.rank == 1 and \
           rhs.__class__ is Array:
          lhs_slice = self.assign_name(lhs, "lhs_slice")
          for (elt_idx, elt) in enumerate(rhs.elts):
            lhs_idx = self.index(lhs_slice, const_int(elt_idx), temp = False)
            self.assign(lhs_idx, elt)
          return None
        elif not self.is_simple(rhs):
          rhs = self.assign_name(rhs)
    else:
      assert lhs_class is Attribute
      assert False, "Considering making attributes immutable"
      lhs = self.transform_lhs_Attribute(lhs)

    if rhs_class is Var and \
       rhs.name in self.bindings and \
       self.use_counts.get(rhs.name, 1) == 1:
      self.use_counts[rhs.name] = 0
      rhs = self.bindings[rhs.name]
    stmt.lhs = lhs
    stmt.rhs = rhs
    return stmt

  def transform_block(self, stmts, keep_bindings = False):
    self.available_expressions.push()
    self.bindings.push()
    
    new_stmts = Transform.transform_block(self, stmts)
    
    self.available_expressions.pop()
    if not keep_bindings:
      self.bindings.pop()
    return new_stmts

  def enter_loop(self, phi_nodes):
    result = {}
    for (k, (left,right)) in phi_nodes.iteritems():
      new_left = self.transform_expr(left)
      if new_left == right:
        self.set_binding(k, new_left)
      else:
        result[k] = (new_left, right)
    return result 
  
  def transform_merge(self, phi_nodes, left_block, right_block):
    result = {}
    for (k, (left, right)) in phi_nodes.iteritems():
      new_left = self.transform_expr(left)
      new_right = self.transform_expr(right)

      if not isinstance(new_left, (Const, Var)):
        new_left = self.temp_in_block(new_left, left_block)
      if not isinstance(new_right, (Const, Var)):
        new_right = self.temp_in_block(new_right, right_block)

      if new_left == new_right:
        # if both control flows yield the same value then
        # we don't actually need the phi-bound variable, we can just
        # replace the left value everywhere
        self.assign(Var(name= k, type = new_left.type), new_left)
        self.set_binding(k, new_left)
      else:
        result[k] = new_left, new_right
    return result

  def transform_If(self, stmt):
    stmt.true = self.transform_block(stmt.true, keep_bindings = True)
    stmt.false = self.transform_block(stmt.false, keep_bindings=True)
    stmt.merge = self.transform_merge(stmt.merge,
                                      left_block = stmt.true,
                                      right_block = stmt.false)
    self.bindings.pop()
    self.bindings.pop()
    stmt.cond = self.transform_simple_expr(stmt.cond, "cond")
    if len(stmt.true) == 0 and len(stmt.false) == 0 and len(stmt.merge) <= 2:
      for (lhs_name, (true_expr, false_expr)) in stmt.merge.items():
        lhs_type = self.lookup_type(lhs_name)
        lhs_var = Var(name = lhs_name, type = lhs_type)
        assert true_expr.type == false_expr.type, \
          "Unexpcted type mismatch: %s != %s" % (true_expr.type, false_expr.type)
        rhs = Select(stmt.cond, true_expr, false_expr, type = true_expr.type)
        self.bind_var(lhs_name, rhs)
        self.assign(lhs_var, rhs)
      return None 
    return stmt

  def transform_loop_condition(self, expr, outer_block, loop_body, merge):
    """Normalize loop conditions so they are just simple variables"""

    if self.is_simple(expr):
      return self.transform_expr(expr)
    else:
      loop_carried_vars = [name for name in collect_var_names(expr)
                           if name in merge]
      if len(loop_carried_vars) == 0:
        return expr

      left_values = [merge[name][0] for name in loop_carried_vars]
      right_values = [merge[name][1] for name in loop_carried_vars]

      left_cond = subst.subst_expr(expr, dict(zip(loop_carried_vars,
                                                  left_values)))
      if not self.is_simple(left_cond):
        left_cond = self.temp_in_block(left_cond, outer_block, name = "cond")

      right_cond = subst.subst_expr(expr, dict(zip(loop_carried_vars,
                                                   right_values)))
      if not self.is_simple(right_cond):
        right_cond = self.temp_in_block(right_cond, loop_body, name = "cond")

      cond_var = self.fresh_var(left_cond.type, "cond")
      merge[cond_var.name] = (left_cond, right_cond)
      return cond_var

    
  def transform_While(self, stmt):
    merge = self.enter_loop(stmt.merge)
    stmt.body = self.transform_block(stmt.body)
    stmt.merge = self.transform_merge(merge,
                                      left_block = self.blocks.current(),
                                      right_block = stmt.body)
    stmt.cond = \
        self.transform_loop_condition(stmt.cond,
                                      outer_block = self.blocks.current(),
                                      loop_body = stmt.body,
                                      merge = stmt.merge)
    return stmt

  def transform_ForLoop(self, stmt):
    merge = self.enter_loop(stmt.merge)
    stmt.body = self.transform_block(stmt.body)
    stmt.merge = self.transform_merge(merge,
                                      left_block = self.blocks.current(),
                                      right_block = stmt.body)
    stmt.start = self.transform_simple_expr(stmt.start, 'start')
    stmt.stop = self.transform_simple_expr(stmt.stop, 'stop')
    if self.is_none(stmt.step):
      stmt.step = one(stmt.start.type)
    else:
      stmt.step = self.transform_simple_expr(stmt.step, 'step')

    # if a loop is only going to run for one iteration, might as well get rid of
    # it
    if stmt.start.__class__ is Const and \
       stmt.stop.__class__ is Const and \
       stmt.step.__class__ is Const:
      if stmt.start.value >= stmt.stop.value:
        for (var_name, (input_value, _)) in stmt.merge.iteritems():
          var = Var(var_name, input_value.type)
          self.blocks.append(Assign(var, input_value))
        return None
      elif stmt.start.value + stmt.step.value >= stmt.stop.value:
        for (var_name, (input_value, _)) in stmt.merge.iteritems():
          var = Var(var_name, input_value.type)
          self.blocks.append(Assign(var, input_value))
        self.assign(stmt.var, stmt.start)
        self.blocks.top().extend(stmt.body)
        return None
    return stmt

  def transform_Return(self, stmt):
    new_value = self.transform_expr(stmt.value)
    if new_value != stmt.value:
      stmt.value = new_value
    return stmt
Ejemplo n.º 3
0
class AST_Translator(ast.NodeVisitor):
  def __init__(self, 
               globals_dict=None, 
               closure_cell_dict=None,
               parent=None, 
               function_name = None, 
               filename = None):
    # assignments which need to get prepended at the beginning of the
    # function
    self.globals = globals_dict
    self.blocks = NestedBlocks()

    self.parent = parent 
    self.scopes = ScopedDict()
    
    self.globals_dict = globals_dict 
    self.closure_cell_dict = closure_cell_dict 
    
    # mapping from names/paths to either a closure cell reference or a 
    # global value 
    self.python_refs = OrderedDict()
    
    self.original_outer_names = []
    self.localized_outer_names = []
    self.filename = filename 
    self.function_name = function_name 
    self.push()

  def push(self, scope = None, block = None):
    if scope is None:
      scope = {}
    if block is None:
      block = []
    self.scopes.push(scope)
    self.blocks.push(block)


  def pop(self):
    scope = self.scopes.pop()
    block = self.blocks.pop()
    return scope, block
  
  def fresh_name(self, original_name):
    fresh_name = names.fresh(original_name)
    self.scopes[original_name] = fresh_name
    return fresh_name
    
  def fresh_names(self, original_names):
    return map(self.fresh_name, original_names)

  def fresh_var(self, name):
    return syntax.Var(self.fresh_name(name))

  def fresh_vars(self, original_names):
    return map(self.fresh_var, original_names)
  
  def current_block(self):
    return self.blocks.top()

  def current_scope(self):
    return self.scopes.top()


  def syntax_to_value(self, expr):
    if isinstance(expr, ast.Num):
      return expr.n
    elif isinstance(expr, ast.Tuple):
      return tuple(self.syntax_to_value(elt) for elt in expr.elts)
    elif isinstance(expr, ast.Name):
      return self.lookup_global(expr.id)
    elif isinstance(expr, ast.Attribute):
      left = self.syntax_to_value(expr.value)
      if isinstance(left, ExternalValue):
        left = left.value 
      return getattr(left, expr.attr) 

  def lookup_global(self, key):
    if isinstance(key, (list, tuple)):
      assert len(key) == 1
      key = key[0]
    else:
      assert isinstance(key, str), "Invalid global key: %s" % (key,)

    if self.globals:
      if key in self.globals:
        return self.globals[key]
      elif key in __builtin__.__dict__:

        return __builtin__.__dict__[key]
      else:
        assert False, "Couldn't find global name %s" % key
    else:
      assert self.parent is not None
      return self.parent.lookup_global(key)
    
  def is_global(self, key):
    if isinstance(key, (list, tuple)):
      key = key[0]
    if key in self.scopes:
      return False
    elif self.closure_cell_dict and key in self.closure_cell_dict:
      return False
    if self.globals:
      return key in self.globals or key in __builtins__  
    assert self.parent is not None 
    return self.parent.is_global(key)
  
  
  def local_ref_name(self, ref, python_name):
    for (local_name, other_ref) in self.python_refs.iteritems():
      if ref == other_ref:
        return Var(local_name)
    local_name = names.fresh(python_name)
    self.scopes[python_name] = local_name
    self.original_outer_names.append(python_name)
    self.localized_outer_names.append(local_name)
    self.python_refs[local_name] = ref
    return Var(local_name)
  
  def lookup(self, name):
    #if name in reserved_names:
    #  return reserved_names[name]
    if name in self.scopes:
      return Var(self.scopes[name])
    elif self.parent:
      # don't actually keep the outer binding name, we just
      # need to check that it's possible and tell the outer scope
      # to register any necessary python refs
      local_name = names.fresh(name)
      self.scopes[name] = local_name
      self.original_outer_names.append(name)
      self.localized_outer_names.append(local_name)
      return Var(local_name)
    elif self.closure_cell_dict and name in self.closure_cell_dict:
      ref = ClosureCellRef(self.closure_cell_dict[name], name)
      return self.local_ref_name(ref, name)
    elif self.is_global(name):
      value = self.lookup_global(name)
      if is_static_value(value):
        return value_to_syntax(value)
      elif isinstance(value, np.ndarray): 
        ref = GlobalValueRef(value)
        return self.local_ref_name(ref, name)
      else:
        # assume that this is a module or object which will have some 
        # statically convertible value pulled out of it 
        return ExternalValue(value)
      #else:
      #  assert False, "Can't use global value %s" % value
    else:
      raise NameNotFound(name)
      
      
  def visit_list(self, nodes):
    return map(self.visit, nodes)


  def tuple_arg_assignments(self, elts, var):
    """
    Recursively decompose a nested tuple argument like
      def f((x,(y,z))):
        ...
    into a single name and a series of assignments:
      def f(tuple_arg):
        x = tuple_arg[0]
        tuple_arg_elt = tuple_arg[1]
        y = tuple_arg_elt[0]
        z = tuple_arg_elt[1]
    """

    assignments = []
    for (i, sub_arg) in enumerate(elts):
      if isinstance(sub_arg, ast.Tuple):
        name = "tuple_arg_elt"
      else:
        assert isinstance(sub_arg, ast.Name)
        name = sub_arg.id
      lhs = self.fresh_var(name)
      stmt = Assign(lhs, syntax.Index(var, syntax.Const(i)))
      assignments.append(stmt)
      if isinstance(sub_arg, ast.Tuple):
        more_stmts = self.tuple_arg_assignments(sub_arg.elts, lhs)
        assignments.extend(more_stmts)
    return assignments

  def translate_args(self, args):
    assert not args.kwarg
    formals = FormalArgs()
    assignments = []
    for arg in args.args:
      if isinstance(arg, ast.Name):
        visible_name = arg.id
        local_name = self.fresh_name(visible_name)
        formals.add_positional(local_name, visible_name)
      else:
        assert isinstance(arg, ast.Tuple)
        arg_name = self.fresh_name("tuple_arg")
        formals.add_positional(arg_name)
        var = Var(arg_name)
        stmts = self.tuple_arg_assignments(arg.elts, var)
        assignments.extend(stmts)

    n_defaults = len(args.defaults)
    if n_defaults > 0:
      local_names = formals.positional[-n_defaults:]
      for (k,expr) in zip(local_names, args.defaults):
        formals.defaults[k] = self.syntax_to_value(expr)


    if args.vararg:
      assert isinstance(args.vararg, str)
      formals.starargs = self.fresh_name(args.vararg)

    return formals, assignments


  def visit_Name(self, expr):
    assert isinstance(expr, ast.Name), "Expected AST Name object: %s" % expr
    return self.lookup(expr.id)

  def create_phi_nodes(self, left_scope, right_scope, new_names = {}):
    """
    Phi nodes make explicit the possible sources of each variable's values and
    are needed when either two branches merge or when one was optionally taken.
    """
    merge = {}
    for (name, ssa_name) in left_scope.iteritems():
      left = Var(ssa_name)
      if name in right_scope:
        right = Var(right_scope[name])
      else:
        try:
          right = self.lookup(name)
        except NameNotFound:
          continue 
          
      if name in new_names:
        new_name = new_names[name]
      else:
        new_name = self.fresh_name(name)
      merge[new_name] = (left, right)

    for (name, ssa_name) in right_scope.iteritems():
      if name not in left_scope:
        try:
          left = self.lookup(name)
          right = Var(ssa_name)

          if name in new_names:
            new_name = new_names[name]
          else:
            new_name = self.fresh_name(name)
          merge[new_name] = (left, right)
        except names.NameNotFound:
          # for now skip over variables which weren't defined before
          # a control flow split, which means that loop-local variables
          # can't be used after the loop.
          # TODO: Fix this. Maybe with 'undef' nodes?
          pass
    return merge


  def visit_Index(self, expr):
    return self.visit(expr.value)

  def visit_Ellipsis(self, expr):
    raise RuntimeError("Ellipsis operator unsupported")

  def visit_Slice(self, expr):
    """
    x[l:u:s]
    Optional fields
      expr.lower
      expr.upper
      expr.step
    """

    start = self.visit(expr.lower) if expr.lower else none
    stop = self.visit(expr.upper) if expr.upper else none
    step = self.visit(expr.step) if expr.step else none
    return syntax.Slice(start, stop, step)

  def visit_ExtSlice(self, expr):
    slice_elts = map(self.visit, expr.dims)
    if len(slice_elts) > 1:
      return syntax.Tuple(slice_elts)
    else:
      return slice_elts[0]

  def visit_UnaryOp(self, expr):
    ssa_val = self.visit(expr.operand)
    # UAdd doesn't do anything!
    
    if expr.op.__class__.__name__ == 'UAdd':
      return ssa_val 
    prim = prims.find_ast_op(expr.op)
    return syntax.PrimCall(prim, [ssa_val])

  def visit_BinOp(self, expr):
    ssa_left = self.visit(expr.left)
    ssa_right = self.visit(expr.right)
    prim = prims.find_ast_op(expr.op)
    return syntax.PrimCall(prim, [ssa_left, ssa_right])

  def visit_BoolOp(self, expr):
    values = map(self.visit, expr.values)
    prim = prims.find_ast_op(expr.op)
    # Python, strangely, allows more than two arguments to
    # Boolean operators
    result = values[0]
    for v in values[1:]:
      result = syntax.PrimCall(prim, [result, v])
    return result

  def visit_Compare(self, expr):
    lhs = self.visit(expr.left)
    assert len(expr.ops) == 1
    prim = prims.find_ast_op(expr.ops[0])
    assert len(expr.comparators) == 1
    rhs = self.visit(expr.comparators[0])
    return syntax.PrimCall(prim, [lhs, rhs])

  def visit_Subscript(self, expr):
    value = self.visit(expr.value)
    index = self.visit(expr.slice)
    return syntax.Index(value, index)

  def generic_visit(self, expr):
    raise UnsupportedSyntax(expr, 
                            function_name = self.function_name, 
                            filename = self.filename)
  
  def translate_builtin_call(self, value, positional, keywords_dict):
    from ..mappings import function_mappings     
    if value is sum:
      return mk_reduce_call(prim_wrapper(prims.add), positional, zero_i64)
    elif value is max:
      if len(positional) == 1:
        return mk_reduce_call(prim_wrapper(prims.maximum), positional)
      else:
        assert len(positional) == 2
        return syntax.PrimCall(prims.maximum, positional)
    elif value is min:
      if len(positional) == 1:
        return mk_reduce_call(prim_wrapper(prims.minimum), positional)
      else:
        assert len(positional) == 2
        return syntax.PrimCall(prims.minimum, positional)
    elif value is map:
      assert len(keywords_dict) == 0
      assert len(positional) > 1
      axis = keywords_dict.get("axis", None)
      return Map(fn = positional[0], args = positional[1:], axis = axis)
    elif value in function_mappings:
      parakeet_equiv = function_mappings[value]
      if isinstance(parakeet_equiv, macro):
        return parakeet_equiv.transform(positional, keywords_dict)
    fn = value_to_syntax(value)
    return syntax.Call(fn, ActualArgs(positional, keywords_dict))
    
      
  def visit(self, node):
    res = ast.NodeVisitor.visit(self, node)
    return res 
    
  def translate_value_call(self, value, positional, keywords_dict= {}, starargs_expr = None):
    from ..mappings import function_mappings 
    if value in function_mappings:
      value = function_mappings[value]
      
    if isinstance(value, macro):
      return value.transform(positional, keywords_dict)
    elif is_user_function(value):
      return syntax.Call(translate_function_value(value), 
                           ActualArgs(positional, keywords_dict, starargs_expr))
    else:
      return self.translate_builtin_call(value, positional, keywords_dict)
    
  def visit_Call(self, expr):
    """
    TODO: 
    The logic here is broken and haphazard, eventually try to handle nested
    scopes correctly, along with globals, cell refs, etc..
    """
    
    fn, args, keywords_list, starargs, kwargs = \
        expr.func, expr.args, expr.keywords, expr.starargs, expr.kwargs
    assert kwargs is None, "Dictionary of keyword args not supported"

    positional = self.visit_list(args)

    keywords_dict = {}
    for kwd in keywords_list:
      keywords_dict[kwd.arg] = self.visit(kwd.value)

    if starargs:
      starargs_expr = self.visit(starargs)
    else:
      starargs_expr = None
    
    def is_attr_chain(expr):
      return isinstance(expr, ast.Name) or \
             (isinstance(expr, ast.Attribute) and is_attr_chain(expr.value))
    def extract_attr_chain(expr):
      if isinstance(expr, ast.Name):
        return [expr.id]
      else:
        base = extract_attr_chain(expr.value)
        base.append(expr.attr)
        return base
    
    def lookup_attr_chain(names):
      value = self.lookup_global(names[0])
    
      for name in names[1:]:
        if hasattr(value, name):
          value = getattr(value, name)
        else:
          try:
            value = value[name]
          except:
            assert False, "Couldn't find global name %s" % ('.'.join(names))
      return value
          
    if is_attr_chain(fn):
      names = extract_attr_chain(fn)
      
      if self.is_global(names):
        return self.translate_value_call(lookup_attr_chain(names), 
                                         positional, keywords_dict, starargs_expr)
    fn_node = self.visit(fn)    
    if isinstance(fn_node, syntax.Expr):
      actuals = ActualArgs(positional, keywords_dict, starargs_expr)
      return syntax.Call(fn_node, actuals)
    else:
      assert isinstance(fn_node, ExternalValue)
      return self.translate_value_call(fn_node.value, 
                                       positional, keywords_dict, starargs_expr)

  def visit_List(self, expr):
    return syntax.Array(self.visit_list(expr.elts))
    
  def visit_Expr(self, expr):
    # dummy assignment to allow for side effects on RHS
    lhs = self.fresh_var("dummy")
    if isinstance(expr.value, ast.Str):
      return syntax.Assign(lhs, zero_i64)
      # return syntax.Comment(expr.value.s.strip().replace('\n', ''))
    else:
      rhs = self.visit(expr.value)
      return syntax.Assign(lhs, rhs)

    
  def visit_GeneratorExp(self, expr):
    return self.visit_ListComp(expr)
    
  def visit_ListComp(self, expr):
    gens = expr.generators
    assert len(gens) == 1
    gen = gens[0]
    target = gen.target
    if target.__class__ is ast.Name:
      arg_vars = [target]
    else:
      assert target.__class__ is ast.Tuple and all(e.__class__ is ast.Name for e in target.elts),\
       "Expected comprehension target to be variable or tuple of variables, got %s" % ast.dump(target)
      arg_vars = [ast.Tuple(elts = target.elts)]
    # build a lambda as a Python ast representing 
    # what we do to each element 

    args = ast.arguments(args = arg_vars, 
                         vararg = None,  
                         kwarg = None,  
                         defaults = ())

    fn = translate_function_ast(name = "comprehension_map", 
                                args = args, 
                                body = [ast.Return(expr.elt)], 
                                parent = self)
    seq = self.visit(gen.iter)
    ifs = gen.ifs
    assert len(ifs) == 0, "Parakeet: Conditions in array comprehensions not yet supported"
    return Map(fn = fn, args=(seq,), axis = zero_i64)
      
  def visit_Attribute(self, expr):
    # TODO:
    # Recursive lookup to see if:
    #  (1) base object is local, if so-- create chain of attributes
    #  (2) base object is global but an adverb primitive-- use it locally
    #      without adding it to nonlocals
    #  (3) not local at all-- in which case, add the whole chain of strings
    #      to nonlocals
    #
    #  AN IDEA:
    #     Allow external values to be brought into the syntax tree as 
    #     a designated ExternalValue node
    #     and then here check if the LHS is an ExternalValue and if so, 
    #     pull out the value. If it's a constant, then make it into syntax, 
    #     if it's a function, then parse it, else raise an error. 
    #
    
    from ..mappings import property_mappings, method_mappings
    value = self.visit(expr.value)
    attr = expr.attr
    if isinstance(value, ExternalValue):
      value = value.value 
      value = getattr(value, attr)
      if is_static_value(value):
        return value_to_syntax(value)
      else:
        return ExternalValue(value) 
    elif attr in property_mappings:
      fn = property_mappings[attr]
      if isinstance(fn, macro):
        return fn.transform( [value] )
      else:
        return syntax.Call(translate_function_value(fn),
                            ActualArgs(positional = (value,)))  
    elif attr in method_mappings:
      fn_python = method_mappings[attr]
      fn_syntax = translate_function_value(fn_python)
      return syntax.Closure(fn_syntax, args=(value,))
    else:
      assert False, "Attribute %s not supported" % attr 
    # return syntax.Attribute(value, expr.attr)

  def visit_Num(self, expr):
    return syntax.Const(expr.n)

  def visit_Tuple(self, expr):
    return syntax.Tuple(self.visit_list(expr.elts))

  def visit_IfExp(self, expr):
    temp1, temp2, result = self.fresh_vars(["if_true", "if_false", "if_result"])
    cond = self.visit(expr.test)
    true_block = [Assign(temp1, self.visit(expr.body))]
    false_block = [Assign(temp2, self.visit(expr.orelse))]
    merge = {result.name : (temp1, temp2)}
    if_stmt = If(cond, true_block, false_block, merge)
    self.current_block().append(if_stmt)
    return result

  def visit_lhs(self, lhs):
    if isinstance(lhs, ast.Name):
      return self.fresh_var(lhs.id)
    elif isinstance(lhs, ast.Tuple):
      return syntax.Tuple( map(self.visit_lhs, lhs.elts))
    else:
      # in case of slicing or attributes
      res = self.visit(lhs)
      return res

  def visit_Assign(self, stmt):
    # important to evaluate RHS before LHS for statements like 'x = x + 1'
    ssa_rhs = self.visit(stmt.value)
    ssa_lhs = self.visit_lhs(stmt.targets[0])
    return Assign(ssa_lhs, ssa_rhs)
  
  def visit_AugAssign(self, stmt):
    ssa_incr = self.visit(stmt.value)
    ssa_old_value = self.visit(stmt.target)
    ssa_new_value = self.visit_lhs(stmt.target)
    prim = prims.find_ast_op(stmt.op) 
    return Assign(ssa_new_value, PrimCall(prim, [ssa_old_value, ssa_incr]))

  def visit_Return(self, stmt):
    return syntax.Return(self.visit(stmt.value))

  def visit_If(self, stmt):
    cond = self.visit(stmt.test)
    true_scope, true_block  = self.visit_block(stmt.body)
    false_scope, false_block = self.visit_block(stmt.orelse)
    merge = self.create_phi_nodes(true_scope, false_scope)
    return syntax.If(cond, true_block, false_block, merge)

  def visit_loop_body(self, body, *exprs):
    merge = {}
    substitutions = {}
    curr_scope = self.current_scope()
    exprs = [self.visit(expr) for expr in exprs]
    scope_after, body = self.visit_block(body)
    for (k, name_after) in scope_after.iteritems():
      if k in self.scopes:
        name_before = self.scopes[k]
        new_name = names.fresh(k + "_loop")
        merge[new_name] = (Var(name_before), Var(name_after))
        substitutions[name_before]  = new_name
        curr_scope[k] = new_name
    
    exprs = [subst_expr(expr, substitutions) for expr in exprs]
    body = subst_stmt_list(body, substitutions)
    return body, merge, exprs 

  def visit_While(self, stmt):
    assert not stmt.orelse
    body, merge, (cond,) = self.visit_loop_body(stmt.body, stmt.test)
    return syntax.While(cond, body, merge)

  def visit_For(self, stmt):
    assert not stmt.orelse 
    var = self.visit_lhs(stmt.target)
    assert isinstance(var, Var)
    seq = self.visit(stmt.iter)
    body, merge, _ = self.visit_loop_body(stmt.body)
    if isinstance(seq, syntax.Range):
      return ForLoop(var, seq.start, seq.stop, seq.step, body, merge)
    else:
      seq_name = self.fresh_name("seq")
      seq_var = Var(seq_name)
      self.current_block().append(Assign(seq_var, seq))
      len_fn = translate_function_value(len)
      n = syntax.Call(len_fn, ActualArgs([seq_var]))
      start = zero_i64
      step = one_i64
      loop_counter_name = self.fresh_name('loop_counter')
      loop_var = Var(loop_counter_name)
      body = [Assign(var, syntax.Index(seq_var, loop_var))] + body
      return ForLoop(loop_var, start, n, step, body, merge)
    
  def visit_block(self, stmts):
    self.push()
    curr_block = self.current_block()
    for stmt in stmts:
      parakeet_stmt = self.visit(stmt)
      curr_block.append(parakeet_stmt)
    return self.pop()

  def visit_FunctionDef(self, node):
    """
    Translate a nested function
    """
    fundef = translate_function_ast(node.name, node.args, node.body, parent = self)
    local_var = self.fresh_var(node.name)
    return Assign(local_var, fundef)

  def visit_Lambda(self, node):
    return translate_function_ast("lambda", node.args, [ast.Return(node.body)], parent = self)