Ejemplo n.º 1
0
    def test_apply_vector(n_threads, halo, grid):
        n_dims = len(grid)
        if n_dims == 1 and n_threads > 1:
            return

        # arrange
        traversals = Traversals(grid, halo, jit_flags, n_threads)
        sut = traversals.apply_vector()

        scl_null_arg_impl = ScalarField.make_null(n_dims).impl
        vec_null_arg_impl = VectorField.make_null(n_dims).impl

        if n_dims == 1:
            data = (np.zeros(grid[0] + 1), )
        elif n_dims == 2:
            data = (np.zeros(
                (grid[0] + 1, grid[1])), np.zeros((grid[0], grid[1] + 1)))
        elif n_dims == 3:
            pass  # TODO
        else:
            raise NotImplementedError()

        out = VectorField(data, halo,
                          [ConstantBoundaryCondition(np.nan)] * n_dims)

        # act
        sut(*[_cell_id_vector] * MAX_DIM_NUM, *out.impl[IMPL_META_AND_DATA],
            *scl_null_arg_impl[IMPL_META_AND_DATA],
            *scl_null_arg_impl[IMPL_BC],
            *vec_null_arg_impl[IMPL_META_AND_DATA],
            *vec_null_arg_impl[IMPL_BC],
            *scl_null_arg_impl[IMPL_META_AND_DATA],
            *scl_null_arg_impl[IMPL_BC])

        # assert
        halos = ((halo - 1, halo, halo), (halo, halo - 1, halo), (halo, halo,
                                                                  halo - 1))

        if n_dims == 1:
            dims = (INNER, )
        elif n_dims == 2:
            dims = (OUTER, INNER)
        else:
            raise NotImplementedError()
        for d in dims:
            print("DIM", d)
            data = out.get_component(d)
            focus = tuple(-halos[d][i] for i in range(MAX_DIM_NUM))
            print("focus", focus)
            for i in range(
                    halos[d][OUTER], halos[d][OUTER] +
                    data.shape[OUTER]) if n_dims > 1 else (INVALID_INDEX, ):
                for j in range(halos[d][MID3D], halos[d][MID3D] +
                               data.shape[MID3D]) if n_dims > 2 else (
                                   INVALID_INDEX, ):
                    for k in range(halos[d][INNER],
                                   halos[d][INNER] + data.shape[INNER]):
                        if n_dims == 1:
                            ijk = (k, INVALID_INDEX, INVALID_INDEX)
                        elif n_dims == 2:
                            ijk = (i, k, INVALID_INDEX)
                        else:
                            raise NotImplementedError()
                        print("check at", i, j, k)
                        value = indexers[n_dims].at[INNER if n_dims ==
                                                    1 else OUTER](focus, data,
                                                                  *ijk)
                        assert cell_id(i, j, k) == value

        assert scl_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][
            META_HALO_VALID]
        assert vec_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][
            META_HALO_VALID]
        assert not out.impl[IMPL_META_AND_DATA][META_AND_DATA_META][
            META_HALO_VALID]
Ejemplo n.º 2
0
    def test_apply_vector(n_threads: int, halo: int, grid: tuple):
        if len(grid) == 1 and n_threads > 1:
            return
        cmn = make_commons(grid, halo, n_threads)

        # arrange
        sut = cmn.traversals.apply_vector()

        data = {
            1:
            lambda: (np.zeros(grid[0] + 1), ),
            2:
            lambda: (np.zeros(
                (grid[0] + 1, grid[1])), np.zeros((grid[0], grid[1] + 1))),
            3:
            lambda: (
                np.zeros((grid[0] + 1, grid[1], grid[2])),
                np.zeros((grid[0], grid[1] + 1, grid[2])),
                np.zeros((grid[0], grid[1], grid[2] + 1)),
            )
        }[cmn.n_dims]()

        out = VectorField(data, halo, tuple([Constant(np.nan)] * cmn.n_dims))
        out.assemble(cmn.traversals)

        # act
        sut(*[_cell_id_vector] * MAX_DIM_NUM, *out.impl[IMPL_META_AND_DATA],
            *cmn.scl_null_arg_impl[IMPL_META_AND_DATA],
            *cmn.scl_null_arg_impl[IMPL_BC],
            *cmn.vec_null_arg_impl[IMPL_META_AND_DATA],
            *cmn.vec_null_arg_impl[IMPL_BC],
            *cmn.scl_null_arg_impl[IMPL_META_AND_DATA],
            *cmn.scl_null_arg_impl[IMPL_BC])

        # assert
        dims = {
            1: (INNER, ),
            2: (OUTER, INNER),
            3: (OUTER, MID3D, INNER)
        }[cmn.n_dims]

        for dim in dims:
            data = out.get_component(dim)
            focus = tuple(-cmn.halos[dim][i] for i in range(MAX_DIM_NUM))
            for i in range(cmn.halos[dim][OUTER], cmn.halos[dim][OUTER] +
                           data.shape[OUTER]) if cmn.n_dims > 1 else (
                               INVALID_INDEX, ):
                for j in range(cmn.halos[dim][MID3D], cmn.halos[dim][MID3D] +
                               data.shape[MID3D]) if cmn.n_dims > 2 else (
                                   INVALID_INDEX, ):
                    for k in range(cmn.halos[dim][INNER],
                                   cmn.halos[dim][INNER] + data.shape[INNER]):
                        if cmn.n_dims == 1:
                            ijk = (k, INVALID_INDEX, INVALID_INDEX)
                        elif cmn.n_dims == 2:
                            ijk = (i, k, INVALID_INDEX)
                        else:
                            ijk = (i, j, k)
                        value = cmn.traversals.indexers[cmn.n_dims].ats[
                            INNER if cmn.n_dims == 1 else OUTER](focus, data,
                                                                 *ijk)
                        assert cell_id(i, j, k) == value

        assert cmn.scl_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][
            META_HALO_VALID]
        assert cmn.vec_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][
            META_HALO_VALID]
        assert not out.impl[IMPL_META_AND_DATA][META_AND_DATA_META][
            META_HALO_VALID]