Esempio n. 1
0
    def __call__(
            self,
            definition_ir: StencilDefinition) -> Dict[str, Dict[str, str]]:
        gtir = DefIRToGTIR.apply(definition_ir)
        gtir_without_unused_params = prune_unused_parameters(gtir)
        dtype_deduced = resolve_dtype(gtir_without_unused_params)
        upcasted = upcast(dtype_deduced)
        oir = gtir_to_oir.GTIRToOIR().visit(upcasted)
        oir = self._optimize_oir(oir)
        sdfg = OirSDFGBuilder().visit(oir)
        sdfg.expand_library_nodes(recursive=True)
        # TODO uncomment once the branch dace/linus-fixes-8 is merged into dace/master
        # sdfg.apply_strict_transformations(validate=True) # noqa: E800 Found commented out code

        implementation = DaCeComputationCodegen.apply(gtir, sdfg)
        bindings = DaCeBindingsCodegen.apply(gtir,
                                             sdfg,
                                             module_name=self.module_name)

        bindings_ext = ".cu" if self.backend.GT_BACKEND_T == "gpu" else ".cpp"
        return {
            "computation": {
                "computation.hpp": implementation
            },
            "bindings": {
                "bindings" + bindings_ext: bindings
            },
        }
Esempio n. 2
0
def test_all_parameters_used():
    field_param = FieldDeclFactory()
    scalar_param = ScalarDeclFactory()
    testee = StencilFactory(
        params=[field_param, scalar_param],
        vertical_loops__0__body__0=ParAssignStmtFactory(
            left__name=field_param.name, right__name=scalar_param.name),
    )
    expected_params = [field_param, scalar_param]

    result = prune_unused_parameters(testee)

    assert expected_params == result.params
Esempio n. 3
0
def test_all_parameters_used():
    field_param = FieldDecl(name="field", dtype=A_ARITHMETIC_TYPE)
    scalar_param = ScalarDecl(name="scalar", dtype=A_ARITHMETIC_TYPE)
    testee = (
        StencilBuilder()
        .add_param(field_param)
        .add_param(scalar_param)
        .add_par_assign_stmt(ParAssignStmtBuilder("field", "scalar").build())
        .build()
    )
    expected_params = [field_param, scalar_param]

    result = prune_unused_parameters(testee)

    assert expected_params == result.params
Esempio n. 4
0
 def __call__(self, definition_ir) -> Dict[str, Dict[str, str]]:
     gtir = DefIRToGTIR.apply(definition_ir)
     gtir_without_unused_params = prune_unused_parameters(gtir)
     dtype_deduced = resolve_dtype(gtir_without_unused_params)
     upcasted = upcast(dtype_deduced)
     oir = gtir_to_oir.GTIRToOIR().visit(upcasted)
     oir = self._optimize_oir(oir)
     cuir = oir_to_cuir.OIRToCUIR().visit(oir)
     cuir = kernel_fusion.FuseKernels().visit(cuir)
     cuir = extent_analysis.ComputeExtents().visit(cuir)
     cuir = extent_analysis.CacheExtents().visit(cuir)
     implementation = cuir_codegen.CUIRCodegen.apply(cuir)
     bindings = GTCCudaBindingsCodegen.apply(cuir, module_name=self.module_name)
     return {
         "computation": {"computation.hpp": implementation},
         "bindings": {"bindings.cu": bindings},
     }
Esempio n. 5
0
 def __call__(self, definition_ir) -> Dict[str, Dict[str, str]]:
     gtir = DefIRToGTIR.apply(definition_ir)
     gtir_without_unused_params = prune_unused_parameters(gtir)
     dtype_deduced = resolve_dtype(gtir_without_unused_params)
     upcasted = upcast(dtype_deduced)
     oir = gtir_to_oir.GTIRToOIR().visit(upcasted)
     oir = self._optimize_oir(oir)
     gtcpp = oir_to_gtcpp.OIRToGTCpp().visit(oir)
     implementation = gtcpp_codegen.GTCppCodegen.apply(
         gtcpp, gt_backend_t=self.gt_backend_t)
     bindings = GTCppBindingsCodegen.apply(gtcpp,
                                           module_name=self.module_name,
                                           gt_backend_t=self.gt_backend_t)
     bindings_ext = ".cu" if self.gt_backend_t == "gpu" else ".cpp"
     return {
         "computation": {
             "computation.hpp": implementation
         },
         "bindings": {
             "bindings" + bindings_ext: bindings
         },
     }