def main(opts): hvd.init() device = torch.device("cuda") # support single GPU only train_opts = Struct(json.load(open(f'{opts.train_dir}/log/hps.json'))) if 'paired' in train_opts.model: EvalDatasetCls = Nlvr2PairedEvalDataset eval_collate_fn = nlvr2_paired_eval_collate if train_opts.model == 'paired': ModelCls = UniterForNlvr2Paired elif train_opts.model == 'paired-attn': ModelCls = UniterForNlvr2PairedAttn else: raise ValueError('unrecognized model type') elif train_opts.model == 'triplet': EvalDatasetCls = Nlvr2TripletEvalDataset ModelCls = UniterForNlvr2Triplet eval_collate_fn = nlvr2_triplet_eval_collate else: raise ValueError('unrecognized model type') img_db = DetectFeatLmdb(opts.img_db, train_opts.conf_th, train_opts.max_bb, train_opts.min_bb, train_opts.num_bb, opts.compressed_db) txt_db = TxtTokLmdb(opts.txt_db, -1) dset = EvalDatasetCls(txt_db, img_db, train_opts.use_img_type) batch_size = (train_opts.val_batch_size if opts.batch_size is None else opts.batch_size) sampler = TokenBucketSampler(dset.lens, bucket_size=BUCKET_SIZE, batch_size=batch_size, droplast=False) eval_dataloader = DataLoader(dset, batch_sampler=sampler, num_workers=opts.n_workers, pin_memory=opts.pin_mem, collate_fn=eval_collate_fn) eval_dataloader = PrefetchLoader(eval_dataloader) # Prepare model ckpt_file = f'{opts.train_dir}/ckpt/model_step_{opts.ckpt}.pt' checkpoint = torch.load(ckpt_file) model_config = UniterConfig.from_json_file( f'{opts.train_dir}/log/model.json') model = ModelCls(model_config, img_dim=IMG_DIM) model.init_type_embedding() model.load_state_dict(checkpoint, strict=False) model.to(device) model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') results = evaluate(model, eval_dataloader, device) # write results if not exists(opts.output_dir): os.makedirs(opts.output_dir) with open(f'{opts.output_dir}/results.csv', 'w') as f: for id_, ans in results: f.write(f'{id_},{ans}\n') print(f'all results written')
def load_model(self): # Load pretrained model if self.model_file: checkpoint = torch.load(self.model_file) LOGGER.info('Using UNITER model {}'.format(self.model_file)) else: checkpoint = {} uniter_config = UniterConfig.from_json_file(self.config['config']) uniter_model = UniterModel(uniter_config, img_dim=IMG_DIM) self.model = MemeUniter(uniter_model=uniter_model, hidden_size=uniter_model.config.hidden_size, n_classes=self.config['n_classes']) self.model.load_state_dict(checkpoint['model_state_dict'])