예제 #1
0
def reg(args, state, fix_image_filename, moving_image_filename, iterations):

    args_state = state['args']

    image_size = [256, 256]

    gpu_id = args.gpu_id

    device = th.device("cuda:" + str(gpu_id))

    if gpu_id >= 0:
        th.cuda.set_device(gpu_id)

    if args_state.model == "R2NN":
        model = gru.GRU_Registration(image_size, 2, args=args_state, device=device)
    else:
        raise ValueError('model type {0} is not known'.format(args_state.model))

    model.eval()
    model.load_state_dict(state['model'])
    print("model loaded")

    if gpu_id >= 0:
        with th.cuda.device(gpu_id):
            model.cuda()

    if args_state.image_loss == "MSE":
        image_loss = il.MSE()
    else:
        print("Image loss is not suported")

    grid = compute_grid([image_size[0], image_size[1]], device=device)

    if not os.path.exists(args.o):
        os.makedirs(args.o)

    fixed_image = sitk.ReadImage(fix_image_filename, sitk.sitkFloat32)
    fixed_image = th.tensor(sitk.GetArrayFromImage(fixed_image)).squeeze().unsqueeze_(0).unsqueeze_(0)

    fixed_image = fixed_image.to(device=device)
    fixed_image = fixed_image - th.mean(fixed_image)
    fixed_image = fixed_image / th.std(fixed_image)
    fixed_image.clamp_(-2, 2)

    moving_image = sitk.ReadImage(moving_image_filename, sitk.sitkFloat32)
    moving_image = th.tensor(sitk.GetArrayFromImage(moving_image)).squeeze().unsqueeze_(0).unsqueeze_(0)

    moving_image = moving_image.to(device=device)
    moving_image = moving_image - th.mean(moving_image)
    moving_image = moving_image / th.std(moving_image)
    moving_image.clamp_(-2, 2)

    image_loss_f, warped_image, displacement, displacement_param, displacement_pixel, single_displacement, warped_local_image = eval_rnn(iterations,
                                                                            model,
                                                                            fixed_image,
                                                                            moving_image,
                                                                            image_loss, grid)

    show_text = False

    for idx, param in enumerate(displacement_param):

        sigma = 2 * param[0].squeeze().cpu().numpy() * 255
        pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255
        angle = -(param[3].cpu().numpy() * 180.0) / np.pi

        displacement_sum = th.sqrt(displacement_pixel[idx][0, 0, ...].pow(2) + displacement_pixel[idx][0, 1, ...].pow(2))
        fig = plt.imshow(displacement_sum.cpu().squeeze().numpy(), cmap='jet', vmax=0.08, vmin=0)

        ax = plt.gca()
        ax.add_patch(Ellipse(pos, width=sigma[0], height=sigma[1],
                                 angle=angle,
                                 edgecolor='white',
                                 facecolor='none',
                                 linewidth=2))
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.axis('off')
        if show_text:
            plt.text(8, 250, r"transformation sum: $t=" + str(idx) + "$", {'color': 'w', 'fontsize': 18})

        plt.savefig(os.path.join(args.o, "displacement_sum_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0)
        plt.close()

        ##############################################################################################################
        displacement_local = th.sqrt(single_displacement[idx][0, 0, ...].pow(2) + single_displacement[idx][0, 1, ...].pow(2))
        fig = plt.imshow(displacement_local.cpu().squeeze().numpy(), cmap='jet',vmax=0.03, vmin=0)

        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.axis('off')
        if show_text:
            plt.text(32, 15, r"network output: $t=" + str(idx) + "$", {'color': 'w', 'fontsize': 18})
            plt.text(8, 250, r"transformation local: $t=" + str(idx) + "$", {'color': 'w', 'fontsize': 18})
        # plt.show()
        plt.savefig(os.path.join(args.o, "displacement_local_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0)
        plt.close()

        ##############################################################################################################
        fig = plt.imshow(warped_local_image[idx].cpu().squeeze().numpy(), cmap='gray')

        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.axis('off')
        plt.axis('off')
        if show_text:
            if idx == 0:
                plt.text(16, 250, r"moving image", {'color': 'r', 'fontsize': 18})
            else:
                plt.text(16, 250, r"warped image: $t=" + str(idx - 1) + "$", {'color': 'r', 'fontsize': 18})

            plt.text(3, 15, r"input: $t=" + str(idx) + "$", {'color': 'r', 'fontsize': 18})

        plt.savefig(os.path.join(args.o, "warped_loacl_input_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0)

        plt.close()

        ##############################################################################################################
        fig = plt.imshow(warped_local_image[idx + 1].cpu().squeeze().numpy(), cmap='gray')
        ax = plt.gca()
        ax.add_patch(Ellipse(pos, width=sigma[0], height=sigma[1],
                             angle=angle,
                             edgecolor='white',
                             facecolor='none',
                             linewidth=2))
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.axis('off')
        plt.axis('off')
        if show_text:
            plt.text(16, 250, r"warped image: $t=" + str(idx) + "$", {'color': 'r', 'fontsize': 18})
        plt.savefig(os.path.join(args.o, "warped_loacl_output_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0)
        plt.close()

        ##############################################################################################################
        diff_image = warped_local_image[idx].cpu().squeeze().data.fill_(1).numpy()
        fig = plt.imshow(diff_image, cmap='gray')

        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.axis('off')
        if show_text:
            plt.text(128, 150, "Recurrent Registration\n" "Neural Networks for\n" "Deformable Image\n""Registration",
                    {'color': 'w', 'fontsize': 18}, ha='center', wrap=True, )

        plt.savefig(os.path.join(args.o, "diff_image_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0)
        plt.close()

    fig = plt.imshow(fixed_image.cpu().squeeze().numpy(), cmap='gray')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.axis('off')
    if show_text:
        plt.text(64, 250, r"fixed image", {'color': 'r', 'fontsize': 18})
        plt.text(180, 15, r"network ", {'color': 'r', 'fontsize': 18})
    plt.savefig(os.path.join(args.o, "fixed_image.png"), bbox_inches='tight', pad_inches=0)
예제 #2
0
def reg(args, state, fix_image_filename, moving_image_filename, iterations):

    args_state = state['args']

    image_size = [256, 256]

    gpu_id = args.gpu_id

    device = th.device("cuda:" + str(gpu_id))

    if gpu_id >= 0:
        th.cuda.set_device(gpu_id)

    if args_state.model == "R2NN":
        model = gru.GRU_Registration(image_size,
                                     2,
                                     args=args_state,
                                     device=device)
    else:
        raise ValueError('model type {0} is not known'.format(
            args_state.model))

    model.eval()
    model.load_state_dict(state['model'])
    print("model loaded")

    if gpu_id >= 0:
        with th.cuda.device(gpu_id):
            model.cuda()

    if args_state.image_loss == "MSE":
        image_loss = il.MSE()
    else:
        print("Image loss is not suported")

    grid = compute_grid([image_size[0], image_size[1]], device=device)

    if not os.path.exists(args.o):
        os.makedirs(args.o)

    fixed_image = sitk.ReadImage(fix_image_filename, sitk.sitkFloat32)

    fixed_image = th.tensor(sitk.GetArrayFromImage(
        fixed_image)).squeeze().unsqueeze_(0).unsqueeze_(0)

    fixed_image = fixed_image.to(device=device)
    fixed_image = fixed_image - th.mean(fixed_image)
    fixed_image = fixed_image / th.std(fixed_image)
    fixed_image.clamp_(-2, 2)

    moving_image = sitk.ReadImage(moving_image_filename, sitk.sitkFloat32)
    moving_image = th.tensor(sitk.GetArrayFromImage(
        moving_image)).squeeze().unsqueeze_(0).unsqueeze_(0)

    moving_image = moving_image.to(device=device)
    moving_image = moving_image - th.mean(moving_image)
    moving_image = moving_image / th.std(moving_image)
    moving_image.clamp_(-2, 2)

    image_loss_f, warped_image, displacement, displacement_param, displacement_pixel, single_displacement, warped_local_image = eval_rnn(
        iterations, model, fixed_image, moving_image, image_loss, grid)

    weight_x = []
    weight_y = []

    shape_x = []
    shape_y = []

    angle = []

    for idx, param in enumerate(displacement_param):

        shape_x.append(param[0][0].cpu().numpy())
        shape_y.append(param[0][1].cpu().numpy())

        weight_x.append(param[1][0].cpu().numpy())
        weight_y.append(param[1][1].cpu().numpy())

        angle.append(-(param[3].cpu().numpy() * 180.0) / np.pi)

    plt.plot(shape_y, label='shape size $\sigma_x$')
    plt.plot(shape_x, label='shape size $\sigma_y$')
    plt.xlabel("Time steps $t$")
    plt.ylabel("Shape size $\sigma_t$ of the local transformation")

    plt.legend(bbox_to_anchor=(0.5, 0.9), loc=2, borderaxespad=0.)

    matplotlib2tikz.save("/tmp/shape.tex")

    plt.show()

    plt.plot(weight_x, label='weight $v_x$')
    plt.plot(weight_y, label='weight $v_y$')
    plt.xlabel("Time steps $t$")
    plt.ylabel("Weights $v_t$ of the local transformations ")
    plt.legend(bbox_to_anchor=(0.5, 0.9), loc=2, borderaxespad=0.)

    matplotlib2tikz.save("/tmp/weight.tex")
    plt.show()

    plt.plot(angle)
    plt.xlabel("Time steps $t$")
    plt.ylabel("Time steps $t$")
    plt.grid(True)
    plt.show()

    matplotlib2tikz.save("/tmp/test.tex")
예제 #3
0
def test(args, state,  image_size=[256, 256]):

    args_state = state['args']

    gpu_id = args.gpu_id

    device = th.device("cuda:" + str(gpu_id))

    if gpu_id >= 0:
        th.cuda.set_device(gpu_id)

    patients = sorted(os.listdir(args.test_path))

    # compute mean image of all data
    mean_image_filenames = []

    for patient in patients:
        examinations = sorted(os.listdir(os.path.join(args.test_path, patient)))
        for exa in examinations:
            slices = sorted(os.listdir(os.path.join(args.test_path, patient, exa)))
            for image_slice in slices:
                slice_path = os.path.join(args.test_path, patient, exa, image_slice)
                images = sorted(os.listdir(slice_path))

                images = [f for f in images if os.path.isfile(os.path.join(os.path.join(args.test_path, patient, exa, image_slice, f)))]

                mean_image_filenames.append(get_fixe_image_filename(slice_path, images))

    print(len(mean_image_filenames))
    print(mean_image_filenames)

    if args_state.model == "R2NN":
        model = gru.GRU_Registration(image_size, 2, args=args_state, device=device)
    else:
        raise ValueError('model type {0} is not known'.format(args_state.model))

    model.eval()
    model.load_state_dict(state['model'])
    print("model loaded")

    if gpu_id >= 0:
        with th.cuda.device(gpu_id):
            model.cuda()

    if args_state.image_loss == "MSE":
        image_loss = il.MSE()
    else:
        print("Image loss is not suported")

    grid = compute_grid([image_size[0], image_size[1]], device=device)

    if args_state.model == "R2NN":
        evaluate_net = eval_rnn
    elif args_state.model == "UNET":
        evaluate_net = eval_feed_forward


    if not os.path.exists(args.o):
        os.makedirs(args.o)


    out_path_image_data = os.path.join(args.o, "image_data")
    if not os.path.exists(out_path_image_data):
        os.makedirs(out_path_image_data)


    slice_index_global = 0
    gloabl_eval_error = []

    for patient in patients:
        if os.path.exists(os.path.join(args.o, "error_" + patient + ".csv")):
            os.remove(os.path.join(args.o, "error_" + patient + ".csv"))

        if os.path.exists(os.path.join(args.o, "tre_" + patient + ".csv")):
            os.remove(os.path.join(args.o, "tre_" + patient + ".csv"))

    for patient in patients:
        examinations = sorted(os.listdir(os.path.join(args.test_path, patient)))
        image_loss_examination = 0
        for exa in examinations:
            slices = sorted(os.listdir(os.path.join(args.test_path, patient, exa)))
            image_loss_slices = 0
            for image_slice in slices:
                slice_path = os.path.join(args.test_path, patient, exa, image_slice)
                image_filenames = sorted(os.listdir(slice_path))
                image_filenames = [f for f in image_filenames if f.endswith(".dcm")]

                output_path = os.path.join(out_path_image_data, patient, exa, image_slice)

                if not os.path.exists(output_path):
                    os.makedirs(output_path)

                fix_image_filename = os.path.join(slice_path, mean_image_filenames[slice_index_global])

                fixed_image = sitk.ReadImage(os.path.join(slice_path, fix_image_filename), sitk.sitkFloat32)

                # load fixed image landmarks
                fix_landmarks_filenames = os.path.join(slice_path, "landmarks", "landmarks_" + mean_image_filenames[slice_index_global][:-4] + ".vtk")
                fixed_image_points = Points.read(fix_landmarks_filenames)


                fixed_image = th.tensor(sitk.GetArrayFromImage(fixed_image)).squeeze().unsqueeze_(0).unsqueeze_(0)

                fixed_image = fixed_image.to(device=device)
                fixed_image = fixed_image - th.mean(fixed_image)
                fixed_image = fixed_image / th.std(fixed_image)
                fixed_image.clamp_(-2, 2)

                sitk.WriteImage(sitk.GetImageFromArray(fixed_image.detach().cpu().squeeze().numpy()),
                                os.path.join(output_path, "fixed_"
                                             + mean_image_filenames[slice_index_global][:-4] + ".vtk"))
                image_loss_images = 0
                image_loss_images_csv = []

                tre_slice = []
                tre_slice.append(image_slice)
                tre_slice.append(fix_image_filename)

                image_loss_images_csv.append(image_slice)
                image_loss_images_csv.append(fix_image_filename)

                for image_filename in image_filenames:
                    moving_image = sitk.ReadImage(os.path.join(slice_path, image_filename), sitk.sitkFloat32)

                    # get image properties
                    image_spacing = [1, 1]
                    image_origin = [0, 0]


                    moving_image = th.tensor(sitk.GetArrayFromImage(moving_image)).squeeze().unsqueeze_(0) \
                        .unsqueeze_(0)

                    moving_image = moving_image.to(device=device)
                    moving_image = moving_image - th.mean(moving_image)
                    moving_image = moving_image / th.std(moving_image)
                    moving_image.clamp_(-2, 2)

                    # load moving image landmarks
                    moving_landmarks_filenames = os.path.join(slice_path, "landmarks",  "landmarks_" + image_filename[:-4] + ".vtk")
                    moving_image_points = Points.read(moving_landmarks_filenames)

                    start = time.time()

                    image_loss_f, warped_image, displacement = evaluate_net(args_state, model, fixed_image,
                                                                            moving_image, image_loss, grid)

                    stop = time.time()

                    displacement = displacement.flip(2)
                    displacement = displacement.transpose(1, 2).transpose(2, 3)
                    displacement = displacement.squeeze().to(dtype=th.float64, device='cpu')
                    # transform to itk displacement
                    for dim in range(displacement.shape[-1]):
                        tmp = float(displacement.shape[-dim - 2] - 1)
                        displacement[..., dim] = float(displacement.shape[-dim - 2] - 1) * displacement[..., dim] / 2.0

                    itk_displacement = sitk.GetImageFromArray(displacement.numpy(), isVector=True)
                    itk_displacement.SetSpacing(image_spacing)
                    itk_displacement.SetOrigin(image_origin)

                    #
                    # displacement_al =  Displacement(displacement, image_size=[256, 256], image_spacing=image_spacing,
                    #                                 image_origin=image_origin)

                    # displacement_al.image = displacement_al.image*image_spacing[0]

                    moving_points_transformed = Points.transform(moving_image_points, itk_displacement)

                    tre = Points.TRE(moving_points_transformed, fixed_image_points)

                    tre_slice.append(tre)

                    print("Time", stop-start)



                    image_loss_images_csv.append(image_loss_f)

                    Points.write(os.path.join(output_path, "warped_points_" + image_filename[:-4] + ".vtk"), moving_points_transformed)
                    Points.write(os.path.join(output_path, "moving_points_" + image_filename[:-4] + ".vtk"), moving_image_points)
                    Points.write(os.path.join(output_path, "fixed_points_" + image_filename[:-4] + ".vtk"), fixed_image_points)

                    image_loss_images += image_loss_f

                    sitk.WriteImage(sitk.GetImageFromArray(warped_image.detach().cpu().squeeze().numpy()),
                                    os.path.join(output_path, "warped_" + image_filename[:-4] + ".vtk"))
                    sitk.WriteImage(sitk.GetImageFromArray(moving_image.detach().cpu().squeeze().numpy()),
                                    os.path.join(output_path, "moving_" + image_filename[:-4] + ".vtk"))

                    sitk.WriteImage(sitk.GetImageFromArray(
                        displacement.detach().cpu().squeeze().numpy(), isVector=True),
                                    os.path.join(output_path, "displacement_" + image_filename[:-4] + ".vtk"))

                slice_index_global += 1

                with open(os.path.join(args.o, "error_" + patient + ".csv"), 'a') as csvFile:
                    writer = csv.writer(csvFile, delimiter=',')
                    writer.writerow(image_loss_images_csv)

                with open(os.path.join(args.o, "tre_" + patient + ".csv"), 'a') as csvFile:
                    writer = csv.writer(csvFile, delimiter=',')
                    writer.writerow(tre_slice)

                image_loss_images /= len(image_filenames)

                image_loss_slices += image_loss_images

            image_loss_slices /= len(slices)

            image_loss_examination += image_loss_slices

        image_loss_examination /= len(examinations)
        gloabl_eval_error.append(image_loss_examination)

        with open(os.path.join(args.o, "error_all_patients.csv"), 'a') as csvFile:
            writer = csv.writer(csvFile, delimiter=',')
            writer.writerow([patient, examinations, image_loss_examination])
예제 #4
0
def reg(args, state, fix_image_filename, moving_image_filename, iterations):

    args_state = state['args']

    image_size = [256, 256]

    gpu_id = args.gpu_id

    device = th.device("cuda:" + str(gpu_id))

    if gpu_id >= 0:
        th.cuda.set_device(gpu_id)

    if args_state.model == "R2NN":
        model = gru.GRU_Registration(image_size,
                                     2,
                                     args=args_state,
                                     device=device)
    else:
        raise ValueError('model type {0} is not known'.format(
            args_state.model))

    model.eval()
    model.load_state_dict(state['model'])
    print("model loaded")

    if gpu_id >= 0:
        with th.cuda.device(gpu_id):
            model.cuda()

    if args_state.image_loss == "MSE":
        image_loss = il.MSE()
    else:
        print("Image loss is not suported")

    grid = compute_grid([image_size[0], image_size[1]], device=device)

    if not os.path.exists(args.o):
        os.makedirs(args.o)

    fixed_image = sitk.ReadImage(fix_image_filename, sitk.sitkFloat32)

    fixed_image = th.tensor(sitk.GetArrayFromImage(
        fixed_image)).squeeze().unsqueeze_(0).unsqueeze_(0)

    fixed_image = fixed_image.to(device=device)
    fixed_image = fixed_image - th.mean(fixed_image)
    fixed_image = fixed_image / th.std(fixed_image)
    fixed_image.clamp_(-2, 2)

    moving_image = sitk.ReadImage(moving_image_filename, sitk.sitkFloat32)

    moving_image = th.tensor(sitk.GetArrayFromImage(
        moving_image)).squeeze().unsqueeze_(0).unsqueeze_(0)

    moving_image = moving_image.to(device=device)
    moving_image = moving_image - th.mean(moving_image)
    moving_image = moving_image / th.std(moving_image)
    moving_image.clamp_(-2, 2)

    image_loss_f, warped_image, displacement, displacement_param, displacement_pixel = eval_rnn(
        iterations, model, fixed_image, moving_image, image_loss, grid)

    displacement_mag = th.sqrt(displacement_pixel[1][0, 0, ...].pow(2) +
                               displacement_pixel[1][0, 1, ...].pow(2))

    fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(),
                     cmap='jet',
                     vmax=0.08,
                     vmin=0)

    for idx, param in enumerate(displacement_param):
        if idx < 2:
            sigma = 2 * param[0].squeeze().cpu().numpy() * 255

            pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255
            angle = -(param[3].cpu().numpy() * 180.0) / np.pi

            ax = plt.gca()
            ax.add_patch(
                Ellipse(pos,
                        width=sigma[0],
                        height=sigma[1],
                        angle=angle,
                        edgecolor='white',
                        facecolor='none',
                        linewidth=2))
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.axis('off')

    plt.savefig(os.path.join(args.o, "disp_2.png"),
                bbox_inches='tight',
                pad_inches=0)
    plt.close()

    displacement_mag = th.sqrt(displacement_pixel[3][0, 0, ...].pow(2) +
                               displacement_pixel[3][0, 1, ...].pow(2))
    fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(),
                     cmap='jet',
                     vmax=0.08,
                     vmin=0)

    for idx, param in enumerate(displacement_param):
        if idx < 4:
            sigma = 2 * param[0].squeeze().cpu().numpy() * 255

            pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255
            angle = -(param[3].cpu().numpy() * 180.0) / np.pi

            ax = plt.gca()
            ax.add_patch(
                Ellipse(pos,
                        width=sigma[0],
                        height=sigma[1],
                        angle=angle,
                        edgecolor='white',
                        facecolor='none',
                        linewidth=2))
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.axis('off')

    plt.savefig(os.path.join(args.o, "disp_4.png"),
                bbox_inches='tight',
                pad_inches=0)
    plt.close()

    displacement_mag = th.sqrt(displacement_pixel[7][0, 0, ...].pow(2) +
                               displacement_pixel[7][0, 1, ...].pow(2))
    fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(),
                     cmap='jet',
                     vmax=0.08,
                     vmin=0)
    for idx, param in enumerate(displacement_param):
        if idx < 8:
            sigma = 2 * param[0].squeeze().cpu().numpy() * 255

            pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255
            angle = -(param[3].cpu().numpy() * 180.0) / np.pi

            ax = plt.gca()
            ax.add_patch(
                Ellipse(pos,
                        width=sigma[0],
                        height=sigma[1],
                        angle=angle,
                        edgecolor='white',
                        facecolor='none',
                        linewidth=2))
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.axis('off')

    plt.savefig(os.path.join(args.o, "disp_8.png"),
                bbox_inches='tight',
                pad_inches=0)
    plt.close()

    displacement_mag = th.sqrt(displacement[0, 0, ...].pow(2) +
                               displacement[0, 1, ...].pow(2))
    fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(),
                     cmap='jet',
                     vmax=0.08,
                     vmin=0)

    for idx, param in enumerate(displacement_param):
        sigma = 2 * param[0].squeeze().cpu().numpy() * 255

        pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255
        angle = -(param[3].cpu().numpy() * 180.0) / np.pi

        ax = plt.gca()
        ax.add_patch(
            Ellipse(pos,
                    width=sigma[0],
                    height=sigma[1],
                    angle=angle,
                    edgecolor='white',
                    facecolor='none',
                    linewidth=2))
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.axis('off')

    plt.savefig(os.path.join(args.o, "disp_25.png"),
                bbox_inches='tight',
                pad_inches=0)
    plt.close()

    fig = plt.imshow(fixed_image.cpu().squeeze().numpy(), cmap='gray')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.axis('off')
    plt.savefig(os.path.join(args.o, "fixed_image.png"),
                bbox_inches='tight',
                pad_inches=0)
    plt.close()

    fig = plt.imshow(moving_image.cpu().squeeze().numpy(), cmap='gray')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.axis('off')
    plt.savefig(os.path.join(args.o, "moving_image.png"),
                bbox_inches='tight',
                pad_inches=0)
    plt.close()

    fig = plt.imshow(warped_image.cpu().squeeze().numpy(), cmap='gray')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.axis('off')
    plt.savefig(os.path.join(args.o, "warped_image.png"),
                bbox_inches='tight',
                pad_inches=0)
    plt.close()

    fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(),
                     cmap='jet',
                     vmax=0.08,
                     vmin=0)
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.axis('off')
    plt.savefig(os.path.join(args.o, "displacement.png"),
                bbox_inches='tight',
                pad_inches=0)
    plt.close()
예제 #5
0
def train_sync(args):

    continue_optimization = False
    eval_iteration = 0

    if args.model_state != "":
        state = th.load(args.model_state, map_location='cpu')
        continue_optimization = True
        eval_iteration = state['eval_counter']

    gpu_id = args.gpu_ids[0]

    th.manual_seed(args.seed)
    np.random.seed(args.seed)

    device = th.device("cuda:" + str(gpu_id))

    viz = vis.Visdom(port=args.port)

    data_manager = dm.DataManager(args.training_path,
                                  normalize_std=args.normalize_std,
                                  random_sampling=args.random_img_pair)
    image_size = data_manager.image_size()

    # Parameters
    params = {
        'batch_size': args.batch_size,
        'shuffle': True,
        'num_workers': args.nb_workers,
        'pin_memory': True
    }

    training_generator = data.DataLoader(data_manager, **params)

    model = gru.GRU_Registration(image_size, 2, device=device, args=args)

    if continue_optimization:
        model.load_state_dict(state['model'])

    model.train()

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("number of parameters model", params)

    evaluater = Evaluater(args, image_size, eval_iteration=eval_iteration)

    if gpu_id >= 0:
        with th.cuda.device(gpu_id):
            model.cuda()

    if gpu_id >= 0:
        th.cuda.manual_seed(args.seed)
        th.cuda.set_device(gpu_id)

    if args.optimizer == 'RMSprop':
        optimizer = th.optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optimizer == 'Adam':
        optimizer = th.optim.Adam(model.parameters(),
                                  lr=args.lr,
                                  amsgrad=args.amsgrad)
    elif args.optimizer == 'Rprop':
        optimizer = th.optim.Rprop(model.parameters(), lr=args.lr)

    if continue_optimization:
        optimizer.load_state_dict = state['optimizer']

    if args.image_loss == "MSE":
        image_loss = il.MSE()
    else:
        print("Image loss is not suported")

    regulariser = dl.IsotropicTVRegulariser([1.0, 1.0])

    grid = compute_grid([image_size[0], image_size[1]]).cuda()

    train_counter = 0
    if continue_optimization:
        train_counter = state['train_counter']
    loss_plot = None

    print("start optimization")
    scale = 1
    if args.use_diff_loss:
        scale = -1

    while True:

        for fixed_image, moving_image in training_generator:

            fixed_image = fixed_image.cuda()
            moving_image = moving_image.cuda()

            if train_counter % args.eval_interval == 0:
                print("Start evaluation")
                evaluater.evaluation(model)

            image_loss_epoch = 0
            model.reset()
            model.zero_grad()

            warped_image = moving_image

            displacement = th.zeros(args.batch_size,
                                    2,
                                    image_size[0],
                                    image_size[1],
                                    device=fixed_image.device,
                                    dtype=fixed_image.dtype)

            displacement_trans = displacement.transpose(1, 2).transpose(
                2, 3) + grid

            if args.entropy_regularizer_weight > 0:
                shapes = th.zeros(1,
                                  1,
                                  image_size[0],
                                  image_size[1],
                                  device=fixed_image.device,
                                  dtype=fixed_image.dtype)
                single_entropy = 0

            loss_start, _ = image_loss(displacement_trans, fixed_image,
                                       warped_image)

            if args.early_stopping > 0:
                if loss_start.item() < args.early_stopping:
                    continue

            start = time.time()
            for j in range(args.rnn_iter):

                net_input = th.cat((fixed_image, warped_image), dim=1)
                net_ouput = model(net_input)

                displacement = displacement + net_ouput[0]

                if args.entropy_regularizer_weight > 0:
                    f_x = net_ouput[1] / (th.sum(net_ouput[1]) + 1e-5) + 1e-5
                    shapes = shapes + f_x
                    single_entropy = single_entropy + compute_entropy(f_x)

                displacement_trans = displacement.transpose(1, 2).transpose(
                    2, 3) + grid
                warped_image = F.grid_sample(moving_image, displacement_trans)

                loss_, _ = image_loss(displacement_trans, fixed_image,
                                      warped_image)

                if args.use_diff_loss:
                    image_loss_epoch = image_loss_epoch + (loss_start - loss_)
                    loss_start = loss_
                else:
                    image_loss_epoch = image_loss_epoch + loss_

                if args.early_stopping > 0:
                    if loss_.item() < args.early_stopping:
                        break

                if args.stop_on_reverse:
                    if loss_.item() <= loss_start.item():
                        loss_start = loss_
                    else:
                        break

            j = j + 1

            displacement_loss = args.reg_weight * regulariser(displacement)

            loss = scale * image_loss_epoch / j + displacement_loss

            if args.entropy_regularizer_weight > 0:
                entropy_loss = (compute_entropy(shapes / j) + single_entropy /
                                j) * args.entropy_regularizer_weight
                loss = loss - entropy_loss
                entropy_loss_value = entropy_loss.data.item()
            else:
                entropy_loss_value = 0

            optimizer.zero_grad()
            loss.backward()

            if args.clip_gradients:
                th.nn.utils.clip_grad_norm_(model.parameters(), 1)

            optimizer.step()

            end = time.time()

            if train_counter % args.save_model == 0:
                state = {
                    'train_counter': train_counter,
                    'eval_counter': evaluater.eval_iterations,
                    'args': args,
                    'agent_id': -1,
                    'optimizer': optimizer.state_dict(),
                    'model': model.state_dict()
                }

                path = os.path.join(args.o, "state_agent_sync.pt")
                th.save(state, path)

            print("iter ", train_counter, "image loss ",
                  image_loss_epoch.item() / j, "displacement loss ",
                  displacement_loss.item(), "loss ", loss.item(), "time",
                  end - start)

            if loss_plot is None:
                opts = dict(title=("loss_value"),
                            width=1000,
                            height=500,
                            showlegend=True)
                loss_value_ = np.column_stack(
                    np.array([
                        image_loss_epoch.data.item() / j,
                        displacement_loss.data.item(), entropy_loss_value
                    ]))
                loss_plot = viz.line(X=np.column_stack(
                    np.ones(3) * train_counter),
                                     Y=loss_value_,
                                     opts=opts)
            else:
                loss_value_ = np.column_stack(
                    np.array([
                        image_loss_epoch.data.item() / j,
                        displacement_loss.data.item(), entropy_loss_value
                    ]))
                loss_plot = viz.line(X=np.column_stack(
                    np.ones(3) * train_counter),
                                     Y=loss_value_,
                                     win=loss_plot,
                                     update='append')

            if train_counter % 250 == 0:

                fixed_image_vis = imfilter.normalize_image(
                    fixed_image[0, ...]).cpu().unsqueeze(0)
                moving_image_vis = imfilter.normalize_image(
                    moving_image[0, ...]).cpu().unsqueeze(0)

                displacement_vis = imfilter.normalize_image(
                    displacement[0, ...]).cpu().unsqueeze(0).detach()
                warped_image_vis = imfilter.normalize_image(
                    warped_image[0, ...]).cpu().unsqueeze(0).detach()

                checkerboard_image = sitk.GetArrayFromImage(
                    sitk.CheckerBoard(
                        sitk.GetImageFromArray(
                            moving_image_vis.squeeze().numpy()),
                        sitk.GetImageFromArray(
                            fixed_image_vis.squeeze().numpy()), [20, 20]))
                checkerboard_image_vis_nor_reg = th.Tensor(
                    checkerboard_image).unsqueeze(0).unsqueeze(0)

                checkerboard_image = sitk.GetArrayFromImage(
                    sitk.CheckerBoard(
                        sitk.GetImageFromArray(
                            warped_image_vis.squeeze().numpy()),
                        sitk.GetImageFromArray(
                            fixed_image_vis.squeeze().numpy()), [20, 20]))
                checkerboard_image_vis = th.Tensor(
                    checkerboard_image).unsqueeze(0).unsqueeze(0)

                image_stack = th.cat(
                    (fixed_image_vis, moving_image_vis,
                     displacement_vis[:, 0, ...].unsqueeze(1),
                     displacement_vis[:, 1,
                                      ...].unsqueeze(1), warped_image_vis,
                     checkerboard_image_vis, checkerboard_image_vis_nor_reg),
                    dim=0)

                opts = dict(title="results")
                viz.images(image_stack, opts=opts, win=2)

            train_counter += 1