Пример #1
0
    def test_torch(self):
        """Torch tensors will simply return their requires_grad attribute"""
        t = torch.tensor([1.0, 2.0], requires_grad=True)
        assert fn.requires_grad(t)

        t = torch.tensor([1.0, 2.0], requires_grad=False)
        assert not fn.requires_grad(t)
Пример #2
0
    def test_autograd(self):
        """Autograd arrays will simply return their requires_grad attribute"""
        t = np.array([1.0, 2.0], requires_grad=True)
        assert fn.requires_grad(t)

        t = np.array([1.0, 2.0], requires_grad=False)
        assert not fn.requires_grad(t)
Пример #3
0
    def test_tf(self):
        """TensorFlow tensors will True *if* they are being watched by a gradient tape"""
        t1 = tf.Variable([1.0, 2.0])
        t2 = tf.constant([1.0, 2.0])
        assert not fn.requires_grad(t1)
        assert not fn.requires_grad(t2)

        with tf.GradientTape():
            # variables are automatically watched within a context,
            # but constants are not
            assert fn.requires_grad(t1)
            assert not fn.requires_grad(t2)

        with tf.GradientTape() as tape:
            # watching makes all tensors trainable
            tape.watch([t1, t2])
            assert fn.requires_grad(t1)
            assert fn.requires_grad(t2)
Пример #4
0
 def test_jax(self, t):
     """jax.DeviceArrays will always return True"""
     assert fn.requires_grad(t)
Пример #5
0
 def test_numpy(self, t):
     """Vanilla NumPy arrays, sequences, and lists will always return False"""
     assert not fn.requires_grad(t)