示例#1
0
    def op_ckernel(self, op):
        op_ndim = len(op.type.shape)
        result_ndim = self.env.get('result-ndim', 0)
        ckernel, args = op.args
        in_types = [self.get_arg_type(arg) for arg in args[1:]]
        out_type = ndt.type(str(args[0].type))

        if isinstance(ckernel, dict):
            tag = ckernel['tag']
            if tag == 'elwise':
                ck = ckernel['ckernel']
                if op.metadata['rank'] < op_ndim and \
                        self.env.get('stream-outer', False) and result_ndim == op_ndim:
                    # Replace the leading dimension type with 'strided' in each operand
                    # if we're streaming it for processing BLZ
                    # TODO: Add dynd tp.subarray(N) function like datashape has
                    for i, tp in enumerate(in_types):
                        if tp.ndim == result_ndim:
                            in_types[i] = ndt.make_strided_dim(tp.element_type)
                    out_type = ndt.make_strided_dim(out_type.element_type)

                op.args[0] = _lowlevel.lift_ckernel_deferred(ck,
                                                             [out_type] + in_types)
            elif tag == 'reduction':
                ck = ckernel['ckernel']
                assoc = ckernel['assoc']
                comm = ckernel['comm']
                ident = ckernel['ident']
                ident = None if ident is None else nd.asarray(ident)
                axis = ckernel['axis']
                keepdims = ckernel['keepdims']
                op.args[0] = _lowlevel.lift_reduction_ckernel_deferred(
                                ck, in_types[0],
                                axis=axis, keepdims=keepdims,
                                associative=assoc, commutative=comm,
                                reduction_identity=ident)
            elif tag == 'rolling':
                ck = ckernel['ckernel']
                window = ckernel['window']
                minp = ckernel['minp']
                if minp != 0:
                    raise ValueError('rolling window with minp != 0 not supported yet')
                op.args[0] = _lowlevel.make_rolling_ckernel_deferred(out_type,
                                                                     in_types[0],
                                                                     ck, window)
            elif tag == 'ckfactory':
                ckfactory = ckernel['ckernel_factory']
                ck = ckfactory(out_type, *in_types)
                op.args[0] = ck
            else:
                raise RuntimeError('unnrecognized ckernel tag %s' % tag)
        else:
            op.args[0] = ckernel
#------------------------------------------------------------------------
# Other Funcs
#------------------------------------------------------------------------

rolling_mean = RollingWindowBlazeFunc('blaze', 'rolling_mean')
mean1d = _lowlevel.make_builtin_mean1d_ckernel_deferred('float64', 0)
rolling_mean.add_overload('(M * float64) -> M * float64', mean1d)

diff = BlazeFunc('blaze', 'diff')
subtract_doubles_ck = _lowlevel.ckernel_deferred_from_ufunc(np.subtract,
                (np.float64, np.float64, np.float64),
                False)
diff_pair_ck = _lowlevel.lift_reduction_ckernel_deferred(subtract_doubles_ck,
                                         'strided * float64',
                                         axis=0,
                                         commutative=False,
                                         associative=False)
diff_ck = _lowlevel.make_rolling_ckernel_deferred('strided * float64',
                                                  'strided * float64',
                                                  diff_pair_ck, 2)
diff.add_overload('(M * float64) -> M * float64', diff_ck)

take = CKFBlazeFunc('blaze', 'take')
# Masked take
take.add_overload('(M * T, M * bool) -> var * T',
                  _lowlevel.make_take_ckernel_deferred)
# Indexed take
take.add_overload('(M * T, N * intptr) -> N * T',
                  _lowlevel.make_take_ckernel_deferred)