예제 #1
0
 def externalUpdate(self, new_img):
     old_vox = self.image.vox_coords
     if len(new_img.shape) > 3:
         self.image = SlicerImage(new_img.subImage(0))
     else:
         self.image = SlicerImage(new_img)
     self.image.vox_coords = old_vox
     if iscomplex(self.image[:]):
         self.image.prefilter = abs_xform
     self.setNorm()
     self.setUpAxesSize()
     self.updateSlices(self.image.zyx_coords())
예제 #2
0
 def externalUpdate(self, new_img):
     old_vox = self.image.vox_coords
     if len(new_img.shape) > 3:
         self.image = SlicerImage(new_img.subImage(0))
     else:
         self.image = SlicerImage(new_img)
     self.image.vox_coords = old_vox
     if iscomplex(self.image[:]):
         self.image.prefilter = abs_xform
     self.setNorm()
     self.setUpAxesSize()
     self.updateSlices(self.image.zyx_coords())
예제 #3
0
    def initoverlay(self, action):
        image_filter = gtk.FileFilter()
        image_filter.add_pattern("*.hdr")
        image_filter.add_pattern("*.nii")
        image_filter.set_name("Recon Images")
        fname = ask_fname(self,
                          "Choose file to overlay...",
                          action="open",
                          filter=image_filter)
        if not fname:
            return
        img = readImage(fname, vrange=(0, 0))
        self.overlay_img = SlicerImage(img)

        img_dims = N.take(
            N.array(self.image.shape) * self.image.dr, self.image.slicing())
        ovl_dims = N.take(
            N.array(self.overlay_img.shape) * self.overlay_img.dr,
            self.overlay_img.slicing())
        if not (img_dims == ovl_dims).all():
            print img_dims, ovl_dims
            print "Overlay failed because physical dimensions do not align..."
            print "base image dimensions (zyx): [%3.1f %3.1f %3.1f] (mm)" % tuple(
                img_dims)
            print "overlay image dimenensions (zyx: [%3.1f %3.1f %3.1f] (mm)" % tuple(
                ovl_dims)
            return
        self.setNorm()
        (ax, cor, sag) = self.overlay_img.slicing()
        self.ax_overlay = OverLay(self.ax_plot,
                                  ax,
                                  norm=self.overlay_norm,
                                  interpolation=self.ax_plot.interpolation)
        self.cor_overlay = OverLay(self.cor_plot,
                                   cor,
                                   norm=self.overlay_norm,
                                   interpolation=self.cor_plot.interpolation)
        self.sag_overlay = OverLay(self.sag_plot,
                                   sag,
                                   norm=self.overlay_norm,
                                   interpolation=self.sag_plot.interpolation)
        self.overlays = [self.ax_overlay, self.cor_overlay, self.sag_overlay]
        self.updateSlices(self.image.zyx_coords(),
                          sliceplots=self.overlays,
                          image=self.overlay_img,
                          norm=self.overlay_norm)
예제 #4
0
 def initoverlay(self, action):
     image_filter = gtk.FileFilter()
     image_filter.add_pattern("*.hdr")
     image_filter.add_pattern("*.nii")
     image_filter.set_name("Recon Images")
     fname = ask_fname(self, "Choose file to overlay...", action="open",
                       filter=image_filter)
     if not fname:
         return
     img = readImage(fname, vrange=(0,0))
     self.overlay_img = SlicerImage(img)
         
     img_dims = N.take(N.array(self.image.shape) * self.image.dr,
                       self.image.slicing())
     ovl_dims = N.take(N.array(self.overlay_img.shape) * self.overlay_img.dr,
                       self.overlay_img.slicing())
     if not (img_dims == ovl_dims).all():
         print img_dims, ovl_dims
         print "Overlay failed because physical dimensions do not align..."
         print "base image dimensions (zyx): [%3.1f %3.1f %3.1f] (mm)"%tuple(img_dims)
         print "overlay image dimenensions (zyx: [%3.1f %3.1f %3.1f] (mm)"%tuple(ovl_dims)
         return
     self.setNorm()
     (ax, cor, sag) = self.overlay_img.slicing()
     self.ax_overlay = OverLay(self.ax_plot, ax,
                               norm=self.overlay_norm,
                               interpolation=self.ax_plot.interpolation)
     self.cor_overlay = OverLay(self.cor_plot, cor,
                                norm=self.overlay_norm,
                                interpolation=self.cor_plot.interpolation)
     self.sag_overlay = OverLay(self.sag_plot, sag,
                                norm=self.overlay_norm,
                                interpolation=self.sag_plot.interpolation)
     self.overlays = [self.ax_overlay, self.cor_overlay, self.sag_overlay]
     self.updateSlices(self.image.zyx_coords(),
                       sliceplots=self.overlays,
                       image=self.overlay_img,
                       norm=self.overlay_norm)
예제 #5
0
    def __init__(self, image, title="spmclone", parent=None):
        gtk.Window.__init__(self)
        children = self.get_children()
        if children:
            self.remove(children[0])
            self.hide_all()
        table = gtk.Table(4, 2)

        # super dumb way of storing the image object with its original
        # class identification
        self.img_obj = image
        self.image = len(image.shape) > 3 \
                     and SlicerImage(image.subImage(0)) \
                     or SlicerImage(image)
        if iscomplex(self.image[:]):
            self.image.prefilter = abs_xform

        self.overlay_img = None
        self.zoom = 0
        self.slice_patches = None
        self.setNorm()
        self.dimlengths = self.image.dr * N.array(self.image.shape)
        zyx_lim = self.image.extents()
        # I'm using [ax,cor,sag] such that this list informs each
        # sliceplot what dimension it slices in the image array
        # (eg for coronal data, the coronal plot slices the 0th dim)
        ax, cor, sag = self.image.slicing()
        origin = [0, 0, 0]
        # Make the ortho plots ---
        self.ax_plot = SlicePlot(self.image.data_xform(ax, origin),
                                 0,
                                 0,
                                 ax,
                                 norm=self.norm,
                                 extent=(zyx_lim[2] + zyx_lim[1]))

        self.cor_plot = SlicePlot(self.image.data_xform(cor, origin),
                                  0,
                                  0,
                                  cor,
                                  norm=self.norm,
                                  extent=(zyx_lim[2] + zyx_lim[0]))

        self.sag_plot = SlicePlot(self.image.data_xform(sag, origin),
                                  0,
                                  0,
                                  sag,
                                  norm=self.norm,
                                  extent=(zyx_lim[1] + zyx_lim[0]))

        # Although it doesn't matter 99% of the time, this list is
        # expected to be ordered this way
        self.sliceplots = [self.ax_plot, self.cor_plot, self.sag_plot]

        # menu bar
        merge = gtk.UIManager()
        merge.insert_action_group(self._create_action_group(), 0)
        mergeid = merge.add_ui_from_string(ui_info)
        self.menubar = merge.get_widget("/MenuBar")

        table.attach(self.menubar, 0, 2, 0, 1)
        self.menubar.set_size_request(600, 30)
        table.attach(self.cor_plot, 0, 1, 1, 2)
        self.cor_plot.set_size_request(250, 250)
        table.attach(self.sag_plot, 1, 2, 1, 2)
        self.sag_plot.set_size_request(250, 250)
        table.attach(self.ax_plot, 0, 1, 2, 3)
        self.ax_plot.set_size_request(250, 250)

        self.displaybox = DisplayInfo(self.image)
        self.displaybox.attach_toggle(self.crosshair_hider)
        self.displaybox.attach_imginterp(self.interp_handler)
        self.displaybox.attach_imgframe(self.zoom_handler)
        self.displaybox.attach_imgspace(self.rediculous_handler)
        table.attach(self.displaybox, 1, 2, 3, 4)
        self.displaybox.set_size_request(300, 300)
        self.statusbox = DisplayStatus(
            tuple(self.image.vox_coords.tolist()),
            tuple(self.image.zyx_coords().tolist()),
            self.image[tuple(self.image.vox_coords)])
        table.attach(self.statusbox, 0, 1, 3, 4)
        self.statusbox.set_size_request(300, 300)
        self.connect("configure_event", self.resize_handler)
        #table.set_row_spacing(1,25)
        # heights = 800
        # 250 plot 1
        # 250 plot 2
        # 370 info stuff
        # 30 menubar
        self.connect_crosshair_id = []
        self.connectCrosshairEvents()

        self.set_data("ui-manager", merge)
        self.add_accel_group(merge.get_accel_group())
        self.set_default_size(600, 730)
        self.set_border_width(3)
        self.set_title(title)
        self.add(table)
        self.show_all()
        #self.setUpAxesSize()
        #P.show()
        if parent:
            self.set_screen(parent.get_screen())
            self.destroy_handle = self.connect(
                'destroy', lambda x: parent._plotter_died())
        else:
            self.connect("destroy", lambda x: gtk.main_quit())
            gtk.main()
예제 #6
0
class spmclone(gtk.Window):
    def __init__(self, image, title="spmclone", parent=None):
        gtk.Window.__init__(self)
        children = self.get_children()
        if children:
            self.remove(children[0])
            self.hide_all()
        table = gtk.Table(4, 2)

        # super dumb way of storing the image object with its original
        # class identification
        self.img_obj = image
        self.image = len(image.shape) > 3 \
                     and SlicerImage(image.subImage(0)) \
                     or SlicerImage(image)
        if iscomplex(self.image[:]):
            self.image.prefilter = abs_xform

        self.overlay_img = None
        self.zoom = 0
        self.slice_patches = None
        self.setNorm()
        self.dimlengths = self.image.dr * N.array(self.image.shape)
        zyx_lim = self.image.extents()
        # I'm using [ax,cor,sag] such that this list informs each
        # sliceplot what dimension it slices in the image array
        # (eg for coronal data, the coronal plot slices the 0th dim)
        ax, cor, sag = self.image.slicing()
        origin = [0, 0, 0]
        # Make the ortho plots ---
        self.ax_plot = SlicePlot(self.image.data_xform(ax, origin),
                                 0,
                                 0,
                                 ax,
                                 norm=self.norm,
                                 extent=(zyx_lim[2] + zyx_lim[1]))

        self.cor_plot = SlicePlot(self.image.data_xform(cor, origin),
                                  0,
                                  0,
                                  cor,
                                  norm=self.norm,
                                  extent=(zyx_lim[2] + zyx_lim[0]))

        self.sag_plot = SlicePlot(self.image.data_xform(sag, origin),
                                  0,
                                  0,
                                  sag,
                                  norm=self.norm,
                                  extent=(zyx_lim[1] + zyx_lim[0]))

        # Although it doesn't matter 99% of the time, this list is
        # expected to be ordered this way
        self.sliceplots = [self.ax_plot, self.cor_plot, self.sag_plot]

        # menu bar
        merge = gtk.UIManager()
        merge.insert_action_group(self._create_action_group(), 0)
        mergeid = merge.add_ui_from_string(ui_info)
        self.menubar = merge.get_widget("/MenuBar")

        table.attach(self.menubar, 0, 2, 0, 1)
        self.menubar.set_size_request(600, 30)
        table.attach(self.cor_plot, 0, 1, 1, 2)
        self.cor_plot.set_size_request(250, 250)
        table.attach(self.sag_plot, 1, 2, 1, 2)
        self.sag_plot.set_size_request(250, 250)
        table.attach(self.ax_plot, 0, 1, 2, 3)
        self.ax_plot.set_size_request(250, 250)

        self.displaybox = DisplayInfo(self.image)
        self.displaybox.attach_toggle(self.crosshair_hider)
        self.displaybox.attach_imginterp(self.interp_handler)
        self.displaybox.attach_imgframe(self.zoom_handler)
        self.displaybox.attach_imgspace(self.rediculous_handler)
        table.attach(self.displaybox, 1, 2, 3, 4)
        self.displaybox.set_size_request(300, 300)
        self.statusbox = DisplayStatus(
            tuple(self.image.vox_coords.tolist()),
            tuple(self.image.zyx_coords().tolist()),
            self.image[tuple(self.image.vox_coords)])
        table.attach(self.statusbox, 0, 1, 3, 4)
        self.statusbox.set_size_request(300, 300)
        self.connect("configure_event", self.resize_handler)
        #table.set_row_spacing(1,25)
        # heights = 800
        # 250 plot 1
        # 250 plot 2
        # 370 info stuff
        # 30 menubar
        self.connect_crosshair_id = []
        self.connectCrosshairEvents()

        self.set_data("ui-manager", merge)
        self.add_accel_group(merge.get_accel_group())
        self.set_default_size(600, 730)
        self.set_border_width(3)
        self.set_title(title)
        self.add(table)
        self.show_all()
        #self.setUpAxesSize()
        #P.show()
        if parent:
            self.set_screen(parent.get_screen())
            self.destroy_handle = self.connect(
                'destroy', lambda x: parent._plotter_died())
        else:
            self.connect("destroy", lambda x: gtk.main_quit())
            gtk.main()

    #-------------------------------------------------------------------------
    def setNorm(self):
        "sets the whitepoint and blackpoint (uses raw data, not scaled)"
        x = N.sort(self.image[:].flatten())
        npts = x.shape[0]
        p01 = x[int(round(.01 * npts))]
        p99 = x[int(round(.99 * npts))]
        self.norm = P.normalize(vmin=p01, vmax=p99)
        #self.norm = P.normalize(-.1, 1.1)
        if hasattr(self, "overlay_img") and self.overlay_img:
            p01 = P.prctile(self.overlay_img[:], 1.0)
            p99 = P.prctile(self.overlay_img[:], 99.)
            self.overlay_norm = P.normalize(vmin=p01, vmax=p99)

    #-------------------------------------------------------------------------
    def xy(self, slice_idx):
        (ax, cor, sag) = self.image.slicing()
        x, y = {
            ax: (2, 1),  # x,y
            cor: (2, 0),  # x,z
            sag: (1, 0),  # y,z
        }.get(slice_idx)
        return x, y

    #-------------------------------------------------------------------------
    def updateSlices(self, zyx, sliceplots=None, image=None, norm=None):
        if not sliceplots:
            sliceplots = self.sliceplots
        if not image:
            image = self.image
        if not norm:
            norm = self.norm
        for sliceplot in sliceplots:
            idx = sliceplot.slice_idx
            sliceplot.setData(image.data_xform(idx, zyx), norm=norm)
            if self.slice_patches is not None:
                p_idx = int(self.image.vox_coords[idx])
                sliceplot.showPatches(self.slice_patches[idx][p_idx])

    #-------------------------------------------------------------------------
    def updateCrosshairs(self):
        for s, sliceplot in enumerate(self.sliceplots):
            idx = sliceplot.slice_idx
            zyx = self.image.zyx_coords().tolist()
            zyx.pop(s)
            ud, lr = zyx
            sliceplot.setCrosshairs(lr, ud)

    #-------------------------------------------------------------------------
    def externalUpdate(self, new_img):
        old_vox = self.image.vox_coords
        if len(new_img.shape) > 3:
            self.image = SlicerImage(new_img.subImage(0))
        else:
            self.image = SlicerImage(new_img)
        self.image.vox_coords = old_vox
        if iscomplex(self.image[:]):
            self.image.prefilter = abs_xform
        self.setNorm()
        self.setUpAxesSize()
        self.updateSlices(self.image.zyx_coords())

    #-------------------------------------------------------------------------
    def setUpAxesSize(self):
        "Scale the axes appropriately for the image dimensions"
        # assume that image resolution is isotropic in dim2 and dim1
        # (not necessarily in dim0)
        # want the isotropic resolution plot to be 215x215 pixels
        #xy_imgsize = 215. # this should be more like 85% of the dim0 sliceplot
        ref_size = self.dimlengths[-1]
        slicing = self.image.slicing()
        idx = slicing.index(0)
        xy_imgsize = .85 * min(*self.sliceplots[idx].get_width_height())
        for sliceplot in self.sliceplots:
            dims_copy = self.dimlengths.tolist()
            ax = sliceplot.getAxes()
            s_idx = sliceplot.slice_idx
            dims_copy.remove(self.dimlengths[s_idx])
            slice_y, slice_x = self.image.is_xpose(s_idx) and \
                               dims_copy[::-1] or dims_copy
            height = xy_imgsize * slice_y / ref_size
            width = xy_imgsize * slice_x / ref_size
            canvas_x, canvas_y = sliceplot.get_width_height()
            w = width / canvas_x
            h = height / canvas_y
            l = (1.0 - width / canvas_x) / 2.
            b = (1.0 - height / canvas_y) / 2.
            ax.set_position([l, b, w, h])
            sliceplot.draw_idle()

    #-------------------------------------------------------------------------
    def connectCrosshairEvents(self, mode="enable"):
        if mode == "enable":
            self._dragging = False
            for sliceplot in self.sliceplots:
                self.connect_crosshair_id.append(
                    sliceplot.mpl_connect("button_press_event",
                                          self.SPMouseDown))
                self.connect_crosshair_id.append(
                    sliceplot.mpl_connect("button_release_event",
                                          self.SPMouseUp))
                self.connect_crosshair_id.append(
                    sliceplot.mpl_connect("motion_notify_event",
                                          self.SPMouseMotion))
                sliceplot.toggleCrosshairs(mode=True)
        else:
            if len(self.connect_crosshair_id):
                for id, sliceplot in enumerate(self.sliceplots):
                    sliceplot.mpl_disconnect(self.connect_crosshair_id[id])
                    sliceplot.mpl_disconnect(self.connect_crosshair_id[id + 1])
                    sliceplot.mpl_disconnect(self.connect_crosshair_id[id + 2])
                    sliceplot.toggleCrosshairs(mode=False)
                self.connect_crosshair_id = []

    #-------------------------------------------------------------------------
    def SPMouseDown(self, event):
        # for a new mouse down event, reset the mouse positions
        self._mouse_lr = self._mouse_ud = None
        self._dragging = event.inaxes
        self.updateCoords(event)

    #-------------------------------------------------------------------------
    def SPMouseUp(self, event):
        # if not dragging, no business being here!
        if self._dragging:
            self.updateCoords(event)
            self._dragging = False

    #-------------------------------------------------------------------------
    def SPMouseMotion(self, event):
        if self._dragging:
            self.updateCoords(event)

    #-------------------------------------------------------------------------
    def updateCoords(self, event):
        "Update all the necessary sliceplot data based on a mouse click."
        # The tasks here are:
        # 1 find zyx_coords of mouse click and translate to vox_coords
        # 2 update the transverse sliceplots based on vox_coords
        # 2a update the transverse overlays if present
        # 3 update crosshairs on all sliceplots
        # 4 update voxel space and zyx space texts
        sliceplot = event.canvas
        # using terminology up-down, left-right to avoid confusion with y,x
        ud, lr = sliceplot.getEventCoords(event)
        if self._mouse_lr == lr and self._mouse_ud == ud:
            return
        if lr is None or ud is None:
            return
        self._mouse_lr, self._mouse_ud = (lr, ud)
        # trans_sliceplots are the transverse plots that get
        # updated from where the mouse clicked
        (ax, cor, sag) = self.image.slicing()
        trans_sliceplots = {
            self.sliceplots[ax]: (self.sliceplots[sag], self.sliceplots[cor]),
            self.sliceplots[cor]: (self.sliceplots[sag], self.sliceplots[ax]),
            self.sliceplots[sag]: (self.sliceplots[cor], self.sliceplots[ax]),
        }.get(sliceplot)
        trans_idx = (trans_sliceplots[0].slice_idx,
                     trans_sliceplots[1].slice_idx)

        # where do left-right and up-down cut across in zyx space?
        trans_ax = self.image.transverse_slicing(sliceplot.slice_idx)
        zyx_clicked = self.image.zyx_coords()
        zyx_clicked[trans_ax[0]] = lr
        zyx_clicked[trans_ax[1]] = ud
        vox = self.image.zyx2vox(zyx_clicked)

        self.image.vox_coords[trans_idx[0]] = vox[trans_idx[0]]
        self.image.vox_coords[trans_idx[1]] = vox[trans_idx[1]]
        self.updateSlices(zyx_clicked, sliceplots=trans_sliceplots)

        if self.overlay_img:
            # basically do the same thing over again wrt the overlay dims
            (ax_o, cor_o, sag_o) = self.overlay_img.slicing()
            trans_overlays = {
                self.sliceplots[ax]:
                (self.overlays[sag_o], self.overlays[cor_o]),
                self.sliceplots[cor]:
                (self.overlays[sag_o], self.overlays[ax_o]),
                self.sliceplots[sag]:
                (self.overlays[cor_o], self.overlays[ax_o]),
            }.get(sliceplot)
            trans_idx = (trans_overlays[0].slice_idx,
                         trans_overlays[1].slice_idx)
            vox = self.overlay_img.zyx2vox(zyx_clicked)
            self.overlay_img.vox_coords[trans_idx[0]] = vox[trans_idx[0]]
            self.overlay_img.vox_coords[trans_idx[1]] = vox[trans_idx[1]]
            self.updateSlices(zyx_clicked,
                              sliceplots=trans_overlays,
                              image=self.overlay_img,
                              norm=self.overlay_norm)

        self.updateCrosshairs()
        # make text to update the statusbox label's
        self.statusbox.set_vox_text(self.image.vox_coords[::-1].tolist())
        self.statusbox.set_zyx_text(self.image.zyx_coords()[::-1].tolist())
        point = tuple(self.image.vox_coords)
        self.statusbox.set_intensity_text(self.image[point])

    #-------------------------------------------------------------------------
    def rediculous_handler(self, cbox):
        #mode = cbox.get_active()==0 and "enable" or "disable"
        #self.connectCrosshairEvents(mode=mode)
        print "You've hit a useless button!"
        return

    #-------------------------------------------------------------------------
    def interp_handler(self, cbox):
        interp_method = interp_lookup[cbox.get_active()]
        for sliceplot in self.sliceplots:
            sliceplot.setInterpo(interp_method)

    #-------------------------------------------------------------------------
    def crosshair_hider(self, toggle):
        hidden = (not toggle.get_active())
        for sliceplot in self.sliceplots:
            sliceplot.toggleCrosshairs(mode=hidden)

    #-------------------------------------------------------------------------
    def zoom_handler(self, cbox):
        "Changes the view range of the sliceplots to be NxN mm"
        self.zoom = {
            0: 0,
            1: 160,
            2: 80,
            3: 40,
            4: 20,
            5: 10,
        }.get(cbox.get_active(), 0)
        r_center = self.image.zyx_coords()
        if self.zoom:
            r_neg = r_center - N.array([self.zoom / 2.] * 3)
            r_pos = r_center + N.array([self.zoom / 2.] * 3)
            zyx_lim = zip(r_neg, r_pos)
            self.dimlengths = N.array([self.zoom] * 3)

        else:
            self.dimlengths = self.image.dr * N.array(self.image.shape)
            zyx_lim = self.image.extents()

        for plot in self.sliceplots:
            x, y = self.image.transverse_slicing(plot.slice_idx)
            plot.setXYlim(zyx_lim[x], zyx_lim[y])
            plot.setCrosshairs(r_center[x], r_center[y])

        self.updateSlices(self.image.zyx_coords())
        self.setUpAxesSize()

    #-------------------------------------------------------------------------
    def resize_handler(self, window, event):
        #print "got resize signal"
        self.setUpAxesSize()

    #-------------------------------------------------------------------------
    def initoverlay(self, action):
        image_filter = gtk.FileFilter()
        image_filter.add_pattern("*.hdr")
        image_filter.add_pattern("*.nii")
        image_filter.set_name("Recon Images")
        fname = ask_fname(self,
                          "Choose file to overlay...",
                          action="open",
                          filter=image_filter)
        if not fname:
            return
        img = readImage(fname, vrange=(0, 0))
        self.overlay_img = SlicerImage(img)

        img_dims = N.take(
            N.array(self.image.shape) * self.image.dr, self.image.slicing())
        ovl_dims = N.take(
            N.array(self.overlay_img.shape) * self.overlay_img.dr,
            self.overlay_img.slicing())
        if not (img_dims == ovl_dims).all():
            print img_dims, ovl_dims
            print "Overlay failed because physical dimensions do not align..."
            print "base image dimensions (zyx): [%3.1f %3.1f %3.1f] (mm)" % tuple(
                img_dims)
            print "overlay image dimenensions (zyx: [%3.1f %3.1f %3.1f] (mm)" % tuple(
                ovl_dims)
            return
        self.setNorm()
        (ax, cor, sag) = self.overlay_img.slicing()
        self.ax_overlay = OverLay(self.ax_plot,
                                  ax,
                                  norm=self.overlay_norm,
                                  interpolation=self.ax_plot.interpolation)
        self.cor_overlay = OverLay(self.cor_plot,
                                   cor,
                                   norm=self.overlay_norm,
                                   interpolation=self.cor_plot.interpolation)
        self.sag_overlay = OverLay(self.sag_plot,
                                   sag,
                                   norm=self.overlay_norm,
                                   interpolation=self.sag_plot.interpolation)
        self.overlays = [self.ax_overlay, self.cor_overlay, self.sag_overlay]
        self.updateSlices(self.image.zyx_coords(),
                          sliceplots=self.overlays,
                          image=self.overlay_img,
                          norm=self.overlay_norm)

    #-------------------------------------------------------------------------
    def launch_overlay_toolbox(self, action):
        if self.overlay_img is not None:
            if not hasattr(self, "overlay_tools") or not self.overlay_tools:
                self.overlay_tools = OverlayToolWin(self.overlays, self)
            else:
                self.overlay_tools.present()

    #-------------------------------------------------------------------------
    def killoverlay(self, action):
        if self.overlay_img is not None:
            for overlay in self.overlays:
                overlay.removeSelf()
            if hasattr(self, "overlay_tools") and self.overlay_tools:
                self.overlay_tools.destroy()
                del self.overlay_tools
            self.overlay_img = None
            self.overlay_norm = None

    #-------------------------------------------------------------------------
    def load_new_image(self, action):
        image_filter = gtk.FileFilter()
        image_filter.add_pattern("*.hdr")
        image_filter.add_pattern("*.nii")
        image_filter.set_name("Recon Images")
        fname = ask_fname(self,
                          "Choose file to open...",
                          action="open",
                          filter=image_filter)
        if not fname:
            return
        try:
            img = readImage(fname, "nifti")
        except:
            img = readImage(fname, "analyze")
        self.killoverlay(None)
        self.__init__(img)

    #-------------------------------------------------------------------------
    def launch_sliceview(self, action):
        from recon.visualization.sliceview import sliceview
        sliceview(self.image, parent=self)

    #-------------------------------------------------------------------------
    def launch_recon_gui(self, action):
        from recon.visualization.recon_gui import recon_gui
        recon_gui(image=self.img_obj, parent=self)

    #-------------------------------------------------------------------------
    def _plotter_died(self):
        pass

    #-------------------------------------------------------------------------
    def _create_action_group(self):
        entries = (
            ("FileMenu", None, "_File"),
            ("Open Image", gtk.STOCK_OPEN, "_Open Image", "<control>O",
             "Opens and plots a new image", self.load_new_image),
            ("Quit", gtk.STOCK_QUIT, "_Quit", "<control>Q", "Quits",
             lambda action: self.destroy()),
            ("ToolsMenu", None, "_Tools"),
            ("Load Overlay", None, "_Load Overlay", "",
             "Load an image to overlay", self.initoverlay),
            ("Unload Overlay", None, "_Unload Overlay", "",
             "Unload the overlay", self.killoverlay),
            ("Overlay Adjustment Toolbox", None, "_Overlay Adjustment Toolbox",
             "", "launch overlay toolbox", self.launch_overlay_toolbox),
            ("Plot In Sliceview", None, "_Plot In Sliceview", "",
             "opens image in sliceview", self.launch_sliceview),
            ("Run Recon GUI", None, "_Run Recon GUI", "", "opens gui",
             self.launch_recon_gui),
        )

        action_group = gtk.ActionGroup("WindowActions")
        action_group.add_actions(entries)
        return action_group
예제 #7
0
class spmclone (gtk.Window):
    
    def __init__(self, image, title="spmclone", parent=None):
        gtk.Window.__init__(self)
        children = self.get_children()
        if children:
            self.remove(children[0])            
            self.hide_all()
        table = gtk.Table(4, 2)

        # super dumb way of storing the image object with its original
        # class identification
        self.img_obj = image
        self.image = len(image.shape) > 3 \
                     and SlicerImage(image.subImage(0)) \
                     or SlicerImage(image)
        if iscomplex(self.image[:]):
            self.image.prefilter = abs_xform

        self.overlay_img = None
        self.zoom = 0
        self.slice_patches = None
        self.setNorm()
        self.dimlengths = self.image.dr * N.array(self.image.shape)
        zyx_lim = self.image.extents()
        # I'm using [ax,cor,sag] such that this list informs each
        # sliceplot what dimension it slices in the image array
        # (eg for coronal data, the coronal plot slices the 0th dim)        
        ax, cor, sag = self.image.slicing()
        origin = [0,0,0]
        # Make the ortho plots ---
        self.ax_plot=SlicePlot(self.image.data_xform(ax, origin), 0, 0, ax,
                               norm=self.norm,
                               extent=(zyx_lim[2] + zyx_lim[1]))
        
        self.cor_plot=SlicePlot(self.image.data_xform(cor, origin), 0, 0, cor,
                                norm=self.norm,
                                extent=(zyx_lim[2] + zyx_lim[0]))
        
        self.sag_plot=SlicePlot(self.image.data_xform(sag, origin), 0, 0, sag,
                                norm=self.norm,
                                extent=(zyx_lim[1] + zyx_lim[0]))
        
        # Although it doesn't matter 99% of the time, this list is
        # expected to be ordered this way
        self.sliceplots = [self.ax_plot, self.cor_plot, self.sag_plot]

        # menu bar
        merge = gtk.UIManager()
        merge.insert_action_group(self._create_action_group(), 0)
        mergeid = merge.add_ui_from_string(ui_info)
        self.menubar = merge.get_widget("/MenuBar")

        table.attach(self.menubar, 0, 2, 0, 1)
        self.menubar.set_size_request(600,30)
        table.attach(self.cor_plot, 0, 1, 1, 2)
        self.cor_plot.set_size_request(250,250)
        table.attach(self.sag_plot, 1, 2, 1, 2)
        self.sag_plot.set_size_request(250,250)
        table.attach(self.ax_plot, 0, 1, 2, 3)
        self.ax_plot.set_size_request(250,250)

        self.displaybox = DisplayInfo(self.image)
        self.displaybox.attach_toggle(self.crosshair_hider)
        self.displaybox.attach_imginterp(self.interp_handler)
        self.displaybox.attach_imgframe(self.zoom_handler)
        self.displaybox.attach_imgspace(self.rediculous_handler)
        table.attach(self.displaybox, 1, 2, 3, 4)
        self.displaybox.set_size_request(300,300)
        self.statusbox = DisplayStatus(tuple(self.image.vox_coords.tolist()),
                                       tuple(self.image.zyx_coords().tolist()),
                                       self.image[tuple(self.image.vox_coords)])
        table.attach(self.statusbox, 0, 1, 3, 4)
        self.statusbox.set_size_request(300,300)
        self.connect("configure_event", self.resize_handler)
        #table.set_row_spacing(1,25)
        # heights = 800
        # 250 plot 1
        # 250 plot 2
        # 370 info stuff
        # 30 menubar
        self.connect_crosshair_id = []
        self.connectCrosshairEvents()
        
        self.set_data("ui-manager", merge)
        self.add_accel_group(merge.get_accel_group())
        self.set_default_size(600,730)
        self.set_border_width(3)
        self.set_title(title)
        self.add(table)        
        self.show_all()        
        #self.setUpAxesSize()        
        #P.show()
        if parent:
            self.set_screen(parent.get_screen())
            self.destroy_handle = self.connect('destroy',
                                               lambda x: parent._plotter_died())
        else:
            self.connect("destroy", lambda x: gtk.main_quit())
            gtk.main()

    #-------------------------------------------------------------------------
    def setNorm(self):
        "sets the whitepoint and blackpoint (uses raw data, not scaled)"
        x = N.sort(self.image[:].flatten())
        npts = x.shape[0]
        p01 = x[int(round(.01*npts))]
        p99 = x[int(round(.99*npts))]
        self.norm = P.normalize(vmin = p01, vmax = p99)
        #self.norm = P.normalize(-.1, 1.1)
        if hasattr(self, "overlay_img") and self.overlay_img:
            p01 = P.prctile(self.overlay_img[:], 1.0)
            p99 = P.prctile(self.overlay_img[:], 99.)
            self.overlay_norm = P.normalize(vmin = p01, vmax = p99)

    #-------------------------------------------------------------------------
    def xy(self, slice_idx):
        (ax, cor, sag) = self.image.slicing()
        x,y = {
            ax: (2, 1), # x,y
            cor: (2, 0), # x,z
            sag: (1, 0), # y,z
        }.get(slice_idx)
        return x,y
    #-------------------------------------------------------------------------
    def updateSlices(self, zyx, sliceplots=None, image=None, norm=None):
        if not sliceplots:
            sliceplots = self.sliceplots
        if not image:
            image = self.image
        if not norm:
            norm = self.norm
        for sliceplot in sliceplots:
            idx = sliceplot.slice_idx
            sliceplot.setData(image.data_xform(idx, zyx), norm=norm)
            if self.slice_patches is not None:
                p_idx = int(self.image.vox_coords[idx])
                sliceplot.showPatches(self.slice_patches[idx][p_idx])

    #-------------------------------------------------------------------------
    def updateCrosshairs(self):
        for s,sliceplot in enumerate(self.sliceplots):
            idx = sliceplot.slice_idx
            zyx = self.image.zyx_coords().tolist()
            zyx.pop(s)
            ud,lr = zyx
            sliceplot.setCrosshairs(lr,ud)

    #-------------------------------------------------------------------------
    def externalUpdate(self, new_img):
        old_vox = self.image.vox_coords
        if len(new_img.shape) > 3:
            self.image = SlicerImage(new_img.subImage(0))
        else:
            self.image = SlicerImage(new_img)
        self.image.vox_coords = old_vox
        if iscomplex(self.image[:]):
            self.image.prefilter = abs_xform
        self.setNorm()
        self.setUpAxesSize()
        self.updateSlices(self.image.zyx_coords())
        
    #-------------------------------------------------------------------------
    def setUpAxesSize(self):
        "Scale the axes appropriately for the image dimensions"
        # assume that image resolution is isotropic in dim2 and dim1
        # (not necessarily in dim0)
        # want the isotropic resolution plot to be 215x215 pixels
        #xy_imgsize = 215. # this should be more like 85% of the dim0 sliceplot
        ref_size = self.dimlengths[-1]
        slicing = self.image.slicing()
        idx = slicing.index(0)
        xy_imgsize = .85 * min(*self.sliceplots[idx].get_width_height())
        for sliceplot in self.sliceplots:
            dims_copy = self.dimlengths.tolist()
            ax = sliceplot.getAxes()
            s_idx = sliceplot.slice_idx
            dims_copy.remove(self.dimlengths[s_idx])
            slice_y, slice_x = self.image.is_xpose(s_idx) and \
                               dims_copy[::-1] or dims_copy
            height = xy_imgsize*slice_y/ref_size
            width = xy_imgsize*slice_x/ref_size
            canvas_x, canvas_y = sliceplot.get_width_height()
            w = width/canvas_x
            h = height/canvas_y
            l = (1.0 - width/canvas_x)/2.
            b = (1.0 - height/canvas_y)/2.
            ax.set_position([l,b,w,h])
            sliceplot.draw_idle()

    #-------------------------------------------------------------------------
    def connectCrosshairEvents(self, mode="enable"):
        if mode=="enable":
            self._dragging = False
            for sliceplot in self.sliceplots:
                self.connect_crosshair_id.append(sliceplot.mpl_connect(
                    "button_press_event", self.SPMouseDown))
                self.connect_crosshair_id.append(sliceplot.mpl_connect(
                    "button_release_event", self.SPMouseUp))
                self.connect_crosshair_id.append(sliceplot.mpl_connect(
                    "motion_notify_event", self.SPMouseMotion))
                sliceplot.toggleCrosshairs(mode=True)
        else:
            if len(self.connect_crosshair_id):
                for id,sliceplot in enumerate(self.sliceplots):
                    sliceplot.mpl_disconnect(self.connect_crosshair_id[id])
                    sliceplot.mpl_disconnect(self.connect_crosshair_id[id+1])
                    sliceplot.mpl_disconnect(self.connect_crosshair_id[id+2])
                    sliceplot.toggleCrosshairs(mode=False)
                self.connect_crosshair_id = []

    #-------------------------------------------------------------------------
    def SPMouseDown(self, event):
        # for a new mouse down event, reset the mouse positions
        self._mouse_lr = self._mouse_ud = None
        self._dragging = event.inaxes
        self.updateCoords(event)

    #-------------------------------------------------------------------------
    def SPMouseUp(self, event):
        # if not dragging, no business being here!
        if self._dragging:
            self.updateCoords(event)
            self._dragging = False

    #-------------------------------------------------------------------------
    def SPMouseMotion(self, event):
        if self._dragging:
            self.updateCoords(event)

    #-------------------------------------------------------------------------
    def updateCoords(self, event):
        "Update all the necessary sliceplot data based on a mouse click."
        # The tasks here are:
        # 1 find zyx_coords of mouse click and translate to vox_coords
        # 2 update the transverse sliceplots based on vox_coords
        # 2a update the transverse overlays if present
        # 3 update crosshairs on all sliceplots
        # 4 update voxel space and zyx space texts
        sliceplot = event.canvas
        # using terminology up-down, left-right to avoid confusion with y,x
        ud,lr = sliceplot.getEventCoords(event)
        if self._mouse_lr == lr and self._mouse_ud == ud:
            return
        if lr is None or ud is None:
            return
        self._mouse_lr, self._mouse_ud = (lr, ud)
        # trans_sliceplots are the transverse plots that get
        # updated from where the mouse clicked
        (ax, cor, sag) = self.image.slicing()
        trans_sliceplots = {
            self.sliceplots[ax]: (self.sliceplots[sag], self.sliceplots[cor]),
            self.sliceplots[cor]: (self.sliceplots[sag], self.sliceplots[ax]),
            self.sliceplots[sag]: (self.sliceplots[cor], self.sliceplots[ax]),
            }.get(sliceplot)
        trans_idx = (trans_sliceplots[0].slice_idx,
                     trans_sliceplots[1].slice_idx)
        
        # where do left-right and up-down cut across in zyx space?
        trans_ax = self.image.transverse_slicing(sliceplot.slice_idx)
        zyx_clicked = self.image.zyx_coords()
        zyx_clicked[trans_ax[0]] = lr
        zyx_clicked[trans_ax[1]] = ud
        vox = self.image.zyx2vox(zyx_clicked)

        self.image.vox_coords[trans_idx[0]] = vox[trans_idx[0]]
        self.image.vox_coords[trans_idx[1]] = vox[trans_idx[1]]
        self.updateSlices(zyx_clicked, sliceplots=trans_sliceplots)

        if self.overlay_img:
            # basically do the same thing over again wrt the overlay dims
            (ax_o, cor_o, sag_o) = self.overlay_img.slicing()
            trans_overlays = {
                self.sliceplots[ax]: (self.overlays[sag_o], self.overlays[cor_o]),
                self.sliceplots[cor]: (self.overlays[sag_o], self.overlays[ax_o]),
                self.sliceplots[sag]: (self.overlays[cor_o], self.overlays[ax_o]),
                }.get(sliceplot)
            trans_idx = (trans_overlays[0].slice_idx,
                         trans_overlays[1].slice_idx)            
            vox = self.overlay_img.zyx2vox(zyx_clicked)
            self.overlay_img.vox_coords[trans_idx[0]] = vox[trans_idx[0]]
            self.overlay_img.vox_coords[trans_idx[1]] = vox[trans_idx[1]]
            self.updateSlices(zyx_clicked, sliceplots=trans_overlays,
                              image=self.overlay_img,
                              norm=self.overlay_norm)

        self.updateCrosshairs()
        # make text to update the statusbox label's
        self.statusbox.set_vox_text(self.image.vox_coords[::-1].tolist())
        self.statusbox.set_zyx_text(self.image.zyx_coords()[::-1].tolist())
        point = tuple(self.image.vox_coords)
        self.statusbox.set_intensity_text(self.image[point])

    #-------------------------------------------------------------------------
    def rediculous_handler(self, cbox):
        #mode = cbox.get_active()==0 and "enable" or "disable"
        #self.connectCrosshairEvents(mode=mode)
        print "You've hit a useless button!"
        return

    #-------------------------------------------------------------------------
    def interp_handler(self, cbox):
        interp_method = interp_lookup[cbox.get_active()]
        for sliceplot in self.sliceplots:
            sliceplot.setInterpo(interp_method)

    #-------------------------------------------------------------------------
    def crosshair_hider(self, toggle):
        hidden = (not toggle.get_active())
        for sliceplot in self.sliceplots:
            sliceplot.toggleCrosshairs(mode=hidden)

    #-------------------------------------------------------------------------
    def zoom_handler(self, cbox):
        "Changes the view range of the sliceplots to be NxN mm"
        self.zoom = {
            0: 0,
            1: 160,
            2: 80,
            3: 40,
            4: 20,
            5: 10,
        }.get(cbox.get_active(), 0)
        r_center = self.image.zyx_coords()
        if self.zoom:
            r_neg = r_center - N.array([self.zoom/2.]*3)
            r_pos = r_center + N.array([self.zoom/2.]*3)
            zyx_lim = zip(r_neg, r_pos)
            self.dimlengths = N.array([self.zoom]*3)
            
        else:
            self.dimlengths = self.image.dr * N.array(self.image.shape)
            zyx_lim = self.image.extents()
            
        for plot in self.sliceplots:
            x,y = self.image.transverse_slicing(plot.slice_idx)
            plot.setXYlim(zyx_lim[x], zyx_lim[y])
            plot.setCrosshairs(r_center[x], r_center[y])
            
        self.updateSlices(self.image.zyx_coords())
        self.setUpAxesSize()
    #-------------------------------------------------------------------------
    def resize_handler(self, window, event):
        #print "got resize signal"
        self.setUpAxesSize()
    #-------------------------------------------------------------------------
    def initoverlay(self, action):
        image_filter = gtk.FileFilter()
        image_filter.add_pattern("*.hdr")
        image_filter.add_pattern("*.nii")
        image_filter.set_name("Recon Images")
        fname = ask_fname(self, "Choose file to overlay...", action="open",
                          filter=image_filter)
        if not fname:
            return
        img = readImage(fname, vrange=(0,0))
        self.overlay_img = SlicerImage(img)
            
        img_dims = N.take(N.array(self.image.shape) * self.image.dr,
                          self.image.slicing())
        ovl_dims = N.take(N.array(self.overlay_img.shape) * self.overlay_img.dr,
                          self.overlay_img.slicing())
        if not (img_dims == ovl_dims).all():
            print img_dims, ovl_dims
            print "Overlay failed because physical dimensions do not align..."
            print "base image dimensions (zyx): [%3.1f %3.1f %3.1f] (mm)"%tuple(img_dims)
            print "overlay image dimenensions (zyx: [%3.1f %3.1f %3.1f] (mm)"%tuple(ovl_dims)
            return
        self.setNorm()
        (ax, cor, sag) = self.overlay_img.slicing()
        self.ax_overlay = OverLay(self.ax_plot, ax,
                                  norm=self.overlay_norm,
                                  interpolation=self.ax_plot.interpolation)
        self.cor_overlay = OverLay(self.cor_plot, cor,
                                   norm=self.overlay_norm,
                                   interpolation=self.cor_plot.interpolation)
        self.sag_overlay = OverLay(self.sag_plot, sag,
                                   norm=self.overlay_norm,
                                   interpolation=self.sag_plot.interpolation)
        self.overlays = [self.ax_overlay, self.cor_overlay, self.sag_overlay]
        self.updateSlices(self.image.zyx_coords(),
                          sliceplots=self.overlays,
                          image=self.overlay_img,
                          norm=self.overlay_norm)

    #-------------------------------------------------------------------------
    def launch_overlay_toolbox(self, action):
        if self.overlay_img is not None:
            if not hasattr(self, "overlay_tools") or not self.overlay_tools:
                self.overlay_tools = OverlayToolWin(self.overlays, self)
            else:
                self.overlay_tools.present()

    #-------------------------------------------------------------------------
    def killoverlay(self, action):
        if self.overlay_img is not None:
            for overlay in self.overlays:
                overlay.removeSelf()
            if hasattr(self, "overlay_tools") and self.overlay_tools:
                self.overlay_tools.destroy()
                del self.overlay_tools
            self.overlay_img = None
            self.overlay_norm = None

    #-------------------------------------------------------------------------
    def load_new_image(self, action):
        image_filter = gtk.FileFilter()
        image_filter.add_pattern("*.hdr")
        image_filter.add_pattern("*.nii")
        image_filter.set_name("Recon Images")
        fname = ask_fname(self, "Choose file to open...", action="open",
                          filter=image_filter)
        if not fname:
            return
        try:
            img = readImage(fname, "nifti")
        except:
            img = readImage(fname, "analyze")
        self.killoverlay(None)
        self.__init__(img)

    #-------------------------------------------------------------------------
    def launch_sliceview(self, action):
        from recon.visualization.sliceview import sliceview
        sliceview(self.image, parent=self)
    #-------------------------------------------------------------------------
    def launch_recon_gui(self, action):
        from recon.visualization.recon_gui import recon_gui
        recon_gui(image=self.img_obj, parent=self)
    #-------------------------------------------------------------------------
    def _plotter_died(self):
        pass
    #-------------------------------------------------------------------------
    def _create_action_group(self):
        entries = (
            ( "FileMenu", None, "_File" ),
            ( "Open Image", gtk.STOCK_OPEN, "_Open Image", "<control>O",
              "Opens and plots a new image", self.load_new_image ),
            ( "Quit", gtk.STOCK_QUIT,
              "_Quit", "<control>Q",
              "Quits",
              lambda action: self.destroy() ),
            ( "ToolsMenu", None, "_Tools" ),
            ( "Load Overlay", None, "_Load Overlay", "",
              "Load an image to overlay", self.initoverlay ),
            ( "Unload Overlay", None, "_Unload Overlay", "",
              "Unload the overlay", self.killoverlay ),
            ( "Overlay Adjustment Toolbox", None,
              "_Overlay Adjustment Toolbox", "",
              "launch overlay toolbox", self.launch_overlay_toolbox ),
            ( "Plot In Sliceview", None,
              "_Plot In Sliceview", "", "opens image in sliceview",
              self.launch_sliceview ),
            ( "Run Recon GUI", None, "_Run Recon GUI", "", "opens gui",
              self.launch_recon_gui ),
        )

        action_group = gtk.ActionGroup("WindowActions")
        action_group.add_actions(entries)
        return action_group