def test_choose_params_and_methods_raises(self, argnum): """Test that the _choose_params_and_methods helper method raises an error if incorrect trainable parameters are specified.""" tape = JacobianTape() tape.trainable_params = [0] diff_methods = ["F"] with pytest.raises( ValueError, match="Incorrect trainable parameters", ): res = tape._choose_params_with_methods(diff_methods, argnum)
def test_choose_params_and_methods_warns_no_params(self): """Test that the _choose_params_and_methods helper method warns if an empty list was passed as argnum.""" tape = JacobianTape() tape.trainable_params = [0] diff_methods = ["F"] argnum = [] with pytest.warns( UserWarning, match="No trainable parameters", ): res = tape._choose_params_with_methods(diff_methods, argnum)
def test_choose_params_and_methods(self, diff_methods, argnum): """Test that the _choose_params_and_methods helper method returns expected results""" tape = JacobianTape() tape._trainable_params = list(range(len(diff_methods))) res = list(tape._choose_params_with_methods(diff_methods, argnum)) num_all_params = len(diff_methods) assert all(k in range(num_all_params) for k, _ in res) assert all(v in diff_methods for _, v in res) if argnum is None: num_params = num_all_params elif isinstance(argnum, int): num_params = 1 else: num_params = len(argnum) assert len(res) == num_params