def alignment_assertions(expected):
    ref_field = create_rectangular_field_function((0, 0), (20, 25),
                                                  5,
                                                  rotation=0)

    moving_field = create_rectangular_field_function(
        (-expected[0], -expected[1]), (20, 25), 5, rotation=-expected[2])

    x_span = np.arange(-50, 51)
    y_span = np.arange(-60, 61)
    axes = (x_span, y_span)

    ref_image = ref_field(x_span[:, None], y_span[None, :])
    moving_image = moving_field(x_span[:, None], y_span[None, :])

    results = align_images(axes, ref_image, axes, moving_image, max_shift=20)
    shifted_image = shift_and_rotate(axes, axes, moving_image, *results)

    try:
        assert np.allclose(shifted_image, ref_image, rtol=0.01, atol=0.01)
    except AssertionError:
        plt.figure()
        plt.imshow(ref_image)

        plt.figure()
        plt.imshow(shifted_image)

        plt.figure()
        plt.imshow(shifted_image - ref_image)

        plt.show()
        raise
def test_interpolated_rotation():
    ref_field = create_rectangular_field_function((0, 0), (5, 7),
                                                  2,
                                                  rotation=30)

    moving_field = create_rectangular_field_function((0, 0), (5, 7),
                                                     2,
                                                     rotation=10)

    x_span = np.linspace(-20, 20, 100)
    y_span = np.linspace(-30, 30, 120)

    ref_image = ref_field(x_span[:, None], y_span[None, :])
    moving_image = moving_field(x_span[:, None], y_span[None, :])

    moving_interp = create_image_interpolation((x_span, y_span), moving_image)

    no_rotation = interpolated_rotation(moving_interp, (x_span, y_span), 0)

    try:
        assert np.allclose(no_rotation, moving_image)
    except AssertionError:
        plt.figure()
        plt.imshow(moving_image)
        plt.axis("equal")

        plt.figure()
        plt.imshow(no_rotation)
        plt.axis("equal")

        plt.show()
        raise

    rotated = interpolated_rotation(moving_interp, (x_span, y_span), -20)

    try:
        assert np.allclose(ref_image, rotated, atol=1.0e-1)
    except AssertionError:
        plt.figure()
        plt.imshow(ref_image)
        plt.axis("equal")

        plt.figure()
        plt.imshow(rotated)
        plt.axis("equal")

        plt.figure()
        plt.imshow(rotated - ref_image)
        plt.axis("equal")

        plt.show()
        raise