def eval_pred(pred, answer_index, round_id, gt_relevance): """ Evaluate the predict results and report metrices. Only for val split. Parameters: ----------- pred: ndarray of shape (n_samples, n_rounds, n_options). answer_index: ndarray of shape (n_sample, n_rounds). round_id: ndarray of shape (n_samples, ). gt_relevance: ndarray of shape (n_samples, n_options). Returns: -------- None """ # Convert them to torch tensor to use visdialch.metrics pred = torch.Tensor(pred) answer_index = torch.Tensor(answer_index).long() round_id = torch.Tensor(round_id).long() gt_relevance = torch.Tensor(gt_relevance) sparse_metrics = SparseGTMetrics() ndcg = NDCG() sparse_metrics.observe(pred, answer_index) pred = pred[torch.arange(pred.size(0)), round_id - 1, :] ndcg.observe(pred, gt_relevance) all_metrics = {} all_metrics.update(sparse_metrics.retrieve(reset=True)) all_metrics.update(ndcg.retrieve(reset=True)) for metric_name, metric_value in all_metrics.items(): print(f"{metric_name}: {metric_value}")
temp_train_batch[key] = batch[key].to(device) elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind']: temp_train_batch[key] = batch[key][:, rnd].to(device) elif key in ['hist_len', 'hist']: temp_train_batch[key] = batch[key][:, :rnd + 1].to(device) else: pass return temp_train_batch model.eval() for i, batch in enumerate(val_dataloader): batchsize = batch['img_ids'].shape[0] rnd = 0 temp_train_batch = get_1round_batch_data(batch, rnd) output = model(temp_train_batch).view(-1, 1, 100).detach() for rnd in range(1, 10): temp_train_batch = get_1round_batch_data(batch, rnd) output = torch.cat((output, model(temp_train_batch).view(-1, 1, 100).detach()), dim=1) sparse_metrics.observe(output, batch["ans_ind"]) if "relevance" in batch: output = output[torch.arange(output.size(0)), batch["round_id"] - 1, :] ndcg.observe(output.view(-1, 100), batch["relevance"].contiguous().view(-1, 100)) # if i > 5: #for debug(like the --overfit) # break all_metrics = {} all_metrics.update(sparse_metrics.retrieve(reset=True)) all_metrics.update(ndcg.retrieve(reset=True)) for metric_name, metric_value in all_metrics.items(): print(f"{metric_name}: {metric_value}") model.train()
def train(config, args, dataloader_dic, device, finetune: bool = False, load_pthpath: str = "", finetune_regression: bool = False, dense_scratch_train: bool = False, dense_annotation_type: str = "default"): """ :param config: :param args: :param dataloader_dic: :param device: :param finetune: :param load_pthpath: :param finetune_regression: :param dense_scratch_train: when we want to start training only on 2000 annotations :param dense_annotation_type: default :return: """ # ============================================================================= # SETUP BEFORE TRAINING LOOP # ============================================================================= train_dataset = dataloader_dic["train_dataset"] train_dataloader = dataloader_dic["train_dataloader"] val_dataloader = dataloader_dic["val_dataloader"] val_dataset = dataloader_dic["val_dataset"] model = get_model(config, args, train_dataset, device) if finetune and not dense_scratch_train: assert load_pthpath != "", "Please provide a path" \ " for pre-trained model before " \ "starting fine tuning" print(f"\n Begin Finetuning:") optimizer, scheduler, iterations, lr_scheduler_type = get_solver( config, args, train_dataset, val_dataset, model, finetune=finetune) start_time = datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') if args.save_dirpath == 'checkpoints/': args.save_dirpath += '%s+%s/%s' % ( config["model"]["encoder"], config["model"]["decoder"], start_time) summary_writer = SummaryWriter(log_dir=args.save_dirpath) checkpoint_manager = CheckpointManager(model, optimizer, args.save_dirpath, config=config) sparse_metrics = SparseGTMetrics() ndcg = NDCG() best_val_loss = np.inf # SA: initially loss can be any number best_val_ndcg = 0.0 # If loading from checkpoint, adjust start epoch and load parameters. # SA: 1. if finetuning -> load from saved model # 2. train -> default load_pthpath = "" # 3. else load pthpath if (not finetune and load_pthpath == "") or dense_scratch_train: start_epoch = 1 else: # "path/to/checkpoint_xx.pth" -> xx ### To cater model finetuning from models with "best_ndcg" checkpoint try: start_epoch = int(load_pthpath.split("_")[-1][:-4]) + 1 except: start_epoch = 1 model_state_dict, optimizer_state_dict = load_checkpoint(load_pthpath) # SA: updating last epoch checkpoint_manager.update_last_epoch(start_epoch) if isinstance(model, nn.DataParallel): model.module.load_state_dict(model_state_dict) else: model.load_state_dict(model_state_dict) # SA: for finetuning optimizer should start from its learning rate if not finetune: optimizer.load_state_dict(optimizer_state_dict) else: print("Optimizer not loaded. Different optimizer for finetuning.") print("Loaded model from {}".format(load_pthpath)) # ============================================================================= # TRAINING LOOP # ============================================================================= # Forever increasing counter to keep track of iterations (for tensorboard log). global_iteration_step = (start_epoch - 1) * iterations running_loss = 0.0 # New train_begin = datetime.datetime.utcnow() # New if finetune: end_epoch = start_epoch + config["solver"]["num_epochs_curriculum"] - 1 if finetune_regression: # criterion = nn.MSELoss(reduction='mean') # criterion = nn.KLDivLoss(reduction='mean') criterion = nn.MultiLabelSoftMarginLoss() else: end_epoch = config["solver"]["num_epochs"] # SA: normal training criterion = get_loss_criterion(config, train_dataset) # SA: end_epoch + 1 => for loop also doing last epoch for epoch in range(start_epoch, end_epoch + 1): # ------------------------------------------------------------------------- # ON EPOCH START (combine dataloaders if training on train + val) # ------------------------------------------------------------------------- if config["solver"]["training_splits"] == "trainval": combined_dataloader = itertools.chain(train_dataloader, val_dataloader) else: combined_dataloader = itertools.chain(train_dataloader) print(f"\nTraining for epoch {epoch}:") for i, batch in enumerate(tqdm(combined_dataloader)): for key in batch: batch[key] = batch[key].to(device) optimizer.zero_grad() output = model(batch) if finetune: target = batch["gt_relevance"] # Same as for ndcg validation, only one round is present output = output[torch.arange(output.size(0)), batch["round_id"] - 1, :] # SA: todo regression loss if finetune_regression: batch_loss = mse_loss(output, target, criterion) else: batch_loss = compute_ndcg_type_loss(output, target) else: batch_loss = get_batch_criterion_loss_value( config, batch, criterion, output) batch_loss.backward() optimizer.step() # -------------------------------------------------------------------- # update running loss and decay learning rates # -------------------------------------------------------------------- if running_loss > 0.0: running_loss = 0.95 * running_loss + 0.05 * batch_loss.item() else: running_loss = batch_loss.item() # SA: lambda_lr was configured to reduce lr after milestone epochs if lr_scheduler_type == "lambda_lr": scheduler.step(global_iteration_step) global_iteration_step += 1 if global_iteration_step % 100 == 0: # print current time, running average, learning rate, iteration, epoch print( "[{}][Epoch: {:3d}][Iter: {:6d}][Loss: {:6f}][lr: {:8f}]". format(datetime.datetime.utcnow() - train_begin, epoch, global_iteration_step, running_loss, optimizer.param_groups[0]['lr'])) # tensorboardX summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step) summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step) torch.cuda.empty_cache() # ------------------------------------------------------------------------- # ON EPOCH END (checkpointing and validation) # ------------------------------------------------------------------------- if not finetune: checkpoint_manager.step(epoch=epoch) else: print("Validating before checkpointing.") # SA: ideally another function: too much work # Validate and report automatic metrics. if args.validate: # Switch dropout, batchnorm etc to the correct mode. model.eval() val_loss = 0 print(f"\nValidation after epoch {epoch}:") for i, batch in enumerate(tqdm(val_dataloader)): for key in batch: batch[key] = batch[key].to(device) with torch.no_grad(): output = model(batch) if finetune: target = batch["gt_relevance"] # Same as for ndcg validation, only one round is present out_ndcg = output[torch.arange(output.size(0)), batch["round_id"] - 1, :] # SA: todo regression loss if finetune_regression: batch_loss = mse_loss(out_ndcg, target, criterion) else: batch_loss = compute_ndcg_type_loss( out_ndcg, target) else: batch_loss = get_batch_criterion_loss_value( config, batch, criterion, output) val_loss += batch_loss.item() sparse_metrics.observe(output, batch["ans_ind"]) if "gt_relevance" in batch: output = output[torch.arange(output.size(0)), batch["round_id"] - 1, :] ndcg.observe(output, batch["gt_relevance"]) all_metrics = {} all_metrics.update(sparse_metrics.retrieve(reset=True)) all_metrics.update(ndcg.retrieve(reset=True)) for metric_name, metric_value in all_metrics.items(): print(f"{metric_name}: {metric_value}") summary_writer.add_scalars("metrics", all_metrics, global_iteration_step) model.train() torch.cuda.empty_cache() val_loss = val_loss / len(val_dataloader) print(f"Validation loss for {epoch} epoch is {val_loss}") print(f"Validation loss for batch is {batch_loss}") summary_writer.add_scalar("val/loss", batch_loss, global_iteration_step) if val_loss < best_val_loss: print(f" Best model found at {epoch} epoch! Saving now.") best_val_loss = val_loss if dense_annotation_type == "default": checkpoint_manager.save_best() else: print(f" Not saving the model at {epoch} epoch!") # SA: Saving the best model both for loss and ndcg now val_ndcg = all_metrics["ndcg"] if val_ndcg > best_val_ndcg: print(f" Best ndcg model found at {epoch} epoch! Saving now.") best_val_ndcg = val_ndcg if dense_annotation_type == "default": checkpoint_manager.save_best(ckpt_name="best_ndcg") else: # SA: trying for dense annotations ckpt_name = f"best_ndcg_annotation_{dense_annotation_type}" checkpoint_manager.save_best(ckpt_name=ckpt_name) else: print(f" Not saving the model at {epoch} epoch!") # SA: "reduce_lr_on_plateau" works only with validate for now if lr_scheduler_type == "reduce_lr_on_plateau": # scheduler.step(val_loss) # SA: # Loss should decrease while ndcg should increase! # can also change the mode in LR reduce on plateau to max scheduler.step(-1 * val_ndcg)