def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...]) -> Derivative: arguments = cpp_arguments(f) argument_names = tuple(a.name for a in arguments) argument_types = tuple(a.type for a in arguments) return_names = tuple(n if n != 'self' else 'result' for n in cpp.return_names(f)) return_types = tuple(cpp.return_type(r) for r in f.func.returns) formula, saved_inputs = saved_variables(formula, argument_names, argument_types, var_names) formula, saved_outputs = saved_variables(formula, return_names, return_types, var_names) # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): if i >= len(f.func.returns): raise RuntimeError( f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} ' f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.' ) return Derivative( formula=formula, var_names=var_names, saved_inputs=saved_inputs, saved_outputs=saved_outputs, )
def gen_differentiable_outputs( f: NativeFunction) -> List[DifferentiableOutput]: outputs: List[DifferentiableOutput] = [ DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret)) for name, ret in zip(cpp.return_names(f), f.func.returns) ] output_differentiability = info.output_differentiability if info else None if output_differentiability is not None: differentiable_outputs: List[DifferentiableOutput] = [] if False in output_differentiability and f.func.kind( ) == SchemaKind.inplace: raise RuntimeError( "output_differentiability=False for inplace operation (version_counter won't get updated)" ) for differentiable, output in zip(output_differentiability, outputs): if differentiable: differentiable_outputs.append(output) return differentiable_outputs candidate_differentiable_outputs = list( filter(lambda r: is_differentiable(r.name, r.type), outputs)) if uses_single_grad(info): return candidate_differentiable_outputs[:1] else: return candidate_differentiable_outputs
def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...]) -> Derivative: original_formula = formula arguments: List[NamedCType] = [a.nctype.remove_const_ref() for a in cpp_arguments(f)] return_names = tuple(n if n != 'self' else 'result' for n in cpp.return_names(f)) return_types = tuple(cpp.return_type(r).remove_const_ref() for r in f.func.returns) named_returns = [NamedCType(name, type) for name, type in zip(return_names, return_types)] formula, saved_inputs = saved_variables(formula, arguments, var_names) formula, saved_outputs = saved_variables(formula, named_returns, var_names) # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): if i >= len(f.func.returns): raise RuntimeError( f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} ' f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.' ) return Derivative( formula=formula, original_formula=original_formula, var_names=var_names, saved_inputs=saved_inputs, saved_outputs=saved_outputs, )
def gen_differentiable_outputs( fn: NativeFunctionWithDifferentiabilityInfo ) -> List[DifferentiableOutput]: f = fn.func info = fn.info outputs: List[DifferentiableOutput] = [ DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret).cpp_type()) for name, ret in zip(cpp.return_names(f), f.func.returns) ] output_differentiability = info.output_differentiability if info else None if output_differentiability is not None: if len(output_differentiability) != len(outputs): raise RuntimeError( f"The length of output_differentiability ({len(output_differentiability)}), " f"does not match the number of outputs ({len(outputs)}).") differentiable_outputs: List[DifferentiableOutput] = [] if False in output_differentiability and f.func.kind( ) == SchemaKind.inplace: raise RuntimeError( "output_differentiability=False for inplace operation (version_counter won't get updated)" ) for differentiable, output in zip(output_differentiability, outputs): if differentiable: differentiable_outputs.append(output) return differentiable_outputs candidate_differentiable_outputs = list( filter(lambda r: is_differentiable(r.name, r.type, info), outputs)) if uses_single_grad(info): return candidate_differentiable_outputs[:1] else: return candidate_differentiable_outputs
def compute_returns_yaml( f: NativeFunction) -> Tuple[List[Dict[str, str]], Dict[str, str]]: # Note [name and field_name] # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # To understand name_to_field_name, we must first talk about this # schema: # # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) # # There is something very odd about this schema: it is an out # variant of the function (that is to say, it will convert into # at::lstsq_out() in the C++ API), but the names of the output # return arguments don't match the keyword argument names of # the inputs. It TURNS OUT that in this situation, the historical # Declarations.yaml we want to output is this (abbreviated to # only show relevant fields): # # arguments: # ... # - field_name: solution # name: X # - field_name: QR # name: qr # ... # # returns: # - field_name: solution # name: X # - field_name: QR # name: qr # # The name of the return fields is stored in 'field_name', and the # name of the arguments is stored in 'name'. So when we process # arguments, we need a way to get at the corresponding return. At # the moment, this is most conveniently done by constructing a # mapping from name (the argument concept) to field_name (the # return concept) while processing return arguments, since we don't # directly maintain this correspondence in the modeling of function # schema itself. # # See also https://github.com/pytorch/pytorch/issues/43114 name_to_field_name: Dict[str, str] = {} # Compute the returns field of the YAML entry returns = [] for i, r in enumerate(f.func.returns): # If we have an inplace function, the return argument is # implicitly named self. # TODO: Consider incorporating this into the data model if f.func.name.name.inplace: assert i == 0, "illegal inplace function with multiple returns" name = 'self' # If we are out function, the name is the name of the # corresponding output function (r.name will get recorded # in field_name later.) elif f.func.is_out_fn(): name = f.func.out_arguments[i].name # If the return argument is explicitly named... elif r.name: name_conflict = any(r.name == a.name for a in f.func.schema_order_arguments()) if name_conflict and not f.func.is_out_fn(): name = f'{r.name}_return' else: name = r.name # If there is no explicit name, we just name the output result, # unless it's a multi-return, in which case it's result0, # result1, etc (zero-indexed) else: name = 'result' if len(f.func.returns) == 1 else f'result{i}' ret = { 'dynamic_type': dynamic_type(r.type), 'name': name, 'type': cpp.return_type(r), } if r.name: # See Note [name and field_name] ret['field_name'] = r.name if f.func.is_out_fn(): name_to_field_name[f.func.out_arguments[i].name] = r.name returns.append(ret) return returns, name_to_field_name
def compute_returns_yaml( f: NativeFunction) -> Tuple[List[Dict[str, str]], Dict[str, str]]: # Note [name and field_name] # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # To understand name_to_field_name, we must first talk about this # schema: # # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) # # There is something very odd about this schema: it is an out # variant of the function (that is to say, it will convert into # at::lstsq_out() in the C++ API), but the names of the output # return arguments don't match the keyword argument names of # the inputs. It TURNS OUT that in this situation, the historical # Declarations.yaml we want to output is this (abbreviated to # only show relevant fields): # # arguments: # ... # - field_name: solution # name: X # - field_name: QR # name: qr # ... # # returns: # - field_name: solution # name: X # - field_name: QR # name: qr # # The name of the return fields is stored in 'field_name', and the # name of the arguments is stored in 'name'. So when we process # arguments, we need a way to get at the corresponding return. At # the moment, this is most conveniently done by constructing a # mapping from name (the argument concept) to field_name (the # return concept) while processing return arguments, since we don't # directly maintain this correspondence in the modeling of function # schema itself. # # See also https://github.com/pytorch/pytorch/issues/43114 name_to_field_name: Dict[str, str] = {} # Compute the returns field of the YAML entry names = cpp.return_names(f) returns = [] for i, (r, name) in enumerate(zip(f.func.returns, names)): ret = { 'dynamic_type': dynamic_type(r.type), 'name': name, 'type': cpp.return_type(r).cpp_type(), } if r.name: # See Note [name and field_name] ret['field_name'] = r.name if f.func.is_out_fn(): name_to_field_name[f.func.arguments.out[i].name] = r.name returns.append(ret) return returns, name_to_field_name
def check_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str: # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] stmts_before_call: List[str] = [] stmts_after_call: List[str] = [] if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: return call # Check properties of inputs (enforce (1)) for unpacked_binding in unpacked_bindings: arg = unpacked_binding.name noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref() if noref_cpp_type == BaseCType(tensorListT): stmts_before_call += [ SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg) ] stmts_after_call += [ ENFORCE_SAME_TENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == ListCType(OptionalCType( BaseCType(tensorT))): stmts_before_call += [ SAVE_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), SAVE_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] stmts_after_call += [ ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == BaseCType(tensorT): stmts_before_call += [ SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg) ] stmts_after_call += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=arg, out_tensor_name=arg), ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg) ] assert (stmts_before_call and stmts_after_call) or (not stmts_before_call and not stmts_after_call) # Check properties of outputs (enforce (2), (3)) if not f.func.kind() in (SchemaKind.inplace, SchemaKind.out): base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)? aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None) if aliased_arg_name is not None: aliased_arg_name = unpacked_name(aliased_arg_name) for i, (ret, ret_name) in enumerate( zip(f.func.returns, cpp.return_names(f))): noref_cpp_type = cpp.return_type(ret).remove_const_ref() if noref_cpp_type == BaseCType(tensorT): if aliased_arg_name is not None: assert i == 0, "Expect non-CompositeImplicitAutograd view function {base} to return single output" stmts_after_call += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=aliased_arg_name, out_tensor_name=ret_name) ] else: if type_wrapper_name( f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT: stmts_after_call += [ ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE. substitute(tensor_name=ret_name, fn_name=type_wrapper_name(f)) ] if type_wrapper_name( f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: stmts_after_call += [ ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE. substitute(tensor_name=ret_name, fn_name=type_wrapper_name(f)) ] # Currently we don't have any functions that return the following types, but # we should update the checks once we do elif noref_cpp_type == ListCType( OptionalCType(BaseCType(tensorT))): raise AssertionError( f"Please add use_count checks for {noref_cpp_type}") elif noref_cpp_type == BaseCType(tensorListT): raise AssertionError( f"Please add use_count checks for {noref_cpp_type}") if stmts_before_call and stmts_after_call: call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + \ call + \ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call) return call