Ejemplo n.º 1
0
 def inner_map_result_shape(self, elt_result, arg_shapes, axes):
   max_rank = self.max_rank(arg_shapes)    
   for i, arg_shape in enumerate(arg_shapes):
     r = self.rank(arg_shape)
     if r == max_rank:
       axis = axes[i]
       if axis is None:
         combined_dims = dims(arg_shape) + dims(elt_result)
         if len(combined_dims) > 0:
           return Shape(combined_dims)
         else:
           return any_scalar 
       else:
         return increase_rank(elt_result, 0, arg_shape.dims[axis])
   return elt_result
Ejemplo n.º 2
0
 def inner_map_result_shape(self, elt_result, arg_shapes, axes):
     max_rank = self.max_rank(arg_shapes)
     for i, arg_shape in enumerate(arg_shapes):
         r = self.rank(arg_shape)
         if r == max_rank:
             axis = axes[i]
             if axis is None:
                 combined_dims = dims(arg_shape) + dims(elt_result)
                 if len(combined_dims) > 0:
                     return Shape(combined_dims)
                 else:
                     return any_scalar
             else:
                 return increase_rank(elt_result, 0, arg_shape.dims[axis])
     return elt_result
Ejemplo n.º 3
0
 def visit_Map(self, expr):
   arg_shapes = self.visit_expr_list(expr.args)
   fn = self.visit_expr(expr.fn)
   axis = unwrap_constant(expr.axis)
   assert axis is not None, "Unexpected axis=None in Map %s" % expr 
   elt_shapes = [self.slice_along_axis(arg, axis) for arg in arg_shapes]
   elt_result = symbolic_call(fn, elt_shapes)
   
   outer_dim = None 
   max_rank = 0
   for arg_shape in arg_shapes:
     if isinstance(arg_shape, Shape) and len(arg_shape.dims) > max_rank:
       max_rank = len(arg_shape.dims)
       outer_dim = arg_shape.dims[axis]
   if outer_dim is not None:
     return shape.increase_rank(elt_result, axis, outer_dim)
   else:
     return elt_result 
Ejemplo n.º 4
0
 def visit_Array(self, expr):
   elts = self.visit_expr_list(expr.elts)
   elt = combine_list(elts)
   n = len(elts)
   res = increase_rank(elt, 0, const(n))
   return res
Ejemplo n.º 5
0
 def visit_Array(self, expr):
     elts = self.visit_expr_list(expr.elts)
     elt = combine_list(elts)
     n = len(elts)
     res = increase_rank(elt, 0, const(n))
     return res