def test_apply_scalar(n_threads, halo, grid, loop): n_dims = len(grid) if n_dims == 1 and n_threads > 1: return # arrange traversals = Traversals(grid, halo, jit_flags, n_threads) sut = traversals.apply_scalar(loop=loop) scl_null_arg_impl = ScalarField.make_null(n_dims).impl vec_null_arg_impl = VectorField.make_null(n_dims).impl out = ScalarField(np.zeros(grid), halo, [ConstantBoundaryCondition(np.nan)] * n_dims) # act sut(_cell_id_scalar, _cell_id_scalar if loop else None, _cell_id_scalar if loop else None, *out.impl[IMPL_META_AND_DATA], *vec_null_arg_impl[IMPL_META_AND_DATA], *vec_null_arg_impl[IMPL_BC], *scl_null_arg_impl[IMPL_META_AND_DATA], *scl_null_arg_impl[IMPL_BC], *scl_null_arg_impl[IMPL_META_AND_DATA], *scl_null_arg_impl[IMPL_BC], *scl_null_arg_impl[IMPL_META_AND_DATA], *scl_null_arg_impl[IMPL_BC]) # assert data = out.get() assert data.shape == grid focus = (-halo, -halo, -halo) for i in range(halo, halo + grid[OUTER]) if n_dims > 1 else (INVALID_INDEX, ): for j in range(halo, halo + grid[MID3D]) if n_dims > 2 else (INVALID_INDEX, ): for k in range(halo, halo + grid[INNER]): if n_dims == 1: ijk = (k, INVALID_INDEX, INVALID_INDEX) elif n_dims == 2: ijk = (i, k, INVALID_INDEX) else: raise NotImplementedError() value = indexers[n_dims].at[INNER if n_dims == 1 else OUTER](focus, data, *ijk) assert (n_dims if loop else 1) * cell_id(i, j, k) == value assert scl_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID] assert vec_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID] assert not out.impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID]
def make_commons(grid, halo, num_threads): traversals = make_traversals(grid, halo, num_threads) n_dims = len(grid) halos = ((halo - 1, halo, halo), (halo, halo - 1, halo), (halo, halo, halo - 1)) _Commons = namedtuple('_Commons', ('n_dims', 'traversals', 'scl_null_arg_impl', 'vec_null_arg_impl', 'halos')) return _Commons( n_dims=n_dims, traversals=traversals, scl_null_arg_impl=ScalarField.make_null(n_dims, traversals).impl, vec_null_arg_impl=VectorField.make_null(n_dims, traversals).impl, halos=halos)
def test_apply_vector(n_threads, halo, grid): n_dims = len(grid) if n_dims == 1 and n_threads > 1: return # arrange traversals = Traversals(grid, halo, jit_flags, n_threads) sut = traversals.apply_vector() scl_null_arg_impl = ScalarField.make_null(n_dims).impl vec_null_arg_impl = VectorField.make_null(n_dims).impl if n_dims == 1: data = (np.zeros(grid[0] + 1), ) elif n_dims == 2: data = (np.zeros( (grid[0] + 1, grid[1])), np.zeros((grid[0], grid[1] + 1))) elif n_dims == 3: pass # TODO else: raise NotImplementedError() out = VectorField(data, halo, [ConstantBoundaryCondition(np.nan)] * n_dims) # act sut(*[_cell_id_vector] * MAX_DIM_NUM, *out.impl[IMPL_META_AND_DATA], *scl_null_arg_impl[IMPL_META_AND_DATA], *scl_null_arg_impl[IMPL_BC], *vec_null_arg_impl[IMPL_META_AND_DATA], *vec_null_arg_impl[IMPL_BC], *scl_null_arg_impl[IMPL_META_AND_DATA], *scl_null_arg_impl[IMPL_BC]) # assert halos = ((halo - 1, halo, halo), (halo, halo - 1, halo), (halo, halo, halo - 1)) if n_dims == 1: dims = (INNER, ) elif n_dims == 2: dims = (OUTER, INNER) else: raise NotImplementedError() for d in dims: print("DIM", d) data = out.get_component(d) focus = tuple(-halos[d][i] for i in range(MAX_DIM_NUM)) print("focus", focus) for i in range( halos[d][OUTER], halos[d][OUTER] + data.shape[OUTER]) if n_dims > 1 else (INVALID_INDEX, ): for j in range(halos[d][MID3D], halos[d][MID3D] + data.shape[MID3D]) if n_dims > 2 else ( INVALID_INDEX, ): for k in range(halos[d][INNER], halos[d][INNER] + data.shape[INNER]): if n_dims == 1: ijk = (k, INVALID_INDEX, INVALID_INDEX) elif n_dims == 2: ijk = (i, k, INVALID_INDEX) else: raise NotImplementedError() print("check at", i, j, k) value = indexers[n_dims].at[INNER if n_dims == 1 else OUTER](focus, data, *ijk) assert cell_id(i, j, k) == value assert scl_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID] assert vec_null_arg_impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID] assert not out.impl[IMPL_META_AND_DATA][META_AND_DATA_META][ META_HALO_VALID]