def test_indexed_grads(self) -> None: schema = tools.codegen.model.FunctionSchema.parse( 'func(Tensor a, Tensor b) -> (Tensor x, Tensor y)') native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) derivative = load_derivatives.create_derivative( native_function, formula='func_backward(grads[0], grads[1])', var_names=(), available_named_gradients=['grad_x', 'grad_y']) self.assertSetEqual(derivative.named_gradients, set())
def test_named_grads(self) -> None: schema = torchgen.model.FunctionSchema.parse( "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)") native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) derivative = load_derivatives.create_derivative( native_function, formula="func_backward(grad_x, grad_y)", var_names=(), available_named_gradients=["grad_x", "grad_y"], ) self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"})