def test_torch(self): """Torch tensors will simply return their requires_grad attribute""" t = torch.tensor([1.0, 2.0], requires_grad=True) assert fn.requires_grad(t) t = torch.tensor([1.0, 2.0], requires_grad=False) assert not fn.requires_grad(t)
def test_autograd(self): """Autograd arrays will simply return their requires_grad attribute""" t = np.array([1.0, 2.0], requires_grad=True) assert fn.requires_grad(t) t = np.array([1.0, 2.0], requires_grad=False) assert not fn.requires_grad(t)
def test_tf(self): """TensorFlow tensors will True *if* they are being watched by a gradient tape""" t1 = tf.Variable([1.0, 2.0]) t2 = tf.constant([1.0, 2.0]) assert not fn.requires_grad(t1) assert not fn.requires_grad(t2) with tf.GradientTape(): # variables are automatically watched within a context, # but constants are not assert fn.requires_grad(t1) assert not fn.requires_grad(t2) with tf.GradientTape() as tape: # watching makes all tensors trainable tape.watch([t1, t2]) assert fn.requires_grad(t1) assert fn.requires_grad(t2)
def test_jax(self, t): """jax.DeviceArrays will always return True""" assert fn.requires_grad(t)
def test_numpy(self, t): """Vanilla NumPy arrays, sequences, and lists will always return False""" assert not fn.requires_grad(t)