Esempio n. 1
0
class SiamFC(BaseTracker):
    def __init__(self):
        super(SiamFC, self).__init__("SiamFC")
        # TODO: edit this path
        self.net_file = path_config.SIAMFC_MODEL
        self.tracker = TrackerSiamFC(net_path=self.net_file)

    def initialize(self, image_file, box):
        image = Image.open(image_file).convert("RGB")
        self.tracker.init(image, box)

    def track(self, image_file):
        image = Image.open(image_file).convert("RGB")
        return self.tracker.update(image)
Esempio n. 2
0
def main(mode='IR', visulization=False):
    assert mode in ['IR', 'RGB'], 'Only Support IR or RGB to evalute'
    # setup tracker
    net_path = 'model.pth'
    tracker = TrackerSiamFC(net_path=net_path)

    # setup experiments
    video_paths = glob.glob(os.path.join('dataset', 'test-dev', '*'))
    video_num = len(video_paths)
    output_dir = os.path.join('results', tracker.name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    overall_performance = []

    # run tracking experiments and report performance
    for video_id, video_path in enumerate(video_paths, start=1):
        video_name = os.path.basename(video_path)
        video_file = os.path.join(video_path, '%s.mp4' % mode)
        res_file = os.path.join(video_path, '%s_label.json' % mode)
        with open(res_file, 'r') as f:
            label_res = json.load(f)

        init_rect = label_res['gt_rect'][0]
        capture = cv2.VideoCapture(video_file)

        frame_id = 0
        out_res = []
        while True:
            ret, frame = capture.read()
            if not ret:
                capture.release()
                break
            if frame_id == 0:
                tracker.init(frame, init_rect)  # initialization
                out = init_rect
                out_res.append(init_rect)
            else:
                out = tracker.update(frame)  # tracking
                out_res.append(out.tolist())
            if visulization:
                _gt = label_res['gt_rect'][frame_id]
                _exist = label_res['exist'][frame_id]
                if _exist:
                    cv2.rectangle(frame, (int(_gt[0]), int(_gt[1])),
                                  (int(_gt[0] + _gt[2]), int(_gt[1] + _gt[3])),
                                  (0, 255, 0))
                cv2.putText(frame, 'exist' if _exist else 'not exist',
                            (frame.shape[1] // 2 - 20, 30), 1, 2,
                            (0, 255, 0) if _exist else (0, 0, 255), 2)

                cv2.rectangle(frame, (int(out[0]), int(out[1])),
                              (int(out[0] + out[2]), int(out[1] + out[3])),
                              (0, 255, 255))
                cv2.imshow(video_name, frame)
                cv2.waitKey(1)
            frame_id += 1
        if visulization:
            cv2.destroyAllWindows()
        # save result
        output_file = os.path.join(output_dir,
                                   '%s_%s.txt' % (video_name, mode))
        with open(output_file, 'w') as f:
            json.dump({'res': out_res}, f)

        mixed_measure = eval(out_res, label_res)
        overall_performance.append(mixed_measure)
        print('[%03d/%03d] %20s %5s Fixed Measure: %.03f' %
              (video_id, video_num, video_name, mode, mixed_measure))

    print('[Overall] %5s Mixed Measure: %.03f\n' %
          (mode, np.mean(overall_performance)))
Esempio n. 3
0
class CellTracker:
    def __init__(self, siamese_model_path, unet_path, dataset_path, use_cuda,
                 new_w, new_h):
        self.dataset_path = dataset_path
        self.unet_path = unet_path
        self.new_w = new_w
        self.new_h = new_h

        self.tracker = TrackerSiamFC(net_path=siamese_model_path,
                                     use_cuda=use_cuda)
        if unet_path is not None:
            print("Loading pretrained model")
            self.seg_net, pretrained = self.load_unet(), True
        else:
            print("Did not load pretrained model")
            self.seg_net, pretrained = None, False

        self.train_images = None
        self.tracks = {}
        self.track_count = 0
        self.result = []
        self.set_01, self.set_02 = None, None

    def load_unet(self):
        model = create_model(self.unet_path, self.new_w, self.new_h)
        return model

    def load_evaluation_images(self, sequence, extension=".tif"):
        seg_dir = "/0{}".format(sequence)
        result = []
        print("Loading test images from {}".format(
            os.path.join(self.dataset_path + seg_dir, "*" + extension)))
        for frame_id, img_path in enumerate(
                glob.glob(
                    os.path.join(self.dataset_path + seg_dir,
                                 "*" + extension))):
            # name = img_path.split("\\t")[-1].split(extension)[0]
            name = img_path.split("/t")[-1].split(extension)[0]
            # print("Image name: {}".format(name))
            img = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH)
            seg_img = None
            result.append(
                CellImage(img, name, self.dataset_path, sequence, seg_img))
        result = sorted(result, key=lambda x: x.image_name, reverse=False)
        return result

    # predicts binary segmentation for input image using the unet
    def predict_seg(self, input_img, thr_markers=240, thr_cell_mask=230):

        w = np.shape(input_img)[0]
        h = np.shape(input_img)[1]
        img = cv2.equalizeHist(np.minimum(input_img, 255).astype(
            np.uint8)) / 255
        img = img.reshape((1, w, h, 1)) - .5

        if self.new_w > 0 or self.new_h > 0:
            img2 = np.zeros((1, self.new_w, self.new_h, 1), dtype=np.float32)
            img2[:, :w, :h, :] = img
            img = img2

        prediction = self.seg_net.predict(img, batch_size=1)
        # prediction = prediction[0, :w, :h, 1]

        # New watershed
        # naive_seg = postprocess_cell_mask(prediction[0, :w, :h, 3] * 255, threshold=thr_cell_mask)
        # distance = ndi.distance_transform_edt(naive_seg)
        # local_maxi = peak_local_max(distance, labels=naive_seg, footprint=np.ones((15, 15)), indices=False)
        # markers = ndi.label(local_maxi)[0]
        # prediction = watershed(-distance, markers, mask=naive_seg)

        # # Watershed
        # m = prediction[0, :w, :h, 1] * 255
        # c = prediction[0, :w, :h, 3] * 255
        # o = (img + .5) * 255
        #
        # # postprocess the result of prediction
        # idx, markers = postprocess_markers(m, threshold=thr_markers, erosion_size=1, circular=False,
        #                                    step=30)
        # cell_mask = postprocess_cell_mask(c, threshold=thr_cell_mask)
        # # correct border
        # cell_mask = np.maximum(cell_mask, markers)
        # prediction = (watershed(-c, markers, mask=cell_mask) > 0)*1.0

        # # Previous unet
        # img = input_img / 255
        # img = torch.Tensor(list(transform.resize(input_img, (512, 512), mode='symmetric'))).unsqueeze(0).unsqueeze(
        #     0).permute(0, 2, 3, 1).numpy()
        # prediction = self.seg_net.predict(img)
        #
        # prediction = cv2.resize(prediction[0, :, :, 0], tuple(reversed(input_img.shape)))
        # prediction = np.array(prediction)
        # _, prediction = cv2.threshold(prediction, 0.6, 1, cv2.THRESH_BINARY)
        # prediction = prediction#.astype(np.uint16)

        return prediction

    # predict segmentation for all frames:
    # def segment_images(self, sequence):
    #     for cell_img in tqdm(self.train_images[sequence]):
    #         cell_img.binary_seg = self.predict_seg(cell_img.image)

    # writes the frames with cell locations to a video file
    def store_footage(self, sequence, fps: int = 3):

        # output_data = np.array([x for x in self.result])
        # output_data = output_data.astype(np.uint8)
        # skvideo.io.vwrite("result {} 0{}.mp4".format(self.name, sequence), output_data, inputdict={'-r': str(fps)})

        # Define the codec and create VideoWriter object
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        out = cv2.VideoWriter(
            "{}/result_0{}.avi".format(self.dataset_path, sequence), fourcc,
            float(fps), (self.result[0].shape[1], self.result[0].shape[0]))
        for frame in self.result:
            out.write(frame.astype(np.uint8))
        out.release()

    # predict the new location of a cell located in frame1:
    def predict_cell_location(self, frame1, frame2, cell):
        self.tracker.init(
            frame1, [cell.min_col, cell.min_row, cell.width, cell.height])
        [x, y, w, h] = self.tracker.update(frame2)
        return [int(x), int(y), int(x + w), int(y + h)]

    # uses random walker with markers from previous frame to predict a new segmentation
    # for collided cells
    @staticmethod
    def resegmentation(initial_segmentation, local_maxi):
        markers = measure.label(local_maxi)
        markers[~initial_segmentation] = -1
        labels = random_walker(initial_segmentation, markers)
        labels[labels == -1] = 0
        return labels

    # stores the tracks' desciptions in the right format
    def store_track(self, filename, sequence):
        store_path = os.path.join(self.dataset_path,
                                  "0{}_RES/".format(sequence), filename)
        keys = sorted(list(self.tracks.keys()))
        print("Storing track at {}".format(store_path))
        with open(store_path, 'w', encoding='utf-8') as file:
            for k in keys:
                file.write(str(self.tracks[k]) + "\n")

    # get the backward matches and resegment when a collision is detected:
    def get_new_detections_dict(self,
                                previous_frame,
                                prev_img,
                                current_frame,
                                track_dict,
                                alt=True):

        cur_img = np.stack((current_frame.image.astype(np.int16), ) * 3,
                           axis=-1)

        available_cells = current_frame.get_cell_locations()
        # print("available_cells: {}".format(available_cells))
        new_detections_dict = {c: [] for c in available_cells}

        prev_num_cells = len(new_detections_dict)
        cur_num_cells = 0
        # return current_frame.seg_output, new_detections_dict, cur_img

        while alt and prev_num_cells != cur_num_cells:
            prev_num_cells = cur_num_cells
            alt = False

            # cell detection backward pass
            for dest_cell in new_detections_dict.keys():

                # get the predicted location of the cell in the previous frame
                [tl_x, tl_y, br_x,
                 br_y] = self.predict_cell_location(cur_img, prev_img,
                                                    deepcopy(dest_cell))

                # match all cell located in that area
                for cell_track in track_dict.keys():
                    c = cell_track.current_cell

                    if tl_x < c.centroid_x < br_x and tl_y < c.centroid_y < br_y:
                        try:
                            new_detections_dict[dest_cell].append(cell_track)
                        except Exception as e:
                            print("new_detections_dict.keys(): {}".format(
                                new_detections_dict.keys()))
                            print("dest_cell: {}".format(dest_cell))
                            # print(new_detections_dict[str(dest_cell)])
                            print(new_detections_dict[dest_cell])
                            print("\n\n")
                            print(new_detections_dict)
                            raise e

                # forward checking:
                if len(new_detections_dict[dest_cell]) > 1:
                    final = []
                    for c_t in new_detections_dict[dest_cell]:
                        [tl_x, tl_y, br_x, br_y] = self.predict_cell_location(
                            prev_img, cur_img, c_t.current_cell)
                        pred_center_x, pred_center_y = int(
                            (br_x + tl_x) / 2), int((br_y + tl_y) / 2)

                        if tl_x < dest_cell.centroid_x < br_x and tl_y < dest_cell.centroid_y < br_y:
                            final.append(c_t)

                        elif dest_cell.min_col < pred_center_x < dest_cell.max_col and \
                                dest_cell.min_row < pred_center_y < dest_cell.max_row:
                            final.append(c_t)
                    new_detections_dict[dest_cell] = final

                # if two or more cells are matching a collision has occured and the frame has to be resegmented
                if len(new_detections_dict[dest_cell]) > 1:
                    alt = True
                    tl_x = min([
                        t.current_cell.min_col
                        for t in new_detections_dict[dest_cell]
                    ])
                    tl_y = min([
                        t.current_cell.min_row
                        for t in new_detections_dict[dest_cell]
                    ])
                    br_x = max([
                        t.current_cell.max_col
                        for t in new_detections_dict[dest_cell]
                    ])
                    br_y = max([
                        t.current_cell.max_row
                        for t in new_detections_dict[dest_cell]
                    ])

                    frame1_seg = previous_frame.seg_output[tl_y:br_y,
                                                           tl_x:br_x]
                    distance = np.zeros(
                        (int(frame1_seg.shape[0]), int(frame1_seg.shape[1])))

                    for color, t in enumerate(new_detections_dict[dest_cell]):
                        a = max(0, int(t.current_cell.centroid_y) - tl_y - 1)
                        b = max(0, int(t.current_cell.centroid_x) - tl_x - 1)
                        distance[a:a + 2,
                                 b:b + 2] = np.array(np.full((2, 2),
                                                             color + 1))

                    distance = cv2.resize(
                        distance,
                        (int(dest_cell.width), int(dest_cell.height)),
                        interpolation=cv2.INTER_NEAREST)

                    current_frame.seg_output[current_frame.seg_output ==
                                             dest_cell.color] = 0
                    new_seg = self.resegmentation(dest_cell.segmentation,
                                                  distance)

                    # plot new segmentation result:
                    # self.multiplot([frame1_seg, dest_cell.segmentation, new_seg])
                    global bla_count
                    bla_count += 1
                    # print(bla_count)
                    # vis = np.concatenate((dest_cell.segmentation, new_seg), axis=1)
                    # plt.imshow(vis)
                    # plt.show()

                    current_frame.seg_output[
                        dest_cell.min_row:dest_cell.max_row,
                        dest_cell.min_col:dest_cell.max_col] = np.maximum(
                            current_frame.seg_output[
                                dest_cell.min_row:dest_cell.max_row,
                                dest_cell.min_col:dest_cell.max_col], new_seg)

                    # relabel the frame and redetect the cells
                    current_frame.seg_output = measure.label(
                        current_frame.seg_output)
                    available_cells = current_frame.get_cell_locations()
                    new_detections_dict = {c: [] for c in available_cells}
                    cur_num_cells = len(new_detections_dict)
                    break

                    # display frame final segmentation:
                    # plt.imshow(current_frame.seg_output)
                    # plt.show()
        return current_frame.seg_output, new_detections_dict, cur_img

    @staticmethod
    def multiplot(image_list):
        abc = ["A", "B", "C", "D", "E", "F", "G", "H"]
        # abc2 = ["t=i", "t=i+1", "t=i+1", "t=i+1"]
        fig, axes = plt.subplots(nrows=1,
                                 ncols=len(image_list),
                                 figsize=(10, 5))

        for num, x in enumerate(image_list):
            plt.subplot(1, len(image_list), num + 1)
            plt.title(abc[num], fontsize=25)
            plt.axis('off')
            plt.imshow(x)
        # plt.subplots_adjust(left=0.1, right=0.1, top=0.1, bottom=0.1)
        # fig.tight_layout()
        plt.savefig("segres.png", bbox_inches='tight')

    @staticmethod
    # makes sure that all cells in a given track have the same pixel value in segmentation
    def propagate_labels(frame, living_tracks, track, cell):
        free_label = max(np.unique(frame.seg_output)) + 1

        if track.cell_id != cell.color:
            for swap_id, swap_track in living_tracks:
                if swap_id != track.cell_id and swap_track.current_cell.color == track.cell_id:
                    frame.seg_output = np.where(
                        frame.seg_output == swap_track.current_cell.color,
                        free_label, frame.seg_output)
                    swap_track.current_cell.color = free_label
                    break

            frame.seg_output = np.where(frame.seg_output == cell.color,
                                        track.cell_id, frame.seg_output)
            cell.color = track.cell_id
        return frame, living_tracks

    # runs the tracking algorithm and exports results for evaluation
    def run_test(self,
                 sequence,
                 collision_detection=True,
                 store_footage=0,
                 load_segs_from_file=None):
        set_01 = self.load_evaluation_images(sequence)
        self.result = []

        print("Segmenting footage:")
        # apply initial segmentation to footage:
        if load_segs_from_file is None:
            for frame in tqdm(set_01, position=0):
                frame.binary_seg = self.predict_seg(frame.image)
                frame.binary_seg_to_output()
        else:
            imgs_list = sorted(make_list_of_imgs_only(
                os.listdir(load_segs_from_file), 'tif'),
                               key=natural_keys)
            for frame in tqdm(set_01, position=0):
                img_indx = int(frame.image_name)
                img_path = os.path.join(load_segs_from_file,
                                        imgs_list[img_indx])
                print("reading img: {} at {}".format(img_indx, img_path))
                frame.binary_seg = mpimg.imread(img_path)
                frame.binary_seg_to_output()

        # load first frame
        previous_frame = set_01[0]
        prev_img = np.stack((previous_frame.image.astype(np.int16), ) * 3,
                            axis=-1)

        # init tracks from detected cells in first frame
        self.tracks = {
            c_id + 1: CellTrack(c, 0, c_id + 1, 0)
            for c_id, c in enumerate(previous_frame.get_cell_locations())
        }
        self.track_count = len(self.tracks) + 1

        i = 0
        print("Tracking:")
        for current_frame in tqdm(set_01[1:]):
            track_dict = {t: [] for t in self.tracks.values() if t.alive}

            # solve collisions in segmentation and locate all cells in the new frame:
            current_frame.seg_output, new_detections_dict, cur_img = \
                self.get_new_detections_dict(previous_frame, prev_img, current_frame, track_dict, collision_detection)

            # image for video visualisation
            cur_copy = cur_img.copy()

            # cell detection forward pass:
            for cell_track in track_dict.keys():
                cell = cell_track.current_cell

                [tl_x, tl_y, br_x,
                 br_y] = self.predict_cell_location(prev_img, cur_img, cell)
                pred_center_x, pred_center_y = int((br_x + tl_x) / 2), int(
                    (br_y + tl_y) / 2)

                for c in new_detections_dict.keys():
                    if c.min_col < pred_center_x < c.max_col and c.min_row < pred_center_y < c.max_row:
                        track_dict[cell_track].append(c)
                    elif tl_x < c.centroid_x < br_x and tl_y < c.centroid_y < br_y:
                        track_dict[cell_track].append(c)

            # match cells from the previous frame to newly located cells:
            new_tracks = {}
            for track_id, cell_track in self.tracks.items():
                matched = False

                # if death cell add it to the new stack
                if not cell_track.alive:
                    new_tracks[track_id] = cell_track

                # else try to find a track continuation
                elif cell_track.alive:

                    forward_match = track_dict[cell_track]

                    # case 1->_
                    if not forward_match:

                        # check if 1<-1:
                        for dest_cell, t in new_detections_dict.items():
                            if cell_track in t:
                                cell_track.add_cell(dest_cell)
                                matched = True
                                break

                        # remove the cell from free new detection
                        if matched:
                            del new_detections_dict[cell_track.current_cell]

                        # else the cell has no match and has died:
                        else:
                            cell_track.alive = False
                        new_tracks[track_id] = cell_track

                    # case 1->1
                    elif len(forward_match) == 1:
                        dest_cell = forward_match[0]

                        # should not occur:
                        if dest_cell not in new_detections_dict:
                            cell_track.alive = False

                        # if 1<-1 or _<-1 match
                        elif not new_detections_dict[
                                dest_cell] or cell_track in new_detections_dict[
                                    dest_cell]:
                            cell_track.add_cell(dest_cell)
                            del new_detections_dict[dest_cell]

                        # if 2<-1 death
                        else:
                            cell_track.alive = False
                        new_tracks[track_id] = cell_track

                    # case 1->1,2 (mitosis)
                    else:
                        available = [
                            c for c in forward_match
                            if c in new_detections_dict
                        ]

                        if not available or len(available) > 1:
                            cell_track.alive = False
                            for dest_cell in available:
                                del new_detections_dict[dest_cell]
                                new_track = CellTrack(dest_cell, i + 1,
                                                      self.track_count,
                                                      cell_track.cell_id)
                                new_tracks[self.track_count] = new_track
                                self.track_count += 1
                        else:
                            del new_detections_dict[available[0]]
                            cell_track.add_cell(available[0])

                        new_tracks[track_id] = cell_track

            # create new tracks for the unmatched cells in the new frame:
            for dest_cell, t in new_detections_dict.items():
                new_track = CellTrack(dest_cell, i + 1, self.track_count, 0)
                new_tracks[self.track_count] = new_track
                self.track_count += 1

            living_tracks = [(track_id, track)
                             for track_id, track in new_tracks.items()
                             if track.last_frame == i + 1]
            # displaying cell locations on frame for visual evaluation:
            for track_id, track in living_tracks:
                c = track.current_cell
                current_frame, living_tracks = self.propagate_labels(
                    current_frame, living_tracks, track, c)
                cv2.rectangle(cur_copy, c.tl, c.br, track.display_color, 3)
                x, y = int(c.centroid_x), int(c.centroid_y)
                cv2.putText(cur_copy, str(track_id), (x, y),
                            cv2.FONT_HERSHEY_SIMPLEX, .4, track.display_color,
                            1, cv2.LINE_AA)
            cv2.putText(cur_copy, str(i + 1), (5, 25),
                        cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 0, 0), 1,
                        cv2.LINE_AA)

            self.result.append(cur_copy)
            self.tracks = new_tracks
            previous_frame = current_frame
            prev_img = cur_img
            i += 1

        print("Saving results")
        for frame in tqdm(set_01):
            frame.store()
        self.store_track("res_track.txt", sequence)

        if store_footage:
            print("Creating video")
            self.store_footage(sequence=sequence, fps=3)