Beispiel #1
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    #device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    device = th.device("cuda:0")

    # directory to store results
    tmp_directory = "/tmp/"

    # load the image data and normalize intensities to [0, 1]
    loader = al.ImageLoader(tmp_directory)

    # Images:
    p1_name = "4DCT_POPI_1"
    p1_img_nr = "image_00"
    p2_name = "4DCT_POPI_1"
    p2_img_nr = "image_50"

    using_landmarks = True

    print("loading images")
    (fixed_image, fixed_points) = loader.load(p1_name, p1_img_nr)
    (moving_image, moving_points) = loader.load(p2_name, p2_img_nr)
    fixed_image.to(dtype, device)
    moving_image.to(dtype, device)

    if fixed_points is None or moving_points is None:
        using_landmarks = False

    if using_landmarks:
        initial_tre = al.Points.TRE(fixed_points, moving_points)
        print("initial TRE: " + str(initial_tre))

    print("preprocessing images")
    (fixed_image, fixed_body_mask) = al.remove_bed_filter(fixed_image)
    (moving_image, moving_body_mask) = al.remove_bed_filter(moving_image)

    # normalize image intensities using common minimum and common maximum
    fixed_image, moving_image = al.utils.normalize_images(
        fixed_image, moving_image)

    # only perform center of mass alignment if inter subject registration is performed
    if p1_name == p2_name:
        cm_alignment = False
    else:
        cm_alignment = True

    # Remove bed and auto-crop images
    f_image, f_mask, m_image, m_mask, cm_displacement = al.get_joint_domain_images(
        fixed_image,
        moving_image,
        cm_alignment=cm_alignment,
        compute_masks=True)

    # align also moving points
    if not cm_displacement is None and using_landmarks:
        moving_points_aligned = np.zeros_like(moving_points)
        for i in range(moving_points_aligned.shape[0]):
            moving_points_aligned[i, :] = moving_points[i, :] + cm_displacement
        print("aligned TRE: " +
              str(al.Points.TRE(fixed_points, moving_points_aligned)))
    else:
        moving_points_aligned = moving_points

    # create image pyramid size/8 size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(
        f_image, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])
    fixed_mask_pyramid = al.create_image_pyramid(
        f_mask, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])
    moving_image_pyramid = al.create_image_pyramid(
        m_image, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])
    moving_mask_pyramid = al.create_image_pyramid(
        m_mask, [[8, 8, 8], [4, 4, 4], [2, 2, 2]])

    constant_displacement = None
    regularisation_weight = [1e-2, 1e-1, 1e-0, 1e+2]
    number_of_iterations = [300, 200, 100, 50]
    sigma = [[9, 9, 9], [9, 9, 9], [9, 9, 9], [9, 9, 9]]
    step_size = [1e-2, 4e-3, 2e-3, 2e-3]

    print("perform registration")
    for level, (mov_im_level, mov_msk_level, fix_im_level,
                fix_msk_level) in enumerate(
                    zip(moving_image_pyramid, moving_mask_pyramid,
                        fixed_image_pyramid, fixed_mask_pyramid)):

        print("---- Level " + str(level) + " ----")
        registration = al.PairwiseRegistration()

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(
            mov_im_level.size,
            sigma=sigma[level],
            order=3,
            dtype=dtype,
            device=device)

        if level > 0:
            constant_displacement = al.transformation.utils.upsample_displacement(
                constant_displacement,
                mov_im_level.size,
                interpolation="linear")

            transformation.set_constant_displacement(constant_displacement)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level,
                                          mov_msk_level, fix_msk_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(
            mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        #define the optimizer
        optimizer = th.optim.Adam(transformation.parameters(),
                                  lr=step_size[level],
                                  amsgrad=True)

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        # store current displacement field
        constant_displacement = transformation.get_displacement()

        # generate SimpleITK displacement field and calculate TRE
        tmp_displacement = al.transformation.utils.upsample_displacement(
            constant_displacement.clone().to(device='cpu'),
            m_image.size,
            interpolation="linear")
        tmp_displacement = al.transformation.utils.unit_displacement_to_dispalcement(
            tmp_displacement)  # unit measures to image domain measures
        tmp_displacement = al.create_displacement_image_from_image(
            tmp_displacement, m_image)
        tmp_displacement.write('/tmp/bspline_displacement_image_level_' +
                               str(level) + '.vtk')

        # in order to not invert the displacement field, the fixed points are transformed to match the moving points
        if using_landmarks:
            print("TRE on that level: " + str(
                al.Points.TRE(
                    moving_points_aligned,
                    al.Points.transform(fixed_points, tmp_displacement))))

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(m_image, displacement)
    displacement = al.transformation.utils.unit_displacement_to_dispalcement(
        displacement)  # unit measures to image domain measures
    displacement = al.create_displacement_image_from_image(
        displacement, m_image)

    end = time.time()

    # in order to not invert the displacement field, the fixed points are transformed to match the moving points
    if using_landmarks:
        print("Initial TRE: " + str(initial_tre))
        fixed_points_transformed = al.Points.transform(fixed_points,
                                                       displacement)
        print("Final TRE: " + str(
            al.Points.TRE(moving_points_aligned, fixed_points_transformed)))

    # write result images
    print("writing results")
    warped_image.write('/tmp/bspline_warped_image.vtk')
    m_image.write('/tmp/bspline_moving_image.vtk')
    m_mask.write('/tmp/bspline_moving_mask.vtk')
    f_image.write('/tmp/bspline_fixed_image.vtk')
    f_mask.write('/tmp/bspline_fixed_mask.vtk')
    displacement.write('/tmp/bspline_displacement_image.vtk')

    if using_landmarks:
        al.Points.write('/tmp/bspline_fixed_points_transformed.vtk',
                        fixed_points_transformed)
        al.Points.write('/tmp/bspline_moving_points_aligned.vtk',
                        moving_points_aligned)

    print("=================================================================")
    print("Registration done in: ", end - start, " seconds")
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # create test image data
    fixed_image, moving_image, shaded_image = create_C_2_O_test_images(
        256, dtype=dtype, device=device)

    # create image pyramide size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(fixed_image,
                                                  [[4, 4], [2, 2]])
    moving_image_pyramid = al.create_image_pyramid(moving_image,
                                                   [[4, 4], [2, 2]])

    constant_displacement = None
    regularisation_weight = [1, 5, 50]
    number_of_iterations = [500, 500, 500]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(
            zip(moving_image_pyramid, fixed_image_pyramid)):

        registration = al.PairwiseRegistration(verbose=True)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(
            mov_im_level.size,
            sigma=sigma[level],
            order=3,
            dtype=dtype,
            device=device,
            diffeomorphic=True)

        if level > 0:
            constant_flow = al.transformation.utils.upsample_displacement(
                constant_flow, mov_im_level.size, interpolation="linear")
            transformation.set_constant_flow(constant_flow)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(
            mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        #define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_flow = transformation.get_flow()

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(shaded_image,
                                                      displacement)
    displacement = al.create_displacement_image_from_image(
        displacement, moving_image)

    # create inverse displacement field
    inverse_displacement = transformation.get_inverse_displacement()
    inverse_warped_image = al.transformation.utils.warp_image(
        warped_image, inverse_displacement)
    inverse_displacement = al.create_displacement_image_from_image(
        inverse_displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")

    # plot the results
    plt.subplot(241)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(242)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(243)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Shaded Moving Image')

    plt.subplot(244)
    plt.imshow(displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Displacement')

    # plot the results
    plt.subplot(245)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Shaded Moving Image')

    plt.subplot(246)
    plt.imshow(shaded_image.numpy(), cmap='gray')
    plt.title('Shaded Moving Image')

    plt.subplot(247)
    plt.imshow(inverse_warped_image.numpy(), cmap='gray')
    plt.title('Inverse Warped Shaded Moving Image')

    plt.subplot(248)
    plt.imshow(inverse_displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Inverse Displacement')

    plt.show()
def register_images(fixed_image, moving_image, shaded_image=None):
    start = time.time()

    # create image pyramid size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(fixed_image,
                                                  [[4, 4], [2, 2]])
    moving_image_pyramid = al.create_image_pyramid(moving_image,
                                                   [[4, 4], [2, 2]])

    constant_flow = None
    # regularisation_weight = [1, 5, 50]
    # number_of_iterations = [500, 500, 500]
    # sigma = [[11, 11], [11, 11], [3, 3]]

    regularisation_weight = [1, 5, 50]
    number_of_iterations = [10, 10, 10]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(
            zip(moving_image_pyramid, fixed_image_pyramid)):

        registration = al.PairwiseRegistration(verbose=True)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(
            mov_im_level.size,
            sigma=sigma[level],
            order=3,
            dtype=fixed_image.dtype,
            device=fixed_image.device)

        if level > 0:
            constant_flow = al.transformation.utils.upsample_displacement(
                constant_flow, mov_im_level.size, interpolation="linear")
            transformation.set_constant_flow(constant_flow)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(
            mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        # define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_flow = transformation.get_flow()

    # create final result
    displacement = transformation.get_displacement()
    if shaded_image is not None:
        warped_image = al.transformation.utils.warp_image(
            shaded_image, displacement)
    else:
        warped_image = al.transformation.utils.warp_image(
            moving_image, displacement)
    displacement_image = al.create_displacement_image_from_image(
        displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")
    return warped_image, displacement_image
Beispiel #4
0
def affine_register(im1,
                    im2,
                    iterations=1000,
                    lr=0.01,
                    transform_type='similarity',
                    gpu_device=0,
                    opt_cm=True,
                    sigma=[[11, 11], [11, 11], [3, 3]],
                    order=2,
                    pyramid=[[4, 4], [2, 2]],
                    loss_fn='mse',
                    use_mask=False,
                    interpolation='bicubic'):
    assert use_mask == False, "Masking not implemented"
    assert transform_type in [
        'similarity', 'affine', 'rigid', 'non_parametric', 'bspline',
        'wendland'
    ]
    start = time.perf_counter()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device(
        "cuda:{}".format(gpu_device) if gpu_device >= 0 else 'cpu')

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    # add mask to loss function
    fixed_image = al.utils.image.create_tensor_image_from_itk_image(
        sitk.GetImageFromArray(cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)),
        dtype=th.float32,
        device=device
    )  #al.read_image_as_tensor("./practice_reg/1.png", dtype=dtype, device=device)#th.tensor(img1,device='cuda',dtype=dtype)#
    moving_image = al.utils.image.create_tensor_image_from_itk_image(
        sitk.GetImageFromArray(cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)),
        dtype=th.float32,
        device=device
    )  #al.read_image_as_tensor("./practice_reg/2.png", dtype=dtype, device=device)#th.tensor(img2,device='cuda',dtype=dtype)#

    fixed_image, moving_image = al.utils.normalize_images(
        fixed_image, moving_image)

    # convert intensities so that the object intensities are 1 and the background 0. This is important in order to
    # calculate the center of mass of the object
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # create pairwise registration object
    registration = al.PairwiseRegistration()

    transforms = dict(
        similarity=al.transformation.pairwise.SimilarityTransformation,
        affine=al.transformation.pairwise.AffineTransformation,
        rigid=al.transformation.pairwise.RigidTransformation,
        non_parametric=al.transformation.pairwise.NonParametricTransformation,
        wendland=al.transformation.pairwise.WendlandKernelTransformation,
        bspline=al.transformation.pairwise.BsplineTransformation)
    constant_flow = None

    if transform_type in ['similarity', 'affine', 'rigid']:
        transform_opts = dict(opt_cm=opt_cm)
        transform_args = [moving_image]
        sigma, fixed_image_pyramid, moving_image_pyramid = [[]], [[]], [[]]
    else:
        transform_opts = dict(diffeomorphic=opt_cm,
                              device=('cuda:{}'.format(gpu_device)
                                      if gpu_device >= 0 else 'cpu'))
        transform_args = [moving_image.size]
        if transform_type in ['bspline', 'wendland']:
            transform_opts['sigma'] = sigma
            fixed_image_pyramid = al.create_image_pyramid(fixed_image, pyramid)
            moving_image_pyramid = al.create_image_pyramid(
                moving_image, pyramid)
        else:
            sigma, fixed_image_pyramid, moving_image_pyramid = [[]], [[
                fixed_image
            ]], [[moving_image]]
        if transform_type == 'bspline':
            transform_opts['order'] = order
        if transform_type == 'wendland':
            transform_opts['cp_scale'] = order

    for level, (mov_im_level, fix_im_level) in enumerate(
            zip(moving_image_pyramid, fixed_image_pyramid)):

        # choose the affine transformation model
        if transform_type == 'non_parametric':
            transform_args[0] = mov_im_level[level].size
        elif transform_type in ['bspline', 'wendland']:
            # for bspline, sigma must be positive tuple of ints
            # for bspline, smaller sigma tuple means less loss of
            # microarchitectural details

            # transform_opts['sigma'] = sigma[level]
            transform_opts['sigma'] = (1, 1)

        transformation = transforms[transform_type](*transform_args,
                                                    **transform_opts)

        # if level > 0 and transform_type=='bspline':
        # 	constant_flow = al.transformation.utils.upsample_displacement(constant_flow,
        # 																  mov_im_level.size,
        # 																  interpolation=interpolation)
        # 	transformation.set_constant_flow(constant_flow)

        if transform_type in ['similarity', 'affine', 'rigid']:
            # initialize the translation with the center of mass of the fixed image
            transformation.init_translation(fixed_image)

        registration.set_transformation(transformation)

        loss_fns = dict(mse=al.loss.pairwise.MSE,
                        ncc=al.loss.pairwise.NCC,
                        lcc=al.loss.pairwise.LCC,
                        mi=al.loss.pairwise.MI,
                        mgf=al.loss.pairwise.NGF,
                        ssim=al.loss.pairwise.SSIM)

        # choose the Mean Squared Error as image loss
        image_loss = loss_fns[loss_fn](fixed_image, moving_image)

        registration.set_image_loss([image_loss])

        # choose the Adam optimizer to minimize the objective
        optimizer = th.optim.Adam(transformation.parameters(),
                                  lr=lr,
                                  amsgrad=True)

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(iterations)

        # start the registration
        registration.start()

        # if transform_type == 'bspline':
        # 	constant_flow = transformation.get_flow()

    # set the intensities back to the original for the visualisation
    fixed_image.image = 1 - fixed_image.image
    moving_image.image = 1 - moving_image.image

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    end = time.perf_counter()

    print("=================================================================")

    print("Registration done in:", end - start, "s")
    if transform_type in ['similarity', 'affine', 'rigid']:
        print("Result parameters:")
        transformation.print()

    # plot the results
    plt.subplot(131)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(132)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(133)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    if transform_type in ['similarity', 'affine', 'rigid']:
        transformation_param = transformation._phi_z
    elif transform_type == 'non_parametric':
        transformation_param = transformation.trans_parameters
    elif transform_type == 'bspline' or transform_type == 'wendland':
        transformation_param = transformation._kernel
    else:
        pass
    return displacement, warped_image, transformation_param, registration.loss.data.item(
    )
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    #device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    itkImg = sitk.ReadImage("./data/affine_test_image_2d_fixed.png",
                            sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    fixed_image = al.create_tensor_image_from_itk_image(itkImg,
                                                        dtype=dtype,
                                                        device=device)

    itkImg = sitk.ReadImage("./data/affine_test_image_2d_moving.png",
                            sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    moving_image = al.create_tensor_image_from_itk_image(itkImg,
                                                         dtype=dtype,
                                                         device=device)

    # create image pyramide size/4, size/2, size/1
    fixed_image_pyramid = al.create_image_pyramid(fixed_image,
                                                  [[4, 4], [2, 2]])
    moving_image_pyramid = al.create_image_pyramid(moving_image,
                                                   [[4, 4], [2, 2]])

    constant_displacement = None
    regularisation_weight = [1, 5, 50]
    number_of_iterations = [500, 500, 500]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(
            zip(moving_image_pyramid, fixed_image_pyramid)):

        registration = al.PairwiseRegistration(dtype=dtype, device=device)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(
            mov_im_level.size,
            sigma=sigma[level],
            order=3,
            dtype=dtype,
            device=device)

        if level > 0:
            constant_displacement = al.transformation.utils.upsample_displacement(
                constant_displacement,
                mov_im_level.size,
                interpolation="linear")

            transformation.set_constant_displacement(constant_displacement)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(
            mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        #define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_displacement = transformation.get_displacement()

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)
    displacement = al.create_displacement_image_from_image(
        displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")

    # plot the results
    plt.subplot(221)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(222)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(223)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    plt.subplot(224)
    plt.imshow(displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Displacement')

    plt.show()