コード例 #1
0
def test_generation(name, backend):
    stencil_definition = stencil_definitions[name]
    externals = externals_registry[name]
    stencil = gtscript.stencil(backend, stencil_definition, externals=externals)
    args = {}
    for k, v in stencil_definition.__annotations__.items():
        if isinstance(v, gtscript._FieldDescriptor):
            args[k] = gt_storage.ones(
                dtype=(v.dtype, v.data_dims) if v.data_dims else v.dtype,
                mask=gtscript.mask_from_axes(v.axes),
                backend=backend,
                shape=(23, 23, 23),
                default_origin=(10, 10, 10),
            )
        else:
            args[k] = v(1.5)
    # vertical domain size >= 16 required for test_large_k_interval
    stencil(**args, origin=(10, 10, 5), domain=(3, 3, 16))
コード例 #2
0
def test_generation_gpu(name, backend):
    stencil_definition = stencil_definitions[name]
    externals = externals_registry[name]
    stencil = gtscript.stencil(backend,
                               stencil_definition,
                               externals=externals)
    args = {}
    for k, v in stencil_definition.__annotations__.items():
        if isinstance(v, gtscript._FieldDescriptor):
            args[k] = gt_storage.ones(
                dtype=v.dtype,
                mask=gtscript.mask_from_axes(v.axes),
                backend=backend,
                shape=(23, 23, 23),
                default_origin=(10, 10, 10),
            )
        else:
            args[k] = v(1.5)
    stencil(**args, origin=(10, 10, 10), domain=(3, 3, 3))
コード例 #3
0
ファイル: utils.py プロジェクト: tehrengruber/gt4py
def make_definition(
    stencil_name: str,
    definition_func: types.FunctionType,
    args_list: list,
    fields_with_storage_descriptor: dict,
    temp_fields_with_type: dict,
    parameters_with_type: dict,
    domain=None,
    externals=None,
    sources=None,
):
    api_signature = make_api_signature(args_list)
    domain = domain or Domain.LatLonGrid()
    externals = externals or {}
    sources = sources or {}

    fields_decls = []
    for name, descriptor in fields_with_storage_descriptor.items():
        fields_decls.append(
            make_field_decl(
                name=name,
                dtype=descriptor.dtype,
                masked_axes=[
                    i
                    for i, masked in enumerate(gtscript.mask_from_axes(descriptor.axes))
                    if masked
                ],
                is_api=True,
                layout_id=name,
            )
        )

    temp_fields_decls = {
        name: make_field_decl(name=name, dtype=dtype, is_api=False)
        for name, dtype in temp_fields_with_type.items()
    }

    parameter_decls = []
    for key, value in parameters_with_type.items():
        if isinstance(value, tuple):
            assert len(value) == 2
            data_type = value[0]
            length = value[1]
        else:
            data_type = value
            length = 0

        parameter_decls.append(
            VarDecl(name=key, data_type=DataType.from_dtype(data_type), length=length, is_api=True)
        )

    computations = make_computations(
        definition_func,
        fields={decl.name: decl for decl in fields_decls},
        parameters={decl.name: decl for decl in parameter_decls},
        local_symbols=None,
        externals=externals,
        domain=domain,
        extra_temp_decls=temp_fields_decls,
    )

    definition = StencilDefinition(
        name=stencil_name,
        domain=domain,
        api_signature=api_signature,
        api_fields=fields_decls,
        parameters=parameter_decls,
        computations=computations,
        externals=externals,
        sources=sources,
    )

    return definition