예제 #1
0
    def test_indexed_grads(self) -> None:
        schema = tools.codegen.model.FunctionSchema.parse(
            'func(Tensor a, Tensor b) -> (Tensor x, Tensor y)')
        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
                                              func=schema)

        derivative = load_derivatives.create_derivative(
            native_function,
            formula='func_backward(grads[0], grads[1])',
            var_names=(),
            available_named_gradients=['grad_x', 'grad_y'])
        self.assertSetEqual(derivative.named_gradients, set())
예제 #2
0
    def test_named_grads(self) -> None:
        schema = torchgen.model.FunctionSchema.parse(
            "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)")
        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
                                              func=schema)

        derivative = load_derivatives.create_derivative(
            native_function,
            formula="func_backward(grad_x, grad_y)",
            var_names=(),
            available_named_gradients=["grad_x", "grad_y"],
        )
        self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"})