def _build_writer(self, global_step=0):
     log_dir = self._hyper_params["log_dir"]
     ensure_dir(log_dir)
     self._state["writer"] = SummaryWriter(
         log_dir=log_dir,
         purge_step=global_step,
         filename_suffix="",
     )
Пример #2
0
 def save_snapshot(self, ):
     r""" 
     save snapshot for current epoch
     """
     epoch = self._state["epoch"]
     snapshot_dir, snapshot_file = self._infer_snapshot_dir_file_from_epoch(
         epoch)
     snapshot_dict = {
         'epoch': epoch,
         'model_state_dict': unwrap_model(self._model).state_dict(),
         'optimizer_state_dict': self._optimizer.state_dict()
     }
     ensure_dir(snapshot_dir)
     torch.save(snapshot_dict, snapshot_file)
     while not osp.exists(snapshot_file):
         logger.info("retrying")
         torch.save(snapshot_dict, snapshot_file)
     logger.info("Snapshot saved at: %s" % snapshot_file)
Пример #3
0
 def test(self):
     r"""
     Run test
     """
     # set dir
     self.tracker_name = self._hyper_params["exp_name"]
     for dataset_name in self._hyper_params["dataset_names"]:
         self.dataset_name = dataset_name
         self.tracker_dir = os.path.join(self._hyper_params["exp_save"],
                                         self.dataset_name)
         self.save_root_dir = os.path.join(self.tracker_dir,
                                           self.tracker_name, "baseline")
         ensure_dir(self.save_root_dir)
         # track videos
         self.run_tracker()
         # evaluation
         eval_result = self.evaluation('default_hp')
     return dict(main_performance=eval_result["JF"])
Пример #4
0
 def test(self) -> Dict:
     r"""
     Run test
     """
     # set dir
     self.tracker_name = self._hyper_params["exp_name"]
     test_result_dict = None
     for dataset_name in self._hyper_params["dataset_names"]:
         self.dataset_name = dataset_name
         self.tracker_dir = os.path.join(self._hyper_params["exp_save"],
                                         self.dataset_name)
         self.save_root_dir = os.path.join(self.tracker_dir,
                                           self.tracker_name, "baseline")
         ensure_dir(self.save_root_dir)
         # track videos
         self.run_tracker()
         # evaluation
         test_result_dict = self.evaluation()
     return test_result_dict
Пример #5
0
    def track_single_video(self, tracker, video, v_id=0):
        r"""
        track frames in single video with VOT rules

        Arguments
        ---------
        tracker: PipelineBase
            pipeline
        video: str
            video name
        v_id: int
            video id
        """
        vot_overlap = importlib.import_module(
            "siamfcpp.evaluation.vot_benchmark.pysot.utils.region",
            package="vot_overlap").vot_overlap
        vot_float2str = importlib.import_module(
            "siamfcpp.evaluation.vot_benchmark.pysot.utils.region",
            package="vot_float2str").vot_float2str
        regions = []
        video = self.dataset[video]
        image_files, gt = video['image_files'], video['gt']
        start_frame, end_frame, lost_times, toc = 0, len(image_files), 0, 0
        for f, image_file in enumerate(tqdm(image_files)):
            im = vot_benchmark.get_img(image_file)
            im_show = im.copy().astype(np.uint8)

            tic = cv2.getTickCount()
            if f == start_frame:  # init
                cx, cy, w, h = vot_benchmark.get_axis_aligned_bbox(gt[f])
                location = vot_benchmark.cxy_wh_2_rect((cx, cy), (w, h))
                tracker.init(im, location)
                regions.append(1 if 'VOT' in self.dataset_name else gt[f])
                gt_polygon = None
                pred_polygon = None
            elif f > start_frame:  # tracking
                location = tracker.update(im)

                gt_polygon = (gt[f][0], gt[f][1], gt[f][2], gt[f][3], gt[f][4],
                              gt[f][5], gt[f][6], gt[f][7])
                pred_polygon = (location[0], location[1],
                                location[0] + location[2], location[1],
                                location[0] + location[2],
                                location[1] + location[3], location[0],
                                location[1] + location[3])
                b_overlap = vot_overlap(gt_polygon, pred_polygon,
                                        (im.shape[1], im.shape[0]))
                gt_polygon = ((gt[f][0], gt[f][1]), (gt[f][2], gt[f][3]),
                              (gt[f][4], gt[f][5]), (gt[f][6], gt[f][7]))
                pred_polygon = ((location[0], location[1]),
                                (location[0] + location[2],
                                 location[1]), (location[0] + location[2],
                                                location[1] + location[3]),
                                (location[0], location[1] + location[3]))

                if b_overlap:
                    regions.append(location)
                else:  # lost
                    regions.append(2)
                    lost_times += 1
                    start_frame = f + 5  # skip 5 frames
            else:  # skip
                regions.append(0)
            toc += cv2.getTickCount() - tic

        toc /= cv2.getTickFrequency()

        # save result
        result_dir = join(self.save_root_dir, video['name'])
        ensure_dir(result_dir)
        result_path = join(result_dir, '{:s}_001.txt'.format(video['name']))
        with open(result_path, "w") as fin:
            for x in regions:
                fin.write("{:d}\n".format(x)) if isinstance(x, int) else \
                    fin.write(','.join([vot_float2str("%.4f", i) for i in x]) + '\n')

        logger.info(
            '({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps Lost: {:d} '
            .format(v_id, video['name'], toc, f / toc, lost_times))

        return lost_times, f / toc
Пример #6
0
if __name__ == '__main__':
    # parsing
    parser = make_parser()
    parsed_args = parser.parse_args()
    # experiment config
    exp_cfg_path = osp.realpath(parsed_args.config)
    root_cfg.merge_from_file(exp_cfg_path)
    # resolve config
    root_cfg = complete_path_wt_root_in_cfg(root_cfg, ROOT_PATH)
    root_cfg = root_cfg.train
    task, task_cfg = specify_task(root_cfg)
    task_cfg.freeze()
    # log config
    log_dir = osp.join(task_cfg.exp_save, task_cfg.exp_name, "logs")
    ensure_dir(log_dir)
    logger.configure(
        handlers=[
            dict(sink=sys.stderr, level="INFO"),
            dict(sink=osp.join(log_dir, "train_log.txt"),
                 enqueue=True,
                 serialize=True,
                 diagnose=True,
                 backtrace=True,
                 level="INFO")
        ],
        extra={"common_to_all": "default"},
    )
    # backup config
    logger.info("Load experiment configuration at: %s" % exp_cfg_path)
    logger.info(