def test_choose_params_and_methods_raises(self, argnum):
     """Test that the _choose_params_and_methods helper method raises an
     error if incorrect trainable parameters are specified."""
     tape = JacobianTape()
     tape.trainable_params = [0]
     diff_methods = ["F"]
     with pytest.raises(
         ValueError,
         match="Incorrect trainable parameters",
     ):
         res = tape._choose_params_with_methods(diff_methods, argnum)
Exemple #2
0
 def test_choose_params_and_methods_warns_no_params(self):
     """Test that the _choose_params_and_methods helper method warns if an
     empty list was passed as argnum."""
     tape = JacobianTape()
     tape.trainable_params = [0]
     diff_methods = ["F"]
     argnum = []
     with pytest.warns(
             UserWarning,
             match="No trainable parameters",
     ):
         res = tape._choose_params_with_methods(diff_methods, argnum)
Exemple #3
0
    def test_choose_params_and_methods(self, diff_methods, argnum):
        """Test that the _choose_params_and_methods helper method returns
        expected results"""
        tape = JacobianTape()
        tape._trainable_params = list(range(len(diff_methods)))
        res = list(tape._choose_params_with_methods(diff_methods, argnum))

        num_all_params = len(diff_methods)

        assert all(k in range(num_all_params) for k, _ in res)
        assert all(v in diff_methods for _, v in res)

        if argnum is None:
            num_params = num_all_params
        elif isinstance(argnum, int):
            num_params = 1
        else:
            num_params = len(argnum)

        assert len(res) == num_params