Ejemplo n.º 1
0
def test_j():
    def stencil(field_a: gs.Field[float], field_b: gs.Field[float, gs.J]):
        with computation(PARALLEL), interval(...):
            field_a = field_b[1] + field_b[-2]

    builder = StencilBuilder(stencil, backend=from_name("debug"))
    old_ext = builder.implementation_ir.fields_extents
    legacy_ext = compute_legacy_extents(prepare_gtir(builder))

    for name, ext in old_ext.items():
        assert legacy_ext[name] == ext
Ejemplo n.º 2
0
def test_single_k_offset():
    def stencil(field_a: gs.Field[float], field_b: gs.Field[float]):
        with computation(PARALLEL), interval(...):
            field_a = field_b[0, 0, 1]

    builder = StencilBuilder(stencil, backend=from_name("debug"))
    old_ext = builder.implementation_ir.fields_extents
    legacy_ext = compute_legacy_extents(prepare_gtir(builder),
                                        mask_inwards=True)

    for name, ext in old_ext.items():
        assert legacy_ext[name] == ext
Ejemplo n.º 3
0
def make_args_data_from_gtir(pipeline: GtirPipeline) -> ModuleData:
    data = ModuleData()
    node = pipeline.full()
    field_extents = compute_legacy_extents(node)

    write_fields = (
        node.iter_tree()
        .if_isinstance(gtir.ParAssignStmt)
        .getattr("left")
        .if_isinstance(gtir.FieldAccess)
        .getattr("name")
        .to_set()
    )

    read_fields: Set[str] = set()
    for expr in node.iter_tree().if_isinstance(gtir.ParAssignStmt).getattr("right"):
        read_fields |= expr.iter_tree().if_isinstance(gtir.FieldAccess).getattr("name").to_set()

    referenced_field_params = [
        param.name for param in node.params if isinstance(param, gtir.FieldDecl)
    ]
    for name in sorted(referenced_field_params):
        access = AccessKind.NONE
        if name in read_fields:
            access |= AccessKind.READ
        if name in write_fields:
            access |= AccessKind.WRITE
        data.field_info[name] = FieldInfo(
            access=access,
            boundary=field_extents[name].to_boundary(),
            axes=tuple(dimension_flags_to_names(node.symtable_[name].dimensions).upper()),
            data_dims=tuple(node.symtable_[name].data_dims),
            dtype=numpy.dtype(node.symtable_[name].dtype.name.lower()),
        )

    referenced_scalar_params = [
        param.name for param in node.params if param.name not in referenced_field_params
    ]
    for name in sorted(referenced_scalar_params):
        data.parameter_info[name] = ParameterInfo(
            dtype=numpy.dtype(node.symtable_[name].dtype.name.lower())
        )

    unref_params = get_unused_params_from_gtir(pipeline)
    for param in sorted(unref_params, key=lambda decl: decl.name):
        if isinstance(param, gtir.FieldDecl):
            data.field_info[param.name] = None
        elif isinstance(param, gtir.ScalarDecl):
            data.parameter_info[param.name] = None

    data.unreferenced = [*sorted(param.name for param in unref_params)]
    return data
Ejemplo n.º 4
0
 def generate_computation(self) -> Dict[str, Union[str, Dict]]:
     computation_name = (self.builder.caching.module_prefix +
                         "computation" +
                         self.builder.caching.module_postfix + ".py")
     return {
         computation_name:
         format_source(
             "python",
             NpirGen.apply(self.npir,
                           field_extents=compute_legacy_extents(
                               self.builder.gtir)),
         ),
     }
Ejemplo n.º 5
0
def test_offset_chain():
    def stencil(field_a: gs.Field[float], field_b: gs.Field[float]):
        with computation(PARALLEL), interval(...):
            field_a = field_b[1, 0, 1]
        with computation(PARALLEL), interval(...):
            field_b = field_a[1, 0, 0]
        with computation(PARALLEL), interval(...):
            tmp = field_b[0, -1, 0] + field_b[0, 1, 0]
            field_a = tmp[0, 0, 0] + tmp[0, 0, -1]

    builder = StencilBuilder(stencil, backend=from_name("debug"))
    old_ext = builder.implementation_ir.fields_extents
    legacy_ext = compute_legacy_extents(prepare_gtir(builder))

    for name, ext in old_ext.items():
        assert legacy_ext[name] == ext
Ejemplo n.º 6
0
def test_field_if():
    def stencil(field_a: gs.Field[float], field_b: gs.Field[float]):
        with computation(PARALLEL), interval(...):
            if field_b[0, 1, 0] < 0.1:
                if field_b[1, 0, 0] > 1.0:
                    field_a = 0
                else:
                    field_a = 1
            else:
                tmp = -field_b[0, 1, 0]
                field_a = tmp

    builder = StencilBuilder(stencil, backend=from_name("debug"))
    old_ext = builder.implementation_ir.fields_extents
    legacy_ext = compute_legacy_extents(prepare_gtir(builder))

    for name, ext in old_ext.items():
        assert legacy_ext[name] == ext
Ejemplo n.º 7
0
def make_args_data_from_gtir(pipeline: GtirPipeline) -> ModuleData:
    data = ModuleData()
    node = pipeline.full()
    field_extents = compute_legacy_extents(node)

    write_fields = (node.iter_tree().if_isinstance(
        gtir.ParAssignStmt).getattr("left").if_isinstance(
            gtir.FieldAccess).getattr("name").to_list())

    referenced_field_params = {
        param.name
        for param in node.params if isinstance(param, gtir.FieldDecl)
    }
    for name in referenced_field_params:
        data.field_info[name] = FieldInfo(
            access=AccessKind.READ_WRITE
            if name in write_fields else AccessKind.READ_ONLY,
            boundary=field_extents[name].to_boundary(),
            axes=list(
                dimension_flags_to_names(
                    node.symtable_[name].dimensions).upper()),
            dtype=numpy.dtype(node.symtable_[name].dtype.name.lower()),
        )

    referenced_scalar_params = set(
        node.param_names).difference(referenced_field_params)
    for name in referenced_scalar_params:
        data.parameter_info[name] = ParameterInfo(
            dtype=numpy.dtype(node.symtable_[name].dtype.name.lower()))

    unref_params = get_unused_params_from_gtir(pipeline)
    for param in unref_params:
        if isinstance(param, gtir.FieldDecl):
            data.field_info[param.name] = None
        elif isinstance(param, gtir.ScalarDecl):
            data.parameter_info[param.name] = None

    data.unreferenced = {param.name for param in unref_params}
    return data
Ejemplo n.º 8
0
    def generate_dace_args(self, gtir, sdfg):
        offset_dict: Dict[str, Tuple[int, int, int]] = {
            k: (-v[0][0], -v[1][0], -v[2][0]) for k, v in compute_legacy_extents(gtir).items()
        }
        k_origins = {
            field_name: boundary[0] for field_name, boundary in compute_k_boundary(gtir).items()
        }
        for name, origin in k_origins.items():
            offset_dict[name] = (offset_dict[name][0], offset_dict[name][1], origin)

        symbols = {f"__{var}": f"__{var}" for var in "IJK"}
        for name, array in sdfg.arrays.items():
            if array.transient:
                symbols[f"__{name}_K_stride"] = "1"
                symbols[f"__{name}_J_stride"] = str(array.shape[2])
                symbols[f"__{name}_I_stride"] = str(array.shape[1] * array.shape[2])
            else:
                dims = [dim for dim, select in zip("IJK", array_dimensions(array)) if select]
                data_ndim = len(array.shape) - len(dims)

                # api field strides
                fmt = "gt::sid::get_stride<{dim}>(gt::sid::get_strides(__{name}_sid))"

                symbols.update(
                    {
                        f"__{name}_{dim}_stride": fmt.format(
                            dim=f"gt::stencil::dim::{dim.lower()}", name=name
                        )
                        for dim in dims
                    }
                )
                symbols.update(
                    {
                        f"__{name}_d{dim}_stride": fmt.format(
                            dim=f"gt::integral_constant<int, {3 + dim}>", name=name
                        )
                        for dim in range(data_ndim)
                    }
                )

                # api field pointers
                fmt = """gt::sid::multi_shifted(
                             gt::sid::get_origin(__{name}_sid)(),
                             gt::sid::get_strides(__{name}_sid),
                             std::array<gt::int_t, {ndim}>{{{origin}}}
                         )"""
                origin = tuple(
                    -offset_dict[name][idx]
                    for idx, var in enumerate("IJK")
                    if any(
                        dace.symbolic.pystr_to_symbolic(f"__{var}") in s.free_symbols
                        for s in array.shape
                        if hasattr(s, "free_symbols")
                    )
                )
                symbols[name] = fmt.format(
                    name=name, ndim=len(array.shape), origin=",".join(str(o) for o in origin)
                )
        # the remaining arguments are variables and can be passed by name
        for sym in sdfg.signature_arglist(with_types=False, for_call=True):
            if sym not in symbols:
                symbols[sym] = sym

        # return strings in order of sdfg signature
        return [symbols[s] for s in sdfg.signature_arglist(with_types=False, for_call=True)]
Ejemplo n.º 9
0
def make_args_data_from_gtir(pipeline: GtirPipeline,
                             legacy=False) -> ModuleData:
    """
    Compute module data containing information about stencil arguments from gtir.

    Use `legacy` parameter to ensure equality with values from :func:`make_args_data_from_iir`.
    """
    data = ModuleData()
    node = pipeline.full()

    write_fields = (node.iter_tree().if_isinstance(
        gtir.ParAssignStmt).getattr("left").if_isinstance(
            gtir.FieldAccess).getattr("name").to_set())

    read_fields: Set[str] = set()
    for expr in node.iter_tree().if_isinstance(
            gtir.ParAssignStmt).getattr("right"):
        read_fields |= expr.iter_tree().if_isinstance(
            gtir.FieldAccess).getattr("name").to_set()

    referenced_field_params = [
        param.name for param in node.params
        if isinstance(param, gtir.FieldDecl)
    ]
    field_extents = compute_legacy_extents(node, mask_inwards=legacy)
    k_boundary = (compute_k_boundary(node) if not legacy else
                  {v: (0, 0)
                   for v in referenced_field_params})
    for name in sorted(referenced_field_params):
        access = AccessKind.NONE
        if name in read_fields:
            access |= AccessKind.READ
        if name in write_fields:
            access |= AccessKind.WRITE
        boundary = Boundary(*field_extents[name].to_boundary()[0:2],
                            k_boundary[name])
        data.field_info[name] = FieldInfo(
            access=access,
            boundary=boundary,
            axes=tuple(
                dimension_flags_to_names(
                    node.symtable_[name].dimensions).upper()),
            data_dims=tuple(node.symtable_[name].data_dims),
            dtype=numpy.dtype(node.symtable_[name].dtype.name.lower()),
        )

    referenced_scalar_params = [
        param.name for param in node.params
        if param.name not in referenced_field_params
    ]
    for name in sorted(referenced_scalar_params):
        data.parameter_info[name] = ParameterInfo(
            dtype=numpy.dtype(node.symtable_[name].dtype.name.lower()))

    unref_params = get_unused_params_from_gtir(pipeline)
    for param in sorted(unref_params, key=lambda decl: decl.name):
        if isinstance(param, gtir.FieldDecl):
            data.field_info[param.name] = None
        elif isinstance(param, gtir.ScalarDecl):
            data.parameter_info[param.name] = None

    data.unreferenced = [*sorted(param.name for param in unref_params)]
    return data