def test_mask_propagation(parallel_k): mask_stmt = MaskStmtFactory() mask_block = OirToNpir().visit( mask_stmt, ctx=OirToNpir.ComputationContext(), parallel_k=parallel_k, ) assert mask_block.body[0].mask == mask_block.mask
def test_mask_stmt_to_mask_block(parallel_k): mask_stmt = MaskStmtFactory(body=[]) mask_block = OirToNpir().visit( mask_stmt, ctx=OirToNpir.ComputationContext(), parallel_k=parallel_k, ) assert isinstance(mask_block.mask, npir.FieldSlice) assert mask_block.body == []
def test_field_access_to_field_slice(parallel_k): field_access = oir.FieldAccess( name="a", offset=common.CartesianOffset(i=-1, j=2, k=0), dtype=common.DataType.FLOAT64, ) ctx = OirToNpir.ComputationContext() parallel_field_slice = OirToNpir().visit(field_access, ctx=ctx, parallel_k=parallel_k) assert parallel_field_slice.k_offset.parallel is parallel_k assert parallel_field_slice.i_offset.offset.value == -1
def test_assign_stmt_to_vector_assign(parallel_k): assign_stmt = oir.AssignStmt( left=oir.FieldAccess( name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64 ), right=oir.FieldAccess( name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64 ), ) ctx = OirToNpir.ComputationContext() v_assign = OirToNpir().visit(assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None) assert isinstance(v_assign, npir.VectorAssign) assert v_assign.left.k_offset.parallel is parallel_k assert v_assign.right.k_offset.parallel is parallel_k
def test_native_func_call(): oir_node = oir.NativeFuncCall( func=common.NativeFunction.SQRT, args=[ oir.FieldAccess( name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64, ), ], ) result = OirToNpir().visit( oir_node, parallel_k=True, ctx=OirToNpir.ComputationContext(), ) assert isinstance(result, npir.VectorExpression)
def test_temp_assign(parallel_k): assign_stmt = oir.AssignStmt( left=oir.FieldAccess( name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64, ), right=oir.FieldAccess( name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64 ), ) ctx = OirToNpir.ComputationContext() _ = OirToNpir().visit( assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None, symtable={"a": TemporaryFactory(name="a")}, ) assert len(ctx.temp_defs) == 1 assert isinstance(ctx.temp_defs["a"].left, npir.VectorTemp) assert isinstance(ctx.temp_defs["a"].right, npir.EmptyTemp)