class SiamFC(BaseTracker): def __init__(self): super(SiamFC, self).__init__("SiamFC") # TODO: edit this path self.net_file = path_config.SIAMFC_MODEL self.tracker = TrackerSiamFC(net_path=self.net_file) def initialize(self, image_file, box): image = Image.open(image_file).convert("RGB") self.tracker.init(image, box) def track(self, image_file): image = Image.open(image_file).convert("RGB") return self.tracker.update(image)
def main(mode='IR', visulization=False): assert mode in ['IR', 'RGB'], 'Only Support IR or RGB to evalute' # setup tracker net_path = 'model.pth' tracker = TrackerSiamFC(net_path=net_path) # setup experiments video_paths = glob.glob(os.path.join('dataset', 'test-dev', '*')) video_num = len(video_paths) output_dir = os.path.join('results', tracker.name) if not os.path.exists(output_dir): os.makedirs(output_dir) overall_performance = [] # run tracking experiments and report performance for video_id, video_path in enumerate(video_paths, start=1): video_name = os.path.basename(video_path) video_file = os.path.join(video_path, '%s.mp4' % mode) res_file = os.path.join(video_path, '%s_label.json' % mode) with open(res_file, 'r') as f: label_res = json.load(f) init_rect = label_res['gt_rect'][0] capture = cv2.VideoCapture(video_file) frame_id = 0 out_res = [] while True: ret, frame = capture.read() if not ret: capture.release() break if frame_id == 0: tracker.init(frame, init_rect) # initialization out = init_rect out_res.append(init_rect) else: out = tracker.update(frame) # tracking out_res.append(out.tolist()) if visulization: _gt = label_res['gt_rect'][frame_id] _exist = label_res['exist'][frame_id] if _exist: cv2.rectangle(frame, (int(_gt[0]), int(_gt[1])), (int(_gt[0] + _gt[2]), int(_gt[1] + _gt[3])), (0, 255, 0)) cv2.putText(frame, 'exist' if _exist else 'not exist', (frame.shape[1] // 2 - 20, 30), 1, 2, (0, 255, 0) if _exist else (0, 0, 255), 2) cv2.rectangle(frame, (int(out[0]), int(out[1])), (int(out[0] + out[2]), int(out[1] + out[3])), (0, 255, 255)) cv2.imshow(video_name, frame) cv2.waitKey(1) frame_id += 1 if visulization: cv2.destroyAllWindows() # save result output_file = os.path.join(output_dir, '%s_%s.txt' % (video_name, mode)) with open(output_file, 'w') as f: json.dump({'res': out_res}, f) mixed_measure = eval(out_res, label_res) overall_performance.append(mixed_measure) print('[%03d/%03d] %20s %5s Fixed Measure: %.03f' % (video_id, video_num, video_name, mode, mixed_measure)) print('[Overall] %5s Mixed Measure: %.03f\n' % (mode, np.mean(overall_performance)))
class CellTracker: def __init__(self, siamese_model_path, unet_path, dataset_path, use_cuda, new_w, new_h): self.dataset_path = dataset_path self.unet_path = unet_path self.new_w = new_w self.new_h = new_h self.tracker = TrackerSiamFC(net_path=siamese_model_path, use_cuda=use_cuda) if unet_path is not None: print("Loading pretrained model") self.seg_net, pretrained = self.load_unet(), True else: print("Did not load pretrained model") self.seg_net, pretrained = None, False self.train_images = None self.tracks = {} self.track_count = 0 self.result = [] self.set_01, self.set_02 = None, None def load_unet(self): model = create_model(self.unet_path, self.new_w, self.new_h) return model def load_evaluation_images(self, sequence, extension=".tif"): seg_dir = "/0{}".format(sequence) result = [] print("Loading test images from {}".format( os.path.join(self.dataset_path + seg_dir, "*" + extension))) for frame_id, img_path in enumerate( glob.glob( os.path.join(self.dataset_path + seg_dir, "*" + extension))): # name = img_path.split("\\t")[-1].split(extension)[0] name = img_path.split("/t")[-1].split(extension)[0] # print("Image name: {}".format(name)) img = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH) seg_img = None result.append( CellImage(img, name, self.dataset_path, sequence, seg_img)) result = sorted(result, key=lambda x: x.image_name, reverse=False) return result # predicts binary segmentation for input image using the unet def predict_seg(self, input_img, thr_markers=240, thr_cell_mask=230): w = np.shape(input_img)[0] h = np.shape(input_img)[1] img = cv2.equalizeHist(np.minimum(input_img, 255).astype( np.uint8)) / 255 img = img.reshape((1, w, h, 1)) - .5 if self.new_w > 0 or self.new_h > 0: img2 = np.zeros((1, self.new_w, self.new_h, 1), dtype=np.float32) img2[:, :w, :h, :] = img img = img2 prediction = self.seg_net.predict(img, batch_size=1) # prediction = prediction[0, :w, :h, 1] # New watershed # naive_seg = postprocess_cell_mask(prediction[0, :w, :h, 3] * 255, threshold=thr_cell_mask) # distance = ndi.distance_transform_edt(naive_seg) # local_maxi = peak_local_max(distance, labels=naive_seg, footprint=np.ones((15, 15)), indices=False) # markers = ndi.label(local_maxi)[0] # prediction = watershed(-distance, markers, mask=naive_seg) # # Watershed # m = prediction[0, :w, :h, 1] * 255 # c = prediction[0, :w, :h, 3] * 255 # o = (img + .5) * 255 # # # postprocess the result of prediction # idx, markers = postprocess_markers(m, threshold=thr_markers, erosion_size=1, circular=False, # step=30) # cell_mask = postprocess_cell_mask(c, threshold=thr_cell_mask) # # correct border # cell_mask = np.maximum(cell_mask, markers) # prediction = (watershed(-c, markers, mask=cell_mask) > 0)*1.0 # # Previous unet # img = input_img / 255 # img = torch.Tensor(list(transform.resize(input_img, (512, 512), mode='symmetric'))).unsqueeze(0).unsqueeze( # 0).permute(0, 2, 3, 1).numpy() # prediction = self.seg_net.predict(img) # # prediction = cv2.resize(prediction[0, :, :, 0], tuple(reversed(input_img.shape))) # prediction = np.array(prediction) # _, prediction = cv2.threshold(prediction, 0.6, 1, cv2.THRESH_BINARY) # prediction = prediction#.astype(np.uint16) return prediction # predict segmentation for all frames: # def segment_images(self, sequence): # for cell_img in tqdm(self.train_images[sequence]): # cell_img.binary_seg = self.predict_seg(cell_img.image) # writes the frames with cell locations to a video file def store_footage(self, sequence, fps: int = 3): # output_data = np.array([x for x in self.result]) # output_data = output_data.astype(np.uint8) # skvideo.io.vwrite("result {} 0{}.mp4".format(self.name, sequence), output_data, inputdict={'-r': str(fps)}) # Define the codec and create VideoWriter object fourcc = cv2.VideoWriter_fourcc(*'XVID') out = cv2.VideoWriter( "{}/result_0{}.avi".format(self.dataset_path, sequence), fourcc, float(fps), (self.result[0].shape[1], self.result[0].shape[0])) for frame in self.result: out.write(frame.astype(np.uint8)) out.release() # predict the new location of a cell located in frame1: def predict_cell_location(self, frame1, frame2, cell): self.tracker.init( frame1, [cell.min_col, cell.min_row, cell.width, cell.height]) [x, y, w, h] = self.tracker.update(frame2) return [int(x), int(y), int(x + w), int(y + h)] # uses random walker with markers from previous frame to predict a new segmentation # for collided cells @staticmethod def resegmentation(initial_segmentation, local_maxi): markers = measure.label(local_maxi) markers[~initial_segmentation] = -1 labels = random_walker(initial_segmentation, markers) labels[labels == -1] = 0 return labels # stores the tracks' desciptions in the right format def store_track(self, filename, sequence): store_path = os.path.join(self.dataset_path, "0{}_RES/".format(sequence), filename) keys = sorted(list(self.tracks.keys())) print("Storing track at {}".format(store_path)) with open(store_path, 'w', encoding='utf-8') as file: for k in keys: file.write(str(self.tracks[k]) + "\n") # get the backward matches and resegment when a collision is detected: def get_new_detections_dict(self, previous_frame, prev_img, current_frame, track_dict, alt=True): cur_img = np.stack((current_frame.image.astype(np.int16), ) * 3, axis=-1) available_cells = current_frame.get_cell_locations() # print("available_cells: {}".format(available_cells)) new_detections_dict = {c: [] for c in available_cells} prev_num_cells = len(new_detections_dict) cur_num_cells = 0 # return current_frame.seg_output, new_detections_dict, cur_img while alt and prev_num_cells != cur_num_cells: prev_num_cells = cur_num_cells alt = False # cell detection backward pass for dest_cell in new_detections_dict.keys(): # get the predicted location of the cell in the previous frame [tl_x, tl_y, br_x, br_y] = self.predict_cell_location(cur_img, prev_img, deepcopy(dest_cell)) # match all cell located in that area for cell_track in track_dict.keys(): c = cell_track.current_cell if tl_x < c.centroid_x < br_x and tl_y < c.centroid_y < br_y: try: new_detections_dict[dest_cell].append(cell_track) except Exception as e: print("new_detections_dict.keys(): {}".format( new_detections_dict.keys())) print("dest_cell: {}".format(dest_cell)) # print(new_detections_dict[str(dest_cell)]) print(new_detections_dict[dest_cell]) print("\n\n") print(new_detections_dict) raise e # forward checking: if len(new_detections_dict[dest_cell]) > 1: final = [] for c_t in new_detections_dict[dest_cell]: [tl_x, tl_y, br_x, br_y] = self.predict_cell_location( prev_img, cur_img, c_t.current_cell) pred_center_x, pred_center_y = int( (br_x + tl_x) / 2), int((br_y + tl_y) / 2) if tl_x < dest_cell.centroid_x < br_x and tl_y < dest_cell.centroid_y < br_y: final.append(c_t) elif dest_cell.min_col < pred_center_x < dest_cell.max_col and \ dest_cell.min_row < pred_center_y < dest_cell.max_row: final.append(c_t) new_detections_dict[dest_cell] = final # if two or more cells are matching a collision has occured and the frame has to be resegmented if len(new_detections_dict[dest_cell]) > 1: alt = True tl_x = min([ t.current_cell.min_col for t in new_detections_dict[dest_cell] ]) tl_y = min([ t.current_cell.min_row for t in new_detections_dict[dest_cell] ]) br_x = max([ t.current_cell.max_col for t in new_detections_dict[dest_cell] ]) br_y = max([ t.current_cell.max_row for t in new_detections_dict[dest_cell] ]) frame1_seg = previous_frame.seg_output[tl_y:br_y, tl_x:br_x] distance = np.zeros( (int(frame1_seg.shape[0]), int(frame1_seg.shape[1]))) for color, t in enumerate(new_detections_dict[dest_cell]): a = max(0, int(t.current_cell.centroid_y) - tl_y - 1) b = max(0, int(t.current_cell.centroid_x) - tl_x - 1) distance[a:a + 2, b:b + 2] = np.array(np.full((2, 2), color + 1)) distance = cv2.resize( distance, (int(dest_cell.width), int(dest_cell.height)), interpolation=cv2.INTER_NEAREST) current_frame.seg_output[current_frame.seg_output == dest_cell.color] = 0 new_seg = self.resegmentation(dest_cell.segmentation, distance) # plot new segmentation result: # self.multiplot([frame1_seg, dest_cell.segmentation, new_seg]) global bla_count bla_count += 1 # print(bla_count) # vis = np.concatenate((dest_cell.segmentation, new_seg), axis=1) # plt.imshow(vis) # plt.show() current_frame.seg_output[ dest_cell.min_row:dest_cell.max_row, dest_cell.min_col:dest_cell.max_col] = np.maximum( current_frame.seg_output[ dest_cell.min_row:dest_cell.max_row, dest_cell.min_col:dest_cell.max_col], new_seg) # relabel the frame and redetect the cells current_frame.seg_output = measure.label( current_frame.seg_output) available_cells = current_frame.get_cell_locations() new_detections_dict = {c: [] for c in available_cells} cur_num_cells = len(new_detections_dict) break # display frame final segmentation: # plt.imshow(current_frame.seg_output) # plt.show() return current_frame.seg_output, new_detections_dict, cur_img @staticmethod def multiplot(image_list): abc = ["A", "B", "C", "D", "E", "F", "G", "H"] # abc2 = ["t=i", "t=i+1", "t=i+1", "t=i+1"] fig, axes = plt.subplots(nrows=1, ncols=len(image_list), figsize=(10, 5)) for num, x in enumerate(image_list): plt.subplot(1, len(image_list), num + 1) plt.title(abc[num], fontsize=25) plt.axis('off') plt.imshow(x) # plt.subplots_adjust(left=0.1, right=0.1, top=0.1, bottom=0.1) # fig.tight_layout() plt.savefig("segres.png", bbox_inches='tight') @staticmethod # makes sure that all cells in a given track have the same pixel value in segmentation def propagate_labels(frame, living_tracks, track, cell): free_label = max(np.unique(frame.seg_output)) + 1 if track.cell_id != cell.color: for swap_id, swap_track in living_tracks: if swap_id != track.cell_id and swap_track.current_cell.color == track.cell_id: frame.seg_output = np.where( frame.seg_output == swap_track.current_cell.color, free_label, frame.seg_output) swap_track.current_cell.color = free_label break frame.seg_output = np.where(frame.seg_output == cell.color, track.cell_id, frame.seg_output) cell.color = track.cell_id return frame, living_tracks # runs the tracking algorithm and exports results for evaluation def run_test(self, sequence, collision_detection=True, store_footage=0, load_segs_from_file=None): set_01 = self.load_evaluation_images(sequence) self.result = [] print("Segmenting footage:") # apply initial segmentation to footage: if load_segs_from_file is None: for frame in tqdm(set_01, position=0): frame.binary_seg = self.predict_seg(frame.image) frame.binary_seg_to_output() else: imgs_list = sorted(make_list_of_imgs_only( os.listdir(load_segs_from_file), 'tif'), key=natural_keys) for frame in tqdm(set_01, position=0): img_indx = int(frame.image_name) img_path = os.path.join(load_segs_from_file, imgs_list[img_indx]) print("reading img: {} at {}".format(img_indx, img_path)) frame.binary_seg = mpimg.imread(img_path) frame.binary_seg_to_output() # load first frame previous_frame = set_01[0] prev_img = np.stack((previous_frame.image.astype(np.int16), ) * 3, axis=-1) # init tracks from detected cells in first frame self.tracks = { c_id + 1: CellTrack(c, 0, c_id + 1, 0) for c_id, c in enumerate(previous_frame.get_cell_locations()) } self.track_count = len(self.tracks) + 1 i = 0 print("Tracking:") for current_frame in tqdm(set_01[1:]): track_dict = {t: [] for t in self.tracks.values() if t.alive} # solve collisions in segmentation and locate all cells in the new frame: current_frame.seg_output, new_detections_dict, cur_img = \ self.get_new_detections_dict(previous_frame, prev_img, current_frame, track_dict, collision_detection) # image for video visualisation cur_copy = cur_img.copy() # cell detection forward pass: for cell_track in track_dict.keys(): cell = cell_track.current_cell [tl_x, tl_y, br_x, br_y] = self.predict_cell_location(prev_img, cur_img, cell) pred_center_x, pred_center_y = int((br_x + tl_x) / 2), int( (br_y + tl_y) / 2) for c in new_detections_dict.keys(): if c.min_col < pred_center_x < c.max_col and c.min_row < pred_center_y < c.max_row: track_dict[cell_track].append(c) elif tl_x < c.centroid_x < br_x and tl_y < c.centroid_y < br_y: track_dict[cell_track].append(c) # match cells from the previous frame to newly located cells: new_tracks = {} for track_id, cell_track in self.tracks.items(): matched = False # if death cell add it to the new stack if not cell_track.alive: new_tracks[track_id] = cell_track # else try to find a track continuation elif cell_track.alive: forward_match = track_dict[cell_track] # case 1->_ if not forward_match: # check if 1<-1: for dest_cell, t in new_detections_dict.items(): if cell_track in t: cell_track.add_cell(dest_cell) matched = True break # remove the cell from free new detection if matched: del new_detections_dict[cell_track.current_cell] # else the cell has no match and has died: else: cell_track.alive = False new_tracks[track_id] = cell_track # case 1->1 elif len(forward_match) == 1: dest_cell = forward_match[0] # should not occur: if dest_cell not in new_detections_dict: cell_track.alive = False # if 1<-1 or _<-1 match elif not new_detections_dict[ dest_cell] or cell_track in new_detections_dict[ dest_cell]: cell_track.add_cell(dest_cell) del new_detections_dict[dest_cell] # if 2<-1 death else: cell_track.alive = False new_tracks[track_id] = cell_track # case 1->1,2 (mitosis) else: available = [ c for c in forward_match if c in new_detections_dict ] if not available or len(available) > 1: cell_track.alive = False for dest_cell in available: del new_detections_dict[dest_cell] new_track = CellTrack(dest_cell, i + 1, self.track_count, cell_track.cell_id) new_tracks[self.track_count] = new_track self.track_count += 1 else: del new_detections_dict[available[0]] cell_track.add_cell(available[0]) new_tracks[track_id] = cell_track # create new tracks for the unmatched cells in the new frame: for dest_cell, t in new_detections_dict.items(): new_track = CellTrack(dest_cell, i + 1, self.track_count, 0) new_tracks[self.track_count] = new_track self.track_count += 1 living_tracks = [(track_id, track) for track_id, track in new_tracks.items() if track.last_frame == i + 1] # displaying cell locations on frame for visual evaluation: for track_id, track in living_tracks: c = track.current_cell current_frame, living_tracks = self.propagate_labels( current_frame, living_tracks, track, c) cv2.rectangle(cur_copy, c.tl, c.br, track.display_color, 3) x, y = int(c.centroid_x), int(c.centroid_y) cv2.putText(cur_copy, str(track_id), (x, y), cv2.FONT_HERSHEY_SIMPLEX, .4, track.display_color, 1, cv2.LINE_AA) cv2.putText(cur_copy, str(i + 1), (5, 25), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 0, 0), 1, cv2.LINE_AA) self.result.append(cur_copy) self.tracks = new_tracks previous_frame = current_frame prev_img = cur_img i += 1 print("Saving results") for frame in tqdm(set_01): frame.store() self.store_track("res_track.txt", sequence) if store_footage: print("Creating video") self.store_footage(sequence=sequence, fps=3)