Beispiel #1
0
    def map_basic_index(self, expr: BasicIndex) -> IndexLambda:
        vng = UniqueNameGenerator()
        indices = []

        in_ary = vng("in")
        bindings = {in_ary: self.rec(expr.array)}
        islice_idx = 0

        for idx, axis_len in zip(expr.indices, expr.array.shape):
            if isinstance(idx, INT_CLASSES):
                if isinstance(axis_len, INT_CLASSES):
                    indices.append(idx % axis_len)
                else:
                    bnd_name = vng("in")
                    bindings[bnd_name] = axis_len
                    indices.append(idx % prim.Variable(bnd_name))
            elif isinstance(idx, NormalizedSlice):
                indices.append(idx.start
                               + idx.step * prim.Variable(f"_{islice_idx}"))
                islice_idx += 1
            else:
                raise NotImplementedError

        return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary),
                                               tuple(indices)),
                           bindings=bindings,
                           shape=expr.shape,
                           dtype=expr.dtype,
                           axes=expr.axes,
                           tags=expr.tags,
                           )
Beispiel #2
0
def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
                        op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression],  # noqa:E501
                        get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]],  # noqa:E501
                        ) -> ArrayOrScalar:
    from pytato.array import _get_default_axes

    if isinstance(a1, SCALAR_CLASSES):
        a1 = np.dtype(type(a1)).type(a1)

    if isinstance(a2, SCALAR_CLASSES):
        a2 = np.dtype(type(a2)).type(a2)

    if np.isscalar(a1) and np.isscalar(a2):
        from pytato.scalar_expr import evaluate
        return evaluate(op(a1, a2))  # type: ignore

    result_shape = get_shape_after_broadcasting([a1, a2])
    dtypes = extract_dtypes_or_scalars([a1, a2])
    result_dtype = get_result_type(*dtypes)

    bindings: Dict[str, Array] = {}

    expr1 = update_bindings_and_get_broadcasted_expr(a1, "_in0", bindings,
                                                     result_shape)
    expr2 = update_bindings_and_get_broadcasted_expr(a2, "_in1", bindings,
                                                     result_shape)

    return IndexLambda(op(expr1, expr2),
                       shape=result_shape,
                       dtype=result_dtype,
                       bindings=bindings,
                       axes=_get_default_axes(len(result_shape)))
Beispiel #3
0
    def map_roll(self, expr: Roll) -> Array:
        from pytato.utils import dim_to_index_lambda_components

        index_expr = var("_in0")
        indices = [var(f"_{d}") for d in range(expr.ndim)]
        axis = expr.axis
        axis_len_expr, bindings = dim_to_index_lambda_components(
            expr.shape[axis],
            UniqueNameGenerator({"_in0"}))

        indices[axis] = (indices[axis] - expr.shift) % axis_len_expr

        if indices:
            index_expr = index_expr[tuple(indices)]

        bindings["_in0"] = expr.array  # type: ignore

        return IndexLambda(expr=index_expr,
                           shape=tuple(self.rec(s) if isinstance(s, Array) else s
                                       for s in expr.shape),
                           dtype=expr.dtype,
                           bindings={name: self.rec(bnd)
                                     for name, bnd in bindings.items()},
                           axes=expr.axes,
                           tags=expr.tags)
Beispiel #4
0
    def map_index_lambda(self, expr: IndexLambda,
                         state: CodeGenState) -> ImplementedResult:
        if expr in state.results:
            return state.results[expr]

        prstnt_ctx = PersistentExpressionContext(state)
        local_ctx = LocalExpressionContext(local_namespace=expr.bindings,
                                           num_indices=expr.ndim,
                                           reduction_bounds={})
        loopy_expr = self.exprgen_mapper(expr.expr, prstnt_ctx, local_ctx)

        result: ImplementedResult = InlinedResult(loopy_expr, expr.ndim,
                                                  prstnt_ctx.depends_on)

        shape_to_scalar_expression(expr.shape, self,
                                   state)  # walk over size params

        # {{{ implementation tag

        if expr.tags_of_type(ImplStored):
            name = _generate_name_for_temp(expr, state)
            result = StoredResult(
                name, expr.ndim,
                frozenset([add_store(name, expr, result, state, self, True)]))
        # }}}

        state.results[expr] = result
        return result
Beispiel #5
0
    def map_stack(self, expr: Stack) -> Array:
        def get_subscript(array_index: int) -> SymbolicIndex:
            result = []
            for i in range(expr.ndim):
                if i != expr.axis:
                    result.append(var(f"_{i}"))
            return tuple(result)

        # I = axis index
        #
        # => If(_I == 0,
        #        _in0[_0, _1, ...],
        #        If(_I == 1,
        #            _in1[_0, _1, ...],
        #            ...
        #                _inNm1[_0, _1, ...] ...))
        for i in range(len(expr.arrays) - 1, -1, -1):
            subarray_expr = var(f"_in{i}")[get_subscript(i)]
            if i == len(expr.arrays) - 1:
                stack_expr = subarray_expr
            else:
                from pymbolic.primitives import If, Comparison
                stack_expr = If(Comparison(var(f"_{expr.axis}"), "==", i),
                                subarray_expr, stack_expr)

        bindings = {
            f"_in{i}": self.rec(array)
            for i, array in enumerate(expr.arrays)
        }

        return IndexLambda(namespace=self.namespace,
                           expr=stack_expr,
                           shape=expr.shape,
                           dtype=expr.dtype,
                           bindings=bindings)
Beispiel #6
0
    def map_non_contiguous_advanced_index(self,
                                          expr: AdvancedIndexInNoncontiguousAxes
                                          ) -> IndexLambda:
        from pytato.utils import (get_shape_after_broadcasting,
                                  get_indexing_expression)
        i_adv_indices = tuple(i
                              for i, idx_expr in enumerate(expr.indices)
                              if isinstance(idx_expr, (Array, INT_CLASSES)))
        adv_idx_shape = get_shape_after_broadcasting([expr.indices[i_idx]
                                                      for i_idx in i_adv_indices])

        vng = UniqueNameGenerator()
        indices = []

        in_ary = vng("in")
        bindings = {in_ary: self.rec(expr.array)}

        islice_idx = len(adv_idx_shape)

        for idx, axis_len in zip(expr.indices, expr.array.shape):
            if isinstance(idx, INT_CLASSES):
                if isinstance(axis_len, INT_CLASSES):
                    indices.append(idx % axis_len)
                else:
                    bnd_name = vng("in")
                    bindings[bnd_name] = self.rec(axis_len)
                    indices.append(idx % prim.Variable(bnd_name))
            elif isinstance(idx, NormalizedSlice):
                indices.append(idx.start
                               + idx.step * prim.Variable(f"_{islice_idx}"))
                islice_idx += 1
            elif isinstance(idx, Array):
                if isinstance(axis_len, INT_CLASSES):
                    bnd_name = vng("in")
                    bindings[bnd_name] = self.rec(idx)

                    indirect_idx_expr = prim.Subscript(prim.Variable(bnd_name),
                                                       get_indexing_expression(
                                                           idx.shape,
                                                           adv_idx_shape))

                    if not idx.tags_of_type(AssumeNonNegative):
                        indirect_idx_expr = indirect_idx_expr % axis_len

                    indices.append(indirect_idx_expr)
                else:
                    raise NotImplementedError("Advanced indexing over"
                                              " parametric axis lengths.")
            else:
                raise NotImplementedError(f"Indices of type {type(idx)}.")

        return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary),
                                               tuple(indices)),
                           bindings=bindings,
                           shape=expr.shape,
                           dtype=expr.dtype,
                           axes=expr.axes,
                           tags=expr.tags,
                           )
Beispiel #7
0
 def map_index_lambda(self, expr: IndexLambda) -> Array:
     bindings = {
             name: self.rec(subexpr)
             for name, subexpr in expr.bindings.items()}
     return IndexLambda(namespace=self.namespace,
             expr=expr.expr,
             shape=expr.shape,
             dtype=expr.dtype,
             bindings=bindings,
             tags=expr.tags)
Beispiel #8
0
    def map_concatenate(self, expr: Concatenate) -> Array:
        from pymbolic.primitives import If, Comparison, Subscript

        def get_subscript(array_index: int,
                          offset: ScalarExpression) -> Subscript:
            aggregate = var(f"_in{array_index}")
            index = [
                var(f"_{i}") if i != expr.axis else (var(f"_{i}") - offset)
                for i in range(len(expr.shape))
            ]
            return Subscript(aggregate, tuple(index))

        lbounds: List[Any] = [0]
        ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]]

        for i, array in enumerate(expr.arrays[1:], start=1):
            ubounds.append(ubounds[i - 1] + array.shape[expr.axis])
            lbounds.append(ubounds[i - 1])

        # I = axis index
        #
        # => If(0<=_I < arrays[0].shape[axis],
        #        _in0[_0, _1, ..., _I, ...],
        #        If(arrays[0].shape[axis]<= _I < (arrays[1].shape[axis]
        #                                         +arrays[0].shape[axis]),
        #            _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...],
        #            ...
        #                _inNm1[_0, _1, ...] ...))
        for i in range(len(expr.arrays) - 1, -1, -1):
            lbound, ubound = lbounds[i], ubounds[i]
            subarray_expr = get_subscript(i, lbound)
            if i == len(expr.arrays) - 1:
                stack_expr = subarray_expr
            else:
                stack_expr = If(
                    Comparison(var(f"_{expr.axis}"), ">=", lbound)
                    and Comparison(var(f"_{expr.axis}"), "<", ubound),
                    subarray_expr, stack_expr)

        bindings = {
            f"_in{i}": self.rec(array)
            for i, array in enumerate(expr.arrays)
        }

        return IndexLambda(namespace=self.namespace,
                           expr=stack_expr,
                           shape=expr.shape,
                           dtype=expr.dtype,
                           bindings=bindings)
Beispiel #9
0
def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...],
                          func_name: str,
                          ret_dtype: Optional[_dtype_any] = None,
                          np_func_name: Optional[str] = None
                          ) -> ArrayOrScalar:
    if all(isinstance(x, SCALAR_CLASSES) for x in inputs):
        if np_func_name is None:
            np_func_name = func_name

        np_func = getattr(np, np_func_name)
        return np_func(*inputs)  # type: ignore

    if not inputs:
        raise ValueError("at least one argument must be present")

    shape = None

    sym_args = []
    bindings = {}
    for index, inp in enumerate(inputs):
        if isinstance(inp, Array):
            if inp.dtype.kind not in ["f", "c"]:
                raise ValueError("only floating-point or complex "
                        "arguments supported")

            if shape is None:
                shape = inp.shape
            elif inp.shape != shape:
                # FIXME: merge this logic with arithmetic, so that broadcasting
                # is implemented properly
                raise NotImplementedError("broadcasting in function application")

            if ret_dtype is None:
                ret_dtype = inp.dtype

            bindings[f"in_{index}"] = inp
            sym_args.append(
                    prim.Subscript(var(f"in_{index}"),
                        tuple(var(f"_{i}") for i in range(len(shape)))))
        else:
            sym_args.append(inp)

    assert shape is not None
    assert ret_dtype is not None

    return IndexLambda(
            prim.Call(var(f"pytato.c99.{func_name}"), tuple(sym_args)),
            shape, ret_dtype, bindings, axes=_get_default_axes(len(shape)))
Beispiel #10
0
    def handle_index_remapping(self, indices_getter: Callable[
        [CodeGenPreprocessor, Array], SymbolicIndex],
                               expr: IndexRemappingBase) -> Array:
        indices = indices_getter(self, expr)

        index_expr = var("_in0")
        if indices:
            index_expr = index_expr[indices]

        array = self.rec(expr.array)

        return IndexLambda(namespace=self.namespace,
                           expr=index_expr,
                           shape=expr.shape,
                           dtype=expr.dtype,
                           bindings=dict(_in0=array))
Beispiel #11
0
    def handle_index_remapping(self,
            indices_getter: Callable[[CodeGenPreprocessor, Array], SymbolicIndex],
            expr: IndexRemappingBase) -> Array:
        indices = indices_getter(self, expr)

        index_expr = var("_in0")
        if indices:
            index_expr = index_expr[indices]

        array = self.rec(expr.array)

        return IndexLambda(expr=index_expr,
                shape=tuple(self.rec(s) if isinstance(s, Array) else s
                            for s in expr.shape),
                dtype=expr.dtype,
                bindings=dict(_in0=array),
                axes=expr.axes,
                tags=expr.tags)
Beispiel #12
0
    def map_einsum(self, expr: Einsum) -> Array:
        import operator
        from functools import reduce
        from pytato.scalar_expr import Reduce
        from pytato.utils import (dim_to_index_lambda_components,
                                  are_shape_components_equal)
        from pytato.array import ElementwiseAxis, ReductionAxis

        bindings = {f"in{k}": self.rec(arg) for k, arg in enumerate(expr.args)}
        redn_bounds: Dict[str, Tuple[ScalarExpression, ScalarExpression]] = {}
        args_as_pym_expr: List[prim.Subscript] = []
        namegen = UniqueNameGenerator(set(bindings))

        # {{{ add bindings coming from the shape expressions

        for access_descr, (iarg, arg) in zip(expr.access_descriptors,
                                            enumerate(expr.args)):
            subscript_indices = []
            for iaxis, axis in enumerate(access_descr):
                if not are_shape_components_equal(
                            arg.shape[iaxis],
                            expr._access_descr_to_axis_len()[axis]):
                    # axis is broadcasted
                    assert are_shape_components_equal(arg.shape[iaxis], 1)
                    subscript_indices.append(0)
                    continue

                if isinstance(axis, ElementwiseAxis):
                    subscript_indices.append(prim.Variable(f"_{axis.dim}"))
                else:
                    assert isinstance(axis, ReductionAxis)
                    redn_idx_name = f"_r{axis.dim}"
                    if redn_idx_name not in redn_bounds:
                        # convert the ShapeComponent to a ScalarExpression
                        redn_bound, redn_bound_bindings = (
                            dim_to_index_lambda_components(
                                arg.shape[iaxis], namegen))
                        redn_bounds[redn_idx_name] = (0, redn_bound)

                        bindings.update({k: self.rec(v)
                                         for k, v in redn_bound_bindings.items()})

                    subscript_indices.append(prim.Variable(redn_idx_name))

            args_as_pym_expr.append(prim.Subscript(prim.Variable(f"in{iarg}"),
                                                   tuple(subscript_indices)))

        # }}}

        inner_expr = reduce(operator.mul, args_as_pym_expr[1:],
                            args_as_pym_expr[0])

        if redn_bounds:
            from pytato.reductions import SumReductionOperation
            inner_expr = Reduce(inner_expr,
                                SumReductionOperation(),
                                redn_bounds)

        return IndexLambda(expr=inner_expr,
                           shape=tuple(self.rec(s) if isinstance(s, Array) else s
                                       for s in expr.shape),
                           dtype=expr.dtype,
                           bindings=bindings,
                           axes=expr.axes,
                           tags=expr.tags)
Beispiel #13
0
def test_array_dot_repr():
    x = pt.make_placeholder("x", (10, 4), np.int64)
    y = pt.make_placeholder("y", (10, 4), np.int64)

    def _assert_stripped_repr(ary: pt.Array, expected_repr: str):
        expected_str = "".join([c for c in repr(ary) if c not in [" ", "\n"]])
        result_str = "".join([c for c in expected_repr if c not in [" ", "\n"]])
        assert expected_str == result_str

    _assert_stripped_repr(
        3*x + 4*y,
        """
IndexLambda(
    expr=Sum((Subscript(Variable('_in0'),
                        (Variable('_0'), Variable('_1'))),
              Subscript(Variable('_in1'),
                        (Variable('_0'), Variable('_1'))))),
    shape=(10, 4),
    dtype='int64',
    bindings={'_in0': IndexLambda(expr=Product((3, Subscript(Variable('_in1'),
                                                             (Variable('_0'),
                                                              Variable('_1'))))),
                                  shape=(10, 4),
                                  dtype='int64',
                                  bindings={'_in1': Placeholder(shape=(10, 4),
                                                                dtype='int64',
                                                                name='x')}),
              '_in1': IndexLambda(expr=Product((4, Subscript(Variable('_in1'),
                                                             (Variable('_0'),
                                                              Variable('_1'))))),
                                  shape=(10, 4),
                                  dtype='int64',
                                  bindings={'_in1': Placeholder(shape=(10, 4),
                                                                dtype='int64',
                                                                name='y')})})""")

    _assert_stripped_repr(
        pt.roll(x.reshape(2, 20).reshape(-1), 3),
        """
Roll(
    array=Reshape(array=Reshape(array=Placeholder(shape=(10, 4),
                                                  dtype='int64',
                                                  name='x'),
                                newshape=(2, 20),
                                order='C'),
                  newshape=(40),
                  order='C'),
    shift=3, axis=0)""")
    _assert_stripped_repr(y * pt.not_equal(x, 3),
                          """
IndexLambda(
    expr=Product((Subscript(Variable('_in0'),
                            (Variable('_0'), Variable('_1'))),
                  Subscript(Variable('_in1'),
                            (Variable('_0'), Variable('_1'))))),
    shape=(10, 4),
    dtype='int64',
    bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='y'),
              '_in1': IndexLambda(
                  expr=Comparison(Subscript(Variable('_in0'),
                                            (Variable('_0'), Variable('_1'))),
                                  '!=',
                                  3),
                  shape=(10, 4),
                  dtype=<class 'numpy.bool_'>,
                  bindings={'_in0': Placeholder(shape=(10, 4),
                                                dtype='int64',
                                                name='x')})})""")
    _assert_stripped_repr(
        x[y[:, 2:3], x[2, :]],
        """
AdvancedIndexInContiguousAxes(
    array=Placeholder(shape=(10, 4), dtype='int64', name='x'),
    indices=(BasicIndex(array=Placeholder(shape=(10, 4),
                                          dtype='int64',
                                          name='y'),
                        indices=(NormalizedSlice(start=0, stop=10, step=1),
                                 NormalizedSlice(start=2, stop=3, step=1))),
             BasicIndex(array=Placeholder(shape=(10, 4),
                                          dtype='int64',
                                          name='x'),
                        indices=(2, NormalizedSlice(start=0, stop=4, step=1)))))""")

    _assert_stripped_repr(
        pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]),
        """
Stack(
    arrays=(
        AxisPermutation(
            array=AdvancedIndexInContiguousAxes(
                array=Placeholder(shape=(10, 4),
                                  dtype='int64',
                                  name='x'),
                indices=(BasicIndex(array=(...),
                                    indices=(NormalizedSlice(start=0,
                                                             stop=10,
                                                             step=1),
                                             NormalizedSlice(start=2,
                                                             stop=3,
                                                             step=1))),
                         BasicIndex(array=(...),
                                    indices=(2,
                                             NormalizedSlice(start=0,
                                                             stop=4,
                                                             step=1))))),
            axis_permutation=(1, 0)),
        AxisPermutation(array=AdvancedIndexInContiguousAxes(
            array=Placeholder(shape=(10,
                                     4),
                              dtype='int64',
                              name='y'),
            indices=(BasicIndex(array=(...),
                                indices=(NormalizedSlice(start=0,
                                                         stop=10,
                                                         step=1),
                                         NormalizedSlice(start=2,
                                                         stop=3,
                                                         step=1))),
                     BasicIndex(array=(...),
                                indices=(2,
                                         NormalizedSlice(start=0,
                                                         stop=4,
                                                         step=1))))),
                        axis_permutation=(1, 0))), axis=0)
    """)