def test_warning_torch_and_autograd(self): """Test that a warning is raised if the sequence of tensors contains both torch and autograd tensors.""" x = torch.tensor([1.0, 2.0, 3.0]) y = np.array([0.5, 0.1]) with pytest.warns(UserWarning, match="Consider replacing Autograd with vanilla NumPy"): fn._get_multi_tensorbox([x, y])
def test_exception_tensorflow_and_torch(self): """Test that an exception is raised if the sequence of tensors contains tensors from incompatible dispatch libraries""" x = tf.Variable([1.0, 2.0, 3.0]) y = onp.array([0.5, 0.1]) z = torch.tensor([0.6]) with pytest.raises(ValueError, match="Tensors contain mixed types"): fn._get_multi_tensorbox([x, y, z])
def test_return_numpy_box(self): """Test that NumPy is correctly identified as the dispatching library.""" x = onp.array([1.0, 2.0, 3.0]) y = [0.5, 0.1] res = fn._get_multi_tensorbox([y, x]) assert res.interface == "numpy"
def test_return_torch_box(self): """Test that Torch is correctly identified as the dispatching library.""" x = torch.tensor([1.0, 2.0, 3.0]) y = onp.array([0.5, 0.1]) res = fn._get_multi_tensorbox([y, x]) assert res.interface == "torch"
def test_return_tensorflow_box(self): """Test that TensorFlow is correctly identified as the dispatching library.""" x = tf.Variable([1.0, 2.0, 3.0]) y = onp.array([0.5, 0.1]) res = fn._get_multi_tensorbox([y, x]) assert res.interface == "tf"