def _get_fusion(func, nin, reduce, post_map, identity, input_types, name): in_vars = [_FusionVar(i, t) for i, t in enumerate(input_types)] mem = _FusionMem(in_vars) in_refs = [_FusionRef(_, mem) for _ in in_vars] out_refs = func(*in_refs) out_refs = list(out_refs) if type(out_refs) == tuple else [out_refs] out_refs = [_ for _ in out_refs if _ is not None] out_refs = [_FusionRef(_normalize_arg(_, mem), mem) for _ in out_refs] out_vars = [_normalize_arg(copy(_), mem) for _ in out_refs] nout = len(out_vars) op_list = mem.op_list tmpvars = mem.var_list[nin:-nout] if nout > 0 else mem.var_list[nin:] in_params = ', '.join(_get_params(in_vars[:nin])) out_params = ', '.join(_get_params(out_vars)) operation = ''.join(_get_declaration_from_var(_) for _ in tmpvars) operation += ''.join(_get_declaration_from_op(_) for _ in op_list) operation += '\n'.join(_get_operation_code(_) for _ in op_list) if reduce is None: if not out_params: in_params = ', '.join(_get_params(in_vars[:-1])) out_params = ', '.join(_get_params([in_vars[-1]])) submodules = _gather_submodules(op_list) submodule_code = ''.join( _get_submodule_code(_) for _ in submodules.values()) return core.ElementwiseKernel(in_params, out_params, operation, preamble=submodule_code, name=name) else: if nout != 1: raise Exception("Wrong number of number of arguments") # pre-map pre_type = out_vars[0].ty pre_code = _get_pre_code(in_vars, out_vars, operation) # reduce reduce_op = _get_reduce_op(reduce._raw, pre_type) reduce_code = reduce_op[2][1] reduce_type = numpy.dtype(reduce_op[1][0]) rtype = reduce_op[2][3] post_type = "type_in0_raw" if rtype is None else rtype pre_code += "typedef %s type_in0_raw;\n" % _dtype_to_ctype[reduce_type] # post-map post_in = [_FusionVar(0, reduce_type)] mem = _FusionMem(post_in) post_in_ref = [_FusionRef(_, mem) for _ in post_in] post_out = _normalize_arg(post_map(*post_in_ref), mem) if type(post_out) == tuple: raise Exception("Can't reduce a tuple") post_vars = mem.var_list post_ops = mem.op_list post_code = ''.join( _get_declaration_from_var(_) for _ in post_vars[1:]) post_code += ''.join(_get_declaration_from_op(_) for _ in post_ops) post_code += '\n'.join(_get_operation_code(_) for _ in post_ops) post_code = _get_post_code(post_vars, post_code, post_out) post_code += ("typedef %s type_out0_raw;\n" % _dtype_to_ctype[reduce_type]) post_code += _get_fix_code(post_type, reduce_type, reduce_op[2][2]) submodules = _gather_submodules(op_list + post_ops) submodule_code = ''.join( _get_submodule_code(v) for v in submodules.values()) submodule_code += reduce._raw._preamble + pre_code + post_code operation_args = ['v' + str(i) for i in six.moves.range(nin)] operation = '_pre_map(' + ', '.join(operation_args) + ')' out_params = '%s res' % post_out.ty return core.ReductionKernel(in_params, out_params, operation, reduce_code, 'res = _post_map(_post_fix(a))', identity, name=name, reduce_type=post_type, preamble=submodule_code)
def get_fusion(self, func, in_params_info, name): """This generates CUDA kernel from the given function and dtypes. This function generates ElementwiseKernel or ReductioKernel from the given function and the list of dtypes of parameters. Args: func (function): The function to be fused. in_types (list of dtypes): The list of dtypes of input parameters. name (str): The name of the kernel. Return value (tuple of ElementwiseKernel/ReductionKernel and dict): The second element of return values is kwargs that will give into the elementwise kernel or reduction kernel. """ in_dtypes = [t for t, d in in_params_info] in_ndims = [d for t, d in in_params_info] self.ndim = max(in_ndims) in_params = [self._fresh_premap_param(t) for t in in_dtypes] in_pvars = [ _FusionVarScalar(v, d, False) if d == -1 else _FusionVarArray( v, d, False) for v, d in zip(in_params, in_ndims) ] return_value = func(*in_pvars) if isinstance(return_value, tuple): return_tuple = True no_return = False out_pvars = return_value elif isinstance(return_value, (_FusionVarScalar, _FusionVarArray)): return_tuple = False no_return = False out_pvars = [return_value] elif return_value is None: return_tuple = False no_return = True out_pvars = [] else: raise TypeError('Fusion function can\'t return {}'.format( type(return_value))) out_pvars = [_ for _ in out_pvars if _ is not None] out_cvars = [self._get_fusion_var(_)._var for _ in out_pvars] out_dtypes = [_.dtype for _ in out_pvars] out_params = [self._fresh_premap_param(t) for t in out_dtypes] in_params_code = ', '.join(var.declaration_in_param() for var in in_params) out_params_code = ', '.join(var.declaration_out_param() for var in out_params) operation = self._emit_operation_code() submodule_code = self._emit_submodules_code() if self.reduce_op is None: operation += ' '.join('{} = {};'.format(t, s) for s, t in zip(out_cvars, out_params)) kernel = core.ElementwiseKernel(in_params_code, out_params_code, operation, preamble=submodule_code, return_tuple=return_tuple, no_return=no_return, name=name) return kernel, {} else: _, (postmap_type, ), (_, reduce_code, postmap_cast_code, reduce_ctype) = self.reduce_op if reduce_ctype is None: reduce_ctype = 'type_in0_raw' postmap_dtype = numpy.dtype(postmap_type) postmap_ctype = _dtype_to_ctype[postmap_dtype] postmap_code = '// {} operations\n'.format( len(self.postmap_op_list)) postmap_code += ''.join(v.declaration() for v in self.postmap_local_list) postmap_code += ''.join(op.declaration_args() for op in self.postmap_op_list) postmap_code += ''.join(op.code() for op in self.postmap_op_list) postmap_code += ' '.join('{} = {};'.format(t, s) for s, t in zip(out_cvars, out_params)) submodule_code += self._emit_premap_code(in_params, operation) submodule_code += 'typedef {} type_in0_raw;\n'.format( postmap_ctype) submodule_code += 'typedef {} type_out0_raw;\n'.format( postmap_ctype) submodule_code += self._emit_postmap_cast_code( reduce_ctype, postmap_dtype, postmap_cast_code) submodule_code += self._emit_postmap_code(out_params, postmap_code) kernel = core.ReductionKernel( in_params_code, out_params_code, '_pre_map({})'.format(', '.join([repr(p) for p in in_params])), reduce_code, '_post_map(_postmap_cast(a), {})'.format(', '.join( [repr(p) for p in out_params])), self.reduce_identity, name=name, reduce_type=reduce_ctype, preamble=submodule_code) return kernel, self.reduce_kwargs