コード例 #1
0
    def test_sum_axis_keepdims(self, t1):
        """Test that passing the axis argument allows for summing along
        a specific axis, while keepdims avoids the summed dimensions from being removed"""
        res = fn.sum(t1, axis=(0, 2), keepdims=True)

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([[[14], [6], [3]]]))
        assert res.shape == (1, 3, 1)
コード例 #2
0
    def test_sum_axis(self, t1):
        """Test that passing the axis argument allows for summing along
        a specific axis"""
        res = fn.sum(t1, axis=(0, 2))

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([14, 6, 3]))
        assert res.shape == (3,)
コード例 #3
0
 def cost_fn(t):
     return fn.sum(fn.take(t, indices, axis=1))
コード例 #4
0
 def test_jax(self):
     """Test that sum, called without the axis arguments, returns a scalar"""
     t = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
     res = fn.sum(t)
     assert fn.allclose(res, 2.1)
コード例 #5
0
 def test_torch(self):
     """Test that sum, called without the axis arguments, returns a scalar"""
     t = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
     res = fn.sum(t)
     assert isinstance(res, torch.Tensor)
     assert fn.allclose(res, 2.1)