Пример #1
0
 def shape_from_tuple(self, expr):
     shape_tuple = self.visit_expr(expr)
     if shape_tuple.__class__ is Tuple:
         return make_shape(tuple(shape_tuple.elts))
     elif shape_tuple.__class__ is Const:
         return make_shape((shape_tuple.value, ))
     else:
         return make_shape((any_scalar, ) * expr.type.rank)
Пример #2
0
 def shape_from_tuple(self, expr):  
   shape_tuple = self.visit_expr(expr)
   if shape_tuple.__class__ is Tuple:
     return make_shape(tuple(shape_tuple.elts))
   elif shape_tuple.__class__ is Const:
     return make_shape((shape_tuple.value,))
   else:
     return make_shape( (any_scalar,) * expr.type.rank)
Пример #3
0
    def visit_Index(self, expr):
        arr = self.visit_expr(expr.value)
        idx = self.visit_expr(expr.index)

        if arr.__class__ is Tuple and idx.__class__ is Const:
            return arr[idx.value]
        elif arr.__class__ is Shape:
            if isinstance(idx, Scalar):
                return shape.lower_rank(arr, 0)

            elif idx.__class__ is Shape:
                assert len(idx.dims) <= len(arr.dims), \
                    "Can't index into rank %d array with rank %d indices" % \
                    (len(arr.dims), len(idx.dims))
                dims = [d for d in arr.dims]
                for (i, d) in enumerate(idx.dims):
                    dims[i] = d
                return shape.make_shape(dims)
            else:
                return self.index(arr, idx)

        elif arr.__class__ is Ptr:
            assert isinstance(arr.elt_shape, Scalar)
            assert isinstance(idx, Scalar)

            return any_scalar

        if isinstance(arr, Scalar):
            assert False, "Expected %s to be array, shape inference found scalar" % (
                arr, )
        elif arr == shape.any_value:
            raise ShapeInferenceFailure(expr, self.fn)
        assert False, \
            "Can't index (%s) with array shape %s and index shape %s" % \
            (expr, arr, idx)
Пример #4
0
 def visit_Index(self, expr):
   arr = self.visit_expr(expr.value)
   idx = self.visit_expr(expr.index)
   
   if arr.__class__ is Tuple and idx.__class__ is Const:
     return arr[idx.value]
   elif arr.__class__ is Shape:
     if isinstance(idx, Scalar):
       return shape.lower_rank(arr, 0)
     
     elif idx.__class__ is Shape:
       assert len(idx.dims) <= len(arr.dims), \
           "Can't index into rank %d array with rank %d indices" % \
           (len(arr.dims), len(idx.dims))
       dims = [d for d in arr.dims]
       for (i,d) in enumerate(idx.dims):
         dims[i] = d
       return shape.make_shape(dims)
     else:
       return self.index(arr, idx)
     
   elif arr.__class__ is Ptr:
     assert isinstance(arr.elt_shape, Scalar)
     assert isinstance(idx, Scalar)
     
     return any_scalar
   
   if isinstance(arr, Scalar):
     assert False, "Expected %s to be array, shape inference found scalar" % (arr,)
   elif arr == shape.any_value:
     raise ShapeInferenceFailure(expr, self.fn)
   assert False, \
       "Can't index (%s) with array shape %s and index shape %s" % \
       (expr, arr, idx)
Пример #5
0
 def visit_IndexMap(self, expr):
   shape_tuple = self.visit_expr(expr.shape)
   clos = self.visit_expr(expr.fn)
   if isinstance(clos.fn.input_types[-1], TupleT):
     elt_result = symbolic_call(clos, [shape_tuple])
   else:
     elt_result = symbolic_call(clos, shape_tuple.elts)
   return make_shape(combine_dims(shape_tuple, elt_result))
Пример #6
0
 def visit_IndexMap(self, expr):
   bounds = self.visit_expr(expr.shape)
   clos = self.visit_expr(expr.fn)
   if isinstance(clos.fn.input_types[-1], TupleT) or bounds.__class__ is not Tuple:
     indices = [bounds]
   else:
     indices = bounds.elts 
   elt_result = symbolic_call(clos, indices)
   return make_shape(combine_dims(bounds, elt_result))
Пример #7
0
 def visit_IndexMap(self, expr):
     bounds = self.visit_expr(expr.shape)
     clos = self.visit_expr(expr.fn)
     if isinstance(clos.fn.input_types[-1],
                   TupleT) or bounds.__class__ is not Tuple:
         indices = [bounds]
     else:
         indices = bounds.elts
     elt_result = symbolic_call(clos, indices)
     return make_shape(combine_dims(bounds, elt_result))
Пример #8
0
 def visit_IndexScan(self, expr):
   fn = self.visit_expr(expr.fn)
   combine = self.visit_expr(expr.combine)
   emit = self.visit_expr(expr.emit)
   bounds = self.visit_expr(expr.shape)
   elt_shape = symbolic_call(fn, [bounds])
   init_shape = elt_shape if self.expr_is_none(expr.init) else self.visit_expr(expr.init) 
   acc_shape = symbolic_call(combine, [init_shape, elt_shape])
   output_elt_shape = symbolic_call(emit, [acc_shape])
   return make_shape(combine_dims(bounds, output_elt_shape))
Пример #9
0
 def visit_Struct(self, expr):
   if isinstance(expr.type, ArrayT):
     shape_tuple = self.visit_expr(expr.args[1])
     return make_shape(shape_tuple.elts)
   elif isinstance(expr.type, TupleT):
     return Tuple(self.visit_expr_list(expr.args))
   elif isinstance(expr.type, SliceT):
     start, stop, step = self.visit_expr_list(expr.args)
     return Slice(start, stop, step)
   else:
     return unknown_value
Пример #10
0
 def outer_map_result_shape(self, elt_result, arg_shapes, axes):
   result_dims = list(dims(elt_result))
   for i, arg_shape in enumerate(arg_shapes):
     r = self.rank(arg_shape)
     if r > 0:
       axis = axes[i]
       if axis is None:
         result_dims.extend(arg_shape.dims)
       else:
         result_dims.append(arg_shape.dims[axis])
   return make_shape(result_dims)
Пример #11
0
 def outer_map_result_shape(self, elt_result, arg_shapes, axes):
     result_dims = list(dims(elt_result))
     for i, arg_shape in enumerate(arg_shapes):
         r = self.rank(arg_shape)
         if r > 0:
             axis = axes[i]
             if axis is None:
                 result_dims.extend(arg_shape.dims)
             else:
                 result_dims.append(arg_shape.dims[axis])
     return make_shape(result_dims)
Пример #12
0
 def visit_Struct(self, expr):
     if isinstance(expr.type, ArrayT):
         shape_tuple = self.visit_expr(expr.args[1])
         return make_shape(shape_tuple.elts)
     elif isinstance(expr.type, TupleT):
         return Tuple(self.visit_expr_list(expr.args))
     elif isinstance(expr.type, SliceT):
         start, stop, step = self.visit_expr_list(expr.args)
         return Slice(start, stop, step)
     else:
         return unknown_value
Пример #13
0
def subst(x, env):
  if isinstance(x, Var):
    assert x in env, "Unknown variable %s" % x
    return env[x]
  elif isinstance(x, Scalar):
    return x
  elif isinstance(x, Shape):
    return make_shape(subst_list(x.dims, env))
  elif isinstance(x, Tuple):
    return shape_semantics.tuple(subst_list(x.elts, env))
  elif isinstance(x, Closure):
    return Closure(x.fn, subst_list(x.args, env))
  else:
    raise RuntimeError("Unexpected abstract expression: %s" % x)
Пример #14
0
 def visit_IndexScan(self, expr):
   fn = self.visit_expr(expr.fn)
   combine = self.visit_expr(expr.combine)
   emit = self.visit_expr(expr.emit)
   bounds = self.visit_expr(expr.shape)
   if isinstance(fn.fn.input_types[-1], TupleT) or bounds.__class__ is not Tuple:
     indices = [bounds]
   else:
     indices = bounds.elts
   elt_shape = symbolic_call(fn, indices)
   init_shape = elt_shape if self.expr_is_none(expr.init) else self.visit_expr(expr.init) 
   acc_shape = symbolic_call(combine, [init_shape, elt_shape])
   output_elt_shape = symbolic_call(emit, [acc_shape])
   return make_shape(combine_dims(bounds, output_elt_shape))
Пример #15
0
def subst(x, env):
    if isinstance(x, Var):
        assert x in env, "Unknown variable %s" % x
        return env[x]
    elif isinstance(x, Scalar):
        return x
    elif isinstance(x, Shape):
        return make_shape(subst_list(x.dims, env))
    elif isinstance(x, Tuple):
        return Tuple(tuple((subst_list(x.elts, env))))
    elif isinstance(x, Closure):
        return Closure(x.fn, subst_list(x.args, env))
    elif isinstance(x, Ptr):
        return Ptr(subst(x.elt_shape, env))
    else:
        raise RuntimeError("Unexpected abstract expression: %s" % x)
Пример #16
0
 def visit_IndexScan(self, expr):
     fn = self.visit_expr(expr.fn)
     combine = self.visit_expr(expr.combine)
     emit = self.visit_expr(expr.emit)
     bounds = self.visit_expr(expr.shape)
     if isinstance(fn.fn.input_types[-1],
                   TupleT) or bounds.__class__ is not Tuple:
         indices = [bounds]
     else:
         indices = bounds.elts
     elt_shape = symbolic_call(fn, indices)
     init_shape = elt_shape if self.expr_is_none(
         expr.init) else self.visit_expr(expr.init)
     acc_shape = symbolic_call(combine, [init_shape, elt_shape])
     output_elt_shape = symbolic_call(emit, [acc_shape])
     return make_shape(combine_dims(bounds, output_elt_shape))
Пример #17
0
 def visit_Index(self, expr):
   arr = self.visit_expr(expr.value)
   idx = self.visit_expr(expr.index)
   if arr.__class__ is Tuple and idx.__class__ is Const:
     return arr[idx.value]
   elif arr.__class__ is Shape:
     if isinstance(idx, Scalar):
       return shape.lower_rank(arr, 0)
     elif idx.__class__ is Shape:
       assert len(idx.dims) <= len(arr.dims), \
           "Can't index into rank %d array with rank %d indices" % \
           (len(arr.dims), len(idx.dims))
       dims = [d for d in arr.dims]
       for (i,d) in enumerate(idx.dims):
         dims[i] = d
       return shape.make_shape(dims)
     else:
       return shape_semantics.index(arr, idx)
   assert False, \
       "Can't index (%s) with array shape %s and index shape %s" % \
       (expr, arr, idx)
Пример #18
0
 def visit_AllocArray(self, expr):
   shape_tuple = self.visit_expr(expr.shape)
   return make_shape(shape_tuple.elts)
Пример #19
0
  def index(self, arr, idx):
    if isinstance(arr, Scalar):
      return arr
    assert arr.__class__ is Shape
    if isinstance(idx, (Scalar, Slice)):
      indices = [idx]
    elif idx.__class__ is Tuple:
      indices = idx.elts
    else:
      assert False, "Unexpected index: %s" % (idx,)
    result_dims = []
    for (i, curr_idx) in enumerate(indices):
      old_dim = arr.dims[i]
      if curr_idx is None or \
         (isinstance(curr_idx, Const) and curr_idx.value is None):
        result_dims.append(old_dim)
      elif isinstance(curr_idx, Scalar):
        pass
      elif curr_idx.__class__ is ConstSlice:
        result_dims.append(curr_idx.nelts)
      elif curr_idx.__class__ is Shape:
        if len(curr_idx.dims) == 0:
          # same as unknown scalar 
          pass 
        else:
          assert len(curr_idx.dims) == 1, "Indexing by a multi-dimensional array not yet supported"
          result_dims.append(curr_idx.dims[0])
      else:
        assert curr_idx.__class__ is Slice, "Unsupported index %s" % curr_idx

        if curr_idx.start is None:
          lower = const(0)
        elif isinstance(curr_idx.start, Const):
          if curr_idx.start.value is None:
            lower = const(0)
          elif curr_idx.start.value < 0:
            lower = self.sub(old_dim, curr_idx.start)
          else:
            lower = curr_idx.start
        else:
          lower = any_scalar
         
        if curr_idx.stop is None:
          upper = old_dim 
        elif isinstance(curr_idx.stop, Const):
          if curr_idx.stop.value is None:
            upper = old_dim
          elif curr_idx.stop.value < 0:
            upper = self.sub(old_dim, curr_idx.stop)
          else:
            upper = curr_idx.stop
        else:
          upper = any_scalar

        n = self.sub(upper, lower)
        step = curr_idx.step
        if step and \
            isinstance(step, Const) and \
            step.value is not None and \
            step.value != 1:
          n = self.div(n, step)
        result_dims.append(n)
    n_original = len(arr.dims)
    n_idx= len(indices)
    if n_original > n_idx:
      result_dims.extend(arr.dims[n_idx:])

    return make_shape(result_dims)
Пример #20
0
 def alloc_array(self, _, dims):
   return make_shape(dims)
Пример #21
0
 def alloc_array(self, _, dims):
     return make_shape(dims)
Пример #22
0
    def index(self, arr, idx):

        if isinstance(arr, Scalar):
            return arr
        assert arr.__class__ is Shape
        if isinstance(idx, (Scalar, Slice, ConstSlice)):
            indices = [idx]
        elif idx.__class__ is Tuple:
            indices = idx.elts
        else:
            assert False, "Unexpected index: %s" % (idx, )
        result_dims = []
        for (i, curr_idx) in enumerate(indices):
            old_dim = arr.dims[i]
            if curr_idx is None or \
               (isinstance(curr_idx, Const) and curr_idx.value is None):
                result_dims.append(old_dim)
            elif isinstance(curr_idx, Scalar):
                pass
            elif curr_idx.__class__ is ConstSlice:
                result_dims.append(curr_idx.nelts)
            elif curr_idx.__class__ is Shape:
                if len(curr_idx.dims) == 0:
                    # same as unknown scalar
                    pass
                else:
                    assert len(
                        curr_idx.dims
                    ) == 1, "Indexing by a multi-dimensional array not yet supported"
                    result_dims.append(curr_idx.dims[0])
            else:
                assert curr_idx.__class__ is Slice, "Unsupported index %s" % curr_idx

                if curr_idx.start is None:
                    lower = const(0)
                elif isinstance(curr_idx.start, Const):
                    if curr_idx.start.value is None:
                        lower = const(0)
                    elif curr_idx.start.value < 0:
                        lower = self.sub(old_dim, curr_idx.start)
                    else:
                        lower = curr_idx.start
                else:
                    lower = any_scalar

                if curr_idx.stop is None:
                    upper = old_dim
                elif isinstance(curr_idx.stop, Const):
                    if curr_idx.stop.value is None:
                        upper = old_dim
                    elif curr_idx.stop.value < 0:
                        upper = self.sub(old_dim, curr_idx.stop)
                    else:
                        upper = curr_idx.stop
                else:
                    upper = any_scalar

                n = self.sub(upper, lower)
                step = curr_idx.step
                if step and \
                    isinstance(step, Const) and \
                    step.value is not None and \
                    step.value != 1:
                    n = self.div(n, step)
                result_dims.append(n)
        n_original = len(arr.dims)
        n_idx = len(indices)
        if n_original > n_idx:
            result_dims.extend(arr.dims[n_idx:])

        return make_shape(result_dims)