Exemplo n.º 1
0
    def _flush_stmts(
        cls,
        loop_order: common.LoopOrder,
        section: oir.VerticalLoopSection,
        flushing_fields: Dict[str, str],
        symtable: Dict[str, Any],
    ) -> List[oir.AssignStmt]:
        """Generate flush statements for the given loop section.

        Args:
            loop_order: forward or backward order.
            section: loop section to split.
            flushing_fields: mapping from field names to cache names.

        Returns:
            A list of flush statements.
        """
        write_fields = AccessCollector.apply(section).write_fields()
        flush_stmts = []
        for field, cache in flushing_fields.items():
            if field in write_fields:
                flush_stmts.append(
                    oir.AssignStmt(
                        left=oir.FieldAccess(
                            name=field,
                            dtype=symtable[field].dtype,
                            offset=common.CartesianOffset.zero(),
                        ),
                        right=oir.FieldAccess(
                            name=cache,
                            dtype=symtable[field].dtype,
                            offset=common.CartesianOffset.zero(),
                        ),
                    ))
        return flush_stmts
Exemplo n.º 2
0
    def visit_FieldIfStmt(self,
                          node: gtir.FieldIfStmt,
                          *,
                          mask: oir.Expr = None,
                          ctx: Context,
                          **kwargs: Any) -> List[oir.Stmt]:
        mask_field_decl = oir.Temporary(name=f"mask_{id(node)}",
                                        dtype=DataType.BOOL,
                                        dimensions=(True, True, True))
        ctx.temp_fields.append(mask_field_decl)
        stmts = [
            oir.AssignStmt(
                left=oir.FieldAccess(
                    name=mask_field_decl.name,
                    offset=CartesianOffset.zero(),
                    dtype=DataType.BOOL,
                    loc=node.loc,
                ),
                right=self.visit(node.cond),
            )
        ]

        current_mask = oir.FieldAccess(
            name=mask_field_decl.name,
            offset=CartesianOffset.zero(),
            dtype=mask_field_decl.dtype,
            loc=node.loc,
        )
        combined_mask = current_mask
        if mask:
            combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                         left=mask,
                                         right=combined_mask,
                                         loc=node.loc)
        stmts.extend(
            self.visit(node.true_branch.body,
                       mask=combined_mask,
                       ctx=ctx,
                       **kwargs))

        if node.false_branch:
            combined_mask = oir.UnaryOp(op=UnaryOperator.NOT,
                                        expr=current_mask)
            if mask:
                combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                             left=mask,
                                             right=combined_mask,
                                             loc=node.loc)
            stmts.extend(
                self.visit(node.false_branch.body,
                           mask=combined_mask,
                           ctx=ctx,
                           **kwargs))

        return stmts
Exemplo n.º 3
0
def test_assign_stmt_to_vector_assign(parallel_k):
    assign_stmt = oir.AssignStmt(
        left=oir.FieldAccess(
            name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64
        ),
        right=oir.FieldAccess(
            name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64
        ),
    )

    ctx = OirToNpir.ComputationContext()
    v_assign = OirToNpir().visit(assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None)
    assert isinstance(v_assign, npir.VectorAssign)
    assert v_assign.left.k_offset.parallel is parallel_k
    assert v_assign.right.k_offset.parallel is parallel_k
Exemplo n.º 4
0
 def visit_FieldAccess(self, node: gtir.FieldAccess, **kwargs: Any) -> oir.FieldAccess:
     return oir.FieldAccess(
         name=node.name,
         offset=self.visit(node.offset),
         data_index=self.visit(node.data_index),
         dtype=node.dtype,
     )
Exemplo n.º 5
0
    def visit_FieldIfStmt(self,
                          node: gtir.FieldIfStmt,
                          *,
                          mask: oir.Expr = None,
                          ctx: Context,
                          **kwargs: Any) -> None:
        mask_field_decl = _create_mask(ctx, f"mask_{node.id_}",
                                       self.visit(node.cond))
        current_mask = oir.FieldAccess(name=mask_field_decl.name,
                                       offset=CartesianOffset.zero(),
                                       dtype=mask_field_decl.dtype)
        combined_mask = current_mask
        if mask:
            combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                         left=mask,
                                         right=combined_mask)
        self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx)

        if node.false_branch:
            combined_mask = oir.UnaryOp(op=UnaryOperator.NOT,
                                        expr=current_mask)
            if mask:
                combined_mask = oir.BinaryOp(op=LogicalOperator.AND,
                                             left=mask,
                                             right=combined_mask)
            self.visit(
                node.false_branch.body,
                mask=combined_mask,
                ctx=ctx,
            )
Exemplo n.º 6
0
 def visit_FieldAccess(self, node: gtir.FieldAccess) -> oir.FieldAccess:
     return oir.FieldAccess(
         name=node.name,
         offset=self.visit(node.offset),
         data_index=self.visit(node.data_index),
         dtype=node.dtype,
         loc=node.loc,
     )
Exemplo n.º 7
0
 def visit_FieldAccess(self, node: oir.FieldAccess, *, name_map: Dict[str,
                                                                      str],
                       **kwargs: Any) -> oir.FieldAccess:
     if node.name in name_map:
         return oir.FieldAccess(name=name_map[node.name],
                                dtype=node.dtype,
                                offset=node.offset)
     return node
Exemplo n.º 8
0
 def visit_FieldAccess(self, node: oir.FieldAccess):
     if node.name not in self._field_table:
         return node
     return oir.FieldAccess(
         name=self._field_table[node.name],
         offset=node.offset,
         dtype=node.dtype,
         data_index=node.data_index,
     )
Exemplo n.º 9
0
    def _fill_stmts(
        cls,
        loop_order: common.LoopOrder,
        section: oir.VerticalLoopSection,
        filling_fields: Dict[str, str],
        first_unfilled: Dict[str, int],
        symtable: Dict[str, Any],
    ) -> Tuple[List[oir.AssignStmt], Dict[str, int]]:
        """Generate fill statements for the given loop section.

        Args:
            loop_order: forward or backward order.
            section: loop section to split.
            filling_fields: mapping from field names to cache names.
            first_unfilled: direction-normalized offset of the first unfilled cache entry for each field.

        Returns:
            A list of fill statements and an updated `first_unfilled` map.
        """
        fill_limits = cls._fill_limits(loop_order, section)
        fill_stmts = []
        for field, cache in filling_fields.items():
            lmin, lmax = fill_limits.get(field, (0, 0))
            lmin = max(lmin, first_unfilled.get(field, lmin))
            for offset in range(lmin, lmax + 1):
                k_offset = common.CartesianOffset(
                    i=0,
                    j=0,
                    k=offset if loop_order == common.LoopOrder.FORWARD else -offset,
                )
                fill_stmts.append(
                    oir.AssignStmt(
                        left=oir.FieldAccess(
                            name=cache, dtype=symtable[field].dtype, offset=k_offset
                        ),
                        right=oir.FieldAccess(
                            name=field, dtype=symtable[field].dtype, offset=k_offset
                        ),
                    )
                )
            first_unfilled[field] = lmax
        return fill_stmts, first_unfilled
Exemplo n.º 10
0
def test_field_access_to_field_slice(parallel_k):
    field_access = oir.FieldAccess(
        name="a",
        offset=common.CartesianOffset(i=-1, j=2, k=0),
        dtype=common.DataType.FLOAT64,
    )

    ctx = OirToNpir.ComputationContext()
    parallel_field_slice = OirToNpir().visit(field_access, ctx=ctx, parallel_k=parallel_k)
    assert parallel_field_slice.k_offset.parallel is parallel_k
    assert parallel_field_slice.i_offset.offset.value == -1
Exemplo n.º 11
0
def test_temp_assign(parallel_k):
    assign_stmt = oir.AssignStmt(
        left=oir.FieldAccess(
            name="a",
            offset=common.CartesianOffset.zero(),
            dtype=common.DataType.FLOAT64,
        ),
        right=oir.FieldAccess(name="b",
                              offset=common.CartesianOffset(i=-1, j=22, k=0),
                              dtype=common.DataType.FLOAT64),
    )
    ctx = OirToNpir.ComputationContext()
    _ = OirToNpir().visit(
        assign_stmt,
        ctx=ctx,
        parallel_k=parallel_k,
        mask=None,
        symtable={"a": TemporaryFactory(name="a")},
    )
    assert len(ctx.temp_defs) == 1
    assert isinstance(ctx.temp_defs["a"].left, npir.VectorTemp)
    assert isinstance(ctx.temp_defs["a"].right, npir.EmptyTemp)
Exemplo n.º 12
0
def test_native_func_call():
    oir_node = oir.NativeFuncCall(
        func=common.NativeFunction.SQRT,
        args=[
            oir.FieldAccess(
                name="a",
                offset=common.CartesianOffset.zero(),
                dtype=common.DataType.FLOAT64,
            ),
        ],
    )
    result = OirToNpir().visit(
        oir_node,
        parallel_k=True,
        ctx=OirToNpir.ComputationContext(),
    )
    assert isinstance(result, npir.VectorExpression)
Exemplo n.º 13
0
def _create_mask(ctx: "GTIRToOIR.Context", name: str, cond: oir.Expr) -> oir.Temporary:
    mask_field_decl = oir.Temporary(name=name, dtype=DataType.BOOL, dimensions=(True, True, True))
    ctx.add_decl(mask_field_decl)

    fill_mask_field = oir.HorizontalExecution(
        body=[
            oir.AssignStmt(
                left=oir.FieldAccess(
                    name=mask_field_decl.name,
                    offset=CartesianOffset.zero(),
                    dtype=mask_field_decl.dtype,
                ),
                right=cond,
            )
        ],
        declarations=[],
    )
    ctx.add_horizontal_execution(fill_mask_field)
    return mask_field_decl
Exemplo n.º 14
0
 def visit_FieldAccess(self, node: gtir.FieldAccess,
                       **kwargs: Any) -> oir.FieldAccess:
     return oir.FieldAccess(name=node.name,
                            offset=node.offset,
                            dtype=node.dtype)