コード例 #1
0
def video_transform(input_config, clip_input):
    normalize = transforms.Normalize(mean=input_config['mean'], std=input_config['std'])
    video_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop((224, 224)),
        # transforms.CenterCrop((224, 224)), # we did not use center crop in our paper
        # transforms.RandomHorizontalFlip(), # we did not use mirror in our paper
        transforms.ToTensor(),
        normalize,
    ])

    return video_transform(clip_input)
コード例 #2
0
    # network
    if torch.cuda.is_available():
        cudnn.benchmark = True
        sym_net = torch.nn.DataParallel(sym_net).cuda()
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        sym_net = torch.nn.DataParallel(sym_net)
        criterion = torch.nn.CrossEntropyLoss()
    net = static_model(net=sym_net,
                       criterion=criterion,
                       model_prefix=args.model_prefix)
    net.load_checkpoint(epoch=args.load_epoch)

    # data iterator:
    data_root = "../dataset/{}".format(args.dataset)
    normalize = transforms.Normalize(mean=input_config['mean'],
                                     std=input_config['std'])
    val_sampler = sampler.RandomSampling(num=args.clip_length,
                                         interval=args.frame_interval,
                                         speed=[1.0, 1.0])
    val_loader = VideoIter(
        video_prefix=os.path.join(data_root, 'raw',
                                  'data'),  # change this part accordingly
        txt_list=os.path.join(
            data_root, 'raw', 'list_cvt',
            'testlist01.txt'),  # change this part accordingly
        sampler=val_sampler,
        force_color=True,
        video_transform=transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            # transforms.CenterCrop((224, 224)), # we did not use center crop in our paper
コード例 #3
0
def search_result(video_path):
    video_path = "./static/data/" + video_path
    b_time = time.time()

    # set args
    args = parser.parse_args()
    args = autofill(args)

    set_logger(log_file=args.log_file, debug_mode=args.debug_mode)
    logging.info("Start evaluation with args:\n" +
                 json.dumps(vars(args), indent=4, sort_keys=True))

    # set device states
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpus)  # before using torch
    assert torch.cuda.is_available(), "CUDA is not available"

    # load dataset related configuration
    dataset_cfg = dataset.get_config(name=args.dataset)
    # number_class=51

    # creat model
    sym_net, input_config = get_symbol(name=args.network, use_flow=False, **dataset_cfg)

    # network
    if torch.cuda.is_available():
        cudnn.benchmark = True
        sym_net = torch.nn.DataParallel(sym_net).cuda()
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        sym_net = torch.nn.DataParallel(sym_net)
        criterion = torch.nn.CrossEntropyLoss()
    net = static_model(net=sym_net,
                       criterion=criterion,
                       model_prefix=args.model_prefix)
    net.load_checkpoint(epoch=args.load_epoch)
    m_time = time.time()

    dict_name_label = get_name_label()
    Video_list, feature_list = get_feature_dict()
    all_feature = np.array(feature_list)
    d_time = time.time()

    get_query(video_path)
    extract_query_frame()
    data_root = "./query/"
    query_names = os.listdir(data_root + "videos")
    txt_path = "./query/list_cvt/search.txt"
    if os.path.exists(txt_path):
        os.remove(txt_path)
    with open(txt_path, "w")as f:
        for i in range(len(query_names)):
            f.write(str(i) + "\t" + "0" + "\t" + query_names[i] + "\n")

    normalize = transforms.Normalize(mean=input_config['mean'], std=input_config['std'])
    val_sampler = sampler.RandomSampling(num=args.clip_length,
                                         interval=args.frame_interval,
                                         speed=[1.0, 1.0])
    val_loader = VideoIter(video_prefix=os.path.join(data_root, 'videos'),
                           frame_prefix=os.path.join(data_root, 'frames'),
                           txt_list=os.path.join(data_root, 'list_cvt', 'search.txt'),
                           sampler=val_sampler,
                           force_color=True,
                           video_transform=transforms.Compose([
                               transforms.Resize((256, 256)),
                               transforms.CenterCrop((224, 224)),
                               transforms.ToTensor(),
                               normalize,
                           ]),
                           name='test',
                           return_item_subpath=True
                           )

    eval_iter = torch.utils.data.DataLoader(val_loader,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            num_workers=1,  # change this part accordingly
                                            pin_memory=True)

    net.net.eval()
    avg_score = {}
    sum_batch_elapse = 0.
    sum_batch_inst = 0
    duplication = 1
    softmax = torch.nn.Softmax(dim=1)
    pr_time = time.time()
    # print("preprocessing video time:" ,pv_time-lm_time)

    total_round = 1  # change this part accordingly if you do not want an inf loop

    for i_round in range(total_round):
        list_Ap = []
        i_batch = 0
        dict_q_r = {}
        # dict_AP={}
        for data, target, video_subpath in eval_iter:

            # print(video_subpath)
            batch_start_time = time.time()
            feature = net.get_feature(data)
            feature = feature.detach().cpu().numpy()

            for i in range(len(video_subpath)):
                dict_info = {}
                V_feature = feature[i]
                topN_re = get_top_N(Video_list, all_feature, args.topN, V_feature)
                dict_info["result"] = topN_re
                if video_subpath[i] in dict_name_label.keys():
                    tmp_AP10 = cal_AP(topN_re[:10], dict_name_label[video_subpath[i]])
                    tmp_AP50 = cal_AP(topN_re[:50], dict_name_label[video_subpath[i]])
                    tmp_AP200 = cal_AP(topN_re[:200], dict_name_label[video_subpath[i]])
                else:
                    print("video is not in the database, AP=0")
                    tmp_AP10 = 0
                    tmp_AP50 = 0
                    tmp_AP200 = 0
                print(video_subpath[i], str(tmp_AP10), str(tmp_AP50), str(tmp_AP200))
                list_Ap = [tmp_AP10, tmp_AP50, tmp_AP200]
                dict_info["AP"] = list_Ap
                dict_q_r[video_subpath[i]] = dict_info
            batch_end_time = time.time()
            dict_q_r[video_subpath[0]]["time"] = batch_end_time - batch_start_time + pr_time - d_time
            dict_q_r[video_subpath[0]]["lmtime"] = m_time - b_time
            dict_q_r[video_subpath[0]]["datatime"] = d_time - m_time
            json.dump(dict_q_r, open("q_r.json", "w"))

    return dict_q_r
コード例 #4
0
        sym_net = torch.nn.DataParallel(sym_net)
        criterion = torch.nn.CrossEntropyLoss()
        criterion_domain = torch.nn.CrossEntropyLoss()
    net = static_model(net=sym_net,
                       criterion=criterion,
                       criterion_domain=criterion_domain,
                       DA_method=None,
                       model_prefix=args.model_prefix)
    net.load_checkpoint(epoch=args.load_epoch)

    # data iterator:
    data_root = "../dataset/{}".format(args.target_dataset)
    video_location = os.path.join(data_root, 'raw', 'test_data')

    target_mean, target_std = input_config['mean'], input_config['std']
    normalize = transforms.Normalize(mean=target_mean, std=target_std)

    val_sampler = sampler.RandomSampling(num=args.clip_length,
                                         interval=args.frame_interval,
                                         speed=[1.0, 1.0])
    val_loader = VideoIter(
        video_prefix=video_location,
        csv_list=os.path.join(data_root, 'raw', 'list_cvt', args.list_file),
        sampler=val_sampler,
        force_color=True,
        video_transform=transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            transforms.ToTensor(),
            normalize,
        ]),