def _build_writer(self, global_step=0): log_dir = self._hyper_params["log_dir"] ensure_dir(log_dir) self._state["writer"] = SummaryWriter( log_dir=log_dir, purge_step=global_step, filename_suffix="", )
def save_snapshot(self, ): r""" save snapshot for current epoch """ epoch = self._state["epoch"] snapshot_dir, snapshot_file = self._infer_snapshot_dir_file_from_epoch( epoch) snapshot_dict = { 'epoch': epoch, 'model_state_dict': unwrap_model(self._model).state_dict(), 'optimizer_state_dict': self._optimizer.state_dict() } ensure_dir(snapshot_dir) torch.save(snapshot_dict, snapshot_file) while not osp.exists(snapshot_file): logger.info("retrying") torch.save(snapshot_dict, snapshot_file) logger.info("Snapshot saved at: %s" % snapshot_file)
def test(self): r""" Run test """ # set dir self.tracker_name = self._hyper_params["exp_name"] for dataset_name in self._hyper_params["dataset_names"]: self.dataset_name = dataset_name self.tracker_dir = os.path.join(self._hyper_params["exp_save"], self.dataset_name) self.save_root_dir = os.path.join(self.tracker_dir, self.tracker_name, "baseline") ensure_dir(self.save_root_dir) # track videos self.run_tracker() # evaluation eval_result = self.evaluation('default_hp') return dict(main_performance=eval_result["JF"])
def test(self) -> Dict: r""" Run test """ # set dir self.tracker_name = self._hyper_params["exp_name"] test_result_dict = None for dataset_name in self._hyper_params["dataset_names"]: self.dataset_name = dataset_name self.tracker_dir = os.path.join(self._hyper_params["exp_save"], self.dataset_name) self.save_root_dir = os.path.join(self.tracker_dir, self.tracker_name, "baseline") ensure_dir(self.save_root_dir) # track videos self.run_tracker() # evaluation test_result_dict = self.evaluation() return test_result_dict
def track_single_video(self, tracker, video, v_id=0): r""" track frames in single video with VOT rules Arguments --------- tracker: PipelineBase pipeline video: str video name v_id: int video id """ vot_overlap = importlib.import_module( "siamfcpp.evaluation.vot_benchmark.pysot.utils.region", package="vot_overlap").vot_overlap vot_float2str = importlib.import_module( "siamfcpp.evaluation.vot_benchmark.pysot.utils.region", package="vot_float2str").vot_float2str regions = [] video = self.dataset[video] image_files, gt = video['image_files'], video['gt'] start_frame, end_frame, lost_times, toc = 0, len(image_files), 0, 0 for f, image_file in enumerate(tqdm(image_files)): im = vot_benchmark.get_img(image_file) im_show = im.copy().astype(np.uint8) tic = cv2.getTickCount() if f == start_frame: # init cx, cy, w, h = vot_benchmark.get_axis_aligned_bbox(gt[f]) location = vot_benchmark.cxy_wh_2_rect((cx, cy), (w, h)) tracker.init(im, location) regions.append(1 if 'VOT' in self.dataset_name else gt[f]) gt_polygon = None pred_polygon = None elif f > start_frame: # tracking location = tracker.update(im) gt_polygon = (gt[f][0], gt[f][1], gt[f][2], gt[f][3], gt[f][4], gt[f][5], gt[f][6], gt[f][7]) pred_polygon = (location[0], location[1], location[0] + location[2], location[1], location[0] + location[2], location[1] + location[3], location[0], location[1] + location[3]) b_overlap = vot_overlap(gt_polygon, pred_polygon, (im.shape[1], im.shape[0])) gt_polygon = ((gt[f][0], gt[f][1]), (gt[f][2], gt[f][3]), (gt[f][4], gt[f][5]), (gt[f][6], gt[f][7])) pred_polygon = ((location[0], location[1]), (location[0] + location[2], location[1]), (location[0] + location[2], location[1] + location[3]), (location[0], location[1] + location[3])) if b_overlap: regions.append(location) else: # lost regions.append(2) lost_times += 1 start_frame = f + 5 # skip 5 frames else: # skip regions.append(0) toc += cv2.getTickCount() - tic toc /= cv2.getTickFrequency() # save result result_dir = join(self.save_root_dir, video['name']) ensure_dir(result_dir) result_path = join(result_dir, '{:s}_001.txt'.format(video['name'])) with open(result_path, "w") as fin: for x in regions: fin.write("{:d}\n".format(x)) if isinstance(x, int) else \ fin.write(','.join([vot_float2str("%.4f", i) for i in x]) + '\n') logger.info( '({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps Lost: {:d} ' .format(v_id, video['name'], toc, f / toc, lost_times)) return lost_times, f / toc
if __name__ == '__main__': # parsing parser = make_parser() parsed_args = parser.parse_args() # experiment config exp_cfg_path = osp.realpath(parsed_args.config) root_cfg.merge_from_file(exp_cfg_path) # resolve config root_cfg = complete_path_wt_root_in_cfg(root_cfg, ROOT_PATH) root_cfg = root_cfg.train task, task_cfg = specify_task(root_cfg) task_cfg.freeze() # log config log_dir = osp.join(task_cfg.exp_save, task_cfg.exp_name, "logs") ensure_dir(log_dir) logger.configure( handlers=[ dict(sink=sys.stderr, level="INFO"), dict(sink=osp.join(log_dir, "train_log.txt"), enqueue=True, serialize=True, diagnose=True, backtrace=True, level="INFO") ], extra={"common_to_all": "default"}, ) # backup config logger.info("Load experiment configuration at: %s" % exp_cfg_path) logger.info(