Exemple #1
0
def _fuse_two_kernels(knla, knlb):
    from loopy.kernel import KernelState
    if knla.state != KernelState.INITIAL or knlb.state != KernelState.INITIAL:
        raise LoopyError("can only fuse kernels in INITIAL state")

    # {{{ fuse domains

    new_domains = knla.domains[:]

    for dom_b in knlb.domains:
        i_fuse = _find_fusable_loop_domain_index(dom_b, new_domains)
        if i_fuse is None:
            new_domains.append(dom_b)
        else:
            dom_a = new_domains[i_fuse]
            dom_a, dom_b = isl.align_two(dom_a, dom_b)

            shared_inames = list(
                    set(dom_a.get_var_dict(dim_type.set))
                    &
                    set(dom_b.get_var_dict(dim_type.set)))

            dom_a_s = dom_a.project_out_except(shared_inames, [dim_type.set])
            dom_b_s = dom_a.project_out_except(shared_inames, [dim_type.set])

            if not (dom_a_s <= dom_b_s and dom_b_s <= dom_a_s):
                raise LoopyError("kernels do not agree on domain of "
                        "inames '%s'" % (",".join(shared_inames)))

            new_domain = dom_a & dom_b

            new_domains[i_fuse] = new_domain

    # }}}

    vng = knla.get_var_name_generator()
    b_var_renames = {}

    # {{{ fuse args

    new_args = knla.args[:]
    for b_arg in knlb.args:
        if b_arg.name not in knla.arg_dict:
            new_arg_name = vng(b_arg.name)

            if new_arg_name != b_arg.name:
                b_var_renames[b_arg.name] = var(new_arg_name)

            new_args.append(b_arg.copy(name=new_arg_name))
        else:
            if b_arg != knla.arg_dict[b_arg.name]:
                raise LoopyError(
                        "argument '{arg_name}' has inconsistent definition between "
                        "the two kernels being merged ({arg_a} <-> {arg_b})"
                        .format(
                            arg_name=b_arg.name,
                            arg_a=str(knla.arg_dict[b_arg.name]),
                            arg_b=str(b_arg)))

    # }}}

    # {{{ fuse temporaries

    new_temporaries = knla.temporary_variables.copy()
    for b_name, b_tv in six.iteritems(knlb.temporary_variables):
        assert b_name == b_tv.name

        new_tv_name = vng(b_name)

        if new_tv_name != b_name:
            b_var_renames[b_name] = var(new_tv_name)

        assert new_tv_name not in new_temporaries
        new_temporaries[new_tv_name] = b_tv.copy(name=new_tv_name)

    # }}}

    knlb = _apply_renames_in_exprs(knlb, b_var_renames)

    from pymbolic.imperative.transform import \
            fuse_instruction_streams_with_unique_ids
    new_instructions, old_b_id_to_new_b_id = \
            fuse_instruction_streams_with_unique_ids(
                    knla.instructions, knlb.instructions)

    # {{{ fuse assumptions

    assump_a = knla.assumptions
    assump_b = knlb.assumptions
    assump_a, assump_b = isl.align_two(assump_a, assump_b)

    shared_param_names = list(
            set(assump_a.get_var_dict(dim_type.set))
            &
            set(assump_b.get_var_dict(dim_type.set)))

    assump_a_s = assump_a.project_out_except(shared_param_names, [dim_type.param])
    assump_b_s = assump_a.project_out_except(shared_param_names, [dim_type.param])

    if not (assump_a_s <= assump_b_s and assump_b_s <= assump_a_s):
        raise LoopyError("assumptions do not agree on kernels to be merged")

    new_assumptions = (assump_a & assump_b).params()

    # }}}

    from loopy.kernel import LoopKernel
    return LoopKernel(
            domains=new_domains,
            instructions=new_instructions,
            args=new_args,
            name="%s_and_%s" % (knla.name, knlb.name),
            preambles=_ordered_merge_lists(knla.preambles, knlb.preambles),
            preamble_generators=_ordered_merge_lists(
                knla.preamble_generators, knlb.preamble_generators),
            assumptions=new_assumptions,
            local_sizes=_merge_dicts(
                "local size", knla.local_sizes, knlb.local_sizes),
            temporary_variables=new_temporaries,
            iname_to_tags=_merge_dicts(
                "iname-to-tag mapping",
                knla.iname_to_tags,
                knlb.iname_to_tags),
            substitutions=_merge_dicts(
                "substitution",
                knla.substitutions,
                knlb.substitutions),
            function_manglers=_ordered_merge_lists(
                knla.function_manglers,
                knlb.function_manglers),
            symbol_manglers=_ordered_merge_lists(
                knla.symbol_manglers,
                knlb.symbol_manglers),

            iname_slab_increments=_merge_dicts(
                "iname slab increment",
                knla.iname_slab_increments,
                knlb.iname_slab_increments),
            loop_priority=knla.loop_priority.union(knlb.loop_priority),
            silenced_warnings=_ordered_merge_lists(
                knla.silenced_warnings,
                knlb.silenced_warnings),
            applied_iname_rewrites=_ordered_merge_lists(
                knla.applied_iname_rewrites,
                knlb.applied_iname_rewrites),
            index_dtype=_merge_values(
                "index dtype",
                knla.index_dtype,
                knlb.index_dtype),
            target=_merge_values(
                "target",
                knla.target,
                knlb.target),
            options=knla.options), old_b_id_to_new_b_id
Exemple #2
0
 def insn_inames(self, insn):
     return LoopKernel.insn_inames(self, insn)