示例#1
0
    # initialize solver
    with torch.no_grad():
        test_and_save_mask(net, test_dataloader)

    print('Finished!')

if __name__ == '__main__':
    cudnn.benchmark = True 
    args = parse_args()

    print('Called with args:')
    print(args)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if args.resume is not None:
        cfg.RESUME = args.resume 
    if args.weights is not None:
        cfg.MODEL = args.weights
    if args.exp_name is not None:
        cfg.EXP_NAME = args.exp_name 

    print('Using config:')
    pprint.pprint(cfg)

    cfg.SAVE_DIR = os.path.join(cfg.SAVE_DIR, cfg.EXP_NAME)
    print('Output will be saved to %s.' % cfg.SAVE_DIR)
示例#2
0
                        help='if self supervised checkpoints are used',
                        default='False',
                        type=str)

    args = parser.parse_args()
    return args


if __name__ == '__main__':

    args = parse_args()
    print(args)

    # load the configurations
    test_cfg_file = args.test_cfg_file
    cfg_from_file(test_cfg_file)

    # load the test objects
    print('Testing with objects: ')
    print(cfg.TEST.OBJECTS)
    obj_list = cfg.TEST.OBJECTS
    with open('./datasets/ycb_video_classes.txt', 'r') as class_name_file:
        obj_list_all = class_name_file.read().split('\n')

    # pf config files
    pf_config_files = sorted(glob.glob(args.pf_cfg_dir + '*yml'))
    cfg_list = []
    for obj in obj_list:
        obj_idx = obj_list_all.index(obj)
        train_config_file = args.train_cfg_dir + '{}.yml'.format(obj)
        pf_config_file = pf_config_files[obj_idx]
示例#3
0
def main():
    # load config
    cfg_from_file(args.config)

    dataset_root = os.path.join('dataset', args.dataset)

    # create model
    net = ModelBuilder()
    checkpoint = torch.load(args.model)
    if 'state_dict' in checkpoint:
        net.load_state_dict(checkpoint['state_dict'])
    else:
        net.load_state_dict(checkpoint)
    net.cuda().eval()
    # create dataset
    dataset = DatasetFactory.create_dataset(name=args.dataset,
                                            dataset_root=dataset_root,
                                            load_img=False)

    model_name = args.save_name
    total_lost = 0
    if args.dataset in ['VOT2016', 'VOT2018', 'VOT2019']:
        # restart tracking
        for v_idx, video in enumerate(dataset):
            if args.video != '':
                # test one special video
                if video.name != args.video:
                    continue
            frame_counter = 0
            lost_number = 0
            toc = 0
            pred_bboxes = []
            for idx, (img, gt_bbox) in enumerate(video):
                tic = cv2.getTickCount()
                if idx == frame_counter:
                    cx, cy, w, h = get_axis_aligned_bbox(np.array(gt_bbox))
                    target_pos, target_sz = np.array([cx,
                                                      cy]), np.array([w, h])
                    state = CGACD_init(img, target_pos, target_sz, net)
                    pred_bbox = cxy_wh_2_rect(state['target_pos'],
                                              state['target_sz'])
                    pred_bboxes.append(1)
                elif idx > frame_counter:
                    state = CGACD_track(state, img)
                    pred_bbox = cxy_wh_2_rect(state['target_pos'],
                                              state['target_sz'])
                    pred_polygon = [
                        pred_bbox[0], pred_bbox[1],
                        pred_bbox[0] + pred_bbox[2], pred_bbox[1],
                        pred_bbox[0] + pred_bbox[2],
                        pred_bbox[1] + pred_bbox[3], pred_bbox[0],
                        pred_bbox[1] + pred_bbox[3]
                    ]
                    overlap = vot_overlap(gt_bbox, pred_polygon,
                                          (img.shape[1], img.shape[0]))
                    if overlap > 0:
                        # not lost
                        pred_bboxes.append(pred_bbox)
                    else:
                        # lost object
                        pred_bboxes.append(2)
                        frame_counter = idx + 5  # skip 5 frames
                        lost_number += 1
                else:
                    pred_bboxes.append(0)
                toc += cv2.getTickCount() - tic
                if idx == 0:
                    cv2.destroyAllWindows()
                if args.vis and idx > frame_counter:
                    target_pos = state['target_pos']
                    target_sz = state['target_sz']
                    cv2.rectangle(img, (int(target_pos[0] - target_sz[0] / 2),
                                        int(target_pos[1] - target_sz[1] / 2)),
                                  (int(target_pos[0] + target_sz[0] / 2),
                                   int(target_pos[1] + target_sz[1] / 2)),
                                  (0, 255, 0), 3)
                    cv2.putText(img, str(idx), (40, 40),
                                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2,
                                cv2.LINE_AA)
                    cv2.imshow(video.name, img)
                    cv2.moveWindow(video.name, 100, 100)
                    key = cv2.waitKey(1)
                    if key == 27:
                        break
            toc /= cv2.getTickFrequency()
            # save results
            video_path = os.path.join('result', args.dataset, model_name,
                                      'baseline', video.name)
            if not os.path.isdir(video_path):
                os.makedirs(video_path)
            result_path = os.path.join(video_path,
                                       '{}_001.txt'.format(video.name))
            with open(result_path, 'w') as f:
                for x in pred_bboxes:
                    if isinstance(x, int):
                        f.write("{:d}\n".format(x))
                    else:
                        f.write(','.join([vot_float2str("%.4f", i)
                                          for i in x]) + '\n')
            print(
                '({:3d}) Video: {:12s} Time: {:4.1f}s Speed: {:3.1f}fps Lost: {:d}'
                .format(v_idx + 1, video.name, toc, idx / toc, lost_number))
            total_lost += lost_number
        print("{:s} total lost: {:d}".format(model_name, total_lost))
    else:
        # OPE tracking
        for v_idx, video in enumerate(dataset):
            if args.video != '':
                # test one special video
                if video.name != args.video:
                    continue
            toc = 0
            pred_bboxes = []
            track_times = []
            for idx, (img, gt_bbox) in enumerate(video):
                tic = cv2.getTickCount()
                if idx == 0:
                    if 'OTB' in args.dataset:
                        target_pos, target_sz = rect1_2_cxy_wh(gt_bbox)
                    else:
                        cx, cy, w, h = get_axis_aligned_bbox(np.array(gt_bbox))
                        target_pos, target_sz = np.array([cx, cy
                                                          ]), np.array([w, h])
                    state = CGACD_init(img, target_pos, target_sz, net)
                    if 'OTB' in args.dataset:
                        pred_bbox = cxy_wh_2_rect1(state['target_pos'],
                                                   state['target_sz'])
                    else:
                        pred_bbox = cxy_wh_2_rect(state['target_pos'],
                                                  state['target_sz'])
                    pred_bboxes.append(pred_bbox)
                else:
                    state = CGACD_track(state, img)
                    pred_bbox = cxy_wh_2_rect(state['target_pos'],
                                              state['target_sz'])
                    pred_bboxes.append(pred_bbox)
                toc += cv2.getTickCount() - tic
                track_times.append(
                    (cv2.getTickCount() - tic) / cv2.getTickFrequency())
                if idx == 0:
                    cv2.destroyAllWindows()
                if args.vis and idx > 0:
                    target_pos = state['target_pos']
                    target_sz = state['target_sz']
                    cv2.rectangle(img, (int(target_pos[0] - target_sz[0] / 2),
                                        int(target_pos[1] - target_sz[1] / 2)),
                                  (int(target_pos[0] + target_sz[0] / 2),
                                   int(target_pos[1] + target_sz[1] / 2)),
                                  (0, 255, 0), 3)
                    cv2.putText(img, str(idx), (40, 40),
                                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2,
                                cv2.LINE_AA)
                    cv2.imshow(video.name, img)
                    cv2.moveWindow(video.name, 100, 100)
                    key = cv2.waitKey(1)
                    if key == 27:
                        break
            toc /= cv2.getTickFrequency()
            if 'GOT-10k' == args.dataset:
                video_path = os.path.join('result', args.dataset, model_name,
                                          video.name)
                if not os.path.isdir(video_path):
                    os.makedirs(video_path)
                result_path = os.path.join(video_path,
                                           '{}_001.txt'.format(video.name))
                with open(result_path, 'w') as f:
                    for x in pred_bboxes:
                        f.write(','.join([str(i) for i in x]) + '\n')
                result_path = os.path.join(video_path,
                                           '{}_time.txt'.format(video.name))
                with open(result_path, 'w') as f:
                    for x in track_times:
                        f.write("{:.6f}\n".format(x))
            else:
                model_path = os.path.join('result', args.dataset, model_name)
                if not os.path.isdir(model_path):
                    os.makedirs(model_path)
                result_path = os.path.join(model_path,
                                           '{}.txt'.format(video.name))
                with open(result_path, 'w') as f:
                    for x in pred_bboxes:
                        f.write(','.join([str(i) for i in x]) + '\n')
            print('({:3d}) Video: {:12s} Time: {:5.1f}s Speed: {:3.1f}fps'.
                  format(v_idx + 1, video.name, toc, idx / toc))
示例#4
0
                        dest='dis_dir',
                        help='relative dir of the distration set',
                        default='../coco/val2017',
                        type=str)

    args = parser.parse_args()
    return args


if __name__ == '__main__':

    args = parse_args()
    print(args)

    cfg_file = '{}{}.yml'.format(args.cfg_dir, args.obj)
    cfg_from_file(cfg_file)

    print('Using config:')
    pprint.pprint(cfg)

    # device
    print('GPU device {:d}'.format(args.gpu_id))

    cfg.MODE = 'TRAIN'
    print(cfg.TRAIN.OBJECTS)
    print(cfg.TRAIN.RENDER_SZ)
    print(cfg.TRAIN.INPUT_IM_SIZE)

    if args.obj_ctg == 'ycb':
        model_path = './cad_models'
        dataset_train = ycb_multi_render_dataset(
示例#5
0
def main(args):
    cfg_from_file(args.config)
    cfg.save_name = args.save_name
    cfg.save_path = args.save_path
    cfg.resume_file = args.resume_file
    cfg.config = args.config
    cfg.batch_size = args.batch_size
    cfg.num_workers = args.num_workers
    save_path = join(args.save_path, args.save_name)
    if not exists(save_path):
        makedirs(save_path)
    resume_file = args.resume_file
    init_log('global', logging.INFO)
    add_file_handler('global', os.path.join(save_path, 'logs.txt'),
                     logging.INFO)
    logger.info("Version Information: \n{}\n".format(commit()))
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))
    start_epoch = 0

    model = ModelBuilder().cuda()
    if cfg.backbone.pretrained:
        load_pretrain(model.backbone,
                      join('pretrained_net', cfg.backbone.pretrained))

    train_dataset = Datasets()
    val_dataset = Datasets(is_train=False)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=False,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=False,
                                             drop_last=True)

    if resume_file:
        if isfile(resume_file):
            logger.info("=> loading checkpoint '{}'".format(resume_file))
            model, start_epoch = restore_from(model, resume_file)
            start_epoch = start_epoch + 1
            for i in range(start_epoch):
                train_loader.dataset.shuffle()
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, start_epoch - 1))
        else:
            logger.info("=> no checkpoint found at '{}'".format(resume_file))

    ngpus = torch.cuda.device_count()
    is_dataparallel = False
    if ngpus > 1:
        model = torch.nn.DataParallel(model, list(range(ngpus))).cuda()
        is_dataparallel = True

    if is_dataparallel:
        optimizer, lr_scheduler = build_opt_lr(model.module, start_epoch)
    else:
        optimizer, lr_scheduler = build_opt_lr(model, start_epoch)

    logger.info(lr_scheduler)
    logger.info("model prepare done")

    if args.log:
        writer = SummaryWriter(comment=args.save_name)

    for epoch in range(start_epoch, cfg.train.epoch):
        train_loader.dataset.shuffle()
        if (epoch == np.array(cfg.backbone.unfix_steps)
            ).sum() > 0 or epoch == cfg.train.pretrain_epoch:
            if is_dataparallel:
                optimizer, lr_scheduler = build_opt_lr(model.module, epoch)
            else:
                optimizer, lr_scheduler = build_opt_lr(model, epoch)
        lr_scheduler.step(epoch)
        record_dict_train = train(train_loader, model, optimizer, epoch)
        record_dict_val = validate(val_loader, model, epoch)
        message = 'Train Epoch: [{0}]\t'.format(epoch)
        for k, v in record_dict_train.items():
            message += '{name:s} {loss:.4f}\t'.format(name=k, loss=v)
        logger.info(message)
        message = 'Val Epoch: [{0}]\t'.format(epoch)
        for k, v in record_dict_val.items():
            message += '{name:s} {loss:.4f}\t'.format(name=k, loss=v)
        logger.info(message)

        if args.log:
            for k, v in record_dict_train.items():
                writer.add_scalar('train/' + k, v, epoch)
            for k, v in record_dict_val.items():
                writer.add_scalar('val/' + k, v, epoch)
        if is_dataparallel:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'cfg': cfg
                }, epoch, save_path)
        else:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'cfg': cfg
                }, epoch, save_path)