Exemplo n.º 1
0
 def initUnet(self):
     print("loading weight file: ", self.weight_selection.file)
     shape = self.cp.getImage().getShape()
     self.unet = UNet((shape[0], shape[1], 1),
                      1,
                      d=8,
                      weights=self.weight_selection.file)
Exemplo n.º 2
0
    def detect(self, im, img, frame):

        if self.unet is None:
            self.unet = UNet((img.shape[0], img.shape[1], 1), 1, d=8)
        img = (img - np.mean(img)) / np.std(img).astype(np.float32)
        timestamp = getTimestamp(self.vidcap, frame)

        probability_map = self.unet.predict(img[None, :, :, None])[0, :, :, 0]
        prediction_mask = probability_map > 0.5
        cells, prediction_mask = mask_to_cells_edge(prediction_mask,
                                                    img,
                                                    self.config,
                                                    self.rmin, {},
                                                    edge_dist=15,
                                                    return_mask=True)

        [
            c.update({
                "frames": frame,
                "timestamp": timestamp,
                "area": np.pi * (c["long_axis"] * c["short_axis"]) / 4
            }) for c in cells
        ]  # maybe use map for this?

        self.db.setMask(image=im, data=prediction_mask.astype(np.uint8))
        self.db.deleteEllipses(type=self.marker_type_cell2, image=im)
        self.drawEllipse(pd.DataFrame(cells), self.marker_type_cell2)

        return cells, probability_map
Exemplo n.º 3
0
    def detect(self):
        im = self.cp.getImage()
        img = self.cp.getImage().data
        if self.unet is None:
            self.unet = UNet((img.shape[0], img.shape[1], 1), 1, d=8)
        img = (img - np.mean(img)) / np.std(img).astype(np.float32)
        prediction_mask = self.unet.predict(img[None, :, :, None])[0, :, :,
                                                                   0] > 0.5
        self.db.setMask(image=self.cp.getImage(),
                        data=prediction_mask.astype(np.uint8))
        print(prediction_mask.shape)
        self.cp.reloadMask()
        print(prediction_mask)

        labeled = label(prediction_mask)

        # iterate over all detected regions
        for region in regionprops(labeled, img):
            y, x = region.centroid
            if region.orientation > 0:
                ellipse_angle = np.pi / 2 - region.orientation
            else:
                ellipse_angle = -np.pi / 2 - region.orientation
            self.db.setEllipse(image=im,
                               x=x,
                               y=y,
                               width=region.major_axis_length,
                               height=region.minor_axis_length,
                               angle=ellipse_angle * 180 / np.pi,
                               type=self.marker_type_cell2)
    def __call__(self, data):
        import time
        predict_start_first = time.time()
        from deformationcytometer.detection.includes.UNETmodel import UNet
        import numpy as np
        from deformationcytometer.detection.includes.regionprops import preprocess, getTimestamp

        if data["type"] == "start" or data["type"] == "end":
            yield data
            return

        log("2detect", "prepare", 1, data["index"])

        def preprocess(img):
            img = img - np.mean(img, axis=(1, 2))[:, None, None]
            img = img / np.std(img, axis=(1, 2))[:, None, None]
            return img.astype(np.float32)

        data_storage_numpy = self.data_storage.get_stored(data["data_info"])
        data_storage_mask_numpy = self.data_storage.get_stored(
            data["mask_info"])

        # initialize the unet if necessary
        im = data_storage_numpy[0]  # batch[0]["im"]
        if self.unet is None or self.unet.shape[:2] != im.shape:
            im = data_storage_numpy[0]  #batch[0]["im"]
            if self.network_weights is not None and self.network_weights != "":
                self.unet = UNet((im.shape[0], im.shape[1], 1),
                                 1,
                                 d=8,
                                 weights=self.network_weights)
            else:
                self.unet = UNet((im.shape[0], im.shape[1], 1), 1, d=8)

        # predict cell masks from the image batch
        im_batch = preprocess(data_storage_numpy)
        import time
        predict_start = time.time()
        import tensorflow as tf
        with tf.device('/GPU:0'):
            prediction_mask_batch = self.unet.predict(
                im_batch[:, :, :, None])[:, :, :, 0] > 0.5
        dt = time.time() - predict_start
        data_storage_mask_numpy[:] = prediction_mask_batch

        import clickpoints
        if self.write_clickpoints_masks:
            with clickpoints.DataFile(data["filename"][:-4] + ".cdb") as cdb:
                # iterate over all images and return them
                for mask, index in zip(data_storage_mask_numpy,
                                       range(data["index"],
                                             data["end_index"])):
                    cdb.setMask(frame=index, data=mask.astype(np.uint8))

        data["config"].update({"network": self.network_weights})

        log("2detect", "prepare", 0, data["index"])
        yield data
class Segmentation:
    # basic segmentation class that can
    # 1) store an network
    # 2) handle both segmentation based on cell area and cell boundary
    # 3) return the segmentation mask

    def __init__(self, network_path=None, img_shape=None, pixel_size=None, r_min=None, frame_data=None, edge_dist=15,
                 channel_width=0, edge_only=False, return_mask=True, d=8, **kwargs):

        self.unet = UNet((img_shape[0], img_shape[1], 1), 1, d=d)
        self.unet.load_weights(network_path)
        self.pixel_size = pixel_size
        self.r_min = r_min
        self.frame_data = frame_data if frame_data is not frame_data else {}
        self.edge_dist = edge_dist
        self.config = {}
        self.config["channel_width_px"] = channel_width
        self.config["pixel_size_m"] = pixel_size
        self.edge_only = edge_only
        self.return_mask = return_mask

    def search_cells(self, prediction_mask, img):

        if self.edge_only:
            if self.return_mask:
                cells, prediction_mask = mask_to_cells_edge(prediction_mask, img, self.config, self.r_min,
                                                            self.frame_data, self.edge_dist,
                                                            return_mask=self.return_mask)
            else:
                cells = mask_to_cells_edge(prediction_mask, img, self.config, self.r_min, self.frame_data,
                                           self.edge_dist, return_mask=self.return_mask)
        else:
            cells = mask_to_cells(prediction_mask, img, self.config, self.r_min, self.frame_data, self.edge_dist)

        return prediction_mask, cells

    def segmentation(self, img):
        # image batch
        if len(img.shape) == 4:
            img = preprocess_batch(img)
            prediction_mask = self.unet.predict(img) > 0.5
            cells = []
            for i in range(prediction_mask.shape[0]):
                _, cells_ = self.search_cells(prediction_mask[i, :, :, 0], img[i, :, :, 0])
                cells.extend(cells_)
            prediction_mask = None
        # single image
        elif len(img.shape) == 2:

            img = (img - np.mean(img)) / np.std(img).astype(np.float32)
            prediction_mask = self.unet.predict(img[None, :, :, None])[0, :, :, 0] > 0.5
            prediction_mask, cells = self.search_cells(prediction_mask, img)
        else:
            raise Exception("incorrect image shape: img.shape == " + str(img.shape))
        return prediction_mask, cells
    def __init__(self, network_path=None, img_shape=None, pixel_size=None, r_min=None, frame_data=None, edge_dist=15,
                 channel_width=0, edge_only=False, return_mask=True, d=8, **kwargs):

        self.unet = UNet((img_shape[0], img_shape[1], 1), 1, d=d)
        self.unet.load_weights(network_path)
        self.pixel_size = pixel_size
        self.r_min = r_min
        self.frame_data = frame_data if frame_data is not frame_data else {}
        self.edge_dist = edge_dist
        self.config = {}
        self.config["channel_width_px"] = channel_width
        self.config["pixel_size_m"] = pixel_size
        self.edge_only = edge_only
        self.return_mask = return_mask
async def detect_masks():
    images = 0
    unet = None
    while images < image_count:
        batch_images, batch_image_indices = await image_batch_queue.get()

        # initialize the unet in the first iteration
        if unet is None:
            im = batch_images[0]
            unet = UNet((im.shape[0], im.shape[1], 1), 1, d=8)

        # predict the images
        prediction_mask_batch = unet.predict(batch_images[:, :, :, None])[:, :, :, 0] > 0.5

        images += len(batch_image_indices)
        await mask_queue.put([batch_images, batch_image_indices, prediction_mask_batch])
Exemplo n.º 8
0
def process_detect_masks(video, image_batch_queue, mask_queue):
    import os
    import logging
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # FATAL
    logging.getLogger('tensorflow').setLevel(logging.FATAL)
    from deformationcytometer.detection.includes.UNETmodel import UNet

    unet = None
    for batch_images, batch_image_indices in queue_iterator(image_batch_queue):

        # initialize the unet in the first iteration
        if unet is None:
            im = batch_images[0]
            unet = UNet((im.shape[0], im.shape[1], 1), 1, d=8)

        # predict the images
        prediction_mask_batch = unet.predict(
            batch_images[:, :, :, None])[:, :, :, 0] > 0.5

        mask_queue.put(
            [batch_images, batch_image_indices, prediction_mask_batch])
    mask_queue.put(0)
Exemplo n.º 9
0
class Addon(clickpoints.Addon):
    signal_update_plot = QtCore.Signal()
    signal_plot_finished = QtCore.Signal()
    disp_text_existing = "displaying existing data"
    disp_text_new = "displaying new data"

    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)

        # qthread and signals for update cell detection and loading ellipse at add on launch
        self.thread = Worker(run_function=None)
        self.thread.thread_started.connect(self.start_pbar)
        self.thread.thread_finished.connect(self.finish_pbar)
        self.thread.thread_progress.connect(self.update_pbar)

        self.stop = False
        self.plot_data = np.array([[], []])
        self.unet = None
        self.layout = QtWidgets.QVBoxLayout(self)

        # Setting up marker Types
        self.marker_type_cell1 = self.db.setMarkerType("cell", "#0a2eff",
                                                       self.db.TYPE_Ellipse)
        self.marker_type_cell2 = self.db.setMarkerType("cell new", "#Fa2eff",
                                                       self.db.TYPE_Ellipse)
        self.cp.reloadTypes()

        # finding and setting path to store network probability map
        self.prob_folder = os.environ["CLICKPOINTS_TMP"]
        self.prob_path = self.db.setPath(self.prob_folder)
        self.prob_layer = self.db.setLayer("prob_map")

        clickpoints.Addon.__init__(self, *args, **kwargs)

        # set the title and layout
        self.setWindowTitle("DeformationCytometer - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # weight file selection
        self.weight_selection = SetFile(store_path,
                                        filetype="weight file (*.h5)")
        self.weight_selection.fileSeleted.connect(self.initUnet)
        self.layout.addLayout(self.weight_selection)

        # update segmentation
        # in range of frames
        seg_layout = QtWidgets.QHBoxLayout()
        self.update_detection_button = QtWidgets.QPushButton(
            "update cell detection")
        self.update_detection_button.setToolTip(
            tooltip_strings["update cell detection"])
        self.update_detection_button.clicked.connect(
            partial(self.start_threaded, self.detect_all))
        seg_layout.addWidget(self.update_detection_button, stretch=5)
        # on single frame
        self.update_single_detection_button = QtWidgets.QPushButton(
            "single detection")
        self.update_single_detection_button.setToolTip(
            tooltip_strings["single detection"])
        self.update_single_detection_button.clicked.connect(self.detect_single)
        seg_layout.addWidget(self.update_single_detection_button, stretch=1)
        self.layout.addLayout(seg_layout)

        # regularity and solidity thresholds
        validator = QtGui.QDoubleValidator(0, 100, 3)
        filter_layout = QtWidgets.QHBoxLayout()
        reg_label = QtWidgets.QLabel("irregularity")
        filter_layout.addWidget(reg_label)
        self.reg_box = QtWidgets.QLineEdit("1.06")
        self.reg_box.setToolTip(tooltip_strings["irregularity"])
        self.reg_box.setValidator(validator)
        filter_layout.addWidget(self.reg_box,
                                stretch=1)  # TODO implement text edited method
        sol_label = QtWidgets.QLabel("solidity")
        filter_layout.addWidget(sol_label)
        self.sol_box = QtWidgets.QLineEdit("0.96")
        self.sol_box.setToolTip(tooltip_strings["solidity"])
        self.sol_box.setValidator(validator)
        filter_layout.addWidget(self.sol_box, stretch=1)
        rmin_label = QtWidgets.QLabel("min radius [µm]")
        filter_layout.addWidget(rmin_label)
        self.rmin_box = QtWidgets.QLineEdit("6")
        self.rmin_box.setToolTip(tooltip_strings["min radius"])
        self.rmin_box.setValidator(validator)
        filter_layout.addWidget(self.rmin_box, stretch=1)
        filter_layout.addStretch(stretch=4)
        self.layout.addLayout(filter_layout)

        # plotting buttons
        layout = QtWidgets.QHBoxLayout()
        self.button_stressstrain = QtWidgets.QPushButton("stress-strain")
        self.button_stressstrain.clicked.connect(self.plot_stress_strain)
        self.button_stressstrain.setToolTip(tooltip_strings["stress-strain"])
        layout.addWidget(self.button_stressstrain)
        self.button_kpos = QtWidgets.QPushButton("k-pos")
        self.button_kpos.clicked.connect(self.plot_k_pos)
        self.button_kpos.setToolTip(tooltip_strings["k-pos"])
        layout.addWidget(self.button_kpos)
        self.button_reg_sol = QtWidgets.QPushButton("regularity-solidity")
        self.button_reg_sol.clicked.connect(self.plot_irreg)
        self.button_reg_sol.setToolTip(tooltip_strings["regularity-solidity"])
        layout.addWidget(self.button_reg_sol)
        self.button_kHist = QtWidgets.QPushButton("k histogram")
        self.button_kHist.clicked.connect(self.plot_kHist)
        self.button_kHist.setToolTip(tooltip_strings["k histogram"])
        layout.addWidget(self.button_kHist)
        self.button_alphaHist = QtWidgets.QPushButton("alpha histogram")
        self.button_alphaHist.clicked.connect(self.plot_alphaHist)
        self.button_alphaHist.setToolTip(tooltip_strings["alpha histogram"])
        layout.addWidget(self.button_alphaHist)
        self.button_kalpha = QtWidgets.QPushButton("k-alpha")
        self.button_kalpha.clicked.connect(self.plot_k_alpha)
        self.button_kalpha.setToolTip(tooltip_strings["k-alpha"])
        layout.addWidget(self.button_kalpha)
        # button to switch between display of loaded and newly generated data
        frame = QtWidgets.QFrame()  # horizontal separating line
        frame.setFrameShape(QtWidgets.QFrame.VLine)
        frame.setLineWidth(3)
        layout.addWidget(frame)
        self.switch_data_button = QtWidgets.QPushButton(
            self.disp_text_existing)
        self.switch_data_button.clicked.connect(self.switch_display_data)
        self.switch_data_button.setToolTip(
            tooltip_strings[self.disp_text_existing])
        layout.addWidget(self.switch_data_button)
        self.layout.addLayout(layout)

        # matplotlib widgets to draw plots
        self.plot = MatplotlibWidget(self)
        self.plot_data = np.array([[], []])
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))
        self.plot.figure.canvas.mpl_connect('button_press_event',
                                            self.button_press_callback)

        # progress bar
        self.progressbar = QtWidgets.QProgressBar()
        self.layout.addWidget(self.progressbar)
        # progressbar lable
        pbar_info_layout = QtWidgets.QHBoxLayout()
        self.pbarLable = QtWidgets.QLabel("")
        pbar_info_layout.addWidget(self.pbarLable, stretch=1)
        pbar_info_layout.addStretch(stretch=2)
        # button to stop thread execution
        self.stop_button = QtWidgets.QPushButton("stop")
        self.stop_button.clicked.connect(self.quit_thread)
        self.stop_button.setToolTip(tooltip_strings["stop"])
        pbar_info_layout.addWidget(self.stop_button, stretch=1)
        self.layout.addLayout(pbar_info_layout)

        # setting paths for data, config and image
        # identifying the full path to the video. If an existing ClickPoints database is opened, the path if
        # is likely relative to the database location.
        self.filename = self.db.getImage(0).get_full_filename()
        if not os.path.isabs(self.filename):
            self.filename = str(
                Path(self.db._database_filename).parent.joinpath(
                    Path(self.filename)))

        self.config_file = self.constructFileNames("_config.txt")
        self.result_file = self.constructFileNames("_result.txt")
        self.addon_result_file = self.constructFileNames("_addon_result.txt")
        self.addon_evaluated_file = self.constructFileNames(
            "_addon_evaluated.csv")
        self.addon_config_file = self.constructFileNames("_addon_config.txt")
        self.vidcap = imageio.get_reader(self.filename)

        # reading in config an data
        self.data_all_existing = pd.DataFrame()
        self.data_mean_existing = pd.DataFrame()
        self.data_all_new = pd.DataFrame()
        self.data_mean_new = pd.DataFrame()
        if self.config_file.exists() and self.result_file.exists():
            self.config = getConfig(self.config_file)
            # ToDo: replace with a flag// also maybe some sort of "reculation" feature
            # Trying to get regularity and solidity from the config
            if "irregularity" in self.config.keys(
            ) and "solidity" in self.config.keys():
                solidity_threshold = self.config["solidity"]
                irregularity_threshold = self.config["irregularity"]
            else:
                solidity_threshold = self.sol_threshold
                irregularity_threshold = self.reg_threshold
            # reading unfiltered data (from results.txt) and data from evaluated.csv
            # unfiltered data (self.data_all_existing) is used to display regularity and solidity scatter plot
            # everything else is from evaluated.csv (self.data_mean_existing)
            self.data_all_existing, self.data_mean_existing = self.load_data(
                self.result_file, solidity_threshold, irregularity_threshold)
        else:  # get a default config if no config is found
            self.config = getConfig(default_config_path)

        ## loading data from previous addon action
        if self.addon_result_file.exists():
            self.data_all_new, self.data_mean_new = self.load_data(
                self.addon_result_file, self.sol_threshold, self.reg_threshold)
            self.start_threaded(
                partial(self.display_ellipses,
                        type=self.marker_type_cell2,
                        data=self.data_all_new))
        # create an addon config file
        # presence of this file allows easy implementation of the load_data and tank threading pipelines when
        # calculating new data
        if not self.addon_config_file.exists():
            shutil.copy(self.config_file, self.addon_config_file)

        self.plot_data_frame = self.data_all
        # initialize plot
        self.plot_stress_strain()

        # Displaying the loaded cells. This is in separate thread as it takes up to 20 seconds.
        self.db.deleteEllipses(type=self.marker_type_cell1)
        self.db.deleteEllipses(type=self.marker_type_cell2)
        self.start_threaded(
            partial(self.display_ellipses,
                    type=self.marker_type_cell1,
                    data=self.data_all_existing))

        print("loading finished")

    def constructFileNames(self, replace):
        if self.filename.endswith(".tif"):
            return Path(self.filename.replace(".tif", replace))
        if self.filename.endswith(".cdb"):
            return Path(self.filename.replace(".cdb", replace))

    # slots to update the progress bar from another thread (update cell detection and display_ellipse)
    @pyqtSlot(tuple, str)  # the decorator is not really necessary
    def start_pbar(self, prange, text):
        self.progressbar.setMinimum(prange[0])
        self.progressbar.setMaximum(prange[1])
        self.pbarLable.setText(text)

    @pyqtSlot(int)
    def update_pbar(self, value):
        self.progressbar.setValue(value)

    @pyqtSlot(int)
    def finish_pbar(self, value):
        self.progressbar.setValue(value)
        self.pbarLable.setText("finished")

    # Dynamic switch between existing and new data
    def switch_display_data(self):

        if self.switch_data_button.text() == self.disp_text_existing:
            text = self.disp_text_new
        else:
            text = self.disp_text_existing
        self.switch_data_button.setText(text)
        # updating the plot
        self.plot_type()

    @property
    def data_all(self):
        if self.switch_data_button.text() == self.disp_text_existing:
            return self.data_all_existing
        if self.switch_data_button.text() == self.disp_text_new:
            return self.data_all_new

    @property
    def data_mean(self):
        if self.switch_data_button.text() == self.disp_text_existing:
            return self.data_mean_existing
        if self.switch_data_button.text() == self.disp_text_new:
            return self.data_mean_new

    # solidity and regularity and rmin properties
    @property
    def sol_threshold(self):
        return float(self.sol_box.text())

    @property
    def reg_threshold(self):
        return float(self.reg_box.text())

    @property
    def rmin(self):
        return float(self.rmin_box.text())

    # handling thread entrance and exit
    def start_threaded(self, run_function):
        self.stop = False  # self.stop property is used to by the thread function to exit loops
        self.thread.run_function = run_function
        self.thread.start()

    def quit_thread(self):
        self.stop = True
        self.thread.quit()

    def load_data(self, file, solidity_threshold, irregularity_threshold):

        data_all = getData(file)
        if not "area" in data_all.keys():
            data_all["area"] = data_all["long_axis"] * data_all[
                "short_axis"] * np.pi / 4

        if len(data_all) == 0:
            print("no data loaded from file '%s'" % file)
            return pd.DataFrame(), pd.DataFrame()
        # use a "read sol from config flag here
        data_mean, config_eval = load_all_data_new(
            self.db.getImage(0).get_full_filename().replace(
                ".tif", "_evaluated_new.csv"),
            do_group=False,
            do_excude=False)
        return data_all, data_mean

    # plotting functions
    # wrapper for all scatter plots; handles empty and data log10 transform
    def plot_scatter(self,
                     data,
                     type1,
                     type2,
                     funct1=doNothing,
                     funct2=doNothing):
        self.init_newPlot()
        try:
            x = funct1(data[type1])
            y = funct2(data[type2])
        except KeyError:
            self.plot.draw()
            return
        if (np.all(np.isnan(x))) or (np.all(np.isnan(x))):
            return
        try:
            plotDensityScatter(x,
                               y,
                               cmap='viridis',
                               alpha=1,
                               skip=1,
                               y_factor=1,
                               s=5,
                               levels=None,
                               loglog=False,
                               ax=self.plot.axes)
            self.plot_data = np.array([x, y])
            self.plot_data_frame = data
            self.plot.axes.set_xlabel(type1)
            self.plot.axes.set_ylabel(type2)
        except (ValueError, np.LinAlgError):
            print("kernel density estimation failed? not enough cells found?")
            return

    # clearing axis and plot.data
    def init_newPlot(self):
        self.plot_data = np.array([[], []])
        self.plot.axes.clear()
        self.plot.draw()

    def plot_alphaHist(self):
        self.plot_type = self.plot_alphaHist
        self.init_newPlot()
        try:
            x = self.data_mean["alpha_cell"]
        except KeyError:
            return
        if not np.any(~np.isnan(x)):
            return
        l = plot_density_hist(x, ax=self.plot.axes, color="C1")
        # stat_k = get_mode_stats(data.k_cell)
        self.plot.axes.set_xlim((1, 1))
        self.plot.axes.xaxis.set_ticks(np.arange(0, 1, 0.2))
        self.plot.axes.grid()
        self.plot.draw()

    def plot_kHist(self):
        self.plot_type = self.plot_kHist
        self.init_newPlot()
        try:
            x = np.array(self.data_mean["k_cell"])
        except KeyError:
            return
        if not np.any(~np.isnan(x)):
            return
        l = plot_density_hist(np.log10(x), ax=self.plot.axes, color="C0")
        self.plot.axes.set_xlim((1, 4))
        self.plot.axes.xaxis.set_ticks(np.arange(5))
        self.plot.axes.grid()
        self.plot.draw()

    def plot_k_alpha(self):
        self.plot_type = self.plot_k_alpha
        self.plot_scatter(self.data_mean,
                          "alpha_cell",
                          "k_cell",
                          funct2=np.log10)
        self.plot.axes.set_ylabel("log10 k")
        self.plot.axes.set_xlabel("alpha")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_k_size(self):
        self.plot_type = self.plot_k_size
        self.plot_scatter(self.data_mean, "area",
                          "w_k_cell")  # use self.data_all for unfiltered data
        self.plot.axes.set_ylabel("w_k")
        self.plot.axes.set_xlabel("area")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_k_pos(self):
        self.plot_type = self.plot_k_pos
        self.plot_scatter(self.data_mean, "rp",
                          "w_k_cell")  # use self.data_all for unfiltered data
        self.plot.axes.set_ylabel("w_k")
        self.plot.axes.set_xlabel("radiale position")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_irreg(self):
        self.plot_type = self.plot_irreg
        # unfiltered plot of irregularity and solidity to easily identify errors
        # currently based on single cells
        self.plot_scatter(self.data_all,
                          "solidity",
                          "irregularity",
                          funct1=doNothing,
                          funct2=doNothing)
        self.plot.axes.axvline(self.sol_threshold, ls="--")
        self.plot.axes.axhline(self.reg_threshold, ls="--")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_stress_strain(self):
        self.plot_type = self.plot_stress_strain
        self.plot_scatter(self.data_mean, "stress", "strain")
        self.plot.axes.set_xlim((-10, 400))
        self.plot.figure.tight_layout()
        self.plot.draw()

    # Jump to cell in ClickPoints window when clicking near a data point in the scatter plot
    def button_press_callback(self, event):
        # only drag with left mouse button, do nothing if plot is empty or clicked outside of axis
        if event.button != 1 or event.inaxes is None or self.plot_data.size == 0:
            return
        xy = np.array([event.xdata, event.ydata])
        scale = np.nanmean(self.plot_data, axis=1)
        distance = np.linalg.norm(self.plot_data / scale[:, None] -
                                  xy[:, None] / scale[:, None],
                                  axis=0)
        nearest_point = np.nanargmin(distance)
        print("clicked", xy)
        self.cp.jumpToFrame(int(self.plot_data_frame.frames[nearest_point]))
        self.cp.centerOn(self.plot_data_frame.x[nearest_point],
                         self.plot_data_frame.y[nearest_point])

    # not sure what this is for ^^
    def buttonPressedEvent(self):
        self.show()

    ## cell detection
    def initUnet(self):
        print("loading weight file: ", self.weight_selection.file)
        shape = self.cp.getImage().getShape()
        self.unet = UNet((shape[0], shape[1], 1),
                         1,
                         d=8,
                         weights=self.weight_selection.file)

    # cell detection and evaluation on multiple frames
    def detect_all(self):
        info = "cell detection frame %d to %d" % (self.cp.getFrameRange()[0],
                                                  self.cp.getFrameRange()[1])
        print(info)

        self.data_all_new = pd.DataFrame()
        self.data_mean_new = pd.DataFrame()
        self.db.deleteEllipses(type=self.marker_type_cell2)
        self.thread.thread_started.emit(tuple(self.cp.getFrameRange()[:2]),
                                        info)
        for frame in range(self.cp.getFrameRange()[0],
                           self.cp.getFrameRange()[1]):
            if self.stop:  # stop signal from "stop" button
                break
            im = self.db.getImage(frame=frame)
            img = im.data
            cells, probability_map = self.detect(im, img, frame)
            for cell in cells:
                self.data_all_new = self.data_all_new.append(cell,
                                                             ignore_index=True)
            self.thread.thread_progress.emit(frame)
            # reloading the mask and ellipse display in ClickPoints// may not be necessary to do it in batches
            if frame % 10 == 0:
                self.cp.reloadMask()
                self.cp.reloadMarker()

        self.cp.reloadMask()
        self.cp.reloadMarker()
        self.data_all_new["timestamp"] = self.data_all_new["timestamp"].astype(
            float)
        self.data_all_new["frames"] = self.data_all_new["frames"].astype(int)
        # save data to addon_result.txt file
        save_cells_to_file(self.addon_result_file,
                           self.data_all_new.to_dict("records"))
        # tank threading
        print("tank threading")
        # catching error if no velocities could be identified (e.g. when only few cells are identified)
        try:
            self.tank_treading(self.data_all_new)
            # further evaluation
            print("evaluation")
            if self.addon_evaluated_file.exists():
                os.remove(self.addon_evaluated_file)
            self.data_all_new, self.data_mean_new = self.load_data(
                self.addon_result_file, self.sol_threshold, self.reg_threshold)
        except ValueError as e:
            print(e)
            self.data_mean_new = self.data_all_new.copy()
        self.thread.thread_finished.emit(self.cp.getFrameRange()[1])
        print("finished")

    # tank threading: saves results to an "_addon_tt.csv" file
    def tank_treading(self, data):
        # TODO implement tank threading for non video database
        image_reader = CachedImageReader(str(self.filename))
        getVelocity(data, self.config)
        correctCenter(data, self.config)
        data = data[(data.solidity > self.sol_threshold)
                    & (data.irregularity < self.reg_threshold)]
        ids = pd.unique(data["cell_id"])
        results = []
        for id in ids:
            d = data[data.cell_id == id]
            crops, shifts, valid = getCroppedImages(image_reader, d)
            if len(crops) <= 1:
                continue
            crops = crops[valid]
            time = (d.timestamp - d.iloc[0].timestamp) * 1e-3
            speed, r2 = doTracking(crops,
                                   data0=d,
                                   times=np.array(time),
                                   pixel_size=self.config["pixel_size"])
            results.append([id, speed, r2])
        data = pd.DataFrame(results, columns=["id", "tt", "tt_r2"])
        data.to_csv(self.filename[:-4] + "_addon_tt.csv")

    # Detection in single frame. Also saves the network probability map to the second ClickPoints layer
    # tif file of the probability map is saved to ClickPoints temporary folder.
    def detect_single(self):
        im = self.cp.getImage()
        img = self.cp.getImage().data
        frame = im.frame
        cells, probability_map = self.detect(im, img, frame)
        self.cp.reloadMask()
        self.cp.reloadMarker()

        # writing probability map as an additional layer
        filename = os.path.join(self.prob_folder, "%dprob_map.tiff" % frame)
        Image.fromarray(
            (probability_map * 255).astype(np.uint8)).save(filename)
        # Catch error if image already exists. In this case only overwriting the image file is sufficient.
        try:
            self.db.setImage(filename=filename,
                             sort_index=frame,
                             layer=self.prob_layer,
                             path=self.prob_path)
        except peewee.IntegrityError:
            pass

    # Base detection function. Includes filters for objects without fully closed boundaries, objects close to
    # the horizontal image edge and objects with a radius smaller the self.r_min.
    def detect(self, im, img, frame):

        if self.unet is None:
            self.unet = UNet((img.shape[0], img.shape[1], 1), 1, d=8)
        img = (img - np.mean(img)) / np.std(img).astype(np.float32)
        timestamp = getTimestamp(self.vidcap, frame)

        probability_map = self.unet.predict(img[None, :, :, None])[0, :, :, 0]
        prediction_mask = probability_map > 0.5
        cells, prediction_mask = mask_to_cells_edge(prediction_mask,
                                                    img,
                                                    self.config,
                                                    self.rmin, {},
                                                    edge_dist=15,
                                                    return_mask=True)

        [
            c.update({
                "frames": frame,
                "timestamp": timestamp,
                "area": np.pi * (c["long_axis"] * c["short_axis"]) / 4
            }) for c in cells
        ]  # maybe use map for this?

        self.db.setMask(image=im, data=prediction_mask.astype(np.uint8))
        self.db.deleteEllipses(type=self.marker_type_cell2, image=im)
        self.drawEllipse(pd.DataFrame(cells), self.marker_type_cell2)

        return cells, probability_map

    def keyPressEvent(self, event):

        if event.key() == QtCore.Qt.Key_G:
            print("detecting")
            self.detect_single()
            print("detecting finished")

    # Display all ellipses at launch
    def display_ellipses(self, type="cell", data=None):

        batch_size = 200
        data = data if not (data is None) else self.data_all_existing
        if len(data) == 0:
            return

        self.thread.thread_started.emit((0, len(data)), "displaying ellipses")
        for block in range(0, len(data), batch_size):
            if self.stop:
                break
            if block + batch_size > len(data):
                data_block = data.iloc[block:]
            else:
                data_block = data.iloc[block:block + batch_size]

            self.drawEllipse(data_block, type)
            self.thread.thread_progress.emit(block)
            self.cp.reloadMarker()  # not sure how thread safe this is
        self.thread.thread_finished.emit(len(data))

    # based ellipse display function
    def drawEllipse(self, data_block, type):

        if len(data_block) == 0:
            return

        strains = (data_block["long_axis"] -
                   data_block["short_axis"]) / np.sqrt(
                       data_block["long_axis"] * data_block["short_axis"])
        # list of all marker texts
        text = []
        for s, sol, irr in zip(strains, data_block['solidity'],
                               data_block['irregularity']):
            text.append(
                f"strain {s:.3f}\nsolidity {sol:.2f}\nirreg. {irr:.3f}")
        self.db.setEllipses(
            frame=list(data_block["frames"]),
            x=list(data_block["x"]),
            y=list(data_block["y"]),
            width=list(data_block["long_axis"] / self.config["pixel_size"]),
            height=list(data_block["short_axis"] / self.config["pixel_size"]),
            angle=list(data_block["angle"]),
            type=type,
            text=text)
Exemplo n.º 10
0
vidcap = imageio.get_reader(video)
config = getConfig(video)

# initialize the progressbar
with tqdm.tqdm(total=len(vidcap)) as progressbar:
    # iterate over image batches
    for batch_images, batch_image_indices in batch_iterator(
            vidcap, batch_size, preprocess):
        # update the description of the progressbar
        progressbar.set_description(f"{len(cells)} good cells")

        # initialize the unet in the first iteration
        if unet is None:
            im = batch_images[0]
            unet = UNet((im.shape[0], im.shape[1], 1),
                        1,
                        d=8,
                        weights=network_weight)

        # predict the images
        prediction_mask_batch = unet.predict(
            batch_images[:, :, :, None])[:, :, :, 0] > 0.5

        # iterate over the predicted images
        for batch_index in range(len(batch_image_indices)):
            image_index = batch_image_indices[batch_index]
            im = batch_images[batch_index]
            prediction_mask = prediction_mask_batch[batch_index]

            # get the images in the detected mask
            cells.extend(
                mask_to_cells_edge(prediction_mask,
Exemplo n.º 11
0
import os
import tensorflow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '0' # makes sure debug info is printed

from deformationcytometer.detection.includes.UNETmodel import UNet
from Neural_Network.includes.training_functions import *
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import time

u = UNet((1000,1000,1), 1, 8)

rand_n = np.random.randint(0,10,(1, 1000, 1000,1))
from tensorflow.python.client import device_lib
device_lib.list_local_devices()
print(get_available_gpus())

for i in range(4):
    print(np.sum(u.predict(rand_n)))
ns = []
dts = []
for n in [1, 5, 10, 20, 30, 40, 100]:
    try:
        print(n)
        rand_n = np.random.randint(0, 10, (n, 1000, 1000, 1))
        get_available_gpus()
        t1 = time.time()
        for i in range(4):
            np.sum(u.predict(rand_n))
        t2 = time.time()
Exemplo n.º 12
0
class Addon(clickpoints.Addon):
    data = None
    data2 = None
    unet = None

    signal_update_plot = QtCore.Signal()
    signal_plot_finished = QtCore.Signal()
    image_plot = None
    last_update = 0
    updating = False
    exporting = False
    exporting_index = 0

    def __init__(self, *args, **kwargs):
        clickpoints.Addon.__init__(self, *args, **kwargs)

        self.layout = QtWidgets.QVBoxLayout(self)

        # Check if the marker type is present
        self.marker_type_cell = self.db.setMarkerType("cell", "#0a2eff",
                                                      self.db.TYPE_Ellipse)
        self.marker_type_cell2 = self.db.setMarkerType("cell2", "#Fa2eff",
                                                       self.db.TYPE_Ellipse)
        self.cp.reloadTypes()

        self.loadData()

        clickpoints.Addon.__init__(self, *args, **kwargs)
        # set the title and layout
        self.setWindowTitle("DeformationCytometer - ClickPoints")
        self.layout = QtWidgets.QVBoxLayout(self)

        # add export buttons
        layout = QtWidgets.QHBoxLayout()
        self.button_stressstrain = QtWidgets.QPushButton("stress-strain")
        self.button_stressstrain.clicked.connect(self.plot_stress_strain)
        layout.addWidget(self.button_stressstrain)

        self.button_stressy = QtWidgets.QPushButton("y-strain")
        self.button_stressy.clicked.connect(self.plot_y_strain)
        layout.addWidget(self.button_stressy)

        self.button_y_angle = QtWidgets.QPushButton("y-angle")
        self.button_y_angle.clicked.connect(self.plot_y_angle)
        layout.addWidget(self.button_y_angle)

        self.layout.addLayout(layout)

        # add a plot widget
        self.plot = MatplotlibWidget(self)
        self.layout.addWidget(self.plot)
        self.layout.addWidget(NavigationToolbar(self.plot, self))
        self.plot.figure.canvas.mpl_connect('button_press_event',
                                            self.button_press_callback)

        # add a progress bar
        self.progressbar = QtWidgets.QProgressBar()
        self.layout.addWidget(self.progressbar)

        # connect slots
        # self.signal_update_plot.connect(self.updatePlotImageEvent)
        # self.signal_plot_finished.connect(self.plotFinishedEvent)

        # initialize the table
        # self.updateTable()
        # self.selected = None

        filename = self.db.getImage(0).get_full_filename()
        print(filename.replace(".tif", "_config.txt"))
        self.config = getConfig(filename.replace(".tif", "_config.txt"))
        self.data = getData(filename.replace(".tif", "_result.txt"))

        getVelocity(self.data, self.config)

        try:
            correctCenter(self.data, self.config)
        except ValueError:
            pass

        self.data = self.data.groupby(['cell_id']).mean()

        self.data = filterCells(self.data, self.config)
        self.data.reset_index(drop=True, inplace=True)

        getStressStrain(self.data, self.config)

    def button_press_callback(self, event):
        # only drag with left mouse button
        if event.button != 1:
            return
        # if the user doesn't have clicked on an axis do nothing
        if event.inaxes is None:
            return
        # get the pixel of the kymograph
        xy = np.array([event.xdata, event.ydata])
        scale = np.mean(self.plot_data, axis=1)
        distance = np.linalg.norm(self.plot_data / scale[:, None] -
                                  xy[:, None] / scale[:, None],
                                  axis=0)
        print(self.plot_data.shape, xy[:, None].shape, distance.shape)
        nearest_dist = np.min(distance)
        print("distance ", nearest_dist)
        nearest_point = np.argmin(distance)

        filename = self.db.getImage(0).get_full_filename()
        stress_values = stressfunc(self.data.iloc[:, 3] * 1e-6,
                                   filename.replace(".tif", "_config.txt"))
        strain_values = strain(self.data.iloc[:, 4], self.data.iloc[:, 5])

        print(
            np.linalg.norm(
                np.array([
                    stress_values[nearest_point], strain_values[nearest_point]
                ]) - xy))

        print("clicked", xy, stress_values[nearest_point], " ",
              strain_values[nearest_point], self.data.iloc[nearest_point])

        # x, y = event.xdata/self.input_scale1.value(), event.ydata/self.h/self.input_scale2.value()
        # jump to the frame in time
        self.cp.jumpToFrame(self.data.frames[nearest_point])
        # and to the xy position
        self.cp.centerOn(self.data.x[nearest_point],
                         self.data.y[nearest_point])

    def plot_stress_strain(self):
        filename = self.db.getImage(0).get_full_filename()

        self.plot.axes.clear()

        #plt.sca(self.plot.axes)
        x = self.data.stress
        y = self.data.strain
        plotDensityScatter(x, y, ax=self.plot.axes)

        self.plot_data = np.array([x, y])
        #self.plot.axes.plot(stress_values, strain_values, "o")
        self.plot.axes.set_xlabel("stress")
        self.plot.axes.set_ylabel("strain")
        self.plot.axes.set_xlim(-10, 400)
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_y_strain(self):
        y = self.data[:, 2]
        stress_values = stressfunc(self.data[:, 3] * 1e-6, self.config)
        strain_values = strain(self.data[:, 4], self.data[:, 5])

        self.plot.axes.clear()

        self.plot_data = np.array([y, strain_values])
        self.plot.axes.plot(y, strain_values, "o")
        self.plot.axes.set_xlabel("y")
        self.plot.axes.set_ylabel("strain")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def plot_y_angle(self):
        y = self.data[:, 2]
        angle = self.data[:, 6]

        self.plot.axes.clear()

        self.plot_data = np.array([y, angle])
        self.plot.axes.plot(y, angle, "o")
        self.plot.axes.set_xlabel("y")
        self.plot.axes.set_ylabel("angle")
        self.plot.figure.tight_layout()
        self.plot.draw()

    def export(self):
        pass

    def buttonPressedEvent(self):
        self.show()

    def detect(self):
        im = self.cp.getImage()
        img = self.cp.getImage().data
        if self.unet is None:
            self.unet = UNet((img.shape[0], img.shape[1], 1), 1, d=8)
        img = (img - np.mean(img)) / np.std(img).astype(np.float32)
        prediction_mask = self.unet.predict(img[None, :, :, None])[0, :, :,
                                                                   0] > 0.5
        self.db.setMask(image=self.cp.getImage(),
                        data=prediction_mask.astype(np.uint8))
        print(prediction_mask.shape)
        self.cp.reloadMask()
        print(prediction_mask)

        labeled = label(prediction_mask)

        # iterate over all detected regions
        for region in regionprops(labeled, img):
            y, x = region.centroid
            if region.orientation > 0:
                ellipse_angle = np.pi / 2 - region.orientation
            else:
                ellipse_angle = -np.pi / 2 - region.orientation
            self.db.setEllipse(image=im,
                               x=x,
                               y=y,
                               width=region.major_axis_length,
                               height=region.minor_axis_length,
                               angle=ellipse_angle * 180 / np.pi,
                               type=self.marker_type_cell2)

    def keyPressEvent(self, event):
        print(event.key(), QtCore.Qt.Key_G)
        if event.key() == QtCore.Qt.Key_G:
            print("detect")
            self.detect()

    def loadData(self):
        if self.data is not None:
            return
        im = self.cp.getImage()
        if im is not None:
            config = configparser.ConfigParser()
            config.read(im.filename.replace(".tif", "_config.txt"))

            magnification = float(config['MICROSCOPE']['objective'].split()[0])
            coupler = float(config['MICROSCOPE']['coupler'].split()[0])
            camera_pixel_size = float(
                config['CAMERA']['camera pixel size'].split()[0])

            self.pixel_size = camera_pixel_size / (magnification * coupler
                                                   )  # in micrometer

            self.data2 = np.genfromtxt(im.filename.replace(
                ".tif", "_result.txt"),
                                       dtype=float,
                                       skip_header=2)
            self.frames = self.data2[:, 0].astype("int")

    def frameChangedEvent(self):
        self.loadData()
        im = self.cp.getImage()
        if im is not None and self.data is not None and im.ellipses.count(
        ) == 0:
            for index, element in self.data[self.data.frames ==
                                            im.frame].iterrows():
                print("element")
                x_pos = element.x
                y_pos = element.y
                long = element.long_axis
                short = element.short_axis
                angle = element.angle

                Irregularity = element.irregularity  # ratio of circumference of the binarized image to the circumference of the ellipse
                Solidity = element.solidity  # percentage of binary pixels within convex hull polygon

                D = np.sqrt(long *
                            short)  # diameter of undeformed (circular) cell
                strain = (long - short) / D

                #print("element.velocity_partner", element.velocity_partner)

                self.db.setEllipse(
                    image=im,
                    x=x_pos,
                    y=y_pos,
                    width=long / self.pixel_size,
                    height=short / self.pixel_size,
                    angle=angle,
                    type=self.marker_type_cell,
                    text=
                    f"timestamp {element.timestamp}\nstrain {strain:.3f}\nsolidity {Solidity:.2f}\nirreg. {Irregularity:.3f}",  #\nvelocity {element.velocity:.3f}\n {element.velocity_partner}"
                )

    def buttonPressedEvent(self):
        self.show()
Exemplo n.º 13
0
# string of the function name
metric_name = metric.__name__

# Random see for reproducibility. You should also load a fixed weight file to the U-net (parameter "weight_path")!
seed = 100

# Function that manipulates the ground truth
# First argument must be a 2-dimensional integer array, must also return a 2-dimensional integer array
# Our strategy is to train the network only to recognize the edge (3 pixel thickness) of cells.
# !!! Note that this approach is also relevant in the downstream cell detection process
# (e.g. in the "mask_to_cells_edge" function) and cannot simply be changed here.
mask_function = extract_edge

# Constructing the Neural Network and loading weight files
np.random.seed(seed)
unet = UNet((im_shape[0], im_shape[1], 1), 1, d=8, weights=weight_path)

# defining paths to write training data
dir_x = os.path.join(dir_training_data, "X_data")
dir_y = os.path.join(dir_training_data, "y_data")
dir_w = os.path.join(dir_training_data, "w_data")

# loading training data from ClickPoints databases and optionally writing it to the disk.
if not use_existing_data:
    write_training_data(cdb_files_list,
                        dir_x,
                        dir_y,
                        seed,
                        test_size,
                        dir_w=dir_w,
                        final_shape=im_shape,
Exemplo n.º 14
0
    def __call__(self, data):
        from deformationcytometer.detection.includes.UNETmodel import UNet
        import numpy as np
        from deformationcytometer.detection.includes.regionprops import preprocess, getTimestamp

        # initialize the batch if necessary
        if self.batch is None:
            self.batch = []

        if data["type"] == "start":
            yield data
            return

        if data["type"] == "image":
            # add the new data
            self.batch.append(data)

        # if the batch is full or all images of the .tif file have been loaded
        if len(self.batch) == self.batch_size or (data["type"] == "end" and len(self.batch)):
            log("2detect", "prepare", 1, self.batch[0]["index"])
            batch = self.batch
            self.batch = []

            # initialize the unet if necessary
            if self.unet is None:
                im = batch[0]["im"]
                self.unet = UNet((im.shape[0], im.shape[1], 1), 1, d=8, weights=self.network_weights)

            # predict cell masks from the image batch
            im_batch = np.dstack([data["im"] for data in batch])
            im_batch = preprocess(im_batch).transpose(2, 0, 1)
            prediction_mask_batch = self.unet.predict(im_batch[:, :, :, None])[:, :, :, 0] > 0.5

            import clickpoints
            if write_clickpoints_file and write_clickpoints_masks:
                with clickpoints.DataFile(data["filename"][:-4] + ".cdb") as cdb:
                    # iterate over all images and return them
                    for i in range(len(batch)):
                        data = batch[i]
                        data["mask"] = prediction_mask_batch[i]

                        cdb.setMask(frame=data["index"], data=data["mask"].astype(np.uint8))

                        data["config"].update({"network": self.network_weights})
                        log("2detect", "prepare", 0, data["index"])
                        yield data
                        if i < len(batch) - 1:
                            log("2detect", "prepare", 1, data["index"] + 1)
            else:
                # iterate over all images and return them
                for i in range(len(batch)):
                    data = batch[i]
                    data["mask"] = prediction_mask_batch[i]
                    data["config"].update({"network": self.network_weights})
                    log("2detect", "prepare", 0, data["index"])
                    yield data
                    if i < len(batch) - 1:
                        log("2detect", "prepare", 1, data["index"] + 1)

        if data["type"] == "end":
            return data
Exemplo n.º 15
0
vidcap2 = getRawVideo(video)
progressbar = tqdm.tqdm(vidcap)

cells = []

im = vidcap.get_data(0)
batch_images = np.zeros([batch_size, im.shape[0], im.shape[1]],
                        dtype=np.float32)
batch_image_indices = []
ips = 0
for image_index, im in enumerate(progressbar):
    progressbar.set_description(
        f"{image_index} {len(cells)} good cells ({ips} ips)")

    if unet is None:
        unet = UNet((im.shape[0], im.shape[1], 1), 1, d=8)

    batch_images[len(batch_image_indices)] = preprocess(im)
    batch_image_indices.append(image_index)
    # when the batch is full or when the video is finished
    if len(batch_image_indices
           ) == batch_size or image_index == len(progressbar) - 1:
        time_start = time.time()
        with tf.device('/gpu:0'):
            prediction_mask_batch = unet.predict(
                batch_images[:len(batch_image_indices), :, :, None])[:, :, :,
                                                                     0] > 0.5
        ips = len(batch_image_indices) / (time.time() - time_start)

        for batch_index in range(len(batch_image_indices)):
            image_index = batch_image_indices[batch_index]