def test_stack_array_jax(self): """Test that stack, called without the axis arguments, stacks vertically""" t1 = onp.array([0.6, 0.1, 0.6]) t2 = jnp.array([0.1, 0.2, 0.3]) t3 = jnp.array([5.0, 8.0, 101.0]) res = fn.stack([t1, t2, t3]) assert np.all(res == np.stack([t1, t2, t3]))
def test_stack_torch(self): """Test that stack, called without the axis arguments, stacks vertically""" t1 = onp.array([5.0, 8.0, 101.0], dtype=np.float64) t2 = torch.tensor([0.6, 0.1, 0.6], dtype=torch.float64) t3 = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64) res = fn.stack([t1, t2, t3]) assert isinstance(res, torch.Tensor) assert np.all(res.numpy() == np.stack([t1, t2.numpy(), t3.numpy()]))
def test_stack_tensorflow(self): """Test that stack, called without the axis arguments, stacks vertically""" t1 = tf.constant([0.6, 0.1, 0.6]) t2 = tf.Variable([0.1, 0.2, 0.3]) t3 = onp.array([5.0, 8.0, 101.0]) res = fn.stack([t1, t2, t3]) assert isinstance(res, tf.Tensor) assert np.all(res.numpy() == np.stack([t1.numpy(), t2.numpy(), t3]))
def test_stack_axis(self, t1): """Test that passing the axis argument allows for stacking along a different axis""" t2 = onp.array([3, 4]) res = fn.stack([t1, t2], axis=1) # if tensorflow or pytorch, extract view of underlying data if hasattr(res, "numpy"): res = res.numpy() assert fn.allclose(res, np.array([[1, 3], [2, 4]])) assert list(res.shape) == [2, 2]