Exemplo n.º 1
0
def main():
    args = parse_args()

    print('Called with args:')
    print(args)

    if args.dataset == "pascal_voc":
        args.imdbval_name = "voc_2007_test"
        set_cfgs = [
            'ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]'
        ]
    elif args.dataset == "pascal_voc_0712":
        args.imdbval_name = "voc_0712_test"
        set_cfgs = [
            'ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]'
        ]
    elif args.dataset == "coco":
        args.imdbval_name = "coco_2014_minival"
        set_cfgs = [
            'ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]'
        ]
    elif args.dataset == "imagenet":
        args.imdbval_name = "imagenet_val"
        set_cfgs = [
            'ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]'
        ]
    elif args.dataset == "vg":
        args.imdbval_name = "vg_150-50-50_minival"
        set_cfgs = [
            'ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]'
        ]
    elif 'mot_2017' in args.dataset or 'mot19_cvpr' in args.dataset:
        set_cfgs = [
            'ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]'
        ]
    # elif args.dataset == "mot_2017_train":
    #     # args.imdb_name = "mot_2017_train"
    #     args.imdbval_name = "mot_2017_train"
    #     set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
    # elif args.dataset == "mot_2017_small_train":
    #     # args.imdb_name = "mot_2017_small_train"
    #     args.imdbval_name = "mot_2017_small_train"
    #     set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
    # elif args.dataset == "mot_2017_small_val":
    #     # args.imdb_name = "mot_2017_small_train"
    #     args.imdbval_name = "mot_2017_small_val"
    #     set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
    # elif args.dataset == "mot_2017_all":
    #     # args.imdb_name = "mot_2017_small_train"
    #     args.imdbval_name = "mot_2017_all"
    #     set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
    else:
        raise NotImplementedError

    input_dir = os.path.join(args.load_dir, args.net, args.dataset,
                             args.exp_name)
    if not os.path.exists(input_dir):
        raise Exception(
            'There is no input directory for loading network from ' +
            input_dir)
    load_name = os.path.join(input_dir,
                             f"fpn_{args.checksession}_{args.checkepoch}.pth")

    cfg_file = os.path.join(input_dir, 'config.yaml')
    cfg_from_file(cfg_file)
    cfg_from_list(set_cfgs)

    print('Using config:')
    pprint.pprint(cfg)
    np.random.seed(cfg.RNG_SEED)

    # data
    # cfg.TRAIN.USE_FLIPPED = False
    imdb, roidb, ratio_list, ratio_index = combined_roidb(
        args.imdbval_name, False)
    imdb.competition_mode(on=True)
    print('{:d} roidb entries'.format(len(roidb)))
    output_dir = get_output_dir(imdb, args.exp_name)
    dataset = roibatchLoader(roidb,
                             ratio_list,
                             ratio_index,
                             1,
                             imdb.num_classes,
                             training=False,
                             normalize=False)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True)

    # network
    print("load checkpoint %s" % (load_name))

    if args.net == 'res101':
        FPN = FPNResNet(imdb.classes, 101, pretrained=False)
    elif args.net == 'res50':
        FPN = FPNResNet(imdb.classes, 50, pretrained=False)
    elif args.net == 'res152':
        FPN = FPNResNet(imdb.classes, 152, pretrained=False)
    else:
        print("Network is not defined.")
        pdb.set_trace()

    FPN.create_architecture()
    FPN.load_state_dict(torch.load(load_name)['model'])
    print('load model successfully!')

    if args.cuda:
        cfg.CUDA = True
        FPN.cuda()
    elif torch.cuda.is_available():
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    start = time.time()

    _t = {'im_detect': time.time(), 'misc': time.time()}
    det_file = os.path.join(output_dir, 'detections.pkl')

    all_boxes = validate(FPN,
                         dataloader,
                         imdb,
                         vis=args.vis,
                         cuda=args.cuda,
                         soft_nms=args.soft_nms,
                         score_thresh=args.score_thresh)

    print('Evaluating detections')
    imdb.evaluate_detections(all_boxes, output_dir)

    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    end = time.time()
    print("test time: %0.4fs" % (end - start))
Exemplo n.º 2
0
            '[4, 8, 16, 32, 64]', 'MAX_NUM_GT_BOXES', '100'
        ]
    elif args.dataset == "mot19_cvpr_seq_4":
        args.imdb_name = "mot19_cvpr_seq_train_4"
        args.imdbval_name = "mot19_cvpr_seq_val_4"
        set_cfgs = [
            'FPN_ANCHOR_SCALES', '[32, 64, 128, 256, 512]', 'FPN_FEAT_STRIDES',
            '[4, 8, 16, 32, 64]', 'MAX_NUM_GT_BOXES', '100'
        ]
    else:
        raise NotImplementedError

    # load config from pre file
    if args.pre_file is not None:
        cfg_file = args.pre_file
        cfg_from_file(cfg_file)

    # load changes from current config file
    cfg_file = f"cfgs/{args.net}{'_ls' if args.lscale else ''}.yml"
    cfg_from_file(cfg_file)

    # load changes from set_cfg list
    cfg_from_list(set_cfgs)
    cfg.CUDA = args.cuda

    print('Using config:')
    pprint.pprint(cfg)

    # set seeds and make deterministic
    torch.backends.cudnn.fastest = False
    torch.backends.cudnn.benchmark = False