Exemplo n.º 1
0
def test_where(nps_app_inst):
    import nums.numpy as nps

    assert nps_app_inst is not None

    shapes = [
        (),
        (10**6,),
        (10**6, 1),
        (10**5, 10)
    ]
    for shape in shapes:
        arr: BlockArray = nps.random.rand(*shape)
        if len(shape) == 1:
            arr = arr.reshape(block_shape=(arr.shape[0] // 12,))
        elif len(shape) == 2:
            arr = arr.reshape(block_shape=(arr.shape[0] // 12,
                                           arr.shape[1]))
        results: tuple = nps.where(arr < 0.5)
        np_results = np.where(arr.get() < 0.5)
        for i in range(len(np_results)):
            assert np.allclose(np_results[i], results[i].get())
        results: tuple = nps.where(arr >= 0.5)
        np_results = np.where(arr.get() >= 0.5)
        for i in range(len(np_results)):
            assert np.allclose(np_results[i], results[i].get())
Exemplo n.º 2
0
def test_where(nps_app_inst):
    import nums.numpy as nps

    assert nps_app_inst is not None

    shapes = [(), (10**6, ), (10**6, 1), (10**5, 10)]
    for shape in shapes:
        arr: BlockArray = nps.random.rand(*shape)
        x: BlockArray = nps.random.rand(*shape)
        y: BlockArray = nps.random.rand(*shape)
        if len(shape) == 1:
            bs = (shape[0] // 12, )
            arr = arr.reshape(block_shape=bs)
            x = x.reshape(block_shape=bs)
            y = y.reshape(block_shape=bs)
        elif len(shape) == 2:
            bs = (shape[0] // 12, shape[1])
            arr = arr.reshape(block_shape=bs)
            x = x.reshape(block_shape=bs)
            y = y.reshape(block_shape=bs)
        results: tuple = nps.where(arr < 0.5)
        np_results = np.where(arr.get() < 0.5)
        for i in range(len(np_results)):
            assert np.allclose(np_results[i], results[i].get())
        results: tuple = nps.where(arr >= 0.5)
        np_results = np.where(arr.get() >= 0.5)
        for i in range(len(np_results)):
            assert np.allclose(np_results[i], results[i].get())

        # Do an xy test.
        np_results = np.where(arr.get() < 0.5, x.get(), y.get())
        result = nps.where(arr < 0.5, x, y)
        assert np.allclose(np_results, result.get())

        np_results = np.where(arr.get() >= 0.5, x.get(), y.get())
        result = nps.where(arr >= 0.5, x, y)
        assert np.allclose(np_results, result.get())