def test_kwargs_are_considered(self): """Tests that kwargs are taken into account when checking independence of outputs.""" f = lambda x, kw=False: 0.1 * x if kw else 0.2 jac = jax.jacobian(f, argnums=0) args = (0.2, ) assert is_independent(f, self.interface, args) assert not is_independent(f, self.interface, args, {"kw": True}) assert is_independent(jac, self.interface, args, {"kw": True})
def test_kwargs_are_considered(self): """Tests that kwargs are taken into account when checking independence of outputs.""" f = lambda x, kw=False: 0.1 * x if kw else 0.2 jac = lambda x, kw: torch.autograd.functional.jacobian(lambda x: f(x, kw), x) args = (torch.tensor(0.2),) assert is_independent(f, self.interface, args) assert not is_independent(f, self.interface, args, {"kw": True}) assert is_independent(jac, self.interface, args, {"kw": True})
def test_dependent(self, func, args): """Tests that a dependent function is correctly detected as such.""" args = tuple([tf.Variable(_arg) for _arg in args]) # Filter out functions with TF-incompatible output format out = func(*args) if not isinstance(out, tf.Tensor): try: _func = lambda *args: tf.Variable(func(*args)) assert not is_independent(_func, self.interface, args) except: pytest.skip() else: assert not is_independent(func, self.interface, args)
def test_kwargs_are_considered(self): """Tests that kwargs are taken into account when checking independence of outputs.""" f = lambda x, kw=False: 0.1 * x if kw else tf.constant(0.2) def _jac(x, kw): with tf.GradientTape() as tape: out = f(x, kw) return tape.jacobian(out, x) args = (tf.Variable(0.2), ) assert is_independent(f, self.interface, args) assert not is_independent(f, self.interface, args, {"kw": True}) assert is_independent(_jac, self.interface, args, {"kw": True})
def test_unknown_interface(self): """Test that an error is raised if an unknown interface is requested.""" with pytest.raises(ValueError, match="Unknown interface: hello"): is_independent(lambda x: x, "hello", (0.1, ))
def test_overlooked_dependence(self, func, args): """Test that particular functions that are dependent on the input are overlooked.""" assert not is_independent(func, self.interface, args)
def test_dependent(self, func, args, exp_fail): """Tests that a dependent function is correctly detected as such.""" if exp_fail: assert is_independent(func, self.interface, args) else: assert not is_independent(func, self.interface, args)
def test_independent(self, func, args): """Tests that an independent function is correctly detected as such.""" assert is_independent(func, self.interface, args)
def test_independent(self, func, args): """Tests that an independent function is correctly detected as such.""" args = tuple([tf.Variable(_arg) for _arg in args]) assert is_independent(func, self.interface, args)