def generate_out_args_from_schema( func: FunctionSchema, ) -> Tuple[List[Return], List[Argument]]: # More of a sanity check - our existing restrictions on schemas should enforce that # mutable schema kinds never return their mutable arguments. assert not any( r.annotation is not None and r.annotation.is_write for r in func.returns ) tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()] assert len(tensorlike_rets) > 0 used_annotations = concatMap( lambda a: [] if a.annotation is None else a.annotation.alias_set, func.arguments.flat_all, ) valid_annotations = [ x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations ] all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns) new_out_args: List[Argument] = [] # The end result of new_returns is that: # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added. # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any). new_returns: List[Return] = [] for (i, r) in enumerate(func.returns): if r.type.is_tensor_like(): new_out = Argument( name="out" if len(func.returns) == 1 else f"out{i}", type=r.type, default=None, annotation=Annotation.parse(f"{valid_annotations[i]}!"), ) new_out_args.append(new_out) if all_rets_are_tensors: # The convention for out= schemas is that they only return their out arguments # if the return is a plain Tensor (or if it's a tuple of plain Tensors) new_ret = Return( name=None, type=new_out.type, annotation=new_out.annotation ) new_returns.append(new_ret) else: new_returns.append(r) return new_returns, new_out_args
def self_to_out_signature(func: FunctionSchema) -> FunctionSchema: # Generating an out= schema from an inplace schema. assert func.kind() == SchemaKind.inplace assert func.arguments.self_arg is not None # The new out= schema has: # - a new out argument with the same type as "func" (but with a mutable annotation) # - The returns (if any) now alias the out= argument instead of "func" # - an "out" overload name return FunctionSchema( name=func.name.remove_inplace().with_overload( "out" if not func.name.overload_name else f"{func.name.overload_name}_out"), arguments=func.arguments.remove_self_annotation().with_out_args([ Argument( name="out", type=func.arguments.self_arg.argument.type, default=None, annotation=func.arguments.self_arg.argument.annotation, ) ]), returns=func.returns, )
# - While the forward lambda just directly calls into the at::_ops API # (following the dispatcher convention), the logic here for the reverse lambda # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). # The lambdas generated for each view op in the functionalization pass are of the form # [capture_arguments](outer_arguments) -> returns_type { # return name(inner_arguments); # } # Define some specific lambda input arguments. base_binding = Binding( name="base", nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), argument=Argument(name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None, ) mutated_view_binding = Binding( name="mutated_view", nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), argument=Argument(name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None, ) mutated_view_idx_binding = Binding( name="mutated_view_idx",
def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema: # Generating an out= schema from a mutable schema. assert func.kind() == SchemaKind.mutable # The new out= schema has: # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments # (if the argument is a tensor then we also return it for method chaining, # otherwise we return nothing) # - an "out" overload name # # Note that: # (1) This also means that we can *only* generate an out= variant from a mutable schema # if the mutable schema has at least one tensor-like non-aliasing return. # (2) The generated out= variant still has mutable positional arguments, # but if necessary we could probably add another out= variant that also # functionalizes the mutable arguments (a functional_out variant) # More of a sanity check - our existing restrictions on schemas should enforce that # mutable schema kinds never return their mutable arguments. assert not any(r.annotation is not None and r.annotation.is_write for r in func.returns) tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()] assert len(tensorlike_rets) > 0 used_annotations = concatMap( lambda a: [] if a.annotation is None else a.annotation.alias_set, func.arguments.flat_all, ) valid_annotations = [ x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations ] all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns) new_out_args: List[Argument] = [] # The end result of new_returns is that: # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added. # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any). new_returns: List[Return] = [] for (i, r) in enumerate(func.returns): if r.type.is_tensor_like(): new_out = Argument( name=f"out{i}", type=r.type, default=None, annotation=Annotation.parse(f"{valid_annotations[i]}!"), ) new_out_args.append(new_out) if all_rets_are_tensors: # The convention for out= schemas is that they only return their out arguments # if the return is a plain Tensor (or if it's a tuple of plain Tensors) new_ret = Return(name=None, type=new_out.type, annotation=new_out.annotation) new_returns.append(new_ret) else: new_returns.append(r) return FunctionSchema( name=func.name.remove_inplace().with_overload( "out" if not func.name.overload_name else f"{func.name.overload_name}_out"), arguments=func.arguments.with_out_args(new_out_args), returns=tuple(new_returns), )