def compile_definition( definition_func, name: str, module: str, *, externals: dict = None, dtypes: dict = None, rebuild=False, **kwargs, ): gtscript._set_arg_dtypes(definition_func, dtypes=dtypes or {}) build_options = gt_definitions.BuildOptions(name=name, module=module, rebuild=rebuild, backend_opts=kwargs, build_info=None) options_id = gt_utils.shashed_id(build_options) stencil_id = frontend.get_stencil_id(build_options.qualified_name, definition_func, externals, options_id) gt_frontend.GTScriptParser(definition_func, externals=externals or {}, options=build_options).run() return stencil_id
def parse_definition( definition_func: Callable[..., None], *, name: str, module: str, externals: Optional[Dict[str, Any]] = None, dtypes: Dict[Type, Type] = None, rebuild=False, **kwargs, ): original_annotations = gtscript._set_arg_dtypes(definition_func, dtypes=dtypes or {}) build_options = gt_definitions.BuildOptions( name=name, module=module, rebuild=rebuild, backend_opts=kwargs, build_info=None, ) gt_frontend.GTScriptFrontend.prepare_stencil_definition( definition_func, externals=externals or {} ) definition_ir = gt_frontend.GTScriptParser( definition_func, externals=externals or {}, options=build_options ).run() setattr(definition_func, "__annotations__", original_annotations) return definition_ir
def test_set_arg_dtypes(self, dtype_in, dtype_out, dtype_scalar): definition = self.sumdiff_defs dtypes = { "dtype_in": dtype_in, "dtype_out": dtype_out, "dtype_scalar": dtype_scalar } definition, original_annotations = gtscript._set_arg_dtypes( definition, dtypes) assert "in_a" in original_annotations assert isinstance(original_annotations["in_a"], gtscript._FieldDescriptor) assert original_annotations["in_a"].dtype == "dtype_in" assert "in_b" in original_annotations assert isinstance(original_annotations["in_b"], gtscript._FieldDescriptor) assert original_annotations["in_b"].dtype == "dtype_in" assert "out_c" in original_annotations assert isinstance(original_annotations["out_c"], gtscript._FieldDescriptor) assert original_annotations["out_c"].dtype == "dtype_out" assert "out_d" in original_annotations assert isinstance(original_annotations["out_d"], gtscript._FieldDescriptor) assert original_annotations["out_d"].dtype == float assert "wa" in original_annotations assert original_annotations["wa"] == "dtype_scalar" assert "wb" in original_annotations assert original_annotations["wb"] == int assert len(original_annotations) == 6 annotations = getattr(definition, "__annotations__", {}) assert "in_a" in annotations assert isinstance(annotations["in_a"], gtscript._FieldDescriptor) assert annotations["in_a"].dtype == dtype_in assert "in_b" in annotations assert isinstance(annotations["in_b"], gtscript._FieldDescriptor) assert annotations["in_b"].dtype == dtype_in assert "out_c" in annotations assert isinstance(annotations["out_c"], gtscript._FieldDescriptor) assert annotations["out_c"].dtype == dtype_out assert "out_d" in annotations assert isinstance(annotations["out_d"], gtscript._FieldDescriptor) assert annotations["out_d"].dtype == float assert "wa" in annotations assert annotations["wa"] == dtype_scalar assert "wb" in annotations assert annotations["wb"] == int assert len(annotations) == 6 setattr(definition, "__annotations__", original_annotations)