def setup_model(cfg, device=None): LOGGER.info("Setup model...") # has to be a BertConfig instance model_cfg = load_json(cfg.model_config) model_cfg = BertConfig(**model_cfg) # add downstream model config add_attr_list = [ "num_labels", "classifier", "cls_hidden_scale", "loss_type" ] for k in add_attr_list: setattr(model_cfg, k, cfg[k]) # we separate the CNN and the transformer in order to use different optimizer for each # transformer still has a CNN layer inside, used to down sample grid. LOGGER.info("setup e2e model") model = ClipBert(model_cfg, input_format=cfg.img_input_format, detectron2_model_cfg=cfg.detectron2_model_cfg, transformer_cls=ClipBertForSequenceClassification) if cfg.e2e_weights_path: LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}") load_state_dict_with_mismatch(model, cfg.e2e_weights_path) else: LOGGER.info(f"Loading cnn weights from {cfg.detectron2_weights_path}") LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}") model.load_separate_ckpt(cnn_weights_path=cfg.detectron2_weights_path, bert_weights_path=cfg.bert_weights_path) if cfg.freeze_cnn: model.freeze_cnn_backbone() model.to(device) LOGGER.info("Setup model done!") return model
def setup_model(cfg, device=None): LOGGER.info("Setup model...") # has to be a BertConfig instance model_cfg = load_json(cfg.model_config) model_cfg = BertConfig(**model_cfg) # add pixel random sampling, only for pre-training model_cfg.pixel_random_sampling_size = cfg.pixel_random_sampling_size # add model-specific config add_attr_list = [ "pixel_random_sampling_size", ] for k in add_attr_list: setattr(model_cfg, k, cfg[k]) LOGGER.info(f"model_cfg {pprint.pformat(model_cfg.to_dict())}") # we separate the CNN and the transformer in order to use different optimizer for each # transformer still has a CNN layer inside, used to down sample grid. LOGGER.info("setup e2e model") model = ClipBert(model_cfg, input_format=cfg.img_input_format, detectron2_model_cfg=cfg.detectron2_model_cfg, transformer_cls=ClipBertForPreTraining) if cfg.e2e_weights_path: LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}") load_state_dict_with_mismatch(model, cfg.e2e_weights_path) else: LOGGER.info(f"Loading cnn weights from {cfg.detectron2_weights_path}") LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}") model.load_separate_ckpt(cnn_weights_path=cfg.detectron2_weights_path, bert_weights_path=cfg.bert_weights_path) if cfg.freeze_cnn: model.freeze_cnn_backbone() model.to(device) LOGGER.info("Setup model done!") return model
def start_inference(cfg): set_random_seed(cfg.seed) n_gpu = hvd.size() device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(hvd.local_rank()) if hvd.rank() != 0: LOGGER.disabled = True inference_res_dir = join( cfg.output_dir, f"mc_results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/" f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}" ) if hvd.rank() == 0: os.makedirs(inference_res_dir, exist_ok=True) save_json(cfg, join(inference_res_dir, "raw_args.json"), save_pretty=True) LOGGER.info("device: {} n_gpu: {}, rank: {}, " "16-bits training: {}".format(device, n_gpu, hvd.rank(), bool(cfg.fp16))) # overwrite cfg with stored_cfg, # but skip keys containing the keyword 'inference' stored_cfg_path = join(cfg.output_dir, "log/args.json") stored_cfg = edict(load_json(stored_cfg_path)) for k, v in cfg.items(): if k in stored_cfg and "inference" not in k and "output_dir" not in k: setattr(cfg, k, stored_cfg[k]) # setup models cfg.model_config = join(cfg.output_dir, "log/model_config.json") e2e_weights_path = join(cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt") cfg.e2e_weights_path = e2e_weights_path model = setup_model(cfg, device=device) model.eval() # FIXME separate scaling for each loss model = amp.initialize(model, enabled=cfg.fp16, opt_level='O2') global_step = 0 # prepare data tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir) cfg.data_ratio = 1. val_loader = mk_msrvtt_mc_eval_dataloader( anno_path=cfg.inference_txt_db, lmdb_dir=cfg.inference_img_db, cfg=cfg, tokenizer=tokenizer, ) LOGGER.info(cfg) LOGGER.info("Starting inference...") LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****") LOGGER.info(f" Batch size = {cfg.inference_batch_size}") LOGGER.info(f'Step {global_step}: start validation') ret_results, ret_scores = inference_retrieval_mc(model, val_loader, cfg.inference_txt_db, cfg) if hvd.rank() == 0: save_json(cfg, join(inference_res_dir, "merged_args.json"), save_pretty=True) save_json(ret_results, join(inference_res_dir, "mc_test_results.json"), save_pretty=True) save_json(ret_scores, join(inference_res_dir, "mc_test_scores.json"), save_pretty=True)
def inference_retrieval_mc(model, val_loader, eval_file_path, cfg, n_options=5): model.eval() pred_id2ans = dict() st = time.time() LOGGER.info(f"Evaluate retrieval MC: {len(val_loader)}") if hvd.rank() == 0: pbar = tqdm(total=len(val_loader), desc="eval") for batch in val_loader: # compile shared text part question_ids = batch["question_ids"] bsz = len(question_ids) del batch["question_ids"] mini_batch = dict() for k, v in batch.items(): if k not in ["visual_inputs", "meta"]: mini_batch[k] = v # multi-frame test, scores across frames of the same video will be pooled together # batch["visual_inputs"] (B, T, C, H, W) pool_method = cfg.score_agg_func # could be 1, where only a single clip is evaluated num_clips = cfg.inference_n_clips num_frm = cfg.num_frm # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W) new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:] visual_inputs = batch["visual_inputs"].view(*new_visual_shape) logits = [] for clip_idx in range(num_clips): mini_batch["visual_inputs"] = visual_inputs[:, clip_idx] mini_batch["n_examples_list"] = batch["n_examples_list"] outputs = forward_step(model, mini_batch, cfg, n_options=n_options) logits.append(outputs["logits"].cpu()) logits = torch.stack(logits) # (num_frm, B, 1 or 2) if pool_method == "mean": logits = logits.mean(0) # (B, 1 or 2) elif pool_method == "max": logits = logits.max(0)[0] # (B, 1 or 2) elif pool_method == "lse": logits = logits.permute( 1, 0, 2).contiguous() # (B, num_frm, 5), pooling will be done in CE logits = torch.logsumexp( logits, dim=1) # torch.exp alone might be too large and unstable else: raise ValueError( f"Invalid value for pool_method, " f"got {pool_method}, expect one of [`mean`, `max`, `lse`]") if logits.shape[1] == 2: probs = F.softmax(logits, dim=1)[:, 1] else: probs = torch.sigmoid(logits.squeeze()) # B probs = probs.view(-1, n_options) # (B, 5) pred_answers = probs.max(1)[1].tolist() # (B, ) for qid, pred_ans in zip(question_ids, pred_answers): pred_id2ans[qid] = int(pred_ans) if hvd.rank() == 0: pbar.update(1) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) n_gpu = hvd.size() eval_dir = join( cfg.output_dir, f"results_mc_{os.path.splitext(os.path.basename(eval_file_path))[0]}") os.makedirs(eval_dir, exist_ok=True) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json( pred_id2ans, join(eval_dir, f"tmp_results_mc_rank{hvd.rank()}.json")) break except Exception as e: print(f"Saving exception: {e}") save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: pred_id2ans = [] for rk in range(n_gpu): pred_id2ans.append( load_json(join(eval_dir, f"tmp_results_mc_rank{rk}.json"))) pred_id2ans = merge_dicts(pred_id2ans) LOGGER.info('results joined') if hvd.rank() == 0: retrieval_qa_metrics = val_loader.dataset.evaluate_qa_accuracy( pred_id2ans, force_same=True) LOGGER.info( f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_qa_metrics}" ) else: retrieval_qa_metrics = None model.train() return pred_id2ans, retrieval_qa_metrics
def inference_retrieval(model, val_loader, eval_file_path, cfg): model.eval() retrieval_res = [] # list(dict): dict(vid_id=..., txt_id=..., score=...) st = time.time() eval_bsz = cfg.inference_batch_size if cfg.do_inference else cfg.eval_retrieval_batch_size LOGGER.info(f"Evaluate retrieval #video per GPU: {len(val_loader)}") if hvd.rank() == 0: pbar = tqdm(total=len(val_loader), desc="eval") for batch in val_loader: # each batch contains 1 video and N (=1000) captions n_mini_batches = math.ceil(len(batch["caption_ids"]) / eval_bsz) vid_id = batch["vid_id"] for idx in range(n_mini_batches): # compile shared text part mini_batch = dict() for k in ["text_input_ids", "text_input_mask", "labels"]: if batch[k] is not None: mini_batch[k] = batch[k][idx * eval_bsz:(idx + 1) * eval_bsz] else: mini_batch[k] = None caption_ids = batch["caption_ids"][idx * eval_bsz:(idx + 1) * eval_bsz] # bsz = len(caption_ids) mini_batch["n_examples_list"] = [len(caption_ids)] # multi-frame test, scores across frames of the same video will be pooled together pool_method = cfg.score_agg_func # could be 1, where only a single clip is evaluated num_clips = cfg.inference_n_clips num_frm = cfg.num_frm # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W) new_visual_shape = (1, num_clips, num_frm) + batch["visual_inputs"].shape[2:] visual_inputs = batch["visual_inputs"].view(*new_visual_shape) logits = [] for clip_idx in range(num_clips): mini_batch["visual_inputs"] = visual_inputs[:, clip_idx] outputs = forward_step(model, mini_batch, cfg) logits.append(outputs["logits"].cpu()) logits = torch.stack(logits) # (num_frm, B, 1 or 2) if pool_method == "mean": logits = logits.mean(0) # (B, 1 or 2) elif pool_method == "max": logits = logits.max(0)[0] # (B, 1 or 2) elif pool_method == "lse": logits = logits.permute(1, 0, 2).contiguous( ) # (B, num_frm, 5), pooling will be done in CE logits = torch.logsumexp( logits, dim=1) # torch.exp alone might be too large and unstable else: raise ValueError( f"Invalid value for pool_method, " f"got {pool_method}, expect one of [`mean`, `max`, `lse`]") if logits.shape[1] == 2: probs = F.softmax(logits, dim=1)[:, 1].tolist() else: probs = torch.sigmoid(logits.squeeze()).tolist() # B for cap_id, score in zip(caption_ids, probs): retrieval_res.append( dict(vid_id=vid_id, txt_id=cap_id, score=round(score, 4))) if hvd.rank() == 0: pbar.update(1) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) n_gpu = hvd.size() eval_dir = join( cfg.output_dir, f"results_{os.path.splitext(os.path.basename(eval_file_path))[0]}") os.makedirs(eval_dir, exist_ok=True) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json(retrieval_res, join(eval_dir, f"tmp_results_rank{hvd.rank()}.json")) break except Exception as e: print(f"Saving exception: {e}") save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: retrieval_res = [] for rk in range(n_gpu): retrieval_res.extend( load_json(join(eval_dir, f"tmp_results_rank{rk}.json"))) LOGGER.info('results joined') if hvd.rank() == 0: retrieval_metrics = eval_retrieval(retrieval_res, val_loader.dataset.gt_cap_id2vid_id, val_loader.dataset.id2data) LOGGER.info( f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_metrics}" ) else: retrieval_metrics = None model.train() return retrieval_res, retrieval_metrics
def start_inference(cfg): set_random_seed(cfg.seed) n_gpu = hvd.size() device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(hvd.local_rank()) if hvd.rank() != 0: LOGGER.disabled = True inference_res_dir = join( cfg.output_dir, f"results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/" f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}" ) if hvd.rank() == 0: os.makedirs(inference_res_dir, exist_ok=True) save_json(cfg, join(inference_res_dir, "raw_args.json"), save_pretty=True) LOGGER.info("device: {} n_gpu: {}, rank: {}, " "16-bits training: {}".format( device, n_gpu, hvd.rank(), bool(cfg.fp16))) # overwrite cfg with stored_cfg, # but skip keys containing the keyword 'inference' stored_cfg_path = join(cfg.output_dir, "log/args.json") stored_cfg = edict(load_json(stored_cfg_path)) for k, v in cfg.items(): if k in stored_cfg and "inference" not in k: setattr(cfg, k, stored_cfg[k]) # setup models cfg.model_config = join(cfg.output_dir, "log/model_config.json") e2e_weights_path = join( cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt") cfg.e2e_weights_path = e2e_weights_path model = setup_model(cfg, device=device) model.eval() # FIXME separate scaling for each loss model = amp.initialize( model, enabled=cfg.fp16, opt_level='O2') global_step = 0 # prepare data tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir) cfg.data_ratio = 1. val_loader = mk_tgif_qa_dataloader( task_type=cfg.task, anno_path=cfg.inference_txt_db, lmdb_dir=cfg.inference_img_db, cfg=cfg, tokenizer=tokenizer, is_train=False, return_label=False ) img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std) val_loader = PrefetchLoader(val_loader, img_norm) LOGGER.info(cfg) LOGGER.info("Starting inference...") LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****") LOGGER.info(f" Batch size = {cfg.inference_batch_size}") LOGGER.info(f'Step {global_step}: start validation') qa_results, qa_scores = validate( model, val_loader, cfg, global_step, eval_score=True) # cfg.inference_split == "val" if hvd.rank() == 0: save_json(cfg, join(inference_res_dir, "merged_args.json"), save_pretty=True) save_json(qa_scores, join(inference_res_dir, "scores.json"), save_pretty=True) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json( qa_results, join(inference_res_dir, f"results_rank{hvd.rank()}.json")) break except Exception as e: save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: qa_results = [] for rk in range(n_gpu): qa_results.extend(load_json( join(inference_res_dir, f"results_rank{rk}.json"))) LOGGER.info(f'results joined') if hvd.rank() == 0: save_json( qa_results, join(inference_res_dir, f"results_all.json")) LOGGER.info(f'all results written')
def mk_tgif_qa_dataloader(task_type, anno_path, lmdb_dir, cfg, tokenizer, is_train=True, return_label=True): """ Returns: list(dict), each dict is action and transition: { "gif_name": "tumblr_nk172bbdPI1u1lr18o1_250", "question": "What does the butterfly do 10 or more than 10 times ?", "options": ["stuff marshmallow", "holds a phone towards face", "fall over", "talk", "flap wings"], "answer": 4 } frameqa: { "gif_name": "tumblr_no73q2fm0I1uuf348o1_250", "question": "what is being placed in the white ice cream cone ?", "answer": "cookie", "answer_type": "object" } msrvtt_qa: { "answer": "couch", "question": "what are three people sitting on?", "video_id": "video6513", "answer_type": "what" } """ raw_datalist = load_jsonl(anno_path) LOGGER.info(f"Loaded data size {len(raw_datalist)}") if cfg.data_ratio != 1.0: random.shuffle(raw_datalist) raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)] LOGGER.info(f"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}") datalist = [] qid = 0 for raw_d in raw_datalist: d = dict( question=raw_d["question"], vid_id=raw_d["gif_name"] if "gif_name" in raw_d else raw_d["video_id"], answer=raw_d["answer"], # int or str question_id=qid # be careful, it is not unique across splits ) qid += 1 if task_type in ["action", "transition"]: d["options"] = raw_d["options"] elif task_type in ["frameqa", "msrvtt_qa"]: d["answer_type"] = raw_d["answer_type"] datalist.append(d) LOGGER.info(f"datalist {len(datalist)}") grouped = defaultdict(list) # examples grouped by image/video id for d in datalist: grouped[d["vid_id"]].append(d) LOGGER.info(f"grouped {len(grouped)}") # each group has a single image with multiple questions group_datalist = mk_input_group( grouped, max_n_example_per_group=cfg.max_n_example_per_group if is_train else 1, # force 1 in eval, is_train=is_train ) LOGGER.info(f"group_datalist {len(group_datalist)}") ans2label = load_json(cfg.ans2label_path) frm_sampling_strategy = cfg.frm_sampling_strategy if not is_train and frm_sampling_strategy == "rand": frm_sampling_strategy = "middle" dataset = ClipBertVideoQADataset( task_type=cfg.task, datalist=group_datalist, tokenizer=tokenizer, img_lmdb_dir=lmdb_dir, ans2label=ans2label, max_img_size=cfg.max_img_size, max_txt_len=cfg.max_txt_len, fps=cfg.fps, num_frm=cfg.num_frm, frm_sampling_strategy=frm_sampling_strategy, ensemble_n_clips=cfg.train_n_clips if is_train else cfg.inference_n_clips, return_label=return_label, is_train=is_train ) LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, " f"each group {cfg.max_n_example_per_group if is_train else 1}") if cfg.do_inference: batch_size = cfg.inference_batch_size else: batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size sampler = DistributedSampler( dataset, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=is_train) vqa_collator = VideoQACollator(tokenizer=tokenizer, max_length=cfg.max_txt_len, task_type=cfg.task) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler, num_workers=cfg.n_workers, pin_memory=cfg.pin_mem, collate_fn=vqa_collator.collate_batch) return dataloader
def start_inference(cfg): set_random_seed(cfg.seed) n_gpu = hvd.size() device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(hvd.local_rank()) if hvd.rank() != 0: LOGGER.disabled = True inference_res_dir = join( cfg.output_dir, f"results_{cfg.inference_split}" f"step_{cfg.inference_model_step}") if hvd.rank() == 0: os.makedirs(inference_res_dir, exist_ok=True) save_json(cfg, join(inference_res_dir, "raw_args.json"), save_pretty=True) LOGGER.info("device: {} n_gpu: {}, rank: {}, " "16-bits training: {}".format(device, n_gpu, hvd.rank(), bool(cfg.fp16))) # overwrite cfg with stored_cfg, # but skip keys containing the keyword 'inference' stored_cfg_path = join(cfg.output_dir, "log/args.json") stored_cfg = edict(load_json(stored_cfg_path)) for k, v in cfg.items(): if (k in stored_cfg and "inference" not in k and k != "output_dir"): value = stored_cfg[k] # FIXME hardcode changes if isinstance(value, str) and value.startswith("/data"): value = value.replace("/data", "/storage") setattr(cfg, k, value) # setup models cfg.model_config = join(cfg.output_dir, "log/model_config.json") cfg.detectron2_model_cfg = join(cfg.output_dir, "log/detectron2_model_cfg.yaml") e2e_weights_path = join(cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt") if exists(e2e_weights_path): cfg.e2e_weights_path = e2e_weights_path else: cfg.bert_weights_path = join( f"{cfg.output_dir}/ckpt", f"transformer_step_{cfg.inference_model_step}.pt") cfg.cnn_weights_path = join( cfg.output_dir, f"ckpt/cnn_step_{cfg.inference_model_step}.pt") model = setup_model(cfg, device=device) model.eval() # FIXME separate scaling for each loss model = amp.initialize(model, enabled=cfg.fp16, opt_level='O2') global_step = 0 # prepare data tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir) cfg.data_ratio = 1. val_loader = mk_vqa_dataloader(anno_path=cfg.inference_txt_db, img_lmdb_dir=cfg.inference_img_db, cfg=cfg, tokenizer=tokenizer, is_train=False) img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std) val_loader = PrefetchLoader(val_loader, img_norm) LOGGER.info(cfg) LOGGER.info("Starting inference...") LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****") LOGGER.info(f" Batch size = {cfg.inference_batch_size}") LOGGER.info(f'Step {global_step}: start validation') vqa_results = validate(model, val_loader, cfg, global_step, eval_score=cfg.inference_split == "val") if hvd.rank() == 0: save_json(cfg, join(inference_res_dir, "merged_args.json"), save_pretty=True) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json( vqa_results, join(inference_res_dir, f"results_rank{hvd.rank()}.json")) break except Exception: save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: vqa_results = [] for rk in range(n_gpu): vqa_results.extend( load_json(join(inference_res_dir, f"results_rank{rk}.json"))) LOGGER.info('results joined') if hvd.rank() == 0: save_json(vqa_results, join(inference_res_dir, "results_all.json")) LOGGER.info('all results written')
def mk_vqa_dataloader(anno_path, img_lmdb_dir, cfg, tokenizer, is_train=True): """ Returns: list(dict), each dict is { "filepath": str, "txt": str, } """ if isinstance(anno_path, str): raw_datalist = load_jsonl(anno_path) else: raw_datalist = flat_list_of_lists([load_jsonl(p) for p in anno_path]) if cfg.data_ratio != 1.0: random.shuffle(raw_datalist) raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)] datalist = [] for raw_d in raw_datalist: d = dict( txt=raw_d["question"], img_id=raw_d["image_id"], question_id=raw_d["question_id"], ) if "labels" in raw_d: # deal with test sets d["labels"] = raw_d["labels"] if "answer_type" in raw_d: d["answer_type"] = raw_d["answer_type"] datalist.append(d) grouped = defaultdict(list) # examples grouped by image/video id for d in datalist: grouped[d["img_id"]].append(d) # each group has a single image with multiple questions group_datalist = mk_input_group( grouped, max_n_example_per_group=cfg.max_n_example_per_group if is_train else 1, # force 1 in eval is_train=is_train, example_unique_key="question_id") ans2label = load_json(cfg.ans2label_path) dataset = ClipBertVQADataset(datalist=group_datalist, tokenizer=tokenizer, img_lmdb_dir=img_lmdb_dir, ans2label=ans2label, max_img_size=cfg.max_img_size, max_txt_len=cfg.max_txt_len) LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, " f"each group {cfg.max_n_example_per_group if is_train else 1}") if cfg.do_inference: batch_size = cfg.inference_batch_size else: batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size sampler = DistributedSampler(dataset, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=is_train) vqa_collator = VQACollator(tokenizer=tokenizer, max_length=cfg.max_txt_len) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler, num_workers=cfg.n_workers, pin_memory=cfg.pin_mem, collate_fn=vqa_collator.collate_batch) return dataloader