Пример #1
0
    def test_count_optional_arg_type_check(self):
        pyfunc = count_with_start_end_usecase

        def try_compile_bad_optional(*args):
            bad_sig = types.int64(types.unicode_type, types.unicode_type,
                                  types.Optional(types.float64),
                                  types.Optional(types.float64))
            njit([bad_sig])(pyfunc)

        with self.assertRaises(TypingError) as raises:
            try_compile_bad_optional('tú quis?', 'tú', 1.1, 1.1)
        self.assertIn('The slice indices must be an Integer or None',
                      str(raises.exception))

        error_msg = "%s\n%s" % ("'{0}'.py_count('{1}', {2}, {3}) = {4}",
                                "'{0}'.c_count_op('{1}', {2}, {3}) = {5}")
        sig_optional = types.int64(types.unicode_type, types.unicode_type,
                                   types.Optional(types.int64),
                                   types.Optional(types.int64))
        cfunc_optional = njit([sig_optional])(pyfunc)

        py_result = pyfunc('tú quis?', 'tú', 0, 8)
        c_result = cfunc_optional('tú quis?', 'tú', 0, 8)
        self.assertEqual(
            py_result, c_result,
            error_msg.format('tú quis?', 'tú', 0, 8, py_result, c_result))
Пример #2
0
    def unify_pairs(self, first, second):
        """
        Choose PyObject type as the abstract if we fail to determine a concrete
        type.
        """
        # TODO: should add an option to reject unsafe type conversion
        if types.none in (first, second):
            if first == types.none:
                return types.Optional(second)
            elif second == types.none:
                return types.Optional(first)

        # Handle optional type
        # XXX: really need to refactor type infer to reduce the number of
        #      special cases
        if (isinstance(first, types.Optional)
                or isinstance(second, types.Optional)):
            a = (first.type if isinstance(first, types.Optional) else first)
            b = (second.type if isinstance(second, types.Optional) else second)
            return types.Optional(self.unify_pairs(a, b))

        d = self.type_compatibility(fromty=first, toty=second)
        if d is None:
            # Complex is not allowed to downcast implicitly.
            # Need to try the other direction of implicit cast to find the
            # most general type of the two.
            first, second = second, first  # swap operand order
            d = self.type_compatibility(fromty=first, toty=second)

        if d is None:
            return types.pyobject
        elif d == 'exact':
            # Same type
            return first
        elif d == 'promote':
            return second
        elif d in ('safe', 'unsafe'):
            if first in types.number_domain and second in types.number_domain:
                a = numpy.dtype(str(first))
                b = numpy.dtype(str(second))
                # Just use NumPy coercion rules
                sel = numpy.promote_types(a, b)
                # Convert NumPy dtype back to Numba types
                return getattr(types, str(sel))
            elif (isinstance(first, types.UniTuple)
                  and isinstance(second, types.UniTuple)):
                a = numpy.dtype(str(first.dtype))
                b = numpy.dtype(str(second.dtype))
                if a > b:
                    return first
                else:
                    return second
            else:
                msg = "unrecognized '{0}' unify for {1} and {2}"
                raise TypeError(msg.format(d, first, second))
        elif d in 'int-tuple-coerce':
            return types.UniTuple(dtype=types.intp, count=len(first))
        else:
            raise Exception("type_compatibility returned %s" % d)
Пример #3
0
def hashmap_lookup(typingctx, dict_type, key_type):

    ty_key, ty_val = dict_type.key_type, dict_type.value_type
    return_type = types.Tuple([types.bool_, types.Optional(ty_val)])
    key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val)

    def codegen(context, builder, sig, args):
        dict_val, key_val = args

        key_val, lir_key_type = transform_input_arg(context, builder, key_type,
                                                    key_val)
        native_value_ptr, lir_value_type = alloc_native_value(
            context, builder, ty_val)

        cdict = cgutils.create_struct_proxy(dict_type)(context,
                                                       builder,
                                                       value=dict_val)
        fnty = lir.FunctionType(lir.IntType(8), [
            lir.IntType(8).as_pointer(), lir_key_type,
            lir_value_type.as_pointer()
        ])
        func_name = f"hashmap_lookup_{key_type_postfix}_to_{value_type_postfix}"
        fn_hashmap_lookup = cgutils.get_or_insert_function(builder.module,
                                                           fnty,
                                                           name=func_name)

        status = builder.call(fn_hashmap_lookup,
                              [cdict.data_ptr, key_val, native_value_ptr])
        status_as_bool = context.cast(builder, status, types.uint8,
                                      types.bool_)

        # if key was not found nothing would be stored to native_value_ptr, so depending on status
        # we either deref it or not, wrapping final result into types.Optional value
        result_ptr = cgutils.alloca_once(
            builder, context.get_value_type(types.Optional(ty_val)))
        with builder.if_else(status_as_bool,
                             likely=True) as (if_ok, if_not_ok):
            with if_ok:
                native_value = builder.load(native_value_ptr)
                result_value = transform_native_val(context, builder, ty_val,
                                                    native_value)

                if context.enable_nrt:
                    context.nrt.incref(builder, ty_val, result_value)

                builder.store(
                    context.make_optional_value(builder, ty_val, result_value),
                    result_ptr)

            with if_not_ok:
                builder.store(context.make_optional_none(builder, ty_val),
                              result_ptr)

        opt_result = builder.load(result_ptr)
        return context.make_tuple(builder, return_type,
                                  [status_as_bool, opt_result])

    func_sig = return_type(dict_type, key_type)
    return func_sig, codegen
Пример #4
0
def _list_getitem_pop_helper(typingctx, l, index, op):
    """Wrap numba_list_getitem and numba_list_pop

    Returns 2-tuple of (int32, ?item_type)

    This is a helper that is parametrized on the type of operation, which can
    be either 'pop' or 'getitem'. This is because, signature wise, getitem and
    pop and are the same.
    """
    assert (op in ("pop", "getitem"))
    IS_NOT_NONE = not isinstance(l.item_type, types.NoneType)
    resty = types.Tuple([
        types.int32,
        types.Optional(l.item_type if IS_NOT_NONE else types.int64)
    ])
    sig = resty(l, index)

    def codegen(context, builder, sig, args):
        fnty = ir.FunctionType(
            ll_status,
            [ll_list_type, ll_ssize_t, ll_bytes],
        )
        [tl, tindex] = sig.args
        [l, index] = args
        fn = builder.module.get_or_insert_function(
            fnty, name='numba_list_{}'.format(op))

        dm_item = context.data_model_manager[tl.item_type]
        ll_item = context.get_data_type(tl.item_type)
        ptr_item = cgutils.alloca_once(builder, ll_item)

        lp = _container_get_data(context, builder, tl, l)
        status = builder.call(
            fn,
            [
                lp,
                index,
                _as_bytes(builder, ptr_item),
            ],
        )
        # Load item if output is available
        found = builder.icmp_signed('>=', status,
                                    status.type(int(ListStatus.LIST_OK)))
        out = context.make_optional_none(
            builder, tl.item_type if IS_NOT_NONE else types.int64)
        pout = cgutils.alloca_once_value(builder, out)

        with builder.if_then(found):
            if IS_NOT_NONE:
                item = dm_item.load_from_data_pointer(builder, ptr_item)
                context.nrt.incref(builder, tl.item_type, item)
                loaded = context.make_optional_value(builder, tl.item_type,
                                                     item)
                builder.store(loaded, pout)

        out = builder.load(pout)
        return context.make_tuple(builder, resty, [status, out])

    return sig, codegen
Пример #5
0
    def test_optional_to_optional(self):
        """
        Test error due mishandling of Optional to Optional casting

        Related issue: https://github.com/numba/numba/issues/1718
        """
        # Attempt to cast optional(intp) to optional(float64)
        opt_int = types.Optional(types.intp)
        opt_flt = types.Optional(types.float64)
        sig = opt_flt(opt_int)

        @njit(sig)
        def foo(a):
            return a

        self.assertEqual(foo(2), 2)
        self.assertIsNone(foo(None))
Пример #6
0
 def test_optional_tuple(self):
     # Unify to optional tuple
     aty = types.none
     bty = types.UniTuple(i32, 2)
     self.assert_unify(aty, bty, types.Optional(types.UniTuple(i32, 2)))
     aty = types.Optional(types.UniTuple(i16, 2))
     bty = types.UniTuple(i32, 2)
     self.assert_unify(aty, bty, types.Optional(types.UniTuple(i32, 2)))
     # Unify to tuple of optionals
     aty = types.Tuple((types.none, i32))
     bty = types.Tuple((i16, types.none))
     self.assert_unify(
         aty, bty, types.Tuple((types.Optional(i16), types.Optional(i32))))
     aty = types.Tuple((types.Optional(i32), i64))
     bty = types.Tuple((i16, types.Optional(i8)))
     self.assert_unify(
         aty, bty, types.Tuple((types.Optional(i32), types.Optional(i64))))
Пример #7
0
 def test_optional(self):
     aty = types.int32
     bty = types.Optional(i32)
     self.assert_can_convert(types.none, bty, Conversion.promote)
     self.assert_can_convert(aty, bty, Conversion.promote)
     self.assert_cannot_convert(bty, types.none)
     self.assert_can_convert(bty, aty, Conversion.safe)  # XXX ???
     # Optional array
     aty = types.Array(i32, 2, "C")
     bty = types.Optional(aty)
     self.assert_can_convert(types.none, bty, Conversion.promote)
     self.assert_can_convert(aty, bty, Conversion.promote)
     self.assert_can_convert(bty, aty, Conversion.safe)
     aty = types.Array(i32, 2, "C")
     bty = types.Optional(aty.copy(layout="A"))
     self.assert_can_convert(aty, bty, Conversion.safe)  # C -> A
     self.assert_cannot_convert(bty, aty)  # A -> C
     aty = types.Array(i32, 2, "C")
     bty = types.Optional(aty.copy(layout="F"))
     self.assert_cannot_convert(aty, bty)
     self.assert_cannot_convert(bty, aty)
Пример #8
0
def _dict_lookup(typingctx, d, key, hashval):
    """Wrap numba_dict_lookup

    Returns 2-tuple of (intp, ?value_type)
    """
    resty = types.Tuple([types.intp, types.Optional(d.value_type)])
    sig = resty(d, key, hashval)

    def codegen(context, builder, sig, args):
        fnty = ir.FunctionType(
            ll_ssize_t,
            [ll_dict_type, ll_bytes, ll_hash, ll_bytes],
        )
        [td, tkey, thashval] = sig.args
        [d, key, hashval] = args
        fn = builder.module.get_or_insert_function(fnty,
                                                   name='numba_dict_lookup')

        dm_key = context.data_model_manager[tkey]
        dm_val = context.data_model_manager[td.value_type]

        data_key = dm_key.as_data(builder, key)
        ptr_key = cgutils.alloca_once_value(builder, data_key)

        ll_val = context.get_data_type(td.value_type)
        ptr_val = cgutils.alloca_once(builder, ll_val)

        dp = _dict_get_data(context, builder, td, d)
        ix = builder.call(
            fn,
            [
                dp,
                _as_bytes(builder, ptr_key),
                hashval,
                _as_bytes(builder, ptr_val),
            ],
        )
        # Load value if output is available
        found = builder.icmp_signed('>=', ix, ix.type(int(DKIX.EMPTY)))

        out = context.make_optional_none(builder, td.value_type)
        pout = cgutils.alloca_once_value(builder, out)

        with builder.if_then(found):
            val = dm_val.load_from_data_pointer(builder, ptr_val)
            context.nrt.incref(builder, td.value_type, val)
            loaded = context.make_optional_value(builder, td.value_type, val)
            builder.store(loaded, pout)

        out = builder.load(pout)
        return context.make_tuple(builder, resty, [ix, out])

    return sig, codegen
Пример #9
0
    def test_strip(self):

        STRIP_CASES = [('ass cii', 'ai'), ('ass cii', None), ('asscii', 'ai '),
                       ('asscii ', 'ai '), (' asscii  ', 'ai '),
                       (' asscii  ', 'asci '), (' asscii  ', 's'),
                       ('      ', ' '), ('', ' '), ('', ''),
                       ('  asscii  ', 'ai '), ('  asscii  ', ''),
                       ('  asscii  ', None), ('tú quién te crees?', 'étú? '),
                       ('  tú quién te crees?   ', 'étú? '),
                       ('  tú qrees?   ', ''),
                       ('  tú quién te crees?   ', None),
                       ('大处 着眼,小处着手。大大大处', '大处'), (' 大处大处  ', ''),
                       (' 大处大处  ', None)]

        # form with no parameter
        for pyfunc, case_name in [(strip_usecase, 'strip'),
                                  (lstrip_usecase, 'lstrip'),
                                  (rstrip_usecase, 'rstrip')]:
            cfunc = njit(pyfunc)

            for string, chars in STRIP_CASES:
                self.assertEqual(pyfunc(string), cfunc(string),
                                 "'%s'.%s()?" % (string, case_name))
        # parametrized form
        for pyfunc, case_name in [(strip_usecase_chars, 'strip'),
                                  (lstrip_usecase_chars, 'lstrip'),
                                  (rstrip_usecase_chars, 'rstrip')]:
            cfunc = njit(pyfunc)

            sig1 = types.unicode_type(types.unicode_type,
                                      types.Optional(types.unicode_type))
            cfunc_optional = njit([sig1])(pyfunc)

            def try_compile_bad_optional(*args):
                bad = types.unicode_type(types.unicode_type,
                                         types.Optional(types.float64))
                njit([bad])(pyfunc)

            for fn in cfunc, try_compile_bad_optional:
                with self.assertRaises(TypingError) as raises:
                    fn('tú quis?', 1.1)
                self.assertIn('The arg must be a UnicodeType or None',
                              str(raises.exception))

            for fn in cfunc, cfunc_optional:

                for string, chars in STRIP_CASES:
                    self.assertEqual(
                        pyfunc(string, chars), fn(string, chars),
                        "'%s'.%s('%s')?" % (string, case_name, chars))
Пример #10
0
def _dict_popitem(typingctx, d):
    """Wrap numba_dict_popitem
    """

    keyvalty = types.Tuple([d.key_type, d.value_type])
    resty = types.Tuple([types.int32, types.Optional(keyvalty)])
    sig = resty(d)

    def codegen(context, builder, sig, args):
        fnty = ir.FunctionType(
            ll_status,
            [ll_dict_type, ll_bytes, ll_bytes],
        )
        [d] = args
        [td] = sig.args
        fn = builder.module.get_or_insert_function(fnty,
                                                   name='numba_dict_popitem')

        dm_key = context.data_model_manager[td.key_type]
        dm_val = context.data_model_manager[td.value_type]

        ptr_key = cgutils.alloca_once(builder, dm_key.get_data_type())
        ptr_val = cgutils.alloca_once(builder, dm_val.get_data_type())

        dp = _dict_get_data(context, builder, td, d)
        status = builder.call(
            fn,
            [
                dp,
                _as_bytes(builder, ptr_key),
                _as_bytes(builder, ptr_val),
            ],
        )
        out = context.make_optional_none(builder, keyvalty)
        pout = cgutils.alloca_once_value(builder, out)

        cond = builder.icmp_signed('==', status, status.type(int(Status.OK)))
        with builder.if_then(cond):
            key = dm_key.load_from_data_pointer(builder, ptr_key)
            val = dm_val.load_from_data_pointer(builder, ptr_val)
            keyval = context.make_tuple(builder, keyvalty, [key, val])
            optkeyval = context.make_optional_value(builder, keyvalty, keyval)
            builder.store(optkeyval, pout)

        out = builder.load(pout)
        return cgutils.pack_struct(builder, [status, out])

    return sig, codegen
Пример #11
0
 def test_none_to_optional(self):
     """
     Test unification of `none` and multiple number types to optional type
     """
     ctx = typing.Context()
     for tys in itertools.combinations(types.number_domain, 2):
         # First unify without none, to provide the control value
         tys = list(tys)
         expected = types.Optional(ctx.unify_types(*tys))
         results = [
             ctx.unify_types(*comb)
             for comb in itertools.permutations(tys + [types.none])
         ]
         # All results must be equal
         for res in results:
             self.assertEqual(res, expected)
Пример #12
0
 def test_optional(self):
     aty = types.Optional(i32)
     bty = types.none
     self.assert_unify(aty, bty, aty)
     aty = types.Optional(i32)
     bty = types.Optional(i64)
     self.assert_unify(aty, bty, bty)
     aty = types.Optional(i32)
     bty = i64
     self.assert_unify(aty, bty, types.Optional(i64))
     # Failure
     aty = types.Optional(i32)
     bty = types.Optional(types.len_type)
     self.assert_unify_failure(aty, bty)
Пример #13
0
    def test_optional_unpack(self):
        """
        Issue 2171
        """
        def pyfunc(x):
            if x is None:
                return
            else:
                a, b = x
                return a, b

        tup = types.Tuple([types.intp] * 2)
        opt_tup = types.Optional(tup)
        sig = (opt_tup, )
        cfunc = njit(sig)(pyfunc)
        self.assertEqual(pyfunc(None), cfunc(None))
        self.assertEqual(pyfunc((1, 2)), cfunc((1, 2)))
Пример #14
0
    def get_return_type(self, typemap):
        rettypes = set()
        for blk in utils.dict_itervalues(self.blocks):
            term = blk.terminator
            if isinstance(term, ir.Return):
                rettypes.add(typemap[term.value.name])

        if types.none in rettypes:
            # Special case None return
            rettypes = rettypes - set([types.none])
            if rettypes:
                unified = self.context.unify_types(*rettypes)
                return types.Optional(unified)
            else:
                return types.none
        else:
            unified = self.context.unify_types(*rettypes)
            return unified
Пример #15
0
 def try_compile_bad_optional(*args):
     bad = types.unicode_type(types.unicode_type,
                              types.Optional(types.float64))
     njit([bad])(pyfunc)
Пример #16
0
    def codegen(context, builder, sig, args):
        dict_val, key_val = args

        key_val, lir_key_type = transform_input_arg(context, builder, key_type,
                                                    key_val)

        # unlike in lookup operation we allocate value here and pass into native function
        # voidptr to allocated data, which copies and frees it's copy
        if isinstance(ty_val, types.Number):
            ret_val_ptr, lir_val_type = alloc_native_value(
                context, builder, ty_val)
        else:
            lir_val_type = context.get_value_type(ty_val)
            ret_val_ptr = cgutils.alloca_once(builder, lir_val_type)

        llvoidptr = context.get_value_type(types.voidptr)
        ret_val_ptr = builder.bitcast(ret_val_ptr, llvoidptr)

        cdict = cgutils.create_struct_proxy(dict_type)(context,
                                                       builder,
                                                       value=dict_val)
        fnty = lir.FunctionType(lir.IntType(8), [
            lir.IntType(8).as_pointer(),
            lir_key_type,
            llvoidptr,
        ])
        func_name = f"hashmap_pop_{key_type_postfix}_to_{value_type_postfix}"
        fn_hashmap_pop = cgutils.get_or_insert_function(builder.module,
                                                        fnty,
                                                        name=func_name)

        status = builder.call(fn_hashmap_pop,
                              [cdict.data_ptr, key_val, ret_val_ptr])
        status_as_bool = context.cast(builder, status, types.uint8,
                                      types.bool_)

        # same logic to handle non-existing key as in hashmap_lookup
        result_ptr = cgutils.alloca_once(
            builder, context.get_value_type(types.Optional(ty_val)))
        with builder.if_else(status_as_bool,
                             likely=True) as (if_ok, if_not_ok):
            with if_ok:

                ret_val_ptr = builder.bitcast(ret_val_ptr,
                                              lir_val_type.as_pointer())
                native_value = builder.load(ret_val_ptr)
                if isinstance(ty_val, types.Number):
                    reduced_value_type = reduced_type_map.get(ty_val, ty_val)
                    native_value = context.cast(builder, native_value,
                                                reduced_value_type, ty_val)

                # no incref of the value here, since it was removed from the dict
                # w/o decref to consider the case when value in the dict had refcnt == 1

                builder.store(
                    context.make_optional_value(builder, ty_val, native_value),
                    result_ptr)

            with if_not_ok:
                builder.store(context.make_optional_none(builder, ty_val),
                              result_ptr)

        opt_result = builder.load(result_ptr)
        return context.make_tuple(builder, return_type,
                                  [status_as_bool, opt_result])
Пример #17
0
 def try_compile_bad_optional(*args):
     bad_sig = types.int64(types.unicode_type, types.unicode_type,
                           types.Optional(types.float64),
                           types.Optional(types.float64))
     njit([bad_sig])(pyfunc)
Пример #18
0
def make_optional(valtype):
    """
    Return the Structure representation of a optional value
    """
    return cgutils.create_struct_proxy(types.Optional(valtype))
Пример #19
0
 def test_optional(self):
     ty = types.Optional(types.int32)
     self.check_pickling(ty)
Пример #20
0
 def test_cleanup_optional(self):
     mem = memoryview(bytearray(b"xyz"))
     tp = types.Optional(types.Buffer(types.intc, 1, 'C'))
     self.check_argument_cleanup(tp, mem)
Пример #21
0
 def try_compile_wrong_end_optional(*args):
     wrong_sig_optional = types.int64(types.unicode_type,
                                      types.unicode_type,
                                      types.Optional(types.intp),
                                      types.Optional(types.float64))
     njit([wrong_sig_optional])(rfind_with_start_end_usecase)
Пример #22
0
 def make_optional_none(self, builder, valtype):
     optval = self.make_helper(builder, types.Optional(valtype))
     optval.valid = cgutils.false_bit
     return optval._getvalue()
Пример #23
0
 def make_optional_value(self, builder, valtype, value):
     optval = self.make_helper(builder, types.Optional(valtype))
     optval.valid = cgutils.true_bit
     optval.data = value
     return optval._getvalue()
Пример #24
0
import numpy as np
from numba import types, jit, prange

from stratego_env.game.stratego_procedural_impl import _get_state_from_player_perspective, _get_valid_moves_as_1d_mask, \
    _get_partially_observable_observation, _get_fully_observable_observation, _get_valid_moves_as_spatial_mask, \
    _get_action_1d_index_from_player_perspective, _get_next_state, _get_game_ended, _get_partially_observable_observation_extended_channels, \
    _get_fully_observable_observation_extended_channels, _get_heuristic_rewards_from_move

STSH = "int64[:, :, ::1]"
OTSH = "float32[:, :, ::1]"
COTSH = types.Optional(types.float32[:, :, ::1])
MSTSH = "int64[:, :, :, ::1]"
MOTSH = "float32[:, :, :, :, ::1]"

VTSH = "int64[:, :, ::1]"
MVTSH = "int64[:, :, :, :, ::1]"


@jit(
    f"void(int64, int64, int64, int64, int64, "
    f"float32[:, :, ::1], float32[:, :, ::1], float32[:, :, ::1], float32[:, :, ::1], boolean, boolean, boolean, boolean,"
    f"{STSH}, {VTSH}, {OTSH}, {OTSH}, {OTSH})",
    nopython=True,
    cache=True)
def _fill_observation(player, action_size, max_actions, rows, columns,
                      p_obs_mids, p_ops_range, f_obs_mids, f_obs_range,
                      partial_out, full_out, internal_out, extended_channels,
                      state, valid_actions_mask, partial_observation,
                      full_observation, internal_state):
    pstate = _get_state_from_player_perspective(state=state, player=player)