def inner_map_result_shape(self, elt_result, arg_shapes, axes): max_rank = self.max_rank(arg_shapes) for i, arg_shape in enumerate(arg_shapes): r = self.rank(arg_shape) if r == max_rank: axis = axes[i] if axis is None: combined_dims = dims(arg_shape) + dims(elt_result) if len(combined_dims) > 0: return Shape(combined_dims) else: return any_scalar else: return increase_rank(elt_result, 0, arg_shape.dims[axis]) return elt_result
def visit_Map(self, expr): arg_shapes = self.visit_expr_list(expr.args) fn = self.visit_expr(expr.fn) axis = unwrap_constant(expr.axis) assert axis is not None, "Unexpected axis=None in Map %s" % expr elt_shapes = [self.slice_along_axis(arg, axis) for arg in arg_shapes] elt_result = symbolic_call(fn, elt_shapes) outer_dim = None max_rank = 0 for arg_shape in arg_shapes: if isinstance(arg_shape, Shape) and len(arg_shape.dims) > max_rank: max_rank = len(arg_shape.dims) outer_dim = arg_shape.dims[axis] if outer_dim is not None: return shape.increase_rank(elt_result, axis, outer_dim) else: return elt_result
def visit_Array(self, expr): elts = self.visit_expr_list(expr.elts) elt = combine_list(elts) n = len(elts) res = increase_rank(elt, 0, const(n)) return res