def test_vector_assign() -> None: result = npir_gen.NpirGen().visit( VectorAssignFactory( left__name="a", right__name="b", )) assert result == "a_[i:I, j:J, k_:(k_ + 1)] = b_[i:I, j:J, k_:(k_ + 1)]"
def test_vector_unary_not() -> None: result = npir_gen.NpirGen().visit( npir.VectorUnaryOp( expr=FieldSliceFactory(name="a"), op=common.UnaryOperator.NOT, )) assert result == "(np.bitwise_not(a_[i:I, j:J, k_:(k_ + 1)]))"
def test_broadcast_literal(defined_dtype: common.DataType) -> None: result = npir_gen.NpirGen().visit( npir.BroadCast(expr=npir.Literal(dtype=defined_dtype, value="42"))) print(result) match = re.match(r"np.(\w*?)\(42\)", result) assert match assert match.groups()[0] == defined_dtype.name.lower()
def test_vector_arithmetic() -> None: result = npir_gen.NpirGen().visit( npir.VectorArithmetic( left=FieldSliceFactory(name="a"), right=FieldSliceFactory(name="b"), op=common.ArithmeticOperator.ADD, )) assert result == "(a_[i:I, j:J, k_:(k_ + 1)] + b_[i:I, j:J, k_:(k_ + 1)])"
def test_mask_block_broadcast() -> None: result = npir_gen.NpirGen().visit( npir.MaskBlock( body=[], mask=npir.BroadCast(expr=npir.Literal( dtype=common.DataType.BOOL, value=common.BuiltInLiteral.TRUE)), mask_name="mask1", )) assert result == "mask1_ = np.full((I - i, J - j, K - k), np.bool(True))\n"
def test_horizontal_block() -> None: result = npir_gen.NpirGen().visit(npir.HorizontalBlock(body=[])) print(result) match = re.match( r"(#.*?\n)?i, I = _di_ - 0, _dI_ \+ 0\nj, J = _dj_ - 0, _dJ_ \+ 0\n", result, re.MULTILINE, ) assert match
def test_temp_with_extent_definition() -> None: result = npir_gen.NpirGen().visit( VectorAssignFactory(temp_init=True, temp_name="a"), field_extents={"a": Extent((0, 1), (-2, 3))}, ) assert ( result == "a_ = ShimmedView(np.zeros((_dI_ + 1, _dJ_ + 5, _dK_), dtype=np.int64), [0, 2, 0])" )
def test_native_function() -> None: result = npir_gen.NpirGen().visit( NativeFuncCallFactory( func=common.NativeFunction.MIN, args=[ FieldSliceFactory(name="a"), FieldSliceFactory(name="b"), ], )) assert result == "np.minimum(a_[i:I, j:J, k_:(k_ + 1)], b_[i:I, j:J, k_:(k_ + 1)])"
def test_vertical_pass_par() -> None: result = npir_gen.NpirGen().visit( VerticalPassFactory(body=[], temp_defs=[])) print(result) match = re.match( (r"(#.*?\n)?" r"k, K = _dk_, _dK_\n"), result, re.MULTILINE, ) assert match
def test_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None: result = npir_gen.NpirGen().visit( npir.Cast(dtype=other_dtype, expr=npir.Literal(dtype=defined_dtype, value="42"))) print(result) match = re.match(r"^np.(\w*?)\(np.(\w*)\(42\), dtype=np.(\w*)\)", result) assert match assert match.groups()[0] == "array" assert match.groups()[1] == defined_dtype.name.lower() assert match.groups()[2] == other_dtype.name.lower()
def test_mask_block_other() -> None: result = npir_gen.NpirGen().visit( npir.MaskBlock( body=[], mask=npir.VectorLogic( op=common.LogicalOperator.AND, left=FieldSliceFactory(name="a"), right=FieldSliceFactory(name="b"), ), mask_name="mask1", )) assert result.startswith("mask1_ = np.bitwise_and(a_[i:I")
def test_verticall_pass_start_start_forward() -> None: result = npir_gen.NpirGen().visit( VerticalPassFactory( body=[], temp_defs=[], upper=common.AxisBound.from_start(offset=5), direction=common.LoopOrder.FORWARD, )) print(result) match = re.match( r"(#.*?\n)?k, K = _dk_, _dk_ \+ 5\nfor k_ in range\(k, K\):\n", result, re.MULTILINE, ) assert match
def test_verticall_pass_end_end_backward() -> None: result = npir_gen.NpirGen().visit( VerticalPassFactory( body=[], temp_defs=[], lower=common.AxisBound.from_end(offset=-4), upper=common.AxisBound.from_end(offset=-1), direction=common.LoopOrder.BACKWARD, )) print(result) match = re.match( r"(#.*?\n)?k, K = _dK_ \- 4, _dK_ \- 1\nfor k_ in range\(K-1, k-1, -1\):\n", result, re.MULTILINE, ) assert match
def test_vertical_pass_temp_def() -> None: result = npir_gen.NpirGen().visit( VerticalPassFactory( temp_defs=[ VectorAssignFactory(temp_init=True, temp_name="a"), ], body=[], lower=common.AxisBound.from_end(offset=-4), upper=common.AxisBound.from_end(offset=-1), direction=common.LoopOrder.BACKWARD, )) print(result) match = re.match( r"(#.*?\n)?a_ = ShimmedView\(np.zeros\(_domain_, dtype=np.int64\), \[0, 0, 0\]\)\nk, K = _dK_ \- 4, _dK_ \- 1\nfor k_ in range\(K-1, k-1, -1\):\n", result, re.MULTILINE, ) assert match
def test_vertical_pass_seq() -> None: result = npir_gen.NpirGen().visit( VerticalPassFactory( temp_defs=[], body=[], lower=common.AxisBound.from_start(offset=1), upper=common.AxisBound.from_end(offset=-2), direction=common.LoopOrder.FORWARD, )) print(result) match = re.match( (r"(#.*?\n)?" r"k, K = _dk_ \+ 1, _dK_ - 2\n" r"for k_ in range\(k, K\):\n"), result, re.MULTILINE, ) assert match
def test_computation() -> None: result = npir_gen.NpirGen().visit( npir.Computation( params=[], field_params=[], field_decls=[], vertical_passes=[], ), field_extents={}, ) print(result) match = re.match( (r"import numpy as np\n\n\n" r"def run\(\*, _domain_, _origin_\):\n" r"\n?" r"( .*?\n)*"), result, re.MULTILINE, ) assert match
def test_mask_block_slice_mask() -> None: result = npir_gen.NpirGen().visit( npir.MaskBlock(body=[], mask=FieldSliceFactory(name="mask1"), mask_name="mask1")) assert result == ""
def test_datatype() -> None: result = npir_gen.NpirGen().visit(common.DataType.FLOAT64) print(result) match = re.match(r"np.float64", result) assert match
def test_temp_definition() -> None: result = npir_gen.NpirGen().visit( VectorAssignFactory(temp_init=True, temp_name="a")) assert result == "a_ = ShimmedView(np.zeros(_domain_, dtype=np.int64), [0, 0, 0])"
def test_field_slice_sequential_k() -> None: result = npir_gen.NpirGen().visit( FieldSliceFactory(name="a_field", parallel_k=False, offsets=(-1, 0, 4))) assert result == "a_field_[(i - 1):(I - 1), j:J, (k_ + 4):(k_ + 4 + 1)]"
def test_field_slice_parallel_k() -> None: result = npir_gen.NpirGen().visit( FieldSliceFactory(name="another_field", parallel_k=True, offsets=(0, 0, -3))) assert result == "another_field_[i:I, j:J, (k - 3):(K - 3)]"
def test_sequential_offset() -> None: result = npir_gen.NpirGen().visit(npir.AxisOffset.k(5)) print(result) assert result == "(k_ + 5):(k_ + 5 + 1)"
def test_sequential_offset_zero() -> None: result = npir_gen.NpirGen().visit(npir.AxisOffset.k(0)) print(result) assert result == "k_:(k_ + 1)"
def test_numerical_offset_zero() -> None: result = npir_gen.NpirGen().visit(npir.NumericalOffset(value=0)) assert result == ""
def test_parallel_offset() -> None: result = npir_gen.NpirGen().visit(npir.AxisOffset.i(-3)) print(result) assert result == "(i - 3):(I - 3)"
def test_parallel_offset_zero() -> None: result = npir_gen.NpirGen().visit(npir.AxisOffset.j(0)) print(result) assert result == "j:J"