def test_build_shards(): array = np.array([1, 2, 3, 4]) shards = _build_shards(4, array) assert shards == [np.array([1]), np.array([2]), np.array([3]), np.array([4])] shards = _build_shards(3, array) for out, expected in zip(shards, map(np.array, [[1], [2], [3, 4]])): assert np.array_equal(out, expected) with pytest.raises(ValueError): shards = _build_shards(5, array)