Ejemplo n.º 1
0
    def _generate_module_info(cls, definition_ir, options, field_info) -> Dict[str, Any]:
        info = {}
        if definition_ir.sources is not None:
            info["sources"].update(
                {
                    key: gt_utils.text.format_source(value, line_length=100)
                    for key, value in definition_ir.sources
                }
            )
        else:
            info["sources"] = {}

        parallel_axes = definition_ir.domain.parallel_axes or []
        sequential_axis = definition_ir.domain.sequential_axis.name
        domain_info = gt_definitions.DomainInfo(
            parallel_axes=tuple(ax.name for ax in parallel_axes),
            sequential_axis=sequential_axis,
            ndims=len(parallel_axes) + (1 if sequential_axis else 0),
        )
        info["domain_info"] = repr(domain_info)

        info["docstring"] = definition_ir.docstring
        info["field_info"] = {}
        info["parameter_info"] = {}
        info["unreferenced"] = []

        fields = {item.name: item for item in definition_ir.api_fields}
        parameters = {item.name: item for item in definition_ir.parameters}

        for arg in definition_ir.api_signature:
            if arg.name in fields:
                access = field_info[arg.name]["access"]
                if access is None:
                    access = gt_definitions.AccessKind.READ_ONLY
                    info["unreferenced"].append(arg.name)
                extent = field_info[arg.name]["extent"]
                boundary = gt_definitions.Boundary([(-pair[0], pair[1]) for pair in extent])
                info["field_info"][arg.name] = gt_definitions.FieldInfo(
                    access=access, dtype=fields[arg.name].data_type.dtype, boundary=boundary
                )
            else:
                info["parameter_info"][arg.name] = gt_definitions.ParameterInfo(
                    dtype=parameters[arg.name].data_type.dtype
                )

        if definition_ir.externals:
            info["gt_constants"] = {
                name: repr(value)
                for name, value in definition_ir.externals.items()
                if isinstance(value, numbers.Number)
            }
        else:
            info["gt_constants"] = {}

        info["gt_options"] = {
            key: value for key, value in options.as_dict().items() if key not in ["build_info"]
        }

        return info
Ejemplo n.º 2
0
    def __call__(
        self,
        args_data: Dict[str, Any],
        builder: Optional["StencilBuilder"] = None,
        **kwargs: Any,
    ) -> str:
        """Generate source code for a Python module containing a StencilObject."""
        if builder:
            self._builder = builder
        self.args_data = args_data

        definition_ir = self.builder.definition_ir

        if definition_ir.sources is not None:
            sources = {
                key: gt_utils.text.format_source(
                    value, line_length=self.SOURCE_LINE_LENGTH)
                for key, value in definition_ir.sources
            }
        else:
            sources = {}

        if definition_ir.externals:
            constants = {
                name: repr(value)
                for name, value in definition_ir.externals.items()
                if isinstance(value, numbers.Number)
            }
        else:
            constants = {}

        options = {
            key: value
            for key, value in self.builder.options.as_dict().items()
            if key not in ["build_info"]
        }

        parallel_axes = definition_ir.domain.parallel_axes or []
        sequential_axis = definition_ir.domain.sequential_axis.name
        domain_info = repr(
            gt_definitions.DomainInfo(
                parallel_axes=tuple(ax.name for ax in parallel_axes),
                sequential_axis=sequential_axis,
                ndims=len(parallel_axes) + (1 if sequential_axis else 0),
            ))

        module_source = self.template.render(
            imports=self.generate_imports(),
            module_members=self.generate_module_members(),
            class_name=self.builder.class_name,
            class_members=self.generate_class_members(),
            docstring=definition_ir.docstring,
            gt_backend=self.backend_name,
            gt_source=sources,
            gt_domain_info=domain_info,
            gt_field_info=repr(self.args_data["field_info"]),
            gt_parameter_info=repr(self.args_data["parameter_info"]),
            gt_constants=constants,
            gt_options=options,
            stencil_signature=self.generate_signature(),
            field_names=self.args_data["field_info"].keys(),
            param_names=self.args_data["parameter_info"].keys(),
            pre_run=self.generate_pre_run(),
            post_run=self.generate_post_run(),
            implementation=self.generate_implementation(),
        )
        if options["format_source"]:
            module_source = gt_utils.text.format_source(
                module_source, line_length=self.SOURCE_LINE_LENGTH)

        return module_source
Ejemplo n.º 3
0
Archivo: base.py Proyecto: twicki/gt4py
    def _generate_module_info(self) -> Dict[str, Any]:
        info = {}
        implementation_ir = self.implementation_ir

        if self.definition_ir.sources is not None:
            info["sources"].update({
                key: gt_utils.text.format_source(
                    value, line_length=self.SOURCE_LINE_LENGTH)
                for key, value in self.definition_ir.sources
            })
        else:
            info["sources"] = {}

        info["docstring"] = implementation_ir.docstring

        parallel_axes = implementation_ir.domain.parallel_axes or []
        sequential_axis = implementation_ir.domain.sequential_axis.name
        info["domain_info"] = repr(
            gt_definitions.DomainInfo(
                parallel_axes=tuple(ax.name for ax in parallel_axes),
                sequential_axis=sequential_axis,
                ndims=len(parallel_axes) + (1 if sequential_axis else 0),
            ))

        info["field_info"] = field_info = {}
        info["parameter_info"] = parameter_info = {}

        # Collect access type per field
        out_fields = set()
        for ms in implementation_ir.multi_stages:
            for sg in ms.groups:
                for st in sg.stages:
                    for acc in st.accessors:
                        if (isinstance(acc, gt_ir.FieldAccessor) and acc.intent
                                == gt_ir.AccessIntent.READ_WRITE):
                            out_fields.add(acc.symbol)

        for arg in implementation_ir.api_signature:
            if arg.name in implementation_ir.fields:
                access = (gt_definitions.AccessKind.READ_WRITE
                          if arg.name in out_fields else
                          gt_definitions.AccessKind.READ_ONLY)
                if arg.name not in implementation_ir.unreferenced:
                    field_info[arg.name] = gt_definitions.FieldInfo(
                        access=access,
                        dtype=implementation_ir.fields[
                            arg.name].data_type.dtype,
                        boundary=implementation_ir.fields_extents[
                            arg.name].to_boundary(),
                    )
                else:
                    field_info[arg.name] = None
            else:
                if arg.name not in implementation_ir.unreferenced:
                    parameter_info[arg.name] = gt_definitions.ParameterInfo(
                        dtype=implementation_ir.parameters[
                            arg.name].data_type.dtype)
                else:
                    parameter_info[arg.name] = None

        if implementation_ir.externals:
            info["gt_constants"] = {
                name: repr(value)
                for name, value in implementation_ir.externals.items()
                if isinstance(value, numbers.Number)
            }
        else:
            info["gt_constants"] = {}

        info["gt_options"] = {
            key: value
            for key, value in self.options.as_dict().items()
            if key not in ["build_info"]
        }

        info["unreferenced"] = self.implementation_ir.unreferenced

        return info
Ejemplo n.º 4
0
    def __call__(self, stencil_id, implementation_ir):
        self.stencil_id = stencil_id
        self.implementation_ir = implementation_ir

        stencil_signature = self.generate_signature()

        sources = {}
        if implementation_ir.sources is not None:
            sources = {
                key: gt_utils.text.format_source(value, line_length=self.SOURCE_LINE_LENGTH)
                for key, value in implementation_ir.sources
            }

        parallel_axes = implementation_ir.domain.parallel_axes or []
        sequential_axis = implementation_ir.domain.sequential_axis.name
        domain_info = repr(
            gt_definitions.DomainInfo(
                parallel_axes=tuple(ax.name for ax in parallel_axes),
                sequential_axis=sequential_axis,
                ndims=len(parallel_axes) + (1 if sequential_axis else 0),
            )
        )

        field_info = {}
        field_names = []
        parameter_info = {}
        param_names = []

        # Collect access type per field
        out_fields = set()
        for ms in implementation_ir.multi_stages:
            for sg in ms.groups:
                for st in sg.stages:
                    for acc in st.accessors:
                        if (
                            isinstance(acc, gt_ir.FieldAccessor)
                            and acc.intent == gt_ir.AccessIntent.READ_WRITE
                        ):
                            out_fields.add(acc.symbol)

        for arg in implementation_ir.api_signature:
            if arg.name in implementation_ir.fields:
                access = (
                    gt_definitions.AccessKind.READ_WRITE
                    if arg.name in out_fields
                    else gt_definitions.AccessKind.READ_ONLY
                )
                if arg.name not in implementation_ir.unreferenced:
                    field_info[arg.name] = gt_definitions.FieldInfo(
                        access=access,
                        dtype=implementation_ir.fields[arg.name].data_type.dtype,
                        boundary=implementation_ir.fields_extents[arg.name].to_boundary(),
                    )
                else:
                    field_info[arg.name] = None
                field_names.append(arg.name)
            else:
                if arg.name not in implementation_ir.unreferenced:
                    parameter_info[arg.name] = gt_definitions.ParameterInfo(
                        dtype=implementation_ir.parameters[arg.name].data_type.dtype
                    )
                else:
                    parameter_info[arg.name] = None
                param_names.append(arg.name)

        field_info = repr(field_info)
        parameter_info = repr(parameter_info)

        if implementation_ir.externals:
            gt_constants = {
                name: repr(value)
                for name, value in implementation_ir.externals.items()
                if isinstance(value, numbers.Number)
            }
        else:
            gt_constants = {}

        gt_options = dict(self.options.__dict__)
        if "build_info" in gt_options:
            del gt_options["build_info"]

        # Concrete implementation in the subclasses
        imports = self.generate_imports()
        module_members = self.generate_module_members()
        class_members = self.generate_class_members()
        implementation = self.generate_implementation()

        module_source = self.template.render(
            imports=imports,
            module_members=module_members,
            class_name=self.stencil_class_name,
            class_members=class_members,
            docstring=implementation_ir.docstring,
            gt_backend=self.backend_name,
            gt_source=sources,
            gt_domain_info=domain_info,
            gt_field_info=field_info,
            gt_parameter_info=parameter_info,
            gt_constants=gt_constants,
            gt_options=gt_options,
            stencil_signature=stencil_signature,
            field_names=field_names,
            param_names=param_names,
            synchronization=self.generate_synchronization(
                [
                    k
                    for k in implementation_ir.fields.keys()
                    if k not in implementation_ir.temporary_fields
                    and k not in implementation_ir.unreferenced
                ]
            ),
            mark_modified=self.generate_mark_modified(
                [
                    k
                    for k in out_fields
                    if k not in implementation_ir.temporary_fields
                    and k not in implementation_ir.unreferenced
                ]
            ),
            implementation=implementation,
        )
        module_source = gt_utils.text.format_source(
            module_source, line_length=self.SOURCE_LINE_LENGTH
        )

        return module_source