示例#1
0
def main(opts):
    hvd.init()
    device = torch.device("cuda")  # support single GPU only
    train_opts = Struct(json.load(open(f'{opts.train_dir}/log/hps.json')))

    if 'paired' in train_opts.model:
        EvalDatasetCls = Nlvr2PairedEvalDataset
        eval_collate_fn = nlvr2_paired_eval_collate
        if train_opts.model == 'paired':
            ModelCls = UniterForNlvr2Paired
        elif train_opts.model == 'paired-attn':
            ModelCls = UniterForNlvr2PairedAttn
        else:
            raise ValueError('unrecognized model type')
    elif train_opts.model == 'triplet':
        EvalDatasetCls = Nlvr2TripletEvalDataset
        ModelCls = UniterForNlvr2Triplet
        eval_collate_fn = nlvr2_triplet_eval_collate
    else:
        raise ValueError('unrecognized model type')

    img_db = DetectFeatLmdb(opts.img_db, train_opts.conf_th, train_opts.max_bb,
                            train_opts.min_bb, train_opts.num_bb,
                            opts.compressed_db)
    txt_db = TxtTokLmdb(opts.txt_db, -1)
    dset = EvalDatasetCls(txt_db, img_db, train_opts.use_img_type)
    batch_size = (train_opts.val_batch_size
                  if opts.batch_size is None else opts.batch_size)
    sampler = TokenBucketSampler(dset.lens,
                                 bucket_size=BUCKET_SIZE,
                                 batch_size=batch_size,
                                 droplast=False)
    eval_dataloader = DataLoader(dset,
                                 batch_sampler=sampler,
                                 num_workers=opts.n_workers,
                                 pin_memory=opts.pin_mem,
                                 collate_fn=eval_collate_fn)
    eval_dataloader = PrefetchLoader(eval_dataloader)

    # Prepare model
    ckpt_file = f'{opts.train_dir}/ckpt/model_step_{opts.ckpt}.pt'
    checkpoint = torch.load(ckpt_file)
    model_config = UniterConfig.from_json_file(
        f'{opts.train_dir}/log/model.json')
    model = ModelCls(model_config, img_dim=IMG_DIM)
    model.init_type_embedding()
    model.load_state_dict(checkpoint, strict=False)
    model.to(device)
    model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')

    results = evaluate(model, eval_dataloader, device)
    # write results
    if not exists(opts.output_dir):
        os.makedirs(opts.output_dir)
    with open(f'{opts.output_dir}/results.csv', 'w') as f:
        for id_, ans in results:
            f.write(f'{id_},{ans}\n')
    print(f'all results written')
示例#2
0
    def load_model(self):
        # Load pretrained model
        if self.model_file:
            checkpoint = torch.load(self.model_file)
            LOGGER.info('Using UNITER model {}'.format(self.model_file))
        else:
            checkpoint = {}

        uniter_config = UniterConfig.from_json_file(self.config['config'])
        uniter_model = UniterModel(uniter_config, img_dim=IMG_DIM)

        self.model = MemeUniter(uniter_model=uniter_model,
                                hidden_size=uniter_model.config.hidden_size,
                                n_classes=self.config['n_classes'])
        self.model.load_state_dict(checkpoint['model_state_dict'])