Ejemplo n.º 1
0
    def test_vector_2d(halo, n_threads):
        # arrange
        grid = (4, 2)
        data = (np.array([
            [1, 6],
            [2, 7],
            [3, 8],
            [4, 9],
            [5, 10],
        ],
                         dtype=float),
                np.array([
                    [1, 5, 9],
                    [2, 6, 10],
                    [3, 7, 11],
                    [4, 8, 12],
                ],
                         dtype=float))
        boundary_conditions = (Periodic(),
                               Polar(grid=grid,
                                     longitude_idx=OUTER,
                                     latitude_idx=INNER))
        field = VectorField(data, halo, boundary_conditions)
        traversals = Traversals(grid=grid,
                                halo=halo,
                                jit_flags=JIT_FLAGS,
                                n_threads=n_threads)
        field.assemble(traversals)
        meta_and_data, fill_halos = field.impl
        sut = traversals._code['fill_halos_vector']  # pylint:disable=protected-access

        # act
        # pylint: disable-next=not-an-iterable
        for thread_id in numba.prange(n_threads):
            sut(thread_id, *meta_and_data, *fill_halos)
Ejemplo n.º 2
0
    def test_scalar_2d(halo, n_threads):
        # arrange
        data = np.array([[1, 6], [2, 7], [3, 8], [4, 9]], dtype=float)
        boundary_condition = (Periodic(),
                              Polar(grid=data.shape,
                                    longitude_idx=OUTER,
                                    latitude_idx=INNER))
        field = ScalarField(data, halo, boundary_condition)
        traversals = Traversals(grid=data.shape,
                                halo=halo,
                                jit_flags=JIT_FLAGS,
                                n_threads=n_threads)
        field.assemble(traversals)
        meta_and_data, fill_halos = field.impl
        sut = traversals._code['fill_halos_scalar']  # pylint:disable=protected-access

        # act
        # pylint: disable-next=not-an-iterable
        for thread_id in numba.prange(n_threads):
            sut(thread_id, *meta_and_data, *fill_halos)

        # assert
        np.testing.assert_array_equal(
            field.data[halo:-halo, :halo],
            np.roll(field.get()[:, :halo], data.shape[OUTER] // 2, axis=OUTER))
        np.testing.assert_array_equal(
            field.data[halo:-halo, -halo:],
            np.roll(field.get()[:, -halo:], data.shape[OUTER] // 2,
                    axis=OUTER))
Ejemplo n.º 3
0
    def test_1d_vector(data, n_threads=1, halo=2):
        # arrange
        boundary_condition = (Extrapolated(), )
        field = VectorField((data, ), halo, boundary_condition)
        traversals = Traversals(grid=field.grid,
                                halo=halo,
                                jit_flags=JIT_FLAGS,
                                n_threads=n_threads)
        field.assemble(traversals)
        meta_and_data, fill_halos = field.impl
        sut = traversals._code['fill_halos_vector']  # pylint:disable=protected-access

        # act
        thread_id = 0
        sut(thread_id, *meta_and_data, *fill_halos)

        # assert
        print(field.data)
def test_formulae_upwind():
    # Arrange
    psi_data = np.array((0, 1, 0))
    flux_data = np.array((0, 0, 1, 0))

    options = Options()
    halo = options.n_halo
    traversals = Traversals(grid=psi_data.shape,
                            halo=halo,
                            jit_flags=options.jit_flags,
                            n_threads=1)
    upwind = make_upwind(options=options,
                         non_unit_g_factor=False,
                         traversals=traversals)

    boundary_conditions = (Periodic(), )

    psi = ScalarField(psi_data, halo, boundary_conditions)
    psi.assemble(traversals)
    psi_impl = psi.impl

    flux = VectorField((flux_data, ), halo, boundary_conditions)
    flux.assemble(traversals)
    flux_impl = flux.impl

    # Act
    with warnings.catch_warnings():
        warnings.simplefilter('ignore',
                              category=NumbaExperimentalFeatureWarning)
        upwind(
            traversals.null_impl,
            _Impl(field=psi_impl[IMPL_META_AND_DATA], bc=psi_impl[IMPL_BC]),
            _Impl(field=flux_impl[IMPL_META_AND_DATA], bc=flux_impl[IMPL_BC]),
            _Impl(field=traversals.null_impl.scalar[IMPL_META_AND_DATA],
                  bc=traversals.null_impl.scalar[IMPL_BC]))

    # Assert
    np.testing.assert_array_equal(psi.get(), np.roll(psi_data, 1))
def make_traversals(grid, halo, n_threads):
    return Traversals(grid=grid,
                      halo=halo,
                      jit_flags=JIT_FLAGS,
                      n_threads=n_threads)