def test_apply_vector(n_threads, halo, grid): n_dims = len(grid) if n_dims == 1 and n_threads > 1: return # arrange traversals = Traversals(grid, halo, jit_flags, n_threads) sut = traversals.apply_vector() scl_null_arg_impl = ScalarField.make_null(n_dims).impl vec_null_arg_impl = VectorField.make_null(n_dims).impl if n_dims == 1: data = (np.zeros(grid[0] + 1), ) elif n_dims == 2: data = (np.zeros( (grid[0] + 1, grid[1])), np.zeros((grid[0], grid[1] + 1))) elif n_dims == 3: pass # TODO else: raise NotImplementedError() out = VectorField(data, halo, [ConstantBoundaryCondition(np.nan)] * n_dims) # act sut(*[_cell_id_vector] * MAX_DIM_NUM, *out.impl[IMPL_META_AND_DATA], *scl_null_arg_impl[IMPL_META_AND_DATA], *scl_null_arg_impl[IMPL_BC], *vec_null_arg_impl[IMPL_META_AND_DATA], *vec_null_arg_impl[IMPL_BC], *scl_null_arg_impl[IMPL_META_AND_DATA], *scl_null_arg_impl[IMPL_BC]) # assert halos = ((halo - 1, halo, halo), (halo, halo - 1, halo), (halo, halo, halo - 1)) if n_dims == 1: dims = (INNER, ) elif n_dims == 2: dims = (OUTER, INNER) else: raise NotImplementedError() for d in dims: print("DIM", d) data = out.get_component(d) focus = tuple(-halos[d][i] for i in range(MAX_DIM_NUM)) print("focus", focus) for i in range( halos[d][OUTER], halos[d][OUTER] + data.shape[OUTER]) if n_dims > 1 else (INVALID_INDEX, ): for j in range(halos[d][MID3D], halos[d][MID3D] + data.shape[MID3D]) if n_dims > 2 else ( INVALID_INDEX, ): for k in range(halos[d][INNER], halos[d][INNER] + data.shape[INNER]): if n_dims == 1: ijk = (k, INVALID_INDEX, INVALID_INDEX) elif n_dims == 2: ijk = (i, k, INVALID_INDEX) else: raise NotImplementedError() print("check at", i, j, k) value = indexers[n_dims].at[INNER if n_dims == 1 else OUTER](focus, data, *ijk) assert cell_id(i, j, k) == value assert scl_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID] assert vec_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID] assert not out.impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID]
def test_apply_vector(n_threads: int, halo: int, grid: tuple): if len(grid) == 1 and n_threads > 1: return cmn = make_commons(grid, halo, n_threads) # arrange sut = cmn.traversals.apply_vector() data = { 1: lambda: (np.zeros(grid[0] + 1), ), 2: lambda: (np.zeros( (grid[0] + 1, grid[1])), np.zeros((grid[0], grid[1] + 1))), 3: lambda: ( np.zeros((grid[0] + 1, grid[1], grid[2])), np.zeros((grid[0], grid[1] + 1, grid[2])), np.zeros((grid[0], grid[1], grid[2] + 1)), ) }[cmn.n_dims]() out = VectorField(data, halo, tuple([Constant(np.nan)] * cmn.n_dims)) out.assemble(cmn.traversals) # act sut(*[_cell_id_vector] * MAX_DIM_NUM, *out.impl[IMPL_META_AND_DATA], *cmn.scl_null_arg_impl[IMPL_META_AND_DATA], *cmn.scl_null_arg_impl[IMPL_BC], *cmn.vec_null_arg_impl[IMPL_META_AND_DATA], *cmn.vec_null_arg_impl[IMPL_BC], *cmn.scl_null_arg_impl[IMPL_META_AND_DATA], *cmn.scl_null_arg_impl[IMPL_BC]) # assert dims = { 1: (INNER, ), 2: (OUTER, INNER), 3: (OUTER, MID3D, INNER) }[cmn.n_dims] for dim in dims: data = out.get_component(dim) focus = tuple(-cmn.halos[dim][i] for i in range(MAX_DIM_NUM)) for i in range(cmn.halos[dim][OUTER], cmn.halos[dim][OUTER] + data.shape[OUTER]) if cmn.n_dims > 1 else ( INVALID_INDEX, ): for j in range(cmn.halos[dim][MID3D], cmn.halos[dim][MID3D] + data.shape[MID3D]) if cmn.n_dims > 2 else ( INVALID_INDEX, ): for k in range(cmn.halos[dim][INNER], cmn.halos[dim][INNER] + data.shape[INNER]): if cmn.n_dims == 1: ijk = (k, INVALID_INDEX, INVALID_INDEX) elif cmn.n_dims == 2: ijk = (i, k, INVALID_INDEX) else: ijk = (i, j, k) value = cmn.traversals.indexers[cmn.n_dims].ats[ INNER if cmn.n_dims == 1 else OUTER](focus, data, *ijk) assert cell_id(i, j, k) == value assert cmn.scl_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID] assert cmn.vec_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID] assert not out.impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID]