Ejemplo n.º 1
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # create test image data
    fixed_image, moving_image, shaded_image = create_C_2_O_test_images(
        256, dtype=dtype, device=device)

    # create image pyramide size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(fixed_image,
                                                  [[4, 4], [2, 2]])
    moving_image_pyramid = al.create_image_pyramid(moving_image,
                                                   [[4, 4], [2, 2]])

    constant_displacement = None
    regularisation_weight = [1, 5, 50]
    number_of_iterations = [500, 500, 500]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(
            zip(moving_image_pyramid, fixed_image_pyramid)):

        registration = al.PairwiseRegistration(verbose=True)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(
            mov_im_level.size,
            sigma=sigma[level],
            order=3,
            dtype=dtype,
            device=device,
            diffeomorphic=True)

        if level > 0:
            constant_flow = al.transformation.utils.upsample_displacement(
                constant_flow, mov_im_level.size, interpolation="linear")
            transformation.set_constant_flow(constant_flow)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(
            mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        #define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_flow = transformation.get_flow()

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(shaded_image,
                                                      displacement)
    displacement = al.create_displacement_image_from_image(
        displacement, moving_image)

    # create inverse displacement field
    inverse_displacement = transformation.get_inverse_displacement()
    inverse_warped_image = al.transformation.utils.warp_image(
        warped_image, inverse_displacement)
    inverse_displacement = al.create_displacement_image_from_image(
        inverse_displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")

    # plot the results
    plt.subplot(241)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(242)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(243)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Shaded Moving Image')

    plt.subplot(244)
    plt.imshow(displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Displacement')

    # plot the results
    plt.subplot(245)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Shaded Moving Image')

    plt.subplot(246)
    plt.imshow(shaded_image.numpy(), cmap='gray')
    plt.title('Shaded Moving Image')

    plt.subplot(247)
    plt.imshow(inverse_warped_image.numpy(), cmap='gray')
    plt.title('Inverse Warped Shaded Moving Image')

    plt.subplot(248)
    plt.imshow(inverse_displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Inverse Displacement')

    plt.show()
Ejemplo n.º 2
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # directory to store results
    tmp_directory = "/tmp/"

    # load the image data and normalize intensities to [0, 1]
    loader = al.ImageLoader(tmp_directory)

    # Images:
    p1_name = "4DCT_POPI_1"
    p1_img_nr = "image_00"
    p2_name = "4DCT_POPI_1"
    p2_img_nr = "image_50"

    using_landmarks = True

    print("loading images")
    (fixed_image, fixed_points) = loader.load(p1_name, p1_img_nr)
    (moving_image, moving_points) = loader.load(p2_name, p2_img_nr)
    fixed_image.to(dtype, device)
    moving_image.to(dtype, device)

    if fixed_points is None or moving_points is None:
        using_landmarks = False

    if using_landmarks:
        initial_tre = al.Points.TRE(fixed_points, moving_points)
        print("initial TRE: "+str(initial_tre))

    print("preprocessing images")
    (fixed_image, fixed_body_mask) = al.remove_bed_filter(fixed_image)
    (moving_image, moving_body_mask) = al.remove_bed_filter(moving_image)

    # normalize image intensities using common minimum and common maximum
    fixed_image, moving_image = al.utils.normalize_images(fixed_image, moving_image)

    # only perform center of mass alignment if inter subject registration is performed
    if p1_name == p2_name:
        cm_alignment = False
    else:
        cm_alignment = True

    # Remove bed and auto-crop images
    f_image, f_mask, m_image, m_mask, cm_displacement = al.get_joint_domain_images(fixed_image, moving_image,
                                                                                   cm_alignment=cm_alignment,
                                                                                   compute_masks=True)

    # align also moving points
    if not cm_displacement is None and using_landmarks:
        moving_points_aligned = np.zeros_like(moving_points)
        for i in range(moving_points_aligned.shape[0]):
            moving_points_aligned[i, :] = moving_points[i, :] + cm_displacement
        print("aligned TRE: " + str(al.Points.TRE(fixed_points, moving_points_aligned)))
    else:
        moving_points_aligned = moving_points

    # create image pyramid size/8 size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(f_image, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])
    fixed_mask_pyramid = al.create_image_pyramid(f_mask, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])
    moving_image_pyramid = al.create_image_pyramid(m_image, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])
    moving_mask_pyramid = al.create_image_pyramid(m_mask, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])

    constant_flow = None
    regularisation_weight = [1e-2, 1e-1, 1e-0, 1e+2]
    number_of_iterations = [300, 200, 100, 50]
    sigma = [[9, 9, 9], [9, 9, 9], [9, 9, 9], [9, 9, 9]]
    step_size = [1e-2, 4e-3, 2e-3, 2e-3]

    print("perform registration")
    for level, (mov_im_level, mov_msk_level, fix_im_level, fix_msk_level) in enumerate(zip(moving_image_pyramid,
                                                                                           moving_mask_pyramid,
                                                                                           fixed_image_pyramid,
                                                                                           fixed_mask_pyramid)):

        print("---- Level "+str(level)+" ----")
        registration = al.PairwiseRegistration()

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(mov_im_level.size,
                                                                          sigma=sigma[level],
                                                                          order=3,
                                                                          dtype=dtype,
                                                                          device=device)

        if level > 0:
            constant_flow = al.transformation.utils.upsample_displacement(constant_flow,
                                                                          mov_im_level.size,
                                                                          interpolation="linear")

            transformation.set_constant_flow(constant_flow)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level, mov_msk_level, fix_msk_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        # define the optimizer
        optimizer = th.optim.Adam(transformation.parameters(), lr=step_size[level], amsgrad=True)

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        # store current flow field
        constant_flow = transformation.get_flow()

        current_displacement = transformation.get_displacement()
        # generate SimpleITK displacement field and calculate TRE
        tmp_displacement = al.transformation.utils.upsample_displacement(current_displacement.clone().to(device='cpu'),
                                                                         m_image.size, interpolation="linear")
        tmp_displacement = al.transformation.utils.unit_displacement_to_dispalcement(tmp_displacement)  # unit measures to image domain measures
        tmp_displacement = al.create_displacement_image_from_image(tmp_displacement, m_image)
        tmp_displacement.write('/tmp/bspline_displacement_image_level_'+str(level)+'.vtk')

        # in order to not invert the displacement field, the fixed points are transformed to match the moving points
        if using_landmarks:
            print("TRE on that level: "+str(al.Points.TRE(moving_points_aligned, al.Points.transform(fixed_points, tmp_displacement))))

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(m_image, displacement)
    displacement = al.transformation.utils.unit_displacement_to_dispalcement(displacement) # unit measures to image domain measures
    displacement = al.create_displacement_image_from_image(displacement, m_image)

    end = time.time()

    # in order to not invert the displacement field, the fixed points are transformed to match the moving points
    if using_landmarks:
        print("Initial TRE: "+str(initial_tre))
        fixed_points_transformed = al.Points.transform(fixed_points, displacement)
        print("Final TRE: " + str(al.Points.TRE(moving_points_aligned, fixed_points_transformed)))

    # write result images
    print("writing results")
    warped_image.write('/tmp/bspline_warped_image.vtk')
    m_image.write('/tmp/bspline_moving_image.vtk')
    m_mask.write('/tmp/bspline_moving_mask.vtk')
    f_image.write('/tmp/bspline_fixed_image.vtk')
    f_mask.write('/tmp/bspline_fixed_mask.vtk')
    displacement.write('/tmp/bspline_displacement_image.vtk')

    if using_landmarks:
        al.Points.write('/tmp/bspline_fixed_points_transformed.vtk', fixed_points_transformed)
        al.Points.write('/tmp/bspline_moving_points_aligned.vtk', moving_points_aligned)

    print("=================================================================")
    print("Registration done in: ", end - start, " seconds")
Ejemplo n.º 3
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # get available GPUs
    deviceIDs = GPUtil.getAvailable(order='first',
                                    limit=100,
                                    maxLoad=0.5,
                                    maxMemory=0.5,
                                    includeNan=False,
                                    excludeID=[],
                                    excludeUUID=[])
    if len(deviceIDs) > 0:
        basic.outputlogMessage("using the GPU (ID:%d) for computing" %
                               deviceIDs[0])
        device = th.device("cuda:%d" % deviceIDs[0])

    ref_scan = '46'
    new_scan = '47'
    z_min = 800
    z_max = 900
    method = "ncc"  # "mse
    # due to the limit of one GPU memory, limit the z length to 200.
    fixed_image = read_image_array_to_tensor(ref_scan, [(z_min, z_max),
                                                        (125, 825),
                                                        (125, 825)], device)

    moving_image = read_image_array_to_tensor(new_scan, [(z_min, z_max),
                                                         (125, 825),
                                                         (125, 825)], device)

    # # create 3D image volume with two objects
    # object_shift = 10
    #
    # fixed_image = th.zeros(64, 64, 64).to(device=device)
    #
    # fixed_image[16:32, 16:32, 16:32] = 1.0
    #
    # # tensor_image, image_size, image_spacing, image_origin
    # fixed_image = al.Image(fixed_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])
    #
    # moving_image = th.zeros(64, 64, 64).to(device=device)
    # moving_image[16 - object_shift:32 - object_shift, 16 - object_shift:32 - object_shift,
    # 16 - object_shift:32 - object_shift] = 1.0
    #
    # moving_image = al.Image(moving_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])

    # create pairwise registration object
    registration = al.PairwiseRegistration()

    # choose the affine transformation model
    transformation = al.transformation.pairwise.RigidTransformation(
        moving_image, opt_cm=True)
    transformation.init_translation(fixed_image)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    if method == 'mse':
        image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)
    elif method == 'ncc':
        image_loss = al.loss.pairwise.NCC(fixed_image, moving_image)
    else:
        raise ValueError("unknown method")

    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=0.1)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(500)

    # start the registration
    registration.start()

    # # set the intensities for the visualisation
    # fixed_image.image = 1 - fixed_image.image
    # moving_image.image = 1 - moving_image.image

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start, " s")
    print("Result parameters:")
    transformation.print()

    # sitk.WriteImage(warped_image.itk(), '/tmp/rigid_warped_image.vtk')
    # sitk.WriteImage(moving_image.itk(), '/tmp/rigid_moving_image.vtk')
    # sitk.WriteImage(fixed_image.itk(), '/tmp/rigid_fixed_image.vtk')

    displacement = al.transformation.utils.unit_displacement_to_displacement(
        displacement)  # unit measures to image domain measures
    displacement = al.create_displacement_image_from_image(
        displacement, moving_image)
    save_disp = "displacement_scan%s_%s_Z%d_%d_M%s" % (ref_scan, new_scan,
                                                       z_min, z_max, method)
    sitk.WriteImage(displacement.itk(), save_disp + '.vtk')
Ejemplo n.º 4
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    fixed_image, moving_image, shaded_image = create_C_2_O_test_images(
        256, dtype=dtype, device=device)

    # create pairwise registration object
    registration = al.DemonsRegistraion(dtype=dtype, device=device)

    # choose the affine transformation model
    transformation = al.transformation.pairwise.NonParametricTransformation(
        moving_image.size, dtype=dtype, device=device)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)

    registration.set_image_loss([image_loss])

    # choose a regulariser for the demons
    regulariser = al.regulariser.demons.GaussianRegulariser(
        moving_image.spacing, sigma=[2, 2], dtype=dtype, device=device)

    registration.set_regulariser([regulariser])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=0.07)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(1000)

    # start the registration
    registration.start()

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()

    # use the shaded version of the fixed image for visualization
    warped_image = al.transformation.utils.warp_image(shaded_image,
                                                      displacement)

    end = time.time()

    displacement = al.create_displacement_image_from_image(
        displacement, moving_image)

    print("=================================================================")

    print("Registration done in: ", end - start)

    # plot the results
    plt.subplot(221)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(222)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(223)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    plt.subplot(224)
    plt.imshow(displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Displacement')

    plt.show()
Ejemplo n.º 5
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    deviceIDs = GPUtil.getAvailable(order='first',
                                    limit=100,
                                    maxLoad=0.5,
                                    maxMemory=0.5,
                                    includeNan=False,
                                    excludeID=[],
                                    excludeUUID=[])
    if len(deviceIDs) > 0:
        print("using the GPU (ID:%d) for computing" % deviceIDs[0])
        device = th.device("cuda:%d" % deviceIDs[0])

    # create 3D image volume with two objects
    object_shift = 5

    fixed_image = th.zeros(64, 64, 64).to(device=device)
    fixed_image[16:32, 16:32, 16:32] = 1.0
    fixed_image = al.Image(fixed_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])

    moving_image = th.zeros(64, 64, 64).to(device=device)
    moving_image[16 - object_shift:32 - object_shift,
                 16 - object_shift:32 - object_shift,
                 16 - object_shift:32 - object_shift] = 1.0
    moving_image = al.Image(moving_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])

    # create pairwise registration object
    registration = al.PairwiseRegistration()

    # choose the affine transformation model
    transformation = al.transformation.pairwise.RigidTransformation(
        moving_image, opt_cm=True)
    transformation.init_translation(fixed_image)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)

    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=0.1)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(500)

    # start the registration
    registration.start()

    # set the intensities for the visualisation
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start, " s")
    print("Result parameters:")
    transformation.print()

    sitk.WriteImage(warped_image.itk(), 'rigid_warped_image.vtk')
    sitk.WriteImage(moving_image.itk(), 'rigid_moving_image.vtk')
    sitk.WriteImage(fixed_image.itk(), 'rigid_fixed_image.vtk')

    displacement = al.transformation.utils.unit_displacement_to_displacement(
        displacement)  # unit measures to image domain measures
    displacement = al.create_displacement_image_from_image(
        displacement, moving_image)
    sitk.WriteImage(displacement.itk(), 'displacement' + '.vtk')

    # plot the results
    plt.subplot(131)
    plt.imshow(fixed_image.numpy()[16, :, :], cmap='gray')
    plt.title('Fixed Image Slice')

    plt.subplot(132)
    plt.imshow(moving_image.numpy()[16, :, :], cmap='gray')
    plt.title('Moving Image Slice')

    plt.subplot(133)
    plt.imshow(warped_image.numpy()[16, :, :], cmap='gray')
    plt.title('Warped Moving Image Slice')

    plt.show()
Ejemplo n.º 6
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    #device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    itkImg = sitk.ReadImage("./data/affine_test_image_2d_fixed.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    fixed_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    itkImg = sitk.ReadImage("./data/affine_test_image_2d_moving.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    moving_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    # create image pyramide size/4, size/2, size/1
    fixed_image_pyramide = al.create_image_pyramide(fixed_image, [[4, 4], [2, 2]])
    moving_image_pyramide = al.create_image_pyramide(moving_image, [[4, 4], [2, 2]])

    constant_displacement = None
    regularisation_weight = [1, 5, 50]
    number_of_iterations = [500, 500, 500]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(zip(moving_image_pyramide, fixed_image_pyramide)):

        registration = al.PairwiseRegistration(dtype=dtype, device=device)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(mov_im_level.size,
                                                                          sigma=sigma[level],
                                                                          order=3,
                                                                          dtype=dtype,
                                                                          device=device)

        if level > 0:
            constant_displacement = al.transformation.utils.upsample_displacement(constant_displacement,
                                                                                  mov_im_level.size,
                                                                                  interpolation="linear")

            transformation.set_constant_displacement(constant_displacement)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        #define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_displacement = transformation.get_displacement()

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image, displacement)
    displacement = al.create_displacement_image_from_image(displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")

    # plot the results
    plt.subplot(221)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(222)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(223)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    plt.subplot(224)
    plt.imshow(displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Displacement')

    plt.show()
Ejemplo n.º 7
0
def register_images(fixed_image, moving_image, shaded_image=None):
    start = time.time()

    # create image pyramid size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(fixed_image,
                                                  [[4, 4], [2, 2]])
    moving_image_pyramid = al.create_image_pyramid(moving_image,
                                                   [[4, 4], [2, 2]])

    constant_flow = None
    # regularisation_weight = [1, 5, 50]
    # number_of_iterations = [500, 500, 500]
    # sigma = [[11, 11], [11, 11], [3, 3]]

    regularisation_weight = [1, 5, 50]
    number_of_iterations = [10, 10, 10]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(
            zip(moving_image_pyramid, fixed_image_pyramid)):

        registration = al.PairwiseRegistration(verbose=True)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(
            mov_im_level.size,
            sigma=sigma[level],
            order=3,
            dtype=fixed_image.dtype,
            device=fixed_image.device)

        if level > 0:
            constant_flow = al.transformation.utils.upsample_displacement(
                constant_flow, mov_im_level.size, interpolation="linear")
            transformation.set_constant_flow(constant_flow)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(
            mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        # define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_flow = transformation.get_flow()

    # create final result
    displacement = transformation.get_displacement()
    if shaded_image is not None:
        warped_image = al.transformation.utils.warp_image(
            shaded_image, displacement)
    else:
        warped_image = al.transformation.utils.warp_image(
            moving_image, displacement)
    displacement_image = al.create_displacement_image_from_image(
        displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")
    return warped_image, displacement_image