Example #1
0
def test_scalarliteral(defined_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(
        npir.ScalarLiteral(dtype=defined_dtype, value="42"))
    print(result)
    match = re.match(r"(.+?)\(42\)", result)
    assert match
    match_dtype(match.groups()[0], defined_dtype)
Example #2
0
def test_broadcast_literal(defined_dtype: common.DataType,
                           is_serial: bool) -> None:
    result = NpirCodegen().visit(
        npir.Broadcast(
            expr=npir.ScalarLiteral(dtype=defined_dtype, value="42")))
    print(result)
    match = re.match(r"(.*?)\(42\)", result)
    assert match
    match_dtype(match.groups()[0], defined_dtype)
Example #3
0
def test_scalar_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(
        npir.ScalarCast(dtype=other_dtype, expr=npir.ScalarLiteral(dtype=defined_dtype, value="42"))
    )
    print(result)
    match = re.match(r"np\.(?P<other_dtype>\w*)\(np.(?P<defined_dtype>\w*)\(42\)\)", result)
    assert match
    assert match.group("defined_dtype") == defined_dtype.name.lower()
    assert match.group("other_dtype") == other_dtype.name.lower()
Example #4
0
def test_scalar_cast(defined_dtype: common.DataType,
                     other_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(
        npir.ScalarCast(dtype=other_dtype,
                        expr=npir.ScalarLiteral(dtype=defined_dtype,
                                                value="42")))
    print(result)
    match = re.match(r"(?P<other_dtype>.+)\((?P<defined_dtype>.+)\(42\)\)",
                     result)
    assert match
    match_dtype(match.group("defined_dtype"), defined_dtype)
    match_dtype(match.group("other_dtype"), other_dtype)
Example #5
0
def test_broadcast_literal(defined_dtype: common.DataType, is_serial: bool) -> None:
    result = NpirCodegen().visit(
        npir.Broadcast(expr=npir.ScalarLiteral(dtype=defined_dtype, value="42")),
        is_serial=is_serial,
        lower=(0, 0),
        upper=(0, 0),
    )
    print(result)
    match = re.match(
        r"np\.full\(\(_dI_\s*\+\s*(?P<iext>\d+)\s*,\s*_dJ_\s*\+\s*(?P<jext>\d+)\s*,\s*(?P<kbounds>[^\)]+)\),\s*np\.(?P<dtype>\w+)\(42\)\)",
        result,
    )
    assert match
    assert tuple(match.group(ext) for ext in ("iext", "jext")) == ("0", "0")
    assert match.group("kbounds") == "1" if is_serial else "K - k"
    assert match.group("dtype") == defined_dtype.name.lower()
Example #6
0
def test_scalarliteral(defined_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(npir.ScalarLiteral(dtype=defined_dtype, value="42"))
    print(result)
    match = re.match(r"np.(\w*?)\(42\)", result)
    assert match
    assert match.groups()[0] == defined_dtype.name.lower()