Example #1
0
    def test_concatenate_jax(self):
        """Test that concatenate, called without the axis arguments, concatenates across the 0th dimension"""
        t1 = jnp.array([5.0, 8.0, 101.0])
        t2 = jnp.array([0.6, 0.1, 0.6])
        t3 = jnp.array([0.1, 0.2, 0.3])

        res = fn.concatenate([t1, t2, t3])
        assert jnp.all(res == jnp.concatenate([t1, t2, t3]))
Example #2
0
    def test_concatenate_array(self):
        """Test that concatenate, called without the axis arguments, concatenates across the 0th dimension"""
        t1 = [0.6, 0.1, 0.6]
        t2 = np.array([0.1, 0.2, 0.3])
        t3 = onp.array([5.0, 8.0, 101.0])

        res = fn.concatenate([t1, t2, t3])
        assert isinstance(res, np.ndarray)
        assert np.all(res == np.concatenate([t1, t2, t3]))
Example #3
0
    def test_stack_torch(self):
        """Test that concatenate, called without the axis arguments, concatenates across the 0th dimension"""
        t1 = onp.array([5.0, 8.0, 101.0], dtype=np.float64)
        t2 = torch.tensor([0.6, 0.1, 0.6], dtype=torch.float64)
        t3 = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64)

        res = fn.concatenate([t1, t2, t3])
        assert isinstance(res, torch.Tensor)
        assert np.all(res.numpy() == np.concatenate([t1, t2.numpy(), t3.numpy()]))
Example #4
0
    def test_stack_tensorflow(self):
        """Test that concatenate, called without the axis arguments, concatenates across the 0th dimension"""
        t1 = tf.constant([0.6, 0.1, 0.6])
        t2 = tf.Variable([0.1, 0.2, 0.3])
        t3 = onp.array([5.0, 8.0, 101.0])

        res = fn.concatenate([t1, t2, t3])
        assert isinstance(res, tf.Tensor)
        assert np.all(res.numpy() == np.concatenate([t1.numpy(), t2.numpy(), t3]))
Example #5
0
    def test_concatenate_flattened_arrays(self, t1):
        """Concatenating arrays with axis=None will result in all arrays being pre-flattened"""
        t2 = onp.array([5])
        res = fn.concatenate([t1, t2], axis=None)

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([1, 2, 5]))
        assert list(res.shape) == [3]
Example #6
0
    def test_stack_axis(self, t1):
        """Test that passing the axis argument allows for concatenating along
        a different axis"""
        t2 = onp.array([[3], [4]])
        res = fn.concatenate([t1, t2], axis=1)

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([[1, 3], [2, 4]]))
        assert list(res.shape) == [2, 2]