def map_basic_index(self, expr: BasicIndex) -> IndexLambda: vng = UniqueNameGenerator() indices = [] in_ary = vng("in") bindings = {in_ary: self.rec(expr.array)} islice_idx = 0 for idx, axis_len in zip(expr.indices, expr.array.shape): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): indices.append(idx.start + idx.step * prim.Variable(f"_{islice_idx}")) islice_idx += 1 else: raise NotImplementedError return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=bindings, shape=expr.shape, dtype=expr.dtype, axes=expr.axes, tags=expr.tags, )
def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 ) -> ArrayOrScalar: from pytato.array import _get_default_axes if isinstance(a1, SCALAR_CLASSES): a1 = np.dtype(type(a1)).type(a1) if isinstance(a2, SCALAR_CLASSES): a2 = np.dtype(type(a2)).type(a2) if np.isscalar(a1) and np.isscalar(a2): from pytato.scalar_expr import evaluate return evaluate(op(a1, a2)) # type: ignore result_shape = get_shape_after_broadcasting([a1, a2]) dtypes = extract_dtypes_or_scalars([a1, a2]) result_dtype = get_result_type(*dtypes) bindings: Dict[str, Array] = {} expr1 = update_bindings_and_get_broadcasted_expr(a1, "_in0", bindings, result_shape) expr2 = update_bindings_and_get_broadcasted_expr(a2, "_in1", bindings, result_shape) return IndexLambda(op(expr1, expr2), shape=result_shape, dtype=result_dtype, bindings=bindings, axes=_get_default_axes(len(result_shape)))
def map_roll(self, expr: Roll) -> Array: from pytato.utils import dim_to_index_lambda_components index_expr = var("_in0") indices = [var(f"_{d}") for d in range(expr.ndim)] axis = expr.axis axis_len_expr, bindings = dim_to_index_lambda_components( expr.shape[axis], UniqueNameGenerator({"_in0"})) indices[axis] = (indices[axis] - expr.shift) % axis_len_expr if indices: index_expr = index_expr[tuple(indices)] bindings["_in0"] = expr.array # type: ignore return IndexLambda(expr=index_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings={name: self.rec(bnd) for name, bnd in bindings.items()}, axes=expr.axes, tags=expr.tags)
def map_index_lambda(self, expr: IndexLambda, state: CodeGenState) -> ImplementedResult: if expr in state.results: return state.results[expr] prstnt_ctx = PersistentExpressionContext(state) local_ctx = LocalExpressionContext(local_namespace=expr.bindings, num_indices=expr.ndim, reduction_bounds={}) loopy_expr = self.exprgen_mapper(expr.expr, prstnt_ctx, local_ctx) result: ImplementedResult = InlinedResult(loopy_expr, expr.ndim, prstnt_ctx.depends_on) shape_to_scalar_expression(expr.shape, self, state) # walk over size params # {{{ implementation tag if expr.tags_of_type(ImplStored): name = _generate_name_for_temp(expr, state) result = StoredResult( name, expr.ndim, frozenset([add_store(name, expr, result, state, self, True)])) # }}} state.results[expr] = result return result
def map_stack(self, expr: Stack) -> Array: def get_subscript(array_index: int) -> SymbolicIndex: result = [] for i in range(expr.ndim): if i != expr.axis: result.append(var(f"_{i}")) return tuple(result) # I = axis index # # => If(_I == 0, # _in0[_0, _1, ...], # If(_I == 1, # _in1[_0, _1, ...], # ... # _inNm1[_0, _1, ...] ...)) for i in range(len(expr.arrays) - 1, -1, -1): subarray_expr = var(f"_in{i}")[get_subscript(i)] if i == len(expr.arrays) - 1: stack_expr = subarray_expr else: from pymbolic.primitives import If, Comparison stack_expr = If(Comparison(var(f"_{expr.axis}"), "==", i), subarray_expr, stack_expr) bindings = { f"_in{i}": self.rec(array) for i, array in enumerate(expr.arrays) } return IndexLambda(namespace=self.namespace, expr=stack_expr, shape=expr.shape, dtype=expr.dtype, bindings=bindings)
def map_non_contiguous_advanced_index(self, expr: AdvancedIndexInNoncontiguousAxes ) -> IndexLambda: from pytato.utils import (get_shape_after_broadcasting, get_indexing_expression) i_adv_indices = tuple(i for i, idx_expr in enumerate(expr.indices) if isinstance(idx_expr, (Array, INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([expr.indices[i_idx] for i_idx in i_adv_indices]) vng = UniqueNameGenerator() indices = [] in_ary = vng("in") bindings = {in_ary: self.rec(expr.array)} islice_idx = len(adv_idx_shape) for idx, axis_len in zip(expr.indices, expr.array.shape): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") bindings[bnd_name] = self.rec(axis_len) indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): indices.append(idx.start + idx.step * prim.Variable(f"_{islice_idx}")) islice_idx += 1 elif isinstance(idx, Array): if isinstance(axis_len, INT_CLASSES): bnd_name = vng("in") bindings[bnd_name] = self.rec(idx) indirect_idx_expr = prim.Subscript(prim.Variable(bnd_name), get_indexing_expression( idx.shape, adv_idx_shape)) if not idx.tags_of_type(AssumeNonNegative): indirect_idx_expr = indirect_idx_expr % axis_len indices.append(indirect_idx_expr) else: raise NotImplementedError("Advanced indexing over" " parametric axis lengths.") else: raise NotImplementedError(f"Indices of type {type(idx)}.") return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=bindings, shape=expr.shape, dtype=expr.dtype, axes=expr.axes, tags=expr.tags, )
def map_index_lambda(self, expr: IndexLambda) -> Array: bindings = { name: self.rec(subexpr) for name, subexpr in expr.bindings.items()} return IndexLambda(namespace=self.namespace, expr=expr.expr, shape=expr.shape, dtype=expr.dtype, bindings=bindings, tags=expr.tags)
def map_concatenate(self, expr: Concatenate) -> Array: from pymbolic.primitives import If, Comparison, Subscript def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: aggregate = var(f"_in{array_index}") index = [ var(f"_{i}") if i != expr.axis else (var(f"_{i}") - offset) for i in range(len(expr.shape)) ] return Subscript(aggregate, tuple(index)) lbounds: List[Any] = [0] ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]] for i, array in enumerate(expr.arrays[1:], start=1): ubounds.append(ubounds[i - 1] + array.shape[expr.axis]) lbounds.append(ubounds[i - 1]) # I = axis index # # => If(0<=_I < arrays[0].shape[axis], # _in0[_0, _1, ..., _I, ...], # If(arrays[0].shape[axis]<= _I < (arrays[1].shape[axis] # +arrays[0].shape[axis]), # _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...], # ... # _inNm1[_0, _1, ...] ...)) for i in range(len(expr.arrays) - 1, -1, -1): lbound, ubound = lbounds[i], ubounds[i] subarray_expr = get_subscript(i, lbound) if i == len(expr.arrays) - 1: stack_expr = subarray_expr else: stack_expr = If( Comparison(var(f"_{expr.axis}"), ">=", lbound) and Comparison(var(f"_{expr.axis}"), "<", ubound), subarray_expr, stack_expr) bindings = { f"_in{i}": self.rec(array) for i, array in enumerate(expr.arrays) } return IndexLambda(namespace=self.namespace, expr=stack_expr, shape=expr.shape, dtype=expr.dtype, bindings=bindings)
def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], func_name: str, ret_dtype: Optional[_dtype_any] = None, np_func_name: Optional[str] = None ) -> ArrayOrScalar: if all(isinstance(x, SCALAR_CLASSES) for x in inputs): if np_func_name is None: np_func_name = func_name np_func = getattr(np, np_func_name) return np_func(*inputs) # type: ignore if not inputs: raise ValueError("at least one argument must be present") shape = None sym_args = [] bindings = {} for index, inp in enumerate(inputs): if isinstance(inp, Array): if inp.dtype.kind not in ["f", "c"]: raise ValueError("only floating-point or complex " "arguments supported") if shape is None: shape = inp.shape elif inp.shape != shape: # FIXME: merge this logic with arithmetic, so that broadcasting # is implemented properly raise NotImplementedError("broadcasting in function application") if ret_dtype is None: ret_dtype = inp.dtype bindings[f"in_{index}"] = inp sym_args.append( prim.Subscript(var(f"in_{index}"), tuple(var(f"_{i}") for i in range(len(shape))))) else: sym_args.append(inp) assert shape is not None assert ret_dtype is not None return IndexLambda( prim.Call(var(f"pytato.c99.{func_name}"), tuple(sym_args)), shape, ret_dtype, bindings, axes=_get_default_axes(len(shape)))
def handle_index_remapping(self, indices_getter: Callable[ [CodeGenPreprocessor, Array], SymbolicIndex], expr: IndexRemappingBase) -> Array: indices = indices_getter(self, expr) index_expr = var("_in0") if indices: index_expr = index_expr[indices] array = self.rec(expr.array) return IndexLambda(namespace=self.namespace, expr=index_expr, shape=expr.shape, dtype=expr.dtype, bindings=dict(_in0=array))
def handle_index_remapping(self, indices_getter: Callable[[CodeGenPreprocessor, Array], SymbolicIndex], expr: IndexRemappingBase) -> Array: indices = indices_getter(self, expr) index_expr = var("_in0") if indices: index_expr = index_expr[indices] array = self.rec(expr.array) return IndexLambda(expr=index_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings=dict(_in0=array), axes=expr.axes, tags=expr.tags)
def map_einsum(self, expr: Einsum) -> Array: import operator from functools import reduce from pytato.scalar_expr import Reduce from pytato.utils import (dim_to_index_lambda_components, are_shape_components_equal) from pytato.array import ElementwiseAxis, ReductionAxis bindings = {f"in{k}": self.rec(arg) for k, arg in enumerate(expr.args)} redn_bounds: Dict[str, Tuple[ScalarExpression, ScalarExpression]] = {} args_as_pym_expr: List[prim.Subscript] = [] namegen = UniqueNameGenerator(set(bindings)) # {{{ add bindings coming from the shape expressions for access_descr, (iarg, arg) in zip(expr.access_descriptors, enumerate(expr.args)): subscript_indices = [] for iaxis, axis in enumerate(access_descr): if not are_shape_components_equal( arg.shape[iaxis], expr._access_descr_to_axis_len()[axis]): # axis is broadcasted assert are_shape_components_equal(arg.shape[iaxis], 1) subscript_indices.append(0) continue if isinstance(axis, ElementwiseAxis): subscript_indices.append(prim.Variable(f"_{axis.dim}")) else: assert isinstance(axis, ReductionAxis) redn_idx_name = f"_r{axis.dim}" if redn_idx_name not in redn_bounds: # convert the ShapeComponent to a ScalarExpression redn_bound, redn_bound_bindings = ( dim_to_index_lambda_components( arg.shape[iaxis], namegen)) redn_bounds[redn_idx_name] = (0, redn_bound) bindings.update({k: self.rec(v) for k, v in redn_bound_bindings.items()}) subscript_indices.append(prim.Variable(redn_idx_name)) args_as_pym_expr.append(prim.Subscript(prim.Variable(f"in{iarg}"), tuple(subscript_indices))) # }}} inner_expr = reduce(operator.mul, args_as_pym_expr[1:], args_as_pym_expr[0]) if redn_bounds: from pytato.reductions import SumReductionOperation inner_expr = Reduce(inner_expr, SumReductionOperation(), redn_bounds) return IndexLambda(expr=inner_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings=bindings, axes=expr.axes, tags=expr.tags)
def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) def _assert_stripped_repr(ary: pt.Array, expected_repr: str): expected_str = "".join([c for c in repr(ary) if c not in [" ", "\n"]]) result_str = "".join([c for c in expected_repr if c not in [" ", "\n"]]) assert expected_str == result_str _assert_stripped_repr( 3*x + 4*y, """ IndexLambda( expr=Sum((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), shape=(10, 4), dtype='int64', bindings={'_in0': IndexLambda(expr=Product((3, Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), shape=(10, 4), dtype='int64', bindings={'_in1': Placeholder(shape=(10, 4), dtype='int64', name='x')}), '_in1': IndexLambda(expr=Product((4, Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), shape=(10, 4), dtype='int64', bindings={'_in1': Placeholder(shape=(10, 4), dtype='int64', name='y')})})""") _assert_stripped_repr( pt.roll(x.reshape(2, 20).reshape(-1), 3), """ Roll( array=Reshape(array=Reshape(array=Placeholder(shape=(10, 4), dtype='int64', name='x'), newshape=(2, 20), order='C'), newshape=(40), order='C'), shift=3, axis=0)""") _assert_stripped_repr(y * pt.not_equal(x, 3), """ IndexLambda( expr=Product((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), shape=(10, 4), dtype='int64', bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='y'), '_in1': IndexLambda( expr=Comparison(Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), '!=', 3), shape=(10, 4), dtype=<class 'numpy.bool_'>, bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='x')})})""") _assert_stripped_repr( x[y[:, 2:3], x[2, :]], """ AdvancedIndexInContiguousAxes( array=Placeholder(shape=(10, 4), dtype='int64', name='x'), indices=(BasicIndex(array=Placeholder(shape=(10, 4), dtype='int64', name='y'), indices=(NormalizedSlice(start=0, stop=10, step=1), NormalizedSlice(start=2, stop=3, step=1))), BasicIndex(array=Placeholder(shape=(10, 4), dtype='int64', name='x'), indices=(2, NormalizedSlice(start=0, stop=4, step=1)))))""") _assert_stripped_repr( pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]), """ Stack( arrays=( AxisPermutation( array=AdvancedIndexInContiguousAxes( array=Placeholder(shape=(10, 4), dtype='int64', name='x'), indices=(BasicIndex(array=(...), indices=(NormalizedSlice(start=0, stop=10, step=1), NormalizedSlice(start=2, stop=3, step=1))), BasicIndex(array=(...), indices=(2, NormalizedSlice(start=0, stop=4, step=1))))), axis_permutation=(1, 0)), AxisPermutation(array=AdvancedIndexInContiguousAxes( array=Placeholder(shape=(10, 4), dtype='int64', name='y'), indices=(BasicIndex(array=(...), indices=(NormalizedSlice(start=0, stop=10, step=1), NormalizedSlice(start=2, stop=3, step=1))), BasicIndex(array=(...), indices=(2, NormalizedSlice(start=0, stop=4, step=1))))), axis_permutation=(1, 0))), axis=0) """)