예제 #1
0
파일: test.py 프로젝트: xurantju/densecap
def get_dataset(args):
    # process text
    if args.val_data_folder == 'validation':
        text_proc, raw_data = get_vocab_and_sentences(args.dataset_file,
                                                      args.max_sentence_len)
    else:
        raw_data = {}
        text_proc, _ = get_vocab_and_sentences(
            '/data6/users/xuran7/myresearch/dense_videocap/third_party/densecap/data/anet/anet_annotations_trainval.json',
            args.max_sentence_len)
        with open(args.dataset_file, 'r') as data_file:
            data_all = json.load(data_file)
        for i in range(len(data_all)):
            raw_data[data_all[i][2:]] = dict(subset='testing', annotations={})
    # Create the dataset and data loader instance
    test_dataset = ANetTestDataset(args.feature_root,
                                   args.slide_window_size,
                                   text_proc,
                                   raw_data,
                                   args.val_data_folder,
                                   learn_mask=args.learn_mask)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             collate_fn=anet_test_collate_fn)

    return test_loader, text_proc
def get_dataset(args):
    # process text
    text_proc, raw_data = get_vocab_and_sentences(args.dataset_file,
                                                  args.max_sentence_len)

    # Create the dataset and data loader instance
    train_dataset = ANetDataset(args,
                                args.train_data_folder,
                                text_proc,
                                raw_data,
                                test=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    valid_dataset = ANetDataset(args,
                                args.val_data_folder,
                                text_proc,
                                raw_data,
                                test=False)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)
    return train_loader, valid_loader
예제 #3
0
def get_dataset(args):
    # process text
    text_proc, raw_data = get_vocab_and_sentences(args.dataset_file,
                                                  args.max_sentence_len)

    # Create the dataset and data loader instance
    train_dataset = ANetDataset(args,
                                args.train_data_folder,
                                text_proc,
                                raw_data,
                                test=False)
    train_sampler = None

    # batch size forced to be 1 here, and batch is implemented in training with periodical zero_grad
    train_loader = DataLoader(train_dataset,
                              batch_size=1,
                              shuffle=True,
                              num_workers=args.num_workers)

    if args.calc_pos_neg:
        from model.loss_func import add_pem_cls_num, add_pem_reg_num, add_tem_num
        bm_mask = get_mask(args.temporal_scale, args.max_duration)
        num_pos_neg = [0 for i in range(7)]

        for i in range(len(train_dataset)):
            sentence, img_feat, label_confidence, label_start, label_end, _ = train_dataset[
                i]

            num_pos_neg[0:3] = add_pem_reg_num(label_confidence, bm_mask,
                                               num_pos_neg[0:3])
            num_pos_neg[3:5] = add_pem_cls_num(label_confidence, bm_mask,
                                               num_pos_neg[3:5])
            num_pos_neg[5:7] = add_tem_num(label_start, label_end,
                                           num_pos_neg[5:7])
        np.savetxt('results/num_pos_neg.txt', np.array(num_pos_neg))
    else:
        num_pos_neg = list(np.loadtxt('results/num_pos_neg.txt'))

    valid_dataset = ANetDataset(args,
                                args.val_data_folder,
                                text_proc,
                                raw_data,
                                test=False)

    valid_loader = DataLoader(valid_dataset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=args.num_workers)

    return train_loader, valid_loader, text_proc, train_sampler, num_pos_neg
예제 #4
0
def get_dataset(args):
    # process text
    text_proc, raw_data = get_vocab_and_sentences(args.dataset_file, args.max_sentence_len)

    # Create the dataset and data loader instance
    test_dataset = ANetTestDataset(args.feature_root,
                                   args.slide_window_size,
                                   text_proc, raw_data, args.val_data_folder,
                                   learn_mask=args.learn_mask)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False, num_workers=args.num_workers,
                             collate_fn=anet_test_collate_fn)

    return test_loader, text_proc
예제 #5
0
def get_dataset(args):
    # process text
    text_proc, raw_data = get_vocab_and_sentences(args.dataset_file,
                                                  args.max_sentence_len)

    # Create the dataset and data loader instance
    train_dataset = ANetDataset(
        args.feature_root,
        args.train_data_folder,
        args.slide_window_size,
        args.dur_file,
        args.kernel_list,
        text_proc,
        raw_data,
        args.pos_thresh,
        args.neg_thresh,
        args.stride_factor,
        args.dataset,
        save_samplelist=args.save_train_samplelist,
        load_samplelist=args.load_train_samplelist,
        sample_listpath=args.train_samplelist_path,
    )

    # dist parallel, optional
    args.distributed = args.world_size > 1
    if args.distributed and args.cuda:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                rank=args.rank,
                                world_size=args.world_size)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=(train_sampler is None),
                              sampler=train_sampler,
                              num_workers=args.num_workers,
                              collate_fn=anet_collate_fn)

    valid_dataset = ANetDataset(args.feature_root,
                                args.val_data_folder,
                                args.slide_window_size,
                                args.dur_file,
                                args.kernel_list,
                                text_proc,
                                raw_data,
                                args.pos_thresh,
                                args.neg_thresh,
                                args.stride_factor,
                                args.dataset,
                                save_samplelist=args.save_valid_samplelist,
                                load_samplelist=args.load_valid_samplelist,
                                sample_listpath=args.valid_samplelist_path)

    valid_loader = DataLoader(valid_dataset,
                              batch_size=args.valid_batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=anet_collate_fn)

    return train_loader, valid_loader, text_proc, train_sampler
예제 #6
0
def get_dataset(args):
    # process text
    text_proc, train_raw_data, val_raw_data = get_vocab_and_sentences(args.train_dataset_file, verbose=False)

    print("len vocab:", len(text_proc.vocab))

    # Create the dataset and data loader instance
    en_train_dataset = ANetDataset(args.feature_root,
                                args.train_data_folder,
                                text_proc, train_raw_data,
                                language="en",
                                save_samplelist=args.save_train_samplelist,
                                load_samplelist=args.load_train_samplelist,
                                sample_listpath=args.train_samplelist_path,
                                verbose=False)
    ch_train_dataset = ANetDataset(args.feature_root,
                                args.train_data_folder,
                                text_proc, train_raw_data,
                                language="ch",
                                save_samplelist=args.save_train_samplelist,
                                load_samplelist=args.load_train_samplelist,
                                sample_listpath=args.train_samplelist_path,
                                verbose=False)

    print("size of English train dataset:", len(en_train_dataset))
    print("size of Chinese train dataset:", len(ch_train_dataset))

    # dist parallel, optional
    args.distributed = args.world_size > 1
    if args.distributed and torch.cuda.is_available():
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    en_train_loader = DataLoader(en_train_dataset,
                              batch_size=args.batch_size,
                              shuffle=(train_sampler is None), sampler=train_sampler,
                              num_workers=args.num_workers,
                              collate_fn=anet_collate_fn)

    ch_train_loader = DataLoader(ch_train_dataset,
                              batch_size=args.batch_size,
                              shuffle=(train_sampler is None), sampler=train_sampler,
                              num_workers=args.num_workers,
                              collate_fn=anet_collate_fn)

    en_valid_dataset = ANetDataset(args.feature_root,
                                args.val_data_folder,
                                text_proc, val_raw_data,
                                language="en",
                                dset="validation",
                                save_samplelist=args.save_valid_samplelist,
                                load_samplelist=args.load_valid_samplelist,
                                sample_listpath=args.valid_samplelist_path,
                                verbose=False)

    en_valid_loader = DataLoader(en_valid_dataset,
                              batch_size=args.valid_batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=anet_collate_fn)
   
    ch_valid_dataset = ANetDataset(args.feature_root,
                                args.val_data_folder,
                                text_proc, val_raw_data,
                                language="ch",
                                dset="validation",
                                save_samplelist=args.save_valid_samplelist,
                                load_samplelist=args.load_valid_samplelist,
                                sample_listpath=args.valid_samplelist_path,
                                verbose=False)

    ch_valid_loader = DataLoader(ch_valid_dataset,
                              batch_size=args.valid_batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=anet_collate_fn)

    return ({"entrain":en_train_loader, "chtrain":ch_train_loader, 
            "envalid":en_valid_loader, "chvalid":ch_valid_loader}, text_proc, train_sampler)