Beispiel #1
0
 def test_kwargs_are_considered(self):
     """Tests that kwargs are taken into account when checking
     independence of outputs."""
     f = lambda x, kw=False: 0.1 * x if kw else 0.2
     jac = jax.jacobian(f, argnums=0)
     args = (0.2, )
     assert is_independent(f, self.interface, args)
     assert not is_independent(f, self.interface, args, {"kw": True})
     assert is_independent(jac, self.interface, args, {"kw": True})
 def test_kwargs_are_considered(self):
     """Tests that kwargs are taken into account when checking
     independence of outputs."""
     f = lambda x, kw=False: 0.1 * x if kw else 0.2
     jac = lambda x, kw: torch.autograd.functional.jacobian(lambda x: f(x, kw), x)
     args = (torch.tensor(0.2),)
     assert is_independent(f, self.interface, args)
     assert not is_independent(f, self.interface, args, {"kw": True})
     assert is_independent(jac, self.interface, args, {"kw": True})
Beispiel #3
0
 def test_dependent(self, func, args):
     """Tests that a dependent function is correctly detected as such."""
     args = tuple([tf.Variable(_arg) for _arg in args])
     # Filter out functions with TF-incompatible output format
     out = func(*args)
     if not isinstance(out, tf.Tensor):
         try:
             _func = lambda *args: tf.Variable(func(*args))
             assert not is_independent(_func, self.interface, args)
         except:
             pytest.skip()
     else:
         assert not is_independent(func, self.interface, args)
Beispiel #4
0
        def test_kwargs_are_considered(self):
            """Tests that kwargs are taken into account when checking
            independence of outputs."""
            f = lambda x, kw=False: 0.1 * x if kw else tf.constant(0.2)

            def _jac(x, kw):
                with tf.GradientTape() as tape:
                    out = f(x, kw)
                return tape.jacobian(out, x)

            args = (tf.Variable(0.2), )
            assert is_independent(f, self.interface, args)
            assert not is_independent(f, self.interface, args, {"kw": True})
            assert is_independent(_jac, self.interface, args, {"kw": True})
Beispiel #5
0
 def test_unknown_interface(self):
     """Test that an error is raised if an unknown interface is requested."""
     with pytest.raises(ValueError, match="Unknown interface: hello"):
         is_independent(lambda x: x, "hello", (0.1, ))
Beispiel #6
0
 def test_overlooked_dependence(self, func, args):
     """Test that particular functions that are dependent on the input
     are overlooked."""
     assert not is_independent(func, self.interface, args)
Beispiel #7
0
 def test_dependent(self, func, args, exp_fail):
     """Tests that a dependent function is correctly detected as such."""
     if exp_fail:
         assert is_independent(func, self.interface, args)
     else:
         assert not is_independent(func, self.interface, args)
Beispiel #8
0
 def test_independent(self, func, args):
     """Tests that an independent function is correctly detected as such."""
     assert is_independent(func, self.interface, args)
Beispiel #9
0
 def test_independent(self, func, args):
     """Tests that an independent function is correctly detected as such."""
     args = tuple([tf.Variable(_arg) for _arg in args])
     assert is_independent(func, self.interface, args)