def test_sum_axis_keepdims(self, t1): """Test that passing the axis argument allows for summing along a specific axis, while keepdims avoids the summed dimensions from being removed""" res = fn.sum(t1, axis=(0, 2), keepdims=True) # if tensorflow or pytorch, extract view of underlying data if hasattr(res, "numpy"): res = res.numpy() assert fn.allclose(res, np.array([[[14], [6], [3]]])) assert res.shape == (1, 3, 1)
def test_sum_axis(self, t1): """Test that passing the axis argument allows for summing along a specific axis""" res = fn.sum(t1, axis=(0, 2)) # if tensorflow or pytorch, extract view of underlying data if hasattr(res, "numpy"): res = res.numpy() assert fn.allclose(res, np.array([14, 6, 3])) assert res.shape == (3,)
def cost_fn(t): return fn.sum(fn.take(t, indices, axis=1))
def test_jax(self): """Test that sum, called without the axis arguments, returns a scalar""" t = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) res = fn.sum(t) assert fn.allclose(res, 2.1)
def test_torch(self): """Test that sum, called without the axis arguments, returns a scalar""" t = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) res = fn.sum(t) assert isinstance(res, torch.Tensor) assert fn.allclose(res, 2.1)