コード例 #1
0
ファイル: compressor.py プロジェクト: wolfv/pyjet
    def init_op(self):
        self.fmt = '{in1} = {in2}{tranpose};' + self.fmt_base
        self.fmt_with_idx = '{in1}({idx}) = {in2};'
        self.at_idx = self.el.at_idx
        self.transpose = ''
        self.slices = self.el.slices
        self.input1 = sanitize_name(self.el.inputs[0].name)
        self.input2 = sanitize_name(self.el.inputs[1].name)
        shape0 = self.el.inputs[0].shape
        shape1 = self.el.inputs[1].shape

        if len(shape0) and len(shape1) and shape0[0] != shape1[0]:
            self.transpose = '.t()'

        if self.el.slices:
            self.slice_tuples = []
            slices = self.el.slices
            input_shape = self.el.inputs[0].shape
            nslices = len(slices)
            slice_ndim = len(input_shape)

            # in ardamillo no 1d vectors -> convert to 2d row vector
            if slice_ndim == 1:
                input_shape = (input_shape[0], 1)
                slice_ndim += 1

            # fill up empty slices with ':'' (-> slice[None])
            while nslices < slice_ndim:
                slices.append(slice(None))
                nslices += 1

            # convert to armadillo slices
            for idx in range(nslices):
                if type(slices[idx]) == int:
                    self.slice_tuples.append((slices[idx], slices[idx]))
                else:
                    sl = slices[idx]
                    stp = sl.stop - 1 if sl.stop else input_shape[idx] - 1
                    self.slice_tuples.append((sl.start or 0, stp))

            out_shape = self.el.output.kwargs['slice_shape']
            slice_shape = out_shape if len(out_shape) != 1 else (out_shape[0],
                                                                 1)
            arma_shape = (self.slice_tuples[0][1] - self.slice_tuples[0][0] +
                          1, self.slice_tuples[1][1] -
                          self.slice_tuples[1][0] + 1)

            if slice_shape == arma_shape or (slice_shape == ()
                                             and arma_shape == (1, 1)):
                self.slice_fmt = 'set_items({lhs}, {rhs}, {{{start}}}, {{{end}}});'
            elif slice_shape == (arma_shape[1], arma_shape[0]):
                self.slice_fmt = 'set_items({lhs}, {rhs}, {{{start}}}, {{{end}}}, true);'
            else:
                raise NotImplementedError("You should not end up here.")
コード例 #2
0
ファイル: compressor.py プロジェクト: wolfv/pyjet
 def __init__(self, el):
     self.el = el
     self.name = sanitize_name(el.name)
     self.fmt_caller = fmt_caller.format(name=self.name,
                             caller_class=el.caller_info[0] if el.caller_info else '',
                             caller_fun=el.caller_info[2] if el.caller_info else '',
                             caller_line=el.caller_info[3] if el.caller_info else '') \
                       if config.debug else ''
     self.init_op()
コード例 #3
0
ファイル: compressor.py プロジェクト: wolfv/pyjet
    def __init__(self, el):
        self.el = el
        self.name = sanitize_name(el.name)
        self.fmt_base = (' ' + fmt_caller + '\n' + (fmt_print_double
                if el.output.shape == () else fmt_print_mat)).format(
                                name=self.name,
                                caller_class=el.caller_info[0] if el.caller_info else '',
                                caller_fun=el.caller_info[2] if el.caller_info else '',
                                caller_line=el.caller_info[3] if el.caller_info else '') \
                if config.debug else ''

        self.init_op()
コード例 #4
0
        def wrapper(*args):
            if not jet.jet_mode:
                return func(*args)

            func_id = id(func)
            func_cached = _func_cached_dict[func_id]['func']
            if func_cached is not None:
                return func_cached(*args)

            shapes = _func_cached_dict[func_id]['shapes']

            if inspect.ismethod(func):
                arg_names = func.__code__.co_varnames[1:func.__code__.
                                                      co_argcount]
            else:
                arg_names = func.__code__.co_varnames[:func.__code__.
                                                      co_argcount]

            if len(arg_names) != len(args):
                assert (len(arg_names) == 0)
                arg_names = [get_unique_name('ph') for each in args]

            if len(shapes) != len(arg_names) and shapes:
                raise ValueError(
                    'Shapes length does not match the arguments length.')

            if not shapes:
                shapes = [
                    arg.shape if hasattr(arg, 'shape') else () for arg in args
                ]
                _func_cached_dict[func_id]['shapes'] = shapes

            ph = [
                placeholder(name=arg[1], shape=shapes[arg[0]])
                for arg in enumerate(arg_names)
            ]
            fun_name = func.__code__.co_name
            if fun_name == '<lambda>':
                fun_name = get_unique_name('lambda')

            jb = JetBuilder(args=ph,
                            out=func(*ph),
                            file_name=get_unique_name(
                                sanitize_name('{}_{}_{func_name}'.format(
                                    *get_caller_info('jit.py')[1:-1],
                                    func_name=fun_name))),
                            fun_name=get_unique_name(fun_name))

            jet_class = getattr(jb.build(), jb.class_name)
            jet_func = getattr(jet_class(), jb.fun_name)
            _func_cached_dict[func_id]['func'] = jet_func

            return jet_func(*args)
コード例 #5
0
ファイル: compressor.py プロジェクト: wolfv/pyjet
 def __repr__(self):
     input1 = sanitize_name(self.el.inputs[0].name)
     return self.fmt.format(dtype=self.get_dtype(),
                            name=self.name,
                            in1=input1,
                            operator=self.op_map[self.el.op])