def train(self): if self.verbose: n_correct = 0 n_total = 0 for batch in self.val_loader: exists_target = batch['exists_target'] n_correct += exists_target.sum().item() n_total += len(exists_target) print(f'Val Oracle acc: {n_correct / n_total * 100:.2f}%') n_correct = 0 n_total = 0 for batch in self.test_loader: exists_target = batch['exists_target'] n_correct += exists_target.sum().item() n_total += len(exists_target) print(f'Test Oracle acc: {n_correct / n_total * 100:.2f}%') if self.verbose: loss_meter = LossMeter() best_valid_acc = 0. best_epoch = 0 if 't5' in self.args.backbone: project_name = "VLT5_RefCOCOg" elif 'bart' in self.args.backbone: project_name = "VLBart_RefCOCOg" if self.args.RefCOCO_GT: project_name += '_GT' wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0, } n_correct = 0 n_total = 0 for step_i, batch in enumerate(self.train_loader): batch['log_train_accuracy'] = self.args.log_train_accuracy if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.args.log_train_accuracy: correct = results['correct'] n_correct += sum(correct) n_total += len(correct) if self.verbose: loss_meter.update(loss.item()) # acc_meter.update(results['acc'].item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:.3f} |' # desc_str += f' Acc {acc_meter.val:.3f} |' if self.args.log_train_accuracy: desc_str += f' Correct {n_correct:.0f}' desc_str += f' (Acc {n_correct/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() if self.args.log_train_accuracy: train_score_dict = {'n_correct': n_correct, 'n_total': n_total} train_score_dict = reduce_dict(train_score_dict, self.args.gpu) # Validation # valid_score_dict = self.evaluate(self.val_loader) # valid_score_dict = reduce_dict(valid_score_dict, self.args.gpu) if self.verbose: if self.args.log_train_accuracy: train_acc = train_score_dict[ 'n_correct'] / train_score_dict['n_total'] * 100 train_n_correct = int(train_score_dict['n_correct']) train_n_total = int(train_score_dict['n_total']) # Validation valid_score_dict = self.evaluate(self.val_loader) valid_acc = valid_score_dict['n_correct'] / valid_score_dict[ 'n_total'] * 100 valid_n_correct = int(valid_score_dict['n_correct']) valid_n_total = int(valid_score_dict['n_total']) if valid_acc > best_valid_acc: best_valid_acc = valid_acc best_epoch = epoch self.save("BEST") log_str = '' if self.args.log_train_accuracy: log_str += f"\nEpoch {epoch}: Train" log_str += f" Acc {train_acc:.2f}% |" log_str += f" # correct {train_n_correct} # total {train_n_total}" log_str += f"\nEpoch {epoch}: Valid" log_str += f" Acc {valid_acc:.2f}% |" log_str += f" # correct {valid_n_correct} # total {valid_n_total}" log_str += f"\nEpoch {best_epoch}: Best Acc {best_valid_acc:.2f}%\n" wandb_log_dict = {} if self.args.log_train_accuracy: wandb_log_dict['Train/Acc'] = train_acc wandb_log_dict['Valid/Acc'] = valid_acc wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) test_score_dict = self.evaluate(self.test_loader) test_acc = test_score_dict['n_correct'] / test_score_dict[ 'n_total'] * 100 test_n_correct = int(test_score_dict['n_correct']) test_n_total = int(test_score_dict['n_total']) wandb_log_dict = {} wandb_log_dict['Test/Acc'] = test_acc wandb.log(wandb_log_dict, step=epoch) log_str = '' log_str += f"\nTest Acc {test_acc:.2f}%" log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}" print(log_str) wandb.log({'finished': True}) if self.args.distributed: dist.barrier()
def train(self): if self.verbose: loss_meter = LossMeter() qa_loss_meter = LossMeter() qar_loss_meter = LossMeter() best_valid_Q_AR = 0. best_epoch = 0 if 't5' in self.args.backbone: if self.args.use_vision: project_name = "VLT5_VCR" else: project_name = "T5_VCR" elif 'bart' in self.args.backbone: if self.args.use_vision: project_name = "VLBart_VCR" else: project_name = "Bart_VCR" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=200) epoch_results = { 'loss': 0, } Q_A_results = 0 QA_R_results = 0 Q_AR_results = 0 n_total = 0 n_accu = 0 train_loss = 0 train_qa_loss = 0 train_qar_loss = 0 for step_i, batch in enumerate(self.train_loader): batch['log_train_accuracy'] = self.args.log_train_accuracy if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.gradient_accumulation_steps > 1: loss /= self.args.gradient_accumulation_steps if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) update = True if self.args.gradient_accumulation_steps > 1: # if step_i == 0: # update = False # elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1: # update = True # else: # update = False update = ((step_i + 1) % self.args.gradient_accumulation_steps == 0) or (step_i == len(self.train_loader) - 1) n_accu += 1 if update: if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: train_loss += loss.detach().item() train_qa_loss += results['qa_loss'].item( ) / self.args.gradient_accumulation_steps train_qar_loss += results['qar_loss'].item( ) / self.args.gradient_accumulation_steps if self.args.log_train_accuracy: qa_pred = results['qa_pred'] qar_pred = results['qar_pred'] a_labels = batch['answer_labels'] r_labels = batch['rationale_labels'] Q_A_correct = a_labels == qa_pred QA_R_correct = r_labels == qar_pred Q_AR_correct = Q_A_correct & QA_R_correct Q_A_results += sum(Q_A_correct) QA_R_results += sum(QA_R_correct) Q_AR_results += sum(Q_AR_correct) n_total += len(qa_pred) if update: if self.args.gradient_accumulation_steps > 1: train_loss *= self.args.gradient_accumulation_steps / n_accu train_qa_loss *= self.args.gradient_accumulation_steps / n_accu train_qar_loss *= self.args.gradient_accumulation_steps / n_accu loss_meter.update(train_loss) qa_loss_meter.update(train_qa_loss) qar_loss_meter.update(train_qar_loss) desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step} |' desc_str += f' Loss {loss_meter.val:.3f} |' desc_str += f' QA Loss {qa_loss_meter.val:.3f} |' desc_str += f' QAR Loss {qar_loss_meter.val:.3f} |' train_loss = 0 train_qa_loss = 0 train_qar_loss = 0 n_accu = 0 if self.args.log_train_accuracy: desc_str += f' Q -> A {Q_A_results} ({Q_A_results/n_total*100:.1f}%)' desc_str += f' QA -> R {QA_R_results} ({QA_R_results/n_total*100:.1f}%)' desc_str += f' Q -> AR {Q_AR_results} ({Q_AR_results/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.verbose: pbar.close() if self.args.log_train_accuracy: train_score_dict = { 'Q_A': Q_A_results, 'QA_R': QA_R_results, 'Q_AR': Q_AR_results, 'n_total': n_total } train_score_dict = reduce_dict(train_score_dict, self.args.gpu) # Validation valid_score_dict = self.evaluate_val(self.val_loader) if self.verbose: if self.args.log_train_accuracy: train_Q_A = train_score_dict['Q_A'] / train_score_dict[ 'n_total'] * 100 train_QA_R = train_score_dict['QA_R'] / train_score_dict[ 'n_total'] * 100 train_Q_AR = train_score_dict['Q_AR'] / train_score_dict[ 'n_total'] * 100 train_n_total = int(train_score_dict['n_total']) valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[ 'n_total'] * 100 valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[ 'n_total'] * 100 valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[ 'n_total'] * 100 valid_n_total = int(valid_score_dict['n_total']) if valid_Q_AR > best_valid_Q_AR: best_valid_Q_AR = valid_Q_AR best_epoch = epoch self.save("BEST") log_str = '' if self.args.log_train_accuracy: log_str += f"\nEpoch {epoch}: Train |" log_str += f" # examples: {train_n_total} |" log_str += f" Q -> A {train_Q_A:.2f}%" log_str += f" QA -> R {train_QA_R:.2f}%" log_str += f" Q -> AR {train_Q_AR:.2f}%" log_str += f"\nEpoch {epoch}: Valid |" log_str += f" # examples: {valid_n_total} |" log_str += f" Q -> A {valid_Q_A:.2f}%" log_str += f" QA -> R {valid_QA_R:.2f}%" log_str += f" Q -> AR {valid_Q_AR:.2f}%" #log_str += "\nEpoch %d: Valid Q -> AR %0.2f" % (epoch, valid_Q_AR) log_str += f"\nBest Epoch {best_epoch}: Q -> AR {best_valid_Q_AR:.2f}%\n" wandb_log_dict = {} # wandb_log_dict['Train/Loss'] = loss_meter.val if self.args.log_train_accuracy: wandb_log_dict['Train/Q_A'] = train_Q_A wandb_log_dict['Train/QA_R'] = train_QA_R wandb_log_dict['Train/Q_AR'] = train_Q_AR wandb_log_dict['Valid/Q_A'] = valid_Q_A wandb_log_dict['Valid/QA_R'] = valid_QA_R wandb_log_dict['Valid/Q_AR'] = valid_Q_AR wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) if self.verbose: dump_path = os.path.join(self.args.output, 'test_submit.csv') print('Dumping test set results at', dump_path) self.evaluate_test(self.test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) print('Done!') wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()
def train(self): LOSSES_NAME = self.args.LOSSES_NAME if self.args.dry: results = self.evaluate_epoch(epoch=0) if self.verbose: loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))] best_eval_loss = 9595. if 't5' in self.args.backbone: project_name = "VLT5_Pretrain" elif 'bart' in self.args.backbone: project_name = "VLBart_Pretrain" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) # Train self.model.train() if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=250) epoch_results = {} for loss_name in LOSSES_NAME: epoch_results[loss_name] = 0. epoch_results[f'{loss_name}_count'] = 0 for step_i, batch in enumerate(self.train_loader): if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 if self.lr_scheduler: if version.parse(torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr for k, v in results.items(): if k in epoch_results: if isinstance(v, int): epoch_results[k] += v elif isinstance(v, torch.Tensor): epoch_results[k] += v.item() if self.verbose: desc_str = f'Epoch {epoch} | LR {lr:.6f} |' for i, (loss_name, loss_meter) in enumerate(zip(LOSSES_NAME, loss_meters)): if loss_name in results: loss_meter.update(results[f'{loss_name}'] / results[f'{loss_name}_count']) if len(loss_meter) > 0: loss_count = epoch_results[f'{loss_name}_count'] desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}' pbar.set_description(desc_str) pbar.update(1) if self.verbose: pbar.close() dist.barrier() results = reduce_dict(epoch_results, average=False) if self.verbose: train_loss = results['total_loss'] train_loss_count = results['total_loss_count'] avg_train_loss = train_loss / train_loss_count losses_str = f"Train Loss: {avg_train_loss:.3f}\n" for name, loss in results.items(): if name[-4:] == 'loss': loss_count = int(results[name+'_count']) if loss_count > 0: avg_loss = loss/loss_count losses_str += f"{name} ({loss_count}): {avg_loss:.3f} " wandb.log({f'Train Loss/{name}': avg_loss}, step=epoch) losses_str += '\n' print(losses_str) dist.barrier() # Validation valid_results, valid_uid2ans = self.evaluate_epoch(epoch=epoch) valid_results = reduce_dict(valid_results, average=False) if self.verbose: valid_loss = valid_results['total_loss'] valid_loss_count = valid_results['total_loss_count'] avg_valid_loss = valid_loss / valid_loss_count losses_str = f"Valid Loss: {avg_valid_loss:.3f}\n" for name, loss in valid_results.items(): if name[-4:] == 'loss': loss_count = int(valid_results[name+'_count']) if loss_count > 0: avg_loss = loss / loss_count losses_str += f"{name} ({loss_count}): {avg_loss:.3f} " wandb.log({f'Valid Loss/{name}': avg_loss}, step=epoch) losses_str += '\n' print(losses_str) if 'qa' in self.args.losses: dset2score, dset2cnt, score, cnt = self.val_loader.dataset.evaluator.evaluate(valid_uid2ans) if len(dset2score) == 0: dset2score = {'vqa': 0, 'gqa': 0, 'visual7w': 0} dset2cnt = {'vqa': 1, 'gqa': 1, 'visual7w': 1} cnt = 3 score = 0 dset2score = reduce_dict(dset2score, average=False) dset2cnt = reduce_dict(dset2cnt, average=False) score_cnt_dict = reduce_dict({'score': score, 'cnt': cnt}, average=False) if self.args.gpu == 0: score = score_cnt_dict['score'] cnt = score_cnt_dict['cnt'] accu = score / cnt dset2accu = {} for dset in dset2cnt: dset2accu[dset] = dset2score[dset] / dset2cnt[dset] accu_str = "Overall QA Acc %0.4f" % (accu) wandb.log({f'Valid QA Acc/Overall': accu}, step=epoch) sorted_keys = sorted(dset2accu.keys()) for key in sorted_keys: accu_str += ", %s Acc %0.4f" % (key, dset2accu[key]) wandb.log({f'Valid QA Acc/{key}': dset2accu[key]}, step=epoch) print(accu_str) accu_str += '\n\n' dist.barrier() if self.verbose: # Save if avg_valid_loss < best_eval_loss: best_eval_loss = avg_valid_loss # self.save("BEST_EVAL_LOSS") self.save("Epoch%02d" % (epoch + 1)) dist.barrier() if self.verbose: wandb.log({'finished': True})
def train(self): if self.verbose: vqa_loss_meter = LossMeter() refcoco_loss_meter = LossMeter() # best_eval_loss = 9595. quesid2ans = {} best_vqa_valid = 0. best_vqa_epoch = 0 # gqa best_gqa_valid = 0 best_gqa_epoch = 0 # nlvr best_nlvr_valid = 0 best_nlvr_epoch = 0 # vcr best_valid_Q_AR = 0 best_vcr_epoch = 0 # refcoco best_refcoco_valid = 0 best_refcoco_epoch = 0 # caption best_caption_valid = 0 best_caption_epoch = 0 # mmt best_mmt_valid = 0 best_mmt_epoch = 0 assert 't5' in self.args.backbone self.setup_wandb() if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=250) epoch_results = { 'loss': 0., } task_counter = { 'vqa': 0, 'gqa': 0, 'nlvr': 0, 'refcoco': 0, 'vcr': 0, 'caption': 0, 'mmt': 0, } # vqa quesid2ans = {} train_acc = 0. # train_acc_steps = int(len(self.train_loader) * 0.05) # last_acc_step = 0 # refcoco n_correct = 0 n_total = 0 for step_i, batch in enumerate(self.train_loader): # print(f'GPU{self.args.gpu} inside training loop') # print(batch) task = batch['task'] # if self.verbose: # print('task', task) task_counter[task] += 1 batch['log_train_accuracy'] = self.args.log_train_accuracy # self.optim.zero_grad() if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] # print(f'GPU{self.args.gpu} after loss') if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(f'GPU{self.args.gpu} after backward') loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr # self.train_step_post_hook(result) if self.args.log_train_accuracy and task == 'refcoco': correct = results['correct'] n_correct += sum(correct) n_total += len(correct) if self.verbose: if task == 'vqa': vqa_loss_meter.update(loss.item()) elif task == 'refcoco': refcoco_loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f}' desc_str += f" |" if 'vqa' in self.args.tasks: desc_str += f" VQA {task_counter['vqa']}" if 'gqa' in self.args.tasks: desc_str += f" GQA {task_counter['gqa']}" if 'nlvr' in self.args.tasks: desc_str += f" NLVR {task_counter['nlvr']}" if 'vcr' in self.args.tasks: desc_str += f" VCR {task_counter['vcr']}" if 'refcoco' in self.args.tasks: desc_str += f" RefCOCOg {task_counter['refcoco']}" if 'caption' in self.args.tasks: desc_str += f" COCO {task_counter['caption']}" if 'mmt' in self.args.tasks: desc_str += f" MMT {task_counter['mmt']}" if len(vqa_loss_meter) > 0: desc_str += f' | VQA Loss {vqa_loss_meter.val:4f}' if len(refcoco_loss_meter) > 0: desc_str += f' | RefCOCOg Loss {refcoco_loss_meter.val:.3f}' if self.args.log_train_accuracy and n_total > 0: desc_str += f' | RefCOCOg Acc' desc_str += f' Correct {n_correct:.0f}' desc_str += f' (Acc {n_correct/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() self.save("Epoch%02d" % (epoch + 1)) if self.args.log_train_accuracy: train_score_dict = {'n_correct': n_correct, 'n_total': n_total} train_score_dict = reduce_dict(train_score_dict, self.args.gpu) if self.verbose: # Validation log_str = '' wandb_log_dict = {} if 'vqa' in self.args.tasks: # VQA vqa_val_loader = self.val_loader['vqa'] score_dict = self.vqa_evaluate(vqa_val_loader) valid_score = score_dict['topk_score'] * 100. valid_score_raw = score_dict['overall'] if valid_score_raw > best_vqa_valid or epoch == 0: best_vqa_valid = valid_score_raw best_vqa_epoch = epoch # self.save("VQA_BEST") log_str += f"VQA" log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % ( epoch, valid_score_raw, valid_score) log_str += "\nEpoch %d: Best Raw %0.2f\n" % ( best_vqa_epoch, best_vqa_valid) wandb_log_dict['VQA/Valid/score'] = valid_score wandb_log_dict['VQA/Valid/raw_score'] = score_dict[ 'overall'] if 'gqa' in self.args.tasks: # GQA gqa_val_loader = self.val_loader['gqa'] valid_score = self.gqa_evaluate(gqa_val_loader) * 100 if valid_score > best_gqa_valid or epoch == 0: best_gqa_valid = valid_score best_gqa_epoch = epoch wandb_log_dict['GQA/Valid/Acc'] = valid_score log_str += f"GQA" log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_gqa_epoch, best_gqa_valid) if 'nlvr' in self.args.tasks: # NLVR nlvr_val_loader = self.val_loader['nlvr'] valid_score_dict = self.nlvr_evaluate(nlvr_val_loader) valid_acc = valid_score_dict['accuracy'] * 100. if valid_acc > best_nlvr_valid or epoch == 0: best_nlvr_valid = valid_acc best_nlvr_epoch = epoch wandb_log_dict['NLVR/Valid/Acc'] = valid_acc log_str += f"NLVR" log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_acc) log_str += "\nEpoch %d: Best %0.2f\n" % (best_nlvr_epoch, best_nlvr_valid) if 'vcr' in self.args.tasks: # VCR vcr_val_loader = self.val_loader['vcr'] valid_score_dict = self.vcr_evaluate(vcr_val_loader) valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[ 'n_total'] * 100 valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[ 'n_total'] * 100 valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[ 'n_total'] * 100 valid_n_total = int(valid_score_dict['n_total']) if valid_Q_AR > best_valid_Q_AR or epoch == 0: best_valid_Q_AR = valid_Q_AR best_vcr_epoch = epoch wandb_log_dict['VCR/Valid/Q_A'] = valid_Q_A wandb_log_dict['VCR/Valid/QA_R'] = valid_QA_R wandb_log_dict['VCR/Valid/Q_AR'] = valid_Q_AR log_str += f"VCR" log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_Q_AR) log_str += "\nEpoch %d: Best %0.2f\n" % (best_vcr_epoch, best_valid_Q_AR) if 'refcoco' in self.args.tasks: # RefCOCO refcoco_val_loader = self.val_loader['refcoco'] if self.args.log_train_accuracy: train_acc = train_score_dict[ 'n_correct'] / train_score_dict['n_total'] * 100 train_n_correct = int(train_score_dict['n_correct']) train_n_total = int(train_score_dict['n_total']) valid_score_dict = self.refcoco_evaluate( refcoco_val_loader) valid_acc = valid_score_dict[ 'n_correct'] / valid_score_dict['n_total'] * 100 valid_n_correct = int(valid_score_dict['n_correct']) valid_n_total = int(valid_score_dict['n_total']) if valid_acc > best_refcoco_valid or epoch == 0: best_refcoco_valid = valid_acc best_refcoco_epoch = epoch if self.args.log_train_accuracy: wandb_log_dict['RefCOCO/Train/Acc'] = train_acc wandb_log_dict['RefCOCO/Valid/Acc'] = valid_acc log_str += f"RefCOCOg" if self.args.log_train_accuracy: log_str += f"\nEpoch {epoch}: Train" log_str += f" Acc {train_acc:.2f}% |" log_str += f" # correct {train_n_correct} # total {train_n_total}" log_str += f"\nEpoch {epoch}: Valid" log_str += f" Acc {valid_acc:.2f}% |" log_str += f" # correct {valid_n_correct} # total {valid_n_total}" log_str += f"\nEpoch {best_refcoco_epoch}: Best Acc {best_refcoco_valid:.2f}%\n" if 'caption' in self.args.tasks: # COCO Caption caption_val_loader = self.val_loader['caption'] valid_results = self.caption_evaluate(caption_val_loader) valid_score = valid_results['CIDEr'] * 100 if valid_score > best_caption_valid or epoch == 0: best_caption_valid = valid_score best_caption_epoch = epoch for score_name, score in valid_results.items(): wandb_log_dict[ f'Caption/Valid/{score_name}'] = score * 100 log_str += f"COCO Caption" log_str += "\nEpoch %d: Valid CIDEr %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % ( best_caption_epoch, best_caption_valid) if 'mmt' in self.args.tasks: # MMT mmt_val_loader = self.val_loader['mmt'] valid_results = self.mmt_evaluate(mmt_val_loader) valid_score = valid_results['BLEU'] if valid_score > best_mmt_valid: best_mmt_valid = valid_score best_mmt_epoch = epoch for score_name, score in valid_results.items(): wandb_log_dict[f'MMT/Valid/{score_name}'] = score log_str += f"Multi30K En-De" log_str += "\nEpoch %d: Valid BLEU %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_mmt_epoch, best_mmt_valid) wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() # Test Set if self.verbose: self.save("LAST") log_str = '' wandb_log_dict = {} if 'vqa' in self.args.tasks: # VQA vqa_test_loader = self.test_loader['vqa'] evaluator = vqa_test_loader.evaluator dump_path = os.path.join(self.args.output, 'karpathy_test_predict.json') quesid2ans = self.vqa_predict(vqa_test_loader, dump_path) wandb.save(dump_path, base_path=self.args.output) acc_dict_all = evaluator.evaluate_raw(quesid2ans) acc_dict_answerable = evaluator.evaluate_raw( quesid2ans, is_topk_optimal=True) acc_dict_unanswerable = evaluator.evaluate_raw( quesid2ans, is_topk_optimal=False) wandb_log_dict['VQA/Test/overall'] = acc_dict_all['overall'] wandb_log_dict['VQA/Test/topk_optimal'] = acc_dict_answerable[ 'overall'] wandb_log_dict[ 'VQA/Test/topk_not_optimal'] = acc_dict_unanswerable[ 'overall'] vqa_submit_test_loader = self.test_loader['vqa_submit'] dump_path = os.path.join(self.args.output, 'vqa_submit.json') self.vqa_predict(vqa_submit_test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) if 'nlvr' in self.args.tasks: # NLVR nlvr_test_loader = self.test_loader['nlvr'] dump_path = os.path.join(self.args.output, 'nlvr_submit.csv') test_score_dict = self.nlvr_evaluate(nlvr_test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) for score_name, score in test_score_dict.items(): wandb_log_dict[f'NLVR/Test/{score_name}'] = score * 100. if 'refcoco' in self.args.tasks: # RefCOCO refcoco_test_loader = self.test_loader['refcoco'] test_score_dict = self.refcoco_evaluate(refcoco_test_loader) test_acc = test_score_dict['n_correct'] / test_score_dict[ 'n_total'] * 100 test_n_correct = int(test_score_dict['n_correct']) test_n_total = int(test_score_dict['n_total']) wandb_log_dict['RefCOCO/test/Acc'] = test_acc log_str = 'RefCOCOg' log_str += f"\nTest Acc {test_acc:.2f}%" log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}" if 'caption' in self.args.tasks: # COCO Caption caption_test_loader = self.test_loader['caption'] test_results = self.caption_evaluate(caption_test_loader) for score_name, score in test_results.items(): wandb_log_dict[f'Caption/Test/{score_name}'] = score if 'mmt' in self.args.tasks: # MMT mmt_test2016_loader = self.test_loader['mmt_test2016'] mmt_test2017_loader = self.test_loader['mmt_test2017'] mmt_test2018_loader = self.test_loader['mmt_test2018'] for loader in [ mmt_test2016_loader, mmt_test2017_loader, mmt_test2018_loader ]: split = loader.dataset.source dump_path = os.path.join(self.args.output, f'submit_{split}_raw.txt') test_results = self.mmt_evaluate(loader, dump_path=dump_path) for score_name, score in test_results.items(): wandb_log_dict[f'MMT/{split}/{score_name}'] = score log_str += f'{split} set results\n' log_str += pformat(test_results) print(log_str) wandb.log(wandb_log_dict, step=self.args.epochs) wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()