예제 #1
0
def test_mask_propagation(parallel_k):
    mask_stmt = MaskStmtFactory()
    mask_block = OirToNpir().visit(
        mask_stmt,
        ctx=OirToNpir.ComputationContext(),
        parallel_k=parallel_k,
    )
    assert mask_block.body[0].mask == mask_block.mask
예제 #2
0
def test_mask_stmt_to_mask_block(parallel_k):
    mask_stmt = MaskStmtFactory(body=[])
    mask_block = OirToNpir().visit(
        mask_stmt,
        ctx=OirToNpir.ComputationContext(),
        parallel_k=parallel_k,
    )
    assert isinstance(mask_block.mask, npir.FieldSlice)
    assert mask_block.body == []
예제 #3
0
def test_field_access_to_field_slice(parallel_k):
    field_access = oir.FieldAccess(
        name="a",
        offset=common.CartesianOffset(i=-1, j=2, k=0),
        dtype=common.DataType.FLOAT64,
    )

    ctx = OirToNpir.ComputationContext()
    parallel_field_slice = OirToNpir().visit(field_access, ctx=ctx, parallel_k=parallel_k)
    assert parallel_field_slice.k_offset.parallel is parallel_k
    assert parallel_field_slice.i_offset.offset.value == -1
예제 #4
0
def test_assign_stmt_to_vector_assign(parallel_k):
    assign_stmt = oir.AssignStmt(
        left=oir.FieldAccess(
            name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64
        ),
        right=oir.FieldAccess(
            name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64
        ),
    )

    ctx = OirToNpir.ComputationContext()
    v_assign = OirToNpir().visit(assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None)
    assert isinstance(v_assign, npir.VectorAssign)
    assert v_assign.left.k_offset.parallel is parallel_k
    assert v_assign.right.k_offset.parallel is parallel_k
예제 #5
0
def test_native_func_call():
    oir_node = oir.NativeFuncCall(
        func=common.NativeFunction.SQRT,
        args=[
            oir.FieldAccess(
                name="a",
                offset=common.CartesianOffset.zero(),
                dtype=common.DataType.FLOAT64,
            ),
        ],
    )
    result = OirToNpir().visit(
        oir_node,
        parallel_k=True,
        ctx=OirToNpir.ComputationContext(),
    )
    assert isinstance(result, npir.VectorExpression)
예제 #6
0
def test_binary_op_to_vector_arithmetic():
    binop = oir.BinaryOp(
        op=common.ArithmeticOperator.ADD,
        left=oir.Literal(dtype=common.DataType.INT32, value="2"),
        right=oir.Literal(dtype=common.DataType.INT32, value="2"),
    )
    result = OirToNpir().visit(binop)
    assert isinstance(result, npir.VectorArithmetic)
    assert isinstance(result.left, npir.BroadCast)
    assert isinstance(result.right, npir.BroadCast)
예제 #7
0
def test_literal(broadcast):
    gtir_literal = oir.Literal(value="42", dtype=common.DataType.INT32)
    result = OirToNpir().visit(gtir_literal, broadcast=broadcast)
    npir_literal = result
    if broadcast:
        assert isinstance(result, npir.BroadCast)
        assert isinstance(result, npir.VectorExpression)
        npir_literal = result.expr
    assert gtir_literal.dtype == npir_literal.dtype
    assert gtir_literal.kind == npir_literal.kind
    assert gtir_literal.value == npir_literal.value
예제 #8
0
def test_cast(broadcast):
    itof = oir.Cast(
        dtype=common.DataType.FLOAT64, expr=oir.Literal(value="42", dtype=common.DataType.INT32)
    )
    result = OirToNpir().visit(itof, broadcast=broadcast)
    assert isinstance(result, npir.BroadCast if broadcast else npir.Cast)
    cast = result
    if broadcast:
        assert isinstance(result, npir.BroadCast)
        assert isinstance(result, npir.VectorExpression)
        cast = result.expr
    assert cast.dtype == itof.dtype
    assert cast.expr.value == "42"
예제 #9
0
def test_temp_assign(parallel_k):
    assign_stmt = oir.AssignStmt(
        left=oir.FieldAccess(
            name="a",
            offset=common.CartesianOffset.zero(),
            dtype=common.DataType.FLOAT64,
        ),
        right=oir.FieldAccess(
            name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64
        ),
    )
    ctx = OirToNpir.ComputationContext()
    _ = OirToNpir().visit(
        assign_stmt,
        ctx=ctx,
        parallel_k=parallel_k,
        mask=None,
        symtable={"a": TemporaryFactory(name="a")},
    )
    assert len(ctx.temp_defs) == 1
    assert isinstance(ctx.temp_defs["a"].left, npir.VectorTemp)
    assert isinstance(ctx.temp_defs["a"].right, npir.EmptyTemp)
예제 #10
0
def test_stencil_to_computation():
    stencil = StencilFactory(
        name="stencil",
        params=[
            FieldDeclFactory(
                name="a",
                dtype=common.DataType.FLOAT64,
            ),
            oir.ScalarDecl(
                name="b",
                dtype=common.DataType.INT32,
            ),
        ],
        vertical_loops__0__sections__0__horizontal_executions__0__body=[
            AssignStmtFactory(
                left=FieldAccessFactory(name="a"), right=ScalarAccessFactory(name="b")
            )
        ],
    )
    computation = OirToNpir().visit(stencil)

    assert computation.field_params == ["a"]
    assert computation.params == ["a", "b"]
    assert len(computation.vertical_passes) == 1
예제 #11
0
def test_horizontal_execution_to_vector_assigns():
    horizontal_execution = HorizontalExecutionFactory(body=[])
    horizontal_region = OirToNpir().visit(horizontal_execution)
    assert horizontal_region.body == []
예제 #12
0
def test_vertical_loop_section_to_vertical_pass():
    vertical_loop_section = VerticalLoopSectionFactory(horizontal_executions=[])
    vertical_pass = OirToNpir().visit(vertical_loop_section, loop_order=common.LoopOrder.PARALLEL)

    assert vertical_pass.body == []
예제 #13
0
def test_vertical_loop_to_vertical_passes():
    vertical_loop = VerticalLoopFactory(sections__0__horizontal_executions=[])
    vertical_passes = OirToNpir().visit(vertical_loop)

    assert vertical_passes[0].body == []
예제 #14
0
 def _make_npir(self) -> npir.Computation:
     return OirToNpir().visit(
         # TODO (ricoh) apply optimizations, skip only the ones that fail
         OirPipeline(GTIRToOIR().visit(self.builder.gtir)).apply([]))