Example #1
0
def test_drr_registration():
    import numpy as np
    from PIL import Image
    from camera import Camera
    from raybox import RayBox
    from drr_set import DrrSet
    from drr_registration import DrrRegistration
    from utils import read_rho, str_to_mat
    xray1 = Image.open('Test_Data/Sawbones_L2L3/0.bmp').convert('L')
    xray2 = Image.open('Test_Data/Sawbones_L2L3/90.bmp').convert('L')
    xray1 = np.array(xray1).astype(np.float32)
    xray2 = np.array(xray2).astype(np.float32)
    xray1 = (xray1-xray1.min())/(xray1.max()-xray1.min())
    xray2 = (xray2-xray2.min())/(xray2.max()-xray2.min())
    m1 = str_to_mat('[-0.785341, -0.068020, -0.615313, -5.901115; 0.559239, 0.348323, -0.752279, -4.000824; 0.265498, -0.934903, -0.235514, -663.099792]')
    m2 = str_to_mat('[-0.214846, 0.964454, 0.153853, 12.792526; 0.557581, 0.250463, -0.791436, -6.176056; -0.801838, -0.084251, -0.591572, -627.625305]')
    k1 = str_to_mat('[3510.918213, 0.000000, 368.718994; 0.000000, 3511.775635, 398.527802; 0.000000, 0.000000, 1.000000]')
    k2 = str_to_mat('[3533.860352, 0.000000, 391.703888; 0.000000, 3534.903809, 395.485229; 0.000000, 0.000000, 1.000000]')
    cam1, cam2 = Camera(m1, k1), Camera(m2, k2)
    raybox = RayBox('cpu')
    rho, sp = read_rho('Test_Data/Sawbones_L2L3/sawbones.nii.gz')
    raybox.set_rho(rho, sp)
    drr_set = DrrSet(cam1, cam2, raybox)
    drr_registration = DrrRegistration(xray1, xray2, drr_set)
    res = drr_registration.register(np.array([-98.92, -106.0, -185.0, -35.0, 25.0, 175]))
    print(res)
Example #2
0
def test_drr_sawbones():
    import numpy as np
    import SimpleITK as sitk
    from camera import Camera
    from raybox import RayBox
    import matplotlib.pyplot as plt
    h = np.int32(460)
    w = np.int32(460)
    m = np.array([[0, 0, -1, 143],
                   [1,  0, 0, -96],
                   [0,  -1, 0, -770]])
    k = np.array([[1001, 0,       204.5, 0],
                  [0,       1001, 137.3, 0],
                  [0,       0,        1, 0]])
    k[0,0], k[1,1] = 1.5*k[0,0], 1.5*k[1,1]
    # m = np.array([[ 1.        ,  0.        ,  0.        ,  2.        ],
    #    [ 0.        ,  0.39073113, -0.92050485, -0.64241966],
    #    [ 0.        ,  0.92050485,  0.39073113, -4.07275054],
    #    [ 0.        ,  0.        ,  0.        ,  1.        ]])
    rho = sitk.GetArrayFromImage(sitk.ReadImage('Test_Data/Sawbones_L2L3/sawbones.nii.gz')).transpose((1, 2, 0)).astype(np.float32)
    sp = np.array([0.375, 0.375, 0.625], dtype=np.float32)[[1, 0, 2]]
    n = np.array([513, 513, 456], dtype=np.int32)[[1, 0, 2]]
    rho = rho[::, ::-1, ::]
    #rho = np.ones((512, 512, 455))
    cam1 = Camera(m, k, h=h, w=w)
    cam2 = Camera(m, k, h=h, w=w)
    raybox = RayBox('cpu')
    raybox.set_cams(cam1, cam2)
    raybox.set_rho(rho, sp)
    raybox.mode = 'gpu'
    raysums1, raysums2 = raybox.trace_rays()
    print(raysums1.max(), raysums2.max())
    plt.imsave('raysums1.png', raysums1, cmap='gray', vmin=0, vmax=1)
Example #3
0
def test_raybox_class():
    import numpy as np
    from camera import Camera
    from raybox import RayBox
    import matplotlib.pyplot as plt
    # m = np.array([[1, 0, 0,  2],
    #               [0, 0, -1, 1],
    #               [0, 1, 0,  -4],
    #               [0, 0, 0,  1]])
    m = np.array([[0, -1, 0, -1],
                   [0, 0, -1, 1],
                   [1, 0, 0, -3],
                   [0, 0, 0, 1]])
    # m = np.array([[ 1.        ,  0.        ,  0.        ,  2.        ],
    #    [ 0.        ,  0.39073113, -0.92050485, -0.64241966],
    #    [ 0.        ,  0.92050485,  0.39073113, -4.07275054],
    #    [ 0.        ,  0.        ,  0.      #  ,  1.        ]])
    h = np.int32(768)
    w = np.int32(768)
    n = np.array([8, 8, 8], dtype=np.int32)
    sp = np.array([1, 1, 1], dtype=np.float32)
    k = np.array([[2*(h/2), 0,       1*(h/2), 0],
                  [0,       2*(w/2), 1*(w/2), 0],
                  [0,       0,       1,       0]])
    rho = np.ones((n-1).tolist(), dtype=np.float32)
    cam1 = Camera(m, k, h=h, w=w)
    cam2 = Camera(m, k, h=h, w=w)
    raybox = RayBox('gpu')
    raybox.set_cams(cam1, cam2)
    raybox.set_rho(rho, sp)
    raysums1, raysums2 = raybox.trace_rays()
    print(raysums1.max(), raysums2.max())
    plt.imsave('raysums1.png', raysums1, cmap='gray')
Example #4
0
def test_trace_rays():
    from camera import Camera
    from raybox import RayBox
    import numpy as np
    import matplotlib.pyplot as plt
    m = np.array([[1, 0, 0,  2],
                  [0, 0, -1, 1],
                  [0, 1, 0,  -4],
                  [0, 0, 0,  1]])
    h = 8
    w = 8
    k = np.array([[2*(h/2), 0,       1*(h/2), 0],
                  [0,       2*(w/2), 1*(w/2), 0],
                  [0,       0,       1,       0]])
    cam = Camera(m=m, k=k, h=h, w=w)
    bs = np.array([-3, -2, 0])
    ns = [3, 3, 3]
    spacing = [1, 1, 1]
    rho = np.ones((ns[0] - 1, ns[1] - 1, ns[2] - 1))
    image = RayBox.trace_rays(bs, ns, spacing, rho, cam)
    image = (image - image.min())/(image.max() - image.min())
Example #5
0
def test_ray_minmax_intersec():
    import numpy as np
    from raybox import RayBox
    from ray import Ray
    raybox = RayBox([2, 0, 0], [3, 3, 3], [1, 1, 1])
    i = np.array([0, -4, 0])
    j = np.array([3, 0, 1])
    k = np.array([4, 4/3, 4/3])
    l = np.array([3, -4, 1])
    m = np.array([3, 2, 1])
    n = np.array([0, 0, 0])
    o = np.array([1, 0, 0])
    # Test 1
    ray = Ray(i, j)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    np.testing.assert_almost_equal(pt1, j)
    np.testing.assert_almost_equal(pt2, k)
    print('Test 1 OK')
    # Test 2
    ray = Ray(l, j)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    np.testing.assert_almost_equal(pt1, j)
    np.testing.assert_almost_equal(pt2, m)
    print('Test 2 OK')
    # Test 3
    ray = Ray(l, n)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    assert pt1 is None and pt2 is None
    print('Test 3 OK')
    # Test 4
    ray = Ray(l, o)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    assert pt1 is None and pt2 is None
    print('Test 4 OK')
    # Test 5
    ray = Ray(i, n)
    pt1, _pt2 = raybox.get_ray_minmax_intersec(ray)
    assert pt1 is None and pt2 is None
    print('Test 5 OK')
    # Test 6
    ray = Ray(i, o)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    assert pt1 is None and pt2 is None
    print('Test 6 OK')
    # Non unit spacing
    raybox = RayBox([2, 0, 0], [5, 5, 5], [0.5, 0.5, 0.5])
    # Test 7
    ray = Ray(i, j)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    np.testing.assert_almost_equal(pt1, j)
    np.testing.assert_almost_equal(pt2, k)
    print('Test 7 OK')
    # Test 8
    ray = Ray(l, j)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    np.testing.assert_almost_equal(pt1, j)
    np.testing.assert_almost_equal(pt2, m)
    print('Test 8 OK')
    # Test 9
    ray = Ray(l, n)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    assert pt1 is None and pt2 is None
    print('Test 9 OK')
    # Test 10
    ray = Ray(l, o)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    assert pt1 is None and pt2 is None
    print('Test 10 OK')
    # Test 11
    ray = Ray(i, n)
    pt1, _pt2 = raybox.get_ray_minmax_intersec(ray)
    assert pt1 is None and pt2 is None
    print('Test 11 OK')
    # Test 12
    ray = Ray(i, o)
    pt1, pt2 = raybox.get_ray_minmax_intersec(ray)
    assert pt1 is None and pt2 is None
    print('Test 12 OK')
Example #6
0
def test_get_radiological_path():
    import numpy as np
    from raybox import RayBox
    from ray import Ray
    c = np.array([-3, -2, 0])
    i = np.array([0, 4, 0])
    j = np.array([-2, 0, 1])
    k = np.array([3, 0, 0])
    l = np.array([-1, -1, 1])
    m = np.array([-2, 4, 1])
    n = np.array([-3, -1.5, 1.5])
    o = np.array([-3, -2, 1.5])
    p = np.array([-2, -2, 1])
    q = np.array([0, 0, 0])
    r = np.array([-2, -2, 2])
    s = np.array([-1.2, -2, 1.2])
    t = np.array([-1.2, 0, 1])
    ln = 2.12
    jp = 2
    jo = 2.29
    lr = 1.73
    tl = 1.02
    sl = 1.04
    rij = Ray(i, j)
    rkl = Ray(k, l)
    rmj = Ray(m, j)
    rql = Ray(q, l)
    ril = Ray(i, l)
    rml = Ray(m, l)
    raybox = RayBox(c, [3, 3, 3], [1, 1, 1])
    alphas_rij = (1.0, 1.5, 0.5, 1.5, 1.0, 1.5, 0.0, 2.0)
    alphas_rkl = (1.0, 1.5, 1.0, 1.5, -0.0, 2.0, 0.0, 2.0)
    alphas_rmj = (1.0, 1.5, float("-inf"), float("inf"),
                  1.0, 1.5, float("-inf"), float("inf"))
    alphas_rql = (1.0, 2.0, 1.0, 3.0, -0.0, 2.0, 0.0, 2.0)
    alphas_ril = (1.0, 1.2, 1.0, 3.0, 0.8, 1.2, 0.0, 2.0)
    alphas_rml = (0.8, 1.0, -1.0, 1.0, 0.8, 1.2, float("-inf"), float("inf"))
    np.testing.assert_almost_equal(k + (alphas_rkl[0])*(l-k), l)
    np.testing.assert_almost_equal(k + (alphas_rkl[1])*(l-k), n)
    np.testing.assert_almost_equal(
        raybox.get_radiological_path(alphas_rkl, rkl), ln, decimal=2)
    print('Test 1 OK')
    np.testing.assert_almost_equal(i + (alphas_rij[0])*(j-i), j)
    np.testing.assert_almost_equal(i + (alphas_rij[1])*(j-i), o)
    np.testing.assert_almost_equal(
        raybox.get_radiological_path(alphas_rij, rij), jo, decimal=2)
    print('Test 2 OK')
    np.testing.assert_almost_equal(m + (alphas_rmj[0])*(j-m), j)
    np.testing.assert_almost_equal(m + (alphas_rmj[1])*(j-m), p)
    np.testing.assert_almost_equal(
        raybox.get_radiological_path(alphas_rmj, rmj), jp, decimal=2)
    print('Test 3 OK')
    np.testing.assert_almost_equal(q + (alphas_rql[0])*(l-q), l)
    np.testing.assert_almost_equal(q + (alphas_rql[1])*(l-q), r)
    np.testing.assert_almost_equal(
        raybox.get_radiological_path(alphas_rql, rql), lr, decimal=2)
    print('Test 4 OK')
    np.testing.assert_almost_equal(i + (alphas_ril[0])*(l-i), l)
    np.testing.assert_almost_equal(i + (alphas_ril[1])*(l-i), s)
    np.testing.assert_almost_equal(
        raybox.get_radiological_path(alphas_ril, ril), sl, decimal=2)
    print('Test 5 OK')
    np.testing.assert_almost_equal(m + (alphas_rml[0])*(l-m), t)
    np.testing.assert_almost_equal(m + (alphas_rml[1])*(l-m), l)
    np.testing.assert_almost_equal(
        raybox.get_radiological_path(alphas_rml, rml), tl, decimal=2)
    print('Test 6 OK')
Example #7
0
 def __init__(self, parent=None):
     super().__init__(parent)
     self.setWindowTitle('DRR Viewer')
     self.setMinimumSize(QtCore.QSize(800, 720))
     size_policy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred,
                                         QtWidgets.QSizePolicy.Preferred)
     size_policy.setHeightForWidth(True)
     self.setSizePolicy(size_policy)
     self.central_widg = QtWidgets.QWidget(self)
     # File dialog
     self.file_dialog = QtWidgets.QFileDialog()
     self.file_dialog.setViewMode(QtWidgets.QFileDialog.List)
     self.file_dialog_out = ''
     # Menu
     self.menu = self.menuBar()
     self.file_menu = self.menu.addMenu('File')
     self.fd_in_out = {
         'CT': [self.on_ct_menu, ''],
         'Camera Files': [self.on_cam_menu, ''],
         'C-Arm Images': [self.on_carm_menu, '']
     }
     for entry in ['CT', 'Camera Files', 'C-Arm Images']:
         action = QtWidgets.QAction('Open {}'.format(entry), self)
         action.triggered.connect(self.fd_in_out[entry][0])
         self.file_menu.addAction(action)
     save_drr1_action = QtWidgets.QAction('Save first DRR as ...', self)
     save_drr1_action.triggered.connect(self.on_save_drr1)
     self.file_menu.addAction(save_drr1_action)
     save_drr2_action = QtWidgets.QAction('Save second DRR as ...', self)
     save_drr2_action.triggered.connect(self.on_save_drr2)
     self.file_menu.addAction(save_drr2_action)
     save_params_action = QtWidgets.QAction('Save parameters as ...', self)
     save_params_action.triggered.connect(self.on_save_params)
     self.file_menu.addAction(save_params_action)
     load_params_action = QtWidgets.QAction('Load parameters from ...',
                                            self)
     load_params_action.triggered.connect(self.on_load_params)
     self.file_menu.addAction(load_params_action)
     save_setup_action = QtWidgets.QAction('Save current setup as ...',
                                           self)
     save_setup_action.triggered.connect(self.on_save_setup)
     self.file_menu.addAction(save_setup_action)
     self.edit_menu = self.menu.addMenu('Edit')
     # Status bar
     self.status_bar = self.statusBar()
     # Images
     self.img1_widg = ImageWidget(self.central_widg)
     self.img2_widg = ImageWidget(self.central_widg)
     self.img3_widg = ImageWidget(self.central_widg)
     self.img4_widg = ImageWidget(self.central_widg)
     # Parameter widgs
     self.params_widg = ParametersWidget(self.central_widg)
     self.params_widg.params_edited.connect(self.on_refresh)
     # Threshold widg
     self.threshold_widg = ThresholdWidget(self.central_widg)
     # Refresh button
     self.refresh_butn = QtWidgets.QPushButton('Refresh', self)
     # Recenter button
     self.recenter_widg = RecenterWidget(self)
     # Layout
     refresh_layout = QtWidgets.QVBoxLayout()
     refresh_layout.addWidget(self.refresh_butn)
     refresh_layout.addWidget(self.recenter_widg)
     # TODO: Add Labels (AP, LAT, left, right)
     # TODO: Require at least AP and LAT
     # TODO: 4x4 grid layout
     left_imgs_layout = QtWidgets.QVBoxLayout()
     left_imgs_layout.addWidget(self.img1_widg)
     left_imgs_layout.addWidget(self.img3_widg)
     right_imgs_layout = QtWidgets.QVBoxLayout()
     right_imgs_layout.addWidget(self.img2_widg)
     right_imgs_layout.addWidget(self.img4_widg)
     # layout = QtWidgets.QGridLayout()
     # layout.addLayout(left_imgs_layout, 0, 0, 1, 2)
     # layout.addLayout(right_imgs_layout, 0, 3, 1, 2)
     # layout.addWidget(self.threshold_widg, 0, 2, 1, 1)
     # layout.addWidget(self.params_widg, 1, 0, 1, 3)
     # layout.addLayout(refresh_layout, 1, 3, 1, 2)
     # layout.setRowStretch(0, 1)
     layout = QtWidgets.QVBoxLayout()
     top_layout = QtWidgets.QHBoxLayout()
     top_layout.addLayout(left_imgs_layout, 1)
     top_layout.addWidget(self.threshold_widg, 0)
     top_layout.addLayout(right_imgs_layout, 1)
     bottom_layout = QtWidgets.QHBoxLayout()
     bottom_layout.addWidget(self.params_widg, 1)
     bottom_layout.addLayout(refresh_layout, 0)
     layout.addLayout(top_layout, 2)
     layout.addLayout(bottom_layout, 0)
     self.central_widg.setLayout(layout)
     self.setCentralWidget(self.central_widg)
     # Focus
     self.setFocusPolicy(QtCore.Qt.ClickFocus)
     # Connect signals and slots
     self.base_pixmap_1.connect(self.img1_widg.on_base_pixmap)
     self.base_pixmap_2.connect(self.img2_widg.on_base_pixmap)
     self.base_pixmap_3.connect(self.img3_widg.on_base_pixmap)
     self.base_pixmap_4.connect(self.img4_widg.on_base_pixmap)
     self.params_widg.alpha_slider.valueChanged.connect(
         self.on_alphaslider_update)
     self.refresh_butn.released.connect(self.on_refresh_btn)
     self.threshold_widg.new_threshold.connect(self.on_new_threshold)
     self.recenter_widg.new_center.connect(self.on_new_center)
     self.alpha.connect(self.img1_widg.on_alpha)
     self.alpha.connect(self.img2_widg.on_alpha)
     self.alpha.connect(self.img3_widg.on_alpha)
     self.alpha.connect(self.img4_widg.on_alpha)
     self.drr1.connect(self.img1_widg.on_drr)
     self.drr2.connect(self.img2_widg.on_drr)
     self.drr3.connect(self.img3_widg.on_drr)
     self.drr4.connect(self.img4_widg.on_drr)
     self.img1_widg.roi_finalize.connect(self.on_roi_finalize_1)
     self.img2_widg.roi_finalize.connect(self.on_roi_finalize_2)
     self.img3_widg.roi_finalize.connect(self.on_roi_finalize_3)
     self.img4_widg.roi_finalize.connect(self.on_roi_finalize_4)
     # Logic
     self.raybox = RayBox()
     self.drr_set = DrrSet(self.raybox)
     self.drr_registration = DrrRegistration(self.drr_set)
     runoptim_action = QtWidgets.QAction('Run Optimizer', self)
     runoptim_action.triggered.connect(self.on_runoptim_action)
     self.edit_menu.addAction(runoptim_action)
     popup_3d_action = QtWidgets.QAction('Pop up 3D visualizer', self)
     self.edit_menu.addAction(popup_3d_action)
     popup_3d_action.triggered.connect(self.plot_set)
     gpu_mode_action = QtWidgets.QAction('GPU mode', self)
     gpu_mode_action.setCheckable(True)
     gpu_mode_action.toggled.connect(self.on_toggled_gpu_mode)
     self.edit_menu.addAction(gpu_mode_action)
     self.autorefresh = False
     autorefresh_action = QtWidgets.QAction('Auto refresh', self)
     autorefresh_action.setCheckable(True)
     autorefresh_action.toggled.connect(self.on_autorefresh_toggle)
     self.edit_menu.addAction(autorefresh_action)
     roiselection_action = QtWidgets.QAction('Enable ROI selection', self)
     roiselection_action.setCheckable(True)
     roiselection_action.toggled.connect(self.img1_widg.on_enable_roi)
     roiselection_action.toggled.connect(self.img2_widg.on_enable_roi)
     roiselection_action.toggled.connect(self.img3_widg.on_enable_roi)
     roiselection_action.toggled.connect(self.img4_widg.on_enable_roi)
     self.edit_menu.addAction(roiselection_action)
Example #8
0
class MainWindow(QtWidgets.QMainWindow):
    new_ct = QtCore.Signal(list)
    input_cams_sig = QtCore.Signal(list)
    base_pixmap_1 = QtCore.Signal(QtGui.QPixmap)
    base_pixmap_2 = QtCore.Signal(QtGui.QPixmap)
    base_pixmap_3 = QtCore.Signal(QtGui.QPixmap)
    base_pixmap_4 = QtCore.Signal(QtGui.QPixmap)
    drr1 = QtCore.Signal(np.ndarray)
    drr2 = QtCore.Signal(np.ndarray)
    drr3 = QtCore.Signal(np.ndarray)
    drr4 = QtCore.Signal(np.ndarray)
    alpha = QtCore.Signal(float)

    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle('DRR Viewer')
        self.setMinimumSize(QtCore.QSize(800, 720))
        size_policy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred,
                                            QtWidgets.QSizePolicy.Preferred)
        size_policy.setHeightForWidth(True)
        self.setSizePolicy(size_policy)
        self.central_widg = QtWidgets.QWidget(self)
        # File dialog
        self.file_dialog = QtWidgets.QFileDialog()
        self.file_dialog.setViewMode(QtWidgets.QFileDialog.List)
        self.file_dialog_out = ''
        # Menu
        self.menu = self.menuBar()
        self.file_menu = self.menu.addMenu('File')
        self.fd_in_out = {
            'CT': [self.on_ct_menu, ''],
            'Camera Files': [self.on_cam_menu, ''],
            'C-Arm Images': [self.on_carm_menu, '']
        }
        for entry in ['CT', 'Camera Files', 'C-Arm Images']:
            action = QtWidgets.QAction('Open {}'.format(entry), self)
            action.triggered.connect(self.fd_in_out[entry][0])
            self.file_menu.addAction(action)
        save_drr1_action = QtWidgets.QAction('Save first DRR as ...', self)
        save_drr1_action.triggered.connect(self.on_save_drr1)
        self.file_menu.addAction(save_drr1_action)
        save_drr2_action = QtWidgets.QAction('Save second DRR as ...', self)
        save_drr2_action.triggered.connect(self.on_save_drr2)
        self.file_menu.addAction(save_drr2_action)
        save_params_action = QtWidgets.QAction('Save parameters as ...', self)
        save_params_action.triggered.connect(self.on_save_params)
        self.file_menu.addAction(save_params_action)
        load_params_action = QtWidgets.QAction('Load parameters from ...',
                                               self)
        load_params_action.triggered.connect(self.on_load_params)
        self.file_menu.addAction(load_params_action)
        save_setup_action = QtWidgets.QAction('Save current setup as ...',
                                              self)
        save_setup_action.triggered.connect(self.on_save_setup)
        self.file_menu.addAction(save_setup_action)
        self.edit_menu = self.menu.addMenu('Edit')
        # Status bar
        self.status_bar = self.statusBar()
        # Images
        self.img1_widg = ImageWidget(self.central_widg)
        self.img2_widg = ImageWidget(self.central_widg)
        self.img3_widg = ImageWidget(self.central_widg)
        self.img4_widg = ImageWidget(self.central_widg)
        # Parameter widgs
        self.params_widg = ParametersWidget(self.central_widg)
        self.params_widg.params_edited.connect(self.on_refresh)
        # Threshold widg
        self.threshold_widg = ThresholdWidget(self.central_widg)
        # Refresh button
        self.refresh_butn = QtWidgets.QPushButton('Refresh', self)
        # Recenter button
        self.recenter_widg = RecenterWidget(self)
        # Layout
        refresh_layout = QtWidgets.QVBoxLayout()
        refresh_layout.addWidget(self.refresh_butn)
        refresh_layout.addWidget(self.recenter_widg)
        # TODO: Add Labels (AP, LAT, left, right)
        # TODO: Require at least AP and LAT
        # TODO: 4x4 grid layout
        left_imgs_layout = QtWidgets.QVBoxLayout()
        left_imgs_layout.addWidget(self.img1_widg)
        left_imgs_layout.addWidget(self.img3_widg)
        right_imgs_layout = QtWidgets.QVBoxLayout()
        right_imgs_layout.addWidget(self.img2_widg)
        right_imgs_layout.addWidget(self.img4_widg)
        # layout = QtWidgets.QGridLayout()
        # layout.addLayout(left_imgs_layout, 0, 0, 1, 2)
        # layout.addLayout(right_imgs_layout, 0, 3, 1, 2)
        # layout.addWidget(self.threshold_widg, 0, 2, 1, 1)
        # layout.addWidget(self.params_widg, 1, 0, 1, 3)
        # layout.addLayout(refresh_layout, 1, 3, 1, 2)
        # layout.setRowStretch(0, 1)
        layout = QtWidgets.QVBoxLayout()
        top_layout = QtWidgets.QHBoxLayout()
        top_layout.addLayout(left_imgs_layout, 1)
        top_layout.addWidget(self.threshold_widg, 0)
        top_layout.addLayout(right_imgs_layout, 1)
        bottom_layout = QtWidgets.QHBoxLayout()
        bottom_layout.addWidget(self.params_widg, 1)
        bottom_layout.addLayout(refresh_layout, 0)
        layout.addLayout(top_layout, 2)
        layout.addLayout(bottom_layout, 0)
        self.central_widg.setLayout(layout)
        self.setCentralWidget(self.central_widg)
        # Focus
        self.setFocusPolicy(QtCore.Qt.ClickFocus)
        # Connect signals and slots
        self.base_pixmap_1.connect(self.img1_widg.on_base_pixmap)
        self.base_pixmap_2.connect(self.img2_widg.on_base_pixmap)
        self.base_pixmap_3.connect(self.img3_widg.on_base_pixmap)
        self.base_pixmap_4.connect(self.img4_widg.on_base_pixmap)
        self.params_widg.alpha_slider.valueChanged.connect(
            self.on_alphaslider_update)
        self.refresh_butn.released.connect(self.on_refresh_btn)
        self.threshold_widg.new_threshold.connect(self.on_new_threshold)
        self.recenter_widg.new_center.connect(self.on_new_center)
        self.alpha.connect(self.img1_widg.on_alpha)
        self.alpha.connect(self.img2_widg.on_alpha)
        self.alpha.connect(self.img3_widg.on_alpha)
        self.alpha.connect(self.img4_widg.on_alpha)
        self.drr1.connect(self.img1_widg.on_drr)
        self.drr2.connect(self.img2_widg.on_drr)
        self.drr3.connect(self.img3_widg.on_drr)
        self.drr4.connect(self.img4_widg.on_drr)
        self.img1_widg.roi_finalize.connect(self.on_roi_finalize_1)
        self.img2_widg.roi_finalize.connect(self.on_roi_finalize_2)
        self.img3_widg.roi_finalize.connect(self.on_roi_finalize_3)
        self.img4_widg.roi_finalize.connect(self.on_roi_finalize_4)
        # Logic
        self.raybox = RayBox()
        self.drr_set = DrrSet(self.raybox)
        self.drr_registration = DrrRegistration(self.drr_set)
        runoptim_action = QtWidgets.QAction('Run Optimizer', self)
        runoptim_action.triggered.connect(self.on_runoptim_action)
        self.edit_menu.addAction(runoptim_action)
        popup_3d_action = QtWidgets.QAction('Pop up 3D visualizer', self)
        self.edit_menu.addAction(popup_3d_action)
        popup_3d_action.triggered.connect(self.plot_set)
        gpu_mode_action = QtWidgets.QAction('GPU mode', self)
        gpu_mode_action.setCheckable(True)
        gpu_mode_action.toggled.connect(self.on_toggled_gpu_mode)
        self.edit_menu.addAction(gpu_mode_action)
        self.autorefresh = False
        autorefresh_action = QtWidgets.QAction('Auto refresh', self)
        autorefresh_action.setCheckable(True)
        autorefresh_action.toggled.connect(self.on_autorefresh_toggle)
        self.edit_menu.addAction(autorefresh_action)
        roiselection_action = QtWidgets.QAction('Enable ROI selection', self)
        roiselection_action.setCheckable(True)
        roiselection_action.toggled.connect(self.img1_widg.on_enable_roi)
        roiselection_action.toggled.connect(self.img2_widg.on_enable_roi)
        roiselection_action.toggled.connect(self.img3_widg.on_enable_roi)
        roiselection_action.toggled.connect(self.img4_widg.on_enable_roi)
        self.edit_menu.addAction(roiselection_action)

    @QtCore.Slot(bool)
    def on_autorefresh_toggle(self, checked):
        if checked:
            self.autorefresh = True
        else:
            self.autorefresh = False

    @QtCore.Slot()
    def on_save_setup(self):
        self.file_dialog.setFileMode(QtWidgets.QFileDialog.AnyFile)
        self.file_dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
        self.file_dialog.setNameFilter('Any files (*)')
        if self.file_dialog.exec_():
            fpath = self.file_dialog.selectedFiles()[0]
            fpath_params = '{}.txt'.format(fpath)
            fpath_drr1 = '{}_drr1.png'.format(fpath)
            fpath_drr2 = '{}_drr2.png'.format(fpath)
            fpath_drr3 = '{}_drr3.png'.format(fpath)
            fpath_drr4 = '{}_drr4.png'.format(fpath)
            with open(fpath_params, 'w') as f:
                params_str = 'Tx = {:.4f}\nTy = {:.4f}\nTz = {:.4f}\nRx = {:.4f}\nRy = {:.4f}\nRz = {:.4f}'.format(
                    self.drr_set.params[0], self.drr_set.params[1],
                    self.drr_set.params[2], self.drr_set.params[3],
                    self.drr_set.params[4], self.drr_set.params[5])
                f.write(params_str)
            plt.imsave(fpath_drr1,
                       self.img1_widg.drr,
                       cmap='gray',
                       vmin=0,
                       vmax=1)
            plt.imsave(fpath_drr2,
                       self.img2_widg.drr,
                       cmap='gray',
                       vmin=0,
                       vmax=1)
            if self.img3_widg.drr is not None:
                plt.imsave(fpath_drr3,
                           self.img3_widg.drr,
                           cmap='gray',
                           vmin=0,
                           vmax=1)
            if self.img4_widg.drr is not None:
                plt.imsave(fpath_drr4,
                           self.img4_widg.drr,
                           cmap='gray',
                           vmin=0,
                           vmax=1)

    @QtCore.Slot()
    def on_save_params(self):
        self.file_dialog.setFileMode(QtWidgets.QFileDialog.AnyFile)
        self.file_dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
        self.file_dialog.setNameFilter('Text Files (*.txt)')
        if self.file_dialog.exec_():
            fpath = self.file_dialog.selectedFiles()[0]
            with open(fpath, 'w') as f:
                params_str = 'Tx = {:.4f}\nTy = {:.4f}\nTz = {:.4f}\nRx = {:.4f}\nRy = {:.4f}\nRz = {:.4f}'.format(
                    self.drr_set.params[0], self.drr_set.params[1],
                    self.drr_set.params[2], self.drr_set.params[3],
                    self.drr_set.params[4], self.drr_set.params[5])
                f.write(params_str)

    @QtCore.Slot()
    def on_load_params(self):
        self.file_dialog.setFileMode(QtWidgets.QFileDialog.ExistingFile)
        self.file_dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptOpen)
        self.file_dialog.setNameFilter('Text Files (*.txt)')
        if self.file_dialog.exec_():
            fpath = self.file_dialog.selectedFiles()[0]
            with open(fpath, 'r') as f:
                s = f.read()
                tx = float(
                    re.search('[Tt][Xx]\s*=\s*([-+]?[0-9]*\.?[0-9]+)',
                              s).group(1))
                ty = float(
                    re.search('[Tt][Yy]\s*=\s*([-+]?[0-9]*\.?[0-9]+)',
                              s).group(1))
                tz = float(
                    re.search('[Tt][Zz]\s*=\s*([-+]?[0-9]*\.?[0-9]+)',
                              s).group(1))
                rx = float(
                    re.search('[Rr][Xx]\s*=\s*([-+]?[0-9]*\.?[0-9]+)',
                              s).group(1))
                ry = float(
                    re.search('[Rr][Yy]\s*=\s*([-+]?[0-9]*\.?[0-9]+)',
                              s).group(1))
                rz = float(
                    re.search('[Rr][Zz]\s*=\s*([-+]?[0-9]*\.?[0-9]+)',
                              s).group(1))
            self.params_widg.set_params(tx, ty, tz, rx, ry, rz)

    @QtCore.Slot()
    def on_save_drr1(self):
        self.file_dialog.setFileMode(QtWidgets.QFileDialog.AnyFile)
        self.file_dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
        self.file_dialog.setNameFilter('Image Files (*.png)')
        if self.file_dialog.exec_():
            fpath = self.file_dialog.selectedFiles()[0]
            plt.imsave(fpath, self.img1_widg.drr, cmap='gray', vmin=0, vmax=1)

    @QtCore.Slot()
    def on_save_drr2(self):
        self.file_dialog.setFileMode(QtWidgets.QFileDialog.AnyFile)
        self.file_dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
        self.file_dialog.setNameFilter('Image Files (*.png)')
        if self.file_dialog.exec_():
            fpath = self.file_dialog.selectedFiles()[0]
            plt.imsave(fpath, self.img2_widg.drr, cmap='gray', vmin=0, vmax=1)

    @QtCore.Slot()
    def on_runoptim_action(self):
        params = np.array(self.params_widg.get_params())
        res = self.drr_registration.register(params)
        print(res)
        new_params = res.x.tolist()
        self.params_widg.set_params(*new_params)
        self.drr_set.set_tfm_params(*new_params)
        self.draw_drrs()

    @QtCore.Slot(bool)
    def on_toggled_gpu_mode(self, checked):
        if checked:
            self.raybox.mode = 'gpu'
        else:
            self.raybox.mode = 'cpu'

    @QtCore.Slot()
    def plot_set(self):
        self.drr_set.plot_camera_set()
        for idx, cam in enumerate(self.drr_set.cams, 1):
            print('Cam', idx)
            print(cam.m)

    @QtCore.Slot(list)
    def on_new_center(self, center):
        self.drr_set.move_to(np.array(center))
        self.params_widg.set_params(*(self.drr_set.params))

    @QtCore.Slot(int)
    def on_alphaslider_update(self, val):
        alpha = val / 100
        print('alpha', alpha)
        self.alpha.emit(alpha)

    @QtCore.Slot(float)
    def on_new_threshold(self, val):
        self.raybox.set_threshold(val)

    @QtCore.Slot()
    def on_ct_menu(self):
        self.file_dialog.setNameFilter('Any files (*)')
        self.file_dialog.setFileMode(QtWidgets.QFileDialog.ExistingFiles)
        self.file_dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptOpen)
        if self.file_dialog.exec_():
            fpaths = self.file_dialog.selectedFiles()
            self.fd_in_out['CT'][1] = fpaths
            self.set_rho(fpaths)

    @QtCore.Slot()
    def on_cam_menu(self):
        self.file_dialog.setNameFilter('Text Files (*.txt)')
        self.file_dialog.setFileMode(QtWidgets.QFileDialog.ExistingFiles)
        self.file_dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptOpen)
        if self.file_dialog.exec_():
            fpaths = self.file_dialog.selectedFiles()
            self.fd_in_out['Camera Files'][1] = fpaths
            print(fpaths)
            self.init_cams_from_path(fpaths)

    @QtCore.Slot()
    def on_carm_menu(self):
        self.file_dialog.setNameFilter('Image Files (*.png, *.bmp)')
        self.file_dialog.setFileMode(QtWidgets.QFileDialog.ExistingFiles)
        self.file_dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptOpen)
        if self.file_dialog.exec_():
            fpaths = self.file_dialog.selectedFiles()
            self.fd_in_out['C-Arm Images'][1] = fpaths
            xrays = []
            for idx, fpath in enumerate(fpaths, 1):
                xrays.append(read_image_as_np(fpath))
                signal = getattr(self, 'base_pixmap_{:d}'.format(idx))
                signal.emit(QtGui.QPixmap(fpath))
            self.drr_registration.set_xrays(*xrays)

    @QtCore.Slot()
    def on_refresh(self):
        if self.autorefresh:
            params = self.params_widg.get_params()
            self.drr_set.set_tfm_params(*params)
            self.draw_drrs()

    @QtCore.Slot()
    def on_refresh_btn(self):
        params = self.params_widg.get_params()
        self.drr_set.set_tfm_params(*params)
        self.draw_drrs()

    def draw_drrs(self):
        drrs = self.raybox.trace_rays()
        assert len(drrs) < 5
        for idx, drr in enumerate(drrs, 1):
            signal = getattr(self, 'drr{:d}'.format(idx))
            signal.emit(drr)

    def set_rho(self, fpaths):
        rho, sp = read_rho(fpaths[0])
        self.raybox.set_rho(rho, sp)
        print('Set Rho')

    def init_cams_from_path(self, fpaths):
        cams = []
        for fpath in fpaths:
            with open(fpath) as f:
                s = f.read()
            m = str_to_mat(re.search('[Mm]\s*=\s*\[(.*)\]', s).group(1))
            k = str_to_mat(re.search('[Kk]\s*=\s*\[(.*)\]', s).group(1))
            h = int(re.search('[Hh]\s*=\s*([0-9]+)', s).group(1))
            w = int(re.search('[Ww]\s*=\s*([0-9]+)', s).group(1))
            cams.append(Camera(m=m, k=k, h=h, w=w))
        # TODO: Add to drr_set
        self.drr_set.set_cams(*cams)
        print('Set cams')

    @QtCore.Slot(float, float, float, float)
    def on_roi_finalize_1(self, a, b, c, d):
        self.drr_registration.mask1 = (a, b, c, d)

    @QtCore.Slot(float, float, float, float)
    def on_roi_finalize_2(self, a, b, c, d):
        self.drr_registration.mask2 = (a, b, c, d)

    @QtCore.Slot(float, float, float, float)
    def on_roi_finalize_3(self, a, b, c, d):
        self.drr_registration.mask3 = (a, b, c, d)

    @QtCore.Slot(float, float, float, float)
    def on_roi_finalize_4(self, a, b, c, d):
        self.drr_registration.mask4 = (a, b, c, d)