def test_warning_torch_and_autograd(self):
        """Test that a warning is raised if the sequence of tensors contains
        both torch and autograd tensors."""
        x = torch.tensor([1.0, 2.0, 3.0])
        y = np.array([0.5, 0.1])

        with pytest.warns(UserWarning, match="Consider replacing Autograd with vanilla NumPy"):
            fn._get_multi_tensorbox([x, y])
Exemple #2
0
    def test_exception_tensorflow_and_torch(self):
        """Test that an exception is raised if the sequence of tensors contains
        tensors from incompatible dispatch libraries"""
        x = tf.Variable([1.0, 2.0, 3.0])
        y = onp.array([0.5, 0.1])
        z = torch.tensor([0.6])

        with pytest.raises(ValueError, match="Tensors contain mixed types"):
            fn._get_multi_tensorbox([x, y, z])
Exemple #3
0
    def test_return_numpy_box(self):
        """Test that NumPy is correctly identified as the dispatching library."""
        x = onp.array([1.0, 2.0, 3.0])
        y = [0.5, 0.1]

        res = fn._get_multi_tensorbox([y, x])
        assert res.interface == "numpy"
Exemple #4
0
    def test_return_torch_box(self):
        """Test that Torch is correctly identified as the dispatching library."""
        x = torch.tensor([1.0, 2.0, 3.0])
        y = onp.array([0.5, 0.1])

        res = fn._get_multi_tensorbox([y, x])
        assert res.interface == "torch"
Exemple #5
0
    def test_return_tensorflow_box(self):
        """Test that TensorFlow is correctly identified as the dispatching library."""
        x = tf.Variable([1.0, 2.0, 3.0])
        y = onp.array([0.5, 0.1])

        res = fn._get_multi_tensorbox([y, x])
        assert res.interface == "tf"