def test_count_optional_arg_type_check(self): pyfunc = count_with_start_end_usecase def try_compile_bad_optional(*args): bad_sig = types.int64(types.unicode_type, types.unicode_type, types.Optional(types.float64), types.Optional(types.float64)) njit([bad_sig])(pyfunc) with self.assertRaises(TypingError) as raises: try_compile_bad_optional('tú quis?', 'tú', 1.1, 1.1) self.assertIn('The slice indices must be an Integer or None', str(raises.exception)) error_msg = "%s\n%s" % ("'{0}'.py_count('{1}', {2}, {3}) = {4}", "'{0}'.c_count_op('{1}', {2}, {3}) = {5}") sig_optional = types.int64(types.unicode_type, types.unicode_type, types.Optional(types.int64), types.Optional(types.int64)) cfunc_optional = njit([sig_optional])(pyfunc) py_result = pyfunc('tú quis?', 'tú', 0, 8) c_result = cfunc_optional('tú quis?', 'tú', 0, 8) self.assertEqual( py_result, c_result, error_msg.format('tú quis?', 'tú', 0, 8, py_result, c_result))
def unify_pairs(self, first, second): """ Choose PyObject type as the abstract if we fail to determine a concrete type. """ # TODO: should add an option to reject unsafe type conversion if types.none in (first, second): if first == types.none: return types.Optional(second) elif second == types.none: return types.Optional(first) # Handle optional type # XXX: really need to refactor type infer to reduce the number of # special cases if (isinstance(first, types.Optional) or isinstance(second, types.Optional)): a = (first.type if isinstance(first, types.Optional) else first) b = (second.type if isinstance(second, types.Optional) else second) return types.Optional(self.unify_pairs(a, b)) d = self.type_compatibility(fromty=first, toty=second) if d is None: # Complex is not allowed to downcast implicitly. # Need to try the other direction of implicit cast to find the # most general type of the two. first, second = second, first # swap operand order d = self.type_compatibility(fromty=first, toty=second) if d is None: return types.pyobject elif d == 'exact': # Same type return first elif d == 'promote': return second elif d in ('safe', 'unsafe'): if first in types.number_domain and second in types.number_domain: a = numpy.dtype(str(first)) b = numpy.dtype(str(second)) # Just use NumPy coercion rules sel = numpy.promote_types(a, b) # Convert NumPy dtype back to Numba types return getattr(types, str(sel)) elif (isinstance(first, types.UniTuple) and isinstance(second, types.UniTuple)): a = numpy.dtype(str(first.dtype)) b = numpy.dtype(str(second.dtype)) if a > b: return first else: return second else: msg = "unrecognized '{0}' unify for {1} and {2}" raise TypeError(msg.format(d, first, second)) elif d in 'int-tuple-coerce': return types.UniTuple(dtype=types.intp, count=len(first)) else: raise Exception("type_compatibility returned %s" % d)
def hashmap_lookup(typingctx, dict_type, key_type): ty_key, ty_val = dict_type.key_type, dict_type.value_type return_type = types.Tuple([types.bool_, types.Optional(ty_val)]) key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) def codegen(context, builder, sig, args): dict_val, key_val = args key_val, lir_key_type = transform_input_arg(context, builder, key_type, key_val) native_value_ptr, lir_value_type = alloc_native_value( context, builder, ty_val) cdict = cgutils.create_struct_proxy(dict_type)(context, builder, value=dict_val) fnty = lir.FunctionType(lir.IntType(8), [ lir.IntType(8).as_pointer(), lir_key_type, lir_value_type.as_pointer() ]) func_name = f"hashmap_lookup_{key_type_postfix}_to_{value_type_postfix}" fn_hashmap_lookup = cgutils.get_or_insert_function(builder.module, fnty, name=func_name) status = builder.call(fn_hashmap_lookup, [cdict.data_ptr, key_val, native_value_ptr]) status_as_bool = context.cast(builder, status, types.uint8, types.bool_) # if key was not found nothing would be stored to native_value_ptr, so depending on status # we either deref it or not, wrapping final result into types.Optional value result_ptr = cgutils.alloca_once( builder, context.get_value_type(types.Optional(ty_val))) with builder.if_else(status_as_bool, likely=True) as (if_ok, if_not_ok): with if_ok: native_value = builder.load(native_value_ptr) result_value = transform_native_val(context, builder, ty_val, native_value) if context.enable_nrt: context.nrt.incref(builder, ty_val, result_value) builder.store( context.make_optional_value(builder, ty_val, result_value), result_ptr) with if_not_ok: builder.store(context.make_optional_none(builder, ty_val), result_ptr) opt_result = builder.load(result_ptr) return context.make_tuple(builder, return_type, [status_as_bool, opt_result]) func_sig = return_type(dict_type, key_type) return func_sig, codegen
def _list_getitem_pop_helper(typingctx, l, index, op): """Wrap numba_list_getitem and numba_list_pop Returns 2-tuple of (int32, ?item_type) This is a helper that is parametrized on the type of operation, which can be either 'pop' or 'getitem'. This is because, signature wise, getitem and pop and are the same. """ assert (op in ("pop", "getitem")) IS_NOT_NONE = not isinstance(l.item_type, types.NoneType) resty = types.Tuple([ types.int32, types.Optional(l.item_type if IS_NOT_NONE else types.int64) ]) sig = resty(l, index) def codegen(context, builder, sig, args): fnty = ir.FunctionType( ll_status, [ll_list_type, ll_ssize_t, ll_bytes], ) [tl, tindex] = sig.args [l, index] = args fn = builder.module.get_or_insert_function( fnty, name='numba_list_{}'.format(op)) dm_item = context.data_model_manager[tl.item_type] ll_item = context.get_data_type(tl.item_type) ptr_item = cgutils.alloca_once(builder, ll_item) lp = _container_get_data(context, builder, tl, l) status = builder.call( fn, [ lp, index, _as_bytes(builder, ptr_item), ], ) # Load item if output is available found = builder.icmp_signed('>=', status, status.type(int(ListStatus.LIST_OK))) out = context.make_optional_none( builder, tl.item_type if IS_NOT_NONE else types.int64) pout = cgutils.alloca_once_value(builder, out) with builder.if_then(found): if IS_NOT_NONE: item = dm_item.load_from_data_pointer(builder, ptr_item) context.nrt.incref(builder, tl.item_type, item) loaded = context.make_optional_value(builder, tl.item_type, item) builder.store(loaded, pout) out = builder.load(pout) return context.make_tuple(builder, resty, [status, out]) return sig, codegen
def test_optional_to_optional(self): """ Test error due mishandling of Optional to Optional casting Related issue: https://github.com/numba/numba/issues/1718 """ # Attempt to cast optional(intp) to optional(float64) opt_int = types.Optional(types.intp) opt_flt = types.Optional(types.float64) sig = opt_flt(opt_int) @njit(sig) def foo(a): return a self.assertEqual(foo(2), 2) self.assertIsNone(foo(None))
def test_optional_tuple(self): # Unify to optional tuple aty = types.none bty = types.UniTuple(i32, 2) self.assert_unify(aty, bty, types.Optional(types.UniTuple(i32, 2))) aty = types.Optional(types.UniTuple(i16, 2)) bty = types.UniTuple(i32, 2) self.assert_unify(aty, bty, types.Optional(types.UniTuple(i32, 2))) # Unify to tuple of optionals aty = types.Tuple((types.none, i32)) bty = types.Tuple((i16, types.none)) self.assert_unify( aty, bty, types.Tuple((types.Optional(i16), types.Optional(i32)))) aty = types.Tuple((types.Optional(i32), i64)) bty = types.Tuple((i16, types.Optional(i8))) self.assert_unify( aty, bty, types.Tuple((types.Optional(i32), types.Optional(i64))))
def test_optional(self): aty = types.int32 bty = types.Optional(i32) self.assert_can_convert(types.none, bty, Conversion.promote) self.assert_can_convert(aty, bty, Conversion.promote) self.assert_cannot_convert(bty, types.none) self.assert_can_convert(bty, aty, Conversion.safe) # XXX ??? # Optional array aty = types.Array(i32, 2, "C") bty = types.Optional(aty) self.assert_can_convert(types.none, bty, Conversion.promote) self.assert_can_convert(aty, bty, Conversion.promote) self.assert_can_convert(bty, aty, Conversion.safe) aty = types.Array(i32, 2, "C") bty = types.Optional(aty.copy(layout="A")) self.assert_can_convert(aty, bty, Conversion.safe) # C -> A self.assert_cannot_convert(bty, aty) # A -> C aty = types.Array(i32, 2, "C") bty = types.Optional(aty.copy(layout="F")) self.assert_cannot_convert(aty, bty) self.assert_cannot_convert(bty, aty)
def _dict_lookup(typingctx, d, key, hashval): """Wrap numba_dict_lookup Returns 2-tuple of (intp, ?value_type) """ resty = types.Tuple([types.intp, types.Optional(d.value_type)]) sig = resty(d, key, hashval) def codegen(context, builder, sig, args): fnty = ir.FunctionType( ll_ssize_t, [ll_dict_type, ll_bytes, ll_hash, ll_bytes], ) [td, tkey, thashval] = sig.args [d, key, hashval] = args fn = builder.module.get_or_insert_function(fnty, name='numba_dict_lookup') dm_key = context.data_model_manager[tkey] dm_val = context.data_model_manager[td.value_type] data_key = dm_key.as_data(builder, key) ptr_key = cgutils.alloca_once_value(builder, data_key) ll_val = context.get_data_type(td.value_type) ptr_val = cgutils.alloca_once(builder, ll_val) dp = _dict_get_data(context, builder, td, d) ix = builder.call( fn, [ dp, _as_bytes(builder, ptr_key), hashval, _as_bytes(builder, ptr_val), ], ) # Load value if output is available found = builder.icmp_signed('>=', ix, ix.type(int(DKIX.EMPTY))) out = context.make_optional_none(builder, td.value_type) pout = cgutils.alloca_once_value(builder, out) with builder.if_then(found): val = dm_val.load_from_data_pointer(builder, ptr_val) context.nrt.incref(builder, td.value_type, val) loaded = context.make_optional_value(builder, td.value_type, val) builder.store(loaded, pout) out = builder.load(pout) return context.make_tuple(builder, resty, [ix, out]) return sig, codegen
def test_strip(self): STRIP_CASES = [('ass cii', 'ai'), ('ass cii', None), ('asscii', 'ai '), ('asscii ', 'ai '), (' asscii ', 'ai '), (' asscii ', 'asci '), (' asscii ', 's'), (' ', ' '), ('', ' '), ('', ''), (' asscii ', 'ai '), (' asscii ', ''), (' asscii ', None), ('tú quién te crees?', 'étú? '), (' tú quién te crees? ', 'étú? '), (' tú qrees? ', ''), (' tú quién te crees? ', None), ('大处 着眼,小处着手。大大大处', '大处'), (' 大处大处 ', ''), (' 大处大处 ', None)] # form with no parameter for pyfunc, case_name in [(strip_usecase, 'strip'), (lstrip_usecase, 'lstrip'), (rstrip_usecase, 'rstrip')]: cfunc = njit(pyfunc) for string, chars in STRIP_CASES: self.assertEqual(pyfunc(string), cfunc(string), "'%s'.%s()?" % (string, case_name)) # parametrized form for pyfunc, case_name in [(strip_usecase_chars, 'strip'), (lstrip_usecase_chars, 'lstrip'), (rstrip_usecase_chars, 'rstrip')]: cfunc = njit(pyfunc) sig1 = types.unicode_type(types.unicode_type, types.Optional(types.unicode_type)) cfunc_optional = njit([sig1])(pyfunc) def try_compile_bad_optional(*args): bad = types.unicode_type(types.unicode_type, types.Optional(types.float64)) njit([bad])(pyfunc) for fn in cfunc, try_compile_bad_optional: with self.assertRaises(TypingError) as raises: fn('tú quis?', 1.1) self.assertIn('The arg must be a UnicodeType or None', str(raises.exception)) for fn in cfunc, cfunc_optional: for string, chars in STRIP_CASES: self.assertEqual( pyfunc(string, chars), fn(string, chars), "'%s'.%s('%s')?" % (string, case_name, chars))
def _dict_popitem(typingctx, d): """Wrap numba_dict_popitem """ keyvalty = types.Tuple([d.key_type, d.value_type]) resty = types.Tuple([types.int32, types.Optional(keyvalty)]) sig = resty(d) def codegen(context, builder, sig, args): fnty = ir.FunctionType( ll_status, [ll_dict_type, ll_bytes, ll_bytes], ) [d] = args [td] = sig.args fn = builder.module.get_or_insert_function(fnty, name='numba_dict_popitem') dm_key = context.data_model_manager[td.key_type] dm_val = context.data_model_manager[td.value_type] ptr_key = cgutils.alloca_once(builder, dm_key.get_data_type()) ptr_val = cgutils.alloca_once(builder, dm_val.get_data_type()) dp = _dict_get_data(context, builder, td, d) status = builder.call( fn, [ dp, _as_bytes(builder, ptr_key), _as_bytes(builder, ptr_val), ], ) out = context.make_optional_none(builder, keyvalty) pout = cgutils.alloca_once_value(builder, out) cond = builder.icmp_signed('==', status, status.type(int(Status.OK))) with builder.if_then(cond): key = dm_key.load_from_data_pointer(builder, ptr_key) val = dm_val.load_from_data_pointer(builder, ptr_val) keyval = context.make_tuple(builder, keyvalty, [key, val]) optkeyval = context.make_optional_value(builder, keyvalty, keyval) builder.store(optkeyval, pout) out = builder.load(pout) return cgutils.pack_struct(builder, [status, out]) return sig, codegen
def test_none_to_optional(self): """ Test unification of `none` and multiple number types to optional type """ ctx = typing.Context() for tys in itertools.combinations(types.number_domain, 2): # First unify without none, to provide the control value tys = list(tys) expected = types.Optional(ctx.unify_types(*tys)) results = [ ctx.unify_types(*comb) for comb in itertools.permutations(tys + [types.none]) ] # All results must be equal for res in results: self.assertEqual(res, expected)
def test_optional(self): aty = types.Optional(i32) bty = types.none self.assert_unify(aty, bty, aty) aty = types.Optional(i32) bty = types.Optional(i64) self.assert_unify(aty, bty, bty) aty = types.Optional(i32) bty = i64 self.assert_unify(aty, bty, types.Optional(i64)) # Failure aty = types.Optional(i32) bty = types.Optional(types.len_type) self.assert_unify_failure(aty, bty)
def test_optional_unpack(self): """ Issue 2171 """ def pyfunc(x): if x is None: return else: a, b = x return a, b tup = types.Tuple([types.intp] * 2) opt_tup = types.Optional(tup) sig = (opt_tup, ) cfunc = njit(sig)(pyfunc) self.assertEqual(pyfunc(None), cfunc(None)) self.assertEqual(pyfunc((1, 2)), cfunc((1, 2)))
def get_return_type(self, typemap): rettypes = set() for blk in utils.dict_itervalues(self.blocks): term = blk.terminator if isinstance(term, ir.Return): rettypes.add(typemap[term.value.name]) if types.none in rettypes: # Special case None return rettypes = rettypes - set([types.none]) if rettypes: unified = self.context.unify_types(*rettypes) return types.Optional(unified) else: return types.none else: unified = self.context.unify_types(*rettypes) return unified
def try_compile_bad_optional(*args): bad = types.unicode_type(types.unicode_type, types.Optional(types.float64)) njit([bad])(pyfunc)
def codegen(context, builder, sig, args): dict_val, key_val = args key_val, lir_key_type = transform_input_arg(context, builder, key_type, key_val) # unlike in lookup operation we allocate value here and pass into native function # voidptr to allocated data, which copies and frees it's copy if isinstance(ty_val, types.Number): ret_val_ptr, lir_val_type = alloc_native_value( context, builder, ty_val) else: lir_val_type = context.get_value_type(ty_val) ret_val_ptr = cgutils.alloca_once(builder, lir_val_type) llvoidptr = context.get_value_type(types.voidptr) ret_val_ptr = builder.bitcast(ret_val_ptr, llvoidptr) cdict = cgutils.create_struct_proxy(dict_type)(context, builder, value=dict_val) fnty = lir.FunctionType(lir.IntType(8), [ lir.IntType(8).as_pointer(), lir_key_type, llvoidptr, ]) func_name = f"hashmap_pop_{key_type_postfix}_to_{value_type_postfix}" fn_hashmap_pop = cgutils.get_or_insert_function(builder.module, fnty, name=func_name) status = builder.call(fn_hashmap_pop, [cdict.data_ptr, key_val, ret_val_ptr]) status_as_bool = context.cast(builder, status, types.uint8, types.bool_) # same logic to handle non-existing key as in hashmap_lookup result_ptr = cgutils.alloca_once( builder, context.get_value_type(types.Optional(ty_val))) with builder.if_else(status_as_bool, likely=True) as (if_ok, if_not_ok): with if_ok: ret_val_ptr = builder.bitcast(ret_val_ptr, lir_val_type.as_pointer()) native_value = builder.load(ret_val_ptr) if isinstance(ty_val, types.Number): reduced_value_type = reduced_type_map.get(ty_val, ty_val) native_value = context.cast(builder, native_value, reduced_value_type, ty_val) # no incref of the value here, since it was removed from the dict # w/o decref to consider the case when value in the dict had refcnt == 1 builder.store( context.make_optional_value(builder, ty_val, native_value), result_ptr) with if_not_ok: builder.store(context.make_optional_none(builder, ty_val), result_ptr) opt_result = builder.load(result_ptr) return context.make_tuple(builder, return_type, [status_as_bool, opt_result])
def try_compile_bad_optional(*args): bad_sig = types.int64(types.unicode_type, types.unicode_type, types.Optional(types.float64), types.Optional(types.float64)) njit([bad_sig])(pyfunc)
def make_optional(valtype): """ Return the Structure representation of a optional value """ return cgutils.create_struct_proxy(types.Optional(valtype))
def test_optional(self): ty = types.Optional(types.int32) self.check_pickling(ty)
def test_cleanup_optional(self): mem = memoryview(bytearray(b"xyz")) tp = types.Optional(types.Buffer(types.intc, 1, 'C')) self.check_argument_cleanup(tp, mem)
def try_compile_wrong_end_optional(*args): wrong_sig_optional = types.int64(types.unicode_type, types.unicode_type, types.Optional(types.intp), types.Optional(types.float64)) njit([wrong_sig_optional])(rfind_with_start_end_usecase)
def make_optional_none(self, builder, valtype): optval = self.make_helper(builder, types.Optional(valtype)) optval.valid = cgutils.false_bit return optval._getvalue()
def make_optional_value(self, builder, valtype, value): optval = self.make_helper(builder, types.Optional(valtype)) optval.valid = cgutils.true_bit optval.data = value return optval._getvalue()
import numpy as np from numba import types, jit, prange from stratego_env.game.stratego_procedural_impl import _get_state_from_player_perspective, _get_valid_moves_as_1d_mask, \ _get_partially_observable_observation, _get_fully_observable_observation, _get_valid_moves_as_spatial_mask, \ _get_action_1d_index_from_player_perspective, _get_next_state, _get_game_ended, _get_partially_observable_observation_extended_channels, \ _get_fully_observable_observation_extended_channels, _get_heuristic_rewards_from_move STSH = "int64[:, :, ::1]" OTSH = "float32[:, :, ::1]" COTSH = types.Optional(types.float32[:, :, ::1]) MSTSH = "int64[:, :, :, ::1]" MOTSH = "float32[:, :, :, :, ::1]" VTSH = "int64[:, :, ::1]" MVTSH = "int64[:, :, :, :, ::1]" @jit( f"void(int64, int64, int64, int64, int64, " f"float32[:, :, ::1], float32[:, :, ::1], float32[:, :, ::1], float32[:, :, ::1], boolean, boolean, boolean, boolean," f"{STSH}, {VTSH}, {OTSH}, {OTSH}, {OTSH})", nopython=True, cache=True) def _fill_observation(player, action_size, max_actions, rows, columns, p_obs_mids, p_ops_range, f_obs_mids, f_obs_range, partial_out, full_out, internal_out, extended_channels, state, valid_actions_mask, partial_observation, full_observation, internal_state): pstate = _get_state_from_player_perspective(state=state, player=player)