Пример #1
0
def main_ui():
    parser = argparse.ArgumentParser()
    parser.add_argument("--trackdata", "-t", type=str)
    parser.add_argument("--divide", action="store_true", default=False)
    parser.add_argument("--cell", type=int)
    parser.add_argument("--at", type=int)
    parser.add_argument("--into", type=int, nargs="+")

    # parser.add_argument("--split_cell_at_frame", action="store_true", default=False)
    # parser.add_argument("--check_consistency", action="store_true", default=False)
    # parser.add_argument(
    #     "--auto_correct_if_possible", action="store_true", default=False
    # )

    # parser.add_argument("--view_tree", action="store_true")
    # parser.add_argument("--set_parent", type=str)
    # parser.add_argument("--set_child", type=str)
    # parser.add_argument("--set_cell_state", type=str)
    # parser.add_argument("--new_cell", type=str)
    # parser.add_argument("--from_frame", type=int)
    # parser.add_argument("--upto_frame", type=int)
    # parser.add_argument("--at_frame", type=int)
    # parser.add_argument("--view_cell", type=str)
    args = parser.parse_args()

    td = TrackDB(args.trackdata)

    if args.divide:
        children = td.divide_cell(args.at, args.cell, *args.into[:2])
        print(
            f"At frame {args.at}, cell {args.cell} divided into" +
            f"({args.into[0]}->{children[0]}) and {args.into[1]}->{children[1]}"
        )
        td.save()
Пример #2
0
def main():
    trackdb = "/home/nmurphy/Dropbox/work/projects/bf_pulse/bf10_track.sqllite"
    filepaths = {
        "basedir": "/media/nmurphy/BF_Data_Orange/proc_data/iphox_movies/",
        "dataset": "BF10_timelapse",
        "lookat": "Column_2",
    }
    image_path = "{basedir}/{dataset}/{lookat}/{lookat}_t{{0:03d}}_ch00.tif".format(
        **filepaths)

    frame = 55
    width = 64
    c_start, c_end = 0, 16
    r_start, r_end = 0, 16

    this_image = image_path.format(frame)
    # image = skimage.io.imread(this_image)
    segmt = skimage.io.imread(
        os.path.join(
            os.path.dirname(this_image),
            "simp",
            os.path.basename(this_image).replace("_ch00", ""),
        ))
    td = TrackDB(trackdb)
    r_offset = r_start * 64
    c_offset = c_start * 64
    rows = slice(r_offset, r_end * width)
    cols = slice(c_offset, c_end * width)
    ellipses = get_ellipses(segmt[0, rows, cols], segmt[1, rows, cols])
    #print(ellipses)

    ellipses_shift = {
        i: shift_ellipse(e, r_offset, c_offset)
        for (i, e) in enumerate(ellipses)
    }

    pre_cells = td.get_cells_in_frame(frame, states=None)  # ie all
    print(pre_cells)
    if pre_cells:
        old_seg = get_cell_mask_cellids(segmt.shape[1:], td, pre_cells, frame)
        new_seg = get_cell_mask_cells(segmt.shape[1:], ellipses_shift)
        rr, cc = np.where((old_seg > 0) & (new_seg > 0))
        overlap = new_seg[rr, cc]
        cells = np.unique(overlap)
        counts = np.bincount(new_seg[rr, cc])
        for c in cells:
            print(c, counts[c])
            if counts[c] > 50:
                ellipses_shift.pop(c)

    td.add_new_ellipses_to_frame(ellipses_shift.values(), frame)
    td.save()
Пример #3
0
def main():
    trackdb = "/home/nmurphy/Dropbox/work/projects/bf_pulse/bf10_track.sqllite"
    filepaths = {
        "basedir": "/media/nmurphy/BF_Data_Orange/proc_data/iphox_movies/",
        "dataset": "BF10_timelapse",
        "lookat": "Column_2",
    }
    frames = {
        50: ((0, 16), (13, 23)),
        51: ((0, 16), (13, 16)),
        52: ((0, 16), (13, 16)),
        53: ((0, 16), (13, 16)),
        54: ((0, 16), (13, 16)),
        55: ((0, 16), (13, 16)),
    }
    image_path = "{basedir}/{dataset}/{lookat}/{lookat}_t{{0:03d}}_ch00.tif".format(
        **filepaths
    )

    td = TrackDB(trackdb)

    w = 64
    X_train = None
    Y_train = None

    for frame, (rows, cols) in frames.items():
        c_start, c_end = cols
        r_start, r_end = rows

        im = proc_image(skimage.io.imread(image_path.format(frame)))
        targets = make_target_images(frame, im.shape[:2], td)

        train_sub_section = im[r_start * w : r_end * w, c_start * w : c_end * w]
        train_sub_section_label = targets[
            r_start * w : r_end * w, c_start * w : c_end * w, :
        ]

        frame_X_train, _ = creat_data_sets_steps(
            train_sub_section, size=w, rotations=4, step=32
        )
        X_train = appender(X_train, frame_X_train)
        frame_Y_train, _ = creat_data_sets_steps(
            train_sub_section_label, size=w, rotations=4, step=32
        )
        Y_train = appender(Y_train, frame_Y_train)

    scipy.io.savemat("training_50-55_1.mat", {"train_X": X_train, "train_Y": Y_train})
Пример #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--frame", type=int, required=True)
    args = parser.parse_args()

    trackdb = "/home/nmurphy/Dropbox/work/projects/bf_pulse/bf10_track.sqllite"
    filepaths = {
        "basedir": "/media/nmurphy/BF_Data_Orange/proc_data/iphox_movies/",
        "dataset": "BF10_timelapse",
        "lookat": "Column_2",
    }
    image_path = "{basedir}/{dataset}/{lookat}/{lookat}_t{{0:03d}}_ch00.tif".format(
        **filepaths)

    frame = args.frame

    td = TrackDB(trackdb)

    track_cells_from_frame(td, image_path, frame)
Пример #5
0
    def __init__(
        self,
        filepattern,
        track_path,
        image_start=1,
        # cell_start: int = 1,
        viewmethod="gauss",  # straight",
        vmax=1.3,
        image_range=(None, None),
        just_suspicious_cells=False,
    ):
        self.filepattern = filepattern
        self.parse_re = re.compile(re.sub(r"{\d?:\d+d}", r"(\d+)", self.filepattern))
        self.image_range = self._get_image_range(image_range)

        self.red_chan = "ch00"
        self.green_chan = "ch01"
        self.blue_chan = "ch02"

        self.color_mode = "trackstatus"
        self.just_suspicious_cells = just_suspicious_cells

        self.cell_interactor_style = {
            "edgecolor": "white",
            "facecolor": "none",
            "linewidth": 1,
        }

        self.trackdata = TrackDB(track_path)  # , max(self.image_range))

        print(self.filepattern.format(self.image_range[0]))
        init_shape = get_shape(self.filepattern.format(self.image_range[0]))
        self.vmaxs = vmax
        self.cells = self.trackdata.get_cell_list()
        self.current_cell_id = 0  # int(cell_start)
        self.current_image = image_start
        self.view_methods = {"gauss": (2, lambda x: skimage.filters.gaussian(x, 1.1))}
        dim, self.view_method = self.view_methods[viewmethod]
        if dim == 3:
            self.img_size = init_shape + (dim,)
        else:
            self.img_size = init_shape

        #######################
        ## Setting up GUI
        #######################
        self.fig = plt.figure()
        self.cmrand = matplotlib.colors.ListedColormap(np.random.rand(10000, 3))
        # self.rand_colors = np.random.rand(1000, 3)

        plt.rcParams["keymap.save"] = ["ctrl+s"]  # free up s
        plt.rcParams["keymap.fullscreen"] = ["ctrl+f"]  # free up f
        plt.rcParams["keymap.home"] = [""]  # free up a and r
        plt.rcParams["keymap.back"] = ["backspace"]
        plt.rcParams["keymap.grid"] = [""]  # g

        self.number_of_steps = 2

        self.ax_images = [
            self.fig.add_axes(
                [(i / self.number_of_steps), 0.5, (1 / self.number_of_steps), 0.45]
            )
            for i in range(self.number_of_steps)
        ]

        shift = (1 / self.number_of_steps) * 0.5
        self.ax_overlaps = [
            self.fig.add_axes(
                [shift + (i / self.number_of_steps), 0.0, 1 / self.number_of_steps, 0.5]
            )
            for i in range(self.number_of_steps - 1)
        ]

        # self.ax_img.set_title(self.make_title())
        self.bg_images = [np.zeros(init_shape, dtype=np.uint16)] * self.number_of_steps
        self.overlap_images = [np.zeros((*init_shape, 3), dtype=np.uint16)] * (
            self.number_of_steps - 1
        )

        # self.last_selected_image = self.current_image

        # imgcmap = "bone"
        # imgcmap = "viridis"
        imgcmap = "hot"

        self.art_images = [
            ax.imshow(bg, cmap=plt.get_cmap(imgcmap), vmin=0, interpolation="nearest")
            for ax, bg in zip(self.ax_images, self.bg_images)
        ]

        self.art_overlaps = [
            ax.imshow(bg, cmap=plt.get_cmap(imgcmap), vmin=0, interpolation="nearest")
            for ax, bg in zip(self.ax_overlaps, self.overlap_images)
        ]

        x_axis = self.ax_images[0].get_shared_x_axes()
        y_axis = self.ax_images[0].get_shared_y_axes()
        [x_axis.join(self.ax_images[0], a) for a in self.ax_images[1:]]
        [y_axis.join(self.ax_images[0], a) for a in self.ax_images[1:]]
        [x_axis.join(self.ax_images[0], a) for a in self.ax_overlaps]
        [y_axis.join(self.ax_images[0], a) for a in self.ax_overlaps]

        self.main_bg_ellipses = [None] * self.number_of_steps
        self.main_ov_ellipses = [None] * (self.number_of_steps - 1)
        self.ui_selectors = []

        self.move_to_image(self.current_image)
        self.go_to_next_unapproved()
Пример #6
0
class State:
    def __init__(
        self,
        filepattern,
        track_path,
        image_start=1,
        # cell_start: int = 1,
        viewmethod="gauss",  # straight",
        vmax=1.3,
        image_range=(None, None),
        just_suspicious_cells=False,
    ):
        self.filepattern = filepattern
        self.parse_re = re.compile(re.sub(r"{\d?:\d+d}", r"(\d+)", self.filepattern))
        self.image_range = self._get_image_range(image_range)

        self.red_chan = "ch00"
        self.green_chan = "ch01"
        self.blue_chan = "ch02"

        self.color_mode = "trackstatus"
        self.just_suspicious_cells = just_suspicious_cells

        self.cell_interactor_style = {
            "edgecolor": "white",
            "facecolor": "none",
            "linewidth": 1,
        }

        self.trackdata = TrackDB(track_path)  # , max(self.image_range))

        print(self.filepattern.format(self.image_range[0]))
        init_shape = get_shape(self.filepattern.format(self.image_range[0]))
        self.vmaxs = vmax
        self.cells = self.trackdata.get_cell_list()
        self.current_cell_id = 0  # int(cell_start)
        self.current_image = image_start
        self.view_methods = {"gauss": (2, lambda x: skimage.filters.gaussian(x, 1.1))}
        dim, self.view_method = self.view_methods[viewmethod]
        if dim == 3:
            self.img_size = init_shape + (dim,)
        else:
            self.img_size = init_shape

        #######################
        ## Setting up GUI
        #######################
        self.fig = plt.figure()
        self.cmrand = matplotlib.colors.ListedColormap(np.random.rand(10000, 3))
        # self.rand_colors = np.random.rand(1000, 3)

        plt.rcParams["keymap.save"] = ["ctrl+s"]  # free up s
        plt.rcParams["keymap.fullscreen"] = ["ctrl+f"]  # free up f
        plt.rcParams["keymap.home"] = [""]  # free up a and r
        plt.rcParams["keymap.back"] = ["backspace"]
        plt.rcParams["keymap.grid"] = [""]  # g

        self.number_of_steps = 2

        self.ax_images = [
            self.fig.add_axes(
                [(i / self.number_of_steps), 0.5, (1 / self.number_of_steps), 0.45]
            )
            for i in range(self.number_of_steps)
        ]

        shift = (1 / self.number_of_steps) * 0.5
        self.ax_overlaps = [
            self.fig.add_axes(
                [shift + (i / self.number_of_steps), 0.0, 1 / self.number_of_steps, 0.5]
            )
            for i in range(self.number_of_steps - 1)
        ]

        # self.ax_img.set_title(self.make_title())
        self.bg_images = [np.zeros(init_shape, dtype=np.uint16)] * self.number_of_steps
        self.overlap_images = [np.zeros((*init_shape, 3), dtype=np.uint16)] * (
            self.number_of_steps - 1
        )

        # self.last_selected_image = self.current_image

        # imgcmap = "bone"
        # imgcmap = "viridis"
        imgcmap = "hot"

        self.art_images = [
            ax.imshow(bg, cmap=plt.get_cmap(imgcmap), vmin=0, interpolation="nearest")
            for ax, bg in zip(self.ax_images, self.bg_images)
        ]

        self.art_overlaps = [
            ax.imshow(bg, cmap=plt.get_cmap(imgcmap), vmin=0, interpolation="nearest")
            for ax, bg in zip(self.ax_overlaps, self.overlap_images)
        ]

        x_axis = self.ax_images[0].get_shared_x_axes()
        y_axis = self.ax_images[0].get_shared_y_axes()
        [x_axis.join(self.ax_images[0], a) for a in self.ax_images[1:]]
        [y_axis.join(self.ax_images[0], a) for a in self.ax_images[1:]]
        [x_axis.join(self.ax_images[0], a) for a in self.ax_overlaps]
        [y_axis.join(self.ax_images[0], a) for a in self.ax_overlaps]

        self.main_bg_ellipses = [None] * self.number_of_steps
        self.main_ov_ellipses = [None] * (self.number_of_steps - 1)
        self.ui_selectors = []

        self.move_to_image(self.current_image)
        self.go_to_next_unapproved()

    # def select_image(self, event):
    #     if event.inaxes is None:
    #         return None
    #     frame = event.inaxes.name
    #     self.last_selected_image = int(frame)
    #     return self.last_selected_image

    def _get_image_range(self, image_range):
        if image_range == (None, None):
            tmppattern_st = self.filepattern.split("{")[0]
            tmppattern_ed = self.filepattern.split("}")[1]

            def parse_number(filepath):
                parsed = self.parse_re.match(filepath).groups()
                return int(parsed[0])

            return sorted(
                [parse_number(f) for f in glob(tmppattern_st + "*" + tmppattern_ed)]
            )
        return list(image_range[0], image_range[1])

    def move_ui_to_image(self, imagen):
        self.move_to_image(imagen)
        self.move_to_cell_all_axes(self.current_cell_id)

        self.update_ui()

    def make_title(self, frame, status):
        vals = (frame, status, self.current_cell_id, len(self.cells))
        return "Image:{0} Cell#{2} **{1}**".format(*vals)

    # def get_cells_list_for_frame(self, frame):
    #     self.cells = self.td.get

    def move_to_image(self, number):
        if number not in self.image_range:
            print(
                "{0} is not in the list of images {1}".format(number, self.image_range)
            )
            return None

        self.current_image = number
        self.bg_images = [
            self.view_method(
                tifffile.imread(self.filepattern.format(number).format("r"))
            )
            for number in range(
                self.current_image, self.current_image + self.number_of_steps
            )
        ]

        for i in range(len(self.bg_images)):
            self.art_images[i].set_data(self.bg_images[i])
            self.art_images[i].set_clim(vmax=self.bg_images[i].max() * self.vmaxs)
            # self.ax_images[i].set_title(self.make_title())
            self.ax_images[i].name = self.current_image + i
            # self.ax_images[i].set_title(self.make_title( name = self.current_image + i

        self.fig.canvas.draw_idle()
        return True

    def go_to_next_unapproved(self):
        index = self.cells.index(self.current_cell_id)
        for i in range(index + 1, len(self.cells)):
            next_cell = self.cells[i]
            print("checking cell", next_cell)
            try:
                cell_props = self.trackdata.get_cell_properties(
                    self.current_image + 1, next_cell
                )
            except:
                print("It doesnt exist in the next frame", next_cell)
                continue
            if cell_props["trackstatus"] == "auto":
                if self.just_suspicious_cells:
                    look_at_it = False
                    how_it_was = self.trackdata.get_cell_properties(
                        self.current_image, next_cell
                    )
                    diff_len = np.sqrt(
                        (how_it_was["length"] - cell_props["length"]) ** 2
                    )
                    diff_width = np.sqrt(
                        (how_it_was["width"] - cell_props["width"]) ** 2
                    )
                    print(f"{diff_len}: length_diff")
                    print(f"{diff_width}: length_diff")
                    if diff_len > 5:
                        look_at_it = True
                    if diff_width > 5:
                        look_at_it = True
                else:
                    look_at_it = True

                if look_at_it:
                    print("Lets look at ", next_cell)
                    self.move_to_cell_all_axes(next_cell)
                    return True
                else:
                    print(f"{next_cell} is looks too similar, skipping")
                    continue
            else:
                print("Its got no status or is null skipping", next_cell)
                continue
        print("No cells left")
        return False

    def update_track_data(self, interactive_cell, frame):
        """ this updates the track data with currently edited cell"""
        if interactive_cell is not None:
            print("updating the data structure")
            print("OLD", self.trackdata.get_cell_params(frame, self.current_cell_id))
            properties = interactive_cell.get_position_props()
            properties.update({"status": "checked"})
            print("GUI", properties)
            self.trackdata.set_cell_properties(frame, self.current_cell_id, properties)
            print("Saved", self.trackdata.get_cell_params(frame, self.current_cell_id))

    def move_to_cell_all_axes(self, cell):
        frame = self.current_image
        self.current_cell_id = cell

        def get_cell(f, c):
            try:
                return self.trackdata.get_cell_properties(f, c)
            except:
                return None

        current_cells = [get_cell(frame + i, cell) for i in range(self.number_of_steps)]

        for i, ax in enumerate(self.ax_images):
            try:
                self.main_bg_ellipses[i].remove()
            except AttributeError:
                pass
            cp = current_cells[i]
            if cp is not None:
                ax.set_title(self.make_title(ax.name, cp["trackstatus"]))
                ellipse_data = self.trackdata.cell_properties_to_params(cp)
                self.main_bg_ellipses[i] = cell_editor.CellInteractor(
                    ax, *ellipse_data, **self.cell_interactor_style
                )

        def make_overlap_images(im1, im2, cel1, cel2):
            dy = cel1["row"] - cel2["row"]
            dx = cel1["col"] - cel2["col"]
            rol1 = np.roll(im2, int(dy), axis=0)
            rol2 = np.roll(rol1, int(dx), axis=1)
            joined = np.dstack([im1, rol2, np.zeros_like(im1)])
            max_v = np.max(joined)
            imb = skimage.exposure.rescale_intensity(
                joined, in_range=(0, max_v), out_range=(0, 255)
            ).astype(np.uint8)
            return imb

        current_cells_pairs = zip(current_cells[:-1], current_cells[1:])
        current_image_pairs = zip(self.bg_images[0:-1], self.bg_images[1:])
        self.overlap_images = [
            make_overlap_images(*images, *cells)
            for (images, cells) in zip(current_image_pairs, current_cells_pairs)
        ]
        for i in range(len(self.overlap_images)):
            self.art_overlaps[i].set_data(self.overlap_images[i])

        z_width = 50
        for a in self.ax_images + self.ax_overlaps:
            a.set_xlim(
                current_cells[0]["col"] - z_width, current_cells[0]["col"] + z_width
            )
            a.set_ylim(
                current_cells[0]["row"] + z_width, current_cells[0]["row"] - z_width
            )

        self.fig.canvas.draw_idle()

    def update_ui(self):

        self.fig.canvas.draw_idle()

    def save_segmentation(self):
        for f, ic in enumerate(self.main_bg_ellipses):
            self.update_track_data(ic, self.current_image + f)
        self.trackdata.save()

    # def set_cell_state(self, cell_state):
    #     state_name = self.trackdata.metadata.states[cell_state]
    #     self.trackdata.set_cell_state(
    #         self.current_image, self.current_cell_id, state_name
    #     )

    def set_track_status(self, cell_id, frame, judgement):
        print(judgement, cell_id, frame)
        self.trackdata.set_cell_properties(frame, cell_id, {"trackstatus": judgement})
        self.trackdata.save()

    def on_key_press(self, event):
        event_dict = {
            "t": lambda: self.set_track_status(
                self.current_cell_id, self.current_image + 1, "approved"
            ),
            "x": lambda: self.set_track_status(
                self.current_cell_id, self.current_image + 1, "disapprove"
            ),
            "w": self.save_segmentation,
            "pagedown": lambda: self.move_ui_to_image(self.current_image + 1),
            "d": lambda: self.move_ui_to_image(self.current_image + 1),
            "pageup": lambda: self.move_ui_to_image(self.current_image - 1),
            "a": lambda: self.move_ui_to_image(self.current_image - 1),
            "right": self.go_to_next_unapproved,
        }
        # print("type", event.key)
        try:
            action = event_dict[event.key]
            action()
        except KeyError:
            print("Pressing {0} does nothing yet".format(event.key))
Пример #7
0
    def __init__(
            self,
            filepattern,
            track_path,
            compiled_path,
            image_start=1,
            cell_start: int = 1,
            viewmethod="straight",
            vmax=1.0,
            image_range=(None, None),
    ):
        self.filepattern = filepattern
        self.parse_re = re.compile(
            re.sub(r"{\d?:\d+d}", r"(\d+)", self.filepattern))
        # print(self.parse_re)
        self.image_range = self._get_image_range(image_range)
        # print(self.image_range)

        self.red_chan = "ch00"
        self.green_chan = "ch01"
        self.blue_chan = "ch02"
        # self.red_chan = "_r"

        self.color_mode = "trackstatus"

        self.cell_interactor_style = {
            "edgecolor": "white",
            "facecolor": "none",
            "linewidth": 1,
        }

        self.trackdata_path = track_path
        self.compiled_path = compiled_path

        self.trackdata = TrackDB(track_path)  # , max(self.image_range))
        self.compileddata = compiledtracks.load_compiled_data(
            self.compiled_path, fail_silently=True)
        if self.compileddata is not None:
            self.compileddata["gr"] = (self.compileddata["green"] /
                                       self.compileddata["red"])

        # TODO When we make the gui the axes is set to be the inital img_size
        # If I set the data using set_data, the axis do not update.
        # Trying to plot things like numbers on that makes the background image be scaled over the initial
        # size while the plots go to the initial place. to avoid that I am
        # getting an inital size an settign it to be the image size. :()
        print(self.filepattern.format(self.image_range[0]))
        init_shape = get_shape(self.filepattern.format(self.image_range[0]))
        self.vmaxs = vmax
        self.cells = self.trackdata.get_cell_list()
        self.current_cell_id = int(cell_start)
        self.current_image = image_start
        self.current_path = ""
        self.view_methods = {
            "straight": (2, lambda x: x),
            # "laphat": (3, lambda x: laphat_segment.laphat_segment_v1_view(x, cell_width_pixels=9)),
            "3color": (3, self.bit16_to_bit8),
            "laphat": (3, self.custom_laphat),
            "hat": (2, self.just_gauss_hat),
            "gauss": (2, lambda x: skimage.filters.gaussian(x, 1.1)),
            "gamma": (3, self.gamma),
        }
        dim, self.view_method = self.view_methods[viewmethod]
        if dim == 3:
            self.img_size = init_shape + (dim, )
        else:
            self.img_size = init_shape

        #######################
        ## Setting up GUI
        #######################
        self.fig = plt.figure()
        self.gridspec = gridspec.GridSpec(4,
                                          2,
                                          height_ratios=[1, 0.1, 0.1, 0.1],
                                          width_ratios=[0.6, 0.4])
        self.cmrand = matplotlib.colors.ListedColormap(np.random.rand(
            10000, 3))
        # self.rand_colors = np.random.rand(1000, 3)

        plt.rcParams["keymap.save"] = ["ctrl+s"]  # free up s
        plt.rcParams["keymap.fullscreen"] = ["ctrl+f"]  # free up f
        plt.rcParams["keymap.home"] = [""]  # free up a and r
        plt.rcParams["keymap.back"] = ["backspace"]
        plt.rcParams["keymap.grid"] = [""]  # g

        self.ax_img = plt.subplot(self.gridspec[0, 0])
        self.ax_img.set_title(self.make_title())
        self.ax_img.name = "cell_viewer"
        self.bg_img = np.zeros(init_shape, dtype=np.uint16)
        self.text_labels = []

        # imgcmap = "bone"
        # imgcmap = "viridis"
        imgcmap = "hot"

        if len(self.bg_img.shape) == 2:
            self.art_img = self.ax_img.imshow(
                self.bg_img,
                cmap=plt.get_cmap(imgcmap),
                vmin=0,
                vmax=0.3,
                interpolation="nearest",
            )
        elif len(self.bg_img.shape) == 3:
            self.art_img = self.ax_img.imshow(self.bg_img,
                                              vmin=0,
                                              interpolation="nearest")

        self.interactive_cell = None
        self.non_edit_cells = []
        self.large_ellipse_coll = None

        self.ui_selectors = []
        self.ui_selectors.append(
            self.fig.canvas.mpl_connect("button_press_event",
                                        self.select_cell_id_from_tree))

        self.ax_tree = plt.subplot(self.gridspec[0, 1])
        self.ax_tree.name = "cell_picker"
        self.node_locs = {}

        self.text_box_contents = None
        self.ax_entry = plt.subplot(self.gridspec[1, 0:2])
        self.text_box = matplotlib.widgets.TextBox(self.ax_entry,
                                                   "Cell id",
                                                   initial="")
        self.text_box.on_submit(self.text_update)

        self.ax_cell_compiled_trace = plt.subplot(self.gridspec[2, 0:2])
        # self.ax_cell_compiled_trace.set_xlim(0, self.trackdata.get_max_frames())
        # self.ax_cell_compiled_trace.name = "compiled_trace"
        self.compiled_plots = {}
        # self.art_cell_id_select = self.ax_cell_id_select.scatter(self.cells, np.ones_like(self.cells), c=self.cells, cmap=self.cmrand)

        self.ui_selectors.append(
            self.fig.canvas.mpl_connect("button_press_event",
                                        self.select_cell))
        self.ui_selectors.append(
            self.fig.canvas.mpl_connect("button_press_event",
                                        self.select_frame))

        self.ax_cell_trace = plt.subplot(self.gridspec[3, 0:2],
                                         sharex=self.ax_cell_compiled_trace)
        self.comp_bar = None
        self.trace_bar = None
        self.cell_trace_plots = []
        self.cur_image_art = self.ax_cell_trace.axvspan(
            self.current_image - 0.5,
            self.current_image + 0.5,
            color="grey",
            alpha=0.4)
        self.ax_cell_trace.legend()
        self.ax_cell_trace.name = "cell_trace"

        self.read_numbers = None
        self.read_in = ""

        self.move_ui_to_image(self.current_image)
Пример #8
0
class State:
    def __init__(
            self,
            filepattern,
            track_path,
            compiled_path,
            image_start=1,
            cell_start: int = 1,
            viewmethod="straight",
            vmax=1.0,
            image_range=(None, None),
    ):
        self.filepattern = filepattern
        self.parse_re = re.compile(
            re.sub(r"{\d?:\d+d}", r"(\d+)", self.filepattern))
        # print(self.parse_re)
        self.image_range = self._get_image_range(image_range)
        # print(self.image_range)

        self.red_chan = "ch00"
        self.green_chan = "ch01"
        self.blue_chan = "ch02"
        # self.red_chan = "_r"

        self.color_mode = "trackstatus"

        self.cell_interactor_style = {
            "edgecolor": "white",
            "facecolor": "none",
            "linewidth": 1,
        }

        self.trackdata_path = track_path
        self.compiled_path = compiled_path

        self.trackdata = TrackDB(track_path)  # , max(self.image_range))
        self.compileddata = compiledtracks.load_compiled_data(
            self.compiled_path, fail_silently=True)
        if self.compileddata is not None:
            self.compileddata["gr"] = (self.compileddata["green"] /
                                       self.compileddata["red"])

        # TODO When we make the gui the axes is set to be the inital img_size
        # If I set the data using set_data, the axis do not update.
        # Trying to plot things like numbers on that makes the background image be scaled over the initial
        # size while the plots go to the initial place. to avoid that I am
        # getting an inital size an settign it to be the image size. :()
        print(self.filepattern.format(self.image_range[0]))
        init_shape = get_shape(self.filepattern.format(self.image_range[0]))
        self.vmaxs = vmax
        self.cells = self.trackdata.get_cell_list()
        self.current_cell_id = int(cell_start)
        self.current_image = image_start
        self.current_path = ""
        self.view_methods = {
            "straight": (2, lambda x: x),
            # "laphat": (3, lambda x: laphat_segment.laphat_segment_v1_view(x, cell_width_pixels=9)),
            "3color": (3, self.bit16_to_bit8),
            "laphat": (3, self.custom_laphat),
            "hat": (2, self.just_gauss_hat),
            "gauss": (2, lambda x: skimage.filters.gaussian(x, 1.1)),
            "gamma": (3, self.gamma),
        }
        dim, self.view_method = self.view_methods[viewmethod]
        if dim == 3:
            self.img_size = init_shape + (dim, )
        else:
            self.img_size = init_shape

        #######################
        ## Setting up GUI
        #######################
        self.fig = plt.figure()
        self.gridspec = gridspec.GridSpec(4,
                                          2,
                                          height_ratios=[1, 0.1, 0.1, 0.1],
                                          width_ratios=[0.6, 0.4])
        self.cmrand = matplotlib.colors.ListedColormap(np.random.rand(
            10000, 3))
        # self.rand_colors = np.random.rand(1000, 3)

        plt.rcParams["keymap.save"] = ["ctrl+s"]  # free up s
        plt.rcParams["keymap.fullscreen"] = ["ctrl+f"]  # free up f
        plt.rcParams["keymap.home"] = [""]  # free up a and r
        plt.rcParams["keymap.back"] = ["backspace"]
        plt.rcParams["keymap.grid"] = [""]  # g

        self.ax_img = plt.subplot(self.gridspec[0, 0])
        self.ax_img.set_title(self.make_title())
        self.ax_img.name = "cell_viewer"
        self.bg_img = np.zeros(init_shape, dtype=np.uint16)
        self.text_labels = []

        # imgcmap = "bone"
        # imgcmap = "viridis"
        imgcmap = "hot"

        if len(self.bg_img.shape) == 2:
            self.art_img = self.ax_img.imshow(
                self.bg_img,
                cmap=plt.get_cmap(imgcmap),
                vmin=0,
                vmax=0.3,
                interpolation="nearest",
            )
        elif len(self.bg_img.shape) == 3:
            self.art_img = self.ax_img.imshow(self.bg_img,
                                              vmin=0,
                                              interpolation="nearest")

        self.interactive_cell = None
        self.non_edit_cells = []
        self.large_ellipse_coll = None

        self.ui_selectors = []
        self.ui_selectors.append(
            self.fig.canvas.mpl_connect("button_press_event",
                                        self.select_cell_id_from_tree))

        self.ax_tree = plt.subplot(self.gridspec[0, 1])
        self.ax_tree.name = "cell_picker"
        self.node_locs = {}

        self.text_box_contents = None
        self.ax_entry = plt.subplot(self.gridspec[1, 0:2])
        self.text_box = matplotlib.widgets.TextBox(self.ax_entry,
                                                   "Cell id",
                                                   initial="")
        self.text_box.on_submit(self.text_update)

        self.ax_cell_compiled_trace = plt.subplot(self.gridspec[2, 0:2])
        # self.ax_cell_compiled_trace.set_xlim(0, self.trackdata.get_max_frames())
        # self.ax_cell_compiled_trace.name = "compiled_trace"
        self.compiled_plots = {}
        # self.art_cell_id_select = self.ax_cell_id_select.scatter(self.cells, np.ones_like(self.cells), c=self.cells, cmap=self.cmrand)

        self.ui_selectors.append(
            self.fig.canvas.mpl_connect("button_press_event",
                                        self.select_cell))
        self.ui_selectors.append(
            self.fig.canvas.mpl_connect("button_press_event",
                                        self.select_frame))

        self.ax_cell_trace = plt.subplot(self.gridspec[3, 0:2],
                                         sharex=self.ax_cell_compiled_trace)
        self.comp_bar = None
        self.trace_bar = None
        self.cell_trace_plots = []
        self.cur_image_art = self.ax_cell_trace.axvspan(
            self.current_image - 0.5,
            self.current_image + 0.5,
            color="grey",
            alpha=0.4)
        self.ax_cell_trace.legend()
        self.ax_cell_trace.name = "cell_trace"

        self.read_numbers = None
        self.read_in = ""

        self.move_ui_to_image(self.current_image)

    def text_update(self, text):
        self.text_box_contents = int(text.strip())
        print("text_contents are now:", self.text_box_contents)

    def _get_image_range(self, image_range):
        if image_range == (None, None):
            tmppattern_st = self.filepattern.split("{")[0]
            tmppattern_ed = self.filepattern.split("}")[1]

            def parse_number(filepath):
                parsed = self.parse_re.match(filepath).groups()
                return int(parsed[0])

            return sorted([
                parse_number(f)
                for f in glob(tmppattern_st + "*" + tmppattern_ed)
            ])
        else:
            return list(image_range[0], image_range[1])

    # def select_cell_id_from_id_axis(self, event):
    #     if (not event.inaxes.name == "cell_picker"):
    #         return None
    #     x = event.xdata
    #     cell = int(np.round(x))
    #     print("selected cell", cell)
    #     if cell == 0:
    #         return True
    #     self.move_to_cell(cell)

    def select_cell_id_from_tree(self, event):
        try:
            if event.inaxis is None:
                return None
            if not event.inaxes.name == "cell_picker":
                return None
            mind = 1e18
            minc = 0
            for c, (x, y) in self.node_locs.items():
                d = np.sqrt(((x - event.xdata)**2 + (y - event.ydata)**2))
                if d < mind:
                    mind = d
                    minc = c
            if mind < 1:
                self.move_to_cell(minc)
        except AttributeError:
            pass

    # def set_current_cell_textbox(self, text):
    #     try:
    #         number = int(text)
    #     except ValueError:
    #         self.txtbox_cell.set_color("pink")
    #         return None

    #     if number in self.image_range:
    #         self.txtbox_cell.set_color("white")
    #         self.move_to_cell(number)
    #     else:
    #         self.txtbox_cell.set_color("orange")
    #         self.add_cell(number)
    # def _get_cell_index(self, cellid):
    #     return np.where(self.cells == cellid)[0][0]

    def move_ui_to_image(self, imagen):
        self.move_to_image(imagen)
        self.move_to_cell(self.current_cell_id)

        self.update_ui()

    def make_title(self):
        vals = (self.current_image, "", self.current_cell_id, len(self.cells))
        return "Image:{0} {1} Cell#{2}: of {3}".format(*vals)

    def show_non_edit_cells(self, frame):
        if self.large_ellipse_coll:
            self.large_ellipse_coll.remove()

        def create_ellipses(cell_id):
            cell_params = self.trackdata.get_cell_params(frame, cell_id)
            cell_props = self.trackdata.get_cell_properties(frame, cell_id)
            cell = cell_editor.get_cell(*cell_params)
            # cell = cell_editor.get_cell(*cell_params, facecolor="none", edgecolor=self.cmrand(int(cell_id)), linewidth=2 )
            cell.cell_id = cell_id
            cell.trackstatus = cell_props["trackstatus"]
            return cell

        self.non_edit_cells = [
            create_ellipses(c)
            for c in self.trackdata.get_cells_in_frame(frame, states=None)
            if (c != self.current_cell_id)
        ]
        coll = matplotlib.collections.PatchCollection(
            self.non_edit_cells,
            # cmap=self.cmrand,
            facecolor="none",
            linewidth=2,
        )
        coll.set_edgecolor(
            [self.get_cell_color(c) for c in self.non_edit_cells])
        # coll.set_array(
        #     np.array([ self.cmrand[c.cell_id] for c in self.non_edit_cells]))
        self.large_ellipse_coll = self.ax_img.add_collection(coll)

    def get_cell_color(self, cell):
        if self.color_mode == "trackstatus":
            if cell.trackstatus == "auto":
                return "orange"
            elif cell.trackstatus == "disapprove":
                return "red"
            elif cell.trackstatus == "approved":
                return "green"
            elif cell.trackstatus == "migrated":
                return "blue"
            elif cell.trackstatus == "manual":
                return "blue"
            else:
                return "gray"
        else:  # if self.color_mode == "random"
            return self.cmrand(cell.cell_id)

    def exagerate_image(self, img):
        img_sobel = skimage.filters.sobel(img)
        img_sobel_sobel = skimage.filters.sobel(img_sobel)
        img_sobel_sobel_01 = skimage.exposure.rescale_intensity(
            img_sobel_sobel, out_range=(0, 1))
        img_gauss_01 = skimage.exposure.rescale_intensity(img,
                                                          out_range=(0, 1.0))
        cell_width = 11
        cell_disc = skimage.morphology.disk(cell_width / 2)
        img_hat = skimage.morphology.white_tophat(img_gauss_01,
                                                  selem=cell_disc)
        img_hat_01 = skimage.exposure.rescale_intensity(img_hat,
                                                        out_range=(0, 1.0))
        return np.dstack([img_hat_01, img_sobel_sobel_01, img])
        # return img_hat_01

    def bit16_to_bit8(self, im):
        def try_read(fid, ch):
            try:
                imc = skimage.io.imread(
                    self.filepattern.format(self.current_image).replace(
                        self.red_chan, ch))
            except FileNotFoundError as e:
                imc = np.zeros_like(im)
            return imc

        if len(im.shape) != 3:
            ## assume red already read in in im.
            im = np.dstack([im] + [
                try_read(self.filepattern, c)
                for c in [self.green_chan, self.blue_chan]
            ])
        # red_max = 2800
        # grn_max = 2800
        red_max = 0.08
        grn_max = 0.05

        def modify(im):
            img = skimage.filters.gaussian(im, sigma=1.1)
            # print(img.max())
            return img

        imr = skimage.exposure.rescale_intensity(modify(im[:, :, 0]),
                                                 in_range=(0, red_max),
                                                 out_range=(0, 255)).astype(
                                                     np.uint8)
        img = skimage.exposure.rescale_intensity(modify(im[:, :, 1]),
                                                 in_range=(0, grn_max),
                                                 out_range=(0, 255)).astype(
                                                     np.uint8)
        imb = skimage.exposure.rescale_intensity(modify(im[:, :, 2]),
                                                 in_range=(0, 6897),
                                                 out_range=(0, 255)).astype(
                                                     np.uint8)
        imx = np.dstack([imr, img, imb])  # np.zeros_like(imr)])
        return imx

    def custom_laphat(self, img):
        cell_width_pixels = 7.5
        img_gauss = skimage.filters.gaussian(img[:, :], 1.1)
        img_gauss_01 = skimage.exposure.rescale_intensity(img_gauss,
                                                          out_range=(0, 1.0))
        cell_disc = skimage.morphology.disk(cell_width_pixels / 2)
        img_hat = skimage.morphology.white_tophat(img_gauss_01,
                                                  selem=cell_disc)
        img_lap = skimage.filters.laplace(img_gauss, ksize=10)
        img_lap[img_lap < 0] = 0
        img_lap = img_lap * (1.0 / img_lap.max())
        img_hat = img_hat * (1.0 / img_hat.max())
        return np.dstack([img_lap, img_hat, img_gauss_01**2])

    def just_gauss_hat(self, img):
        cell_width_pixels = 7.5
        img_gauss = skimage.filters.gaussian(img[:, :], 1.1)
        img_gauss_01 = skimage.exposure.rescale_intensity(img_gauss,
                                                          out_range=(0, 1.0))
        cell_disc = skimage.morphology.disk(cell_width_pixels / 2)
        img_hat = skimage.morphology.white_tophat(img_gauss_01,
                                                  selem=cell_disc)
        return img_hat

    def gamma(self, img):
        cell_width_pixels = 3
        cell_disc = skimage.morphology.disk(cell_width_pixels / 2)
        img_gauss = skimage.filters.gaussian(img[:, :, 0], 1.0)
        img_gauss_01 = skimage.exposure.rescale_intensity(img_gauss,
                                                          out_range=(0, 1.0))
        img_hat = skimage.morphology.white_tophat(img_gauss_01,
                                                  selem=cell_disc)
        exim = skimage.exposure.adjust_gamma(img_gauss_01, gamma=0.9)
        return np.dstack([exim, img_hat, np.zeros_like(exim)])

    def move_to_image(self, number):
        if number not in self.image_range:
            print("{0} is not in the list of images {1}".format(
                number, self.image_range))
            return None

        self.current_image = number
        self.current_path = self.filepattern.format(number)
        # This is a kind of hack to load 3 color images as seperate files
        # print("about to read the image")
        # print(self.current_path)
        # if self.current_path.find("{") >= 0:
        #    print("it had a brace")
        #    self.current_path = self.filepattern.format(number, "r")
        img = tifffile.imread(self.current_path.format("r"))
        self.bg_img = self.view_method(img)

    def guess_next_cell_location(self, direction=+1):
        # jet_pattern = self.filepattern + ".mat"
        self.trackdata = auto_match.guess_next_cell_gui(
            self.trackdata,
            self.current_cell_id,
            self.current_image,
            "",
            self.filepattern,
            direction=direction,
        )
        # print("guessed someting ")
        self.move_ui_to_image(self.current_image)

    def plot_cell_compiled_trace(self, frame, cell):
        # all_frames = np.arange(0, self.trackdata.metadata["max_frames"])
        # if self.comp_bar is not None: self.comp_bar.remove()
        self.comp_bar = self.ax_cell_compiled_trace.axvspan(frame - 0.5,
                                                            frame + 0.5,
                                                            color="grey",
                                                            alpha=0.4)

        plots = [("red", "red"), ("green", "green")]  # ("gr", "blue") ]
        # plots = [("g_by_r", "blue") ]
        leaves = self.trackdata.get_final_decendants(cell)
        leaf = leaves[0]
        lineage = self.trackdata.get_cell_lineage(leaf)
        if self.compileddata is None:
            pass
            # if len(self.compiled_plots) < len(plots):
            #     self.compiled_plots = { k, [None]*len(plots)
        else:
            for i, (ch, cl) in enumerate(plots):
                lin_frames, lin_data = compiledtracks.get_channel_of_cell(
                    self.compileddata, lineage, ch)
                print("cell lineage", lineage)
                print("cell frames", lin_frames)
                # print("GOT frames", lin_frames)
                if len(lin_frames) == 0:
                    print("nothing to print here")
                    self.ax_cell_compiled_trace.clear()
                    break
                print("it didnt stop")
                # aframes = lin_frames[~ignore_points]
                # adata = lin_data[~ignore_points]
                lin_smooth = scipy.ndimage.gaussian_filter1d(lin_data,
                                                             sigma=5,
                                                             mode="mirror")
                frames, data = compiledtracks.get_channel_of_cell(
                    self.compileddata, cell, ch)
                # ignore_points = (np.isnan(data) | np.isinf(data))
                # sdata = [ s for f, s in zip(lin_frames, lin_smooth) if f in aframes]
                # print(len(frames), len(sdata))
                if (ch not in self.compiled_plots.keys()
                    ) or self.compiled_plots[ch] is None:
                    lin_line, = self.ax_cell_compiled_trace.plot(lin_frames,
                                                                 lin_smooth,
                                                                 color=cl,
                                                                 linestyle=":")
                    line, = self.ax_cell_compiled_trace.plot(frames,
                                                             data,
                                                             color=cl)
                    self.ax_cell_compiled_trace.set_ylim(
                        bottom=0, top=np.nanmax(lin_data))
                    # self.ax_cell_compiled_trace.set_ylim(bottom=0, top=4)
                    self.ax_cell_compiled_trace.set_xlim(
                        left=0, right=self.trackdata.metadata["max_frames"])
                    self.compiled_plots[ch] = line
                    self.compiled_plots[ch + "lin"] = lin_line
                else:
                    self.compiled_plots[ch].set_data(frames, data)
                    self.compiled_plots[ch + "lin"].set_data(
                        lin_frames, lin_data)
                    self.ax_cell_compiled_trace.set_ylim(
                        bottom=0, top=np.nanmax(lin_data))
                    self.ax_cell_compiled_trace.set_xlim(
                        left=0, right=self.trackdata.metadata["max_frames"])

    def plot_cell_path(self, frame, cell):
        if self.trace_bar is not None:
            self.trace_bar.remove()
        self.trace_bar = self.ax_cell_trace.axvspan(frame - 0.5,
                                                    frame + 0.5,
                                                    color="grey",
                                                    alpha=0.4)

        if cell not in self.trackdata.get_cell_list():
            # self.ax_cell_trace.clear()
            for p in self.cell_trace_plots:
                p.remove()
            self.cell_trace_plots = [None for _ in self.cell_trace_plots]
            return None
        # all_frames = np.arange(0, self.trackdata.metadata["max_frames"])
        something = np.array(self.trackdata.cells[cell]["state"]) > 0
        frames, = np.where(something)
        states = np.array(self.trackdata.cells[cell]["state"])
        lengths = np.array(self.trackdata.cells[cell]["length"])
        widths = np.array(self.trackdata.cells[cell]["width"])

        plots = [
            (frames, lengths[something], "None"),
            (frames, widths[something], "None"),
        ]

        statesymb = [(2, "$⚭$"), (3, "$☠$"), (4, "$☉$"), (5, "$🟟$")]

        for state, symb in statesymb:
            loc, = np.where(states == state)
            loc, lengths[loc]
            plots.append((loc, lengths[loc], symb))

        if len(self.cell_trace_plots) < len(plots):
            self.cell_trace_plots = [None] * len(plots)

        for i, (x, y, m) in enumerate(plots):
            if self.cell_trace_plots[i] is None:
                line, = self.ax_cell_trace.plot(
                    x, y, markersize=10, marker=m)  # , fontname='Symbola', )
                self.cell_trace_plots[i] = line
            else:
                self.cell_trace_plots[i].set_data(x, y)

    def move_to_cell(self, cell):
        number = self.current_image
        self.current_cell_id = cell
        self.ax_img.set_title(self.make_title())
        if cell not in self.trackdata.get_cell_list():
            self.trackdata.create_cell(cell)
        if self.trackdata.get_cell_state(number, cell) != 0:
            ellipse_data = self.trackdata.get_cell_params(number, cell)
            if self.interactive_cell is None:
                self.interactive_cell = cell_editor.CellInteractor(
                    self.ax_img, *ellipse_data, **self.cell_interactor_style)
            else:
                self.interactive_cell.set_cell_props(*ellipse_data)
        else:
            if self.interactive_cell is not None:
                self.interactive_cell.remove()
                self.interactive_cell = None
                self.fig.canvas.draw_idle()
        # DIAABLED self.plot_cell_compiled_trace(self.current_image, cell)
        # self.ax_cell_compiled_trace.relim()
        # self.ax_cell_compiled_trace.autoscale_view(True,True,True)
        ## DISABLED FOR NOW
        # self.plot_cell_path(self.current_image, cell)
        self.show_non_edit_cells(self.current_image)
        self.fig.canvas.draw_idle()
        #     self.ax_cell_id_select.name = "cell_picker"
        #     self.cur_image_art.remove()

    def update_ui(self):
        self.art_img.set_data(self.bg_img)
        self.art_img.set_clim(vmax=self.bg_img.max() * self.vmaxs)
        self.ax_img.set_title(self.make_title())
        ## update tree
        # self.update_tree(recalculate_parents=False)#False)
        self.fig.canvas.draw_idle()
        # for ot in self.text_labels:
        #     print("removing old labels")
        #     ot.remove()
        #     regionp = skimage.measure.regionprops(self.all_cells)
        #     self.text_labels = [self.ax_img.text(r.centroid[1], r.centroid[0], str(r.label)) for r in regionp]
        #     if len(self.cells) == 0:
        #         cellwzero = np.arange(0, 10)# np.insert(self.cells, 0, [0])
        #     else:
        #         cellwzero = np.arange(0, self.cells.max() + 10)# np.insert(self.cells, 0, [0])
        #     self.art_cell_id_select = self.ax_cell_id_select.scatter(cellwzero, np.ones_like(cellwzero), c=cellwzero, cmap=self.cmrand)

    def update_tree(self, recalculate_parents=False):
        if len(self.trackdata.get_cell_list()) <= 1:
            return None
        if recalculate_parents:
            self.trackdata = track_data.set_possible_parents(self.trackdata)
        self.ax_tree.clear()
        cell_colors = [(0, 0, 0, 1)
                       ] + [self.cmrand(int(i)) for i in self.cells]
        cell_finals = {
            i: self.trackdata.get_final_frame(i)
            for i in self.cells
        }
        self.ax_tree, self.node_locs = self.trackdata.plot_tree(
            self.ax_tree, cell_colors, cell_finals)
        self.ax_tree.set_ylim(self.trackdata.get_max_frames() + 5, -2)
        self.ax_tree.tick_params(axis="y", which="minor")
        self.fig.canvas.draw_idle()

    def select_cell(self, event):
        if event.inaxes is None:
            return None
        if event.inaxes.name == "cell_viewer":
            if self.interactive_cell is not None:
                hits_edit, props = self.interactive_cell.ellipse.contains(
                    event)
                if hits_edit:
                    return None
            hit, cid = self.large_ellipse_coll.contains(event)
            if hit:
                cid = self.non_edit_cells[cid["ind"][0]].cell_id
                self.move_to_cell(cid)

    def select_frame(self, event):
        # print(event)
        # print(event.inaxes)
        if event.inaxes is None:
            return None
        if (event.inaxes.name == "cell_trace") or (event.inaxes.name
                                                   == "compiled_trace"):
            select = int(np.round(event.xdata))
            print("Selecting frame", select)
            self.move_ui_to_image(select)

    def make_current_cell_like_previous_frame(self):
        prev_frame = self.current_image - 1
        if self.trackdata.get_cell_state(prev_frame,
                                         self.current_cell_id) != 0:
            self.trackdata.copy_cell_info_from_frame(self.current_cell_id,
                                                     prev_frame,
                                                     self.current_image)
            ellipse_data = self.trackdata.get_cell_params(
                self.current_image, self.current_cell_id)
            if self.interactive_cell is None:
                self.interactive_cell = cell_editor.CellInteractor(
                    self.ax_img, *ellipse_data, **self.cell_interactor_style)
            self.interactive_cell.set_cell_props(*ellipse_data)
        else:
            print(
                "Cell ",
                self.current_cell_id,
                " is not present in frame",
                self.current_image,
            )
        self.fig.canvas.draw_idle()

    def update_track_data(self):
        """ this updates the track data with currently edited cell"""
        if self.interactive_cell is not None:
            print("updating the data structure")
            print(
                "OLD",
                self.trackdata.get_cell_params(self.current_image,
                                               self.current_cell_id),
            )
            properties = self.interactive_cell.get_position_props()
            properties.update({"status": "checked"})
            print("GUI", properties)
            self.trackdata.set_cell_properties(self.current_image,
                                               self.current_cell_id,
                                               properties)
            print(
                "Saved",
                self.trackdata.get_cell_params(self.current_image,
                                               self.current_cell_id),
            )

    def save_segmentation(self):
        self.update_track_data()
        self.trackdata.save()  # self.trackdata_path)

    def add_new_cell_to_frame(self, cell_id):
        # Try to put the cell in the middle of the screen.
        x0, x1 = self.ax_img.get_xlim()
        y0, y1 = self.ax_img.get_ylim()
        xpos = x0 + (x1 - x0) / 2
        ypos = y1 + (y0 - y1) / 2
        temp_properties = {
            "row": ypos,
            "col": xpos,
            "length": 22,
            "width": 6,
            "angle": 0,
            "state": "there",
        }

        try:
            self.trackdata.add_cell_to_frame(self.current_image, cell_id,
                                             temp_properties)
            print("Adding cell {0}".format(cell_id))
            self.current_cell_id = cell_id
            self.move_to_cell(cell_id)
            self.update_track_data()
        except track_db.SchnitzExistsError:
            print(f"Cell {cell_id} exists in frame {self.current_image}")
            return None
        self.fig.canvas.draw_idle()

    def set_cell_state(self, state):
        state_name = self.trackdata.metadata.states[state]
        self.trackdata.set_cell_state(self.current_image, self.current_cell_id,
                                      state_name)

    def track_cell(self):
        print("tracking")
        print(
            "initial frame{0}".format(self.current_image),
            self.trackdata.get_cell_properties(self.current_image,
                                               self.current_cell_id),
        )
        print(
            "next frame{0}".format(self.current_image + 1),
            self.trackdata.get_cell_properties(self.current_image + 1,
                                               self.current_cell_id),
        )
        ntd = cell_tracker.gui_interpolate(self.trackdata,
                                           self.current_cell_id,
                                           self.current_image)
        self.trackdata = ntd
        print("tracked")
        print(
            "after frame{0}".format(self.current_image),
            self.trackdata.get_cell_properties(self.current_image,
                                               self.current_cell_id),
        )
        print(
            "next frame{0}".format(self.current_image + 1),
            self.trackdata.get_cell_properties(self.current_image + 1,
                                               self.current_cell_id),
        )
        self.move_ui_to_image(self.current_image)
        # self.move_ui_frame(self.current_image)

    def delete_cell_in_frame(self, cell, frame):
        self.trackdata.blank_cell_params(frame, cell)
        self.trackdata.save()
        self.move_ui_to_image(frame)

    def approve_track_status(self, cell_id, frame):
        self.trackdata.set_cell_properties(frame, cell_id,
                                           {"trackstatus": "approved"})

    def set_schntiz_cell_id(self, cell_id, frame):
        self.trackdata.set_cell_id(frame, cell_id, self.text_box_contents)
        self.trackdata.save()
        print("set id to ", self.text_box_contents)
        self.move_to_cell(self.text_box_contents)
        print("moving to ", self.text_box_contents)
        self.update_ui()

    def on_key_press(self, event):
        event_dict = {
            # "t" : self.track_cell,
            "t":
            lambda: self.approve_track_status(self.current_cell_id, self.
                                              current_image),
            "i":
            lambda: self.set_schntiz_cell_id(self.current_cell_id, self.
                                             current_image),
            "w":
            self.save_segmentation,
            "v":
            lambda: self.add_new_cell_to_frame(self.trackdata.get_max_cell_id(
            ) + 1),
            "a":
            lambda: self.add_new_cell_to_frame(self.current_cell_id),
            "c":
            self.make_current_cell_like_previous_frame,
            "X":
            lambda: self.delete_cell_in_frame(self.current_cell_id, self.
                                              current_image),
            "g":
            lambda: self.guess_next_cell_location(direction=+1),
            "f":
            lambda: self.guess_next_cell_location(direction=-1),
            "pagedown":
            lambda: self.move_ui_to_image(self.current_image + 1),
            #"d": lambda: self.move_ui_to_image(self.current_image + 1),
            "pageup":
            lambda: self.move_ui_to_image(self.current_image - 1),
            #"a": lambda: self.move_ui_to_image(self.current_image - 1),
        }
        # print("type", event.key)
        try:
            action = event_dict[event.key]
            action()
        except KeyError as e:
            pass
        # elif event.key == "t":
        #     self.update_tree(True)
        if event.key == "A":
            print("add cell number:")
            self.read_numbers = self.add_new_cell_to_frame
            self.read_in = ""
        elif self.read_numbers is not None:
            if event.key in "1234567890":
                self.read_in += event.key
            elif event.key == "enter":
                print("Got number: {0}".format(self.read_in))
                self.read_numbers(int(self.read_in))
                self.read_numbers = None
                self.read_in = ""
            elif event.key == "esc":
                self.read_numbers = None
                self.read_in = ""
                print("not adding number any more.")
        elif event.key == "S":
            print("set cell_state:", self.trackdata.metadata["states"])
            self.read_numbers = self.set_cell_state
            self.read_in = ""
            if event.key == "esc":
                self.read_numbers = None
                self.read_in = ""
                print("not going to change state")
        elif event.key == "s":
            self.large_ellipse_coll.set_visible(
                not self.large_ellipse_coll.get_visible())
            self.fig.canvas.draw_idle()
        # elif event.key == "i":
        #     self.art_img.set_visible(not self.art_img.get_visible())
        #     self.fig.canvas.draw_idle()
        else:
            if event.key not in event_dict:
                print("Pressing {0} does nothing yet".format(event.key))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--frame", type=int, required=True)
    args = parser.parse_args()

    trackdb = "/home/nmurphy/Dropbox/work/projects/bf_pulse/bf10_track.sqllite"
    filepaths = {
        "basedir": "/media/nmurphy/BF_Data_Orange/proc_data/iphox_movies/",
        "dataset": "BF10_timelapse",
        "lookat": "Column_2",
    }
    image_path = "{basedir}/{dataset}/{lookat}/{lookat}_t{{0:03d}}_ch00.tif".format(
        **filepaths
    )

    td = TrackDB(trackdb)

    now_frame = args.frame
    next_frame = now_frame + 1
    # width = 64
    # c_start, c_end = 0, 16
    # r_start, r_end = 0, 16

    now_image = image_path.format(now_frame)
    next_image = image_path.format(next_frame)

    now_cells = td.get_dataframe_of_cell_properties_in_frame(now_frame)
    next_cells = td.get_dataframe_of_cell_properties_in_frame(next_frame)

    # if a cell is tracked in the next frame remove it from the now set
    next_cells_done_already = next_cells["cell_id"].values
    now_cells = now_cells.loc[~now_cells["cell_id"].isin(next_cells_done_already), :]

    source_centers = now_cells.set_index("cell_id")[["row", "col"]].to_dict(
        orient="index"
    )
    source_centers = {k: (v["row"], v["col"]) for k, v in source_centers.items()}

    search_distance = 40

    next_pos = auto_match.predict_next_location_simple(
        image_path, source_centers.items(), now_frame, 1, search_w=search_distance
    )

    segmented = skimage.io.imread(
        os.path.join(
            os.path.dirname(next_image),
            "model3",
            os.path.basename(next_image).replace("_ch00", ""),
        )
    )

    center_labels = np.zeros_like(segmented[0, :, :], dtype=np.uint16)
    for cell_id, position in next_pos.items():
        center_labels[position[:, 0], position[:, 1]] = cell_id

    # r_offset = r_start * 64
    # c_offset = c_start * 64
    # rows = slice(r_offset, r_end * width)
    # cols = slice(c_offset, c_end * width)
    ellipses = get_ellipses(segmented[0, :, :], center_labels)
    # segmt[1, rows, cols])
    # print(ellipses)

    # ellipses_shift = {
    #     i: shift_ellipse(e, r_offset, c_offset) for (i, e) in enumerate(ellipses)
    # }

    # pre_cells = td.get_cells_in_frame(frame, states=None)  # ie all
    # print(pre_cells)
    # if pre_cells:
    #     old_seg = get_cell_mask_cellids(segmt.shape[1:], td, pre_cells, frame)
    #     new_seg = get_cell_mask_cells(segmt.shape[1:], ellipses_shift)
    #     rr, cc = np.where((old_seg > 0) & (new_seg > 0))
    #     overlap = new_seg[rr, cc]
    #     cells = np.unique(overlap)
    #     counts = np.bincount(new_seg[rr, cc])
    #     for c in cells:
    #         print(c, counts[c])
    #         if counts[c] > 50:
    #             ellipses_shift.pop(c)

    td.add_new_ellipses_to_frame(ellipses, next_frame, {"trackstatus": "auto"})
    td.save()