示例#1
0
def test_mask_propagation(parallel_k):
    mask_stmt = MaskStmtFactory()
    mask_block = OirToNpir().visit(
        mask_stmt,
        ctx=OirToNpir.ComputationContext(),
        parallel_k=parallel_k,
    )
    assert mask_block.body[0].mask == mask_block.mask
示例#2
0
def test_mask_stmt_to_mask_block(parallel_k):
    mask_stmt = MaskStmtFactory(body=[])
    mask_block = OirToNpir().visit(
        mask_stmt,
        ctx=OirToNpir.ComputationContext(),
        parallel_k=parallel_k,
    )
    assert isinstance(mask_block.mask, npir.FieldSlice)
    assert mask_block.body == []
示例#3
0
def test_field_access_to_field_slice(parallel_k):
    field_access = oir.FieldAccess(
        name="a",
        offset=common.CartesianOffset(i=-1, j=2, k=0),
        dtype=common.DataType.FLOAT64,
    )

    ctx = OirToNpir.ComputationContext()
    parallel_field_slice = OirToNpir().visit(field_access,
                                             ctx=ctx,
                                             parallel_k=parallel_k)
    assert parallel_field_slice.k_offset.parallel is parallel_k
    assert parallel_field_slice.i_offset.offset.value == -1
示例#4
0
def test_stencil_to_computation() -> None:
    stencil = StencilFactory(
        name="stencil",
        params=[
            FieldDeclFactory(
                name="a",
                dtype=common.DataType.FLOAT64,
            ),
            oir.ScalarDecl(
                name="b",
                dtype=common.DataType.INT32,
            ),
        ],
        vertical_loops__0__sections__0__horizontal_executions__0__body=[
            AssignStmtFactory(
                left=FieldAccessFactory(name="a"), right=ScalarAccessFactory(name="b")
            )
        ],
    )
    computation = OirToNpir().visit(stencil)

    assert set(d.name for d in computation.api_field_decls) == {
        "a",
    }
    assert set(computation.arguments) == {"a", "b"}
    assert len(computation.vertical_passes) == 1
示例#5
0
def test_vertical_loop_section_to_vertical_pass():
    vertical_loop_section = VerticalLoopSectionFactory(
        horizontal_executions=[])
    vertical_pass = OirToNpir().visit(vertical_loop_section,
                                      loop_order=common.LoopOrder.PARALLEL)

    assert vertical_pass.body == []
示例#6
0
def test_field_access_to_field_slice_variablek() -> None:
    field_access = FieldAccessFactory(
        offset=oir.VariableKOffset(k=oir.Literal(value="1", dtype=common.DataType.INT32))
    )
    field_slice = OirToNpir().visit(field_access)
    assert (field_slice.i_offset, field_slice.j_offset) == (0, 0)
    assert field_slice.k_offset.k.value == "1"
示例#7
0
def make_block_and_transform(**kwargs) -> npir.HorizontalBlock:
    oir_stencil = StencilFactory(
        vertical_loops__0__sections__0__horizontal_executions=[
            HorizontalExecutionFactory(**kwargs)
        ])

    return OirToNpir().visit(oir_stencil).vertical_passes[0].body[0]
示例#8
0
def test_native_func_call():
    oir_node = oir.NativeFuncCall(
        func=common.NativeFunction.SQRT,
        args=[
            oir.FieldAccess(
                name="a",
                offset=common.CartesianOffset.zero(),
                dtype=common.DataType.FLOAT64,
            ),
        ],
    )
    result = OirToNpir().visit(
        oir_node,
        parallel_k=True,
        ctx=OirToNpir.ComputationContext(),
    )
    assert isinstance(result, npir.VectorExpression)
示例#9
0
def test_literal_broadcast() -> None:
    result = OirToNpir().visit(
        AssignStmtFactory(
            left__dtype=common.DataType.FLOAT32,
            right=oir.Literal(value="42", dtype=common.DataType.FLOAT32),
        ))
    assert isinstance(result.right, npir.Broadcast)
    assert (result.right.expr.value,
            result.right.expr.dtype) == ("42", common.DataType.FLOAT32)
示例#10
0
def test_assign_stmt_to_vector_assign(parallel_k):
    assign_stmt = oir.AssignStmt(
        left=oir.FieldAccess(name="a",
                             offset=common.CartesianOffset.zero(),
                             dtype=common.DataType.FLOAT64),
        right=oir.FieldAccess(name="b",
                              offset=common.CartesianOffset(i=-1, j=22, k=0),
                              dtype=common.DataType.FLOAT64),
    )

    ctx = OirToNpir.ComputationContext()
    v_assign = OirToNpir().visit(assign_stmt,
                                 ctx=ctx,
                                 parallel_k=parallel_k,
                                 mask=None)
    assert isinstance(v_assign, npir.VectorAssign)
    assert v_assign.left.k_offset.parallel is parallel_k
    assert v_assign.right.k_offset.parallel is parallel_k
示例#11
0
def test_binary_op_to_vector_arithmetic():
    binop = oir.BinaryOp(
        op=common.ArithmeticOperator.ADD,
        left=oir.Literal(dtype=common.DataType.INT32, value="2"),
        right=oir.Literal(dtype=common.DataType.INT32, value="2"),
    )
    result = OirToNpir().visit(binop)
    assert isinstance(result, npir.VectorArithmetic)
    assert isinstance(result.left, npir.BroadCast)
    assert isinstance(result.right, npir.BroadCast)
示例#12
0
def test_literal(broadcast):
    gtir_literal = oir.Literal(value="42", dtype=common.DataType.INT32)
    result = OirToNpir().visit(gtir_literal, broadcast=broadcast)
    npir_literal = result
    if broadcast:
        assert isinstance(result, npir.BroadCast)
        assert isinstance(result, npir.VectorExpression)
        npir_literal = result.expr
    assert gtir_literal.dtype == npir_literal.dtype
    assert gtir_literal.kind == npir_literal.kind
    assert gtir_literal.value == npir_literal.value
示例#13
0
def test_cast(broadcast):
    itof = oir.Cast(dtype=common.DataType.FLOAT64,
                    expr=oir.Literal(value="42", dtype=common.DataType.INT32))
    result = OirToNpir().visit(itof, broadcast=broadcast)
    assert isinstance(result, npir.BroadCast if broadcast else npir.Cast)
    cast = result
    if broadcast:
        assert isinstance(result, npir.BroadCast)
        assert isinstance(result, npir.VectorExpression)
        cast = result.expr
    assert cast.dtype == itof.dtype
    assert cast.expr.value == "42"
示例#14
0
def test_temp_assign(parallel_k):
    assign_stmt = oir.AssignStmt(
        left=oir.FieldAccess(
            name="a",
            offset=common.CartesianOffset.zero(),
            dtype=common.DataType.FLOAT64,
        ),
        right=oir.FieldAccess(name="b",
                              offset=common.CartesianOffset(i=-1, j=22, k=0),
                              dtype=common.DataType.FLOAT64),
    )
    ctx = OirToNpir.ComputationContext()
    _ = OirToNpir().visit(
        assign_stmt,
        ctx=ctx,
        parallel_k=parallel_k,
        mask=None,
        symtable={"a": TemporaryFactory(name="a")},
    )
    assert len(ctx.temp_defs) == 1
    assert isinstance(ctx.temp_defs["a"].left, npir.VectorTemp)
    assert isinstance(ctx.temp_defs["a"].right, npir.EmptyTemp)
示例#15
0
 def _make_npir(self) -> npir.Computation:
     base_oir = GTIRToOIR().visit(self.builder.gtir)
     oir_pipeline = self.builder.options.backend_opts.get(
         "oir_pipeline",
         DefaultPipeline(skip=[
             IJCacheDetection,
             KCacheDetection,
             PruneKCacheFills,
             PruneKCacheFlushes,
             FillFlushToLocalKCaches,
         ]),
     )
     oir = oir_pipeline.run(base_oir)
     return OirToNpir().visit(oir)
示例#16
0
 def _make_npir(self) -> npir.Computation:
     base_oir = GTIRToOIR().visit(self.builder.gtir)
     oir_pipeline = self.builder.options.backend_opts.get(
         "oir_pipeline",
         DefaultPipeline(skip=[
             IJCacheDetection,
             KCacheDetection,
             PruneKCacheFills,
             PruneKCacheFlushes,
         ]),
     )
     oir_node = oir_pipeline.run(base_oir)
     base_npir = OirToNpir().visit(oir_node)
     npir_node = ScalarsToTemporaries().visit(base_npir)
     return npir_node
示例#17
0
def test_stencil_to_computation():
    stencil = StencilFactory(
        name="stencil",
        params=[
            FieldDeclFactory(
                name="a",
                dtype=common.DataType.FLOAT64,
            ),
            oir.ScalarDecl(
                name="b",
                dtype=common.DataType.INT32,
            ),
        ],
        vertical_loops__0__sections__0__horizontal_executions__0__body=[
            AssignStmtFactory(left=FieldAccessFactory(name="a"),
                              right=ScalarAccessFactory(name="b"))
        ],
    )
    computation = OirToNpir().visit(stencil)

    assert computation.field_params == ["a"]
    assert computation.params == ["a", "b"]
    assert len(computation.vertical_passes) == 1
示例#18
0
def test_mask_propagation() -> None:
    mask_stmt = MaskStmtFactory()
    assign_stmts = OirToNpir().visit(mask_stmt)
    assert assign_stmts[0].mask == OirToNpir().visit(mask_stmt.mask)
示例#19
0
def test_mask_stmt_to_assigns() -> None:
    mask_stmt = MaskStmtFactory(body=[AssignStmtFactory()])
    assign_stmts = OirToNpir().visit(mask_stmt)
    assert isinstance(assign_stmts[0].mask, npir.FieldSlice)
    assert len(assign_stmts) == 1
示例#20
0
def test_binary_op_to_npir(oir_node: oir.Expr, npir_type: npir.Expr) -> None:
    assert isinstance(OirToNpir().visit(oir_node), npir_type)
示例#21
0
def test_field_access_to_field_slice_cartesian() -> None:
    field_access = FieldAccessFactory(offset__i=-1, offset__j=2, offset__k=0)
    field_slice = OirToNpir().visit(field_access)
    assert (field_slice.i_offset, field_slice.j_offset, field_slice.k_offset) == (-1, 2, 0)
示例#22
0
def test_native_func_call() -> None:
    assert isinstance(
        OirToNpir().visit(NativeFuncCallFactory(args__0=FieldAccessFactory())), npir.NativeFuncCall
    )
示例#23
0
def test_cast(oir_int_expr, npir_type) -> None:
    assert isinstance(
        OirToNpir().visit(oir.Cast(dtype=common.DataType.FLOAT64, expr=oir_int_expr)), npir_type
    )
示例#24
0
def test_horizontal_execution_to_vector_assigns() -> None:
    horizontal_execution = HorizontalExecutionFactory(body=[])
    horizontal_block = OirToNpir().visit(horizontal_execution)
    assert horizontal_block.body == []
示例#25
0
def test_mask_stmt_to_assigns() -> None:
    mask_stmt = MaskStmtFactory(body=[AssignStmtFactory()])
    assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2))
    assert isinstance(assign_stmts[0].right.cond, npir.FieldSlice)
    assert len(assign_stmts) == 1
示例#26
0
def test_mask_propagation() -> None:
    mask_stmt = MaskStmtFactory()
    assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2))
    assert assign_stmts[0].right.cond == OirToNpir().visit(mask_stmt.mask)
示例#27
0
def test_horizontal_execution_to_vector_assigns():
    horizontal_execution = HorizontalExecutionFactory(body=[])
    horizontal_region = OirToNpir().visit(horizontal_execution)
    assert horizontal_region.body == []
示例#28
0
def test_vertical_loop_to_vertical_passes():
    vertical_loop = VerticalLoopFactory(sections__0__horizontal_executions=[])
    vertical_passes = OirToNpir().visit(vertical_loop)

    assert vertical_passes[0].body == []