def test_binary_op_to_vector_arithmetic():
    binop = oir.BinaryOp(
        op=common.ArithmeticOperator.ADD,
        left=oir.Literal(dtype=common.DataType.INT32, value="2"),
        right=oir.Literal(dtype=common.DataType.INT32, value="2"),
    )
    result = OirToNpir().visit(binop)
    assert isinstance(result, npir.VectorArithmetic)
    assert isinstance(result.left, npir.BroadCast)
    assert isinstance(result.right, npir.BroadCast)
Beispiel #2
0
def test_field_access_to_field_slice_variablek() -> None:
    field_access = FieldAccessFactory(
        offset=oir.VariableKOffset(k=oir.Literal(value="1", dtype=common.DataType.INT32))
    )
    field_slice = OirToNpir().visit(field_access)
    assert (field_slice.i_offset, field_slice.j_offset) == (0, 0)
    assert field_slice.k_offset.k.value == "1"
Beispiel #3
0
def test_literal_broadcast() -> None:
    result = OirToNpir().visit(
        AssignStmtFactory(
            left__dtype=common.DataType.FLOAT32,
            right=oir.Literal(value="42", dtype=common.DataType.FLOAT32),
        ))
    assert isinstance(result.right, npir.Broadcast)
    assert (result.right.expr.value,
            result.right.expr.dtype) == ("42", common.DataType.FLOAT32)
def test_literal(broadcast):
    gtir_literal = oir.Literal(value="42", dtype=common.DataType.INT32)
    result = OirToNpir().visit(gtir_literal, broadcast=broadcast)
    npir_literal = result
    if broadcast:
        assert isinstance(result, npir.BroadCast)
        assert isinstance(result, npir.VectorExpression)
        npir_literal = result.expr
    assert gtir_literal.dtype == npir_literal.dtype
    assert gtir_literal.kind == npir_literal.kind
    assert gtir_literal.value == npir_literal.value
def test_cast(broadcast):
    itof = oir.Cast(dtype=common.DataType.FLOAT64,
                    expr=oir.Literal(value="42", dtype=common.DataType.INT32))
    result = OirToNpir().visit(itof, broadcast=broadcast)
    assert isinstance(result, npir.BroadCast if broadcast else npir.Cast)
    cast = result
    if broadcast:
        assert isinstance(result, npir.BroadCast)
        assert isinstance(result, npir.VectorExpression)
        cast = result.expr
    assert cast.dtype == itof.dtype
    assert cast.expr.value == "42"
Beispiel #6
0
 def visit_Literal(self, node: gtir.Literal) -> oir.Literal:
     return oir.Literal(
         value=self.visit(node.value), dtype=node.dtype, kind=node.kind, loc=node.loc
     )
Beispiel #7
0
 def visit_Literal(self, node: gtir.Literal, **kwargs: Any) -> oir.Literal:
     return oir.Literal(value=self.visit(node.value), dtype=node.dtype, kind=node.kind)
Beispiel #8
0
def test_literal_broadcast() -> None:
    result = OirToNpir().visit(
        AssignStmtFactory(
            left__dtype=common.DataType.FLOAT32,
            right=oir.Literal(value="42", dtype=common.DataType.FLOAT32),
        ),
        local_assigns={},
    )
    assert isinstance(result.right, npir.Broadcast)
    assert (result.right.expr.value, result.right.expr.dtype) == ("42", common.DataType.FLOAT32)


@pytest.mark.parametrize(
    "oir_int_expr, npir_type",
    (
        (oir.Literal(value="42", dtype=common.DataType.INT32), npir.ScalarCast),
        (FieldAccessFactory(dtype=common.DataType.INT32), npir.VectorCast),
    ),
)
def test_cast(oir_int_expr, npir_type) -> None:
    assert isinstance(
        OirToNpir().visit(oir.Cast(dtype=common.DataType.FLOAT64, expr=oir_int_expr)), npir_type
    )


def test_native_func_call() -> None:
    assert isinstance(
        OirToNpir().visit(NativeFuncCallFactory(args__0=FieldAccessFactory())), npir.NativeFuncCall
    )