示例#1
0
def test_assign_with_mask_local() -> None:
    result = NpirCodegen().visit(
        VectorAssignFactory(
            left=LocalScalarAccessFactory(name="tmp"),
            mask=FieldSliceFactory(name="mask1", dtype=common.DataType.BOOL),
        ),
        ctx=NpirCodegen.BlockContext(),
        symtable={"tmp": ScalarDeclFactory(name="tmp", dtype=common.DataType.INT32)},
    )
    print(result)
    assert re.match(r"tmp = np.where\(mask1.*, np.int32\(\)\)", result) is not None
示例#2
0
def test_mask_block_other() -> None:
    result = NpirCodegen().visit(
        npir.MaskBlock(
            body=[],
            mask=npir.VectorLogic(
                op=common.LogicalOperator.AND,
                left=FieldSliceFactory(name="a"),
                right=FieldSliceFactory(name="b"),
            ),
            mask_name="mask1",
        ))
    assert result.startswith("mask1_ = np.bitwise_and(a_[i:I")
示例#3
0
def test_broadcast_literal(defined_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(
        npir.BroadCast(expr=npir.Literal(dtype=defined_dtype, value="42")))
    print(result)
    match = re.match(r"np.(\w*?)\(42\)", result)
    assert match
    assert match.groups()[0] == defined_dtype.name.lower()
示例#4
0
def test_vector_unary_not() -> None:
    result = NpirCodegen().visit(
        npir.VectorUnaryOp(
            expr=FieldSliceFactory(name="a"),
            op=common.UnaryOperator.NOT,
        ))
    assert result == "(np.bitwise_not(a_[i:I, j:J, k_:(k_ + 1)]))"
示例#5
0
def test_vector_assign() -> None:
    result = NpirCodegen().visit(
        VectorAssignFactory(
            left__name="a",
            right__name="b",
        ))
    assert result == "a_[i:I, j:J, k_:(k_ + 1)] = b_[i:I, j:J, k_:(k_ + 1)]"
示例#6
0
def test_full_computation_valid(tmp_path) -> None:
    computation = ComputationFactory(
        vertical_passes__0__body__0__body__0=VectorAssignFactory(
            left__name="a",
            right=VectorArithmeticFactory(left__name="b",
                                          right=ParamAccessFactory(name="p"),
                                          op=common.ArithmeticOperator.ADD),
        ),
        param_decls=[ScalarDeclFactory(name="p")],
    )
    result = NpirCodegen().visit(computation)
    print(result)
    mod_path = tmp_path / "npir_codegen_1.py"
    mod_path.write_text(result)

    sys.path.append(str(tmp_path))
    import npir_codegen_1 as mod

    a = np.zeros((10, 10, 10))
    b = np.ones_like(a) * 3
    p = 2
    mod.run(
        a=a,
        b=b,
        p=p,
        _domain_=(8, 5, 9),
        _origin_={
            "a": (1, 1, 0),
            "b": (0, 0, 0)
        },
    )
    assert (a[1:9, 1:6, 0:9] == 5).all()
示例#7
0
def test_computation() -> None:
    result = NpirCodegen().visit(
        ComputationFactory(
            vertical_passes__0__body__0__body__0=VectorAssignFactory(
                left__name="a", right__name="b"
            )
        )
    )
    print(result)
    match = re.match(
        (
            r"import numbers\n"
            r"from typing import Tuple\n+"
            r"import numpy as np\n"
            r"import scipy.special\n+"
            r"class Field:\n"
            r"(.*\n)+"
            r"def run\(\*, a, b, _domain_, _origin_\):\n"
            r"\n?"
            r"(    .*?\n)*"
        ),
        result,
        re.MULTILINE,
    )
    assert match
示例#8
0
def test_vector_assign(left, is_serial: bool) -> None:
    result = NpirCodegen().visit(
        VectorAssignFactory(left=left, right=FieldSliceFactory(name="right")),
        ctx=NpirCodegen.BlockContext(),
        is_serial=is_serial,
    )
    left_str, right_str = result.split(" = ")

    k_str = "k_:k_+1" if is_serial else "k:K"

    if isinstance(left, npir.FieldSlice):
        assert left_str == "left[i:I, j:J, " + k_str + "]"
    else:
        assert left_str == "left"

    assert right_str == "right[i:I, j:J, " + k_str + "]"
示例#9
0
def test_scalarliteral(defined_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(
        npir.ScalarLiteral(dtype=defined_dtype, value="42"))
    print(result)
    match = re.match(r"(.+?)\(42\)", result)
    assert match
    match_dtype(match.groups()[0], defined_dtype)
示例#10
0
def test_temp_definition() -> None:
    result = NpirCodegen().visit(
        TemporaryDeclFactory(name="a",
                             offset=(1, 2),
                             padding=(3, 4),
                             dtype=common.DataType.FLOAT32))
    print(result)
    assert result == "a = Field.empty((_dI_ + 3, _dJ_ + 4, _dK_), np.float32, (1, 2, 0))"
示例#11
0
def test_vector_arithmetic() -> None:
    result = NpirCodegen().visit(
        npir.VectorArithmetic(
            left=FieldSliceFactory(name="a"),
            right=FieldSliceFactory(name="b"),
            op=common.ArithmeticOperator.ADD,
        ))
    assert result == "(a_[i:I, j:J, k_:(k_ + 1)] + b_[i:I, j:J, k_:(k_ + 1)])"
示例#12
0
def test_vector_unary_not() -> None:
    result = NpirCodegen().visit(
        npir.VectorUnaryOp(
            op=common.UnaryOperator.NOT,
            expr=FieldSliceFactory(name="mask", dtype=common.DataType.BOOL),
        )
    )
    assert result == "(np.bitwise_not(mask[i:I, j:J, k:K]))"
示例#13
0
def test_temp_with_extent_definition() -> None:
    result = NpirCodegen().visit(
        VectorAssignFactory(temp_init=True, temp_name="a"),
        field_extents={"a": Extent((0, 1), (-2, 3))},
    )
    assert (
        result ==
        "a_ = ShimmedView(np.zeros((_dI_ + 1, _dJ_ + 5, _dK_), dtype=np.int64), [0, 2, 0])"
    )
示例#14
0
def test_vertical_pass_par() -> None:
    result = NpirCodegen().visit(VerticalPassFactory(direction=common.LoopOrder.PARALLEL))
    print(result)
    match = re.match(
        (r"(#.*?\n)?" r"k, K = _dk_, _dK_\n"),
        result,
        re.MULTILINE,
    )
    assert match
示例#15
0
def test_horizontal_block() -> None:
    result = NpirCodegen().visit(HorizontalBlockFactory()).strip("\n")
    print(result)
    match = re.match(
        r"#.*\n" r"i, I = _di_ - 0, _dI_ \+ 0\n" r"j, J = _dj_ - 0, _dJ_ \+ 0\n",
        result,
        re.MULTILINE,
    )
    assert match
示例#16
0
def test_horizontal_block() -> None:
    result = NpirCodegen().visit(npir.HorizontalBlock(body=[]))
    print(result)
    match = re.match(
        r"(#.*?\n)?i, I = _di_ - 0, _dI_ \+ 0\nj, J = _dj_ - 0, _dJ_ \+ 0\n",
        result,
        re.MULTILINE,
    )
    assert match
示例#17
0
def test_broadcast_literal(defined_dtype: common.DataType,
                           is_serial: bool) -> None:
    result = NpirCodegen().visit(
        npir.Broadcast(
            expr=npir.ScalarLiteral(dtype=defined_dtype, value="42")))
    print(result)
    match = re.match(r"(.*?)\(42\)", result)
    assert match
    match_dtype(match.groups()[0], defined_dtype)
示例#18
0
def test_vector_unary_op() -> None:
    result = NpirCodegen().visit(
        npir.VectorUnaryOp(
            expr=FieldSliceFactory(name="a"),
            op=common.UnaryOperator.NEG,
        ),
        is_serial=False,
    )
    assert result == "(-(a[i:I, j:J, k:K]))"
示例#19
0
def test_scalar_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(
        npir.ScalarCast(dtype=other_dtype, expr=npir.ScalarLiteral(dtype=defined_dtype, value="42"))
    )
    print(result)
    match = re.match(r"np\.(?P<other_dtype>\w*)\(np.(?P<defined_dtype>\w*)\(42\)\)", result)
    assert match
    assert match.group("defined_dtype") == defined_dtype.name.lower()
    assert match.group("other_dtype") == other_dtype.name.lower()
示例#20
0
def test_vector_assign(left, is_serial: bool) -> None:
    result = NpirCodegen().visit(
        VectorAssignFactory(left=left,
                            right=FieldSliceFactory(name="right",
                                                    k_offset=-1)),
        ctx=NpirCodegen.BlockContext(),
        is_serial=is_serial,
    )
    left_str, right_str = result.split(" = ")

    if isinstance(left, npir.FieldSlice):
        k_str_left = "k_:k_ + 1" if is_serial else "k:K"
    else:
        k_str_left = ":" if is_serial else "k:K"
    k_str_right = "k_ - 1:k_" if is_serial else "k - 1:K - 1"

    assert left_str == f"left[i:I, j:J, {k_str_left}]"
    assert right_str == f"right[i:I, j:J, {k_str_right}]"
示例#21
0
def test_native_function() -> None:
    result = NpirCodegen().visit(
        NativeFuncCallFactory(
            func=common.NativeFunction.MIN,
            args=[
                FieldSliceFactory(name="a"),
                FieldSliceFactory(name="b"),
            ],
        ))
    assert result == "np.minimum(a_[i:I, j:J, k_:(k_ + 1)], b_[i:I, j:J, k_:(k_ + 1)])"
示例#22
0
def test_vector_arithmetic() -> None:
    result = NpirCodegen().visit(
        npir.VectorArithmetic(
            left=FieldSliceFactory(name="a"),
            right=FieldSliceFactory(name="b"),
            op=common.ArithmeticOperator.ADD,
        ),
        is_serial=False,
    )
    assert result == "(a[i:I, j:J, k:K] + b[i:I, j:J, k:K])"
示例#23
0
    def generate_computation(self) -> Dict[str, Union[str, Dict]]:
        computation_name = (self.builder.caching.module_prefix +
                            "computation" +
                            self.builder.caching.module_postfix + ".py")

        source = NpirCodegen.apply(self.npir)
        if self.builder.options.format_source:
            source = format_source("python", source)

        return {computation_name: source}
示例#24
0
def test_vertical_pass_par() -> None:
    result = NpirCodegen().visit(VerticalPassFactory(body=[], temp_defs=[]))
    print(result)
    match = re.match(
        (r"(#.*?\n)?"
         r"k, K = _dk_, _dK_\n"),
        result,
        re.MULTILINE,
    )
    assert match
示例#25
0
def test_mask_block_broadcast() -> None:
    result = NpirCodegen().visit(
        npir.MaskBlock(
            body=[],
            mask=npir.BroadCast(expr=npir.Literal(
                dtype=common.DataType.BOOL, value=common.BuiltInLiteral.TRUE)),
            mask_name="mask1",
        ),
        is_serial=False,
    )
    assert result == "mask1_ = np.full((I - i, J - j, K - k), np.bool(True))\n"
示例#26
0
def test_cast(defined_dtype: common.DataType,
              other_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(
        npir.Cast(dtype=other_dtype,
                  expr=npir.Literal(dtype=defined_dtype, value="42")))
    print(result)
    match = re.match(r"^np.(\w*?)\(np.(\w*)\(42\), dtype=np.(\w*)\)", result)
    assert match
    assert match.groups()[0] == "array"
    assert match.groups()[1] == defined_dtype.name.lower()
    assert match.groups()[2] == other_dtype.name.lower()
示例#27
0
def test_scalar_cast(defined_dtype: common.DataType,
                     other_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(
        npir.ScalarCast(dtype=other_dtype,
                        expr=npir.ScalarLiteral(dtype=defined_dtype,
                                                value="42")))
    print(result)
    match = re.match(r"(?P<other_dtype>.+)\((?P<defined_dtype>.+)\(42\)\)",
                     result)
    assert match
    match_dtype(match.group("defined_dtype"), defined_dtype)
    match_dtype(match.group("other_dtype"), other_dtype)
示例#28
0
def test_vector_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None:
    result = NpirCodegen().visit(
        npir.VectorCast(
            dtype=other_dtype,
            expr=npir.FieldSlice(name="a", i_offset=0, j_offset=0, k_offset=0, dtype=defined_dtype),
        )
    )
    print(result)
    match = re.match(r"(?P<name>\w+)\[.*]\.astype\(np\.(?P<dtype>\w+)\)", result)
    assert match
    assert match.group("name") == "a"
    assert match.group("dtype") == other_dtype.name.lower()
示例#29
0
def test_native_function() -> None:
    result = NpirCodegen().visit(
        NativeFuncCallFactory(
            func=common.NativeFunction.MIN,
            args=[
                FieldSliceFactory(name="a"),
                ParamAccessFactory(name="p"),
            ],
        )
    )
    print(result)
    match = re.match(r"np.minimum\(a\[.*\],\s*p\)", result)
    assert match
示例#30
0
    def generate_computation(self) -> Dict[str, Union[str, Dict]]:
        computation_name = (self.builder.caching.module_prefix +
                            "computation" +
                            self.builder.caching.module_postfix + ".py")

        ignore_np_errstate = self.builder.options.backend_opts.get(
            "ignore_np_errstate", True)
        source = NpirCodegen.apply(self.npir,
                                   ignore_np_errstate=ignore_np_errstate)
        if self.builder.options.format_source:
            source = format_source("python", source)

        return {computation_name: source}