예제 #1
0
def compile_definition(
    definition_func,
    name: str,
    module: str,
    *,
    externals: dict = None,
    dtypes: dict = None,
    rebuild=False,
    **kwargs,
):
    gtscript._set_arg_dtypes(definition_func, dtypes=dtypes or {})
    build_options = gt_definitions.BuildOptions(name=name,
                                                module=module,
                                                rebuild=rebuild,
                                                backend_opts=kwargs,
                                                build_info=None)

    options_id = gt_utils.shashed_id(build_options)
    stencil_id = frontend.get_stencil_id(build_options.qualified_name,
                                         definition_func, externals,
                                         options_id)
    gt_frontend.GTScriptParser(definition_func,
                               externals=externals or {},
                               options=build_options).run()

    return stencil_id
def parse_definition(
    definition_func: Callable[..., None],
    *,
    name: str,
    module: str,
    externals: Optional[Dict[str, Any]] = None,
    dtypes: Dict[Type, Type] = None,
    rebuild=False,
    **kwargs,
):
    original_annotations = gtscript._set_arg_dtypes(definition_func, dtypes=dtypes or {})

    build_options = gt_definitions.BuildOptions(
        name=name,
        module=module,
        rebuild=rebuild,
        backend_opts=kwargs,
        build_info=None,
    )

    gt_frontend.GTScriptFrontend.prepare_stencil_definition(
        definition_func, externals=externals or {}
    )
    definition_ir = gt_frontend.GTScriptParser(
        definition_func, externals=externals or {}, options=build_options
    ).run()

    setattr(definition_func, "__annotations__", original_annotations)

    return definition_ir
예제 #3
0
    def test_set_arg_dtypes(self, dtype_in, dtype_out, dtype_scalar):
        definition = self.sumdiff_defs
        dtypes = {
            "dtype_in": dtype_in,
            "dtype_out": dtype_out,
            "dtype_scalar": dtype_scalar
        }

        definition, original_annotations = gtscript._set_arg_dtypes(
            definition, dtypes)

        assert "in_a" in original_annotations
        assert isinstance(original_annotations["in_a"],
                          gtscript._FieldDescriptor)
        assert original_annotations["in_a"].dtype == "dtype_in"
        assert "in_b" in original_annotations
        assert isinstance(original_annotations["in_b"],
                          gtscript._FieldDescriptor)
        assert original_annotations["in_b"].dtype == "dtype_in"
        assert "out_c" in original_annotations
        assert isinstance(original_annotations["out_c"],
                          gtscript._FieldDescriptor)
        assert original_annotations["out_c"].dtype == "dtype_out"
        assert "out_d" in original_annotations
        assert isinstance(original_annotations["out_d"],
                          gtscript._FieldDescriptor)
        assert original_annotations["out_d"].dtype == float
        assert "wa" in original_annotations
        assert original_annotations["wa"] == "dtype_scalar"
        assert "wb" in original_annotations
        assert original_annotations["wb"] == int
        assert len(original_annotations) == 6

        annotations = getattr(definition, "__annotations__", {})
        assert "in_a" in annotations
        assert isinstance(annotations["in_a"], gtscript._FieldDescriptor)
        assert annotations["in_a"].dtype == dtype_in
        assert "in_b" in annotations
        assert isinstance(annotations["in_b"], gtscript._FieldDescriptor)
        assert annotations["in_b"].dtype == dtype_in
        assert "out_c" in annotations
        assert isinstance(annotations["out_c"], gtscript._FieldDescriptor)
        assert annotations["out_c"].dtype == dtype_out
        assert "out_d" in annotations
        assert isinstance(annotations["out_d"], gtscript._FieldDescriptor)
        assert annotations["out_d"].dtype == float
        assert "wa" in annotations
        assert annotations["wa"] == dtype_scalar
        assert "wb" in annotations
        assert annotations["wb"] == int
        assert len(annotations) == 6

        setattr(definition, "__annotations__", original_annotations)