コード例 #1
0
    def test_jax(self):
        """Test that a jax array is automatically converted into
        a diagonal tensor"""
        t = jnp.array([0.1, 0.2, 0.3])
        res = fn.diag(t)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3]))

        res = fn.diag(t, k=1)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3], k=1))
コード例 #2
0
    def test_torch(self):
        """Test that a torch tensor is automatically converted into
        a diagonal tensor"""
        t = torch.tensor([0.1, 0.2, 0.3])
        res = fn.diag(t)
        assert isinstance(res, torch.Tensor)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3]))

        res = fn.diag(t, k=1)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3], k=1))
コード例 #3
0
    def test_array(self):
        """Test that a NumPy array is automatically converted into
        a diagonal tensor"""
        t = np.array([0.1, 0.2, 0.3])
        res = fn.diag(t)
        assert isinstance(res, np.ndarray)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3]))

        res = fn.diag(t, k=1)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3], k=1))
コード例 #4
0
    def test_tensorflow(self):
        """Test that a tensorflow tensor is automatically converted into
        a diagonal tensor"""
        t = tf.Variable([0.1, 0.2, 0.3])
        res = fn.diag(t)
        assert isinstance(res, tf.Tensor)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3]))

        res = fn.diag(t, k=1)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3], k=1))
コード例 #5
0
 def test_sequence(self, a, interface):
     """Test that a sequence is automatically converted into
     a diagonal tensor"""
     t = [0.1, 0.2, a]
     res = fn.diag(t)
     assert fn.get_interface(res) == interface
     assert fn.allclose(res, onp.diag([0.1, 0.2, 0.5]))