コード例 #1
0
ファイル: native.py プロジェクト: huaxz1986/pytorch
def name(func: FunctionSchema) -> str:
    name = str(func.name.name)
    # TODO: delete this!
    if func.is_out_fn():
        name += "_out"
    if func.name.overload_name:
        name += f"_{func.name.overload_name}"
    return name
コード例 #2
0
ファイル: native.py プロジェクト: huaxz1986/pytorch
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
    args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
    args.extend(func.arguments.non_out)
    args.extend(func.arguments.out)
    return [
        r for arg in args
        for r in argument(arg, symint=symint, is_out=func.is_out_fn())
    ]
コード例 #3
0
ファイル: cpp.py プロジェクト: Mu-L/pytorch
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
    name = str(func.name.name)
    if func.is_symint_fn():
        name += "_symint"
    if func.is_out_fn():
        if faithful_name_for_out_overloads:
            name += "_outf"
        else:
            name += "_out"

    return name
コード例 #4
0
ファイル: generator.py プロジェクト: Mu-L/pytorch
def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
    assert not schema.is_out_fn()
    schema_name = schema.name.name.base
    arg_map = {}
    for arg in schema.schema_order_arguments():
        test_value_exp = test_value_expression(arg.type, index, schema_name)
        arg_map[arg.name] = test_value_exp
    config.override_test_values(arg_map, schema_name, index)
    arg_populations = []
    for arg_name, arg_value in arg_map.items():
        arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
    return ";\n    ".join(arg_populations) + ";"
コード例 #5
0
ファイル: ufunc.py プロジェクト: huaxz1986/pytorch
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
    assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
    return f"ufunc_{func.name.name}_{dispatch_key}"
コード例 #6
0
ファイル: generator.py プロジェクト: Mu-L/pytorch
def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
    assert not schema.is_out_fn()
    return ",".join(f"{arg.name}{index}"
                    for arg in schema.schema_order_arguments())