예제 #1
0
def _check_supported(kernel_info):
    """
    call op's check_supported to check supported or not

    Args:
        kernel_info (dict): kernel info load by json string

    Returns:
        bool: check result, true or false
    """
    try:
        op_name = kernel_info['op_info']['name']
        is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape']
        te_set_version(kernel_info["op_info"]["socVersion"])
        impl_path = build_in_impl_path
        custom_flag = False
        if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
            op_impl_path = os.path.realpath(kernel_info['impl_path'])
            if os.path.isfile(op_impl_path):
                path, file_name = os.path.split(op_impl_path)
                op_name, _ = os.path.splitext(file_name)
                impl_path = path
                custom_flag = True
        if impl_path not in sys.path:
            sys.path.insert(0, impl_path)

        if custom_flag:
            op_module = __import__(op_name)
        elif is_dynamic_shape:
            op_module = __import__("impl.dynamic." + op_name, globals(),
                                   locals(), [op_name], 0)
        else:
            op_module = __import__("impl." + op_name, globals(), locals(),
                                   [op_name], 0)

        # get function
        if not hasattr(op_module, "check_supported"):
            return ""
        op_func = getattr(op_module, "check_supported", None)

        # call function
        inputs_args = get_args(kernel_info['op_info'], 'inputs')
        outputs_args = get_args(kernel_info['op_info'], 'outputs')
        attrs_args = get_args(kernel_info['op_info'], 'attrs')
        kernel_name = kernel_info['op_info']['kernel_name']
        if op_name in ("resize_nearest_neighbor_v2_grad_d",
                       "resize_bilinear_v2_grad"):
            attrs_args.pop(-1)
        ret = op_func(*inputs_args,
                      *outputs_args,
                      *attrs_args,
                      kernel_name=kernel_name)
        if isinstance(ret, tuple) and len(ret) == 2:
            ret = ret[0]

    except Exception as e:
        raise TBEException(str(e))

    return ret
예제 #2
0
def _initialize(impl_path):
    """Initialize"""
    te_set_version(ddk_version)
    if impl_path == "":
        op_module_name = build_in_impl_path
    else:
        op_module_name = impl_path
    if not op_module_name:
        raise ValueError("Can not find the env TBE_IMPL_PATH")

    sys.path.insert(0, op_module_name)
예제 #3
0
def _op_select_format(kernel_info):
    """
    call op's op_select_format to get op supported format

    Args:
        kernel_info (dict): kernel info load by json string

    Returns:
        op supported format
    """
    try:
        op_name = kernel_info['op_info']['name']
        te_set_version(kernel_info["op_info"]["socVersion"])
        impl_path = build_in_impl_path
        custom_flag = False
        if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
            op_impl_path = os.path.realpath(kernel_info['impl_path'])
            if os.path.isfile(op_impl_path):
                path, file_name = os.path.split(op_impl_path)
                op_name, _ = os.path.splitext(file_name)
                impl_path = path
                custom_flag = True
        if impl_path not in sys.path:
            sys.path.insert(0, impl_path)

        if custom_flag:
            op_module = __import__(op_name)
        else:
            op_module = __import__("impl." + op_name, globals(), locals(),
                                   [op_name], 0)
        # get function
        if not hasattr(op_module, "op_select_format"):
            return ""
        op_func = getattr(op_module, "op_select_format", None)

        # call function
        inputs_args = get_args(kernel_info['op_info'], 'inputs')
        outputs_args = get_args(kernel_info['op_info'], 'outputs')
        attrs_args = get_args(kernel_info['op_info'], 'attrs')
        kernel_name = kernel_info['op_info']['kernel_name']
        ret = op_func(*inputs_args,
                      *outputs_args,
                      *attrs_args,
                      kernel_name=kernel_name)

    except Exception as e:
        raise TBEException(str(e))

    return ret
예제 #4
0
def compile_fusion_op(json_str):
    """
    compile fusion op with input args json_str

    Args:
        json_str (str): op function input args

    Raises:
        Exception: If specific keyword is not found.
    """
    args = json.loads(json_str)
    te_set_version(args['fusion_op']["socVersion"])
    if 'fusion_op' not in args or not args['fusion_op']:
        raise ValueError("Json string Errors, key:fusion_op not found.")
    fusion_op_arg = args['fusion_op']
    return fusion_op(json.dumps(fusion_op_arg))
예제 #5
0
    def init_tune_interface(self, json_str, process_num):
        """
        Initialize tuner interface
        :param json_str: ori json
        :param process_num : process num for tuner
        :return: bool True or False
        """
        json_info = json.loads(json_str)
        soc_info = self.get_soc_info(json_info)
        cur_cce_product_params = te_set_version(*soc_info)
        if cur_cce_product_params is None:
            log.warning("Set Soc Info failed.")
        tune_mode = self.get_tune_mode(json_info)
        ret = self.parallel_compilation_init(soc_info, tune_mode, process_num)
        if not ret:
            log.error("Init parallel compilation env failed")
            return False

        return True
예제 #6
0
def build_op(build_type, json_str):
    """
    call op functions with function name and input args json_str

    Args:
        build_type : op function name
        json_str (str): op function input args

    Raises:
        Exception: If specific keyword is not found.
    """
    kernel_info = json.loads(json_str)
    check_kernel_info(kernel_info)
    te_set_version(kernel_info["op_info"]["socVersion"])
    op_name = kernel_info['op_info']['name']

    try:
        custom_flag = False
        if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
            impl_path = os.path.realpath(kernel_info['impl_path'])
            if os.path.isfile(impl_path):
                path, file_name = os.path.split(impl_path)
                op_name, _ = os.path.splitext(file_name)
                impl_path = path
                custom_flag = True
            else:
                impl_path = ""
        _initialize(impl_path)

        inputs_args = get_args(kernel_info['op_info'], 'inputs')
        outputs_args = get_args(kernel_info['op_info'], 'outputs')
        attrs_args = get_args(kernel_info['op_info'], 'attrs')
        kernel_name = kernel_info['op_info']['kernel_name']
        is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape']
        if is_dynamic_shape:
            _replace_range(inputs_args)
            _replace_range(outputs_args)

        if custom_flag:
            op_module = __import__(op_name)
        else:
            if is_dynamic_shape:
                op_module = __import__("impl.dynamic." + op_name, globals(),
                                       locals(), [op_name], 0)
            else:
                op_module = __import__("impl." + op_name, globals(), locals(),
                                       [op_name], 0)
        # get function
        if build_type == op_build:
            if custom_flag:
                py_fn_name = kernel_info['op_info']['name']
            else:
                py_fn_name = op_name
        else:
            raise ValueError(
                "function {} is not supported by Tbe op {}.".format(
                    build_type, op_name))
        op_func = getattr(op_module, py_fn_name, None)
        if op_func is None:
            raise ValueError(
                "Op:{} function {} is not supported by Tbe.".format(
                    op_name, build_type))

        # call function
        if op_name == "bounding_box_encode":
            return op_func(*inputs_args,
                           *outputs_args,
                           *attrs_args,
                           kernel_name_val=kernel_name)

        if is_dynamic_shape:
            with te.op.dynamic():
                op_func(*inputs_args,
                        *outputs_args,
                        *attrs_args,
                        kernel_name=kernel_name)
                return te.op.get_compile_info()
        else:
            return op_func(*inputs_args,
                           *outputs_args,
                           *attrs_args,
                           kernel_name=kernel_name)

    except Exception as e:
        raise RuntimeError(e)
예제 #7
0
def build_op(build_type, json_str, tune_mode=None):
    """
    call op functions with function name and input args json_str

    Args:
        build_type : op function name
        json_str (str): op function input args
        tune_mode (str): if use auto_tune

    Raises:
        Exception: If specific keyword is not found.
    """
    kernel_info = json.loads(json_str)
    check_kernel_info(kernel_info)
    te_set_version(kernel_info["op_info"]["socVersion"])
    op_name = kernel_info['op_info']['name']
    op_type = kernel_info['op_info']['Type']

    try:
        custom_flag = False
        if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
            impl_path = os.path.realpath(kernel_info['impl_path'])
            if os.path.isfile(impl_path):
                path, file_name = os.path.split(impl_path)
                op_name, _ = os.path.splitext(file_name)
                impl_path = path
                custom_flag = True
            else:
                impl_path = ""
        _initialize(impl_path)

        inputs_args = get_args(kernel_info['op_info'], 'inputs')
        outputs_args = get_args(kernel_info['op_info'], 'outputs')
        attrs_args = get_args(kernel_info['op_info'], 'attrs')
        kernel_name = kernel_info['op_info']['kernel_name']
        is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape']
        if is_dynamic_shape:
            _replace_range(inputs_args)
            _replace_range(outputs_args)

        if custom_flag:
            op_module = __import__(op_name)
        else:
            if is_dynamic_shape:
                op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
                op_module_name = "impl.dynamic." + op_name
            else:
                op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
                op_module_name = "impl." + op_name
        # get function
        if build_type == op_build:
            if custom_flag:
                py_fn_name = kernel_info['op_info']['name']
            else:
                py_fn_name = op_name
        else:
            raise ValueError("function {} is not supported by Tbe op {}.".format(build_type, op_name))
        op_func = getattr(op_module, py_fn_name, None)
        if op_func is None:
            raise ValueError("Op:{} function {} is not supported by Tbe.".format(op_name, build_type))

        # call function
        if is_dynamic_shape:
            import tbe.common.context.op_context as op_context
            with op_context.OpContext("dynamic"):
                op_info = operator_info.OpInfo(op_type, op_type)
                op_context.get_context().add_op_info(op_info)
                op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
                compile_info = op_context.get_context().get_compile_info()
                if tune_mode is not None:
                    return compile_info, (inputs_args, outputs_args, attrs_args), op_module_name
                return compile_info
        else:
            res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
            if tune_mode is not None:
                return None, (inputs_args, outputs_args, attrs_args), op_module_name
            return res

    except Exception as e:
        raise RuntimeError(e)