Exemple #1
0
 def ol_bar(x):
     self.assertTrue(isinstance(x, types.LiteralList))
     lv = x.literal_value
     self.assertTrue(isinstance(lv, list))
     self.assertEqual(lv[0], types.literal(1))
     self.assertEqual(lv[1], types.literal('a'))
     self.assertEqual(lv[2], types.Array(types.float64, 1, 'C'))
     self.assertEqual(lv[3], types.List(types.intp, reflected=False,
                                        initial_value=[1, 2, 3]))
     self.assertTrue(isinstance(lv[4], types.LiteralList))
     self.assertEqual(lv[4].literal_value[0], types.literal('cat'))
     self.assertEqual(lv[4].literal_value[1], types.literal(10))
     return lambda x: x
    def test_literal_nested(self):
        @njit
        def foo(x):
            return literally(x) * 2

        @njit
        def bar(y, x):
            return foo(y) + x

        y, x = 3, 7
        self.assertEqual(bar(y, x), y * 2 + x)
        [foo_sig] = foo.signatures
        self.assertEqual(foo_sig[0], types.literal(y))
        [bar_sig] = bar.signatures
        self.assertEqual(bar_sig[0], types.literal(y))
        self.assertNotIsInstance(bar_sig[1], types.Literal)
Exemple #3
0
    def lower_print(self, inst):
        """
        Lower a ir.Print()
        """
        # We handle this, as far as possible, as a normal call to built-in
        # print().  This will make it easy to undo the special ir.Print
        # rewrite when it becomes unnecessary (e.g. when we have native
        # strings).
        sig = self.fndesc.calltypes[inst]
        assert sig.return_type == types.none
        fnty = self.context.typing_context.resolve_value_type(print)

        # Fix the call signature to inject any constant-inferred
        # string argument
        pos_tys = list(sig.args)
        pos_args = list(inst.args)
        for i in range(len(pos_args)):
            if i in inst.consts:
                pyval = inst.consts[i]
                if isinstance(pyval, str):
                    pos_tys[i] = types.literal(pyval)

        fixed_sig = typing.signature(sig.return_type, *pos_tys)
        fixed_sig.pysig = sig.pysig

        argvals = self.fold_call_args(fnty, sig, pos_args, inst.vararg, {})
        impl = self.context.get_function(print, fixed_sig)
        impl(self.builder, argvals)
Exemple #4
0
 def generic(self, args, kws):
     # Resolution of members for record and structured arrays
     record, idx, value = args
     if isinstance(record, types.Record) and isinstance(idx, str):
         expectedty = record.typeof(idx)
         if self.context.can_convert(value, expectedty) is not None:
             return signature(types.void, record, types.literal(idx), value)
Exemple #5
0
        def ol_bar(d):
            a = {
                "A": 1,
                "B": 1,
                "C": 1,
                "D": 1,
                "E": 1,
                "F": 1,
                "G": 1,
                "H": 1,
                "I": 1,
                "J": 1,
                "K": 1,
                "L": 1,
                "M": 1,
                "N": 1,
                "O": 1,
                "P": 1,
                "Q": 1,
                "R": 1,
                "S": 'a',
            }

            def specific_ty(z):
                return types.literal(z) if types.maybe_literal(z) else typeof(
                    z)

            expected = {types.literal(x): specific_ty(y) for x, y in a.items()}
            self.assertTrue(isinstance(d, types.LiteralStrKeyDict))
            self.assertEqual(d.literal_value, expected)
            self.assertEqual(hasattr(d, 'initial_value'), False)
            return lambda d: d
Exemple #6
0
def _lit_or_omitted(value):
    """Returns a Literal instance if the type of value is supported;
    otherwise, return `Omitted(value)`.
    """
    try:
        return types.literal(value)
    except LiteralTypingError:
        return types.Omitted(value)
Exemple #7
0
 def generic(self, args, kws):
     # Resolution of members for record and structured arrays
     record, idx, value = args
     if isinstance(record, types.Record):
         if isinstance(idx, str):
             expectedty = record.typeof(idx)
             if self.context.can_convert(value, expectedty) is not None:
                 return signature(types.void, record, types.literal(idx),
                                  value)
         elif isinstance(idx, int):
             if idx >= len(record.fields):
                 msg = f"Requested index {idx} is out of range"
                 raise NumbaIndexError(msg)
             str_field = list(record.fields)[idx]
             expectedty = record.typeof(str_field)
             if self.context.can_convert(value, expectedty) is not None:
                 return signature(types.void, record, types.literal(idx),
                                  value)
    def test_literally_freevar(self):
        # Try referring to numba.literally not in the globals
        import numba

        @njit
        def foo(x):
            return numba.literally(x)

        self.assertEqual(foo(123), 123)
        self.assertEqual(foo.signatures[0][0], types.literal(123))
 def run_pass(self, state):
     repl = {}
     # Force the static_getitem to have a literal type as
     # index to replicate the problem.
     for inst, sig in state.calltypes.items():
         if isinstance(inst, ir.Expr) and inst.op == "static_getitem":
             [obj, idx] = sig.args
             new_sig = sig.replace(args=(obj, types.literal(inst.index)))
             repl[inst] = new_sig
     state.calltypes.update(repl)
     return True
Exemple #10
0
        def ol_bar(x):
            self.assertTrue(isinstance(x, types.LiteralStrKeyDict))
            dlv = x.literal_value
            inner_literal = {
                types.literal('g'): types.literal('h'),
                types.literal('i'): types.Array(types.float64, 1, 'C')
            }
            inner_dict = types.LiteralStrKeyDict(inner_literal)
            outer_literal = {
                types.literal('a'):
                types.LiteralList([
                    types.literal(1),
                    types.literal('a'),
                    types.DictType(types.unicode_type,
                                   types.intp,
                                   initial_value={'f': 1}), inner_dict
                ]),
                types.literal('b'):
                types.literal(2),
                types.literal('c'):
                types.List(types.complex128, reflected=False)
            }

            def check_same(a, b):
                if (isinstance(a, types.LiteralList)
                        and isinstance(b, types.LiteralList)):
                    for i, j in zip(a.literal_value, b.literal_value):
                        check_same(a.literal_value, b.literal_value)
                elif (isinstance(a, list) and isinstance(b, list)):
                    for i, j in zip(a, b):
                        check_same(i, j)
                elif (isinstance(a, types.LiteralStrKeyDict)
                      and isinstance(b, types.LiteralStrKeyDict)):
                    for (ki, vi), (kj, vj) in zip(a.literal_value.items(),
                                                  b.literal_value.items()):
                        check_same(ki, kj)
                        check_same(vi, vj)
                elif (isinstance(a, dict) and isinstance(b, dict)):
                    for (ki, vi), (kj, vj) in zip(a.items(), b.items()):
                        check_same(ki, kj)
                        check_same(vi, vj)
                else:
                    self.assertEqual(a, b)

            check_same(dlv, outer_literal)
            return lambda x: x
Exemple #11
0
 def impl(cgctx, builder, sig, args):
     lld, = args
     impl = cgctx.get_function('static_getitem',
                               types.none(d, types.literal('dummy')))
     items = []
     for k in range(len(keys)):
         item = impl(builder, (lld, k),)
         casted = cgctx.cast(builder, item, literal_tys[k], d.types[k])
         items.append(casted)
         cgctx.nrt.incref(builder, d.types[k], item)
     ret = cgctx.make_tuple(builder, sig.return_type, items)
     return ret
    def test_literal_nested_multi_arg(self):
        @njit
        def foo(a, b, c):
            return inner(a, c)

        @njit
        def inner(x, y):
            return x + literally(y)

        kwargs = dict(a=1, b=2, c=3)
        got = foo(**kwargs)
        expect = (lambda a, b, c: a + c)(**kwargs)
        self.assertEqual(got, expect)
        [foo_sig] = foo.signatures
        self.assertEqual(foo_sig[2], types.literal(3))
Exemple #13
0
    def impl(cgctx, builder, sig, args):

        lld, = args
        impl = cgctx.get_function('static_getitem',
                                  types.none(d, types.literal('dummy')))
        items = []
        for k in range(len(keys)):
            item = impl(builder, (lld, k),)
            casted = cgctx.cast(builder, item, literal_tys[k], d.types[k])
            cgctx.nrt.incref(builder, d.types[k], item)
            keydata = make_string_from_constant(cgctx, builder,
                                                types.unicode_type,
                                                keys[k].literal_value)
            pair = cgctx.make_tuple(builder,
                                    types.Tuple([types.unicode_type,
                                                d.types[k]]), (keydata, casted))
            items.append(pair)
        ret = cgctx.make_tuple(builder, sig.return_type, items)
        return ret
Exemple #14
0
import numba
from numba.core.types import literal

t1 = literal("on")
t2 = literal("off")


@numba.generated_jit()
def sim(x, switch):
    if switch.key == t1.key:

        def _sim(x, switch):
            return x
    elif switch.key == t2.key:

        def _sim(x, switch):
            return 2 * x

    return _sim


@numba.generated_jit()
def sim2(x, switch):
    def _sim(x, switch):
        if switch.key == t1.key:
            return x
        elif switch.key == t2.key:
            return 2 * x

    return _sim
Exemple #15
0
    # A list of the actual nodes of the tree.
    ('nodes',ListType(TreeNodeType)),
    # ('u_ys', i4[::1]),

    # A cache of split contexts keyed by the sequence of splits so far
    #  this is where split statistics are held between calls to ifit().
    # ('context_cache', DictType(u8[::1],SplitterContextType)),
    ('context_cache', AKDType(u8,SplitterContextType)),

    # The data stats for this tree. This is kept around be between calls 
    #  to ifit() and replaced with each call to fit().
    ('data_stats', DataStatsType),

    # Whether or not iterative fitting is enabled
    ('ifit_enabled', literal(False)),

]



Tree, TreeTypeTemplate = define_structref_template("Tree", tree_fields, define_constructor=False)

@njit(cache=True)
def Tree_ctor(tree_type):
    st = new(tree_type)
    st.nodes = List.empty_list(TreeNodeType)
    # st.u_ys = np.zeros(0,dtype=np.int32)
    st.context_cache = new_akd(u8,SplitterContextType)#Dict.empty(i8_arr, SplitterContextType)
    st.data_stats = DataStats_ctor()
    return st
Exemple #16
0
    def __init__(self, preset_type='decision_tree', **kwargs):
        '''
        TODO: Finish docs
        kwargs:
            preset_type: Specifies the values of the other kwargs
            criterion: The name of the criterion function used 'entropy', 'gini', etc.
            total_func: The function for combining the impurities for two splits, default 'sum'.
            split_choice: The name of the split choice policy 'all_max', etc.
            pred_choice: The prediction choice policy 'pure_majority_general' etc.
            secondary_criterion: The name of the secondary criterion function used only if 
              no split can be found with the primary impurity.
            secondary_total_func: The name of the secondary total_func, defaults to 'sum'
            positive_class: The integer id for the positive class (used in prediction)
            sep_nan: If set to True then use a ternary tree that treats nan's seperately 
        '''
        kwargs = {**tree_classifier_presets[preset_type], **kwargs}

        criterion, total_func, split_choice, pred_choice, secondary_criterion, \
         secondary_total_func, positive_class, sep_nan, cache_nodes = \
            itemgetter('criterion', 'total_func', 'split_choice', 'pred_choice',
                "secondary_criterion", 'secondary_total_func', 'positive_class',
                'sep_nan', 'cache_nodes')(kwargs)

        g = globals()
        criterion_enum = g.get(f"CRITERION_{criterion}", None)
        # total_enum = g.get(f"TOTAL_{total_func}",None)
        split_choice_enum = g.get(f"SPLIT_CHOICE_{split_choice}", None)
        pred_choice_enum = g.get(f"PRED_CHOICE_{pred_choice}", None)

        if (criterion_enum is None):
            raise ValueError(f"Invalid criterion {criterion}")
        # if(total_enum is None): raise ValueError(f"Invalid criterion {total_func}")
        if (split_choice_enum is None):
            raise ValueError(f"Invalid split_choice {split_choice}")
        if (pred_choice_enum is None):
            raise ValueError(f"Invalid pred_choice {pred_choice}")
        self.positive_class = positive_class

        config_dict = {k: v for k, v in config_fields}
        config_dict['criterion_enum'] = literal(criterion_enum)
        config_dict['split_choice_enum'] = literal(split_choice_enum)
        config_dict['pred_choice_enum'] = literal(pred_choice_enum)

        ConfigType = TreeClassifierConfigTemplate([
            (k, v) for k, v in config_dict.items()
        ])

        # print(config_dict)

        # self.config =

        tf_dict = {k: v for k, v in tree_fields}
        tf = [(k, v) for k, v in {
            **tf_dict,
            **{
                "ifit_enabled": literal(True)
            }
        }.items()]
        self.tree_type = TreeTypeTemplate(tf)
        self.config = new_config(ConfigType)
        self.tree = Tree_ctor(self.tree_type)
Exemple #17
0
 def specific_ty(z):
     return types.literal(z) if types.maybe_literal(z) else typeof(
         z)