Exemplo n.º 1
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    #device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    itkImg = sitk.ReadImage("./data/affine_test_image_2d_fixed.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    fixed_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    itkImg = sitk.ReadImage("./data/affine_test_image_2d_moving.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    moving_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    # create pairwise registration object
    registration = al.PairwiseRegistration(dtype=dtype, device=device)

    # choose the affine transformation model
    transformation = al.transformation.pairwise.RigidTransformation(moving_image.size, dtype=dtype, device=device)

    registration.set_transformation(transformation)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.MSE(fixed_image, moving_image)

    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=0.01)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(100)

    # start the registration
    registration.start()

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image, displacement)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")
    transformation.print()

    # plot the results
    plt.subplot(131)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(132)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(133)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    plt.show()
def affine_reg(img_draw, img_ref, output_path, lr=0.01, iter=200):
    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    # device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    itkImg = sitk.ReadImage(img_ref, sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    fixed_image = al.create_tensor_image_from_itk_image(itkImg,
                                                        dtype=dtype,
                                                        device=device)
    fsize = fixed_image.numpy().flatten().size

    itkImg = sitk.ReadImage(img_draw, sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    moving_image = al.create_tensor_image_from_itk_image(itkImg,
                                                         dtype=dtype,
                                                         device=device)

    # create pairwise registration object
    registration = al.PairwiseRegistration(dtype=dtype, device=device)

    # choose the affine transformation model
    transformation = al.transformation.pairwise.RigidTransformation(
        moving_image.size, dtype=dtype, device=device)
    registration.set_transformation(transformation)

    # choose the scaling regulariser and the diffusion regulariser
    # scale_reg = al.regulariser.parameter.ScalingRegulariser('trans_parameters')
    # scale_reg.set_weight(0.0001)
    # registration.set_regulariser_parameter([scale_reg])
    # dis_reg = al.regulariser.displacement.DiffusionRegulariser(moving_image.spacing, size_average=False)
    # registration.set_regulariser_displacement([dis_reg])
    # dis_reg.set_weight(0.1)

    # choose the Mean Squared Error as image loss
    image_loss = al.loss.pairwise.NCC(fixed_image, moving_image)
    #init_loss = np.sum(np.square(fixed_image.numpy() - moving_image.numpy()))/fsize
    registration.set_image_loss([image_loss])

    # choose the Adam optimizer to minimize the objective
    optimizer = th.optim.Adam(transformation.parameters(), lr=lr)

    registration.set_optimizer(optimizer)
    registration.set_number_of_iterations(iter)

    # start the registration
    registration.start()

    # warp the moving image with the final transformation result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image,
                                                      displacement)

    param = transformation.trans_parameters.detach().numpy()
    translate = np.sqrt(np.square(param[1]) + np.square(param[2]))
    scale = np.sqrt((np.square(param[3] - 1) + np.square(param[4] - 1)) * 0.5)
    init_loss = registration.init_loss.detach().numpy()
    final_loss = registration.img_loss.detach().numpy()

    # plot the results
    # plt.subplot(131)
    # plt.imshow(fixed_image.numpy(), cmap='gray')
    # plt.title('Ref')

    plt.figure(figsize=(13, 6))
    plt.subplot(121)
    plt.xticks([])
    plt.yticks([])
    fixed = fixed_image.numpy()
    moved = moving_image.numpy()
    warp = warped_image.numpy()
    csfont = {'fontname': 'Times New Roman', 'size': 35}

    p1 = np.ones((fixed.shape[0], fixed.shape[1], 3))
    p1[fixed < 1] = [0.5, 0.5, 0.5]
    p1[moved < 1] = [0.4, 0.62, 0.78]
    plt.imshow(p1)
    plt.title('Raw', **csfont)

    plt.subplot(122)
    plt.xticks([])
    plt.yticks([])
    p2 = np.ones((fixed.shape[0], fixed.shape[1], 3))
    p2[fixed < 1] = [0.5, 0.5, 0.5]
    p2[warp < 0.5] = [0.4, 0.62, 0.78]
    plt.imshow(p2)
    plt.title('Transformed', **csfont)

    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    return init_loss, final_loss, np.abs(
        param[0]), translate, scale, warped_image
Exemplo n.º 3
0
def main():
    start = time.time()

    # set the used data type
    dtype = th.float32
    # set the device for the computaion to CPU
    device = th.device("cpu")

    # In order to use a GPU uncomment the following line. The number is the device index of the used GPU
    # Here, the GPU with the index 0 is used.
    #device = th.device("cuda:0")

    # load the image data and normalize to [0, 1]
    itkImg = sitk.ReadImage("./data/affine_test_image_2d_fixed.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    fixed_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    itkImg = sitk.ReadImage("./data/affine_test_image_2d_moving.png", sitk.sitkFloat32)
    itkImg = sitk.RescaleIntensity(itkImg, 0, 1)
    moving_image = al.create_tensor_image_from_itk_image(itkImg, dtype=dtype, device=device)

    # create image pyramide size/4, size/2, size/1
    fixed_image_pyramide = al.create_image_pyramide(fixed_image, [[4, 4], [2, 2]])
    moving_image_pyramide = al.create_image_pyramide(moving_image, [[4, 4], [2, 2]])

    constant_displacement = None
    regularisation_weight = [1, 5, 50]
    number_of_iterations = [500, 500, 500]
    sigma = [[11, 11], [11, 11], [3, 3]]

    for level, (mov_im_level, fix_im_level) in enumerate(zip(moving_image_pyramide, fixed_image_pyramide)):

        registration = al.PairwiseRegistration(dtype=dtype, device=device)

        # define the transformation
        transformation = al.transformation.pairwise.BsplineTransformation(mov_im_level.size,
                                                                          sigma=sigma[level],
                                                                          order=3,
                                                                          dtype=dtype,
                                                                          device=device)

        if level > 0:
            constant_displacement = al.transformation.utils.upsample_displacement(constant_displacement,
                                                                                  mov_im_level.size,
                                                                                  interpolation="linear")

            transformation.set_constant_displacement(constant_displacement)

        registration.set_transformation(transformation)

        # choose the Mean Squared Error as image loss
        image_loss = al.loss.pairwise.MSE(fix_im_level, mov_im_level)

        registration.set_image_loss([image_loss])

        # define the regulariser for the displacement
        regulariser = al.regulariser.displacement.DiffusionRegulariser(mov_im_level.spacing)
        regulariser.SetWeight(regularisation_weight[level])

        registration.set_regulariser_displacement([regulariser])

        #define the optimizer
        optimizer = th.optim.Adam(transformation.parameters())

        registration.set_optimizer(optimizer)
        registration.set_number_of_iterations(number_of_iterations[level])

        registration.start()

        constant_displacement = transformation.get_displacement()

    # create final result
    displacement = transformation.get_displacement()
    warped_image = al.transformation.utils.warp_image(moving_image, displacement)
    displacement = al.create_displacement_image_from_image(displacement, moving_image)

    end = time.time()

    print("=================================================================")

    print("Registration done in: ", end - start)
    print("Result parameters:")

    # plot the results
    plt.subplot(221)
    plt.imshow(fixed_image.numpy(), cmap='gray')
    plt.title('Fixed Image')

    plt.subplot(222)
    plt.imshow(moving_image.numpy(), cmap='gray')
    plt.title('Moving Image')

    plt.subplot(223)
    plt.imshow(warped_image.numpy(), cmap='gray')
    plt.title('Warped Moving Image')

    plt.subplot(224)
    plt.imshow(displacement.magnitude().numpy(), cmap='jet')
    plt.title('Magnitude Displacement')

    plt.show()