def affine_registrate(mov_img,
                      fix_img,
                      lr=None,
                      niter=None,
                      init_transform=None):
    registration = al.PairwiseRegistration(verbose=True)

    transformation = al.transformation.pairwise.AffineTransformation(mov_img)
    if init_transform:
        init_transform.init_al_transform(transformation)
    registration.set_transformation(transformation)

    image_loss = al.loss.pairwise.MSE(fix_img, mov_img)
    registration.set_image_loss([image_loss])

    optimizer = torch.optim.Adam(transformation.parameters(), lr=lr)
    registration.set_optimizer(optimizer)

    registration.set_number_of_iterations(niter)
    registration.start()

    end = time.time()
    return transformation
Esempio n. 2
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # directory to store results
    tmp_directory = "/tmp/"

    # load the image data and normalize intensities to [0, 1]
    loader = al.ImageLoader(tmp_directory)

    # Images:
    p1_name = "4DCT_POPI_1"
    p1_img_nr = "image_00"
    p2_name = "4DCT_POPI_1"
    p2_img_nr = "image_50"

    using_landmarks = True

    print("loading images")
    (fixed_image, fixed_points) = loader.load(p1_name, p1_img_nr)
    (moving_image, moving_points) = loader.load(p2_name, p2_img_nr)
    fixed_image.to(dtype, device)
    moving_image.to(dtype, device)

    if fixed_points is None or moving_points is None:
        using_landmarks = False

    if using_landmarks:
        initial_tre = al.Points.TRE(fixed_points, moving_points)
        print("initial TRE: "+str(initial_tre))

    print("preprocessing images")
    (fixed_image, fixed_body_mask) = al.remove_bed_filter(fixed_image)
    (moving_image, moving_body_mask) = al.remove_bed_filter(moving_image)

    # normalize image intensities using common minimum and common maximum
    fixed_image, moving_image = al.utils.normalize_images(fixed_image, moving_image)

    # only perform center of mass alignment if inter subject registration is performed
    if p1_name == p2_name:
        cm_alignment = False
    else:
        cm_alignment = True

    # Remove bed and auto-crop images
    f_image, f_mask, m_image, m_mask, cm_displacement = al.get_joint_domain_images(fixed_image, moving_image,
                                                                                   cm_alignment=cm_alignment,
                                                                                   compute_masks=True)

    # align also moving points
    if not cm_displacement is None and using_landmarks:
        moving_points_aligned = np.zeros_like(moving_points)
        for i in range(moving_points_aligned.shape[0]):
            moving_points_aligned[i, :] = moving_points[i, :] + cm_displacement
        print("aligned TRE: " + str(al.Points.TRE(fixed_points, moving_points_aligned)))
    else:
        moving_points_aligned = moving_points

    # create image pyramid size/8 size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(f_image, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])
    fixed_mask_pyramid = al.create_image_pyramid(f_mask, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])
    moving_image_pyramid = al.create_image_pyramid(m_image, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])
    moving_mask_pyramid = al.create_image_pyramid(m_mask, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])

    constant_flow = None
    regularisation_weight = [1e-2, 1e-1, 1e-0, 1e+2]
    number_of_iterations = [300, 200, 100, 50]
    sigma = [[9, 9, 9], [9, 9, 9], [9, 9, 9], [9, 9, 9]]
    step_size = [1e-2, 4e-3, 2e-3, 2e-3]

    print("perform registration")
    for level, (mov_im_level, mov_msk_level, fix_im_level, fix_msk_level) in enumerate(zip(moving_image_pyramid,
                                                                                           moving_mask_pyramid,
                                                                                           fixed_image_pyramid,
                                                                                           fixed_mask_pyramid)):

        print("---- Level "+str(level)+" ----")
        registration = al.PairwiseRegistration()

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(mov_im_level.size,
                                                                          sigma=sigma[level],
                                                                          order=3,
                                                                          dtype=dtype,
                                                                          device=device)

        if level > 0:
            constant_flow = al.transformation.utils.upsample_displacement(constant_flow,
                                                                          mov_im_level.size,
                                                                          interpolation="linear")

            transformation.set_constant_flow(constant_flow)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level, mov_msk_level, fix_msk_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        # define the optimizer
        optimizer = th.optim.Adam(transformation.parameters(), lr=step_size[level], amsgrad=True)

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        # store current flow field
        constant_flow = transformation.get_flow()

        current_displacement = transformation.get_displacement()
        # generate SimpleITK displacement field and calculate TRE
        tmp_displacement = al.transformation.utils.upsample_displacement(current_displacement.clone().to(device='cpu'),
                                                                         m_image.size, interpolation="linear")
        tmp_displacement = al.transformation.utils.unit_displacement_to_dispalcement(tmp_displacement)  # unit measures to image domain measures
        tmp_displacement = al.create_displacement_image_from_image(tmp_displacement, m_image)
        tmp_displacement.write('/tmp/bspline_displacement_image_level_'+str(level)+'.vtk')

        # in order to not invert the displacement field, the fixed points are transformed to match the moving points
        if using_landmarks:
            print("TRE on that level: "+str(al.Points.TRE(moving_points_aligned, al.Points.transform(fixed_points, tmp_displacement))))

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(m_image, displacement)
    displacement = al.transformation.utils.unit_displacement_to_dispalcement(displacement) # unit measures to image domain measures
    displacement = al.create_displacement_image_from_image(displacement, m_image)

    end = time.time()

    # in order to not invert the displacement field, the fixed points are transformed to match the moving points
    if using_landmarks:
        print("Initial TRE: "+str(initial_tre))
        fixed_points_transformed = al.Points.transform(fixed_points, displacement)
        print("Final TRE: " + str(al.Points.TRE(moving_points_aligned, fixed_points_transformed)))

    # write result images
    print("writing results")
    warped_image.write('/tmp/bspline_warped_image.vtk')
    m_image.write('/tmp/bspline_moving_image.vtk')
    m_mask.write('/tmp/bspline_moving_mask.vtk')
    f_image.write('/tmp/bspline_fixed_image.vtk')
    f_mask.write('/tmp/bspline_fixed_mask.vtk')
    displacement.write('/tmp/bspline_displacement_image.vtk')

    if using_landmarks:
        al.Points.write('/tmp/bspline_fixed_points_transformed.vtk', fixed_points_transformed)
        al.Points.write('/tmp/bspline_moving_points_aligned.vtk', moving_points_aligned)

    print("=================================================================")
    print("Registration done in: ", end - start, " seconds")
Esempio n. 3
0
def affine_register(im1,
                    im2,
                    iterations=1000,
                    lr=0.01,
                    transform_type='similarity',
                    gpu_device=0,
                    opt_cm=True,
                    sigma=[[11, 11], [11, 11], [3, 3]],
                    order=2,
                    pyramid=[[4, 4], [2, 2]],
                    loss_fn='mse',
                    use_mask=False,
                    interpolation='bicubic'):
    assert use_mask == False, "Masking not implemented"
    assert transform_type in [
        'similarity', 'affine', 'rigid', 'non_parametric', 'bspline',
        'wendland'
    ]
    start = time.perf_counter()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device(
        "cuda:{}".format(gpu_device) if gpu_device >= 0 else 'cpu')

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    # add mask to loss function
    fixed_image = al.utils.image.create_tensor_image_from_itk_image(
        sitk.GetImageFromArray(cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)),
        dtype=th.float32,
        device=device
    )  # al.read_image_as_tensor("./practice_reg/1.png", dtype=dtype, device=device)#th.tensor(img1,device='cuda',dtype=dtype)#
    moving_image = al.utils.image.create_tensor_image_from_itk_image(
        sitk.GetImageFromArray(cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)),
        dtype=th.float32,
        device=device
    )  # al.read_image_as_tensor("./practice_reg/2.png", dtype=dtype, device=device)#th.tensor(img2,device='cuda',dtype=dtype)#

    fixed_image, moving_image = al.utils.normalize_images(
        fixed_image, moving_image)

    # convert intensities so that the object intensities are 1 and the background 0. This is important in order to
    # calculate the center of mass of the object
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # create pairwise registration object
    registration = al.PairwiseRegistration()

    transforms = dict(
        similarity=al.transformation.pairwise.SimilarityTransformation,
        affine=al.transformation.pairwise.AffineTransformation,
        rigid=al.transformation.pairwise.RigidTransformation,
        non_parametric=al.transformation.pairwise.NonParametricTransformation,
        wendland=al.transformation.pairwise.WendlandKernelTransformation,
        bspline=al.transformation.pairwise.BsplineTransformation)
    constant_flow = None

    if transform_type in ['similarity', 'affine', 'rigid']:
        transform_opts = dict(opt_cm=opt_cm)
        transform_args = [moving_image]
        sigma, fixed_image_pyramid, moving_image_pyramid = [[]], [[]], [[]]
    else:
        transform_opts = dict(diffeomorphic=opt_cm,
                              device=('cuda:{}'.format(gpu_device)
                                      if gpu_device >= 0 else 'cpu'))
        transform_args = [moving_image.size]
        if transform_type in ['bspline', 'wendland']:
            transform_opts['sigma'] = sigma
            fixed_image_pyramid = al.create_image_pyramid(fixed_image, pyramid)
            moving_image_pyramid = al.create_image_pyramid(
                moving_image, pyramid)
        else:
            sigma, fixed_image_pyramid, moving_image_pyramid = [[]], [[
                fixed_image
            ]], [[moving_image]]
        if transform_type == 'bspline':
            transform_opts['order'] = order
        if transform_type == 'wendland':
            transform_opts['cp_scale'] = order

    for level, (mov_im_level, fix_im_level) in enumerate(
            zip(moving_image_pyramid, fixed_image_pyramid)):

        # choose the affine transformation model
        if transform_type == 'non_parametric':
            transform_args[0] = mov_im_level[level].size
        elif transform_type in ['bspline', 'wendland']:
            # for bspline, sigma must be positive tuple of ints
            # for bspline, smaller sigma tuple means less loss of
            # microarchitectural details

            # transform_opts['sigma'] = sigma[level]
            transform_opts['sigma'] = (1, 1)

        transformation = transforms[transform_type](*transform_args,
                                                    **transform_opts)

        # if level > 0 and transform_type=='bspline':
        # 	constant_flow = al.transformation.utils.upsample_displacement(constant_flow,
        # 																  mov_im_level.size,
        # 																  interpolation=interpolation)
        # 	transformation.set_constant_flow(constant_flow)

        if transform_type in ['similarity', 'affine', 'rigid']:
            # initialize the translation with the center of mass of the fixed image
            transformation.init_translation(fixed_image)

        registration.set_transformation(transformation)

        loss_fns = dict(mse=al.loss.pairwise.MSE,
                        ncc=al.loss.pairwise.NCC,
                        lcc=al.loss.pairwise.LCC,
                        mi=al.loss.pairwise.MI,
                        mgf=al.loss.pairwise.NGF,
                        ssim=al.loss.pairwise.SSIM)

        # choose the Mean Squared Error as image loss
        image_loss = loss_fns[loss_fn](fixed_image, moving_image)

        registration.set_image_loss([image_loss])

        # choose the Adam optimizer to minimize the objective
        optimizer = th.optim.Adam(transformation.parameters(),
                                  lr=lr,
                                  amsgrad=True)

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(iterations)

        # start the registration
        registration.start()

    # if transform_type == 'bspline':
    # 	constant_flow = transformation.get_flow()

    # set the intensities back to the original for the visualisation
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    # TODO add saved displacement, registration objects

    # TODO log final registration loss

    end = time.perf_counter()

    print("=================================================================")

    print("Registration done in:", end - start, "s")
    if transform_type in ['similarity', 'affine', 'rigid']:
        print("Result parameters:")
        transformation.print()

    # plot the results
    plt.subplot(131)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(132)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(133)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    if transform_type in ['similarity', 'affine', 'rigid']:
        transformation_param = transformation._phi_z
    elif transform_type == 'non_parametric':
        transformation_param = transformation.trans_parameters
    elif transform_type == 'bspline' or transform_type == 'wendland':
        transformation_param = transformation._kernel
    else:
        pass
    return displacement, warped_image, transformation_param, registration.loss.data.item(
    )
Esempio n. 4
0
def affine_register(im1,
                    im2,
                    iterations=1000,
                    lr=0.01,
                    transform_type='similarity',
                    gpu_device=0):
    assert transform_type in ['similarity', 'affine', 'rigid']
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cuda:{}".format(gpu_device))

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    # add mask to loss function
    fixed_image = al.utils.image.create_tensor_image_from_itk_image(
        sitk.GetImageFromArray(cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)),
        dtype=th.float32,
        device=device
    )  #al.read_image_as_tensor("./practice_reg/1.png", dtype=dtype, device=device)#th.tensor(img1,device='cuda',dtype=dtype)#
    moving_image = al.utils.image.create_tensor_image_from_itk_image(
        sitk.GetImageFromArray(cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)),
        dtype=th.float32,
        device=device
    )  #al.read_image_as_tensor("./practice_reg/2.png", dtype=dtype, device=device)#th.tensor(img2,device='cuda',dtype=dtype)#

    fixed_image, moving_image = al.utils.normalize_images(
        fixed_image, moving_image)

    # convert intensities so that the object intensities are 1 and the background 0. This is important in order to
    # calculate the center of mass of the object
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # create pairwise registration object
    registration = al.PairwiseRegistration()

    transforms = dict(
        similarity=al.transformation.pairwise.SimilarityTransformation,
        affine=al.transformation.pairwise.AffineTransformation,
        rigid=al.transformation.pairwise.RigidTransformation)

    # choose the affine transformation model
    transformation = transforms[transform_type](moving_image, opt_cm=True)
    # initialize the translation with the center of mass of the fixed image
    transformation.init_translation(fixed_image)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)

    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=lr, amsgrad=True)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(iterations)

    # start the registration
    registration.start()

    # set the intensities back to the original for the visualisation
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    end = time.time()

    print("=================================================================")

    print("Registration done in:", end - start, "s")
    print("Result parameters:")
    transformation.print()

    # plot the results
    plt.subplot(131)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(132)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(133)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')
    return displacement, warped_image, transformation._phi_z, registration.loss.data.item(
    )
def affine_reg(img_draw, img_ref, output_path, lr=0.01, iter=200):
    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    itkImg = sitk.ReadImage(img_ref, sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    fixed_image = al.create_tensor_image_from_itk_image(itkImg,
                                                        dtype=dtype,
                                                        device=device)
    fsize = fixed_image.numpy().flatten().size

    itkImg = sitk.ReadImage(img_draw, sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    moving_image = al.create_tensor_image_from_itk_image(itkImg,
                                                         dtype=dtype,
                                                         device=device)

    # create pairwise registration object
    registration = al.PairwiseRegistration(dtype=dtype, device=device)

    # choose the affine transformation model
    transformation = al.transformation.pairwise.RigidTransformation(
        moving_image.size, dtype=dtype, device=device)
    registration.set_transformation(transformation)

    # choose the scaling regulariser and the diffusion regulariser
    # scale_reg = al.regulariser.parameter.ScalingRegulariser('trans_parameters')
    # scale_reg.set_weight(0.0001)
    # registration.set_regulariser_parameter([scale_reg])
    # dis_reg = al.regulariser.displacement.DiffusionRegulariser(moving_image.spacing, size_average=False)
    # registration.set_regulariser_displacement([dis_reg])
    # dis_reg.set_weight(0.1)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.NCC(fixed_image, moving_image)
    #init_loss = np.sum(np.square(fixed_image.numpy() - moving_image.numpy()))/fsize
    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=lr)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(iter)

    # start the registration
    registration.start()

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    param = transformation.trans_parameters.detach().numpy()
    translate = np.sqrt(np.square(param[1]) + np.square(param[2]))
    scale = np.sqrt((np.square(param[3] - 1) + np.square(param[4] - 1)) * 0.5)
    init_loss = registration.init_loss.detach().numpy()
    final_loss = registration.img_loss.detach().numpy()

    # plot the results
    # plt.subplot(131)
    # plt.imshow(fixed_image.numpy(), cmap='gray')
    # plt.title('Ref')

    plt.figure(figsize=(13, 6))
    plt.subplot(121)
    plt.xticks([])
    plt.yticks([])
    fixed = fixed_image.numpy()
    moved = moving_image.numpy()
    warp = warped_image.numpy()
    csfont = {'fontname': 'Times New Roman', 'size': 35}

    p1 = np.ones((fixed.shape[0], fixed.shape[1], 3))
    p1[fixed < 1] = [0.5, 0.5, 0.5]
    p1[moved < 1] = [0.4, 0.62, 0.78]
    plt.imshow(p1)
    plt.title('Raw', **csfont)

    plt.subplot(122)
    plt.xticks([])
    plt.yticks([])
    p2 = np.ones((fixed.shape[0], fixed.shape[1], 3))
    p2[fixed < 1] = [0.5, 0.5, 0.5]
    p2[warp < 0.5] = [0.4, 0.62, 0.78]
    plt.imshow(p2)
    plt.title('Transformed', **csfont)

    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    return init_loss, final_loss, np.abs(
        param[0]), translate, scale, warped_image
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    fixed_image = al.read_image_as_tensor("./data/affine_test_image_2d_fixed.png", dtype=dtype, device=device)
    moving_image = al.read_image_as_tensor("./data/affine_test_image_2d_moving.png", dtype=dtype, device=device)

    fixed_image, moving_image = al.utils.normalize_images(fixed_image, moving_image)

    # convert intensities so that the object intensities are 1 and the background 0. This is important in order to
    # calculate the center of mass of the object
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # create pairwise registration object
    registration = al.PairwiseRegistration()

    # choose the affine transformation model
    transformation = al.transformation.pairwise.SimilarityTransformation(moving_image, opt_cm=True)
    # initialize the translation with the center of mass of the fixed image
    transformation.init_translation(fixed_image)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)

    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=0.01, amsgrad=True)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(1000)

    # start the registration
    registration.start()

    # set the intensities back to the original for the visualisation
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image, displacement)

    end = time.time()

    print("=================================================================")

    print("Registration done in:", end - start, "s")
    print("Result parameters:")
    transformation.print()

    # plot the results
    plt.subplot(131)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(132)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(133)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    plt.show()
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # create test image data
    fixed_image, moving_image, shaded_image = create_C_2_O_test_images(
        256, dtype=dtype, device=device)

    # create image pyramide size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(fixed_image,
                                                  [[4, 4], [2, 2]])
    moving_image_pyramid = al.create_image_pyramid(moving_image,
                                                   [[4, 4], [2, 2]])

    constant_displacement = None
    regularisation_weight = [1, 5, 50]
    number_of_iterations = [500, 500, 500]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(
            zip(moving_image_pyramid, fixed_image_pyramid)):

        registration = al.PairwiseRegistration(verbose=True)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(
            mov_im_level.size,
            sigma=sigma[level],
            order=3,
            dtype=dtype,
            device=device,
            diffeomorphic=True)

        if level > 0:
            constant_flow = al.transformation.utils.upsample_displacement(
                constant_flow, mov_im_level.size, interpolation="linear")
            transformation.set_constant_flow(constant_flow)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(
            mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        #define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_flow = transformation.get_flow()

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(shaded_image,
                                                      displacement)
    displacement = al.create_displacement_image_from_image(
        displacement, moving_image)

    # create inverse displacement field
    inverse_displacement = transformation.get_inverse_displacement()
    inverse_warped_image = al.transformation.utils.warp_image(
        warped_image, inverse_displacement)
    inverse_displacement = al.create_displacement_image_from_image(
        inverse_displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")

    # plot the results
    plt.subplot(241)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(242)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(243)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Shaded Moving Image')

    plt.subplot(244)
    plt.imshow(displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Displacement')

    # plot the results
    plt.subplot(245)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Shaded Moving Image')

    plt.subplot(246)
    plt.imshow(shaded_image.numpy(), cmap='gray')
    plt.title('Shaded Moving Image')

    plt.subplot(247)
    plt.imshow(inverse_warped_image.numpy(), cmap='gray')
    plt.title('Inverse Warped Shaded Moving Image')

    plt.subplot(248)
    plt.imshow(inverse_displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Inverse Displacement')

    plt.show()
Esempio n. 8
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # get available GPUs
    deviceIDs = GPUtil.getAvailable(order='first',
                                    limit=100,
                                    maxLoad=0.5,
                                    maxMemory=0.5,
                                    includeNan=False,
                                    excludeID=[],
                                    excludeUUID=[])
    if len(deviceIDs) > 0:
        basic.outputlogMessage("using the GPU (ID:%d) for computing" %
                               deviceIDs[0])
        device = th.device("cuda:%d" % deviceIDs[0])

    ref_scan = '46'
    new_scan = '47'
    z_min = 800
    z_max = 900
    method = "ncc"  # "mse
    # due to the limit of one GPU memory, limit the z length to 200.
    fixed_image = read_image_array_to_tensor(ref_scan, [(z_min, z_max),
                                                        (125, 825),
                                                        (125, 825)], device)

    moving_image = read_image_array_to_tensor(new_scan, [(z_min, z_max),
                                                         (125, 825),
                                                         (125, 825)], device)

    # # create 3D image volume with two objects
    # object_shift = 10
    #
    # fixed_image = th.zeros(64, 64, 64).to(device=device)
    #
    # fixed_image[16:32, 16:32, 16:32] = 1.0
    #
    # # tensor_image, image_size, image_spacing, image_origin
    # fixed_image = al.Image(fixed_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])
    #
    # moving_image = th.zeros(64, 64, 64).to(device=device)
    # moving_image[16 - object_shift:32 - object_shift, 16 - object_shift:32 - object_shift,
    # 16 - object_shift:32 - object_shift] = 1.0
    #
    # moving_image = al.Image(moving_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])

    # create pairwise registration object
    registration = al.PairwiseRegistration()

    # choose the affine transformation model
    transformation = al.transformation.pairwise.RigidTransformation(
        moving_image, opt_cm=True)
    transformation.init_translation(fixed_image)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    if method == 'mse':
        image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)
    elif method == 'ncc':
        image_loss = al.loss.pairwise.NCC(fixed_image, moving_image)
    else:
        raise ValueError("unknown method")

    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=0.1)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(500)

    # start the registration
    registration.start()

    # # set the intensities for the visualisation
    # fixed_image.image = 1 - fixed_image.image
    # moving_image.image = 1 - moving_image.image

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start, " s")
    print("Result parameters:")
    transformation.print()

    # sitk.WriteImage(warped_image.itk(), '/tmp/rigid_warped_image.vtk')
    # sitk.WriteImage(moving_image.itk(), '/tmp/rigid_moving_image.vtk')
    # sitk.WriteImage(fixed_image.itk(), '/tmp/rigid_fixed_image.vtk')

    displacement = al.transformation.utils.unit_displacement_to_displacement(
        displacement)  # unit measures to image domain measures
    displacement = al.create_displacement_image_from_image(
        displacement, moving_image)
    save_disp = "displacement_scan%s_%s_Z%d_%d_M%s" % (ref_scan, new_scan,
                                                       z_min, z_max, method)
    sitk.WriteImage(displacement.itk(), save_disp + '.vtk')
Esempio n. 9
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    #device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    itkImg = sitk.ReadImage("./data/affine_test_image_2d_fixed.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    fixed_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    itkImg = sitk.ReadImage("./data/affine_test_image_2d_moving.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    moving_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    # create pairwise registration object
    registration = al.PairwiseRegistration(dtype=dtype, device=device)

    # choose the affine transformation model
    transformation = al.transformation.pairwise.RigidTransformation(moving_image.size, dtype=dtype, device=device)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)

    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=0.01)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(100)

    # start the registration
    registration.start()

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image, displacement)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")
    transformation.print()

    # plot the results
    plt.subplot(131)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(132)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(133)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    plt.show()
Esempio n. 10
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    deviceIDs = GPUtil.getAvailable(order='first',
                                    limit=100,
                                    maxLoad=0.5,
                                    maxMemory=0.5,
                                    includeNan=False,
                                    excludeID=[],
                                    excludeUUID=[])
    if len(deviceIDs) > 0:
        print("using the GPU (ID:%d) for computing" % deviceIDs[0])
        device = th.device("cuda:%d" % deviceIDs[0])

    # create 3D image volume with two objects
    object_shift = 5

    fixed_image = th.zeros(64, 64, 64).to(device=device)
    fixed_image[16:32, 16:32, 16:32] = 1.0
    fixed_image = al.Image(fixed_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])

    moving_image = th.zeros(64, 64, 64).to(device=device)
    moving_image[16 - object_shift:32 - object_shift,
                 16 - object_shift:32 - object_shift,
                 16 - object_shift:32 - object_shift] = 1.0
    moving_image = al.Image(moving_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])

    # create pairwise registration object
    registration = al.PairwiseRegistration()

    # choose the affine transformation model
    transformation = al.transformation.pairwise.RigidTransformation(
        moving_image, opt_cm=True)
    transformation.init_translation(fixed_image)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)

    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=0.1)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(500)

    # start the registration
    registration.start()

    # set the intensities for the visualisation
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start, " s")
    print("Result parameters:")
    transformation.print()

    sitk.WriteImage(warped_image.itk(), 'rigid_warped_image.vtk')
    sitk.WriteImage(moving_image.itk(), 'rigid_moving_image.vtk')
    sitk.WriteImage(fixed_image.itk(), 'rigid_fixed_image.vtk')

    displacement = al.transformation.utils.unit_displacement_to_displacement(
        displacement)  # unit measures to image domain measures
    displacement = al.create_displacement_image_from_image(
        displacement, moving_image)
    sitk.WriteImage(displacement.itk(), 'displacement' + '.vtk')

    # plot the results
    plt.subplot(131)
    plt.imshow(fixed_image.numpy()[16, :, :], cmap='gray')
    plt.title('Fixed Image Slice')

    plt.subplot(132)
    plt.imshow(moving_image.numpy()[16, :, :], cmap='gray')
    plt.title('Moving Image Slice')

    plt.subplot(133)
    plt.imshow(warped_image.numpy()[16, :, :], cmap='gray')
    plt.title('Warped Moving Image Slice')

    plt.show()
Esempio n. 11
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    #device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    itkImg = sitk.ReadImage("./data/affine_test_image_2d_fixed.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    fixed_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    itkImg = sitk.ReadImage("./data/affine_test_image_2d_moving.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    moving_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    # create image pyramide size/4, size/2, size/1
    fixed_image_pyramide = al.create_image_pyramide(fixed_image, [[4, 4], [2, 2]])
    moving_image_pyramide = al.create_image_pyramide(moving_image, [[4, 4], [2, 2]])

    constant_displacement = None
    regularisation_weight = [1, 5, 50]
    number_of_iterations = [500, 500, 500]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(zip(moving_image_pyramide, fixed_image_pyramide)):

        registration = al.PairwiseRegistration(dtype=dtype, device=device)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(mov_im_level.size,
                                                                          sigma=sigma[level],
                                                                          order=3,
                                                                          dtype=dtype,
                                                                          device=device)

        if level > 0:
            constant_displacement = al.transformation.utils.upsample_displacement(constant_displacement,
                                                                                  mov_im_level.size,
                                                                                  interpolation="linear")

            transformation.set_constant_displacement(constant_displacement)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        #define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_displacement = transformation.get_displacement()

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image, displacement)
    displacement = al.create_displacement_image_from_image(displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")

    # plot the results
    plt.subplot(221)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(222)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(223)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    plt.subplot(224)
    plt.imshow(displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Displacement')

    plt.show()
Esempio n. 12
0
def register_images(fixed_image, moving_image, shaded_image=None):
    start = time.time()

    # create image pyramid size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(fixed_image,
                                                  [[4, 4], [2, 2]])
    moving_image_pyramid = al.create_image_pyramid(moving_image,
                                                   [[4, 4], [2, 2]])

    constant_flow = None
    # regularisation_weight = [1, 5, 50]
    # number_of_iterations = [500, 500, 500]
    # sigma = [[11, 11], [11, 11], [3, 3]]

    regularisation_weight = [1, 5, 50]
    number_of_iterations = [10, 10, 10]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(
            zip(moving_image_pyramid, fixed_image_pyramid)):

        registration = al.PairwiseRegistration(verbose=True)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(
            mov_im_level.size,
            sigma=sigma[level],
            order=3,
            dtype=fixed_image.dtype,
            device=fixed_image.device)

        if level > 0:
            constant_flow = al.transformation.utils.upsample_displacement(
                constant_flow, mov_im_level.size, interpolation="linear")
            transformation.set_constant_flow(constant_flow)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(
            mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        # define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_flow = transformation.get_flow()

    # create final result
    displacement = transformation.get_displacement()
    if shaded_image is not None:
        warped_image = al.transformation.utils.warp_image(
            shaded_image, displacement)
    else:
        warped_image = al.transformation.utils.warp_image(
            moving_image, displacement)
    displacement_image = al.create_displacement_image_from_image(
        displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")
    return warped_image, displacement_image
    def three_dim_affine_reg(self):
        start = time.time()

        # set the used data type
        dtype = torch.float32
        # set the device for the computaion to CPU
        #device = torch.device("cpu")
        #device = torch.device("cuda:0")
        device = self.device

        # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
        # Here, the GPU with the index 0 is used.
        # device = th.device("cuda:0")

        #Creating the airlabs image objects for registration
        new_stationary_img_tnsr = self.preprocessed_stationary_img_tnsr.to(
            device=device)
        new_moving_img_tnsr = self.preprocessed_moving_img_tnsr.to(
            device=device)
        fixed_image = al.Image(new_stationary_img_tnsr, self.img_shape,
                               self.preprocessed_stationary_img_voxel_dim,
                               self.preprocessed_stationary_img_centre)
        moving_image = al.Image(new_moving_img_tnsr, self.img_shape,
                                self.preprocessed_moving_img_voxel_dim,
                                self.preprocessed_moving_img_centre)

        # printing image properties
        print(
            " ============= fixed image size, spacing, origin and datatype ==================="
        )
        print(fixed_image.size)
        print(fixed_image.spacing)
        print(fixed_image.origin)
        print(fixed_image.dtype)
        print(
            " ============= moving image size, spacing, origin and datatype ==================="
        )
        print(moving_image.size)
        print(moving_image.spacing)
        print(moving_image.origin)
        print(moving_image.dtype)
        print(" ============= ============== ===================")

        # create pairwise registration object
        registration = al.PairwiseRegistration()

        # choose the affine transformation model
        print("Using Affine transformation")
        print(" ============= ============== ===================")

        if (self.reg_type == "affine"):
            transformation = al.transformation.pairwise.AffineTransformation(
                moving_image, opt_cm=True)
        else:
            transformation = al.transformation.pairwise.BsplineTransformation(
                image_size=moving_image.size,
                sigma=self.sigma,
                diffeomorphic=True,
                order=3,
                dtype=torch.float32,
                device='cpu')
        transformation.init_translation(fixed_image)
        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        if (self.loss_fnc == "MSE"):
            print("Using Mean squared error loss")
            image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)
        elif (self.loss_fnc == "MI"):
            print("Using Mutual information loss")
            image_loss = al.loss.pairwise.MI(fixed_image,
                                             moving_image,
                                             bins=20,
                                             sigma=3)
        elif (self.loss_fnc == "CC"):
            print("Using Cross corelation loss")
            image_loss = al.loss.pairwise.NCC(fixed_image, moving_image)
        else:
            print(
                "No valid option chosen among MSE/NCC/NMI, using MSE as default"
            )
            image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)

        registration.set_image_loss([image_loss])

        # choose the Adam optimizer to minimize the objective
        optimizer = torch.optim.Adam(transformation.parameters(), lr=0.1)

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(const.ITERATIONS)

        # start the registration
        registration.start()

        # set the intensities for the visualisation
        fixed_image.image = 1 - fixed_image.image
        moving_image.image = 1 - moving_image.image

        # warp the moving image with the final transformation result
        displacement = transformation.get_displacement()
        warped_image = al.transformation.utils.warp_image(
            moving_image, displacement)

        end = time.time()

        print(" ============= ============== ===================")

        print("Registration done in: ", end - start, " s")
        print("Result parameters:")
        transformation.print()
        print(" ============= ============== ===================")
        print(transformation.transformation_matrix)
        print(" ============= ============== ===================")

        # plot the results - commented out as it pops open a window

        plt.subplot(131)
        plt.imshow(fixed_image.numpy()[90, :, :], cmap='gray')
        plt.title('Fixed Image Slice')

        plt.subplot(132)
        plt.imshow(moving_image.numpy()[90, :, :], cmap='gray')
        plt.title('Moving Image Slice')

        plt.subplot(133)
        plt.imshow(warped_image.numpy()[16, :, :], cmap='gray')
        plt.title('Warped Moving Image Slice')
        plt.show()

        self.affine_transformation_matrix = transformation.transformation_matrix
        self.affine_transformation_object = transformation
        self.displacement = displacement

        return warped_image, transformation, displacement
Esempio n. 14
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # create 3D image volume with two objects
    object_shift = 10

    fixed_image = th.zeros(64, 64, 64).to(device=device)
    fixed_image[16:32, 16:32, 16:32] = 1.0
    fixed_image = al.Image(fixed_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])

    moving_image = th.zeros(64, 64, 64).to(device=device)
    moving_image[16 - object_shift:32 - object_shift,
                 16 - object_shift:32 - object_shift,
                 16 - object_shift:32 - object_shift] = 1.0
    moving_image = al.Image(moving_image, [64, 64, 64], [1, 1, 1], [0, 0, 0])

    # create pairwise registration object
    registration = al.PairwiseRegistration()

    # choose the affine transformation model
    transformation = al.transformation.pairwise.RigidTransformation(
        moving_image, opt_cm=True)
    transformation.init_translation(fixed_image)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)

    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=0.1)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(500)

    # start the registration
    registration.start()

    # set the intensities for the visualisation
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start, " s")
    print("Result parameters:")
    transformation.print()

    # sitk.WriteImage(warped_image.itk(), '/tmp/rigid_warped_image.vtk')
    # sitk.WriteImage(moving_image.itk(), '/tmp/rigid_moving_image.vtk')
    # sitk.WriteImage(fixed_image.itk(), '/tmp/rigid_fixed_image.vtk')

    # plot the results
    plt.subplot(131)
    plt.imshow(fixed_image.numpy()[16, :, :], cmap='gray')
    plt.title('Fixed Image Slice')

    plt.subplot(132)
    plt.imshow(moving_image.numpy()[16, :, :], cmap='gray')
    plt.title('Moving Image Slice')

    plt.subplot(133)
    plt.imshow(warped_image.numpy()[16, :, :], cmap='gray')
    plt.title('Warped Moving Image Slice')

    plt.show()