コード例 #1
0
ファイル: test_stacked.py プロジェクト: TNonet/lmdec
def test_mean_non_consistent_shape(shapes):
    sa = StackedArray([da.random.random(shape) for shape in shapes])
    for axis in [
            None, -1, *list(list(range(x)) for x in range(len(sa.shape)))
    ]:
        note(f"shapes: {shapes}, axis: {axis}")
        np.testing.assert_array_almost_equal(sa.mean(axis=axis),
                                             sa.array.mean(axis=axis))
コード例 #2
0
ファイル: test_stacked.py プロジェクト: TNonet/lmdec
def test_reshape():
    sa = StackedArray([da.random.random(size=(4, 4)) for _ in range(7)])

    for new_shape in [(16, ), (2, 8), (4, 4), (8, 2)]:
        sa1 = sa.reshape(new_shape)

        assert sa1.shape == new_shape
        np.testing.assert_array_almost_equal(sa1.mean(), sa.mean(), decimal=12)
        np.testing.assert_array_almost_equal(sa1.std(), sa.std(), decimal=12)
コード例 #3
0
ファイル: test_stacked.py プロジェクト: TNonet/lmdec
def test_persist():
    def delay_array(x):
        # slow operation takes about 2 seconds on my computer
        d = da.mean(da.random.random(1e9))
        return d + x - d

    arrays = [delay_array(da.random.random(size=(4, ))) for _ in range(2)]

    sa = StackedArray(arrays)
    sa_persist = sa.persist()

    start = time.time()
    sa_persist.mean().compute()
    persist_took = time.time() - start

    start_mean = time.time()
    sa.mean().compute()
    mean_took = time.time() - start_mean
    assert persist_took <= mean_took / 10
コード例 #4
0
ファイル: test_stacked.py プロジェクト: TNonet/lmdec
def test_fallback_methods():
    sa = StackedArray([da.random.random(size=(4, 4)) for x in range(7)])

    assert sa.shape == sa.array.shape
    assert sa.chunks == sa.array.chunks

    for axis in [None, 0, 1]:
        np.testing.assert_array_almost_equal(sa.std(axis=axis),
                                             sa.array.std(axis=axis),
                                             decimal=12)
        np.testing.assert_array_almost_equal(sa.mean(axis=axis),
                                             sa.array.mean(axis=axis),
                                             decimal=12)
        np.testing.assert_array_almost_equal(sa.max(axis=axis),
                                             sa.array.max(axis=axis),
                                             decimal=12)
        np.testing.assert_array_almost_equal(sa.min(axis=axis),
                                             sa.array.min(axis=axis),
                                             decimal=12)
コード例 #5
0
ファイル: test_stacked.py プロジェクト: TNonet/lmdec
def test_mean_consistent_shape():
    sa = StackedArray([da.random.random(size=(10, 10, 10)) for _ in range(7)])

    for axis in [None, 0, 1, 2, (0, 1), -1, (1, 2), (0, 1, 2)]:
        np.testing.assert_array_almost_equal(sa.mean(axis=axis),
                                             sa.array.mean(axis=axis))