Ejemplo n.º 1
0
def get_radix_array(size, use_max_radix=False):
    """
    For any ``size``, this function decomposes ``size`` into factors for loacal memory tranpose
    based fft. Factors (radices) are sorted such that the first one (radix_array[0])
    is the largest. This base radix determines the number of registers used by each
    work item and product of remaining radices determine the size of work group needed.
    To make things concrete with and example, suppose ``size`` = 1024. It is decomposed into
    1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and
    needs 16 x 4 = 64 work items per work group. So kernel first performance 64 length
    16 ffts (64 work items working in parallel) following by transpose using local
    memory followed by again 64 length 16 ffts followed by transpose using local memory
    followed by 256 length 4 ffts. For the last step since with size of work group is
    64 and each work item can array for 16 values, 64 work items can compute 256 length
    4 ffts by each work item computing 4 length 4 ffts.
    Similarly for ``size`` = 2048 = 8 x 8 x 8 x 4, each work group has 8 x 8 x 4 = 256 work
    iterms which each computes 256 (in-parallel) length 8 ffts in-register, followed
    by transpose using local memory, followed by 256 length 8 in-register ffts, followed
    by transpose using local memory, followed by 256 length 8 in-register ffts, followed
    by transpose using local memory, followed by 512 length 4 in-register ffts. Again,
    for the last step, each work item computes two length 4 in-register ffts and thus
    256 work items are needed to compute all 512 ffts.
    For ``size`` = 32 = 8 x 4, 4 work items first compute 4 in-register
    lenth 8 ffts, followed by transpose using local memory followed by 8 in-register
    length 4 ffts, where each work item computes two length 4 ffts thus 4 work items
    can compute 8 length 4 ffts. However if work group size of say 64 is choosen,
    each work group can compute 64/ 4 = 16 size 32 ffts (batched transform).
    Users can play with these parameters to figure what gives best performance on
    their particular device i.e. some device have less register space thus using
    smaller base radix can avoid spilling ... some has small local memory thus
    using smaller work group size may be required etc
    """
    assert size == 2 ** helpers.log2(size)

    if use_max_radix:
        radix = min(size, MAX_RADIX)
        radix_array = []
        while size > radix:
            radix_array.append(radix)
            size //= radix
        radix_array.append(size)
        return radix_array
    else:
        arrays = {
            2: [2], 4: [4], 8: [8],
            16: [8, 2], 32: [8, 4], 64: [8, 8],
            128: [8, 4, 4],
            256: [4, 4, 4, 4],
            512: [8, 8, 8],
            1024: [16, 16, 4],
            2048: [8, 8, 8, 4]
        }
        if size in arrays:
            return arrays[size]
        else:
            # Naive algorithm, can be imroved.
            lsize = helpers.log2(size)
            num_elems = helpers.min_blocks(lsize, 4)
            return [16] * (num_elems - 1) + [16 if lsize % 4 == 0 else 2 ** (lsize % 4)]
Ejemplo n.º 2
0
def get_radix_array(size, use_max_radix=False):
    """
    For any ``size``, this function decomposes ``size`` into factors for loacal memory tranpose
    based fft. Factors (radices) are sorted such that the first one (radix_array[0])
    is the largest. This base radix determines the number of registers used by each
    work item and product of remaining radices determine the size of work group needed.
    To make things concrete with and example, suppose ``size`` = 1024. It is decomposed into
    1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and
    needs 16 x 4 = 64 work items per work group. So kernel first performance 64 length
    16 ffts (64 work items working in parallel) following by transpose using local
    memory followed by again 64 length 16 ffts followed by transpose using local memory
    followed by 256 length 4 ffts. For the last step since with size of work group is
    64 and each work item can array for 16 values, 64 work items can compute 256 length
    4 ffts by each work item computing 4 length 4 ffts.
    Similarly for ``size`` = 2048 = 8 x 8 x 8 x 4, each work group has 8 x 8 x 4 = 256 work
    iterms which each computes 256 (in-parallel) length 8 ffts in-register, followed
    by transpose using local memory, followed by 256 length 8 in-register ffts, followed
    by transpose using local memory, followed by 256 length 8 in-register ffts, followed
    by transpose using local memory, followed by 512 length 4 in-register ffts. Again,
    for the last step, each work item computes two length 4 in-register ffts and thus
    256 work items are needed to compute all 512 ffts.
    For ``size`` = 32 = 8 x 4, 4 work items first compute 4 in-register
    lenth 8 ffts, followed by transpose using local memory followed by 8 in-register
    length 4 ffts, where each work item computes two length 4 ffts thus 4 work items
    can compute 8 length 4 ffts. However if work group size of say 64 is choosen,
    each work group can compute 64/ 4 = 16 size 32 ffts (batched transform).
    Users can play with these parameters to figure what gives best performance on
    their particular device i.e. some device have less register space thus using
    smaller base radix can avoid spilling ... some has small local memory thus
    using smaller work group size may be required etc
    """
    assert size == 2 ** helpers.log2(size)

    if use_max_radix:
        radix = min(size, MAX_RADIX)
        radix_array = []
        while size > radix:
            radix_array.append(radix)
            size //= radix
        radix_array.append(size)
        return radix_array
    else:
        arrays = {
            2: [2], 4: [4], 8: [8],
            16: [8, 2], 32: [8, 4], 64: [8, 8],
            128: [8, 4, 4],
            256: [4, 4, 4, 4],
            512: [8, 8, 8],
            1024: [16, 16, 4],
            2048: [8, 8, 8, 4]
        }
        if size in arrays:
            return arrays[size]
        else:
            # Naive algorithm, can be imroved.
            lsize = helpers.log2(size)
            num_elems = helpers.min_blocks(lsize, 4)
            return [16] * (num_elems - 1) + [16 if lsize % 4 == 0 else 2 ** (lsize % 4)]
Ejemplo n.º 3
0
Archivo: fft.py Proyecto: ringw/reikna
def get_global_radix_info(size):
    """
    For ``size`` larger than what can be computed using local memory fft, global transposes
    multiple kernel launces is needed. For these sizes, ``size`` can be decomposed using
    much larger base radices i.e. say ``size`` = 262144 = 128 x 64 x 32. Thus three kernel
    launches will be needed, first computing 64 x 32, length 128 ffts, second computing
    128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts.
    Each of these base radices can futher be divided into factors so that each of these
    base ffts can be computed within one kernel launch using in-register ffts and local
    memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length
    128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length
    16 ffts followed by transpose using local memory followed by each of these eight
    work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This
    means only 8 work items are needed for computing one length 128 fft. If we choose
    work group size of say 64, we can compute 64/8 = 8 length 128 ffts within one
    work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this
    means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each
    work group where each work group is computing 8 length 128 ffts where each length
    128 fft is computed by 8 work items. Same logic can be applied to other two kernels
    in this example. Users can play with difference base radices and difference
    decompositions of base radices to generates different kernels and see which gives
    best performance. Following function is just fixed to use 128 as base radix
    """
    assert size == 2**helpers.log2(size)

    base_radix = min(size, 128)

    num_radices = 0
    while size > base_radix:
        size //= base_radix
        num_radices += 1

    radix_list = [base_radix] * num_radices + [size]
    radix1_list = []
    radix2_list = []

    for radix in radix_list:
        if radix <= 8:
            radix1_list.append(radix)
            radix2_list.append(1)
        else:
            radix1 = 2
            radix2 = radix // radix1
            while radix2 > radix1:
                radix1 *= 2
                radix2 = radix // radix1

            radix1_list.append(radix1)
            radix2_list.append(radix2)

    # sanity checks:
    for radix, radix1, radix2 in zip(radix_list, radix1_list, radix2_list):
        assert radix2 <= radix1
        assert radix1 * radix2 == radix
        assert radix1 <= MAX_RADIX

    return radix_list, radix1_list, radix2_list
Ejemplo n.º 4
0
def get_global_radix_info(size):
    """
    For ``size`` larger than what can be computed using local memory fft, global transposes
    multiple kernel launces is needed. For these sizes, ``size`` can be decomposed using
    much larger base radices i.e. say ``size`` = 262144 = 128 x 64 x 32. Thus three kernel
    launches will be needed, first computing 64 x 32, length 128 ffts, second computing
    128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts.
    Each of these base radices can futher be divided into factors so that each of these
    base ffts can be computed within one kernel launch using in-register ffts and local
    memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length
    128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length
    16 ffts followed by transpose using local memory followed by each of these eight
    work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This
    means only 8 work items are needed for computing one length 128 fft. If we choose
    work group size of say 64, we can compute 64/8 = 8 length 128 ffts within one
    work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this
    means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each
    work group where each work group is computing 8 length 128 ffts where each length
    128 fft is computed by 8 work items. Same logic can be applied to other two kernels
    in this example. Users can play with difference base radices and difference
    decompositions of base radices to generates different kernels and see which gives
    best performance. Following function is just fixed to use 128 as base radix
    """
    assert size == 2 ** helpers.log2(size)

    base_radix = min(size, 128)

    num_radices = 0
    while size > base_radix:
        size //= base_radix
        num_radices += 1

    radix_list = [base_radix] * num_radices + [size]
    radix1_list = []
    radix2_list = []

    for radix in radix_list:
        if radix <= 8:
            radix1_list.append(radix)
            radix2_list.append(1)
        else:
            radix1 = 2
            radix2 = radix // radix1
            while radix2 > radix1:
                radix1 *= 2
                radix2 = radix // radix1

            radix1_list.append(radix1)
            radix2_list.append(radix2)

    # sanity checks:
    for radix, radix1, radix2 in zip(radix_list, radix1_list, radix2_list):
        assert radix2 <= radix1
        assert radix1 * radix2 == radix
        assert radix1 <= MAX_RADIX

    return radix_list, radix1_list, radix2_list
Ejemplo n.º 5
0
    def _build_plan(self, plan_factory, device_params, output, matrix_a, matrix_b):
        bwo = self._block_width_override

        if bwo is not None:
            block_widths = [bwo]
        else:
            nbanks = device_params.local_mem_banks
            block_widths = [2 ** n for n in range(helpers.log2(nbanks), -1, -1)]

        a_batch = helpers.product(matrix_a.shape[:-2])
        b_batch = helpers.product(matrix_b.shape[:-2])
        batch = max(a_batch, b_batch)

        for block_width in block_widths:

            plan = plan_factory()

            if block_width ** 2 > device_params.max_work_group_size:
                continue

            num_steps = helpers.min_blocks(self._convolution_size, block_width)
            a_blocks = helpers.min_blocks(self._a_outer_size, block_width)
            b_blocks = helpers.min_blocks(self._b_outer_size, block_width)

            render_kwds = dict(
                batched_a=(a_batch != 1),
                batched_b=(b_batch != 1),
                transposed_a=self._transposed_a,
                transposed_b=self._transposed_b,
                num_steps=num_steps,
                a_slices=(len(matrix_a.shape) - 2, 1, 1),
                b_slices=(len(matrix_b.shape) - 2, 1, 1),
                output_slices=(len(output.shape) - 2, 1, 1),
                block_width=block_width,
                mul=functions.mul(matrix_a.dtype, matrix_b.dtype, out_dtype=output.dtype))

            try:
                plan.kernel_call(
                    TEMPLATE.get_def('matrixmul'),
                    [output, matrix_a, matrix_b],
                    kernel_name="kernel_matrixmul",
                    global_size=(
                        batch,
                        a_blocks * block_width,
                        b_blocks * block_width),
                    local_size=(1, block_width, block_width),
                    render_kwds=render_kwds)
            except OutOfResourcesError:
                continue

            return plan

        raise ValueError("Could not find suitable call parameters for the kernel")
Ejemplo n.º 6
0
    def _build_plan(self, plan_factory, device_params, output, matrix_a,
                    matrix_b):
        bwo = self._block_width_override

        if bwo is not None:
            block_widths = [bwo]
        else:
            nbanks = device_params.local_mem_banks
            block_widths = [2**n for n in range(helpers.log2(nbanks), -1, -1)]

        a_batch = helpers.product(matrix_a.shape[:-2])
        b_batch = helpers.product(matrix_b.shape[:-2])
        batch = max(a_batch, b_batch)

        for block_width in block_widths:

            plan = plan_factory()

            if block_width**2 > device_params.max_work_group_size:
                continue

            num_steps = helpers.min_blocks(self._convolution_size, block_width)
            a_blocks = helpers.min_blocks(self._a_outer_size, block_width)
            b_blocks = helpers.min_blocks(self._b_outer_size, block_width)

            render_kwds = dict(batched_a=(a_batch != 1),
                               batched_b=(b_batch != 1),
                               transposed_a=self._transposed_a,
                               transposed_b=self._transposed_b,
                               num_steps=num_steps,
                               a_slices=(len(matrix_a.shape) - 2, 1, 1),
                               b_slices=(len(matrix_b.shape) - 2, 1, 1),
                               output_slices=(len(output.shape) - 2, 1, 1),
                               block_width=block_width,
                               mul=functions.mul(matrix_a.dtype,
                                                 matrix_b.dtype,
                                                 out_dtype=output.dtype))

            try:
                plan.kernel_call(TEMPLATE.get_def('matrixmul'),
                                 [output, matrix_a, matrix_b],
                                 kernel_name="kernel_matrixmul",
                                 global_size=(batch, a_blocks * block_width,
                                              b_blocks * block_width),
                                 local_size=(1, block_width, block_width),
                                 render_kwds=render_kwds)
            except OutOfResourcesError:
                continue

            return plan

        raise ValueError(
            "Could not find suitable call parameters for the kernel")
Ejemplo n.º 7
0
    def _build_plan_for_wg_size(self, plan_factory, warp_size, max_wg_size,
                                output, input_):

        plan = plan_factory()

        # Using algorithm cascading: sequential reduction, and then the parallel one.
        # According to Brent's theorem, the optimal sequential size is O(log(n)).
        # Setting it to the nearest power of 2 to simplify integer operations.
        max_seq_size = helpers.bounding_power_of_2(helpers.log2(max_wg_size))
        max_reduce_power = max_wg_size * max_seq_size

        if self._transpose_axes is None:
            # normal reduction
            cur_input = input_
        else:
            transpose = Transpose(input_, axes=self._transpose_axes)
            tr_output = plan.temp_array_like(transpose.parameter.output)
            plan.computation_call(transpose, tr_output, input_)

            cur_input = tr_output

        axis_start = len(output.shape)
        axis_end = len(input_.shape) - 1

        input_slices = (axis_start, axis_end - axis_start + 1)

        part_size = helpers.product(cur_input.shape[axis_start:])
        final_size = helpers.product(cur_input.shape[:axis_start])

        while part_size > 1:

            if part_size > max_reduce_power:
                seq_size = max_seq_size
                block_size = max_wg_size
                blocks_per_part = helpers.min_blocks(part_size,
                                                     block_size * seq_size)
                cur_output = plan.temp_array((final_size, blocks_per_part),
                                             input_.dtype)
                output_slices = (1, 1)
            else:
                if part_size > max_wg_size:
                    seq_size = helpers.min_blocks(part_size, max_wg_size)
                    block_size = max_wg_size
                else:
                    seq_size = 1
                    block_size = helpers.bounding_power_of_2(part_size)
                blocks_per_part = 1
                cur_output = output
                output_slices = (len(cur_output.shape), 0)

            if part_size % (block_size * seq_size) != 0:
                last_block_size = part_size % (block_size * seq_size)
            else:
                last_block_size = block_size * seq_size

            render_kwds = dict(seq_size=seq_size,
                               blocks_per_part=blocks_per_part,
                               last_block_size=last_block_size,
                               log2=helpers.log2,
                               block_size=block_size,
                               warp_size=warp_size,
                               empty=self._empty,
                               operation=self._operation,
                               input_slices=input_slices,
                               output_slices=output_slices)

            plan.kernel_call(TEMPLATE.get_def('reduce'),
                             [cur_output, cur_input],
                             global_size=(final_size,
                                          blocks_per_part * block_size),
                             local_size=(1, block_size),
                             render_kwds=render_kwds)

            part_size = blocks_per_part
            cur_input = cur_output
            input_slices = output_slices

        return plan
Ejemplo n.º 8
0
    def _build_plan_for_wg_size(self, plan_factory, warp_size, max_wg_size, output, input_):

        plan = plan_factory()

        # Using algorithm cascading: sequential reduction, and then the parallel one.
        # According to Brent's theorem, the optimal sequential size is O(log(n)).
        # Setting it to the nearest power of 2 to simplify integer operations.
        max_seq_size = helpers.bounding_power_of_2(helpers.log2(max_wg_size))
        max_reduce_power = max_wg_size * max_seq_size

        if self._transpose_axes is None:
            # normal reduction
            cur_input = input_
        else:
            transpose = Transpose(input_, axes=self._transpose_axes)
            tr_output = plan.temp_array_like(transpose.parameter.output)
            plan.computation_call(transpose, tr_output, input_)

            cur_input = tr_output

        axis_start = len(output.shape)
        axis_end = len(input_.shape) - 1

        input_slices = (axis_start, axis_end - axis_start + 1)

        part_size = helpers.product(cur_input.shape[axis_start:])
        final_size = helpers.product(cur_input.shape[:axis_start])

        while part_size > 1:

            if part_size > max_reduce_power:
                seq_size = max_seq_size
                block_size = max_wg_size
                blocks_per_part = helpers.min_blocks(part_size, block_size * seq_size)
                cur_output = plan.temp_array(
                    (final_size, blocks_per_part), input_.dtype)
                output_slices = (1, 1)
            else:
                if part_size > max_wg_size:
                    seq_size = helpers.min_blocks(part_size, max_wg_size)
                    block_size = max_wg_size
                else:
                    seq_size = 1
                    block_size = helpers.bounding_power_of_2(part_size)
                blocks_per_part = 1
                cur_output = output
                output_slices = (len(cur_output.shape), 0)

            if part_size % (block_size * seq_size) != 0:
                last_block_size = part_size % (block_size * seq_size)
            else:
                last_block_size = block_size * seq_size

            render_kwds = dict(
                seq_size=seq_size,
                blocks_per_part=blocks_per_part,
                last_block_size=last_block_size,
                log2=helpers.log2, block_size=block_size,
                warp_size=warp_size,
                empty=self._empty,
                operation=self._operation,
                input_slices=input_slices,
                output_slices=output_slices)

            plan.kernel_call(
                TEMPLATE.get_def('reduce'),
                [cur_output, cur_input],
                global_size=(final_size, blocks_per_part * block_size),
                local_size=(1, block_size),
                render_kwds=render_kwds)

            part_size = blocks_per_part
            cur_input = cur_output
            input_slices = output_slices

        return plan
Ejemplo n.º 9
0
    def _build_plan(self, plan_factory, device_params, output, input_):
        plan = plan_factory()

        if self._transpose_to is not None:

            transpose_to = Transpose(input_, axes=self._transpose_to)
            transposed = plan.temp_array_like(transpose_to.parameter.output)

            sub_scan = Scan(
                transposed, self._predicate, axes=self._axes, exclusive=self._exclusive,
                max_work_group_size=self._max_work_group_size)
            transposed_scanned = plan.temp_array_like(sub_scan.parameter.output)

            transpose_from = Transpose(
                transposed_scanned, axes=self._transpose_from, output_arr_t=output)

            plan.computation_call(transpose_to, transposed, input_)
            plan.computation_call(sub_scan, transposed_scanned, transposed)
            plan.computation_call(transpose_from, output, transposed_scanned)

        else:

            scan_ndim = len(self._axes) # assuming that at this point axes are inner and sorted
            batch_shape = output.shape[:-scan_ndim]
            batch_size = helpers.product(batch_shape)
            scan_shape = output.shape[-scan_ndim:]
            scan_size = helpers.product(scan_shape)

            if self._max_work_group_size is None:
                max_wg_size = device_params.max_work_group_size
            else:
                max_wg_size = self._max_work_group_size

            # The current algorithm requires workgroup size to be a power of 2.
            assert max_wg_size == 2**helpers.log2(max_wg_size)

            # Using algorithm cascading: sequential reduction, and then the parallel one.
            # According to Brent's theorem, the optimal sequential size is O(log(n)).
            # So, ideally we want the minimum `wg_size` for which
            # `wg_size * log2(wg_size) >= scan_size`.
            if self._seq_size is None:
                wg_size = 2
                while wg_size < max_wg_size:
                    seq_size = helpers.bounding_power_of_2(helpers.log2(wg_size) - 1)
                    if wg_size * seq_size >= scan_size:
                        break
                    wg_size *= 2
            else:
                seq_size = self._seq_size
                wg_size = helpers.bounding_power_of_2(helpers.min_blocks(scan_size, seq_size))
                if wg_size > max_wg_size:
                    raise ValueError(
                        "Sequential size " + str(seq_size)
                        + " cannot be set because of the maximum workgroup size " + max_wg_size)

            wg_totals_size = helpers.min_blocks(scan_size, wg_size * seq_size)
            wg_totals = plan.temp_array((batch_size, wg_totals_size,), output.dtype)

            if wg_totals_size > 1:
                temp_output = plan.temp_array_like(output)
            else:
                temp_output = output

            last_part_size = scan_size % (wg_size * seq_size)
            if last_part_size == 0:
                last_part_size = wg_size * seq_size

            plan.kernel_call(
                TEMPLATE.get_def('scan'),
                    [temp_output, input_, wg_totals],
                    kernel_name="kernel_scan_wg",
                    global_size=(batch_size, wg_size * wg_totals_size),
                    local_size=(1, wg_size),
                    render_kwds=dict(
                        slices=(len(batch_shape), len(scan_shape)),
                        log_num_banks=helpers.log2(device_params.local_mem_banks),
                        exclusive=self._exclusive,
                        wg_size=wg_size,
                        seq_size=seq_size,
                        scan_size=scan_size,
                        last_part_size=last_part_size,
                        wg_totals_size=wg_totals_size,
                        log_wg_size=helpers.log2(wg_size),
                        predicate=self._predicate
                        ))

            if wg_totals_size > 1:
                sub_scan = Scan(
                    wg_totals, self._predicate, axes=(1,), exclusive=True,
                    max_work_group_size=self._max_work_group_size)
                scanned_wg_totals = plan.temp_array_like(wg_totals)
                plan.computation_call(sub_scan, scanned_wg_totals, wg_totals)

                plan.kernel_call(
                    TEMPLATE.get_def('add_wg_totals'),
                        [output, temp_output, scanned_wg_totals],
                        kernel_name="kernel_scan_add_wg_totals",
                        global_size=(batch_size, scan_size,),
                        render_kwds=dict(
                            slices=(len(batch_shape), len(scan_shape),),
                            wg_size=wg_size,
                            seq_size=seq_size,
                            ))

        return plan
Ejemplo n.º 10
0
    def _build_plan(self, plan_factory, device_params, output, input_):
        plan = plan_factory()

        if self._transpose_to is not None:

            transpose_to = Transpose(input_, axes=self._transpose_to)
            transposed = plan.temp_array_like(transpose_to.parameter.output)

            sub_scan = Scan(transposed,
                            self._predicate,
                            axes=self._axes,
                            exclusive=self._exclusive,
                            max_work_group_size=self._max_work_group_size)
            transposed_scanned = plan.temp_array_like(
                sub_scan.parameter.output)

            transpose_from = Transpose(transposed_scanned,
                                       axes=self._transpose_from,
                                       output_arr_t=output)

            plan.computation_call(transpose_to, transposed, input_)
            plan.computation_call(sub_scan, transposed_scanned, transposed)
            plan.computation_call(transpose_from, output, transposed_scanned)

        else:

            scan_ndim = len(
                self._axes
            )  # assuming that at this point axes are inner and sorted
            batch_shape = output.shape[:-scan_ndim]
            batch_size = helpers.product(batch_shape)
            scan_shape = output.shape[-scan_ndim:]
            scan_size = helpers.product(scan_shape)

            if self._max_work_group_size is None:
                max_wg_size = device_params.max_work_group_size
            else:
                max_wg_size = self._max_work_group_size

            # The current algorithm requires workgroup size to be a power of 2.
            assert max_wg_size == 2**helpers.log2(max_wg_size)

            # Using algorithm cascading: sequential reduction, and then the parallel one.
            # According to Brent's theorem, the optimal sequential size is O(log(n)).
            # So, ideally we want the minimum `wg_size` for which
            # `wg_size * log2(wg_size) >= scan_size`.
            if self._seq_size is None:
                wg_size = 2
                while wg_size < max_wg_size:
                    seq_size = helpers.bounding_power_of_2(
                        helpers.log2(wg_size) - 1)
                    if wg_size * seq_size >= scan_size:
                        break
                    wg_size *= 2
            else:
                seq_size = self._seq_size
                wg_size = helpers.bounding_power_of_2(
                    helpers.min_blocks(scan_size, seq_size))
                if wg_size > max_wg_size:
                    raise ValueError(
                        "Sequential size " + str(seq_size) +
                        " cannot be set because of the maximum workgroup size "
                        + max_wg_size)

            wg_totals_size = helpers.min_blocks(scan_size, wg_size * seq_size)
            wg_totals = plan.temp_array((
                batch_size,
                wg_totals_size,
            ), output.dtype)

            if wg_totals_size > 1:
                temp_output = plan.temp_array_like(output)
            else:
                temp_output = output

            last_part_size = scan_size % (wg_size * seq_size)
            if last_part_size == 0:
                last_part_size = wg_size * seq_size

            plan.kernel_call(
                TEMPLATE.get_def('scan'), [temp_output, input_, wg_totals],
                kernel_name="kernel_scan_wg",
                global_size=(batch_size, wg_size * wg_totals_size),
                local_size=(1, wg_size),
                render_kwds=dict(slices=(len(batch_shape), len(scan_shape)),
                                 log_num_banks=helpers.log2(
                                     device_params.local_mem_banks),
                                 exclusive=self._exclusive,
                                 wg_size=wg_size,
                                 seq_size=seq_size,
                                 scan_size=scan_size,
                                 last_part_size=last_part_size,
                                 wg_totals_size=wg_totals_size,
                                 log_wg_size=helpers.log2(wg_size),
                                 predicate=self._predicate))

            if wg_totals_size > 1:
                sub_scan = Scan(wg_totals,
                                self._predicate,
                                axes=(1, ),
                                exclusive=True,
                                max_work_group_size=self._max_work_group_size)
                scanned_wg_totals = plan.temp_array_like(wg_totals)
                plan.computation_call(sub_scan, scanned_wg_totals, wg_totals)

                plan.kernel_call(TEMPLATE.get_def('add_wg_totals'),
                                 [output, temp_output, scanned_wg_totals],
                                 kernel_name="kernel_scan_add_wg_totals",
                                 global_size=(
                                     batch_size,
                                     scan_size,
                                 ),
                                 render_kwds=dict(
                                     slices=(
                                         len(batch_shape),
                                         len(scan_shape),
                                     ),
                                     wg_size=wg_size,
                                     seq_size=seq_size,
                                 ))

        return plan
Ejemplo n.º 11
0
    def _build_plan(self, plan_factory, device_params, output, alpha, beta):

        plan = plan_factory()

        samples, modes = alpha.shape

        for_reduction = Type(alpha.dtype, (samples, self._max_total_clicks + 1))

        prepared_state = plan.temp_array_like(alpha)

        plan.kernel_call(
            TEMPLATE.get_def("compound_click_probability_prepare"),
            [prepared_state, alpha, beta],
            kernel_name="compound_click_probability_prepare",
            global_size=alpha.shape,
            render_kwds=dict(
                mul_cc=functions.mul(alpha.dtype, alpha.dtype),
                exp_c=functions.exp(alpha.dtype),
                ))

        # Block size is limited by the amount of available local memory.
        # In some OpenCL implementations the number reported cannot actually be fully used
        # (because it's used by kernel arguments), so we're padding it a little.
        local_mem_size = device_params.local_mem_size
        max_elems = (local_mem_size - 256) // alpha.dtype.itemsize
        block_size = 2**helpers.log2(max_elems)

        # No reason to have block size larger than the number of modes
        block_size = min(block_size, helpers.bounding_power_of_2(modes))

        products_gsize = (samples, helpers.min_blocks(self._max_total_clicks + 1, block_size) * block_size)
        products = plan.temp_array_like(for_reduction)

        read_size = min(block_size, device_params.max_work_group_size)

        while read_size > 1:

            full_steps = modes // block_size
            remainder_size = modes % block_size

            try:
                plan.kernel_call(
                    TEMPLATE.get_def("compound_click_probability_aggregate"),
                    [products, prepared_state],
                    kernel_name="compound_click_probability_aggregate",
                    global_size=products_gsize,
                    local_size=(1, read_size,),
                    render_kwds=dict(
                        block_size=block_size,
                        read_size=read_size,
                        full_steps=full_steps,
                        remainder_size=remainder_size,
                        output_size=self._max_total_clicks + 1,
                        mul_cc=functions.mul(alpha.dtype, alpha.dtype),
                        add_cc=functions.add(alpha.dtype, alpha.dtype),
                        polar_unit=functions.polar_unit(dtypes.real_for(alpha.dtype)),
                        modes=self._system.modes,
                        max_total_clicks=self._max_total_clicks,
                        ))

            except OutOfResourcesError:
                read_size //= 2

            break

        reduction = Reduce(for_reduction, predicate_sum(alpha.dtype), axes=(0,))

        temp = plan.temp_array_like(reduction.parameter.output)

        plan.computation_call(reduction, temp, products)

        fft = FFT(temp)
        real_trf = Transformation([
            Parameter('output', Annotation(output, 'o')),
            Parameter('input', Annotation(temp, 'i')),
            ],
            """
                ${input.ctype} val = ${input.load_same};
                ${output.store_same}(val.x);
                """)
        fft.parameter.output.connect(real_trf, real_trf.input, output_p=real_trf.output)

        plan.computation_call(fft, output, temp, True)

        return plan