Ejemplo n.º 1
0
    def test_sum_axis_keepdims(self, t1):
        """Test that passing the axis argument allows for summing along
        a specific axis, while keepdims avoids the summed dimensions from being removed"""
        res = fn.sum_(t1, axis=(0, 2), keepdims=True)

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([[[14], [6], [3]]]))
        assert res.shape == (1, 3, 1)
Ejemplo n.º 2
0
    def test_sum_axis(self, t1):
        """Test that passing the axis argument allows for summing along
        a specific axis"""
        res = fn.sum_(t1, axis=(0, 2))

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([14, 6, 3]))
        assert res.shape == (3, )
Ejemplo n.º 3
0
 def test_jax(self):
     """Test that sum, called without the axis arguments, returns a scalar"""
     t = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
     res = fn.sum_(t)
     assert fn.allclose(res, 2.1)
Ejemplo n.º 4
0
 def test_torch(self):
     """Test that sum, called without the axis arguments, returns a scalar"""
     t = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
     res = fn.sum_(t)
     assert isinstance(res, torch.Tensor)
     assert fn.allclose(res, 2.1)