def run_model(self, beam_size, split): # set beam-size self.model.module.set_beam_size(beam_size) predictions = { "question_id": [], "topkscores": [], "complete_seqs": [], # 'ocr_tokens': [] } with torch.no_grad(): for batch in tqdm(self.dataloaders[split], desc="Beam Search Evaluation"): # Batch is updated inside the method, no outputs are needed forward_model(None, self.device, self.model, batch_dict=batch, beam_search=True) save_keys = ["question_id", "topkscores", "complete_seqs"] for key in save_keys: predictions[key].append(batch[key]) break return predictions
def evaluate( dataloaders, task_cfg, device, model, ): scores, batch_sizes = [], [] model.eval() with torch.no_grad(): for batch_dict in tqdm(dataloaders["val"], desc="Validation"): loss, score, batch_size, _ = forward_model(task_cfg, device, model, batch_dict=batch_dict) scores.append(score * batch_size) batch_sizes.append(batch_size) model.train() return sum(scores) / sum(batch_sizes)
def run_model_no_beam(self, split): scores, batch_sizes = [], [] predictions = [] self.model.eval() with torch.no_grad(): for batch_dict in tqdm(self.dataloaders[split], desc=f"Eval on {split}"): loss, score, batch_size, batch_predictions = forward_model( { "loss": "textvqa", "metric": "textvqa" }, self.device, self.model, batch_dict=batch_dict) scores.append(score * batch_size) batch_sizes.append(batch_size) predictions.extend(batch_predictions) evalai_preds = [{ "question_id": x["question_id"], "answer": x["pred_answer"] } for x in predictions] return evalai_preds
def main(): task_cfg, args, save_path = get_config() checkpoint_path = os.path.join(save_path, "best_model.tar") base_lr = task_cfg["lr"] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info(f"Device: {device}, Numer of GPUs: {n_gpu}") dataloaders = load_datasets(task_cfg, ["train", "val", "test"]) mmt_config = BertConfig.from_dict(task_cfg["SA-M4C"]) text_bert_config = BertConfig.from_dict(task_cfg["TextBERT"]) model = SAM4C(mmt_config, text_bert_config) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"Training Parameters: {trainable_params}") optimizer_grouped_parameters = model.get_optimizer_parameters(base_lr) print(len(list(model.named_parameters())), len(optimizer_grouped_parameters)) optimizer, warmup_scheduler = get_optim_scheduler( task_cfg, optimizer_grouped_parameters, base_lr) start_iter_id, global_step, start_epoch = 0, 0, 0 model.to(device) for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() if n_gpu > 1: model = torch.nn.DataParallel(model) # When running only evaluation if args.pretrained_eval != "": logger.info( f"Dumping Evaluation results at: {os.path.dirname(args.pretrained_eval)}" ) return args.pretrained_eval, model, dataloaders # This validation score is used for model-saving. best_val_step, best_val_score = -1, -1 loss_values, score_values = [], [] median_num_iter = len(dataloaders["train"]) # Train loop model.train() for epoch_id in tqdm(range(start_epoch, args.num_train_epochs), desc="Epoch"): for step in tqdm(range(median_num_iter), desc="Iters"): assert model.training iter_id = start_iter_id + step + (epoch_id * median_num_iter) loss, score, _, _ = forward_model(task_cfg, device, model, dataloaders, "train") # Compute gradients loss.backward() clip_gradients(model, task_cfg["max_grad_norm"]) # Apply and reset gradients optimizer.step() warmup_scheduler.step() model.zero_grad() # Increment loggers global_step += 1 loss_values.append(loss) score_values.append(score) # Handle logging if step % 20 == 0 and step != 0: loss_avg, score_avg = float( sum(loss_values) / len(loss_values)), float( sum(score_values) / len(score_values)) loss_values, score_values = [], [] log_str = f"Epoch: {epoch_id}: Iter: {iter_id}; loss = {loss_avg}; accuracy = {score_avg}" if step % 100 == 0: log_str += f"\n lr rates = {[float(grp['lr']) for grp in optimizer.param_groups]}" logger.info(log_str) # Evaluate after every epoch curr_val_score = evaluate( dataloaders, task_cfg, device, model, ) logger.info( f"[Validation] Current VQA: {curr_val_score} at {global_step} | Best VQA: {best_val_score} at {best_val_step}" ) if curr_val_score > best_val_score: logger.info(f"Saving Checkpoint: {checkpoint_path}") model_to_save = model.module if hasattr(model, "module") else model best_val_score, best_val_step = curr_val_score, global_step torch.save( { "model_state_dict": model_to_save.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "warmup_scheduler_state_dict": warmup_scheduler.state_dict(), "global_step": global_step, "current_val_score": curr_val_score, "epoch_id": epoch_id, }, checkpoint_path, ) print( f"Best Validation Score: {best_val_score}, Best Validation Epoch: {best_val_step}" ) return checkpoint_path, model, dataloaders