def make_args_data(definition_ir: gt_ir.StencilDefinition, sir_field_info: Dict[str, Any]) -> Dict[str, Any]: data: Dict[str, Any] = { "field_info": {}, "parameter_info": {}, "unreferenced": {} } fields = {item.name: item for item in definition_ir.api_fields} parameters = {item.name: item for item in definition_ir.parameters} for arg in definition_ir.api_signature: if arg.name in fields: access = sir_field_info[arg.name]["access"] if access is None: access = gt_definitions.AccessKind.READ_ONLY data["unreferenced"].append(arg.name) extent = sir_field_info[arg.name]["extent"] boundary = gt_definitions.Boundary([(-pair[0], pair[1]) for pair in extent]) data["field_info"][arg.name] = gt_definitions.FieldInfo( access=access, dtype=fields[arg.name].data_type.dtype, boundary=boundary) else: data["parameter_info"][ arg.name] = gt_definitions.ParameterInfo( dtype=parameters[arg.name].data_type.dtype) return data
def make_args_data(definition_ir: gt_ir.StencilDefinition, sir_field_info: Dict[str, Any]) -> ModuleData: data = ModuleData() fields = {item.name: item for item in definition_ir.api_fields} parameters = {item.name: item for item in definition_ir.parameters} for arg in definition_ir.api_signature: if arg.name in fields: access = sir_field_info[arg.name]["access"] if access is None: access = gt_definitions.AccessKind.READ_ONLY data.unreferenced.add(arg.name) extent = sir_field_info[arg.name]["extent"] boundary = gt_definitions.Boundary([(-pair[0], pair[1]) for pair in extent]) data.field_info[arg.name] = gt_definitions.FieldInfo( access=access, boundary=boundary, axes=fields[arg.name].axes, dtype=fields[arg.name].data_type.dtype, ) else: data.parameter_info[arg.name] = gt_definitions.ParameterInfo( dtype=parameters[arg.name].data_type.dtype) return data
def _generate_module_info(cls, definition_ir, options, field_info) -> Dict[str, Any]: info = {} if definition_ir.sources is not None: info["sources"].update( { key: gt_utils.text.format_source(value, line_length=100) for key, value in definition_ir.sources } ) else: info["sources"] = {} parallel_axes = definition_ir.domain.parallel_axes or [] sequential_axis = definition_ir.domain.sequential_axis.name domain_info = gt_definitions.DomainInfo( parallel_axes=tuple(ax.name for ax in parallel_axes), sequential_axis=sequential_axis, ndims=len(parallel_axes) + (1 if sequential_axis else 0), ) info["domain_info"] = repr(domain_info) info["docstring"] = definition_ir.docstring info["field_info"] = {} info["parameter_info"] = {} info["unreferenced"] = [] fields = {item.name: item for item in definition_ir.api_fields} parameters = {item.name: item for item in definition_ir.parameters} for arg in definition_ir.api_signature: if arg.name in fields: access = field_info[arg.name]["access"] if access is None: access = gt_definitions.AccessKind.READ_ONLY info["unreferenced"].append(arg.name) extent = field_info[arg.name]["extent"] boundary = gt_definitions.Boundary([(-pair[0], pair[1]) for pair in extent]) info["field_info"][arg.name] = gt_definitions.FieldInfo( access=access, dtype=fields[arg.name].data_type.dtype, boundary=boundary ) else: info["parameter_info"][arg.name] = gt_definitions.ParameterInfo( dtype=parameters[arg.name].data_type.dtype ) if definition_ir.externals: info["gt_constants"] = { name: repr(value) for name, value in definition_ir.externals.items() if isinstance(value, numbers.Number) } else: info["gt_constants"] = {} info["gt_options"] = { key: value for key, value in options.as_dict().items() if key not in ["build_info"] } return info
def make_args_data_from_iir( implementation_ir: gt_ir.StencilImplementation) -> Dict[str, Any]: data: Dict[str, Any] = { "field_info": {}, "parameter_info": {}, "unreferenced": {} } # Collect access type per field out_fields = set() for ms in implementation_ir.multi_stages: for sg in ms.groups: for st in sg.stages: for acc in st.accessors: if (isinstance(acc, gt_ir.FieldAccessor) and acc.intent == gt_ir.AccessIntent.READ_WRITE): out_fields.add(acc.symbol) for arg in implementation_ir.api_signature: if arg.name in implementation_ir.fields: access = (gt_definitions.AccessKind.READ_WRITE if arg.name in out_fields else gt_definitions.AccessKind.READ_ONLY) if arg.name not in implementation_ir.unreferenced: field_decl = implementation_ir.fields[arg.name] data["field_info"][arg.name] = gt_definitions.FieldInfo( access=access, boundary=implementation_ir.fields_extents[ arg.name].to_boundary(), axes=field_decl.axes, dtype=field_decl.data_type.dtype, ) else: data["field_info"][arg.name] = None else: if arg.name not in implementation_ir.unreferenced: data["parameter_info"][ arg.name] = gt_definitions.ParameterInfo( dtype=implementation_ir.parameters[ arg.name].data_type.dtype) else: data["parameter_info"][arg.name] = None data["unreferenced"] = implementation_ir.unreferenced return data
def _generate_module_info(self) -> Dict[str, Any]: info = {} implementation_ir = self.implementation_ir if self.definition_ir.sources is not None: info["sources"].update({ key: gt_utils.text.format_source( value, line_length=self.SOURCE_LINE_LENGTH) for key, value in self.definition_ir.sources }) else: info["sources"] = {} info["docstring"] = implementation_ir.docstring parallel_axes = implementation_ir.domain.parallel_axes or [] sequential_axis = implementation_ir.domain.sequential_axis.name info["domain_info"] = repr( gt_definitions.DomainInfo( parallel_axes=tuple(ax.name for ax in parallel_axes), sequential_axis=sequential_axis, ndims=len(parallel_axes) + (1 if sequential_axis else 0), )) info["field_info"] = field_info = {} info["parameter_info"] = parameter_info = {} # Collect access type per field out_fields = set() for ms in implementation_ir.multi_stages: for sg in ms.groups: for st in sg.stages: for acc in st.accessors: if (isinstance(acc, gt_ir.FieldAccessor) and acc.intent == gt_ir.AccessIntent.READ_WRITE): out_fields.add(acc.symbol) for arg in implementation_ir.api_signature: if arg.name in implementation_ir.fields: access = (gt_definitions.AccessKind.READ_WRITE if arg.name in out_fields else gt_definitions.AccessKind.READ_ONLY) if arg.name not in implementation_ir.unreferenced: field_info[arg.name] = gt_definitions.FieldInfo( access=access, dtype=implementation_ir.fields[ arg.name].data_type.dtype, boundary=implementation_ir.fields_extents[ arg.name].to_boundary(), ) else: field_info[arg.name] = None else: if arg.name not in implementation_ir.unreferenced: parameter_info[arg.name] = gt_definitions.ParameterInfo( dtype=implementation_ir.parameters[ arg.name].data_type.dtype) else: parameter_info[arg.name] = None if implementation_ir.externals: info["gt_constants"] = { name: repr(value) for name, value in implementation_ir.externals.items() if isinstance(value, numbers.Number) } else: info["gt_constants"] = {} info["gt_options"] = { key: value for key, value in self.options.as_dict().items() if key not in ["build_info"] } info["unreferenced"] = self.implementation_ir.unreferenced return info
def __call__(self, stencil_id, implementation_ir): self.stencil_id = stencil_id self.implementation_ir = implementation_ir stencil_signature = self.generate_signature() sources = {} if implementation_ir.sources is not None: sources = { key: gt_utils.text.format_source(value, line_length=self.SOURCE_LINE_LENGTH) for key, value in implementation_ir.sources } parallel_axes = implementation_ir.domain.parallel_axes or [] sequential_axis = implementation_ir.domain.sequential_axis.name domain_info = repr( gt_definitions.DomainInfo( parallel_axes=tuple(ax.name for ax in parallel_axes), sequential_axis=sequential_axis, ndims=len(parallel_axes) + (1 if sequential_axis else 0), ) ) field_info = {} field_names = [] parameter_info = {} param_names = [] # Collect access type per field out_fields = set() for ms in implementation_ir.multi_stages: for sg in ms.groups: for st in sg.stages: for acc in st.accessors: if ( isinstance(acc, gt_ir.FieldAccessor) and acc.intent == gt_ir.AccessIntent.READ_WRITE ): out_fields.add(acc.symbol) for arg in implementation_ir.api_signature: if arg.name in implementation_ir.fields: access = ( gt_definitions.AccessKind.READ_WRITE if arg.name in out_fields else gt_definitions.AccessKind.READ_ONLY ) if arg.name not in implementation_ir.unreferenced: field_info[arg.name] = gt_definitions.FieldInfo( access=access, dtype=implementation_ir.fields[arg.name].data_type.dtype, boundary=implementation_ir.fields_extents[arg.name].to_boundary(), ) else: field_info[arg.name] = None field_names.append(arg.name) else: if arg.name not in implementation_ir.unreferenced: parameter_info[arg.name] = gt_definitions.ParameterInfo( dtype=implementation_ir.parameters[arg.name].data_type.dtype ) else: parameter_info[arg.name] = None param_names.append(arg.name) field_info = repr(field_info) parameter_info = repr(parameter_info) if implementation_ir.externals: gt_constants = { name: repr(value) for name, value in implementation_ir.externals.items() if isinstance(value, numbers.Number) } else: gt_constants = {} gt_options = dict(self.options.__dict__) if "build_info" in gt_options: del gt_options["build_info"] # Concrete implementation in the subclasses imports = self.generate_imports() module_members = self.generate_module_members() class_members = self.generate_class_members() implementation = self.generate_implementation() module_source = self.template.render( imports=imports, module_members=module_members, class_name=self.stencil_class_name, class_members=class_members, docstring=implementation_ir.docstring, gt_backend=self.backend_name, gt_source=sources, gt_domain_info=domain_info, gt_field_info=field_info, gt_parameter_info=parameter_info, gt_constants=gt_constants, gt_options=gt_options, stencil_signature=stencil_signature, field_names=field_names, param_names=param_names, synchronization=self.generate_synchronization( [ k for k in implementation_ir.fields.keys() if k not in implementation_ir.temporary_fields and k not in implementation_ir.unreferenced ] ), mark_modified=self.generate_mark_modified( [ k for k in out_fields if k not in implementation_ir.temporary_fields and k not in implementation_ir.unreferenced ] ), implementation=implementation, ) module_source = gt_utils.text.format_source( module_source, line_length=self.SOURCE_LINE_LENGTH ) return module_source