def test_assign_with_mask_local() -> None: result = NpirCodegen().visit( VectorAssignFactory( left=LocalScalarAccessFactory(name="tmp"), mask=FieldSliceFactory(name="mask1", dtype=common.DataType.BOOL), ), ctx=NpirCodegen.BlockContext(), symtable={"tmp": ScalarDeclFactory(name="tmp", dtype=common.DataType.INT32)}, ) print(result) assert re.match(r"tmp = np.where\(mask1.*, np.int32\(\)\)", result) is not None
def test_mask_block_other() -> None: result = NpirCodegen().visit( npir.MaskBlock( body=[], mask=npir.VectorLogic( op=common.LogicalOperator.AND, left=FieldSliceFactory(name="a"), right=FieldSliceFactory(name="b"), ), mask_name="mask1", )) assert result.startswith("mask1_ = np.bitwise_and(a_[i:I")
def test_broadcast_literal(defined_dtype: common.DataType) -> None: result = NpirCodegen().visit( npir.BroadCast(expr=npir.Literal(dtype=defined_dtype, value="42"))) print(result) match = re.match(r"np.(\w*?)\(42\)", result) assert match assert match.groups()[0] == defined_dtype.name.lower()
def test_vector_unary_not() -> None: result = NpirCodegen().visit( npir.VectorUnaryOp( expr=FieldSliceFactory(name="a"), op=common.UnaryOperator.NOT, )) assert result == "(np.bitwise_not(a_[i:I, j:J, k_:(k_ + 1)]))"
def test_vector_assign() -> None: result = NpirCodegen().visit( VectorAssignFactory( left__name="a", right__name="b", )) assert result == "a_[i:I, j:J, k_:(k_ + 1)] = b_[i:I, j:J, k_:(k_ + 1)]"
def test_full_computation_valid(tmp_path) -> None: computation = ComputationFactory( vertical_passes__0__body__0__body__0=VectorAssignFactory( left__name="a", right=VectorArithmeticFactory(left__name="b", right=ParamAccessFactory(name="p"), op=common.ArithmeticOperator.ADD), ), param_decls=[ScalarDeclFactory(name="p")], ) result = NpirCodegen().visit(computation) print(result) mod_path = tmp_path / "npir_codegen_1.py" mod_path.write_text(result) sys.path.append(str(tmp_path)) import npir_codegen_1 as mod a = np.zeros((10, 10, 10)) b = np.ones_like(a) * 3 p = 2 mod.run( a=a, b=b, p=p, _domain_=(8, 5, 9), _origin_={ "a": (1, 1, 0), "b": (0, 0, 0) }, ) assert (a[1:9, 1:6, 0:9] == 5).all()
def test_computation() -> None: result = NpirCodegen().visit( ComputationFactory( vertical_passes__0__body__0__body__0=VectorAssignFactory( left__name="a", right__name="b" ) ) ) print(result) match = re.match( ( r"import numbers\n" r"from typing import Tuple\n+" r"import numpy as np\n" r"import scipy.special\n+" r"class Field:\n" r"(.*\n)+" r"def run\(\*, a, b, _domain_, _origin_\):\n" r"\n?" r"( .*?\n)*" ), result, re.MULTILINE, ) assert match
def test_vector_assign(left, is_serial: bool) -> None: result = NpirCodegen().visit( VectorAssignFactory(left=left, right=FieldSliceFactory(name="right")), ctx=NpirCodegen.BlockContext(), is_serial=is_serial, ) left_str, right_str = result.split(" = ") k_str = "k_:k_+1" if is_serial else "k:K" if isinstance(left, npir.FieldSlice): assert left_str == "left[i:I, j:J, " + k_str + "]" else: assert left_str == "left" assert right_str == "right[i:I, j:J, " + k_str + "]"
def test_scalarliteral(defined_dtype: common.DataType) -> None: result = NpirCodegen().visit( npir.ScalarLiteral(dtype=defined_dtype, value="42")) print(result) match = re.match(r"(.+?)\(42\)", result) assert match match_dtype(match.groups()[0], defined_dtype)
def test_temp_definition() -> None: result = NpirCodegen().visit( TemporaryDeclFactory(name="a", offset=(1, 2), padding=(3, 4), dtype=common.DataType.FLOAT32)) print(result) assert result == "a = Field.empty((_dI_ + 3, _dJ_ + 4, _dK_), np.float32, (1, 2, 0))"
def test_vector_arithmetic() -> None: result = NpirCodegen().visit( npir.VectorArithmetic( left=FieldSliceFactory(name="a"), right=FieldSliceFactory(name="b"), op=common.ArithmeticOperator.ADD, )) assert result == "(a_[i:I, j:J, k_:(k_ + 1)] + b_[i:I, j:J, k_:(k_ + 1)])"
def test_vector_unary_not() -> None: result = NpirCodegen().visit( npir.VectorUnaryOp( op=common.UnaryOperator.NOT, expr=FieldSliceFactory(name="mask", dtype=common.DataType.BOOL), ) ) assert result == "(np.bitwise_not(mask[i:I, j:J, k:K]))"
def test_temp_with_extent_definition() -> None: result = NpirCodegen().visit( VectorAssignFactory(temp_init=True, temp_name="a"), field_extents={"a": Extent((0, 1), (-2, 3))}, ) assert ( result == "a_ = ShimmedView(np.zeros((_dI_ + 1, _dJ_ + 5, _dK_), dtype=np.int64), [0, 2, 0])" )
def test_vertical_pass_par() -> None: result = NpirCodegen().visit(VerticalPassFactory(direction=common.LoopOrder.PARALLEL)) print(result) match = re.match( (r"(#.*?\n)?" r"k, K = _dk_, _dK_\n"), result, re.MULTILINE, ) assert match
def test_horizontal_block() -> None: result = NpirCodegen().visit(HorizontalBlockFactory()).strip("\n") print(result) match = re.match( r"#.*\n" r"i, I = _di_ - 0, _dI_ \+ 0\n" r"j, J = _dj_ - 0, _dJ_ \+ 0\n", result, re.MULTILINE, ) assert match
def test_horizontal_block() -> None: result = NpirCodegen().visit(npir.HorizontalBlock(body=[])) print(result) match = re.match( r"(#.*?\n)?i, I = _di_ - 0, _dI_ \+ 0\nj, J = _dj_ - 0, _dJ_ \+ 0\n", result, re.MULTILINE, ) assert match
def test_broadcast_literal(defined_dtype: common.DataType, is_serial: bool) -> None: result = NpirCodegen().visit( npir.Broadcast( expr=npir.ScalarLiteral(dtype=defined_dtype, value="42"))) print(result) match = re.match(r"(.*?)\(42\)", result) assert match match_dtype(match.groups()[0], defined_dtype)
def test_vector_unary_op() -> None: result = NpirCodegen().visit( npir.VectorUnaryOp( expr=FieldSliceFactory(name="a"), op=common.UnaryOperator.NEG, ), is_serial=False, ) assert result == "(-(a[i:I, j:J, k:K]))"
def test_scalar_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None: result = NpirCodegen().visit( npir.ScalarCast(dtype=other_dtype, expr=npir.ScalarLiteral(dtype=defined_dtype, value="42")) ) print(result) match = re.match(r"np\.(?P<other_dtype>\w*)\(np.(?P<defined_dtype>\w*)\(42\)\)", result) assert match assert match.group("defined_dtype") == defined_dtype.name.lower() assert match.group("other_dtype") == other_dtype.name.lower()
def test_vector_assign(left, is_serial: bool) -> None: result = NpirCodegen().visit( VectorAssignFactory(left=left, right=FieldSliceFactory(name="right", k_offset=-1)), ctx=NpirCodegen.BlockContext(), is_serial=is_serial, ) left_str, right_str = result.split(" = ") if isinstance(left, npir.FieldSlice): k_str_left = "k_:k_ + 1" if is_serial else "k:K" else: k_str_left = ":" if is_serial else "k:K" k_str_right = "k_ - 1:k_" if is_serial else "k - 1:K - 1" assert left_str == f"left[i:I, j:J, {k_str_left}]" assert right_str == f"right[i:I, j:J, {k_str_right}]"
def test_native_function() -> None: result = NpirCodegen().visit( NativeFuncCallFactory( func=common.NativeFunction.MIN, args=[ FieldSliceFactory(name="a"), FieldSliceFactory(name="b"), ], )) assert result == "np.minimum(a_[i:I, j:J, k_:(k_ + 1)], b_[i:I, j:J, k_:(k_ + 1)])"
def test_vector_arithmetic() -> None: result = NpirCodegen().visit( npir.VectorArithmetic( left=FieldSliceFactory(name="a"), right=FieldSliceFactory(name="b"), op=common.ArithmeticOperator.ADD, ), is_serial=False, ) assert result == "(a[i:I, j:J, k:K] + b[i:I, j:J, k:K])"
def generate_computation(self) -> Dict[str, Union[str, Dict]]: computation_name = (self.builder.caching.module_prefix + "computation" + self.builder.caching.module_postfix + ".py") source = NpirCodegen.apply(self.npir) if self.builder.options.format_source: source = format_source("python", source) return {computation_name: source}
def test_vertical_pass_par() -> None: result = NpirCodegen().visit(VerticalPassFactory(body=[], temp_defs=[])) print(result) match = re.match( (r"(#.*?\n)?" r"k, K = _dk_, _dK_\n"), result, re.MULTILINE, ) assert match
def test_mask_block_broadcast() -> None: result = NpirCodegen().visit( npir.MaskBlock( body=[], mask=npir.BroadCast(expr=npir.Literal( dtype=common.DataType.BOOL, value=common.BuiltInLiteral.TRUE)), mask_name="mask1", ), is_serial=False, ) assert result == "mask1_ = np.full((I - i, J - j, K - k), np.bool(True))\n"
def test_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None: result = NpirCodegen().visit( npir.Cast(dtype=other_dtype, expr=npir.Literal(dtype=defined_dtype, value="42"))) print(result) match = re.match(r"^np.(\w*?)\(np.(\w*)\(42\), dtype=np.(\w*)\)", result) assert match assert match.groups()[0] == "array" assert match.groups()[1] == defined_dtype.name.lower() assert match.groups()[2] == other_dtype.name.lower()
def test_scalar_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None: result = NpirCodegen().visit( npir.ScalarCast(dtype=other_dtype, expr=npir.ScalarLiteral(dtype=defined_dtype, value="42"))) print(result) match = re.match(r"(?P<other_dtype>.+)\((?P<defined_dtype>.+)\(42\)\)", result) assert match match_dtype(match.group("defined_dtype"), defined_dtype) match_dtype(match.group("other_dtype"), other_dtype)
def test_vector_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None: result = NpirCodegen().visit( npir.VectorCast( dtype=other_dtype, expr=npir.FieldSlice(name="a", i_offset=0, j_offset=0, k_offset=0, dtype=defined_dtype), ) ) print(result) match = re.match(r"(?P<name>\w+)\[.*]\.astype\(np\.(?P<dtype>\w+)\)", result) assert match assert match.group("name") == "a" assert match.group("dtype") == other_dtype.name.lower()
def test_native_function() -> None: result = NpirCodegen().visit( NativeFuncCallFactory( func=common.NativeFunction.MIN, args=[ FieldSliceFactory(name="a"), ParamAccessFactory(name="p"), ], ) ) print(result) match = re.match(r"np.minimum\(a\[.*\],\s*p\)", result) assert match
def generate_computation(self) -> Dict[str, Union[str, Dict]]: computation_name = (self.builder.caching.module_prefix + "computation" + self.builder.caching.module_postfix + ".py") ignore_np_errstate = self.builder.options.backend_opts.get( "ignore_np_errstate", True) source = NpirCodegen.apply(self.npir, ignore_np_errstate=ignore_np_errstate) if self.builder.options.format_source: source = format_source("python", source) return {computation_name: source}