示例#1
0
    def __init__(self, func: FunctionSchema):

        positional_arg_types = []
        for arg_field in ["pre_self_positional",
                          "self_arg",
                          "post_self_positional"]:
            if arg_field == "self_arg" and func.arguments.self_arg is not None:
                arg = getattr(func.arguments, "self_arg").argument
                positional_arg_types.append(NamedCType(arg.name, process_ir_type(arg.type)))
            elif getattr(func.arguments, arg_field) is not None:
                positional_arg_types.extend([
                    NamedCType(
                        arg.name,
                        process_ir_type(arg.type)) for arg in getattr(func.arguments, arg_field)])
        self.positional_arg_types = tuple(positional_arg_types)

        keyword_arg_types = []
        for arg_field in ["pre_tensor_options_kwarg_only",
                          "tensor_options",
                          "post_tensor_options_kwarg_only",
                          "out"]:
            curr_args = getattr(func.arguments, arg_field)
            if curr_args is not None:
                if isinstance(curr_args, TensorOptionsArguments):
                    curr_args = curr_args.all()
                keyword_arg_types.extend([NamedCType(arg.name, process_ir_type(arg.type)) for arg in curr_args])
        self.keyword_arg_types = tuple(keyword_arg_types)
        self.name = func.name
        self.returns = func.returns
        self.wrapped_scalar_names = [arg.name for arg in func.schema_order_arguments() if isWrappedScalarType(arg.type)]
示例#2
0
def name(func: FunctionSchema) -> str:
    name = str(func.name.name)
    # TODO: delete this!
    if func.is_out_fn():
        name += '_out'
    if func.name.overload_name:
        name += f'_{func.name.overload_name}'
    return name
示例#3
0
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
    name = str(func.name.name)
    if func.is_out_fn():
        if faithful_name_for_out_overloads:
            name += '_outf'
        else:
            name += '_out'

    return name
示例#4
0
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
    assert func.is_out_fn(
    ), "ufunc.kernel_name should only be invoked on out schemas"
    return f"ufunc_{func.name.name}_{dispatch_key}"
示例#5
0
def arguments(func: FunctionSchema) -> List[Binding]:
    args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
    args.extend(func.arguments.non_out)
    args.extend(func.arguments.out)
    return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())]