def validate(model, val_loader, eval_loader, cfg, train_global_step, eval_filepath): """use eval_score=False when doing inference on test sets where answers are not available""" model.eval() loss = 0. n_ex = 0 n_corrects = 0 st = time.time() debug_step = 5 for val_step, batch in enumerate(val_loader): # forward pass del batch["caption_ids"] outputs = forward_step(model, batch, cfg) targets = batch['labels'] loss += outputs["loss"].sum().item() if isinstance( outputs["loss"], torch.Tensor) else 0 n_ex += len(targets) if outputs["logits"].shape[1] == 2: n_corrects += (outputs["logits"].max( dim=-1)[1] == targets).sum().item() else: predictions = (torch.sigmoid(outputs["logits"]) > 0.5).long() predictions = predictions.view(outputs["loss"].shape[0], -1) targets = targets.view(outputs["loss"].shape[0], -1) matched = predictions[:, 0].squeeze() == targets[:, 0].squeeze() n_corrects += matched.sum().item() if cfg.debug and val_step >= debug_step: break loss = sum(all_gather_list(loss)) n_ex = sum(all_gather_list(n_ex)) n_corrects = sum(all_gather_list(n_corrects)) _, retrieval_metrics = inference_retrieval(model, eval_loader, eval_filepath, cfg) model.train() if hvd.rank() == 0: # average loss for each example acc = float(n_corrects / n_ex) val_log = {'valid/loss': float(loss / n_ex), 'valid/acc': acc} for ret_type, ret_m in retrieval_metrics.items(): val_log.update({ f"valid/{ret_type}_{k}": round(v, 4) for k, v in ret_m.items() }) TB_LOGGER.log_scalar_dict(val_log) LOGGER.info(f"validation finished in {int(time.time() - st)} seconds." f"itm_acc: {acc}. Retrieval res {retrieval_metrics}")
def inference_retrieval_mc(model, val_loader, eval_file_path, cfg, n_options=5): model.eval() pred_id2ans = dict() st = time.time() LOGGER.info(f"Evaluate retrieval MC: {len(val_loader)}") if hvd.rank() == 0: pbar = tqdm(total=len(val_loader), desc="eval") for batch in val_loader: # compile shared text part question_ids = batch["question_ids"] bsz = len(question_ids) del batch["question_ids"] mini_batch = dict() for k, v in batch.items(): if k not in ["visual_inputs", "meta"]: mini_batch[k] = v # multi-frame test, scores across frames of the same video will be pooled together # batch["visual_inputs"] (B, T, C, H, W) pool_method = cfg.score_agg_func # could be 1, where only a single clip is evaluated num_clips = cfg.inference_n_clips num_frm = cfg.num_frm # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W) new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:] visual_inputs = batch["visual_inputs"].view(*new_visual_shape) logits = [] for clip_idx in range(num_clips): mini_batch["visual_inputs"] = visual_inputs[:, clip_idx] mini_batch["n_examples_list"] = batch["n_examples_list"] outputs = forward_step(model, mini_batch, cfg, n_options=n_options) logits.append(outputs["logits"].cpu()) logits = torch.stack(logits) # (num_frm, B, 1 or 2) if pool_method == "mean": logits = logits.mean(0) # (B, 1 or 2) elif pool_method == "max": logits = logits.max(0)[0] # (B, 1 or 2) elif pool_method == "lse": logits = logits.permute( 1, 0, 2).contiguous() # (B, num_frm, 5), pooling will be done in CE logits = torch.logsumexp( logits, dim=1) # torch.exp alone might be too large and unstable else: raise ValueError( f"Invalid value for pool_method, " f"got {pool_method}, expect one of [`mean`, `max`, `lse`]") if logits.shape[1] == 2: probs = F.softmax(logits, dim=1)[:, 1] else: probs = torch.sigmoid(logits.squeeze()) # B probs = probs.view(-1, n_options) # (B, 5) pred_answers = probs.max(1)[1].tolist() # (B, ) for qid, pred_ans in zip(question_ids, pred_answers): pred_id2ans[qid] = int(pred_ans) if hvd.rank() == 0: pbar.update(1) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) n_gpu = hvd.size() eval_dir = join( cfg.output_dir, f"results_mc_{os.path.splitext(os.path.basename(eval_file_path))[0]}") os.makedirs(eval_dir, exist_ok=True) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json( pred_id2ans, join(eval_dir, f"tmp_results_mc_rank{hvd.rank()}.json")) break except Exception as e: print(f"Saving exception: {e}") save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: pred_id2ans = [] for rk in range(n_gpu): pred_id2ans.append( load_json(join(eval_dir, f"tmp_results_mc_rank{rk}.json"))) pred_id2ans = merge_dicts(pred_id2ans) LOGGER.info('results joined') if hvd.rank() == 0: retrieval_qa_metrics = val_loader.dataset.evaluate_qa_accuracy( pred_id2ans, force_same=True) LOGGER.info( f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_qa_metrics}" ) else: retrieval_qa_metrics = None model.train() return pred_id2ans, retrieval_qa_metrics
def inference_retrieval(model, val_loader, eval_file_path, cfg): model.eval() retrieval_res = [] # list(dict): dict(vid_id=..., txt_id=..., score=...) st = time.time() eval_bsz = cfg.inference_batch_size if cfg.do_inference else cfg.eval_retrieval_batch_size LOGGER.info(f"Evaluate retrieval #video per GPU: {len(val_loader)}") if hvd.rank() == 0: pbar = tqdm(total=len(val_loader), desc="eval") for batch in val_loader: # each batch contains 1 video and N (=1000) captions n_mini_batches = math.ceil(len(batch["caption_ids"]) / eval_bsz) vid_id = batch["vid_id"] for idx in range(n_mini_batches): # compile shared text part mini_batch = dict() for k in ["text_input_ids", "text_input_mask", "labels"]: if batch[k] is not None: mini_batch[k] = batch[k][idx * eval_bsz:(idx + 1) * eval_bsz] else: mini_batch[k] = None caption_ids = batch["caption_ids"][idx * eval_bsz:(idx + 1) * eval_bsz] # bsz = len(caption_ids) mini_batch["n_examples_list"] = [len(caption_ids)] # multi-frame test, scores across frames of the same video will be pooled together pool_method = cfg.score_agg_func # could be 1, where only a single clip is evaluated num_clips = cfg.inference_n_clips num_frm = cfg.num_frm # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W) new_visual_shape = (1, num_clips, num_frm) + batch["visual_inputs"].shape[2:] visual_inputs = batch["visual_inputs"].view(*new_visual_shape) logits = [] for clip_idx in range(num_clips): mini_batch["visual_inputs"] = visual_inputs[:, clip_idx] outputs = forward_step(model, mini_batch, cfg) logits.append(outputs["logits"].cpu()) logits = torch.stack(logits) # (num_frm, B, 1 or 2) if pool_method == "mean": logits = logits.mean(0) # (B, 1 or 2) elif pool_method == "max": logits = logits.max(0)[0] # (B, 1 or 2) elif pool_method == "lse": logits = logits.permute(1, 0, 2).contiguous( ) # (B, num_frm, 5), pooling will be done in CE logits = torch.logsumexp( logits, dim=1) # torch.exp alone might be too large and unstable else: raise ValueError( f"Invalid value for pool_method, " f"got {pool_method}, expect one of [`mean`, `max`, `lse`]") if logits.shape[1] == 2: probs = F.softmax(logits, dim=1)[:, 1].tolist() else: probs = torch.sigmoid(logits.squeeze()).tolist() # B for cap_id, score in zip(caption_ids, probs): retrieval_res.append( dict(vid_id=vid_id, txt_id=cap_id, score=round(score, 4))) if hvd.rank() == 0: pbar.update(1) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) n_gpu = hvd.size() eval_dir = join( cfg.output_dir, f"results_{os.path.splitext(os.path.basename(eval_file_path))[0]}") os.makedirs(eval_dir, exist_ok=True) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json(retrieval_res, join(eval_dir, f"tmp_results_rank{hvd.rank()}.json")) break except Exception as e: print(f"Saving exception: {e}") save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: retrieval_res = [] for rk in range(n_gpu): retrieval_res.extend( load_json(join(eval_dir, f"tmp_results_rank{rk}.json"))) LOGGER.info('results joined') if hvd.rank() == 0: retrieval_metrics = eval_retrieval(retrieval_res, val_loader.dataset.gt_cap_id2vid_id, val_loader.dataset.id2data) LOGGER.info( f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_metrics}" ) else: retrieval_metrics = None model.train() return retrieval_res, retrieval_metrics
def start_inference(cfg): set_random_seed(cfg.seed) n_gpu = hvd.size() device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(hvd.local_rank()) if hvd.rank() != 0: LOGGER.disabled = True inference_res_dir = join( cfg.output_dir, f"results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/" f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}" ) if hvd.rank() == 0: os.makedirs(inference_res_dir, exist_ok=True) save_json(cfg, join(inference_res_dir, "raw_args.json"), save_pretty=True) LOGGER.info("device: {} n_gpu: {}, rank: {}, " "16-bits training: {}".format( device, n_gpu, hvd.rank(), bool(cfg.fp16))) # overwrite cfg with stored_cfg, # but skip keys containing the keyword 'inference' stored_cfg_path = join(cfg.output_dir, "log/args.json") stored_cfg = edict(load_json(stored_cfg_path)) for k, v in cfg.items(): if k in stored_cfg and "inference" not in k: setattr(cfg, k, stored_cfg[k]) # setup models cfg.model_config = join(cfg.output_dir, "log/model_config.json") e2e_weights_path = join( cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt") cfg.e2e_weights_path = e2e_weights_path model = setup_model(cfg, device=device) model.eval() # FIXME separate scaling for each loss model = amp.initialize( model, enabled=cfg.fp16, opt_level='O2') global_step = 0 # prepare data tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir) cfg.data_ratio = 1. val_loader = mk_tgif_qa_dataloader( task_type=cfg.task, anno_path=cfg.inference_txt_db, lmdb_dir=cfg.inference_img_db, cfg=cfg, tokenizer=tokenizer, is_train=False, return_label=False ) img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std) val_loader = PrefetchLoader(val_loader, img_norm) LOGGER.info(cfg) LOGGER.info("Starting inference...") LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****") LOGGER.info(f" Batch size = {cfg.inference_batch_size}") LOGGER.info(f'Step {global_step}: start validation') qa_results, qa_scores = validate( model, val_loader, cfg, global_step, eval_score=True) # cfg.inference_split == "val" if hvd.rank() == 0: save_json(cfg, join(inference_res_dir, "merged_args.json"), save_pretty=True) save_json(qa_scores, join(inference_res_dir, "scores.json"), save_pretty=True) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json( qa_results, join(inference_res_dir, f"results_rank{hvd.rank()}.json")) break except Exception as e: save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: qa_results = [] for rk in range(n_gpu): qa_results.extend(load_json( join(inference_res_dir, f"results_rank{rk}.json"))) LOGGER.info(f'results joined') if hvd.rank() == 0: save_json( qa_results, join(inference_res_dir, f"results_all.json")) LOGGER.info(f'all results written')
def validate(model, val_loader, cfg, train_global_step, eval_score=True): """use eval_score=False when doing inference on test sets where answers are not available""" model.eval() loss = 0. n_ex = 0 qa_results = [] st = time.time() debug_step = 5 pbar = tqdm(total=len(val_loader)) for val_step, batch in enumerate(val_loader): # forward pass question_ids = batch["question_ids"] bsz = len(question_ids) # used to make visual feature copies del batch["question_ids"] # add visual part into the mini batch and perform inference mini_batch = dict() for k, v in batch.items(): if k != "visual_inputs": mini_batch[k] = v n_ex += len(question_ids) # multi-frame test, scores across frames of the same video will be pooled together pool_method = cfg.score_agg_func # could be 1, where only a single clip is evaluated num_clips = cfg.inference_n_clips num_frm = cfg.num_frm # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W) new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:] visual_inputs = batch["visual_inputs"].view(*new_visual_shape) logits = [] losses = [] for clip_idx in range(num_clips): # (B, num_frm, C, H, W) mini_batch["visual_inputs"] = visual_inputs[:, clip_idx] mini_batch["n_examples_list"] = batch["n_examples_list"] outputs = forward_step(model, mini_batch, cfg) logits.append(outputs["logits"].cpu()) _loss = outputs["loss"].sum().item() if isinstance( outputs["loss"], torch.Tensor) else 0 losses.append(_loss) loss += (sum(losses) / num_clips) logits = torch.stack(logits) # (num_frm, B, 5) if pool_method == "mean": logits = logits.mean(0) # (B, 5) elif pool_method == "max": logits = logits.max(0)[0] # (B, 5) elif pool_method == "lse": logits = logits.permute(1, 0, 2).contiguous() # (B, num_frm, 5), pooling will be done in CE logits = torch.logsumexp(logits, dim=1) # torch.exp alone might be too large and unstable else: raise ValueError(f"Invalid value for pool_method, " f"got {pool_method}, expect one of [`mean`, `max`, `lse`]") if cfg.task in ["action", "transition", "frameqa", "msrvtt_qa"]: # cross entropy pred_labels = logits.max(dim=-1)[1].data.tolist() else: # mse preds = (logits + 0.5).long().clamp(min=1, max=10) pred_labels = preds.data.squeeze().tolist() for qid, pred_label in zip(question_ids, pred_labels): qa_results.append(dict( question_id=qid, answer=pred_label, data=val_loader.dataset.qid2data[qid] )) pbar.update(1) if cfg.debug and val_step >= debug_step: break if cfg.debug: LOGGER.info(qa_results[:10]) n_ex_per_rank = all_gather_list(n_ex) loss = sum(all_gather_list(loss)) n_ex = sum(all_gather_list(n_ex)) # average loss for each example val_log = {f'valid/loss': float(loss / n_ex)} if eval_score: LOGGER.info(f"QA Task [{cfg.task}], " f"{len(qa_results)} qa_results," f"3 examples here: {qa_results[:3]}") vqa_scores = val_loader.dataset.evaluate_tgif_qa(qa_results) # print(f"{hvd.rank()}: {vqa_scores}") # Gather scores scores_per_rank = all_gather_list(vqa_scores) gathered_scores = {} if "ratios" in scores_per_rank[0]: gathered_ratios = { k: [0, 0] for k, _ in scores_per_rank[0]["ratios"].items()} # Gather ratios for rank_id in range(len(n_ex_per_rank)): current_ratios = scores_per_rank[rank_id]["ratios"] for k, v in current_ratios.items(): gathered_ratios[k][1] += v[1] for k, v in gathered_ratios.items(): gathered_ratios[k][0] = get_rounded_percentage( 1. * v[1] / n_ex) gathered_scores["ratios"] = gathered_ratios # FIXME: Gather scores become complicated due to np.mean and dict format. for scores_k, _ in vqa_scores.items(): if "ratio" in scores_k: continue gathered_v = 0 for rank_id, n in enumerate(n_ex_per_rank): curr_acc, curr_n_ex = 0, 0 if "overall" in scores_k: curr_acc = scores_per_rank[rank_id][scores_k] * n else: if "ratios" in scores_per_rank[0]: curr_n_ex = scores_per_rank[ rank_id]["ratios"][ scores_k.replace("acc", "ratio")][1] curr_acc = scores_per_rank[rank_id][ scores_k] * curr_n_ex gathered_v += curr_acc if "overall" in scores_k: gathered_v = gathered_v * 1. / n_ex else: if "ratios" in scores_per_rank[0]: _num = gathered_ratios[ scores_k.replace("acc", "ratio")][1] gathered_v = gathered_v * 1. / _num if _num != 0 else 0 if cfg.task in ["action", "transition", "frameqa", "msrvtt_qa"]: gathered_scores[scores_k] = get_rounded_percentage( gathered_v) else: gathered_scores[scores_k] = round(gathered_v, 2) for k, v in gathered_scores.items(): if "ratio" not in k: val_log[f'valid/{k}'] = v else: LOGGER.info("eval_score = False, no scores are calculated.") gathered_scores = 0 TB_LOGGER.log_scalar_dict(val_log) LOGGER.info(f"validation finished in {int(time.time() - st)} seconds." f"{gathered_scores}") model.train() return qa_results, gathered_scores
def validate(model, val_loader, cfg): model.eval() mlm_loss = 0 n_mlm_tokens = 0 n_mlm_corrects = 0 itm_loss = 0 n_itm_ex = 0 n_itm_corrects = 0 st = time.time() val_log = { 'valid/mlm_loss': 0, 'valid/mlm_acc': 0, 'valid/itm_loss': 0, 'valid/itm_acc': 0 } debug_step = 5 val_loaders = val_loader if isinstance(val_loader, dict) else { "unnamed_val_loader": val_loader } LOGGER.info(f"In total {len(val_loaders)} val loaders") for loader_name, val_loader in val_loaders.items(): LOGGER.info(f"Loop val_loader {loader_name}.") for val_step, batch in enumerate(val_loader): # use iter to reset MetaLoader # forward pass outputs = forward_step(cfg, model, batch) # mlm mlm_labels = outputs["mlm_labels"] if cfg.use_mlm: mlm_loss += outputs["mlm_loss"].sum().item() mlm_mask = mlm_labels != -100 # (B, Lt) -100 is the ignored label for cross entropy n_mlm_tokens += mlm_mask.sum().item() n_mlm_corrects += (outputs["mlm_scores"][mlm_mask].max( dim=-1)[1] == mlm_labels[mlm_mask]).sum().item() # itm if cfg.use_itm: itm_loss += outputs["itm_loss"].sum().item() n_itm_ex += len(outputs["itm_labels"]) n_itm_corrects += (outputs["itm_scores"].max( dim=-1)[1] == outputs["itm_labels"]).sum().item() if cfg.debug and val_step >= debug_step: break # Gather across all processes mlm_loss = sum(all_gather_list(mlm_loss)) n_mlm_corrects = sum(all_gather_list(n_mlm_corrects)) n_mlm_tokens = sum(all_gather_list(n_mlm_tokens)) itm_loss = sum(all_gather_list(itm_loss)) n_itm_corrects = sum(all_gather_list(n_itm_corrects)) n_itm_ex = sum(all_gather_list(n_itm_ex)) if n_mlm_tokens != 0: val_log.update({ 'valid/mlm_loss': float(mlm_loss / n_mlm_tokens), 'valid/mlm_acc': float(n_mlm_corrects / n_mlm_tokens) }) if n_itm_ex != 0: val_log.update({ 'valid/itm_loss': float(itm_loss / n_itm_ex), 'valid/itm_acc': float(n_itm_corrects / n_itm_ex) }) TB_LOGGER.log_scalar_dict(val_log) LOGGER.info( f"validation finished in {int(time.time() - st)} seconds, " f"[mlm_acc (per token)]: {val_log['valid/mlm_acc'] * 100:.2f} " f"[itm_acc (per example)]: {val_log['valid/itm_acc'] * 100:.2f} ") model.train() return val_log
def start_inference(cfg): set_random_seed(cfg.seed) n_gpu = hvd.size() device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(hvd.local_rank()) if hvd.rank() != 0: LOGGER.disabled = True inference_res_dir = join( cfg.output_dir, f"results_{cfg.inference_split}" f"step_{cfg.inference_model_step}") if hvd.rank() == 0: os.makedirs(inference_res_dir, exist_ok=True) save_json(cfg, join(inference_res_dir, "raw_args.json"), save_pretty=True) LOGGER.info("device: {} n_gpu: {}, rank: {}, " "16-bits training: {}".format(device, n_gpu, hvd.rank(), bool(cfg.fp16))) # overwrite cfg with stored_cfg, # but skip keys containing the keyword 'inference' stored_cfg_path = join(cfg.output_dir, "log/args.json") stored_cfg = edict(load_json(stored_cfg_path)) for k, v in cfg.items(): if (k in stored_cfg and "inference" not in k and k != "output_dir"): value = stored_cfg[k] # FIXME hardcode changes if isinstance(value, str) and value.startswith("/data"): value = value.replace("/data", "/storage") setattr(cfg, k, value) # setup models cfg.model_config = join(cfg.output_dir, "log/model_config.json") cfg.detectron2_model_cfg = join(cfg.output_dir, "log/detectron2_model_cfg.yaml") e2e_weights_path = join(cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt") if exists(e2e_weights_path): cfg.e2e_weights_path = e2e_weights_path else: cfg.bert_weights_path = join( f"{cfg.output_dir}/ckpt", f"transformer_step_{cfg.inference_model_step}.pt") cfg.cnn_weights_path = join( cfg.output_dir, f"ckpt/cnn_step_{cfg.inference_model_step}.pt") model = setup_model(cfg, device=device) model.eval() # FIXME separate scaling for each loss model = amp.initialize(model, enabled=cfg.fp16, opt_level='O2') global_step = 0 # prepare data tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir) cfg.data_ratio = 1. val_loader = mk_vqa_dataloader(anno_path=cfg.inference_txt_db, img_lmdb_dir=cfg.inference_img_db, cfg=cfg, tokenizer=tokenizer, is_train=False) img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std) val_loader = PrefetchLoader(val_loader, img_norm) LOGGER.info(cfg) LOGGER.info("Starting inference...") LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****") LOGGER.info(f" Batch size = {cfg.inference_batch_size}") LOGGER.info(f'Step {global_step}: start validation') vqa_results = validate(model, val_loader, cfg, global_step, eval_score=cfg.inference_split == "val") if hvd.rank() == 0: save_json(cfg, join(inference_res_dir, "merged_args.json"), save_pretty=True) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json( vqa_results, join(inference_res_dir, f"results_rank{hvd.rank()}.json")) break except Exception: save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: vqa_results = [] for rk in range(n_gpu): vqa_results.extend( load_json(join(inference_res_dir, f"results_rank{rk}.json"))) LOGGER.info('results joined') if hvd.rank() == 0: save_json(vqa_results, join(inference_res_dir, "results_all.json")) LOGGER.info('all results written')
def validate(model, val_loader, cfg, train_global_step, eval_score=True): """use eval_score=False when doing inference on test sets where answers are not available""" model.eval() loss = 0. n_ex = 0 vqa_results = [] st = time.time() debug_step = 5 for val_step, batch in enumerate(val_loader): # forward pass outputs, question_ids = forward_step(model, batch) loss += outputs["loss"].sum().item() if isinstance( outputs["loss"], torch.Tensor) else 0 n_ex += len(question_ids) pred_labels = outputs["logits"].max(dim=-1)[1].data.tolist() for qid, pred_label in zip(question_ids, pred_labels): vqa_results.append( dict(question_id=qid, answer=val_loader.dataset.label2ans[pred_label])) if cfg.debug and val_step >= debug_step: break if cfg.debug: LOGGER.info(vqa_results[:10]) n_ex_per_rank = all_gather_list(n_ex) loss = sum(all_gather_list(loss)) n_ex = sum(all_gather_list(n_ex)) val_log = {'valid/loss': float(loss / n_ex)} if eval_score: LOGGER.info(f"Evaluate VQA scores for {len(vqa_results)} vqa_results," f"3 examples here: {vqa_results[:3]}") vqa_scores = val_loader.dataset.evaluate_vqa(vqa_results) # Gather scores scores_per_rank = all_gather_list(vqa_scores) gathered_scores = {} gathered_ratios = { k: [0, 0] for k, _ in scores_per_rank[0]["ratios"].items() } # Gather ratios for rank_id in range(len(n_ex_per_rank)): current_ratios = scores_per_rank[rank_id]["ratios"] for k, v in current_ratios.items(): gathered_ratios[k][1] += v[1] for k, v in gathered_ratios.items(): gathered_ratios[k][0] = get_rounded_percentage(1. * v[1] / n_ex) # FIXME: Gather scores become complicated due to np.mean and dict format. for scores_k, _ in vqa_scores.items(): if "ratio" in scores_k: continue gathered_v = 0 for rank_id, n in enumerate(n_ex_per_rank): if "overall" in scores_k: curr_acc = scores_per_rank[rank_id][scores_k] * n else: curr_n_ex = scores_per_rank[rank_id]["ratios"][ scores_k.replace("acc", "ratio")][1] curr_acc = scores_per_rank[rank_id][scores_k] * curr_n_ex gathered_v += curr_acc if "overall" in scores_k: gathered_v = gathered_v * 1. / n_ex else: gathered_v = gathered_v * 1. / gathered_ratios[ scores_k.replace("acc", "ratio")][1] gathered_scores[scores_k] = get_rounded_percentage(gathered_v) gathered_scores["ratios"] = gathered_ratios for k, v in gathered_scores.items(): if "ratio" not in k: val_log[f'valid/{k}'] = v else: LOGGER.info("Seems you are doing inference on test set," "no scores are calculated.") gathered_scores = 0 TB_LOGGER.log_scalar_dict(val_log) LOGGER.info(f"validation finished in {int(time.time() - st)} seconds." f"{gathered_scores}") model.train() return vqa_results