def test_mask_propagation(parallel_k): mask_stmt = MaskStmtFactory() mask_block = OirToNpir().visit( mask_stmt, ctx=OirToNpir.ComputationContext(), parallel_k=parallel_k, ) assert mask_block.body[0].mask == mask_block.mask
def test_mask_stmt_to_mask_block(parallel_k): mask_stmt = MaskStmtFactory(body=[]) mask_block = OirToNpir().visit( mask_stmt, ctx=OirToNpir.ComputationContext(), parallel_k=parallel_k, ) assert isinstance(mask_block.mask, npir.FieldSlice) assert mask_block.body == []
def test_field_access_to_field_slice(parallel_k): field_access = oir.FieldAccess( name="a", offset=common.CartesianOffset(i=-1, j=2, k=0), dtype=common.DataType.FLOAT64, ) ctx = OirToNpir.ComputationContext() parallel_field_slice = OirToNpir().visit(field_access, ctx=ctx, parallel_k=parallel_k) assert parallel_field_slice.k_offset.parallel is parallel_k assert parallel_field_slice.i_offset.offset.value == -1
def test_stencil_to_computation() -> None: stencil = StencilFactory( name="stencil", params=[ FieldDeclFactory( name="a", dtype=common.DataType.FLOAT64, ), oir.ScalarDecl( name="b", dtype=common.DataType.INT32, ), ], vertical_loops__0__sections__0__horizontal_executions__0__body=[ AssignStmtFactory( left=FieldAccessFactory(name="a"), right=ScalarAccessFactory(name="b") ) ], ) computation = OirToNpir().visit(stencil) assert set(d.name for d in computation.api_field_decls) == { "a", } assert set(computation.arguments) == {"a", "b"} assert len(computation.vertical_passes) == 1
def test_vertical_loop_section_to_vertical_pass(): vertical_loop_section = VerticalLoopSectionFactory( horizontal_executions=[]) vertical_pass = OirToNpir().visit(vertical_loop_section, loop_order=common.LoopOrder.PARALLEL) assert vertical_pass.body == []
def test_field_access_to_field_slice_variablek() -> None: field_access = FieldAccessFactory( offset=oir.VariableKOffset(k=oir.Literal(value="1", dtype=common.DataType.INT32)) ) field_slice = OirToNpir().visit(field_access) assert (field_slice.i_offset, field_slice.j_offset) == (0, 0) assert field_slice.k_offset.k.value == "1"
def make_block_and_transform(**kwargs) -> npir.HorizontalBlock: oir_stencil = StencilFactory( vertical_loops__0__sections__0__horizontal_executions=[ HorizontalExecutionFactory(**kwargs) ]) return OirToNpir().visit(oir_stencil).vertical_passes[0].body[0]
def test_native_func_call(): oir_node = oir.NativeFuncCall( func=common.NativeFunction.SQRT, args=[ oir.FieldAccess( name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64, ), ], ) result = OirToNpir().visit( oir_node, parallel_k=True, ctx=OirToNpir.ComputationContext(), ) assert isinstance(result, npir.VectorExpression)
def test_literal_broadcast() -> None: result = OirToNpir().visit( AssignStmtFactory( left__dtype=common.DataType.FLOAT32, right=oir.Literal(value="42", dtype=common.DataType.FLOAT32), )) assert isinstance(result.right, npir.Broadcast) assert (result.right.expr.value, result.right.expr.dtype) == ("42", common.DataType.FLOAT32)
def test_assign_stmt_to_vector_assign(parallel_k): assign_stmt = oir.AssignStmt( left=oir.FieldAccess(name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64), right=oir.FieldAccess(name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64), ) ctx = OirToNpir.ComputationContext() v_assign = OirToNpir().visit(assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None) assert isinstance(v_assign, npir.VectorAssign) assert v_assign.left.k_offset.parallel is parallel_k assert v_assign.right.k_offset.parallel is parallel_k
def test_binary_op_to_vector_arithmetic(): binop = oir.BinaryOp( op=common.ArithmeticOperator.ADD, left=oir.Literal(dtype=common.DataType.INT32, value="2"), right=oir.Literal(dtype=common.DataType.INT32, value="2"), ) result = OirToNpir().visit(binop) assert isinstance(result, npir.VectorArithmetic) assert isinstance(result.left, npir.BroadCast) assert isinstance(result.right, npir.BroadCast)
def test_literal(broadcast): gtir_literal = oir.Literal(value="42", dtype=common.DataType.INT32) result = OirToNpir().visit(gtir_literal, broadcast=broadcast) npir_literal = result if broadcast: assert isinstance(result, npir.BroadCast) assert isinstance(result, npir.VectorExpression) npir_literal = result.expr assert gtir_literal.dtype == npir_literal.dtype assert gtir_literal.kind == npir_literal.kind assert gtir_literal.value == npir_literal.value
def test_cast(broadcast): itof = oir.Cast(dtype=common.DataType.FLOAT64, expr=oir.Literal(value="42", dtype=common.DataType.INT32)) result = OirToNpir().visit(itof, broadcast=broadcast) assert isinstance(result, npir.BroadCast if broadcast else npir.Cast) cast = result if broadcast: assert isinstance(result, npir.BroadCast) assert isinstance(result, npir.VectorExpression) cast = result.expr assert cast.dtype == itof.dtype assert cast.expr.value == "42"
def test_temp_assign(parallel_k): assign_stmt = oir.AssignStmt( left=oir.FieldAccess( name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64, ), right=oir.FieldAccess(name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64), ) ctx = OirToNpir.ComputationContext() _ = OirToNpir().visit( assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None, symtable={"a": TemporaryFactory(name="a")}, ) assert len(ctx.temp_defs) == 1 assert isinstance(ctx.temp_defs["a"].left, npir.VectorTemp) assert isinstance(ctx.temp_defs["a"].right, npir.EmptyTemp)
def _make_npir(self) -> npir.Computation: base_oir = GTIRToOIR().visit(self.builder.gtir) oir_pipeline = self.builder.options.backend_opts.get( "oir_pipeline", DefaultPipeline(skip=[ IJCacheDetection, KCacheDetection, PruneKCacheFills, PruneKCacheFlushes, FillFlushToLocalKCaches, ]), ) oir = oir_pipeline.run(base_oir) return OirToNpir().visit(oir)
def _make_npir(self) -> npir.Computation: base_oir = GTIRToOIR().visit(self.builder.gtir) oir_pipeline = self.builder.options.backend_opts.get( "oir_pipeline", DefaultPipeline(skip=[ IJCacheDetection, KCacheDetection, PruneKCacheFills, PruneKCacheFlushes, ]), ) oir_node = oir_pipeline.run(base_oir) base_npir = OirToNpir().visit(oir_node) npir_node = ScalarsToTemporaries().visit(base_npir) return npir_node
def test_stencil_to_computation(): stencil = StencilFactory( name="stencil", params=[ FieldDeclFactory( name="a", dtype=common.DataType.FLOAT64, ), oir.ScalarDecl( name="b", dtype=common.DataType.INT32, ), ], vertical_loops__0__sections__0__horizontal_executions__0__body=[ AssignStmtFactory(left=FieldAccessFactory(name="a"), right=ScalarAccessFactory(name="b")) ], ) computation = OirToNpir().visit(stencil) assert computation.field_params == ["a"] assert computation.params == ["a", "b"] assert len(computation.vertical_passes) == 1
def test_mask_propagation() -> None: mask_stmt = MaskStmtFactory() assign_stmts = OirToNpir().visit(mask_stmt) assert assign_stmts[0].mask == OirToNpir().visit(mask_stmt.mask)
def test_mask_stmt_to_assigns() -> None: mask_stmt = MaskStmtFactory(body=[AssignStmtFactory()]) assign_stmts = OirToNpir().visit(mask_stmt) assert isinstance(assign_stmts[0].mask, npir.FieldSlice) assert len(assign_stmts) == 1
def test_binary_op_to_npir(oir_node: oir.Expr, npir_type: npir.Expr) -> None: assert isinstance(OirToNpir().visit(oir_node), npir_type)
def test_field_access_to_field_slice_cartesian() -> None: field_access = FieldAccessFactory(offset__i=-1, offset__j=2, offset__k=0) field_slice = OirToNpir().visit(field_access) assert (field_slice.i_offset, field_slice.j_offset, field_slice.k_offset) == (-1, 2, 0)
def test_native_func_call() -> None: assert isinstance( OirToNpir().visit(NativeFuncCallFactory(args__0=FieldAccessFactory())), npir.NativeFuncCall )
def test_cast(oir_int_expr, npir_type) -> None: assert isinstance( OirToNpir().visit(oir.Cast(dtype=common.DataType.FLOAT64, expr=oir_int_expr)), npir_type )
def test_horizontal_execution_to_vector_assigns() -> None: horizontal_execution = HorizontalExecutionFactory(body=[]) horizontal_block = OirToNpir().visit(horizontal_execution) assert horizontal_block.body == []
def test_mask_stmt_to_assigns() -> None: mask_stmt = MaskStmtFactory(body=[AssignStmtFactory()]) assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2)) assert isinstance(assign_stmts[0].right.cond, npir.FieldSlice) assert len(assign_stmts) == 1
def test_mask_propagation() -> None: mask_stmt = MaskStmtFactory() assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2)) assert assign_stmts[0].right.cond == OirToNpir().visit(mask_stmt.mask)
def test_horizontal_execution_to_vector_assigns(): horizontal_execution = HorizontalExecutionFactory(body=[]) horizontal_region = OirToNpir().visit(horizontal_execution) assert horizontal_region.body == []
def test_vertical_loop_to_vertical_passes(): vertical_loop = VerticalLoopFactory(sections__0__horizontal_executions=[]) vertical_passes = OirToNpir().visit(vertical_loop) assert vertical_passes[0].body == []