예제 #1
0
파일: fusion.py 프로젝트: tkerola/cupy
def _get_fusion(func, nin, reduce, post_map, identity, input_types, name):
    in_vars = [_FusionVar(i, t) for i, t in enumerate(input_types)]
    mem = _FusionMem(in_vars)
    in_refs = [_FusionRef(_, mem) for _ in in_vars]
    out_refs = func(*in_refs)
    out_refs = list(out_refs) if type(out_refs) == tuple else [out_refs]
    out_refs = [_ for _ in out_refs if _ is not None]
    out_refs = [_FusionRef(_normalize_arg(_, mem), mem) for _ in out_refs]
    out_vars = [_normalize_arg(copy(_), mem) for _ in out_refs]
    nout = len(out_vars)
    op_list = mem.op_list
    tmpvars = mem.var_list[nin:-nout] if nout > 0 else mem.var_list[nin:]

    in_params = ', '.join(_get_params(in_vars[:nin]))
    out_params = ', '.join(_get_params(out_vars))
    operation = ''.join(_get_declaration_from_var(_) for _ in tmpvars)
    operation += ''.join(_get_declaration_from_op(_) for _ in op_list)
    operation += '\n'.join(_get_operation_code(_) for _ in op_list)

    if reduce is None:
        if not out_params:
            in_params = ', '.join(_get_params(in_vars[:-1]))
            out_params = ', '.join(_get_params([in_vars[-1]]))
        submodules = _gather_submodules(op_list)
        submodule_code = ''.join(
            _get_submodule_code(_) for _ in submodules.values())
        return core.ElementwiseKernel(in_params,
                                      out_params,
                                      operation,
                                      preamble=submodule_code,
                                      name=name)
    else:
        if nout != 1:
            raise Exception("Wrong number of number of arguments")
        # pre-map
        pre_type = out_vars[0].ty
        pre_code = _get_pre_code(in_vars, out_vars, operation)

        # reduce
        reduce_op = _get_reduce_op(reduce._raw, pre_type)
        reduce_code = reduce_op[2][1]
        reduce_type = numpy.dtype(reduce_op[1][0])
        rtype = reduce_op[2][3]
        post_type = "type_in0_raw" if rtype is None else rtype
        pre_code += "typedef %s type_in0_raw;\n" % _dtype_to_ctype[reduce_type]

        # post-map
        post_in = [_FusionVar(0, reduce_type)]
        mem = _FusionMem(post_in)
        post_in_ref = [_FusionRef(_, mem) for _ in post_in]
        post_out = _normalize_arg(post_map(*post_in_ref), mem)
        if type(post_out) == tuple:
            raise Exception("Can't reduce a tuple")
        post_vars = mem.var_list
        post_ops = mem.op_list
        post_code = ''.join(
            _get_declaration_from_var(_) for _ in post_vars[1:])
        post_code += ''.join(_get_declaration_from_op(_) for _ in post_ops)
        post_code += '\n'.join(_get_operation_code(_) for _ in post_ops)
        post_code = _get_post_code(post_vars, post_code, post_out)
        post_code += ("typedef %s type_out0_raw;\n" %
                      _dtype_to_ctype[reduce_type])
        post_code += _get_fix_code(post_type, reduce_type, reduce_op[2][2])

        submodules = _gather_submodules(op_list + post_ops)
        submodule_code = ''.join(
            _get_submodule_code(v) for v in submodules.values())
        submodule_code += reduce._raw._preamble + pre_code + post_code
        operation_args = ['v' + str(i) for i in six.moves.range(nin)]
        operation = '_pre_map(' + ', '.join(operation_args) + ')'
        out_params = '%s res' % post_out.ty
        return core.ReductionKernel(in_params,
                                    out_params,
                                    operation,
                                    reduce_code,
                                    'res = _post_map(_post_fix(a))',
                                    identity,
                                    name=name,
                                    reduce_type=post_type,
                                    preamble=submodule_code)
예제 #2
0
    def get_fusion(self, func, in_params_info, name):
        """This generates CUDA kernel from the given function and dtypes.

        This function generates ElementwiseKernel or ReductioKernel from the
        given function and the list of dtypes of parameters.

        Args:
            func (function): The function to be fused.
            in_types (list of dtypes): The list of dtypes of input parameters.
            name (str): The name of the kernel.

        Return value (tuple of ElementwiseKernel/ReductionKernel and dict):
            The second element of return values is kwargs that will give into
            the elementwise kernel or reduction kernel.
        """
        in_dtypes = [t for t, d in in_params_info]
        in_ndims = [d for t, d in in_params_info]
        self.ndim = max(in_ndims)
        in_params = [self._fresh_premap_param(t) for t in in_dtypes]
        in_pvars = [
            _FusionVarScalar(v, d, False) if d == -1 else _FusionVarArray(
                v, d, False) for v, d in zip(in_params, in_ndims)
        ]
        return_value = func(*in_pvars)

        if isinstance(return_value, tuple):
            return_tuple = True
            no_return = False
            out_pvars = return_value
        elif isinstance(return_value, (_FusionVarScalar, _FusionVarArray)):
            return_tuple = False
            no_return = False
            out_pvars = [return_value]
        elif return_value is None:
            return_tuple = False
            no_return = True
            out_pvars = []
        else:
            raise TypeError('Fusion function can\'t return {}'.format(
                type(return_value)))

        out_pvars = [_ for _ in out_pvars if _ is not None]
        out_cvars = [self._get_fusion_var(_)._var for _ in out_pvars]

        out_dtypes = [_.dtype for _ in out_pvars]
        out_params = [self._fresh_premap_param(t) for t in out_dtypes]

        in_params_code = ', '.join(var.declaration_in_param()
                                   for var in in_params)
        out_params_code = ', '.join(var.declaration_out_param()
                                    for var in out_params)

        operation = self._emit_operation_code()
        submodule_code = self._emit_submodules_code()

        if self.reduce_op is None:
            operation += ' '.join('{} = {};'.format(t, s)
                                  for s, t in zip(out_cvars, out_params))
            kernel = core.ElementwiseKernel(in_params_code,
                                            out_params_code,
                                            operation,
                                            preamble=submodule_code,
                                            return_tuple=return_tuple,
                                            no_return=no_return,
                                            name=name)
            return kernel, {}
        else:
            _, (postmap_type, ), (_, reduce_code, postmap_cast_code,
                                  reduce_ctype) = self.reduce_op
            if reduce_ctype is None:
                reduce_ctype = 'type_in0_raw'

            postmap_dtype = numpy.dtype(postmap_type)
            postmap_ctype = _dtype_to_ctype[postmap_dtype]

            postmap_code = '// {} operations\n'.format(
                len(self.postmap_op_list))
            postmap_code += ''.join(v.declaration()
                                    for v in self.postmap_local_list)
            postmap_code += ''.join(op.declaration_args()
                                    for op in self.postmap_op_list)
            postmap_code += ''.join(op.code() for op in self.postmap_op_list)
            postmap_code += ' '.join('{} = {};'.format(t, s)
                                     for s, t in zip(out_cvars, out_params))

            submodule_code += self._emit_premap_code(in_params, operation)
            submodule_code += 'typedef {} type_in0_raw;\n'.format(
                postmap_ctype)
            submodule_code += 'typedef {} type_out0_raw;\n'.format(
                postmap_ctype)
            submodule_code += self._emit_postmap_cast_code(
                reduce_ctype, postmap_dtype, postmap_cast_code)
            submodule_code += self._emit_postmap_code(out_params, postmap_code)

            kernel = core.ReductionKernel(
                in_params_code,
                out_params_code,
                '_pre_map({})'.format(', '.join([repr(p) for p in in_params])),
                reduce_code,
                '_post_map(_postmap_cast(a), {})'.format(', '.join(
                    [repr(p) for p in out_params])),
                self.reduce_identity,
                name=name,
                reduce_type=reduce_ctype,
                preamble=submodule_code)
            return kernel, self.reduce_kwargs