def test_concatenate_jax(self): """Test that concatenate, called without the axis arguments, concatenates across the 0th dimension""" t1 = jnp.array([5.0, 8.0, 101.0]) t2 = jnp.array([0.6, 0.1, 0.6]) t3 = jnp.array([0.1, 0.2, 0.3]) res = fn.concatenate([t1, t2, t3]) assert jnp.all(res == jnp.concatenate([t1, t2, t3]))
def test_concatenate_array(self): """Test that concatenate, called without the axis arguments, concatenates across the 0th dimension""" t1 = [0.6, 0.1, 0.6] t2 = np.array([0.1, 0.2, 0.3]) t3 = onp.array([5.0, 8.0, 101.0]) res = fn.concatenate([t1, t2, t3]) assert isinstance(res, np.ndarray) assert np.all(res == np.concatenate([t1, t2, t3]))
def test_stack_torch(self): """Test that concatenate, called without the axis arguments, concatenates across the 0th dimension""" t1 = onp.array([5.0, 8.0, 101.0], dtype=np.float64) t2 = torch.tensor([0.6, 0.1, 0.6], dtype=torch.float64) t3 = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64) res = fn.concatenate([t1, t2, t3]) assert isinstance(res, torch.Tensor) assert np.all(res.numpy() == np.concatenate([t1, t2.numpy(), t3.numpy()]))
def test_stack_tensorflow(self): """Test that concatenate, called without the axis arguments, concatenates across the 0th dimension""" t1 = tf.constant([0.6, 0.1, 0.6]) t2 = tf.Variable([0.1, 0.2, 0.3]) t3 = onp.array([5.0, 8.0, 101.0]) res = fn.concatenate([t1, t2, t3]) assert isinstance(res, tf.Tensor) assert np.all(res.numpy() == np.concatenate([t1.numpy(), t2.numpy(), t3]))
def test_concatenate_flattened_arrays(self, t1): """Concatenating arrays with axis=None will result in all arrays being pre-flattened""" t2 = onp.array([5]) res = fn.concatenate([t1, t2], axis=None) # if tensorflow or pytorch, extract view of underlying data if hasattr(res, "numpy"): res = res.numpy() assert fn.allclose(res, np.array([1, 2, 5])) assert list(res.shape) == [3]
def test_stack_axis(self, t1): """Test that passing the axis argument allows for concatenating along a different axis""" t2 = onp.array([[3], [4]]) res = fn.concatenate([t1, t2], axis=1) # if tensorflow or pytorch, extract view of underlying data if hasattr(res, "numpy"): res = res.numpy() assert fn.allclose(res, np.array([[1, 3], [2, 4]])) assert list(res.shape) == [2, 2]