Example #1
0
    def visit_VerticalLoop(self, node: gtir.VerticalLoop, *,
                           ctx: Context) -> oir.VerticalLoop:
        horiz_execs: List[oir.HorizontalExecution] = []
        for stmt in node.body:
            ctx.reset_local_scalars()
            ret = self.visit(stmt, ctx=ctx)
            stmts = utils.flatten_list(
                [ret] if isinstance(ret, oir.Stmt) else ret)
            horiz_execs.append(
                oir.HorizontalExecution(body=stmts,
                                        declarations=ctx.local_scalars))

        ctx.temp_fields += [
            oir.Temporary(name=temp.name,
                          dtype=temp.dtype,
                          dimensions=temp.dimensions)
            for temp in node.temporaries
        ]

        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=[
                oir.VerticalLoopSection(
                    interval=self.visit(node.interval),
                    horizontal_executions=horiz_execs,
                    loc=node.loc,
                )
            ],
        )
Example #2
0
    def visit_FieldIfStmt(self,
                          node: gtir.FieldIfStmt,
                          *,
                          mask: oir.Expr = None,
                          ctx: Context,
                          **kwargs: Any) -> List[oir.Stmt]:
        mask_field_decl = oir.Temporary(name=f"mask_{id(node)}",
                                        dtype=DataType.BOOL,
                                        dimensions=(True, True, True))
        ctx.temp_fields.append(mask_field_decl)
        stmts = [
            oir.AssignStmt(
                left=oir.FieldAccess(
                    name=mask_field_decl.name,
                    offset=CartesianOffset.zero(),
                    dtype=DataType.BOOL,
                    loc=node.loc,
                ),
                right=self.visit(node.cond),
            )
        ]

        current_mask = oir.FieldAccess(
            name=mask_field_decl.name,
            offset=CartesianOffset.zero(),
            dtype=mask_field_decl.dtype,
            loc=node.loc,
        )
        combined_mask = current_mask
        if mask:
            combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                         left=mask,
                                         right=combined_mask,
                                         loc=node.loc)
        stmts.extend(
            self.visit(node.true_branch.body,
                       mask=combined_mask,
                       ctx=ctx,
                       **kwargs))

        if node.false_branch:
            combined_mask = oir.UnaryOp(op=UnaryOperator.NOT,
                                        expr=current_mask)
            if mask:
                combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                             left=mask,
                                             right=combined_mask,
                                             loc=node.loc)
            stmts.extend(
                self.visit(node.false_branch.body,
                           mask=combined_mask,
                           ctx=ctx,
                           **kwargs))

        return stmts
Example #3
0
    def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context,
                           **kwargs: Any) -> oir.VerticalLoop:
        ctx.horizontal_executions.clear()
        self.visit(node.body, ctx=ctx)

        for temp in node.temporaries:
            ctx.add_decl(oir.Temporary(name=temp.name, dtype=temp.dtype))

        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=[
                oir.VerticalLoopSection(
                    interval=self.visit(node.interval, **kwargs),
                    horizontal_executions=ctx.horizontal_executions,
                )
            ],
            caches=[],
        )
Example #4
0
def _create_mask(ctx: "GTIRToOIR.Context", name: str, cond: oir.Expr) -> oir.Temporary:
    mask_field_decl = oir.Temporary(name=name, dtype=DataType.BOOL, dimensions=(True, True, True))
    ctx.add_decl(mask_field_decl)

    fill_mask_field = oir.HorizontalExecution(
        body=[
            oir.AssignStmt(
                left=oir.FieldAccess(
                    name=mask_field_decl.name,
                    offset=CartesianOffset.zero(),
                    dtype=mask_field_decl.dtype,
                ),
                right=cond,
            )
        ],
        declarations=[],
    )
    ctx.add_horizontal_execution(fill_mask_field)
    return mask_field_decl
Example #5
0
def sdfg_arrays_to_oir_decls(
        sdfg: dace.SDFG) -> Tuple[List[oir.Decl], List[oir.Temporary]]:
    params = list()
    decls = list()

    array: dace.data.Data
    for name, array in sdfg.arrays.items():
        dtype = common.typestr_to_data_type(dace_dtype_to_typestr(array.dtype))
        if isinstance(array, dace.data.Array):
            dimensions = array_dimensions(array)
            if not array.transient:
                params.append(
                    oir.FieldDecl(
                        name=name,
                        dtype=dtype,
                        dimensions=dimensions,
                        data_dims=array.shape[sum(dimensions):],
                    ))
            else:
                decls.append(
                    oir.Temporary(
                        name=name,
                        dtype=dtype,
                        dimensions=dimensions,
                        data_dims=array.shape[sum(dimensions):],
                    ))
        else:
            assert isinstance(array, dace.data.Scalar)
            params.append(oir.ScalarDecl(name=name, dtype=dtype))

    reserved_symbols = internal_symbols(sdfg)
    for sym, stype in sdfg.symbols.items():
        if sym not in reserved_symbols:
            params.append(
                oir.ScalarDecl(name=sym,
                               dtype=common.typestr_to_data_type(
                                   stype.as_numpy_dtype().str)))
    return params, decls
Example #6
0
    def visit_VerticalLoop(
        self,
        node: oir.VerticalLoop,
        *,
        new_tmps: List[oir.Temporary],
        symtable: Dict[str, Any],
        new_symbol_name: Callable[[str], str],
        **kwargs: Any,
    ) -> oir.VerticalLoop:
        filling_fields: Dict[str, str] = {
            c.name: new_symbol_name(c.name)
            for c in node.caches
            if isinstance(c, oir.KCache) and c.fill
        }
        flushing_fields: Dict[str, str] = {
            c.name: filling_fields[c.name] if c.name in filling_fields else new_symbol_name(c.name)
            for c in node.caches
            if isinstance(c, oir.KCache) and c.flush
        }

        filling_or_flushing_fields = dict(
            set(filling_fields.items()) | set(flushing_fields.items())
        )

        if not filling_or_flushing_fields:
            return node

        # new temporaries used for caches, declarations are later added to stencil
        for field_name, tmp_name in filling_or_flushing_fields.items():
            new_tmps.append(
                oir.Temporary(
                    name=tmp_name, dtype=symtable[field_name].dtype, dimensions=(True, True, True)
                )
            )

        if filling_fields:
            # split sections where more than one fill operations are required at the entry level
            first_unfilled: Dict[str, int] = dict()
            split_sections: List[oir.VerticalLoopSection] = []
            for section in node.sections:
                split_section, previous_fills = self._split_section_with_multiple_fills(
                    node.loop_order, section, filling_fields, first_unfilled, new_symbol_name
                )
                split_sections += split_section
        else:
            split_sections = node.sections

        # generate cache fill and flush statements
        first_unfilled = dict()
        sections = []
        for section in split_sections:
            fills, first_unfilled = self._fill_stmts(
                node.loop_order, section, filling_fields, first_unfilled, symtable
            )
            flushes = self._flush_stmts(node.loop_order, section, flushing_fields, symtable)
            sections.append(
                self.visit(
                    section,
                    fills=fills,
                    flushes=flushes,
                    name_map=filling_or_flushing_fields,
                    symtable=symtable,
                    **kwargs,
                )
            )

        # replace cache declarations
        caches = [c for c in node.caches if c.name not in filling_or_flushing_fields] + [
            oir.KCache(name=f, fill=False, flush=False) for f in filling_or_flushing_fields.values()
        ]

        return oir.VerticalLoop(loop_order=node.loop_order, sections=sections, caches=caches)