示例#1
0
def test_mask_propagation(parallel_k):
    mask_stmt = MaskStmtFactory()
    mask_block = OirToNpir().visit(
        mask_stmt,
        ctx=OirToNpir.ComputationContext(),
        parallel_k=parallel_k,
    )
    assert mask_block.body[0].mask == mask_block.mask
示例#2
0
def test_mask_stmt_to_mask_block(parallel_k):
    mask_stmt = MaskStmtFactory(body=[])
    mask_block = OirToNpir().visit(
        mask_stmt,
        ctx=OirToNpir.ComputationContext(),
        parallel_k=parallel_k,
    )
    assert isinstance(mask_block.mask, npir.FieldSlice)
    assert mask_block.body == []
示例#3
0
def test_field_access_to_field_slice(parallel_k):
    field_access = oir.FieldAccess(
        name="a",
        offset=common.CartesianOffset(i=-1, j=2, k=0),
        dtype=common.DataType.FLOAT64,
    )

    ctx = OirToNpir.ComputationContext()
    parallel_field_slice = OirToNpir().visit(field_access, ctx=ctx, parallel_k=parallel_k)
    assert parallel_field_slice.k_offset.parallel is parallel_k
    assert parallel_field_slice.i_offset.offset.value == -1
示例#4
0
def test_assign_stmt_to_vector_assign(parallel_k):
    assign_stmt = oir.AssignStmt(
        left=oir.FieldAccess(
            name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64
        ),
        right=oir.FieldAccess(
            name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64
        ),
    )

    ctx = OirToNpir.ComputationContext()
    v_assign = OirToNpir().visit(assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None)
    assert isinstance(v_assign, npir.VectorAssign)
    assert v_assign.left.k_offset.parallel is parallel_k
    assert v_assign.right.k_offset.parallel is parallel_k
示例#5
0
def test_native_func_call():
    oir_node = oir.NativeFuncCall(
        func=common.NativeFunction.SQRT,
        args=[
            oir.FieldAccess(
                name="a",
                offset=common.CartesianOffset.zero(),
                dtype=common.DataType.FLOAT64,
            ),
        ],
    )
    result = OirToNpir().visit(
        oir_node,
        parallel_k=True,
        ctx=OirToNpir.ComputationContext(),
    )
    assert isinstance(result, npir.VectorExpression)
示例#6
0
def test_temp_assign(parallel_k):
    assign_stmt = oir.AssignStmt(
        left=oir.FieldAccess(
            name="a",
            offset=common.CartesianOffset.zero(),
            dtype=common.DataType.FLOAT64,
        ),
        right=oir.FieldAccess(
            name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64
        ),
    )
    ctx = OirToNpir.ComputationContext()
    _ = OirToNpir().visit(
        assign_stmt,
        ctx=ctx,
        parallel_k=parallel_k,
        mask=None,
        symtable={"a": TemporaryFactory(name="a")},
    )
    assert len(ctx.temp_defs) == 1
    assert isinstance(ctx.temp_defs["a"].left, npir.VectorTemp)
    assert isinstance(ctx.temp_defs["a"].right, npir.EmptyTemp)