コード例 #1
0
ファイル: test.py プロジェクト: Chris210634/Siammask
def main():
    global args, logger, v_id
    args = parser.parse_args()
    cfg = load_config(args)

    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    device = torch.device('cuda' if (torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # setup dataset
    dataset = load_dataset(args.dataset)

    # VOS or VOT?
    if args.dataset in ['DAVIS','DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    total_lost = 0  # VOT
    iou_lists = []  # VOS
    speed_list = []

    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue

        if vos_enable:
            iou_list, speed = track_vos(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                                 args.mask, args.refine, args.dataset in ['DAVIS2017', 'ytb_vos'], device=device)
            iou_lists.append(iou_list)
        else:
            lost, speed = track_vot(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                             args.mask, args.refine, device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result
    if vos_enable:
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))
コード例 #2
0
ファイル: tune_vot.py プロジェクト: UM-ARM-Lab/SiamMask
def main():
    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    params = {'penalty_k': args.penalty_k,
              'window_influence': args.window_influence,
              'lr': args.lr,
              'instance_size': args.search_region}

    num_search = len(params['penalty_k']) * len(params['window_influence']) * \
        len(params['lr']) * len(params['instance_size'])

    print(params)
    print(num_search)

    cfg = load_config(args)
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        model = models.__dict__[args.arch](anchors=cfg['anchors'])

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    model = model.to(device)

    default_hp = cfg.get('hp', {})

    p = dict()

    p['network'] = model
    p['network_name'] = args.arch+'_'+args.resume.split('/')[-1].split('.')[0]
    p['dataset'] = args.dataset

    global ims, gt, image_files

    dataset_info = load_dataset(args.dataset)
    videos = list(dataset_info.keys())
    np.random.shuffle(videos)

    for video in videos:
        print(video)
        if isfile('finish.flag'):
            return

        p['video'] = video
        ims = None
        image_files = dataset_info[video]['image_files']
        gt = dataset_info[video]['gt']

        np.random.shuffle(params['penalty_k'])
        np.random.shuffle(params['window_influence'])
        np.random.shuffle(params['lr'])
        for penalty_k in params['penalty_k']:
            for window_influence in params['window_influence']:
                for lr in params['lr']:
                    for instance_size in params['instance_size']:
                        p['hp'] = default_hp.copy()
                        p['hp'].update({'penalty_k':penalty_k,
                                'window_influence':window_influence,
                                'lr':lr,
                                'instance_size': instance_size,
                                })
                        tune(p)
コード例 #3
0
def main():
    # 获取命令行参数信息
    global args, logger, v_id
    args = parser.parse_args()
    # 获取配置文件中配置信息:主要包括网络结构,超参数等
    cfg = load_config(args)
    # 初始化logxi信息,并将日志信息输入到磁盘文件中
    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)
    # 将相关的配置信息输入到日志文件中
    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    # 加载网络模型架构
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))
    # 加载网络模型参数
    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model = load_pretrain(model, args.resume)
    # 使用评估模式,将drop等激活
    model.eval()
    # 硬件信息
    device = torch.device('cuda' if (
        torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # 加载数据集 setup dataset
    dataset = load_dataset(args.dataset)

    # 这三种数据支持掩膜 VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    total_lost = 0  # VOT
    iou_lists = []  # VOS
    speed_list = []
    # 对数据进行处理
    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue
        # true 调用track_vos
        if vos_enable:
            # 如测试数据是['DAVIS2017', 'ytb_vos']时,会开启多目标跟踪
            iou_list, speed = track_vos(
                model,
                dataset[video],
                cfg['hp'] if 'hp' in cfg.keys() else None,
                args.mask,
                args.refine,
                args.dataset in ['DAVIS2017', 'ytb_vos'],
                device=device)
            iou_lists.append(iou_list)
        # False 调用track_vot
        else:
            lost, speed = track_vot(model,
                                    dataset[video],
                                    cfg['hp'] if 'hp' in cfg.keys() else None,
                                    args.mask,
                                    args.refine,
                                    device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result
    if vos_enable:
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(
                thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))
コード例 #4
0
ファイル: test.py プロジェクト: SrikrishnaBhat/SiamMask
def main():
    global args, logger, v_id
    args = parser.parse_args()
    cfg = load_config(args)

    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    device = torch.device('cuda' if (
        torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # setup dataset
    dataset = load_dataset(args.dataset, args.dir_type)

    # VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    # total_lost = 0  # VOT
    # iou_lists = []  # VOS
    # speed_list = []

    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue

        if vos_enable:
            iou_list, speed = track_vos(
                model,
                dataset[video],
                cfg['hp'] if 'hp' in cfg.keys() else None,
                args.mask,
                args.refine,
                args.dataset in ['DAVIS2017', 'ytb_vos'],
                device=device)
            # iou_lists.append(iou_list)
        else:
            lost, speed = track_vot(model,
                                    dataset[video],
                                    cfg['hp'] if 'hp' in cfg.keys() else None,
                                    args.mask,
                                    args.refine,
                                    device=device)
            total_lost += lost
コード例 #5
0
ファイル: tune1.py プロジェクト: dasari4321/SION
def main():
    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    params = {'penalty_k': args.penalty_k,
              'window_influence': args.window_influence,
              'lr': args.lr,
              'instance_size': args.search_region}

    num_search = len(params['penalty_k']) * len(params['window_influence']) * \
        len(params['lr']) * len(params['instance_size'])

    print(params)
    print(num_search)
    cfg.merge_from_file(args.config)

#    cfg = load_config(args)
    model = ModelPublish()

    # load model
#    model = load_pretrain(model, args.resume).cuda().eval()

    tracker = build_tracker(model)

    # if args.resume:
    #     assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
    #     model = load_pretrain(model, args.resume)
    model.eval()
    model = model.to(device)
    tracker = build_tracker(model)

    default_hp =  {
        "seg_thr": 0.30,
        "penalty_k": 0.04,
        "window_influence": 0.42,
        "lr": 0.25
    }
    p = dict()

    p['network'] =tracker
    p['network_name'] = args.arch+'_'+args.resume.split('/')[-1].split('.')[0]
    p['dataset'] = args.dataset
    p['hp'] = default_hp.copy()
    s = p['hp'].values()
    print([float(x) for x in s])

    global ims, gt, image_files

    dataset_info = load_dataset(args.dataset)
    videos = list(dataset_info.keys())
    np.random.shuffle(videos)

    for video in videos:
        print(video)
        if isfile('finish.flag'):
            return

        p['video'] = video
        ims = None
        image_files = dataset_info[video]['image_files']
        gt = dataset_info[video]['gt']

        np.random.shuffle(params['penalty_k'])
        np.random.shuffle(params['window_influence'])
        np.random.shuffle(params['lr'])
        for penalty_k in params['penalty_k']:
            for window_influence in params['window_influence']:
                for lr in params['lr']:
                    for instance_size in params['instance_size']:
                        p['hp'] = default_hp.copy()
                        p['hp'].update({'penalty_k':penalty_k,
                                'window_influence':window_influence,
                                'lr':lr,
                                'instance_size': instance_size,
                                })
                        tune(p)
        print([float(x) for x in s])
コード例 #6
0
def main():
    global args, logger, v_id  #全局变量
    args = parser.parse_args()  #args是test.py文件运行时,接受的参数
    cfg = load_config(args)  #加载 JSON 配置文件并设置args.arch的值。
    print(cfg)

    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log,
                         logging.INFO)  #add_file_handler 创建一个记录器并绑定文件句柄。

    logger = logging.getLogger('global')
    logger.info(args)

    # setup model         Custom 为论文实现的网络。如果不是“Custom”,加载 models 下指定的结构。
    if args.arch == 'Custom':  #args.arch参数,预训练模型的结构,命令行不给的话,默认为' ',
        from custom import Custom
        model = Custom(anchors=cfg['anchors']
                       )  #cfg是从config_vot.json的到的数据,所以跟踪时用的model.anchors字典中的数据
    else:
        parser.error('invalid architecture: {}'.format(args.arch))

    if args.resume:  #给了args.resume,如果args.resume不是文件,报错,
        assert isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model = load_pretrain(
            model, args.resume)  #args.resume是文件load_pretrain ,能够处理网络之间的不一致
    model.eval()
    device = torch.device('cuda' if (
        torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)

    # setup dataset,字典
    dataset = load_dataset(
        args.dataset)  #load_dataset 能够加载 VOT、DAVIS、ytb_vos 三种数据集。
    #仅以上三种数据源支持掩膜输出。

    # VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output  ,使用掩膜输出
    else:
        vos_enable = False

    total_lost = 0  # VOT  跟踪任务有损失函数
    iou_lists = []  # VOS  分割任务
    speed_list = []

    #v_id视频索引从1起,video是视频名字
    for v_id, video in enumerate(dataset.keys(), start=1):
        if v_id == 2:
            exit()
        if args.video != '' and video != args.video:  #不成立,args.video默认是' '
            continue

        if vos_enable:  #分割任务,,,,分割任务和跟踪任务只能选一个
            iou_list, speed = track_vos(
                model,
                dataset[video],
                cfg['hp'] if 'hp' in cfg.keys() else None,
                args.mask,
                args.refine,
                args.dataset in ['DAVIS2017', 'ytb_vos'],
                device=device)
            iou_lists.append(iou_list)  #iou_list是什么类型的数据???
        else:  #跟踪任务
            lost, speed = track_vot(model,
                                    dataset[video],
                                    cfg['hp'] if 'hp' in cfg.keys() else None,
                                    args.mask,
                                    args.refine,
                                    device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result记录最终结果
    if vos_enable:  #如果进行的是分割任务
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(
                thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))