def test_write_after_read_with_offset():
    testee = VerticalLoopSectionFactory(horizontal_executions=[
        HorizontalExecutionFactory(
            body=[AssignStmtFactory(right__name="foo", right__offset__i=1)]),
        HorizontalExecutionFactory(body=[AssignStmtFactory(left__name="foo")]),
    ])
    transformed = (optimize_horizontal_executions(
        StencilFactory(vertical_loops__0__sections__0=testee),
        GraphMerging,
    ).vertical_loops[0].sections[0])
    for result, reference in zip(transformed.horizontal_executions,
                                 testee.horizontal_executions):
        assert result.body == reference.body
def test_nonzero_extent_merging():
    testee = VerticalLoopSectionFactory(horizontal_executions=[
        HorizontalExecutionFactory(body=[AssignStmtFactory(
            right__name="foo")]),
        HorizontalExecutionFactory(
            body=[AssignStmtFactory(right__name="foo", right__offset__j=1)]),
    ])
    transformed = (optimize_horizontal_executions(
        StencilFactory(vertical_loops__0__sections__0=testee),
        GraphMerging,
    ).vertical_loops[0].sections[0])
    assert len(transformed.horizontal_executions) == 1
    assert transformed.horizontal_executions[0].body == sum(
        (he.body for he in testee.horizontal_executions), [])
        ]),
        HorizontalExecutionFactory(body=[
            assignment_1 := AssignStmtFactory(left__name="baz",
                                              right__name="bar")
        ]),
        HorizontalExecutionFactory(body=[
            assignment_2 := AssignStmtFactory(left__name="foo",
                                              right__name="foo")
        ]),
        HorizontalExecutionFactory(body=[
            assignment_3 := AssignStmtFactory(left__name="foo",
                                              right__name="baz")
        ], ),
    ])
    transformed = (optimize_horizontal_executions(
        StencilFactory(vertical_loops__0__sections__0=testee),
        GraphMerging,
    ).vertical_loops[0].sections[0])
    assert len(transformed.horizontal_executions) == 1
    transformed_order = transformed.horizontal_executions[0].body
    assert transformed_order.index(assignment_0) < transformed_order.index(
        assignment_2)
    assert transformed_order.index(assignment_1) < transformed_order.index(
        assignment_3)
    assert transformed_order.index(assignment_2) < transformed_order.index(
        assignment_3)


def test_mixed_merging():
    testee = VerticalLoopSectionFactory(horizontal_executions=[
        HorizontalExecutionFactory(
            body=[assignment_0 := AssignStmtFactory(left__name="foo")]),