Example #1
0
    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)
        # set the title and layout
        self.setWindowTitle("Fine Contrast Adjust - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # add a plot widget
        self.plot = MatplotlibWidget(self)
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))

        self.hist_plot, = self.plot.axes.plot([], [])
        self.vline1, = self.plot.axes.plot([], [], color="r")
        self.vline2, = self.plot.axes.plot([], [], color="m")

        self.slider = QtWidgets.QSlider()
        self.slider.setOrientation(QtCore.Qt.Horizontal)
        self.layout.addWidget(self.slider)
        self.slider.valueChanged.connect(self.setValue)

        self.slider2 = QtWidgets.QSlider()
        self.slider2.setOrientation(QtCore.Qt.Horizontal)
        self.layout.addWidget(self.slider2)
        self.slider2.valueChanged.connect(self.setValue2)
Example #2
0
class Addon(clickpoints.Addon):

    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)
        # set the title and layout
        self.setWindowTitle("Fine Contrast Adjust - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # add a plot widget
        self.plot = MatplotlibWidget(self)
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))

        self.hist_plot, = self.plot.axes.plot([], [])
        self.vline1, = self.plot.axes.plot([], [], color="r")
        self.vline2, = self.plot.axes.plot([], [], color="m")

        self.slider = QtWidgets.QSlider()
        self.slider.setOrientation(QtCore.Qt.Horizontal)
        self.layout.addWidget(self.slider)
        self.slider.valueChanged.connect(self.setValue)

        self.slider2 = QtWidgets.QSlider()
        self.slider2.setOrientation(QtCore.Qt.Horizontal)
        self.layout.addWidget(self.slider2)
        self.slider2.valueChanged.connect(self.setValue2)

    def setValue(self, value):
        self.cp.window.GetModule("GammaCorrection").updateBrightnes(value)
        self.vline1.set_data([value, value], [0, 1])
        self.plot.draw()

    def setValue2(self, value):
        self.cp.window.GetModule("GammaCorrection").updateContrast(value)
        self.vline2.set_data([value, value], [0, 1])
        self.plot.draw()

    def frameChangedEvent(self):
        self.slider.setRange(0, self.cp.window.GetModule("GammaCorrection").max_value)
        im = self.cp.getImage().data
        hist, bins = np.histogram(im[::4, ::4].ravel(), bins=np.linspace(0, self.cp.window.GetModule("GammaCorrection").max_value+1, 256), density=True)
        self.hist_plot.set_data(bins[:-1], hist)
        self.plot.axes.set_ylim(0, np.max(hist)*1.2)
        self.plot.axes.set_xlim(0, self.cp.window.GetModule("GammaCorrection").max_value)
        self.plot.draw()


    def buttonPressedEvent(self):
        self.slider.setRange(0, self.cp.window.GetModule("GammaCorrection").max_value)
        self.slider2.setRange(0, self.cp.window.GetModule("GammaCorrection").max_value)
        self.frameChangedEvent()
        self.show()
Example #3
0
    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)
        # set the title and layout
        self.setWindowTitle("Kymograph - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # add some options
        # the frame number for the kymograph
        self.addOption(key="frames",
                       display_name="Frames",
                       default=50,
                       value_type="int",
                       tooltip="How many images to use for the kymograph.")
        self.input_count = AddQSpinBox(self.layout,
                                       "Frames:",
                                       value=self.getOption("frames"),
                                       float=False)
        self.linkOption("frames", self.input_count)

        # the with in pixel of each line
        self.addOption(key="width",
                       display_name="Width",
                       default=1,
                       value_type="int",
                       tooltip="The width of the slice to cut from the image.")
        self.input_width = AddQSpinBox(self.layout,
                                       "Width:",
                                       value=self.getOption("width"),
                                       float=False)
        self.linkOption("width", self.input_width)

        # the length scaling
        self.addOption(key="scaleLength",
                       display_name="Scale Length",
                       default=1,
                       value_type="float",
                       tooltip="What is distance a pixel represents.")
        self.input_scale1 = AddQSpinBox(self.layout,
                                        "Scale Length:",
                                        value=self.getOption("scaleLength"),
                                        float=True)
        self.linkOption("scaleLength", self.input_scale1)

        # the time scaling
        self.addOption(
            key="scaleTime",
            display_name="Scale Time",
            default=1,
            value_type="float",
            tooltip="What is the time difference between two images.")
        self.input_scale2 = AddQSpinBox(self.layout,
                                        "Scale Time:",
                                        value=self.getOption("scaleTime"),
                                        float=True)
        self.linkOption("scaleTime", self.input_scale2)

        # the colormap
        self.addOption(key="colormap",
                       display_name="Colormap",
                       default="None",
                       value_type="string",
                       tooltip="The colormap to use for the kymograph.")
        maps = ["None"]
        maps.extend(plt.colormaps())
        self.input_colormap = AddQComboBox(
            self.layout,
            "Colormap:",
            selectedValue=self.getOption("colormap"),
            values=maps)
        self.input_colormap.setEditable(True)
        self.linkOption("colormap", self.input_colormap)

        # the table listing the line objects
        self.tableWidget = QtWidgets.QTableWidget(0, 1, self)
        self.layout.addWidget(self.tableWidget)
        self.row_headers = ["Line Length"]
        self.tableWidget.setHorizontalHeaderLabels(self.row_headers)
        self.tableWidget.setMinimumHeight(180)
        self.setMinimumWidth(500)
        self.tableWidget.setCurrentCell(0, 0)
        self.tableWidget.cellClicked.connect(self.cellSelected)

        # add kymograph types
        self.my_type = self.db.setMarkerType("kymograph",
                                             "#ef7fff",
                                             self.db.TYPE_Line,
                                             text="#$marker_id")
        self.my_type2 = self.db.setMarkerType("kymograph_end", "#df00ff",
                                              self.db.TYPE_Normal)
        self.cp.reloadTypes()

        # add a plot widget
        self.plot = MatplotlibWidget(self)
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))
        self.plot.figure.canvas.mpl_connect('button_press_event',
                                            self.button_press_callback)

        # add export buttons
        layout = QtWidgets.QHBoxLayout()
        self.button_export = QtWidgets.QPushButton("Export")
        self.button_export.clicked.connect(self.export)
        layout.addWidget(self.button_export)
        self.button_export2 = QtWidgets.QPushButton("Export All")
        self.button_export2.clicked.connect(self.export2)
        layout.addWidget(self.button_export2)
        self.layout.addLayout(layout)

        # add a progress bar
        self.progressbar = QtWidgets.QProgressBar()
        self.layout.addWidget(self.progressbar)

        # connect slots
        self.signal_update_plot.connect(self.updatePlotImageEvent)
        self.signal_plot_finished.connect(self.plotFinishedEvent)

        # initialize the table
        self.updateTable()
        self.selected = None
Example #4
0
class Addon(clickpoints.Addon):
    signal_update_plot = QtCore.Signal()
    signal_plot_finished = QtCore.Signal()
    image_plot = None
    last_update = 0
    updating = False
    exporting = False
    exporting_index = 0

    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)
        # set the title and layout
        self.setWindowTitle("Kymograph - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # add some options
        # the frame number for the kymograph
        self.addOption(key="frames",
                       display_name="Frames",
                       default=50,
                       value_type="int",
                       tooltip="How many images to use for the kymograph.")
        self.input_count = AddQSpinBox(self.layout,
                                       "Frames:",
                                       value=self.getOption("frames"),
                                       float=False)
        self.linkOption("frames", self.input_count)

        # the with in pixel of each line
        self.addOption(key="width",
                       display_name="Width",
                       default=1,
                       value_type="int",
                       tooltip="The width of the slice to cut from the image.")
        self.input_width = AddQSpinBox(self.layout,
                                       "Width:",
                                       value=self.getOption("width"),
                                       float=False)
        self.linkOption("width", self.input_width)

        # the length scaling
        self.addOption(key="scaleLength",
                       display_name="Scale Length",
                       default=1,
                       value_type="float",
                       tooltip="What is distance a pixel represents.")
        self.input_scale1 = AddQSpinBox(self.layout,
                                        "Scale Length:",
                                        value=self.getOption("scaleLength"),
                                        float=True)
        self.linkOption("scaleLength", self.input_scale1)

        # the time scaling
        self.addOption(
            key="scaleTime",
            display_name="Scale Time",
            default=1,
            value_type="float",
            tooltip="What is the time difference between two images.")
        self.input_scale2 = AddQSpinBox(self.layout,
                                        "Scale Time:",
                                        value=self.getOption("scaleTime"),
                                        float=True)
        self.linkOption("scaleTime", self.input_scale2)

        # the colormap
        self.addOption(key="colormap",
                       display_name="Colormap",
                       default="None",
                       value_type="string",
                       tooltip="The colormap to use for the kymograph.")
        maps = ["None"]
        maps.extend(plt.colormaps())
        self.input_colormap = AddQComboBox(
            self.layout,
            "Colormap:",
            selectedValue=self.getOption("colormap"),
            values=maps)
        self.input_colormap.setEditable(True)
        self.linkOption("colormap", self.input_colormap)

        # the table listing the line objects
        self.tableWidget = QtWidgets.QTableWidget(0, 1, self)
        self.layout.addWidget(self.tableWidget)
        self.row_headers = ["Line Length"]
        self.tableWidget.setHorizontalHeaderLabels(self.row_headers)
        self.tableWidget.setMinimumHeight(180)
        self.setMinimumWidth(500)
        self.tableWidget.setCurrentCell(0, 0)
        self.tableWidget.cellClicked.connect(self.cellSelected)

        # add kymograph types
        self.my_type = self.db.setMarkerType("kymograph",
                                             "#ef7fff",
                                             self.db.TYPE_Line,
                                             text="#$marker_id")
        self.my_type2 = self.db.setMarkerType("kymograph_end", "#df00ff",
                                              self.db.TYPE_Normal)
        self.cp.reloadTypes()

        # add a plot widget
        self.plot = MatplotlibWidget(self)
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))
        self.plot.figure.canvas.mpl_connect('button_press_event',
                                            self.button_press_callback)

        # add export buttons
        layout = QtWidgets.QHBoxLayout()
        self.button_export = QtWidgets.QPushButton("Export")
        self.button_export.clicked.connect(self.export)
        layout.addWidget(self.button_export)
        self.button_export2 = QtWidgets.QPushButton("Export All")
        self.button_export2.clicked.connect(self.export2)
        layout.addWidget(self.button_export2)
        self.layout.addLayout(layout)

        # add a progress bar
        self.progressbar = QtWidgets.QProgressBar()
        self.layout.addWidget(self.progressbar)

        # connect slots
        self.signal_update_plot.connect(self.updatePlotImageEvent)
        self.signal_plot_finished.connect(self.plotFinishedEvent)

        # initialize the table
        self.updateTable()
        self.selected = None

    def button_press_callback(self, event):
        # only drag with left mouse button
        if event.button != 1:
            return
        # if the user doesn't have clicked on an axis do nothing
        if event.inaxes is None:
            return
        # get the pixel of the kymograph
        x, y = event.xdata / self.input_scale1.value(
        ), event.ydata / self.h / self.input_scale2.value()
        # jump to the frame in time
        self.cp.jumpToFrame(self.bar.image.sort_index + int(y))
        # and to the xy position
        self.cp.centerOn(*self.getLinePoint(self.bar, x))

    def cellSelected(self, row, column):
        # store the row
        self.selected = row
        # and update the plot
        self.updatePlot()

    def setTableText(self, row, column, text):
        if column == -1:
            item = self.tableWidget.verticalHeaderItem(row)
            if item is None:
                item = QtWidgets.QTableWidgetItem("")
                self.tableWidget.setVerticalHeaderItem(row, item)
        else:
            item = self.tableWidget.item(row, column)
            if item is None:
                item = QtWidgets.QTableWidgetItem("")
                self.tableWidget.setItem(row, column, item)
                if column == 2:
                    item.setFlags(QtCore.Qt.ItemIsSelectable
                                  | QtCore.Qt.ItemIsEnabled)
        item.setText(str(text))

    def updateTable(self):
        self.updating = True
        bars = self.db.getLines(type=self.my_type)
        self.bars = [bar for bar in bars]
        self.bar_dict = {}
        self.tableWidget.setRowCount(bars.count())
        self.last_image_id = None
        for idx, bar in enumerate(bars):
            self.updateRow(idx)
            self.bar_dict[bar.id] = idx
        self.updating = False

    def updateRow(self, idx):
        bar = self.bars[idx]
        self.setTableText(idx, -1, "#%d" % bar.id)
        self.setTableText(idx, 0, bar.length())

    def getLinePoint(self, line, percentage):
        x1 = line.x1
        x2 = line.x2
        y1 = line.y1
        y2 = line.y2
        if self.mirror:
            y1, y2 = y2, y1
        w = x2 - x1
        h = y2 - y1
        length = np.sqrt(w**2 + h**2)
        if self.mirror:
            percentage = length - percentage
        return x1 + w * percentage / length, y1 + h * percentage / length

    def getLine(self, image, line, height, image_entry=None):
        x1 = line.x1
        x2 = line.x2
        y1 = line.y1
        y2 = line.y2
        if self.mirror:
            y1, y2 = y2, y1
        w = x2 - x1
        h = y2 - y1
        length = np.sqrt(w**2 + h**2)
        w2 = h / length
        h2 = -w / length

        if image_entry and image_entry.offset:
            offx, offy = image_entry.offset.x, image_entry.offset.y
        else:
            offx, offy = 0, 0
        x1 -= offx - self.start_offx
        y1 -= offy - self.start_offy

        datas = []
        for j in np.arange(0, self.h) - self.h / 2. + 0.5:
            data = []
            for i in np.linspace(0, 1, np.ceil(length)):
                x = x1 + w * i + w2 * j
                y = y1 + h * i + h2 * j
                xp = x - np.floor(x)
                yp = y - np.floor(y)
                v = np.dot(
                    np.array([[1 - yp, yp]]).T, np.array([[1 - xp, xp]]))
                if len(image.shape) == 3:
                    data.append(
                        np.sum(image[int(y):int(y) + 2,
                                     int(x):int(x) + 2, :] * v[:, :, None],
                               axis=(0, 1),
                               dtype=image.dtype))
                else:
                    data.append(
                        np.sum(image[int(y):int(y) + 2,
                                     int(x):int(x) + 2] * v,
                               dtype=image.dtype))
            datas.append(data)

        if self.mirror:
            return np.array(datas)[:, ::-1]
        return np.array(datas)[::-1, :]

    def updatePlot(self):
        if self.selected is None:
            return
        self.n = -1
        self.terminate()
        self.mirror = False
        if self.db.getOption("rotation") == 180:
            self.mirror = True

        self.bar = self.bars[self.selected]
        self.plot.axes.clear()
        image_start = self.bar.image
        if image_start.offset:
            self.start_offx, self.start_offy = image_start.offset.x, image_start.offset.y
        else:
            self.start_offx, self.start_offy = 0, 0
        self.h = self.input_width.value()
        if int(self.input_count.value()) == 0:
            image = self.bar.image.sort_index
            end_marker = self.db.table_marker.select().where(
                self.db.table_marker.type == self.my_type2).join(
                    self.db.table_image).where(
                        self.db.table_image.sort_index > image).limit(1)
            self.n = end_marker[0].image.sort_index - image
        else:
            self.n = int(self.input_count.value())
        self.progressbar.setRange(0, self.n - 1)
        data = image_start.data
        line_cut = self.getLine(data, self.bar, self.h, image_start)
        self.w = line_cut.shape[1]

        if len(data.shape) == 3:
            self.current_data = np.zeros(
                (self.h * self.n, self.w, data.shape[2]), dtype=line_cut.dtype)
        else:
            self.current_data = np.zeros((self.h * self.n, self.w),
                                         dtype=line_cut.dtype)
        self.current_data[0:self.h, :] = line_cut

        extent = (0, self.current_data.shape[1] * self.input_scale1.value(),
                  self.current_data.shape[0] * self.input_scale2.value(), 0)
        if self.input_colormap.currentText() != "None":
            if len(self.current_data.shape) == 3:
                data_gray = np.dot(self.current_data[..., :3],
                                   [0.299, 0.587, 0.114])
                self.image_plot = self.plot.axes.imshow(
                    data_gray,
                    cmap=self.input_colormap.currentText(),
                    extent=extent)
            else:
                self.image_plot = self.plot.axes.imshow(
                    self.current_data,
                    cmap=self.input_colormap.currentText(),
                    extent=extent)
        else:
            self.image_plot = self.plot.axes.imshow(self.current_data,
                                                    cmap="gray",
                                                    extent=extent)
        self.plot.axes.set_xlabel(u"distance (µm)")
        self.plot.axes.set_ylabel("time (s)")
        self.plot.figure.tight_layout()
        self.plot.draw()

        self.last_update = time.time()

        self.run_threaded(image_start.sort_index + 1, self.run)

    def updatePlotImageEvent(self):
        t = time.time()
        if t - self.last_update < 0.1 and self.index < self.n - 1:
            return
        self.last_update = t
        if self.image_plot:
            if len(self.current_data.shape
                   ) == 3 and self.input_colormap.currentText() != "None":
                data_gray = np.dot(self.current_data[..., :3],
                                   [0.299, 0.587, 0.114])
                self.image_plot.set_data(data_gray)
            else:
                self.image_plot.set_data(self.current_data)
        self.plot.draw()
        self.progressbar.setValue(self.index)

    def run(self, start_frame=0):
        for index, image in enumerate(self.db.getImageIterator(start_frame)):
            index += 1
            self.index = index
            line_cut = self.getLine(image.data, self.bar, self.h, image)
            self.current_data[index * self.h:(index + 1) *
                              self.h, :] = line_cut
            self.signal_update_plot.emit()
            if index >= self.n - 1 or self.cp.stop:
                self.signal_plot_finished.emit()
                break

    def plotFinishedEvent(self):
        if self.exporting:
            self.export()
            self.exporting_index += 1
            if self.exporting_index < len(self.bars):
                self.cellSelected(self.exporting_index, 0)
            else:
                self.exporting_index = 0
                self.exporting = False

    def export(self):
        filename = "kymograph%d.%s"
        # convert to grayscale if it is a color image that should be saved with a colormap
        if len(self.current_data.shape
               ) == 3 and self.input_colormap.currentText() != "None":
            data_gray = np.dot(self.current_data[..., :3],
                               [0.299, 0.587, 0.114])
        # if not just keep it
        else:
            data_gray = self.current_data
        # save the data as a numpy file
        np.savez(filename % (self.bar.id, "npz"), data_gray)
        # get the colormap
        cmap = self.input_colormap.currentText()
        if cmap == "None":
            cmap = "gray"
        # save the kymograph as an image
        plt.imsave(filename % (self.bar.id, "png"), data_gray, cmap=cmap)
        # print a log in the console
        print("Exported", filename % (self.bar.id, "npz"))

    def export2(self):
        self.exporting_index = 0
        self.exporting = True
        self.cellSelected(self.exporting_index, 0)

    def markerMoveEvent(self, marker):
        if marker.type == self.my_type:
            row = self.bar_dict[marker.id]
            self.bars[row] = marker
            self.tableWidget.selectRow(row)
            self.updateRow(row)
            self.selected = row
            self.updatePlot()

    def markerAddEvent(self, entry):
        self.updateTable()

    def markerRemoveEvent(self, entry):
        self.updateTable()

    def buttonPressedEvent(self):
        self.show()
Example #5
0
    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)

        # qthread and signals for update cell detection and loading ellipse at add on launch
        self.thread = Worker(run_function=None)
        self.thread.thread_started.connect(self.start_pbar)
        self.thread.thread_finished.connect(self.finish_pbar)
        self.thread.thread_progress.connect(self.update_pbar)

        self.stop = False
        self.plot_data = np.array([[], []])
        self.unet = None
        self.layout = QtWidgets.QVBoxLayout(self)

        # Setting up marker Types
        self.marker_type_cell1 = self.db.setMarkerType("cell", "#0a2eff",
                                                       self.db.TYPE_Ellipse)
        self.marker_type_cell2 = self.db.setMarkerType("cell new", "#Fa2eff",
                                                       self.db.TYPE_Ellipse)
        self.cp.reloadTypes()

        # finding and setting path to store network probability map
        self.prob_folder = os.environ["CLICKPOINTS_TMP"]
        self.prob_path = self.db.setPath(self.prob_folder)
        self.prob_layer = self.db.setLayer("prob_map")

        clickpoints.Addon.__init__(self, *args, **kwargs)

        # set the title and layout
        self.setWindowTitle("DeformationCytometer - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # weight file selection
        self.weight_selection = SetFile(store_path,
                                        filetype="weight file (*.h5)")
        self.weight_selection.fileSeleted.connect(self.initUnet)
        self.layout.addLayout(self.weight_selection)

        # update segmentation
        # in range of frames
        seg_layout = QtWidgets.QHBoxLayout()
        self.update_detection_button = QtWidgets.QPushButton(
            "update cell detection")
        self.update_detection_button.setToolTip(
            tooltip_strings["update cell detection"])
        self.update_detection_button.clicked.connect(
            partial(self.start_threaded, self.detect_all))
        seg_layout.addWidget(self.update_detection_button, stretch=5)
        # on single frame
        self.update_single_detection_button = QtWidgets.QPushButton(
            "single detection")
        self.update_single_detection_button.setToolTip(
            tooltip_strings["single detection"])
        self.update_single_detection_button.clicked.connect(self.detect_single)
        seg_layout.addWidget(self.update_single_detection_button, stretch=1)
        self.layout.addLayout(seg_layout)

        # regularity and solidity thresholds
        validator = QtGui.QDoubleValidator(0, 100, 3)
        filter_layout = QtWidgets.QHBoxLayout()
        reg_label = QtWidgets.QLabel("irregularity")
        filter_layout.addWidget(reg_label)
        self.reg_box = QtWidgets.QLineEdit("1.06")
        self.reg_box.setToolTip(tooltip_strings["irregularity"])
        self.reg_box.setValidator(validator)
        filter_layout.addWidget(self.reg_box,
                                stretch=1)  # TODO implement text edited method
        sol_label = QtWidgets.QLabel("solidity")
        filter_layout.addWidget(sol_label)
        self.sol_box = QtWidgets.QLineEdit("0.96")
        self.sol_box.setToolTip(tooltip_strings["solidity"])
        self.sol_box.setValidator(validator)
        filter_layout.addWidget(self.sol_box, stretch=1)
        rmin_label = QtWidgets.QLabel("min radius [µm]")
        filter_layout.addWidget(rmin_label)
        self.rmin_box = QtWidgets.QLineEdit("6")
        self.rmin_box.setToolTip(tooltip_strings["min radius"])
        self.rmin_box.setValidator(validator)
        filter_layout.addWidget(self.rmin_box, stretch=1)
        filter_layout.addStretch(stretch=4)
        self.layout.addLayout(filter_layout)

        # plotting buttons
        layout = QtWidgets.QHBoxLayout()
        self.button_stressstrain = QtWidgets.QPushButton("stress-strain")
        self.button_stressstrain.clicked.connect(self.plot_stress_strain)
        self.button_stressstrain.setToolTip(tooltip_strings["stress-strain"])
        layout.addWidget(self.button_stressstrain)
        self.button_kpos = QtWidgets.QPushButton("k-pos")
        self.button_kpos.clicked.connect(self.plot_k_pos)
        self.button_kpos.setToolTip(tooltip_strings["k-pos"])
        layout.addWidget(self.button_kpos)
        self.button_reg_sol = QtWidgets.QPushButton("regularity-solidity")
        self.button_reg_sol.clicked.connect(self.plot_irreg)
        self.button_reg_sol.setToolTip(tooltip_strings["regularity-solidity"])
        layout.addWidget(self.button_reg_sol)
        self.button_kHist = QtWidgets.QPushButton("k histogram")
        self.button_kHist.clicked.connect(self.plot_kHist)
        self.button_kHist.setToolTip(tooltip_strings["k histogram"])
        layout.addWidget(self.button_kHist)
        self.button_alphaHist = QtWidgets.QPushButton("alpha histogram")
        self.button_alphaHist.clicked.connect(self.plot_alphaHist)
        self.button_alphaHist.setToolTip(tooltip_strings["alpha histogram"])
        layout.addWidget(self.button_alphaHist)
        self.button_kalpha = QtWidgets.QPushButton("k-alpha")
        self.button_kalpha.clicked.connect(self.plot_k_alpha)
        self.button_kalpha.setToolTip(tooltip_strings["k-alpha"])
        layout.addWidget(self.button_kalpha)
        # button to switch between display of loaded and newly generated data
        frame = QtWidgets.QFrame()  # horizontal separating line
        frame.setFrameShape(QtWidgets.QFrame.VLine)
        frame.setLineWidth(3)
        layout.addWidget(frame)
        self.switch_data_button = QtWidgets.QPushButton(
            self.disp_text_existing)
        self.switch_data_button.clicked.connect(self.switch_display_data)
        self.switch_data_button.setToolTip(
            tooltip_strings[self.disp_text_existing])
        layout.addWidget(self.switch_data_button)
        self.layout.addLayout(layout)

        # matplotlib widgets to draw plots
        self.plot = MatplotlibWidget(self)
        self.plot_data = np.array([[], []])
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))
        self.plot.figure.canvas.mpl_connect('button_press_event',
                                            self.button_press_callback)

        # progress bar
        self.progressbar = QtWidgets.QProgressBar()
        self.layout.addWidget(self.progressbar)
        # progressbar lable
        pbar_info_layout = QtWidgets.QHBoxLayout()
        self.pbarLable = QtWidgets.QLabel("")
        pbar_info_layout.addWidget(self.pbarLable, stretch=1)
        pbar_info_layout.addStretch(stretch=2)
        # button to stop thread execution
        self.stop_button = QtWidgets.QPushButton("stop")
        self.stop_button.clicked.connect(self.quit_thread)
        self.stop_button.setToolTip(tooltip_strings["stop"])
        pbar_info_layout.addWidget(self.stop_button, stretch=1)
        self.layout.addLayout(pbar_info_layout)

        # setting paths for data, config and image
        # identifying the full path to the video. If an existing ClickPoints database is opened, the path if
        # is likely relative to the database location.
        self.filename = self.db.getImage(0).get_full_filename()
        if not os.path.isabs(self.filename):
            self.filename = str(
                Path(self.db._database_filename).parent.joinpath(
                    Path(self.filename)))

        self.config_file = self.constructFileNames("_config.txt")
        self.result_file = self.constructFileNames("_result.txt")
        self.addon_result_file = self.constructFileNames("_addon_result.txt")
        self.addon_evaluated_file = self.constructFileNames(
            "_addon_evaluated.csv")
        self.addon_config_file = self.constructFileNames("_addon_config.txt")
        self.vidcap = imageio.get_reader(self.filename)

        # reading in config an data
        self.data_all_existing = pd.DataFrame()
        self.data_mean_existing = pd.DataFrame()
        self.data_all_new = pd.DataFrame()
        self.data_mean_new = pd.DataFrame()
        if self.config_file.exists() and self.result_file.exists():
            self.config = getConfig(self.config_file)
            # ToDo: replace with a flag// also maybe some sort of "reculation" feature
            # Trying to get regularity and solidity from the config
            if "irregularity" in self.config.keys(
            ) and "solidity" in self.config.keys():
                solidity_threshold = self.config["solidity"]
                irregularity_threshold = self.config["irregularity"]
            else:
                solidity_threshold = self.sol_threshold
                irregularity_threshold = self.reg_threshold
            # reading unfiltered data (from results.txt) and data from evaluated.csv
            # unfiltered data (self.data_all_existing) is used to display regularity and solidity scatter plot
            # everything else is from evaluated.csv (self.data_mean_existing)
            self.data_all_existing, self.data_mean_existing = self.load_data(
                self.result_file, solidity_threshold, irregularity_threshold)
        else:  # get a default config if no config is found
            self.config = getConfig(default_config_path)

        ## loading data from previous addon action
        if self.addon_result_file.exists():
            self.data_all_new, self.data_mean_new = self.load_data(
                self.addon_result_file, self.sol_threshold, self.reg_threshold)
            self.start_threaded(
                partial(self.display_ellipses,
                        type=self.marker_type_cell2,
                        data=self.data_all_new))
        # create an addon config file
        # presence of this file allows easy implementation of the load_data and tank threading pipelines when
        # calculating new data
        if not self.addon_config_file.exists():
            shutil.copy(self.config_file, self.addon_config_file)

        self.plot_data_frame = self.data_all
        # initialize plot
        self.plot_stress_strain()

        # Displaying the loaded cells. This is in separate thread as it takes up to 20 seconds.
        self.db.deleteEllipses(type=self.marker_type_cell1)
        self.db.deleteEllipses(type=self.marker_type_cell2)
        self.start_threaded(
            partial(self.display_ellipses,
                    type=self.marker_type_cell1,
                    data=self.data_all_existing))

        print("loading finished")
Example #6
0
class Addon(clickpoints.Addon):
    signal_update_plot = QtCore.Signal()
    signal_plot_finished = QtCore.Signal()
    disp_text_existing = "displaying existing data"
    disp_text_new = "displaying new data"

    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)

        # qthread and signals for update cell detection and loading ellipse at add on launch
        self.thread = Worker(run_function=None)
        self.thread.thread_started.connect(self.start_pbar)
        self.thread.thread_finished.connect(self.finish_pbar)
        self.thread.thread_progress.connect(self.update_pbar)

        self.stop = False
        self.plot_data = np.array([[], []])
        self.unet = None
        self.layout = QtWidgets.QVBoxLayout(self)

        # Setting up marker Types
        self.marker_type_cell1 = self.db.setMarkerType("cell", "#0a2eff",
                                                       self.db.TYPE_Ellipse)
        self.marker_type_cell2 = self.db.setMarkerType("cell new", "#Fa2eff",
                                                       self.db.TYPE_Ellipse)
        self.cp.reloadTypes()

        # finding and setting path to store network probability map
        self.prob_folder = os.environ["CLICKPOINTS_TMP"]
        self.prob_path = self.db.setPath(self.prob_folder)
        self.prob_layer = self.db.setLayer("prob_map")

        clickpoints.Addon.__init__(self, *args, **kwargs)

        # set the title and layout
        self.setWindowTitle("DeformationCytometer - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # weight file selection
        self.weight_selection = SetFile(store_path,
                                        filetype="weight file (*.h5)")
        self.weight_selection.fileSeleted.connect(self.initUnet)
        self.layout.addLayout(self.weight_selection)

        # update segmentation
        # in range of frames
        seg_layout = QtWidgets.QHBoxLayout()
        self.update_detection_button = QtWidgets.QPushButton(
            "update cell detection")
        self.update_detection_button.setToolTip(
            tooltip_strings["update cell detection"])
        self.update_detection_button.clicked.connect(
            partial(self.start_threaded, self.detect_all))
        seg_layout.addWidget(self.update_detection_button, stretch=5)
        # on single frame
        self.update_single_detection_button = QtWidgets.QPushButton(
            "single detection")
        self.update_single_detection_button.setToolTip(
            tooltip_strings["single detection"])
        self.update_single_detection_button.clicked.connect(self.detect_single)
        seg_layout.addWidget(self.update_single_detection_button, stretch=1)
        self.layout.addLayout(seg_layout)

        # regularity and solidity thresholds
        validator = QtGui.QDoubleValidator(0, 100, 3)
        filter_layout = QtWidgets.QHBoxLayout()
        reg_label = QtWidgets.QLabel("irregularity")
        filter_layout.addWidget(reg_label)
        self.reg_box = QtWidgets.QLineEdit("1.06")
        self.reg_box.setToolTip(tooltip_strings["irregularity"])
        self.reg_box.setValidator(validator)
        filter_layout.addWidget(self.reg_box,
                                stretch=1)  # TODO implement text edited method
        sol_label = QtWidgets.QLabel("solidity")
        filter_layout.addWidget(sol_label)
        self.sol_box = QtWidgets.QLineEdit("0.96")
        self.sol_box.setToolTip(tooltip_strings["solidity"])
        self.sol_box.setValidator(validator)
        filter_layout.addWidget(self.sol_box, stretch=1)
        rmin_label = QtWidgets.QLabel("min radius [µm]")
        filter_layout.addWidget(rmin_label)
        self.rmin_box = QtWidgets.QLineEdit("6")
        self.rmin_box.setToolTip(tooltip_strings["min radius"])
        self.rmin_box.setValidator(validator)
        filter_layout.addWidget(self.rmin_box, stretch=1)
        filter_layout.addStretch(stretch=4)
        self.layout.addLayout(filter_layout)

        # plotting buttons
        layout = QtWidgets.QHBoxLayout()
        self.button_stressstrain = QtWidgets.QPushButton("stress-strain")
        self.button_stressstrain.clicked.connect(self.plot_stress_strain)
        self.button_stressstrain.setToolTip(tooltip_strings["stress-strain"])
        layout.addWidget(self.button_stressstrain)
        self.button_kpos = QtWidgets.QPushButton("k-pos")
        self.button_kpos.clicked.connect(self.plot_k_pos)
        self.button_kpos.setToolTip(tooltip_strings["k-pos"])
        layout.addWidget(self.button_kpos)
        self.button_reg_sol = QtWidgets.QPushButton("regularity-solidity")
        self.button_reg_sol.clicked.connect(self.plot_irreg)
        self.button_reg_sol.setToolTip(tooltip_strings["regularity-solidity"])
        layout.addWidget(self.button_reg_sol)
        self.button_kHist = QtWidgets.QPushButton("k histogram")
        self.button_kHist.clicked.connect(self.plot_kHist)
        self.button_kHist.setToolTip(tooltip_strings["k histogram"])
        layout.addWidget(self.button_kHist)
        self.button_alphaHist = QtWidgets.QPushButton("alpha histogram")
        self.button_alphaHist.clicked.connect(self.plot_alphaHist)
        self.button_alphaHist.setToolTip(tooltip_strings["alpha histogram"])
        layout.addWidget(self.button_alphaHist)
        self.button_kalpha = QtWidgets.QPushButton("k-alpha")
        self.button_kalpha.clicked.connect(self.plot_k_alpha)
        self.button_kalpha.setToolTip(tooltip_strings["k-alpha"])
        layout.addWidget(self.button_kalpha)
        # button to switch between display of loaded and newly generated data
        frame = QtWidgets.QFrame()  # horizontal separating line
        frame.setFrameShape(QtWidgets.QFrame.VLine)
        frame.setLineWidth(3)
        layout.addWidget(frame)
        self.switch_data_button = QtWidgets.QPushButton(
            self.disp_text_existing)
        self.switch_data_button.clicked.connect(self.switch_display_data)
        self.switch_data_button.setToolTip(
            tooltip_strings[self.disp_text_existing])
        layout.addWidget(self.switch_data_button)
        self.layout.addLayout(layout)

        # matplotlib widgets to draw plots
        self.plot = MatplotlibWidget(self)
        self.plot_data = np.array([[], []])
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))
        self.plot.figure.canvas.mpl_connect('button_press_event',
                                            self.button_press_callback)

        # progress bar
        self.progressbar = QtWidgets.QProgressBar()
        self.layout.addWidget(self.progressbar)
        # progressbar lable
        pbar_info_layout = QtWidgets.QHBoxLayout()
        self.pbarLable = QtWidgets.QLabel("")
        pbar_info_layout.addWidget(self.pbarLable, stretch=1)
        pbar_info_layout.addStretch(stretch=2)
        # button to stop thread execution
        self.stop_button = QtWidgets.QPushButton("stop")
        self.stop_button.clicked.connect(self.quit_thread)
        self.stop_button.setToolTip(tooltip_strings["stop"])
        pbar_info_layout.addWidget(self.stop_button, stretch=1)
        self.layout.addLayout(pbar_info_layout)

        # setting paths for data, config and image
        # identifying the full path to the video. If an existing ClickPoints database is opened, the path if
        # is likely relative to the database location.
        self.filename = self.db.getImage(0).get_full_filename()
        if not os.path.isabs(self.filename):
            self.filename = str(
                Path(self.db._database_filename).parent.joinpath(
                    Path(self.filename)))

        self.config_file = self.constructFileNames("_config.txt")
        self.result_file = self.constructFileNames("_result.txt")
        self.addon_result_file = self.constructFileNames("_addon_result.txt")
        self.addon_evaluated_file = self.constructFileNames(
            "_addon_evaluated.csv")
        self.addon_config_file = self.constructFileNames("_addon_config.txt")
        self.vidcap = imageio.get_reader(self.filename)

        # reading in config an data
        self.data_all_existing = pd.DataFrame()
        self.data_mean_existing = pd.DataFrame()
        self.data_all_new = pd.DataFrame()
        self.data_mean_new = pd.DataFrame()
        if self.config_file.exists() and self.result_file.exists():
            self.config = getConfig(self.config_file)
            # ToDo: replace with a flag// also maybe some sort of "reculation" feature
            # Trying to get regularity and solidity from the config
            if "irregularity" in self.config.keys(
            ) and "solidity" in self.config.keys():
                solidity_threshold = self.config["solidity"]
                irregularity_threshold = self.config["irregularity"]
            else:
                solidity_threshold = self.sol_threshold
                irregularity_threshold = self.reg_threshold
            # reading unfiltered data (from results.txt) and data from evaluated.csv
            # unfiltered data (self.data_all_existing) is used to display regularity and solidity scatter plot
            # everything else is from evaluated.csv (self.data_mean_existing)
            self.data_all_existing, self.data_mean_existing = self.load_data(
                self.result_file, solidity_threshold, irregularity_threshold)
        else:  # get a default config if no config is found
            self.config = getConfig(default_config_path)

        ## loading data from previous addon action
        if self.addon_result_file.exists():
            self.data_all_new, self.data_mean_new = self.load_data(
                self.addon_result_file, self.sol_threshold, self.reg_threshold)
            self.start_threaded(
                partial(self.display_ellipses,
                        type=self.marker_type_cell2,
                        data=self.data_all_new))
        # create an addon config file
        # presence of this file allows easy implementation of the load_data and tank threading pipelines when
        # calculating new data
        if not self.addon_config_file.exists():
            shutil.copy(self.config_file, self.addon_config_file)

        self.plot_data_frame = self.data_all
        # initialize plot
        self.plot_stress_strain()

        # Displaying the loaded cells. This is in separate thread as it takes up to 20 seconds.
        self.db.deleteEllipses(type=self.marker_type_cell1)
        self.db.deleteEllipses(type=self.marker_type_cell2)
        self.start_threaded(
            partial(self.display_ellipses,
                    type=self.marker_type_cell1,
                    data=self.data_all_existing))

        print("loading finished")

    def constructFileNames(self, replace):
        if self.filename.endswith(".tif"):
            return Path(self.filename.replace(".tif", replace))
        if self.filename.endswith(".cdb"):
            return Path(self.filename.replace(".cdb", replace))

    # slots to update the progress bar from another thread (update cell detection and display_ellipse)
    @pyqtSlot(tuple, str)  # the decorator is not really necessary
    def start_pbar(self, prange, text):
        self.progressbar.setMinimum(prange[0])
        self.progressbar.setMaximum(prange[1])
        self.pbarLable.setText(text)

    @pyqtSlot(int)
    def update_pbar(self, value):
        self.progressbar.setValue(value)

    @pyqtSlot(int)
    def finish_pbar(self, value):
        self.progressbar.setValue(value)
        self.pbarLable.setText("finished")

    # Dynamic switch between existing and new data
    def switch_display_data(self):

        if self.switch_data_button.text() == self.disp_text_existing:
            text = self.disp_text_new
        else:
            text = self.disp_text_existing
        self.switch_data_button.setText(text)
        # updating the plot
        self.plot_type()

    @property
    def data_all(self):
        if self.switch_data_button.text() == self.disp_text_existing:
            return self.data_all_existing
        if self.switch_data_button.text() == self.disp_text_new:
            return self.data_all_new

    @property
    def data_mean(self):
        if self.switch_data_button.text() == self.disp_text_existing:
            return self.data_mean_existing
        if self.switch_data_button.text() == self.disp_text_new:
            return self.data_mean_new

    # solidity and regularity and rmin properties
    @property
    def sol_threshold(self):
        return float(self.sol_box.text())

    @property
    def reg_threshold(self):
        return float(self.reg_box.text())

    @property
    def rmin(self):
        return float(self.rmin_box.text())

    # handling thread entrance and exit
    def start_threaded(self, run_function):
        self.stop = False  # self.stop property is used to by the thread function to exit loops
        self.thread.run_function = run_function
        self.thread.start()

    def quit_thread(self):
        self.stop = True
        self.thread.quit()

    def load_data(self, file, solidity_threshold, irregularity_threshold):

        data_all = getData(file)
        if not "area" in data_all.keys():
            data_all["area"] = data_all["long_axis"] * data_all[
                "short_axis"] * np.pi / 4

        if len(data_all) == 0:
            print("no data loaded from file '%s'" % file)
            return pd.DataFrame(), pd.DataFrame()
        # use a "read sol from config flag here
        data_mean, config_eval = load_all_data_new(
            self.db.getImage(0).get_full_filename().replace(
                ".tif", "_evaluated_new.csv"),
            do_group=False,
            do_excude=False)
        return data_all, data_mean

    # plotting functions
    # wrapper for all scatter plots; handles empty and data log10 transform
    def plot_scatter(self,
                     data,
                     type1,
                     type2,
                     funct1=doNothing,
                     funct2=doNothing):
        self.init_newPlot()
        try:
            x = funct1(data[type1])
            y = funct2(data[type2])
        except KeyError:
            self.plot.draw()
            return
        if (np.all(np.isnan(x))) or (np.all(np.isnan(x))):
            return
        try:
            plotDensityScatter(x,
                               y,
                               cmap='viridis',
                               alpha=1,
                               skip=1,
                               y_factor=1,
                               s=5,
                               levels=None,
                               loglog=False,
                               ax=self.plot.axes)
            self.plot_data = np.array([x, y])
            self.plot_data_frame = data
            self.plot.axes.set_xlabel(type1)
            self.plot.axes.set_ylabel(type2)
        except (ValueError, np.LinAlgError):
            print("kernel density estimation failed? not enough cells found?")
            return

    # clearing axis and plot.data
    def init_newPlot(self):
        self.plot_data = np.array([[], []])
        self.plot.axes.clear()
        self.plot.draw()

    def plot_alphaHist(self):
        self.plot_type = self.plot_alphaHist
        self.init_newPlot()
        try:
            x = self.data_mean["alpha_cell"]
        except KeyError:
            return
        if not np.any(~np.isnan(x)):
            return
        l = plot_density_hist(x, ax=self.plot.axes, color="C1")
        # stat_k = get_mode_stats(data.k_cell)
        self.plot.axes.set_xlim((1, 1))
        self.plot.axes.xaxis.set_ticks(np.arange(0, 1, 0.2))
        self.plot.axes.grid()
        self.plot.draw()

    def plot_kHist(self):
        self.plot_type = self.plot_kHist
        self.init_newPlot()
        try:
            x = np.array(self.data_mean["k_cell"])
        except KeyError:
            return
        if not np.any(~np.isnan(x)):
            return
        l = plot_density_hist(np.log10(x), ax=self.plot.axes, color="C0")
        self.plot.axes.set_xlim((1, 4))
        self.plot.axes.xaxis.set_ticks(np.arange(5))
        self.plot.axes.grid()
        self.plot.draw()

    def plot_k_alpha(self):
        self.plot_type = self.plot_k_alpha
        self.plot_scatter(self.data_mean,
                          "alpha_cell",
                          "k_cell",
                          funct2=np.log10)
        self.plot.axes.set_ylabel("log10 k")
        self.plot.axes.set_xlabel("alpha")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_k_size(self):
        self.plot_type = self.plot_k_size
        self.plot_scatter(self.data_mean, "area",
                          "w_k_cell")  # use self.data_all for unfiltered data
        self.plot.axes.set_ylabel("w_k")
        self.plot.axes.set_xlabel("area")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_k_pos(self):
        self.plot_type = self.plot_k_pos
        self.plot_scatter(self.data_mean, "rp",
                          "w_k_cell")  # use self.data_all for unfiltered data
        self.plot.axes.set_ylabel("w_k")
        self.plot.axes.set_xlabel("radiale position")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_irreg(self):
        self.plot_type = self.plot_irreg
        # unfiltered plot of irregularity and solidity to easily identify errors
        # currently based on single cells
        self.plot_scatter(self.data_all,
                          "solidity",
                          "irregularity",
                          funct1=doNothing,
                          funct2=doNothing)
        self.plot.axes.axvline(self.sol_threshold, ls="--")
        self.plot.axes.axhline(self.reg_threshold, ls="--")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_stress_strain(self):
        self.plot_type = self.plot_stress_strain
        self.plot_scatter(self.data_mean, "stress", "strain")
        self.plot.axes.set_xlim((-10, 400))
        self.plot.figure.tight_layout()
        self.plot.draw()

    # Jump to cell in ClickPoints window when clicking near a data point in the scatter plot
    def button_press_callback(self, event):
        # only drag with left mouse button, do nothing if plot is empty or clicked outside of axis
        if event.button != 1 or event.inaxes is None or self.plot_data.size == 0:
            return
        xy = np.array([event.xdata, event.ydata])
        scale = np.nanmean(self.plot_data, axis=1)
        distance = np.linalg.norm(self.plot_data / scale[:, None] -
                                  xy[:, None] / scale[:, None],
                                  axis=0)
        nearest_point = np.nanargmin(distance)
        print("clicked", xy)
        self.cp.jumpToFrame(int(self.plot_data_frame.frames[nearest_point]))
        self.cp.centerOn(self.plot_data_frame.x[nearest_point],
                         self.plot_data_frame.y[nearest_point])

    # not sure what this is for ^^
    def buttonPressedEvent(self):
        self.show()

    ## cell detection
    def initUnet(self):
        print("loading weight file: ", self.weight_selection.file)
        shape = self.cp.getImage().getShape()
        self.unet = UNet((shape[0], shape[1], 1),
                         1,
                         d=8,
                         weights=self.weight_selection.file)

    # cell detection and evaluation on multiple frames
    def detect_all(self):
        info = "cell detection frame %d to %d" % (self.cp.getFrameRange()[0],
                                                  self.cp.getFrameRange()[1])
        print(info)

        self.data_all_new = pd.DataFrame()
        self.data_mean_new = pd.DataFrame()
        self.db.deleteEllipses(type=self.marker_type_cell2)
        self.thread.thread_started.emit(tuple(self.cp.getFrameRange()[:2]),
                                        info)
        for frame in range(self.cp.getFrameRange()[0],
                           self.cp.getFrameRange()[1]):
            if self.stop:  # stop signal from "stop" button
                break
            im = self.db.getImage(frame=frame)
            img = im.data
            cells, probability_map = self.detect(im, img, frame)
            for cell in cells:
                self.data_all_new = self.data_all_new.append(cell,
                                                             ignore_index=True)
            self.thread.thread_progress.emit(frame)
            # reloading the mask and ellipse display in ClickPoints// may not be necessary to do it in batches
            if frame % 10 == 0:
                self.cp.reloadMask()
                self.cp.reloadMarker()

        self.cp.reloadMask()
        self.cp.reloadMarker()
        self.data_all_new["timestamp"] = self.data_all_new["timestamp"].astype(
            float)
        self.data_all_new["frames"] = self.data_all_new["frames"].astype(int)
        # save data to addon_result.txt file
        save_cells_to_file(self.addon_result_file,
                           self.data_all_new.to_dict("records"))
        # tank threading
        print("tank threading")
        # catching error if no velocities could be identified (e.g. when only few cells are identified)
        try:
            self.tank_treading(self.data_all_new)
            # further evaluation
            print("evaluation")
            if self.addon_evaluated_file.exists():
                os.remove(self.addon_evaluated_file)
            self.data_all_new, self.data_mean_new = self.load_data(
                self.addon_result_file, self.sol_threshold, self.reg_threshold)
        except ValueError as e:
            print(e)
            self.data_mean_new = self.data_all_new.copy()
        self.thread.thread_finished.emit(self.cp.getFrameRange()[1])
        print("finished")

    # tank threading: saves results to an "_addon_tt.csv" file
    def tank_treading(self, data):
        # TODO implement tank threading for non video database
        image_reader = CachedImageReader(str(self.filename))
        getVelocity(data, self.config)
        correctCenter(data, self.config)
        data = data[(data.solidity > self.sol_threshold)
                    & (data.irregularity < self.reg_threshold)]
        ids = pd.unique(data["cell_id"])
        results = []
        for id in ids:
            d = data[data.cell_id == id]
            crops, shifts, valid = getCroppedImages(image_reader, d)
            if len(crops) <= 1:
                continue
            crops = crops[valid]
            time = (d.timestamp - d.iloc[0].timestamp) * 1e-3
            speed, r2 = doTracking(crops,
                                   data0=d,
                                   times=np.array(time),
                                   pixel_size=self.config["pixel_size"])
            results.append([id, speed, r2])
        data = pd.DataFrame(results, columns=["id", "tt", "tt_r2"])
        data.to_csv(self.filename[:-4] + "_addon_tt.csv")

    # Detection in single frame. Also saves the network probability map to the second ClickPoints layer
    # tif file of the probability map is saved to ClickPoints temporary folder.
    def detect_single(self):
        im = self.cp.getImage()
        img = self.cp.getImage().data
        frame = im.frame
        cells, probability_map = self.detect(im, img, frame)
        self.cp.reloadMask()
        self.cp.reloadMarker()

        # writing probability map as an additional layer
        filename = os.path.join(self.prob_folder, "%dprob_map.tiff" % frame)
        Image.fromarray(
            (probability_map * 255).astype(np.uint8)).save(filename)
        # Catch error if image already exists. In this case only overwriting the image file is sufficient.
        try:
            self.db.setImage(filename=filename,
                             sort_index=frame,
                             layer=self.prob_layer,
                             path=self.prob_path)
        except peewee.IntegrityError:
            pass

    # Base detection function. Includes filters for objects without fully closed boundaries, objects close to
    # the horizontal image edge and objects with a radius smaller the self.r_min.
    def detect(self, im, img, frame):

        if self.unet is None:
            self.unet = UNet((img.shape[0], img.shape[1], 1), 1, d=8)
        img = (img - np.mean(img)) / np.std(img).astype(np.float32)
        timestamp = getTimestamp(self.vidcap, frame)

        probability_map = self.unet.predict(img[None, :, :, None])[0, :, :, 0]
        prediction_mask = probability_map > 0.5
        cells, prediction_mask = mask_to_cells_edge(prediction_mask,
                                                    img,
                                                    self.config,
                                                    self.rmin, {},
                                                    edge_dist=15,
                                                    return_mask=True)

        [
            c.update({
                "frames": frame,
                "timestamp": timestamp,
                "area": np.pi * (c["long_axis"] * c["short_axis"]) / 4
            }) for c in cells
        ]  # maybe use map for this?

        self.db.setMask(image=im, data=prediction_mask.astype(np.uint8))
        self.db.deleteEllipses(type=self.marker_type_cell2, image=im)
        self.drawEllipse(pd.DataFrame(cells), self.marker_type_cell2)

        return cells, probability_map

    def keyPressEvent(self, event):

        if event.key() == QtCore.Qt.Key_G:
            print("detecting")
            self.detect_single()
            print("detecting finished")

    # Display all ellipses at launch
    def display_ellipses(self, type="cell", data=None):

        batch_size = 200
        data = data if not (data is None) else self.data_all_existing
        if len(data) == 0:
            return

        self.thread.thread_started.emit((0, len(data)), "displaying ellipses")
        for block in range(0, len(data), batch_size):
            if self.stop:
                break
            if block + batch_size > len(data):
                data_block = data.iloc[block:]
            else:
                data_block = data.iloc[block:block + batch_size]

            self.drawEllipse(data_block, type)
            self.thread.thread_progress.emit(block)
            self.cp.reloadMarker()  # not sure how thread safe this is
        self.thread.thread_finished.emit(len(data))

    # based ellipse display function
    def drawEllipse(self, data_block, type):

        if len(data_block) == 0:
            return

        strains = (data_block["long_axis"] -
                   data_block["short_axis"]) / np.sqrt(
                       data_block["long_axis"] * data_block["short_axis"])
        # list of all marker texts
        text = []
        for s, sol, irr in zip(strains, data_block['solidity'],
                               data_block['irregularity']):
            text.append(
                f"strain {s:.3f}\nsolidity {sol:.2f}\nirreg. {irr:.3f}")
        self.db.setEllipses(
            frame=list(data_block["frames"]),
            x=list(data_block["x"]),
            y=list(data_block["y"]),
            width=list(data_block["long_axis"] / self.config["pixel_size"]),
            height=list(data_block["short_axis"] / self.config["pixel_size"]),
            angle=list(data_block["angle"]),
            type=type,
            text=text)
Example #7
0
    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)

        # set the title and layout
        self.setWindowTitle("Fluorescence Diffusion - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        self.addOption(key="delta_t",
                       display_name="Delta t",
                       default=2,
                       value_type="float")
        self.addOption(key="color_channel",
                       display_name="Color Channel",
                       default=1,
                       value_type="int")
        self.addOption(key="output_folder",
                       display_name="Output Folder",
                       default="output",
                       value_type="string")

        # create a line type "connect"
        if not self.db.getMarkerType("connect"):
            self.db.setMarkerType("connect", [0, 255, 255], self.db.TYPE_Line)
            self.cp.reloadTypes()

        self.layout_intensity = QtWidgets.QHBoxLayout()
        self.layout.addLayout(self.layout_intensity)

        self.layout_intensity1 = QtWidgets.QVBoxLayout()
        self.layout_intensity.addLayout(self.layout_intensity1)

        self.input_delta_t = AddQSpinBox(self.layout_intensity1,
                                         "Delta T:",
                                         value=self.getOption("delta_t"),
                                         float=True)
        self.input_delta_t.setSuffix(" s")
        self.linkOption("delta_t", self.input_delta_t)

        self.input_color = AddQSpinBox(self.layout_intensity1,
                                       "Color Channel:",
                                       value=self.getOption("color_channel"),
                                       float=False)
        self.linkOption("color_channel", self.input_color)

        self.button_update = QtWidgets.QPushButton("Calculate Intensities")
        self.layout_intensity1.addWidget(self.button_update)
        self.button_update.clicked.connect(self.updateIntensities)

        # the table listing the line objects
        self.tableWidget = QtWidgets.QTableWidget(0, 1, self)
        self.layout_intensity1.addWidget(self.tableWidget)

        self.layout_intensity_plot = QtWidgets.QVBoxLayout()
        self.layout_intensity.addLayout(self.layout_intensity_plot)
        self.plot_intensity = MatplotlibWidget(self)
        self.layout_intensity_plot.addWidget(self.plot_intensity)
        self.layout_intensity_plot.addWidget(
            NavigationToolbar(self.plot_intensity, self))

        self.layout.addWidget(AddHLine())

        self.layout_diffusion = QtWidgets.QHBoxLayout()
        self.layout.addLayout(self.layout_diffusion)

        self.layout_diffusion1 = QtWidgets.QVBoxLayout()
        self.layout_diffusion.addLayout(self.layout_diffusion1)

        self.button_calculate = QtWidgets.QPushButton("Calculate Diffusion")
        self.layout_diffusion1.addWidget(self.button_calculate)
        self.button_calculate.clicked.connect(self.calculateDiffusion)

        # the table listing the line objects
        self.tableWidget2 = QtWidgets.QTableWidget(0, 1, self)
        self.layout_diffusion1.addWidget(self.tableWidget2)

        self.layout_diffusion_plot = QtWidgets.QVBoxLayout()
        self.layout_diffusion.addLayout(self.layout_diffusion_plot)
        self.plot_diffusion = MatplotlibWidget(self)
        self.layout_diffusion_plot.addWidget(self.plot_diffusion)
        self.layout_diffusion_plot.addWidget(
            NavigationToolbar(self.plot_diffusion, self))

        # add a progress bar
        self.progressbar = QtWidgets.QProgressBar()
        self.layout.addWidget(self.progressbar)

        self.diffusionConstants = []
Example #8
0
    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)

        self.layout = QtWidgets.QVBoxLayout(self)

        # Check if the marker type is present
        self.marker_type_cell = self.db.setMarkerType("cell", "#0a2eff",
                                                      self.db.TYPE_Ellipse)
        self.marker_type_cell2 = self.db.setMarkerType("cell2", "#Fa2eff",
                                                       self.db.TYPE_Ellipse)
        self.cp.reloadTypes()

        self.loadData()

        clickpoints.Addon.__init__(self, *args, **kwargs)
        # set the title and layout
        self.setWindowTitle("DeformationCytometer - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # add export buttons
        layout = QtWidgets.QHBoxLayout()
        self.button_stressstrain = QtWidgets.QPushButton("stress-strain")
        self.button_stressstrain.clicked.connect(self.plot_stress_strain)
        layout.addWidget(self.button_stressstrain)

        self.button_stressy = QtWidgets.QPushButton("y-strain")
        self.button_stressy.clicked.connect(self.plot_y_strain)
        layout.addWidget(self.button_stressy)

        self.button_y_angle = QtWidgets.QPushButton("y-angle")
        self.button_y_angle.clicked.connect(self.plot_y_angle)
        layout.addWidget(self.button_y_angle)

        self.layout.addLayout(layout)

        # add a plot widget
        self.plot = MatplotlibWidget(self)
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))
        self.plot.figure.canvas.mpl_connect('button_press_event',
                                            self.button_press_callback)

        # add a progress bar
        self.progressbar = QtWidgets.QProgressBar()
        self.layout.addWidget(self.progressbar)

        # connect slots
        # self.signal_update_plot.connect(self.updatePlotImageEvent)
        # self.signal_plot_finished.connect(self.plotFinishedEvent)

        # initialize the table
        # self.updateTable()
        # self.selected = None

        filename = self.db.getImage(0).get_full_filename()
        print(filename.replace(".tif", "_config.txt"))
        self.config = getConfig(filename.replace(".tif", "_config.txt"))
        self.data = getData(filename.replace(".tif", "_result.txt"))

        getVelocity(self.data, self.config)

        try:
            correctCenter(self.data, self.config)
        except ValueError:
            pass

        self.data = self.data.groupby(['cell_id']).mean()

        self.data = filterCells(self.data, self.config)
        self.data.reset_index(drop=True, inplace=True)

        getStressStrain(self.data, self.config)
Example #9
0
class Addon(clickpoints.Addon):
    data = None
    data2 = None
    unet = None

    signal_update_plot = QtCore.Signal()
    signal_plot_finished = QtCore.Signal()
    image_plot = None
    last_update = 0
    updating = False
    exporting = False
    exporting_index = 0

    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)

        self.layout = QtWidgets.QVBoxLayout(self)

        # Check if the marker type is present
        self.marker_type_cell = self.db.setMarkerType("cell", "#0a2eff",
                                                      self.db.TYPE_Ellipse)
        self.marker_type_cell2 = self.db.setMarkerType("cell2", "#Fa2eff",
                                                       self.db.TYPE_Ellipse)
        self.cp.reloadTypes()

        self.loadData()

        clickpoints.Addon.__init__(self, *args, **kwargs)
        # set the title and layout
        self.setWindowTitle("DeformationCytometer - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # add export buttons
        layout = QtWidgets.QHBoxLayout()
        self.button_stressstrain = QtWidgets.QPushButton("stress-strain")
        self.button_stressstrain.clicked.connect(self.plot_stress_strain)
        layout.addWidget(self.button_stressstrain)

        self.button_stressy = QtWidgets.QPushButton("y-strain")
        self.button_stressy.clicked.connect(self.plot_y_strain)
        layout.addWidget(self.button_stressy)

        self.button_y_angle = QtWidgets.QPushButton("y-angle")
        self.button_y_angle.clicked.connect(self.plot_y_angle)
        layout.addWidget(self.button_y_angle)

        self.layout.addLayout(layout)

        # add a plot widget
        self.plot = MatplotlibWidget(self)
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))
        self.plot.figure.canvas.mpl_connect('button_press_event',
                                            self.button_press_callback)

        # add a progress bar
        self.progressbar = QtWidgets.QProgressBar()
        self.layout.addWidget(self.progressbar)

        # connect slots
        # self.signal_update_plot.connect(self.updatePlotImageEvent)
        # self.signal_plot_finished.connect(self.plotFinishedEvent)

        # initialize the table
        # self.updateTable()
        # self.selected = None

        filename = self.db.getImage(0).get_full_filename()
        print(filename.replace(".tif", "_config.txt"))
        self.config = getConfig(filename.replace(".tif", "_config.txt"))
        self.data = getData(filename.replace(".tif", "_result.txt"))

        getVelocity(self.data, self.config)

        try:
            correctCenter(self.data, self.config)
        except ValueError:
            pass

        self.data = self.data.groupby(['cell_id']).mean()

        self.data = filterCells(self.data, self.config)
        self.data.reset_index(drop=True, inplace=True)

        getStressStrain(self.data, self.config)

    def button_press_callback(self, event):
        # only drag with left mouse button
        if event.button != 1:
            return
        # if the user doesn't have clicked on an axis do nothing
        if event.inaxes is None:
            return
        # get the pixel of the kymograph
        xy = np.array([event.xdata, event.ydata])
        scale = np.mean(self.plot_data, axis=1)
        distance = np.linalg.norm(self.plot_data / scale[:, None] -
                                  xy[:, None] / scale[:, None],
                                  axis=0)
        print(self.plot_data.shape, xy[:, None].shape, distance.shape)
        nearest_dist = np.min(distance)
        print("distance ", nearest_dist)
        nearest_point = np.argmin(distance)

        filename = self.db.getImage(0).get_full_filename()
        stress_values = stressfunc(self.data.iloc[:, 3] * 1e-6,
                                   filename.replace(".tif", "_config.txt"))
        strain_values = strain(self.data.iloc[:, 4], self.data.iloc[:, 5])

        print(
            np.linalg.norm(
                np.array([
                    stress_values[nearest_point], strain_values[nearest_point]
                ]) - xy))

        print("clicked", xy, stress_values[nearest_point], " ",
              strain_values[nearest_point], self.data.iloc[nearest_point])

        # x, y = event.xdata/self.input_scale1.value(), event.ydata/self.h/self.input_scale2.value()
        # jump to the frame in time
        self.cp.jumpToFrame(self.data.frames[nearest_point])
        # and to the xy position
        self.cp.centerOn(self.data.x[nearest_point],
                         self.data.y[nearest_point])

    def plot_stress_strain(self):
        filename = self.db.getImage(0).get_full_filename()

        self.plot.axes.clear()

        #plt.sca(self.plot.axes)
        x = self.data.stress
        y = self.data.strain
        plotDensityScatter(x, y, ax=self.plot.axes)

        self.plot_data = np.array([x, y])
        #self.plot.axes.plot(stress_values, strain_values, "o")
        self.plot.axes.set_xlabel("stress")
        self.plot.axes.set_ylabel("strain")
        self.plot.axes.set_xlim(-10, 400)
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_y_strain(self):
        y = self.data[:, 2]
        stress_values = stressfunc(self.data[:, 3] * 1e-6, self.config)
        strain_values = strain(self.data[:, 4], self.data[:, 5])

        self.plot.axes.clear()

        self.plot_data = np.array([y, strain_values])
        self.plot.axes.plot(y, strain_values, "o")
        self.plot.axes.set_xlabel("y")
        self.plot.axes.set_ylabel("strain")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_y_angle(self):
        y = self.data[:, 2]
        angle = self.data[:, 6]

        self.plot.axes.clear()

        self.plot_data = np.array([y, angle])
        self.plot.axes.plot(y, angle, "o")
        self.plot.axes.set_xlabel("y")
        self.plot.axes.set_ylabel("angle")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def export(self):
        pass

    def buttonPressedEvent(self):
        self.show()

    def detect(self):
        im = self.cp.getImage()
        img = self.cp.getImage().data
        if self.unet is None:
            self.unet = UNet((img.shape[0], img.shape[1], 1), 1, d=8)
        img = (img - np.mean(img)) / np.std(img).astype(np.float32)
        prediction_mask = self.unet.predict(img[None, :, :, None])[0, :, :,
                                                                   0] > 0.5
        self.db.setMask(image=self.cp.getImage(),
                        data=prediction_mask.astype(np.uint8))
        print(prediction_mask.shape)
        self.cp.reloadMask()
        print(prediction_mask)

        labeled = label(prediction_mask)

        # iterate over all detected regions
        for region in regionprops(labeled, img):
            y, x = region.centroid
            if region.orientation > 0:
                ellipse_angle = np.pi / 2 - region.orientation
            else:
                ellipse_angle = -np.pi / 2 - region.orientation
            self.db.setEllipse(image=im,
                               x=x,
                               y=y,
                               width=region.major_axis_length,
                               height=region.minor_axis_length,
                               angle=ellipse_angle * 180 / np.pi,
                               type=self.marker_type_cell2)

    def keyPressEvent(self, event):
        print(event.key(), QtCore.Qt.Key_G)
        if event.key() == QtCore.Qt.Key_G:
            print("detect")
            self.detect()

    def loadData(self):
        if self.data is not None:
            return
        im = self.cp.getImage()
        if im is not None:
            config = configparser.ConfigParser()
            config.read(im.filename.replace(".tif", "_config.txt"))

            magnification = float(config['MICROSCOPE']['objective'].split()[0])
            coupler = float(config['MICROSCOPE']['coupler'].split()[0])
            camera_pixel_size = float(
                config['CAMERA']['camera pixel size'].split()[0])

            self.pixel_size = camera_pixel_size / (magnification * coupler
                                                   )  # in micrometer

            self.data2 = np.genfromtxt(im.filename.replace(
                ".tif", "_result.txt"),
                                       dtype=float,
                                       skip_header=2)
            self.frames = self.data2[:, 0].astype("int")

    def frameChangedEvent(self):
        self.loadData()
        im = self.cp.getImage()
        if im is not None and self.data is not None and im.ellipses.count(
        ) == 0:
            for index, element in self.data[self.data.frames ==
                                            im.frame].iterrows():
                print("element")
                x_pos = element.x
                y_pos = element.y
                long = element.long_axis
                short = element.short_axis
                angle = element.angle

                Irregularity = element.irregularity  # ratio of circumference of the binarized image to the circumference of the ellipse
                Solidity = element.solidity  # percentage of binary pixels within convex hull polygon

                D = np.sqrt(long *
                            short)  # diameter of undeformed (circular) cell
                strain = (long - short) / D

                #print("element.velocity_partner", element.velocity_partner)

                self.db.setEllipse(
                    image=im,
                    x=x_pos,
                    y=y_pos,
                    width=long / self.pixel_size,
                    height=short / self.pixel_size,
                    angle=angle,
                    type=self.marker_type_cell,
                    text=
                    f"timestamp {element.timestamp}\nstrain {strain:.3f}\nsolidity {Solidity:.2f}\nirreg. {Irregularity:.3f}",  #\nvelocity {element.velocity:.3f}\n {element.velocity_partner}"
                )

    def buttonPressedEvent(self):
        self.show()