Ejemplo n.º 1
0
def build_dataloader(dataset, collate_fn, is_train, opts):
    batch_size = opts.train_batch_size if is_train else opts.val_batch_size
    if is_train:
        sampler = TokenBucketSampler(
            dataset.lens,
            bucket_size=BUCKET_SIZE,
            batch_size=batch_size,
            droplast=is_train,
        )
        dataloader = DataLoader(
            dataset,
            batch_sampler=sampler,
            num_workers=opts.n_workers,
            pin_memory=opts.pin_mem,
            collate_fn=collate_fn,
        )
    else:
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=opts.n_workers,
            shuffle=False,
            pin_memory=opts.pin_mem,
            collate_fn=collate_fn,
        )
    dataloader = PrefetchLoader(dataloader)
    return dataloader
Ejemplo n.º 2
0
def main(opts):
    hvd.init()
    device = torch.device("cuda")  # support single GPU only
    train_opts = Struct(json.load(open(f'{opts.train_dir}/log/hps.json')))

    if 'paired' in train_opts.model:
        EvalDatasetCls = Nlvr2PairedEvalDataset
        eval_collate_fn = nlvr2_paired_eval_collate
        if train_opts.model == 'paired':
            ModelCls = UniterForNlvr2Paired
        elif train_opts.model == 'paired-attn':
            ModelCls = UniterForNlvr2PairedAttn
        else:
            raise ValueError('unrecognized model type')
    elif train_opts.model == 'triplet':
        EvalDatasetCls = Nlvr2TripletEvalDataset
        ModelCls = UniterForNlvr2Triplet
        eval_collate_fn = nlvr2_triplet_eval_collate
    else:
        raise ValueError('unrecognized model type')

    img_db = DetectFeatLmdb(opts.img_db, train_opts.conf_th, train_opts.max_bb,
                            train_opts.min_bb, train_opts.num_bb,
                            opts.compressed_db)
    txt_db = TxtTokLmdb(opts.txt_db, -1)
    dset = EvalDatasetCls(txt_db, img_db, train_opts.use_img_type)
    batch_size = (train_opts.val_batch_size
                  if opts.batch_size is None else opts.batch_size)
    sampler = TokenBucketSampler(dset.lens,
                                 bucket_size=BUCKET_SIZE,
                                 batch_size=batch_size,
                                 droplast=False)
    eval_dataloader = DataLoader(dset,
                                 batch_sampler=sampler,
                                 num_workers=opts.n_workers,
                                 pin_memory=opts.pin_mem,
                                 collate_fn=eval_collate_fn)
    eval_dataloader = PrefetchLoader(eval_dataloader)

    # Prepare model
    ckpt_file = f'{opts.train_dir}/ckpt/model_step_{opts.ckpt}.pt'
    checkpoint = torch.load(ckpt_file)
    model_config = UniterConfig.from_json_file(
        f'{opts.train_dir}/log/model.json')
    model = ModelCls(model_config, img_dim=IMG_DIM)
    model.init_type_embedding()
    model.load_state_dict(checkpoint, strict=False)
    model.to(device)
    model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')

    results = evaluate(model, eval_dataloader, device)
    # write results
    if not exists(opts.output_dir):
        os.makedirs(opts.output_dir)
    with open(f'{opts.output_dir}/results.csv', 'w') as f:
        for id_, ans in results:
            f.write(f'{id_},{ans}\n')
    print(f'all results written')
Ejemplo n.º 3
0
def build_dataloader(dataset, collate_fn, is_train, opts):
    if is_train:
        batch_size = opts.train_batch_size
    else:
        batch_size = opts.val_batch_size
    sampler = TokenBucketSampler(dataset.lens, bucket_size=BUCKET_SIZE,
                                 batch_size=batch_size, droplast=is_train)
    loader = DataLoader(dataset, batch_sampler=sampler,
                        num_workers=opts.n_workers, pin_memory=opts.pin_mem,
                        collate_fn=collate_fn)
    return loader
Ejemplo n.º 4
0
def create_dataloader(img_path, txt_path, batch_size, is_train,
                      dset_cls, collate_fn, opts):
    img_db = DetectFeatLmdb(img_path, opts.conf_th, opts.max_bb, opts.min_bb,
                            opts.num_bb, opts.compressed_db)
    txt_db = TxtTokLmdb(txt_path, opts.max_txt_len if is_train else -1)
    dset = dset_cls(txt_db, img_db, opts.use_img_type)
    sampler = TokenBucketSampler(dset.lens, bucket_size=BUCKET_SIZE,
                                 batch_size=batch_size, droplast=is_train)
    loader = DataLoader(dset, batch_sampler=sampler,
                        num_workers=opts.n_workers, pin_memory=opts.pin_mem,
                        collate_fn=collate_fn)
    return PrefetchLoader(loader)
def build_dataloader(dataset, collate_fn, is_train, opts):
    batch_size = (opts.train_batch_size if is_train else opts.val_batch_size)
    if is_train:
        train_sampler = WeightedRandomSampler(dataset.weights_by_class,
                                              len(dataset),
                                              replacement=True)
        dataloader = DataLoader(dataset,
                                sampler=train_sampler,
                                num_workers=opts.n_workers,
                                batch_size=32,
                                pin_memory=opts.pin_mem,
                                collate_fn=collate_fn)
    else:
        sampler = TokenBucketSampler(dataset.lens,
                                     bucket_size=BUCKET_SIZE,
                                     batch_size=batch_size,
                                     droplast=is_train)
        dataloader = DataLoader(dataset,
                                batch_sampler=sampler,
                                num_workers=opts.n_workers,
                                pin_memory=opts.pin_mem,
                                collate_fn=collate_fn)
    dataloader = PrefetchLoader(dataloader)
    return dataloader
Ejemplo n.º 6
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    hps_file = f"{opts.output_dir}/log/hps.json"
    model_opts = Struct(json.load(open(hps_file)))

    # train_examples = None
    ans2label_file = f"{opts.output_dir}/ckpt/ans2label.json"
    ans2label = json.load(open(ans2label_file))
    label2ans = {label: ans for ans, label in ans2label.items()}

    # load DBs and image dirs
    eval_img_db = DetectFeatLmdb(
        opts.img_db,
        model_opts.conf_th,
        model_opts.max_bb,
        model_opts.min_bb,
        model_opts.num_bb,
        opts.compressed_db,
    )
    eval_txt_db = TxtTokLmdb(opts.txt_db, -1)
    eval_dataset = VqaEvalDataset(len(ans2label), eval_txt_db, eval_img_db)

    # Prepare model
    if exists(opts.checkpoint):
        ckpt_file = opts.checkpoint
    else:
        ckpt_file = f"{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt"
    checkpoint = torch.load(ckpt_file)
    model = UniterForVisualQuestionAnswering.from_pretrained(
        f"{opts.output_dir}/log/model.json",
        checkpoint,
        img_dim=IMG_DIM,
        num_answer=len(ans2label),
    )
    model.to(device)
    if opts.fp16:
        model = amp.initialize(model, enabled=True, opt_level="O2")

    sampler = TokenBucketSampler(
        eval_dataset.lens,
        bucket_size=BUCKET_SIZE,
        batch_size=opts.batch_size,
        droplast=False,
    )
    eval_dataloader = DataLoader(
        eval_dataset,
        batch_sampler=sampler,
        num_workers=opts.n_workers,
        pin_memory=opts.pin_mem,
        collate_fn=vqa_eval_collate,
    )
    eval_dataloader = PrefetchLoader(eval_dataloader)

    val_log, results, logits = evaluate(model, eval_dataloader, label2ans,
                                        opts.save_logits)
    result_dir = f"{opts.output_dir}/results_test"
    if not exists(result_dir) and rank == 0:
        os.makedirs(result_dir)

    all_results = list(concat(all_gather_list(results)))
    if opts.save_logits:
        all_logits = {}
        for id2logit in all_gather_list(logits):
            all_logits.update(id2logit)
    if hvd.rank() == 0:
        with open(f"{result_dir}/"
                  f"results_{opts.checkpoint}_all.json", "w") as f:
            json.dump(all_results, f)
        if opts.save_logits:
            np.savez(f"{result_dir}/logits_{opts.checkpoint}_all.npz",
                     **all_logits)