Esempio n. 1
0
def eval_pred(pred, answer_index, round_id, gt_relevance):
    """
    Evaluate the predict results and report metrices. Only for val split.

    Parameters:
    -----------
    pred: ndarray of shape (n_samples, n_rounds, n_options).
    answer_index: ndarray of shape (n_sample, n_rounds).
    round_id: ndarray of shape (n_samples, ).
    gt_relevance: ndarray of shape (n_samples, n_options).

    Returns:
    --------
    None
    """
    # Convert them to torch tensor to use visdialch.metrics
    pred = torch.Tensor(pred)
    answer_index = torch.Tensor(answer_index).long()
    round_id = torch.Tensor(round_id).long()
    gt_relevance = torch.Tensor(gt_relevance)

    sparse_metrics = SparseGTMetrics()
    ndcg = NDCG()

    sparse_metrics.observe(pred, answer_index)
    pred = pred[torch.arange(pred.size(0)), round_id - 1, :]
    ndcg.observe(pred, gt_relevance)

    all_metrics = {}
    all_metrics.update(sparse_metrics.retrieve(reset=True))
    all_metrics.update(ndcg.retrieve(reset=True))
    for metric_name, metric_value in all_metrics.items():
        print(f"{metric_name}: {metric_value}")
Esempio n. 2
0
            temp_train_batch[key] = batch[key].to(device)
        elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind']:
            temp_train_batch[key] = batch[key][:, rnd].to(device)
        elif key in ['hist_len', 'hist']:
            temp_train_batch[key] = batch[key][:, :rnd + 1].to(device)
        else:
            pass
    return temp_train_batch

model.eval()
for i, batch in enumerate(val_dataloader):
    batchsize = batch['img_ids'].shape[0]
    rnd = 0
    temp_train_batch = get_1round_batch_data(batch, rnd)
    output = model(temp_train_batch).view(-1, 1, 100).detach()
    for rnd in range(1, 10):
        temp_train_batch = get_1round_batch_data(batch, rnd)
        output = torch.cat((output, model(temp_train_batch).view(-1, 1, 100).detach()), dim=1)
    sparse_metrics.observe(output, batch["ans_ind"])
    if "relevance" in batch:
        output = output[torch.arange(output.size(0)), batch["round_id"] - 1, :]
        ndcg.observe(output.view(-1, 100), batch["relevance"].contiguous().view(-1, 100))
    # if i > 5: #for debug(like the --overfit)
    #     break
all_metrics = {}
all_metrics.update(sparse_metrics.retrieve(reset=True))
all_metrics.update(ndcg.retrieve(reset=True))
for metric_name, metric_value in all_metrics.items():
    print(f"{metric_name}: {metric_value}")
model.train()
Esempio n. 3
0
def train(config,
          args,
          dataloader_dic,
          device,
          finetune: bool = False,
          load_pthpath: str = "",
          finetune_regression: bool = False,
          dense_scratch_train: bool = False,
          dense_annotation_type: str = "default"):
    """

    :param config:
    :param args:
    :param dataloader_dic:
    :param device:
    :param finetune:
    :param load_pthpath:
    :param finetune_regression:
    :param dense_scratch_train: when we want to start training only on 2000 annotations
    :param dense_annotation_type: default
    :return:
    """
    # =============================================================================
    #   SETUP BEFORE TRAINING LOOP
    # =============================================================================
    train_dataset = dataloader_dic["train_dataset"]
    train_dataloader = dataloader_dic["train_dataloader"]
    val_dataloader = dataloader_dic["val_dataloader"]
    val_dataset = dataloader_dic["val_dataset"]

    model = get_model(config, args, train_dataset, device)

    if finetune and not dense_scratch_train:
        assert load_pthpath != "", "Please provide a path" \
                                        " for pre-trained model before " \
                                        "starting fine tuning"
        print(f"\n Begin Finetuning:")

    optimizer, scheduler, iterations, lr_scheduler_type = get_solver(
        config, args, train_dataset, val_dataset, model, finetune=finetune)

    start_time = datetime.datetime.strftime(datetime.datetime.utcnow(),
                                            '%d-%b-%Y-%H:%M:%S')
    if args.save_dirpath == 'checkpoints/':
        args.save_dirpath += '%s+%s/%s' % (
            config["model"]["encoder"], config["model"]["decoder"], start_time)
    summary_writer = SummaryWriter(log_dir=args.save_dirpath)
    checkpoint_manager = CheckpointManager(model,
                                           optimizer,
                                           args.save_dirpath,
                                           config=config)
    sparse_metrics = SparseGTMetrics()
    ndcg = NDCG()
    best_val_loss = np.inf  # SA: initially loss can be any number
    best_val_ndcg = 0.0
    # If loading from checkpoint, adjust start epoch and load parameters.

    # SA: 1. if finetuning -> load from saved model
    # 2. train -> default load_pthpath = ""
    # 3. else load pthpath
    if (not finetune and load_pthpath == "") or dense_scratch_train:
        start_epoch = 1
    else:
        # "path/to/checkpoint_xx.pth" -> xx
        ### To cater model finetuning from models with "best_ndcg" checkpoint
        try:
            start_epoch = int(load_pthpath.split("_")[-1][:-4]) + 1
        except:
            start_epoch = 1

        model_state_dict, optimizer_state_dict = load_checkpoint(load_pthpath)

        # SA: updating last epoch
        checkpoint_manager.update_last_epoch(start_epoch)

        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(model_state_dict)
        else:
            model.load_state_dict(model_state_dict)

        # SA: for finetuning optimizer should start from its learning rate
        if not finetune:
            optimizer.load_state_dict(optimizer_state_dict)
        else:
            print("Optimizer not loaded. Different optimizer for finetuning.")
        print("Loaded model from {}".format(load_pthpath))

    # =============================================================================
    #   TRAINING LOOP
    # =============================================================================

    # Forever increasing counter to keep track of iterations (for tensorboard log).
    global_iteration_step = (start_epoch - 1) * iterations

    running_loss = 0.0  # New
    train_begin = datetime.datetime.utcnow()  # New

    if finetune:
        end_epoch = start_epoch + config["solver"]["num_epochs_curriculum"] - 1
        if finetune_regression:
            # criterion = nn.MSELoss(reduction='mean')
            # criterion = nn.KLDivLoss(reduction='mean')
            criterion = nn.MultiLabelSoftMarginLoss()
    else:
        end_epoch = config["solver"]["num_epochs"]
        # SA: normal training
        criterion = get_loss_criterion(config, train_dataset)

    # SA: end_epoch + 1 => for loop also doing last epoch
    for epoch in range(start_epoch, end_epoch + 1):
        # -------------------------------------------------------------------------
        #   ON EPOCH START  (combine dataloaders if training on train + val)
        # -------------------------------------------------------------------------
        if config["solver"]["training_splits"] == "trainval":
            combined_dataloader = itertools.chain(train_dataloader,
                                                  val_dataloader)
        else:
            combined_dataloader = itertools.chain(train_dataloader)

        print(f"\nTraining for epoch {epoch}:")
        for i, batch in enumerate(tqdm(combined_dataloader)):
            for key in batch:
                batch[key] = batch[key].to(device)

            optimizer.zero_grad()
            output = model(batch)

            if finetune:
                target = batch["gt_relevance"]
                # Same as for ndcg validation, only one round is present
                output = output[torch.arange(output.size(0)),
                                batch["round_id"] - 1, :]
                # SA: todo regression loss
                if finetune_regression:
                    batch_loss = mse_loss(output, target, criterion)
                else:
                    batch_loss = compute_ndcg_type_loss(output, target)
            else:
                batch_loss = get_batch_criterion_loss_value(
                    config, batch, criterion, output)

            batch_loss.backward()
            optimizer.step()

            # --------------------------------------------------------------------
            # update running loss and decay learning rates
            # --------------------------------------------------------------------
            if running_loss > 0.0:
                running_loss = 0.95 * running_loss + 0.05 * batch_loss.item()
            else:
                running_loss = batch_loss.item()

            # SA: lambda_lr was configured to reduce lr after milestone epochs
            if lr_scheduler_type == "lambda_lr":
                scheduler.step(global_iteration_step)

            global_iteration_step += 1

            if global_iteration_step % 100 == 0:
                # print current time, running average, learning rate, iteration, epoch
                print(
                    "[{}][Epoch: {:3d}][Iter: {:6d}][Loss: {:6f}][lr: {:8f}]".
                    format(datetime.datetime.utcnow() - train_begin, epoch,
                           global_iteration_step, running_loss,
                           optimizer.param_groups[0]['lr']))

                # tensorboardX
                summary_writer.add_scalar("train/loss", batch_loss,
                                          global_iteration_step)
                summary_writer.add_scalar("train/lr",
                                          optimizer.param_groups[0]["lr"],
                                          global_iteration_step)
        torch.cuda.empty_cache()

        # -------------------------------------------------------------------------
        #   ON EPOCH END  (checkpointing and validation)
        # -------------------------------------------------------------------------
        if not finetune:
            checkpoint_manager.step(epoch=epoch)
        else:
            print("Validating before checkpointing.")

        # SA: ideally another function: too much work
        # Validate and report automatic metrics.
        if args.validate:

            # Switch dropout, batchnorm etc to the correct mode.
            model.eval()
            val_loss = 0

            print(f"\nValidation after epoch {epoch}:")
            for i, batch in enumerate(tqdm(val_dataloader)):
                for key in batch:
                    batch[key] = batch[key].to(device)
                with torch.no_grad():
                    output = model(batch)
                    if finetune:
                        target = batch["gt_relevance"]
                        # Same as for ndcg validation, only one round is present
                        out_ndcg = output[torch.arange(output.size(0)),
                                          batch["round_id"] - 1, :]
                        # SA: todo regression loss
                        if finetune_regression:
                            batch_loss = mse_loss(out_ndcg, target, criterion)
                        else:
                            batch_loss = compute_ndcg_type_loss(
                                out_ndcg, target)
                    else:
                        batch_loss = get_batch_criterion_loss_value(
                            config, batch, criterion, output)

                    val_loss += batch_loss.item()
                sparse_metrics.observe(output, batch["ans_ind"])
                if "gt_relevance" in batch:
                    output = output[torch.arange(output.size(0)),
                                    batch["round_id"] - 1, :]
                    ndcg.observe(output, batch["gt_relevance"])

            all_metrics = {}
            all_metrics.update(sparse_metrics.retrieve(reset=True))
            all_metrics.update(ndcg.retrieve(reset=True))
            for metric_name, metric_value in all_metrics.items():
                print(f"{metric_name}: {metric_value}")
            summary_writer.add_scalars("metrics", all_metrics,
                                       global_iteration_step)

            model.train()
            torch.cuda.empty_cache()

            val_loss = val_loss / len(val_dataloader)
            print(f"Validation loss for {epoch} epoch is {val_loss}")
            print(f"Validation loss for batch is {batch_loss}")

            summary_writer.add_scalar("val/loss", batch_loss,
                                      global_iteration_step)

            if val_loss < best_val_loss:
                print(f" Best model found at {epoch} epoch! Saving now.")
                best_val_loss = val_loss
                if dense_annotation_type == "default":
                    checkpoint_manager.save_best()
            else:
                print(f" Not saving the model at {epoch} epoch!")

            # SA: Saving the best model both for loss and ndcg now
            val_ndcg = all_metrics["ndcg"]
            if val_ndcg > best_val_ndcg:
                print(f" Best ndcg model found at {epoch} epoch! Saving now.")
                best_val_ndcg = val_ndcg
                if dense_annotation_type == "default":
                    checkpoint_manager.save_best(ckpt_name="best_ndcg")
                else:
                    # SA: trying for dense annotations
                    ckpt_name = f"best_ndcg_annotation_{dense_annotation_type}"
                    checkpoint_manager.save_best(ckpt_name=ckpt_name)
            else:
                print(f" Not saving the model at {epoch} epoch!")

            # SA: "reduce_lr_on_plateau" works only with validate for now
            if lr_scheduler_type == "reduce_lr_on_plateau":
                # scheduler.step(val_loss)
                # SA: # Loss should decrease while ndcg should increase!
                # can also change the mode in LR reduce on plateau to max
                scheduler.step(-1 * val_ndcg)