def test_merge_read_after_write_k_parallel_noseq( merge_blocks_pass: PassType) -> None: transform_data = make_transform_data( name="read_after_write_forbidden", # type ignore is due to attribclass domain=Domain( parallel_axes=[Axis(name=idx) for idx in ["I", "J", "K"]]), # type: ignore fields=["out", "in"], body=[("tmp", "in", (0, 0, 0)), ("out", "tmp", (0, 0, -1))], iteration_order=IterationOrder.PARALLEL, ) transform_data = merge_blocks_pass(transform_data) # allowed to be merged, because k-axis is not sequential assert len(transform_data.blocks) == 1
def test_merge_read_after_write_k_parallel_noseq( merge_blocks_pass: AnalysisPass) -> None: transform_data = ( TDefinition( name="read_after_write_forbidden_noseq", # type ignores are due to attribclass domain=Domain( # type: ignore parallel_axes=[Axis(name=idx) for idx in ["I", "J", "K"]] # type: ignore ), fields=["out", "in"], ).add_blocks( TComputationBlock(order=IterationOrder.PARALLEL).add_statements( TAssign("tmp", "in", (0, 0, 0)), TAssign("out", "tmp", (0, 0, -1)), )).build_transform()) transform_data = merge_blocks_pass(transform_data) # allowed to be merged, because k-axis is not sequential assert len(transform_data.blocks) == 1
def ijk_domain() -> Domain: axes = [Axis(name=idx) for idx in ["I", "J", "K"]] return Domain(parallel_axes=axes[:2], sequential_axis=axes[2])