Exemplo n.º 1
0
    def test_non_differentiable_output_invalid_type(self) -> None:
        specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
        schema = torchgen.model.FunctionSchema.parse(specification)
        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
                                              func=schema)

        _, differentiability_info = load_derivatives.create_differentiability_info(
            defn_dict={
                "name": specification,
                "dispatch": {
                    "Default": {
                        "a": "grad_x",
                        "b": "grad_z",
                    }
                },
            },
            functions_by_signature={schema.signature(): [native_function]},
            functions_by_schema={specification: native_function},
            op_counter=typing.Counter[str](),
            used_dispatch_keys=set(),
        )
        definition = gen_autograd_functions.process_function(
            differentiability_info["Default"],
            gen_autograd_functions.FUNCTION_DEFINITION,
        )
        # grad_z should map to grads[1], not grads[2] because output 1
        # (y) is not differentiable.
        assert "grad_z = grads[2]" not in definition
        assert "grad_z = grads[1]" in definition
Exemplo n.º 2
0
    def test_non_differentiable_output_invalid_type(self) -> None:
        specification = 'func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)'
        schema = tools.codegen.model.FunctionSchema.parse(specification)
        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
                                              func=schema)

        differentiability_info = load_derivatives.create_differentiability_info(
            defn={
                'name': specification,
                'a': 'grad_x',
                'b': 'grad_z',
            },
            functions_by_signature={schema.signature(): [native_function]},
            functions_by_schema={specification: native_function},
            op_counter=typing.Counter[str](),
        )
        definition = gen_autograd_functions.process_function(
            differentiability_info, gen_autograd_functions.FUNCTION_DEFINITION)
        # grad_z should map to grads[1], not grads[2] because output 1
        # (y) is not differentiable.
        assert 'grad_z = grads[2]' not in definition
        assert 'grad_z = grads[1]' in definition
Exemplo n.º 3
0
    def test_non_differentiable_output_output_differentiability(self) -> None:
        specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
        schema = torchgen.model.FunctionSchema.parse(specification)
        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
                                              func=schema)

        differentiability_info = load_derivatives.create_differentiability_info(
            defn={
                "name": specification,
                "a": "grad_x",
                "b": "grad_z",
                "output_differentiability": [True, False, True],
            },
            functions_by_signature={schema.signature(): [native_function]},
            functions_by_schema={specification: native_function},
            op_counter=typing.Counter[str](),
        )
        definition = gen_autograd_functions.process_function(
            differentiability_info, gen_autograd_functions.FUNCTION_DEFINITION)
        # grad_z should map to grads[1], not grads[2] because output 1
        # (y) is not differentiable.
        assert "grad_z = grads[2]" not in definition
        assert "grad_z = grads[1]" in definition