Example #1
0
    def compile(self, *args, **kwargs) -> AscendOpKernel:
        """
        compile
        """
        import tbe  # 'pylint: disable=import-outside-toplevel
        import tbe.common.context.op_info as operator_info  # 'pylint: disable=import-outside-toplevel
        op_func = self._load_op_func()
        try:
            with tbe.common.context.op_context.OpContext("dynamic"):
                op_info = operator_info.OpInfo(self.op_type, self.op_type)
                tbe.common.context.op_context.get_context().add_op_info(
                    op_info)
                op_func(*args, **kwargs)
                compile_info = tbe.common.context.get_context(
                ).get_compile_info()
        except BaseException as compile_err:
            raise RuntimeError("Compile op failed.") from compile_err

        kernel_name = kwargs.get("kernel_name")
        kernel_meta_dir = os.path.realpath("./kernel_meta")
        bin_path = os.path.join(kernel_meta_dir, kernel_name + ".o")
        json_path = os.path.join(kernel_meta_dir, kernel_name + ".json")
        if not os.path.exists(bin_path) or not os.path.exists(json_path):
            raise RuntimeError(
                "Compile op failed, .o or .json is not generate successful.")

        kernel = AscendOpKernel(bin_path, json_path)
        kernel.set_compile_info(compile_info)

        kernel_inputs, kernel_outputs = self._pick_kernel_args(args)
        kernel.set_input_info(kernel_inputs)
        kernel.set_output_info(kernel_outputs)
        return kernel
Example #2
0
def build_op(build_type, json_str, tune_mode=None):
    """
    call op functions with function name and input args json_str

    Args:
        build_type : op function name
        json_str (str): op function input args
        tune_mode (str): if use auto_tune

    Raises:
        Exception: If specific keyword is not found.
    """
    kernel_info = json.loads(json_str)
    check_kernel_info(kernel_info)
    te_set_version(kernel_info["op_info"]["socVersion"])
    op_name = kernel_info['op_info']['name']
    op_type = kernel_info['op_info']['Type']

    try:
        custom_flag = False
        if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
            impl_path = os.path.realpath(kernel_info['impl_path'])
            if os.path.isfile(impl_path):
                path, file_name = os.path.split(impl_path)
                op_name, _ = os.path.splitext(file_name)
                impl_path = path
                custom_flag = True
            else:
                impl_path = ""
        _initialize(impl_path)

        inputs_args = get_args(kernel_info['op_info'], 'inputs')
        outputs_args = get_args(kernel_info['op_info'], 'outputs')
        attrs_args = get_args(kernel_info['op_info'], 'attrs')
        kernel_name = kernel_info['op_info']['kernel_name']
        is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape']
        if is_dynamic_shape:
            _replace_range(inputs_args)
            _replace_range(outputs_args)

        if custom_flag:
            op_module = __import__(op_name)
        else:
            if is_dynamic_shape:
                op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
                op_module_name = "impl.dynamic." + op_name
            else:
                op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
                op_module_name = "impl." + op_name
        # get function
        if build_type == op_build:
            if custom_flag:
                py_fn_name = kernel_info['op_info']['name']
            else:
                py_fn_name = op_name
        else:
            raise ValueError("function {} is not supported by Tbe op {}.".format(build_type, op_name))
        op_func = getattr(op_module, py_fn_name, None)
        if op_func is None:
            raise ValueError("Op:{} function {} is not supported by Tbe.".format(op_name, build_type))

        # call function
        if is_dynamic_shape:
            import tbe.common.context.op_context as op_context
            with op_context.OpContext("dynamic"):
                op_info = operator_info.OpInfo(op_type, op_type)
                op_context.get_context().add_op_info(op_info)
                op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
                compile_info = op_context.get_context().get_compile_info()
                if tune_mode is not None:
                    return compile_info, (inputs_args, outputs_args, attrs_args), op_module_name
                return compile_info
        else:
            res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
            if tune_mode is not None:
                return None, (inputs_args, outputs_args, attrs_args), op_module_name
            return res

    except Exception as e:
        raise RuntimeError(e)