Ejemplo n.º 1
0
        help=r"completed epoch's number, latest or one model path")

    return parser


if __name__ == '__main__':
    # parsing
    parser = make_parser()
    parsed_args = parser.parse_args()
    # experiment config
    exp_cfg_path = osp.realpath(parsed_args.config)
    root_cfg.merge_from_file(exp_cfg_path)
    # resolve config
    root_cfg = complete_path_wt_root_in_cfg(root_cfg, ROOT_PATH)
    root_cfg = root_cfg.train
    task, task_cfg = specify_task(root_cfg)
    task_cfg.freeze()
    # log config
    log_dir = osp.join(task_cfg.exp_save, task_cfg.exp_name, "logs")
    ensure_dir(log_dir)
    logger.configure(
        handlers=[
            dict(sink=sys.stderr, level="INFO"),
            dict(sink=osp.join(log_dir, "train_log.txt"),
                 enqueue=True,
                 serialize=True,
                 diagnose=True,
                 backtrace=True,
                 level="INFO")
        ],
        extra={"common_to_all": "default"},
Ejemplo n.º 2
0
    parser.add_argument(
        '-r',
        '--resume',
        default="",
        help=r"completed epoch's number, latest or one model path")
    parsed_args = parser.parse_args()

    # experiment config
    exp_cfg_path = osp.realpath(parsed_args.config)
    root_cfg.merge_from_file(exp_cfg_path)

    # resolve config\
    ROOT_PATH = os.getcwd()
    root_cfg = complete_path_wt_root_in_cfg(root_cfg, ROOT_PATH)
    root_cfg = root_cfg.train
    task, task_cfg = specify_task(root_cfg)  # task = track
    task_cfg.freeze()

    # log config
    log_dir = osp.join(task_cfg.exp_save, task_cfg.exp_name, "logs")
    ensure_dir(log_dir)
    logger.configure(
        handlers=[
            dict(sink=sys.stderr, level="INFO"),
            dict(sink=osp.join(log_dir, "train_log.txt"),
                 enqueue=True,
                 serialize=True,
                 diagnose=True,
                 backtrace=True,
                 level="INFO")
        ],
Ejemplo n.º 3
0
    if args.model_path:
        root_cfg.test.track.model.task_model.SiamTrack.pretrain_model_path = args.model_path

    #print(root_cfg.test.track.model.task_model.SiamTrack.pretrain_model_path)

    logger.info("Load experiment configuration at: %s" % exp_cfg_path)
    ROOT_PATH = os.getcwd()

    root_cfg = complete_path_wt_root_in_cfg(root_cfg,
                                            ROOT_PATH)  #把cfg中的相对路径,变成了绝对路径

    root_cfg = root_cfg.test  #获取test的yaml root_cfg['test']
    #root_cfg['track']
    #root_cfg['vos']
    task, task_cfg = specify_task(root_cfg)  #获取任务 track or vos ,

    task_cfg.freeze()

    torch.multiprocessing.set_start_method('spawn', force=True)

    # build_siamfcpp_tester
    model = model_builder.build("track", task_cfg.model)
    # build pipeline
    pipeline = pipeline_builder.build("track", task_cfg.pipeline,
                                      model)  #配置超参数
    # build tester
    testers = tester_builder("track", task_cfg.tester, "tester", pipeline)

    for tester in testers:
        tester.test()
Ejemplo n.º 4
0
def main(args):
    root_cfg = cfg
    root_cfg.merge_from_file(args.config)
    logger.info("Load experiment configuration at: %s" % args.config)

    # resolve config
    root_cfg = complete_path_wt_root_in_cfg(root_cfg, ROOT_PATH)
    root_cfg = root_cfg.test
    task, task_cfg = specify_task(root_cfg)
    task_cfg.freeze()
    window_name = task_cfg.exp_name
    # build model
    model = model_builder.build(task, task_cfg.model)
    # build pipeline
    pipeline = pipeline_builder.build(task, task_cfg.pipeline, model)
    dev = torch.device(args.device)
    pipeline.set_device(dev)
    init_box = None
    template = None
    if len(args.init_bbox) == 4:
        init_box = args.init_bbox

    vw = None
    resize_ratio = args.resize
    dump_only = args.dump_only

    # create video stream
    if args.video == "webcam":
        logger.info("Starting video stream...")
        vs = cv2.VideoCapture(0)
        vs.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'))
    elif not osp.isfile(args.video):
        logger.info("Starting from video frame image files...")
        vs = ImageFileVideoStream(args.video, init_counter=args.start_index)
    else:
        logger.info("Starting from video file...")
        vs = cv2.VideoCapture(args.video)

    # create video writer to output video
    if args.output:
        if osp.isdir(args.output):
            vw = ImageFileVideoWriter(args.output)
        else:
            fourcc = cv2.VideoWriter_fourcc(*'MJPG')
            width, height = vs.get(3), vs.get(4)
            vw = cv2.VideoWriter(
                args.output, fourcc, 25,
                (int(width * resize_ratio), int(height * resize_ratio)))

    # loop over sequence
    while vs.isOpened():
        key = 255
        ret, frame = vs.read()
        logger.debug("frame: {}".format(ret))
        if ret:
            if template is not None:
                time_a = time.time()
                rect_pred = pipeline.update(frame)
                logger.debug(rect_pred)
                show_frame = frame.copy()
                time_cost = time.time() - time_a
                bbox_pred = xywh2xyxy(rect_pred)
                bbox_pred = tuple(map(int, bbox_pred))
                cv2.putText(show_frame,
                            "track cost: {:.4f} s".format(time_cost),
                            (128, 20), cv2.FONT_HERSHEY_COMPLEX, font_size,
                            (0, 0, 255), font_width)
                cv2.rectangle(show_frame, bbox_pred[:2], bbox_pred[2:],
                              (0, 255, 0))
                if template is not None:
                    show_frame[:128, :128] = template
            else:
                show_frame = frame
            show_frame = cv2.resize(
                show_frame,
                (int(show_frame.shape[1] * resize_ratio),
                 int(show_frame.shape[0] * resize_ratio)))  # resize
            if not dump_only:
                cv2.imshow(window_name, show_frame)
            if vw is not None:
                vw.write(show_frame)
        # catch key if
        if (init_box is None) or (vw is None):
            logger.debug("Press key s to select object.")
            key = cv2.waitKey(30) & 0xFF
        logger.debug("key: {}".format(key))
        if key == ord("q"):
            break
        # if the 's' key is selected, we are going to "select" a bounding
        # box to track
        elif key == ord("s"):
            # select the bounding box of the object we want to track (make
            # sure you press ENTER or SPACE after selecting the ROI)
            logger.debug("Select object to track")
            box = cv2.selectROI(window_name,
                                frame,
                                fromCenter=False,
                                showCrosshair=True)
            if box[2] > 0 and box[3] > 0:
                init_box = box
        elif key == ord("c"):
            logger.debug(
                "init_box/template released, press key s to select object.")
            init_box = None
            template = None
        if (init_box is not None) and (template is None):
            template = cv2.resize(
                frame[int(init_box[1]):int(init_box[1] + init_box[3]),
                      int(init_box[0]):int(init_box[0] + init_box[2])],
                (128, 128))
            pipeline.init(frame, init_box)
            logger.debug(
                "pipeline initialized with bbox : {}".format(init_box))
    vs.release()
    if vw is not None:
        vw.release()
    cv2.destroyAllWindows()
Ejemplo n.º 5
0

if __name__ == '__main__':
    # parsing
    parser = make_parser()
    parsed_args = parser.parse_args()

    # experiment config
    exp_cfg_path = os.path.realpath(parsed_args.config)
    root_cfg.merge_from_file(exp_cfg_path)
    logger.info("Load experiment configuration at: %s" % exp_cfg_path)

    # resolve config
    root_cfg = complete_path_wt_root_in_cfg(root_cfg, ROOT_PATH)
    root_cfg = root_cfg.test
    task, task_cfg_origin = specify_task(root_cfg)

    # hpo config
    with open(parsed_args.hpo_config, "r") as f:
        hpo_cfg = yaml.safe_load(f)
    hpo_cfg = hpo_cfg["test"][task]
    hpo_schedules = hpo.parse_hp_path_and_range(hpo_cfg)

    csv_file = os.path.join(hpo_cfg["exp_save"],
                            "hpo_{}.csv".format(task_cfg_origin["exp_name"]))

    while True:
        task_cfg = deepcopy(task_cfg_origin)
        hpo_exp_dict = hpo.sample_and_update_hps(task_cfg, hpo_schedules)
        if task == "track":
            testers = build_siamfcpp_tester(task_cfg)