def evaluate_epoch(self, epoch): LOSSES_NAME = self.args.LOSSES_NAME epoch_results = {} for loss_name in LOSSES_NAME: epoch_results[loss_name] = 0. epoch_results[f'{loss_name}_count'] = 0 uid2ans = {} self.model.eval() with torch.no_grad(): if self.verbose: loss_meter = LossMeter() loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))] pbar = tqdm(total=len(self.val_loader), ncols=250) for step_i, batch in enumerate(self.val_loader): if self.args.distributed: results = self.model.module.valid_step(batch) else: results = self.model.valid_step(batch) if 'qa' in self.args.losses: qa_pred = results['qa_pred'] for uid, ans in zip(batch['uid'], qa_pred): uid2ans[uid] = ans for k, v in results.items(): if k in epoch_results: if isinstance(v, int): epoch_results[k] += v elif isinstance(v, torch.Tensor): epoch_results[k] += v.item() if self.verbose: desc_str = f'Valid Epoch {epoch} |' for i, (loss_name, loss_meter) in enumerate(zip(LOSSES_NAME, loss_meters)): if loss_name in results: loss_meter.update(results[f'{loss_name}'] / results[f'{loss_name}_count']) if len(loss_meter) > 0: loss_count = epoch_results[f'{loss_name}_count'] desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}' pbar.set_description(desc_str) pbar.update(1) dist.barrier() if self.verbose: pbar.close() dist.barrier() if 'qa' not in self.args.losses: uid2ans = None return epoch_results, uid2ans
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. best_epoch = 0 if 't5' in self.args.backbone: if self.args.use_vision: project_name = "VLT5_MMT" else: project_name = "T5_MMT" elif 'bart' in self.args.backbone: if self.args.use_vision: project_name = "VLBart_MMT" else: project_name = "Bart_MMT" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0., } for step_i, batch in enumerate(self.train_loader): # self.optim.zero_grad() if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] # print(f'GPU{self.args.gpu} after loss') if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(f'GPU{self.args.gpu} after backward') loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params( self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) update = True if self.args.gradient_accumulation_steps > 1: if step_i == 0: update = False elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1: update = True else: update = False if update: if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse(torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step}' desc_str += f' | Loss {loss_meter.val:4f}' pbar.set_description(desc_str) pbar.update(1) # if self.args.distributed: # dist.barrier() if self.verbose: pbar.close() # Validation valid_results = self.evaluate(self.val_loader) valid_score = valid_results['BLEU'] if valid_score > best_valid: best_valid = valid_score best_epoch = epoch self.save("BEST") log_str = '' log_str += pformat(valid_results) log_str += "\nEpoch %d: Valid BLEU %0.4f" % (epoch, valid_score) log_str += "\nEpoch %d: Best BLEU %0.4f\n" % (best_epoch, best_valid) wandb_log_dict = {} wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len(self.train_loader) for score_name, score in valid_results.items(): wandb_log_dict[f'Valid/{score_name}'] = score wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) if isinstance(self.test_loader, list): for loader in self.test_loader: split = loader.dataset.source dump_path = os.path.join(self.args.output, f'submit_{split}_raw.txt') test_results = self.evaluate(loader, dump_path=dump_path) wandb_log_dict = {} for score_name, score in test_results.items(): wandb_log_dict[f'{split}/{score_name}'] = score wandb.log(wandb_log_dict, step=epoch) log_str = f'{split} set results\n' log_str += pformat(test_results) print(log_str) wandb.save(dump_path, base_path=self.args.output) print('\nUploaded', dump_path) else: split = loader.dataset.source dump_path = os.path.join(self.args.output, f'submit_{split}_raw.txt') test_results = self.evaluate(loader, dump_path=dump_path) wandb_log_dict = {} for score_name, score in test_results.items(): wandb_log_dict[f'{split}/{score_name}'] = score wandb.log(wandb_log_dict, step=epoch) log_str = f'{split} set results\n' log_str += pformat(test_results) print(log_str) wandb.save(dump_path, base_path=self.args.output) print('\nUploaded', dump_path) wandb.log({'finished': True})
def train(self): if self.verbose: n_correct = 0 n_total = 0 for batch in self.val_loader: exists_target = batch['exists_target'] n_correct += exists_target.sum().item() n_total += len(exists_target) print(f'Val Oracle acc: {n_correct / n_total * 100:.2f}%') n_correct = 0 n_total = 0 for batch in self.test_loader: exists_target = batch['exists_target'] n_correct += exists_target.sum().item() n_total += len(exists_target) print(f'Test Oracle acc: {n_correct / n_total * 100:.2f}%') if self.verbose: loss_meter = LossMeter() best_valid_acc = 0. best_epoch = 0 if 't5' in self.args.backbone: project_name = "VLT5_RefCOCOg" elif 'bart' in self.args.backbone: project_name = "VLBart_RefCOCOg" if self.args.RefCOCO_GT: project_name += '_GT' wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0, } n_correct = 0 n_total = 0 for step_i, batch in enumerate(self.train_loader): batch['log_train_accuracy'] = self.args.log_train_accuracy if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.args.log_train_accuracy: correct = results['correct'] n_correct += sum(correct) n_total += len(correct) if self.verbose: loss_meter.update(loss.item()) # acc_meter.update(results['acc'].item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:.3f} |' # desc_str += f' Acc {acc_meter.val:.3f} |' if self.args.log_train_accuracy: desc_str += f' Correct {n_correct:.0f}' desc_str += f' (Acc {n_correct/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() if self.args.log_train_accuracy: train_score_dict = {'n_correct': n_correct, 'n_total': n_total} train_score_dict = reduce_dict(train_score_dict, self.args.gpu) # Validation # valid_score_dict = self.evaluate(self.val_loader) # valid_score_dict = reduce_dict(valid_score_dict, self.args.gpu) if self.verbose: if self.args.log_train_accuracy: train_acc = train_score_dict[ 'n_correct'] / train_score_dict['n_total'] * 100 train_n_correct = int(train_score_dict['n_correct']) train_n_total = int(train_score_dict['n_total']) # Validation valid_score_dict = self.evaluate(self.val_loader) valid_acc = valid_score_dict['n_correct'] / valid_score_dict[ 'n_total'] * 100 valid_n_correct = int(valid_score_dict['n_correct']) valid_n_total = int(valid_score_dict['n_total']) if valid_acc > best_valid_acc: best_valid_acc = valid_acc best_epoch = epoch self.save("BEST") log_str = '' if self.args.log_train_accuracy: log_str += f"\nEpoch {epoch}: Train" log_str += f" Acc {train_acc:.2f}% |" log_str += f" # correct {train_n_correct} # total {train_n_total}" log_str += f"\nEpoch {epoch}: Valid" log_str += f" Acc {valid_acc:.2f}% |" log_str += f" # correct {valid_n_correct} # total {valid_n_total}" log_str += f"\nEpoch {best_epoch}: Best Acc {best_valid_acc:.2f}%\n" wandb_log_dict = {} if self.args.log_train_accuracy: wandb_log_dict['Train/Acc'] = train_acc wandb_log_dict['Valid/Acc'] = valid_acc wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) test_score_dict = self.evaluate(self.test_loader) test_acc = test_score_dict['n_correct'] / test_score_dict[ 'n_total'] * 100 test_n_correct = int(test_score_dict['n_correct']) test_n_total = int(test_score_dict['n_total']) wandb_log_dict = {} wandb_log_dict['Test/Acc'] = test_acc wandb.log(wandb_log_dict, step=epoch) log_str = '' log_str += f"\nTest Acc {test_acc:.2f}%" log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}" print(log_str) wandb.log({'finished': True}) if self.args.distributed: dist.barrier()
def train(cont=False): # for tensorboard tracking logger = get_logger() logger.info("(1) Initiating Training ... ") logger.info("Training on device: {}".format(device)) writer = SummaryWriter() # init model aux_layers = None if net == "SETR-PUP": aux_layers, model = get_SETR_PUP() elif net == "SETR-MLA": aux_layers, model = get_SETR_MLA() elif net == "TransUNet-Base": model = get_TransUNet_base() elif net == "TransUNet-Large": model = get_TransUNet_large() elif net == "UNet": model = UNet(CLASS_NUM) # prepare dataset cluster_model = get_clustering_model(logger) train_dataset = CityscapeDataset(img_dir=data_dir, img_dim=IMG_DIM, mode="train", cluster_model=cluster_model) valid_dataset = CityscapeDataset(img_dir=data_dir, img_dim=IMG_DIM, mode="val", cluster_model=cluster_model) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) logger.info("(2) Dataset Initiated. ") # optimizer epochs = epoch_num if epoch_num > 0 else iteration_num // len( train_loader) + 1 optim = SGD(model.parameters(), lr=lrate, momentum=momentum, weight_decay=wdecay) # optim = Adam(model.parameters(), lr=lrate) scheduler = lr_scheduler.MultiStepLR( optim, milestones=[int(epochs * fine_tune_ratio)], gamma=0.1) cur_epoch = 0 best_loss = float('inf') epochs_since_improvement = 0 # for continue training if cont: model, optim, cur_epoch, best_loss = load_ckpt_continue_training( best_ckpt_src, model, optim, logger) logger.info("Current best loss: {0}".format(best_loss)) with warnings.catch_warnings(): warnings.simplefilter("ignore") for i in range(cur_epoch): scheduler.step() else: model = nn.DataParallel(model) model = model.to(device) logger.info("(3) Model Initiated ... ") logger.info("Training model: {}".format(net) + ". Training Started.") # loss ce_loss = CrossEntropyLoss() if use_dice_loss: dice_loss = DiceLoss(CLASS_NUM) # loop over epochs iter_count = 0 epoch_bar = tqdm.tqdm(total=epochs, desc="Epoch", position=cur_epoch, leave=True) logger.info("Total epochs: {0}. Starting from epoch {1}.".format( epochs, cur_epoch + 1)) for e in range(epochs - cur_epoch): epoch = e + cur_epoch # Training. model.train() trainLossMeter = LossMeter() train_batch_bar = tqdm.tqdm(total=len(train_loader), desc="TrainBatch", position=0, leave=True) for batch_num, (orig_img, mask_img) in enumerate(train_loader): orig_img, mask_img = orig_img.float().to( device), mask_img.float().to(device) if net == "TransUNet-Base" or net == "TransUNet-Large": pred = model(orig_img) elif net == "SETR-PUP" or net == "SETR-MLA": if aux_layers is not None: pred, _ = model(orig_img) else: pred = model(orig_img) elif net == "UNet": pred = model(orig_img) loss_ce = ce_loss(pred, mask_img[:].long()) if use_dice_loss: loss_dice = dice_loss(pred, mask_img, softmax=True) loss = 0.5 * (loss_ce + loss_dice) else: loss = loss_ce # Backward Propagation, Update weight and metrics optim.zero_grad() loss.backward() optim.step() # update learning rate for param_group in optim.param_groups: orig_lr = param_group['lr'] param_group['lr'] = orig_lr * (1.0 - iter_count / iteration_num)**0.9 iter_count += 1 # Update loss trainLossMeter.update(loss.item()) # print status if (batch_num + 1) % print_freq == 0: status = 'Epoch: [{0}][{1}/{2}]\t' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(train_loader), loss=trainLossMeter) logger.info(status) # log loss to tensorboard if (batch_num + 1) % tensorboard_freq == 0: writer.add_scalar( 'Train_Loss_{0}'.format(tensorboard_freq), trainLossMeter.avg, epoch * (len(train_loader) / tensorboard_freq) + (batch_num + 1) / tensorboard_freq) train_batch_bar.update(1) writer.add_scalar('Train_Loss_epoch', trainLossMeter.avg, epoch) # Validation. model.eval() validLossMeter = LossMeter() valid_batch_bar = tqdm.tqdm(total=len(valid_loader), desc="ValidBatch", position=0, leave=True) with torch.no_grad(): for batch_num, (orig_img, mask_img) in enumerate(valid_loader): orig_img, mask_img = orig_img.float().to( device), mask_img.float().to(device) if net == "TransUNet-Base" or net == "TransUNet-Large": pred = model(orig_img) elif net == "SETR-PUP" or net == "SETR-MLA": if aux_layers is not None: pred, _ = model(orig_img) else: pred = model(orig_img) elif net == "UNet": pred = model(orig_img) loss_ce = ce_loss(pred, mask_img[:].long()) if use_dice_loss: loss_dice = dice_loss(pred, mask_img, softmax=True) loss = 0.5 * (loss_ce + loss_dice) else: loss = loss_ce # Update loss validLossMeter.update(loss.item()) # print status if (batch_num + 1) % print_freq == 0: status = 'Validation: [{0}][{1}/{2}]\t' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(valid_loader), loss=validLossMeter) logger.info(status) # log loss to tensorboard if (batch_num + 1) % tensorboard_freq == 0: writer.add_scalar( 'Valid_Loss_{0}'.format(tensorboard_freq), validLossMeter.avg, epoch * (len(valid_loader) / tensorboard_freq) + (batch_num + 1) / tensorboard_freq) valid_batch_bar.update(1) valid_loss = validLossMeter.avg writer.add_scalar('Valid_Loss_epoch', valid_loss, epoch) logger.info("Validation Loss of epoch [{0}/{1}]: {2}\n".format( epoch + 1, epochs, valid_loss)) # update optim scheduler scheduler.step() # save checkpoint is_best = valid_loss < best_loss best_loss_tmp = min(valid_loss, best_loss) if not is_best: epochs_since_improvement += 1 logger.info("Epochs since last improvement: %d\n" % (epochs_since_improvement, )) if epochs_since_improvement == early_stop_tolerance: break # early stopping. else: epochs_since_improvement = 0 state = { 'epoch': epoch, 'loss': best_loss_tmp, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optim.state_dict(), } torch.save(state, ckpt_src) logger.info("Checkpoint updated.") best_loss = best_loss_tmp epoch_bar.update(1) writer.close()
def train(self): if self.verbose: loss_meter = LossMeter() quesid2ans = {} best_valid = 0. print("Valid Oracle: %0.2f" % (self.oracle_score(self.val_loader) * 100)) from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=self.args.log_dir) hparam_dict = {} for k, v in self.args.__dict__.items(): if type(v) in [int, float, str, bool, torch.Tensor]: hparam_dict[k] = v metric_dict = {} self.writer.add_hparams(hparam_dict, metric_dict) if self.args.distributed: dist.barrier() self.optim.zero_grad() for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=150) quesid2ans = {} for step_i, batch in enumerate(self.train_loader): update = True if self.args.update_freq > 1: if step_i == 0: update = False elif step_i % self.args.update_freq == 0 or step_i == len( self.train_loader) - 1: update = True else: update = False vis_feats = batch['vis_feats'].cuda() boxes = batch['boxes'].cuda() input_ids = batch['word_ids'].cuda() target = batch['targets'].cuda() ques_id = batch['question_ids'] B = len(batch['word_ids']) results = self.model( input_ids=input_ids, visual_feats=vis_feats, visual_pos=boxes, attention_mask=input_ids > 0, ) logit = results['logit'] assert logit.size() == target.size() assert logit.size() == (B, self.num_answers) loss = self.bce_loss(logit, target) loss.backward() if update: if not self.args.no_clip_grad: nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm) self.optim.step() self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None try: lr = self.scheduler.get_last_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:4f} |' score, predict = logit.max(1) predict = predict.cpu().numpy() target = target.cpu().numpy() for qid, pred in zip(ques_id, predict): pred_ans = self.train_loader.dataset.raw_dataset.label2ans[ pred] quesid2ans[qid] = pred_ans pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() # score, label = logit.max(1) # for qid, l in zip(ques_id, label.cpu().numpy()): # ans = dset.label2ans[l] # quesid2ans[qid.item()] = ans if self.verbose: pbar.close() score = self.train_loader.evaluator.evaluate(quesid2ans) * 100. log_str = "\nEpoch %d: Train %0.2f" % (epoch, score) if not self.args.dry: self.writer.add_scalar(f'GQA/Train/score', score, epoch) # Validation valid_score = self.evaluate(self.val_loader) * 100. if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (epoch, best_valid) if not self.args.dry: self.writer.add_scalar(f'GQA/Valid/score', valid_score, epoch) print(log_str) self.logger.info(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST")
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. best_epoch = 0 if 't5' in self.args.backbone: if self.args.use_vision: project_name = "VLT5_VQA" else: project_name = "T5_VQA" elif 'bart' in self.args.backbone: if self.args.use_vision: project_name = "VLBart_VQA" else: project_name = "Bart_VQA" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0., } quesid2ans = {} for step_i, batch in enumerate(self.train_loader): if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params( self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse(torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f}' desc_str += f' | Loss {loss_meter.val:4f}' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() # Validation score_dict = self.evaluate(self.val_loader) if self.verbose: valid_score = score_dict['topk_score'] * 100. valid_score_raw = score_dict['overall'] if valid_score_raw > best_valid or epoch == 0: best_valid = valid_score_raw best_epoch = epoch self.save("BEST") log_str = '' log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % (epoch, valid_score_raw, valid_score) log_str += "\nEpoch %d: Best Raw %0.2f\n" % (best_epoch, best_valid) wandb_log_dict = {} wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len(self.train_loader) wandb_log_dict['Valid/score'] = valid_score wandb_log_dict['Valid/raw_score'] = score_dict['overall'] for qtype, score in score_dict['perQuestionType'].items(): wandb_log_dict[f'Valid_Qtypes/{qtype}'] = score for atype, score in score_dict['perAnswerType'].items(): if atype == 'yes/no': atype = 'yes_no' wandb_log_dict[f'Valid_Atypes/{atype}'] = score wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) quesid2ans = self.predict(self.test_loader) if self.verbose: evaluator = self.test_loader.evaluator score_dict = evaluator.evaluate(quesid2ans) evaluator.dump_result(quesid2ans) acc_dict_all = evaluator.evaluate_raw(quesid2ans) acc_dict_answerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=True) acc_dict_unanswerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=False) wandb_log_dict = {} wandb_log_dict['Test/overall'] = acc_dict_all['overall'] wandb_log_dict['Test/topk_optimal'] = acc_dict_answerable['overall'] wandb_log_dict['Test/topk_not_optimal'] = acc_dict_unanswerable['overall'] for qtype, score in acc_dict_all['perQuestionType'].items(): wandb_log_dict[f'Test_Qtypes/{qtype}'] = score for atype, score in acc_dict_all['perAnswerType'].items(): if atype == 'yes/no': atype = 'yes_no' wandb_log_dict[f'Test_Atypes/{atype}'] = score print(wandb_log_dict) wandb.log(wandb_log_dict) if self.args.submit: dump_path = os.path.join(self.args.output, 'submit.json') self.predict(self.submit_test_loader, dump_path) wandb.save(dump_path, base_path=self.args.output) wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. best_epoch = 0 if 't5' in self.args.backbone: if self.args.use_vision: project_name = "VLT5_GQA" else: project_name = "T5_GQA" elif 'bart' in self.args.backbone: if self.args.use_vision: project_name = "VLBart_GQA" else: project_name = "Bart_GQA" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() # torch.autograd.set_detect_anomaly(True) # print(f'GPU{self.args.gpu} before training starts') global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0., } quesid2ans = {} train_acc = 0. train_acc_steps = int(len(self.train_loader) * 0.05) last_acc_step = 0 # print(f'GPU{self.args.gpu} before training loop') for step_i, batch in enumerate(self.train_loader): # print(f'GPU{self.args.gpu} inside training loop') # print(batch) # self.optim.zero_grad() if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] # print(f'GPU{self.args.gpu} after loss') if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(f'GPU{self.args.gpu} after backward') loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f}' desc_str += f' | Loss {loss_meter.val:4f}' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() log_str = '' # Validation valid_score = self.evaluate(self.val_loader) * 100. if valid_score > best_valid: best_valid = valid_score best_epoch = epoch self.save("BEST") log_str += "\nEpoch %d: Testdev %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_epoch, best_valid) wandb_log_dict = {} wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len( self.train_loader) wandb_log_dict['Testdev/score'] = valid_score wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) dump_path = os.path.join(self.args.output, 'submit.json') self.predict(self.test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()
def evaluate_epoch(self, epoch): LOSSES_NAME = self.args.LOSSES_NAME task_dict = { 'Mask_LM': 'word_mask', 'Matched': 'matched', 'Mask_Obj': 'vis_mask', 'Mask_Attr': 'vis_mask', 'Mask_Feat': 'vis_mask', 'QA': 'qa' } epoch_results = { 'lm_loss': 0, 'vis_loss': 0, 'matched_loss': 0, 'qa_loss': 0, 'obj_loss': 0, 'feat_loss': 0, 'attr_loss': 0, } for k in list(epoch_results.keys()): if k[-4:] == 'loss': epoch_results[f'{k}_count'] = 0 uid2ans = {} self.model.eval() with torch.no_grad(): if self.verbose: loss_meter = LossMeter() loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))] loss_counts = [0 for _ in range(len(LOSSES_NAME))] pbar = tqdm(total=len(self.val_loader), ncols=180) for step_i, batch in enumerate(self.val_loader): # task = random.choice(self.args.MASK_MODALITY) task_i = step_i % len(self.args.MASK_MODALITY) task = self.args.MASK_MODALITY[task_i] if self.args.vis_mask_COCO_only or self.args.vis_mask_COCOVG_only: if task == 'vis_mask': batch['word_id'] = batch['COCO_word_id'] if self.args.clustering: batch['cluster_id'] = batch['COCO_cluster_id'] results = self.forward(batch, task) if task == 'vis_mask': epoch_results['vis_loss_count'] += 1 if 'Mask_Obj' in LOSSES_NAME: epoch_results['obj_loss_count'] += 1 if 'Mask_Feat' in LOSSES_NAME: epoch_results['feat_loss_count'] += 1 if 'Mask_Attr' in LOSSES_NAME: epoch_results['attr_loss_count'] += 1 elif task == 'word_mask': epoch_results['lm_loss_count'] += 1 elif task == 'matched': epoch_results['matched_loss_count'] += 1 elif task == 'qa': epoch_results['qa_loss_count'] += 1 if self.args.task_qa: qa_pred = results['qa_pred'] for uid, ans_id in zip(batch['uid'], qa_pred.cpu().numpy()): ans = self.train_loader.dataset.answer_table.id2ans( ans_id) uid2ans[uid] = ans for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.verbose: desc_str = f'Valid Epoch {epoch} | ' # if self.args.task_qa: # loss_meter.update(results['qa_loss'].item()) if self.args.task_qa: loss_meter = loss_meters[-1] loss_meter.update(results['qa_loss'].item()) loss_counts[-1] += 1 for i, (loss_name, loss_meter) in enumerate( zip(LOSSES_NAME, loss_meters)): if task_dict[loss_name] == task: if task == 'vis_mask': if loss_name == 'Mask_Obj': loss_meter.update( results['obj_loss'].item()) elif loss_name == 'Mask_Attr': loss_meter.update( results['attr_loss'].item()) elif loss_name == 'Mask_Feat': loss_meter.update( results['feat_loss'].item()) elif task == 'word_mask': loss_meter.update(results['lm_loss'].item()) elif task == 'matched': loss_meter.update( results['matched_loss'].item()) # elif task == 'qa': # loss_meter.update(results['qa_loss'].item()) loss_counts[i] += 1 if len(loss_meter) > 0: loss_count = loss_counts[i] if loss_name in [ 'Mask_LM', 'Matched', 'Mask_Obj', 'Mask_Attr', 'Mask_Feat', 'QA' ]: desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.2f}' else: desc_str += f' {loss_name} {loss_meter.val:.2f}' pbar.set_description(desc_str) pbar.update(1) dist.barrier() if self.verbose: pbar.close() dist.barrier() if not self.args.task_qa: uid2ans = None return epoch_results, uid2ans
def train(self): if self.verbose: loss_meter = LossMeter() qa_loss_meter = LossMeter() qar_loss_meter = LossMeter() best_valid_Q_AR = 0. best_epoch = 0 if 't5' in self.args.backbone: if self.args.use_vision: project_name = "VLT5_VCR" else: project_name = "T5_VCR" elif 'bart' in self.args.backbone: if self.args.use_vision: project_name = "VLBart_VCR" else: project_name = "Bart_VCR" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=200) epoch_results = { 'loss': 0, } Q_A_results = 0 QA_R_results = 0 Q_AR_results = 0 n_total = 0 n_accu = 0 train_loss = 0 train_qa_loss = 0 train_qar_loss = 0 for step_i, batch in enumerate(self.train_loader): batch['log_train_accuracy'] = self.args.log_train_accuracy if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.gradient_accumulation_steps > 1: loss /= self.args.gradient_accumulation_steps if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) update = True if self.args.gradient_accumulation_steps > 1: # if step_i == 0: # update = False # elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1: # update = True # else: # update = False update = ((step_i + 1) % self.args.gradient_accumulation_steps == 0) or (step_i == len(self.train_loader) - 1) n_accu += 1 if update: if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: train_loss += loss.detach().item() train_qa_loss += results['qa_loss'].item( ) / self.args.gradient_accumulation_steps train_qar_loss += results['qar_loss'].item( ) / self.args.gradient_accumulation_steps if self.args.log_train_accuracy: qa_pred = results['qa_pred'] qar_pred = results['qar_pred'] a_labels = batch['answer_labels'] r_labels = batch['rationale_labels'] Q_A_correct = a_labels == qa_pred QA_R_correct = r_labels == qar_pred Q_AR_correct = Q_A_correct & QA_R_correct Q_A_results += sum(Q_A_correct) QA_R_results += sum(QA_R_correct) Q_AR_results += sum(Q_AR_correct) n_total += len(qa_pred) if update: if self.args.gradient_accumulation_steps > 1: train_loss *= self.args.gradient_accumulation_steps / n_accu train_qa_loss *= self.args.gradient_accumulation_steps / n_accu train_qar_loss *= self.args.gradient_accumulation_steps / n_accu loss_meter.update(train_loss) qa_loss_meter.update(train_qa_loss) qar_loss_meter.update(train_qar_loss) desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step} |' desc_str += f' Loss {loss_meter.val:.3f} |' desc_str += f' QA Loss {qa_loss_meter.val:.3f} |' desc_str += f' QAR Loss {qar_loss_meter.val:.3f} |' train_loss = 0 train_qa_loss = 0 train_qar_loss = 0 n_accu = 0 if self.args.log_train_accuracy: desc_str += f' Q -> A {Q_A_results} ({Q_A_results/n_total*100:.1f}%)' desc_str += f' QA -> R {QA_R_results} ({QA_R_results/n_total*100:.1f}%)' desc_str += f' Q -> AR {Q_AR_results} ({Q_AR_results/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.verbose: pbar.close() if self.args.log_train_accuracy: train_score_dict = { 'Q_A': Q_A_results, 'QA_R': QA_R_results, 'Q_AR': Q_AR_results, 'n_total': n_total } train_score_dict = reduce_dict(train_score_dict, self.args.gpu) # Validation valid_score_dict = self.evaluate_val(self.val_loader) if self.verbose: if self.args.log_train_accuracy: train_Q_A = train_score_dict['Q_A'] / train_score_dict[ 'n_total'] * 100 train_QA_R = train_score_dict['QA_R'] / train_score_dict[ 'n_total'] * 100 train_Q_AR = train_score_dict['Q_AR'] / train_score_dict[ 'n_total'] * 100 train_n_total = int(train_score_dict['n_total']) valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[ 'n_total'] * 100 valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[ 'n_total'] * 100 valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[ 'n_total'] * 100 valid_n_total = int(valid_score_dict['n_total']) if valid_Q_AR > best_valid_Q_AR: best_valid_Q_AR = valid_Q_AR best_epoch = epoch self.save("BEST") log_str = '' if self.args.log_train_accuracy: log_str += f"\nEpoch {epoch}: Train |" log_str += f" # examples: {train_n_total} |" log_str += f" Q -> A {train_Q_A:.2f}%" log_str += f" QA -> R {train_QA_R:.2f}%" log_str += f" Q -> AR {train_Q_AR:.2f}%" log_str += f"\nEpoch {epoch}: Valid |" log_str += f" # examples: {valid_n_total} |" log_str += f" Q -> A {valid_Q_A:.2f}%" log_str += f" QA -> R {valid_QA_R:.2f}%" log_str += f" Q -> AR {valid_Q_AR:.2f}%" #log_str += "\nEpoch %d: Valid Q -> AR %0.2f" % (epoch, valid_Q_AR) log_str += f"\nBest Epoch {best_epoch}: Q -> AR {best_valid_Q_AR:.2f}%\n" wandb_log_dict = {} # wandb_log_dict['Train/Loss'] = loss_meter.val if self.args.log_train_accuracy: wandb_log_dict['Train/Q_A'] = train_Q_A wandb_log_dict['Train/QA_R'] = train_QA_R wandb_log_dict['Train/Q_AR'] = train_Q_AR wandb_log_dict['Valid/Q_A'] = valid_Q_A wandb_log_dict['Valid/QA_R'] = valid_QA_R wandb_log_dict['Valid/Q_AR'] = valid_Q_AR wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) if self.verbose: dump_path = os.path.join(self.args.output, 'test_submit.csv') print('Dumping test set results at', dump_path) self.evaluate_test(self.test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) print('Done!') wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=self.args.log_dir) hparam_dict = {} for k, v in self.args.__dict__.items(): if type(v) in [int, float, str, bool, torch.Tensor]: hparam_dict[k] = v metric_dict = {} self.writer.add_hparams(hparam_dict, metric_dict) if self.args.distributed: dist.barrier() self.optim.zero_grad() for epoch in range(self.args.epochs): self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=150) results = np.zeros(4, dtype=int) quesid2ans = {} for step_i, batch in enumerate(self.train_loader): vis_feats = batch['vis_feats'].cuda() boxes = batch['boxes'].cuda() ques_id = batch['question_ids'] B = len(ques_id) input_ids = batch['word_ids'].cuda() input_ids = input_ids.unsqueeze(1).repeat(1, 2, 1).view(B * 2, -1) label = batch['labels'].cuda() results = self.model( input_ids=input_ids, visual_feats=vis_feats, visual_pos=boxes, attention_mask=input_ids > 0, ) logit = results['logit'] loss = self.mce_loss(logit, label) loss.backward() update = True if self.args.update_freq > 1: if step_i == 0: update = False elif step_i % self.args.update_freq == 0 or step_i == len( self.train_loader) - 1: update = True else: update = False if update: if not self.args.no_clip_grad: nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm) self.optim.step() self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None try: lr = self.scheduler.get_last_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:4f} |' score, predict = logit.max(1) predict = predict.cpu().numpy() label = label.cpu().numpy() for qid, pred in zip(ques_id, predict): quesid2ans[qid] = pred results[0] += sum((label == 1) & (predict == 1)) results[1] += sum((label == 1) & (predict == 0)) results[2] += sum((label == 0) & (predict == 1)) results[3] += sum((label == 0) & (predict == 0)) n_total = sum(results) desc_str += f' TP {results[0]} ({results[0]/n_total*100:.1f}%)' desc_str += f' FN {results[1]} ({results[1]/n_total*100:.1f}%)' desc_str += f' FP {results[2]} ({results[2]/n_total*100:.1f}%)' desc_str += f' TN {results[3]} ({results[3]/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() score = self.train_loader.evaluator.evaluate(quesid2ans) * 100. log_str = "\nEpoch %d: Train %0.2f" % (epoch, score) if not self.args.dry: self.writer.add_scalar(f'NLVR/Train/score', score, epoch) # Validation valid_score = self.evaluate(self.val_loader) * 100. if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (epoch, best_valid) if not self.args.dry: self.writer.add_scalar(f'NLVR/Valid/score', valid_score, epoch) print(log_str) self.logger.info(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST")
def train(self): if self.verbose: loss_meter = LossMeter() # best_eval_loss = 9595. quesid2ans = {} best_valid = 0. best_epoch = 0 if 't5' in self.args.backbone: project_name = "VLT5_NLVR" elif 'bart' in self.args.backbone: project_name = "VLBART_NLVR" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=150) nlvr_results = np.zeros(4, dtype=int) epoch_results = { 'loss': 0, } quesid2ans = {} train_acc = 0. train_acc_steps = int(len(self.train_loader) * 0.05) last_acc_step = 0 for step_i, batch in enumerate(self.train_loader): self.optim.zero_grad() if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] # print(f'GPU{self.args.gpu} after loss') if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(f'GPU{self.args.gpu} after backward') loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params( self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse(torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:4f} |' pred_ans = results['pred_ans_id'] ques_ids = batch['question_ids'] for qid, ans in zip(ques_ids, pred_ans): quesid2ans[qid] = ans label = batch['labels'].cpu().numpy() predict = results['pred_ans_id'] nlvr_results[0] += sum((label == 1) & (predict == 1)) nlvr_results[1] += sum((label == 1) & (predict == 0)) nlvr_results[2] += sum((label == 0) & (predict == 1)) nlvr_results[3] += sum((label == 0) & (predict == 0)) n_total = sum(nlvr_results) desc_str += f' TP {nlvr_results[0]} ({nlvr_results[0]/n_total*100:.1f}%)' desc_str += f' FN {nlvr_results[1]} ({nlvr_results[1]/n_total*100:.1f}%)' desc_str += f' FP {nlvr_results[2]} ({nlvr_results[2]/n_total*100:.1f}%)' desc_str += f' TN {nlvr_results[3]} ({nlvr_results[3]/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() log_str = '' # train_score_dict = self.train_loader.evaluator.evaluate(quesid2ans) # train_acc = train_score_dict['accuracy'] * 100. # train_cons = train_score_dict['consistency'] * 100. train_acc = self.train_loader.evaluator.evaluate_train(quesid2ans) * 100. train_score = train_acc log_str += "\nEpoch %d: Train %0.2f" % (epoch, train_score) # Validation valid_score_dict = self.evaluate(self.val_loader) valid_acc = valid_score_dict['accuracy'] * 100. # valid_cons = valid_score_dict['consistency'] * 100. valid_score = valid_acc if valid_score > best_valid: best_valid = valid_score best_epoch = epoch self.save("BEST") log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_epoch, best_valid) wandb_log_dict = {} # wandb_log_dict['Train/Loss'] = loss_meter.val # wandb_log_dict['Train/score'] = score # wandb_log_dict['Valid/score'] = valid_score # for score_name, score in train_score_dict.items(): # wandb_log_dict[f'Train/{score_name}'] = score * 100. wandb_log_dict['Train/accuracy'] = train_acc for score_name, score in valid_score_dict.items(): wandb_log_dict[f'Valid/{score_name}'] = score * 100. wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) log_str = 'Test set results\n' dump_path = os.path.join(self.args.output, 'submit.csv') test_score_dict = self.evaluate(self.test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) wandb_log_dict = {} for score_name, score in test_score_dict.items(): wandb_log_dict[f'Test/{score_name}'] = score * 100. wandb.log(wandb_log_dict, step=epoch) from pprint import pformat log_str += pformat(test_score_dict) print(log_str) wandb.log({'finished': True})
def train(self): if self.verbose: vqa_loss_meter = LossMeter() refcoco_loss_meter = LossMeter() # best_eval_loss = 9595. quesid2ans = {} best_vqa_valid = 0. best_vqa_epoch = 0 # gqa best_gqa_valid = 0 best_gqa_epoch = 0 # nlvr best_nlvr_valid = 0 best_nlvr_epoch = 0 # vcr best_valid_Q_AR = 0 best_vcr_epoch = 0 # refcoco best_refcoco_valid = 0 best_refcoco_epoch = 0 # caption best_caption_valid = 0 best_caption_epoch = 0 # mmt best_mmt_valid = 0 best_mmt_epoch = 0 assert 't5' in self.args.backbone self.setup_wandb() if self.args.distributed: dist.barrier() global_step = 0 for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=250) epoch_results = { 'loss': 0., } task_counter = { 'vqa': 0, 'gqa': 0, 'nlvr': 0, 'refcoco': 0, 'vcr': 0, 'caption': 0, 'mmt': 0, } # vqa quesid2ans = {} train_acc = 0. # train_acc_steps = int(len(self.train_loader) * 0.05) # last_acc_step = 0 # refcoco n_correct = 0 n_total = 0 for step_i, batch in enumerate(self.train_loader): # print(f'GPU{self.args.gpu} inside training loop') # print(batch) task = batch['task'] # if self.verbose: # print('task', task) task_counter[task] += 1 batch['log_train_accuracy'] = self.args.log_train_accuracy # self.optim.zero_grad() if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] # print(f'GPU{self.args.gpu} after loss') if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(f'GPU{self.args.gpu} after backward') loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr # self.train_step_post_hook(result) if self.args.log_train_accuracy and task == 'refcoco': correct = results['correct'] n_correct += sum(correct) n_total += len(correct) if self.verbose: if task == 'vqa': vqa_loss_meter.update(loss.item()) elif task == 'refcoco': refcoco_loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f}' desc_str += f" |" if 'vqa' in self.args.tasks: desc_str += f" VQA {task_counter['vqa']}" if 'gqa' in self.args.tasks: desc_str += f" GQA {task_counter['gqa']}" if 'nlvr' in self.args.tasks: desc_str += f" NLVR {task_counter['nlvr']}" if 'vcr' in self.args.tasks: desc_str += f" VCR {task_counter['vcr']}" if 'refcoco' in self.args.tasks: desc_str += f" RefCOCOg {task_counter['refcoco']}" if 'caption' in self.args.tasks: desc_str += f" COCO {task_counter['caption']}" if 'mmt' in self.args.tasks: desc_str += f" MMT {task_counter['mmt']}" if len(vqa_loss_meter) > 0: desc_str += f' | VQA Loss {vqa_loss_meter.val:4f}' if len(refcoco_loss_meter) > 0: desc_str += f' | RefCOCOg Loss {refcoco_loss_meter.val:.3f}' if self.args.log_train_accuracy and n_total > 0: desc_str += f' | RefCOCOg Acc' desc_str += f' Correct {n_correct:.0f}' desc_str += f' (Acc {n_correct/n_total*100:.1f}%)' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() self.save("Epoch%02d" % (epoch + 1)) if self.args.log_train_accuracy: train_score_dict = {'n_correct': n_correct, 'n_total': n_total} train_score_dict = reduce_dict(train_score_dict, self.args.gpu) if self.verbose: # Validation log_str = '' wandb_log_dict = {} if 'vqa' in self.args.tasks: # VQA vqa_val_loader = self.val_loader['vqa'] score_dict = self.vqa_evaluate(vqa_val_loader) valid_score = score_dict['topk_score'] * 100. valid_score_raw = score_dict['overall'] if valid_score_raw > best_vqa_valid or epoch == 0: best_vqa_valid = valid_score_raw best_vqa_epoch = epoch # self.save("VQA_BEST") log_str += f"VQA" log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % ( epoch, valid_score_raw, valid_score) log_str += "\nEpoch %d: Best Raw %0.2f\n" % ( best_vqa_epoch, best_vqa_valid) wandb_log_dict['VQA/Valid/score'] = valid_score wandb_log_dict['VQA/Valid/raw_score'] = score_dict[ 'overall'] if 'gqa' in self.args.tasks: # GQA gqa_val_loader = self.val_loader['gqa'] valid_score = self.gqa_evaluate(gqa_val_loader) * 100 if valid_score > best_gqa_valid or epoch == 0: best_gqa_valid = valid_score best_gqa_epoch = epoch wandb_log_dict['GQA/Valid/Acc'] = valid_score log_str += f"GQA" log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_gqa_epoch, best_gqa_valid) if 'nlvr' in self.args.tasks: # NLVR nlvr_val_loader = self.val_loader['nlvr'] valid_score_dict = self.nlvr_evaluate(nlvr_val_loader) valid_acc = valid_score_dict['accuracy'] * 100. if valid_acc > best_nlvr_valid or epoch == 0: best_nlvr_valid = valid_acc best_nlvr_epoch = epoch wandb_log_dict['NLVR/Valid/Acc'] = valid_acc log_str += f"NLVR" log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_acc) log_str += "\nEpoch %d: Best %0.2f\n" % (best_nlvr_epoch, best_nlvr_valid) if 'vcr' in self.args.tasks: # VCR vcr_val_loader = self.val_loader['vcr'] valid_score_dict = self.vcr_evaluate(vcr_val_loader) valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[ 'n_total'] * 100 valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[ 'n_total'] * 100 valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[ 'n_total'] * 100 valid_n_total = int(valid_score_dict['n_total']) if valid_Q_AR > best_valid_Q_AR or epoch == 0: best_valid_Q_AR = valid_Q_AR best_vcr_epoch = epoch wandb_log_dict['VCR/Valid/Q_A'] = valid_Q_A wandb_log_dict['VCR/Valid/QA_R'] = valid_QA_R wandb_log_dict['VCR/Valid/Q_AR'] = valid_Q_AR log_str += f"VCR" log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_Q_AR) log_str += "\nEpoch %d: Best %0.2f\n" % (best_vcr_epoch, best_valid_Q_AR) if 'refcoco' in self.args.tasks: # RefCOCO refcoco_val_loader = self.val_loader['refcoco'] if self.args.log_train_accuracy: train_acc = train_score_dict[ 'n_correct'] / train_score_dict['n_total'] * 100 train_n_correct = int(train_score_dict['n_correct']) train_n_total = int(train_score_dict['n_total']) valid_score_dict = self.refcoco_evaluate( refcoco_val_loader) valid_acc = valid_score_dict[ 'n_correct'] / valid_score_dict['n_total'] * 100 valid_n_correct = int(valid_score_dict['n_correct']) valid_n_total = int(valid_score_dict['n_total']) if valid_acc > best_refcoco_valid or epoch == 0: best_refcoco_valid = valid_acc best_refcoco_epoch = epoch if self.args.log_train_accuracy: wandb_log_dict['RefCOCO/Train/Acc'] = train_acc wandb_log_dict['RefCOCO/Valid/Acc'] = valid_acc log_str += f"RefCOCOg" if self.args.log_train_accuracy: log_str += f"\nEpoch {epoch}: Train" log_str += f" Acc {train_acc:.2f}% |" log_str += f" # correct {train_n_correct} # total {train_n_total}" log_str += f"\nEpoch {epoch}: Valid" log_str += f" Acc {valid_acc:.2f}% |" log_str += f" # correct {valid_n_correct} # total {valid_n_total}" log_str += f"\nEpoch {best_refcoco_epoch}: Best Acc {best_refcoco_valid:.2f}%\n" if 'caption' in self.args.tasks: # COCO Caption caption_val_loader = self.val_loader['caption'] valid_results = self.caption_evaluate(caption_val_loader) valid_score = valid_results['CIDEr'] * 100 if valid_score > best_caption_valid or epoch == 0: best_caption_valid = valid_score best_caption_epoch = epoch for score_name, score in valid_results.items(): wandb_log_dict[ f'Caption/Valid/{score_name}'] = score * 100 log_str += f"COCO Caption" log_str += "\nEpoch %d: Valid CIDEr %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % ( best_caption_epoch, best_caption_valid) if 'mmt' in self.args.tasks: # MMT mmt_val_loader = self.val_loader['mmt'] valid_results = self.mmt_evaluate(mmt_val_loader) valid_score = valid_results['BLEU'] if valid_score > best_mmt_valid: best_mmt_valid = valid_score best_mmt_epoch = epoch for score_name, score in valid_results.items(): wandb_log_dict[f'MMT/Valid/{score_name}'] = score log_str += f"Multi30K En-De" log_str += "\nEpoch %d: Valid BLEU %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (best_mmt_epoch, best_mmt_valid) wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() # Test Set if self.verbose: self.save("LAST") log_str = '' wandb_log_dict = {} if 'vqa' in self.args.tasks: # VQA vqa_test_loader = self.test_loader['vqa'] evaluator = vqa_test_loader.evaluator dump_path = os.path.join(self.args.output, 'karpathy_test_predict.json') quesid2ans = self.vqa_predict(vqa_test_loader, dump_path) wandb.save(dump_path, base_path=self.args.output) acc_dict_all = evaluator.evaluate_raw(quesid2ans) acc_dict_answerable = evaluator.evaluate_raw( quesid2ans, is_topk_optimal=True) acc_dict_unanswerable = evaluator.evaluate_raw( quesid2ans, is_topk_optimal=False) wandb_log_dict['VQA/Test/overall'] = acc_dict_all['overall'] wandb_log_dict['VQA/Test/topk_optimal'] = acc_dict_answerable[ 'overall'] wandb_log_dict[ 'VQA/Test/topk_not_optimal'] = acc_dict_unanswerable[ 'overall'] vqa_submit_test_loader = self.test_loader['vqa_submit'] dump_path = os.path.join(self.args.output, 'vqa_submit.json') self.vqa_predict(vqa_submit_test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) if 'nlvr' in self.args.tasks: # NLVR nlvr_test_loader = self.test_loader['nlvr'] dump_path = os.path.join(self.args.output, 'nlvr_submit.csv') test_score_dict = self.nlvr_evaluate(nlvr_test_loader, dump_path=dump_path) wandb.save(dump_path, base_path=self.args.output) for score_name, score in test_score_dict.items(): wandb_log_dict[f'NLVR/Test/{score_name}'] = score * 100. if 'refcoco' in self.args.tasks: # RefCOCO refcoco_test_loader = self.test_loader['refcoco'] test_score_dict = self.refcoco_evaluate(refcoco_test_loader) test_acc = test_score_dict['n_correct'] / test_score_dict[ 'n_total'] * 100 test_n_correct = int(test_score_dict['n_correct']) test_n_total = int(test_score_dict['n_total']) wandb_log_dict['RefCOCO/test/Acc'] = test_acc log_str = 'RefCOCOg' log_str += f"\nTest Acc {test_acc:.2f}%" log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}" if 'caption' in self.args.tasks: # COCO Caption caption_test_loader = self.test_loader['caption'] test_results = self.caption_evaluate(caption_test_loader) for score_name, score in test_results.items(): wandb_log_dict[f'Caption/Test/{score_name}'] = score if 'mmt' in self.args.tasks: # MMT mmt_test2016_loader = self.test_loader['mmt_test2016'] mmt_test2017_loader = self.test_loader['mmt_test2017'] mmt_test2018_loader = self.test_loader['mmt_test2018'] for loader in [ mmt_test2016_loader, mmt_test2017_loader, mmt_test2018_loader ]: split = loader.dataset.source dump_path = os.path.join(self.args.output, f'submit_{split}_raw.txt') test_results = self.mmt_evaluate(loader, dump_path=dump_path) for score_name, score in test_results.items(): wandb_log_dict[f'MMT/{split}/{score_name}'] = score log_str += f'{split} set results\n' log_str += pformat(test_results) print(log_str) wandb.log(wandb_log_dict, step=self.args.epochs) wandb.log({'finished': True}) if self.args.distributed: dist.barrier() exit()
def train(self): if self.verbose: loss_meter = LossMeter() best_valid = 0. best_epoch = 0 if not self.wandb_initialized: if 't5' in self.args.backbone: project_name = "VLT5_COCOCaption" elif 'bart' in self.args.backbone: project_name = "VLBart_COCOCaption" wandb.init(project=project_name) wandb.run.name = self.args.run_name wandb.config.update(self.args) wandb.watch(self.model) src_dir = Path(__file__).resolve().parent base_path = str(src_dir.parent) src_dir = str(src_dir) wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path) self.wandb_initialized = True if self.args.distributed: dist.barrier() global_step = 0 epochs = self.args.epochs for epoch in range(epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=120) epoch_results = { 'loss': 0., } for step_i, batch in enumerate(self.train_loader): if self.args.fp16 and _use_native_amp: with autocast(): if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) else: if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) loss = results['loss'] if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() loss = loss.detach() # Update Parameters if self.args.clip_grad_norm > 0: if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optim), self.args.clip_grad_norm) else: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.clip_grad_norm) update = True if self.args.gradient_accumulation_steps > 1: if step_i == 0: update = False elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len( self.train_loader) - 1: update = True else: update = False if update: if self.args.fp16 and _use_native_amp: self.scaler.step(self.optim) self.scaler.update() else: self.optim.step() if self.lr_scheduler: self.lr_scheduler.step() # self.model.zero_grad() for param in self.model.parameters(): param.grad = None global_step += 1 for k, v in results.items(): if k in epoch_results: epoch_results[k] += v.item() if self.lr_scheduler: if version.parse( torch.__version__) >= version.parse("1.4"): lr = self.lr_scheduler.get_last_lr()[0] else: lr = self.lr_scheduler.get_lr()[0] else: try: lr = self.optim.get_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step}' desc_str += f' | Loss {loss_meter.val:4f}' pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() if self.verbose: pbar.close() # format ex) # {'Bleu_1': 0.9999999997500004, # 'Bleu_2': 0.5773502690332603, # 'Bleu_3': 4.3679023223468616e-06, # 'Bleu_4': 1.4287202142987477e-08, # 'CIDEr': 3.333333333333333, # 'METEOR': 0.43354749322305886, # 'ROUGE_L': 0.75, # 'SPICE': 0.6666666666666666} # Validation valid_results = self.evaluate(self.val_loader) if self.verbose: valid_score = valid_results['CIDEr'] if valid_score > best_valid or epoch == 0: best_valid = valid_score best_epoch = epoch self.save("BEST") log_str = '' log_str += pformat(valid_results) log_str += "\nEpoch %d: Valid CIDEr %0.4f" % (epoch, valid_score) log_str += "\nEpoch %d: Best CIDEr %0.4f\n" % (best_epoch, best_valid) wandb_log_dict = {} wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len( self.train_loader) for score_name, score in valid_results.items(): wandb_log_dict[f'Valid/{score_name}'] = score wandb_log_dict[f'Valid/best_epoch'] = best_epoch wandb.log(wandb_log_dict, step=epoch) print(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") # Test Set best_path = os.path.join(self.args.output, 'BEST') self.load(best_path) if self.verbose: wandb.save(best_path, base_path=self.args.output) print(f'\nUploaded checkpoint {best_epoch}', best_path) test_results = self.evaluate(self.test_loader) if self.verbose: wandb_log_dict = {} for score_name, score in test_results.items(): wandb_log_dict[f'Test/{score_name}'] = score wandb.log(wandb_log_dict, step=epoch) log_str = 'Test set results\n' log_str += pformat(test_results) print(log_str) if self.args.distributed: dist.barrier()