def _fuse_two_kernels(knla, knlb): from loopy.kernel import KernelState if knla.state != KernelState.INITIAL or knlb.state != KernelState.INITIAL: raise LoopyError("can only fuse kernels in INITIAL state") # {{{ fuse domains new_domains = knla.domains[:] for dom_b in knlb.domains: i_fuse = _find_fusable_loop_domain_index(dom_b, new_domains) if i_fuse is None: new_domains.append(dom_b) else: dom_a = new_domains[i_fuse] dom_a, dom_b = isl.align_two(dom_a, dom_b) shared_inames = list( set(dom_a.get_var_dict(dim_type.set)) & set(dom_b.get_var_dict(dim_type.set))) dom_a_s = dom_a.project_out_except(shared_inames, [dim_type.set]) dom_b_s = dom_a.project_out_except(shared_inames, [dim_type.set]) if not (dom_a_s <= dom_b_s and dom_b_s <= dom_a_s): raise LoopyError("kernels do not agree on domain of " "inames '%s'" % (",".join(shared_inames))) new_domain = dom_a & dom_b new_domains[i_fuse] = new_domain # }}} vng = knla.get_var_name_generator() b_var_renames = {} # {{{ fuse args new_args = knla.args[:] for b_arg in knlb.args: if b_arg.name not in knla.arg_dict: new_arg_name = vng(b_arg.name) if new_arg_name != b_arg.name: b_var_renames[b_arg.name] = var(new_arg_name) new_args.append(b_arg.copy(name=new_arg_name)) else: if b_arg != knla.arg_dict[b_arg.name]: raise LoopyError( "argument '{arg_name}' has inconsistent definition between " "the two kernels being merged ({arg_a} <-> {arg_b})" .format( arg_name=b_arg.name, arg_a=str(knla.arg_dict[b_arg.name]), arg_b=str(b_arg))) # }}} # {{{ fuse temporaries new_temporaries = knla.temporary_variables.copy() for b_name, b_tv in six.iteritems(knlb.temporary_variables): assert b_name == b_tv.name new_tv_name = vng(b_name) if new_tv_name != b_name: b_var_renames[b_name] = var(new_tv_name) assert new_tv_name not in new_temporaries new_temporaries[new_tv_name] = b_tv.copy(name=new_tv_name) # }}} knlb = _apply_renames_in_exprs(knlb, b_var_renames) from pymbolic.imperative.transform import \ fuse_instruction_streams_with_unique_ids new_instructions, old_b_id_to_new_b_id = \ fuse_instruction_streams_with_unique_ids( knla.instructions, knlb.instructions) # {{{ fuse assumptions assump_a = knla.assumptions assump_b = knlb.assumptions assump_a, assump_b = isl.align_two(assump_a, assump_b) shared_param_names = list( set(assump_a.get_var_dict(dim_type.set)) & set(assump_b.get_var_dict(dim_type.set))) assump_a_s = assump_a.project_out_except(shared_param_names, [dim_type.param]) assump_b_s = assump_a.project_out_except(shared_param_names, [dim_type.param]) if not (assump_a_s <= assump_b_s and assump_b_s <= assump_a_s): raise LoopyError("assumptions do not agree on kernels to be merged") new_assumptions = (assump_a & assump_b).params() # }}} from loopy.kernel import LoopKernel return LoopKernel( domains=new_domains, instructions=new_instructions, args=new_args, name="%s_and_%s" % (knla.name, knlb.name), preambles=_ordered_merge_lists(knla.preambles, knlb.preambles), preamble_generators=_ordered_merge_lists( knla.preamble_generators, knlb.preamble_generators), assumptions=new_assumptions, local_sizes=_merge_dicts( "local size", knla.local_sizes, knlb.local_sizes), temporary_variables=new_temporaries, iname_to_tags=_merge_dicts( "iname-to-tag mapping", knla.iname_to_tags, knlb.iname_to_tags), substitutions=_merge_dicts( "substitution", knla.substitutions, knlb.substitutions), function_manglers=_ordered_merge_lists( knla.function_manglers, knlb.function_manglers), symbol_manglers=_ordered_merge_lists( knla.symbol_manglers, knlb.symbol_manglers), iname_slab_increments=_merge_dicts( "iname slab increment", knla.iname_slab_increments, knlb.iname_slab_increments), loop_priority=knla.loop_priority.union(knlb.loop_priority), silenced_warnings=_ordered_merge_lists( knla.silenced_warnings, knlb.silenced_warnings), applied_iname_rewrites=_ordered_merge_lists( knla.applied_iname_rewrites, knlb.applied_iname_rewrites), index_dtype=_merge_values( "index dtype", knla.index_dtype, knlb.index_dtype), target=_merge_values( "target", knla.target, knlb.target), options=knla.options), old_b_id_to_new_b_id
def insn_inames(self, insn): return LoopKernel.insn_inames(self, insn)