def inner_map_result_shape(self, elt_result, arg_shapes, axes): max_rank = self.max_rank(arg_shapes) for i, arg_shape in enumerate(arg_shapes): r = self.rank(arg_shape) if r == max_rank: axis = axes[i] if axis is None: combined_dims = dims(arg_shape) + dims(elt_result) if len(combined_dims) > 0: return Shape(combined_dims) else: return any_scalar else: return increase_rank(elt_result, 0, arg_shape.dims[axis]) return elt_result
def outer_map_result_shape(self, elt_result, arg_shapes, axes): result_dims = list(dims(elt_result)) for i, arg_shape in enumerate(arg_shapes): r = self.rank(arg_shape) if r > 0: axis = axes[i] if axis is None: result_dims.extend(arg_shape.dims) else: result_dims.append(arg_shape.dims[axis]) return make_shape(result_dims)