コード例 #1
0
    def _get_remove_kernels(self):
        import compyle.parallel as parallel

        @annotate(i='int', gintp='indices, if_remove')
        def fill_if_remove(i, indices, if_remove):
            if_remove[indices[i]] = 1

        fill_if_remove_knl = parallel.Elementwise(fill_if_remove,
                                                  backend=self.backend)

        @annotate(i='int', if_remove='gintp', return_='int')
        def remove_input_expr(i, if_remove):
            return if_remove[i]

        types = {
            'i': 'int',
            'item': 'int',
            'if_remove': 'gintp',
            'new_array': self.gptr_type,
            'old_array': self.gptr_type
        }

        @annotate(**types)
        def remove_output_expr(i, item, if_remove, new_array, old_array):
            if not if_remove[i]:
                new_array[i - item] = old_array[i]

        remove_knl = parallel.Scan(remove_input_expr,
                                   remove_output_expr,
                                   'a+b',
                                   dtype=np.int32,
                                   backend=self.backend)

        return fill_if_remove_knl, remove_knl
コード例 #2
0
def get_align_kernel(ary_list, order, backend=None):
    import compyle.parallel as parallel
    align_multiple_knl = AlignMultiple('align_multiple_knl',
                                       len(ary_list))
    align_multiple_elwise = parallel.Elementwise(align_multiple_knl.function,
                                                 backend=backend)
    return align_multiple_elwise
コード例 #3
0
def take(ary, indices, backend=None, out=None):
    import compyle.parallel as parallel
    if backend is None:
        backend = ary.backend
    if out is None:
        out = empty(indices.length, ary.dtype, backend=backend)
    if backend == 'opencl' or backend == 'cuda':
        take_knl = parallel.Elementwise(take_elwise, backend=backend)
        take_knl(indices, ary, out)
    elif backend == 'cython':
        np.take(ary.dev, indices.dev, out=out.dev)
    return out