Пример #1
0
def train(model,
          criterion,
          dataset,
          logger,
          train_csv_logger,
          val_csv_logger,
          test_csv_logger,
          args,
          epoch_offset,
          train=True):
    model = model.cuda()
    #generalization adjustment
    adjustments = [float(c) for c in args.generalization_adjustment.split(',')]
    assert len(adjustments) in (1, dataset['train_data'].n_groups)
    if len(adjustments) == 1:
        adjustments = np.array(adjustments * dataset['train_data'].n_groups)
    else:
        adjustments = np.array(adjustments)

    train_loss_computer = LossComputer(criterion,
                                       is_robust=args.robust,
                                       dataset=dataset['train_data'],
                                       alpha=args.alpha,
                                       gamma=args.gamma,
                                       adj=adjustments,
                                       step_size=args.robust_step_size)

    # BERT uses its own scheduler and optimizer
    if args.model.startswith('bert'):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.lr,
                          eps=args.adam_epsilon)
        t_total = len(dataset['train_loader']) * args.n_epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)
    else:
        params_list = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
        scheduler = None

    best_val_acc = 0
    for epoch in range(epoch_offset, epoch_offset + args.n_epochs):
        logger.write('\nEpoch [%d]:\n' % epoch)
        logger.write(f'Training:\n')
        if train:
            run_epoch(epoch,
                      model,
                      optimizer,
                      dataset['train_loader'],
                      train_loss_computer,
                      logger,
                      train_csv_logger,
                      args,
                      is_training=True,
                      show_progress=args.show_progress,
                      log_every=args.log_every,
                      scheduler=scheduler)

        logger.write(f'\nValidation:\n')
        val_loss_computer = LossComputer(criterion,
                                         is_robust=args.robust,
                                         dataset=dataset['val_data'],
                                         step_size=args.robust_step_size,
                                         alpha=args.alpha)
        run_epoch(epoch,
                  model,
                  optimizer,
                  dataset['val_loader'],
                  val_loss_computer,
                  logger,
                  val_csv_logger,
                  args,
                  is_training=False)

        if dataset['test_data'] is not None:
            test_loss_computer = LossComputer(criterion,
                                              is_robust=args.robust,
                                              dataset=dataset['test_data'],
                                              step_size=args.robust_step_size,
                                              alpha=args.alpha)
            run_epoch(epoch,
                      model,
                      optimizer,
                      dataset['test_loader'],
                      test_loss_computer,
                      None,
                      test_csv_logger,
                      args,
                      is_training=False)

        # Inspect learning rates
        if (epoch + 1) % 1 == 0:
            for param_group in optimizer.param_groups:
                curr_lr = param_group['lr']
                logger.write('Current lr: %f\n' % curr_lr)

        if epoch % args.save_step == 0:
            torch.save(model, os.path.join(args.log_dir,
                                           '%d_model.pth' % epoch))

        if args.save_last:
            torch.save(model, os.path.join(args.log_dir, 'last_model.pth'))

        if args.save_best:
            if args.robust or args.reweight_groups:
                curr_val_acc = min(val_loss_computer.avg_group_acc)
            else:
                curr_val_acc = val_loss_computer.avg_acc
            logger.write(f'Current validation accuracy: {curr_val_acc}\n')
            if curr_val_acc > best_val_acc:
                best_val_acc = curr_val_acc
                torch.save(model, os.path.join(args.log_dir, 'best_model.pth'))
                logger.write(f'Best model saved at epoch {epoch}\n')
        logger.write('\n')
    return model
Пример #2
0
def train(model, criterion, dataset, logger, train_csv_logger, val_csv_logger,
          test_csv_logger, args, epoch_offset):
    model = model.cuda()

    # process generalization adjustment stuff
    adjustments = [float(c) for c in args.generalization_adjustment.split(',')]
    assert len(adjustments) in (1, dataset['train_data'].n_groups)
    if len(adjustments) == 1:
        adjustments = np.array(adjustments * dataset['train_data'].n_groups)
    else:
        adjustments = np.array(adjustments)

    train_loss_computer = LossComputer(
        criterion,
        is_robust=args.robust,
        dataset=dataset['train_data'],
        alpha=args.alpha,
        gamma=args.gamma,
        adj=adjustments,
        step_size=args.robust_step_size,
        normalize_loss=args.use_normalized_loss,
        btl=args.btl,
        min_var_weight=args.minimum_variational_weight,
        sp=args.sp,
        mode=args.mode,
        ratio=args.ratio)

    # BERT uses its own scheduler and optimizer
    if args.model == 'bert':
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.lr,
                          eps=args.adam_epsilon)
        t_total = len(dataset['train_loader']) * args.n_epochs
        print(f'\nt_total is {t_total}\n')
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)
    else:
        if args.adam:
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                model.parameters()),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                               model.parameters()),
                                        lr=args.lr,
                                        momentum=0.9,
                                        weight_decay=args.weight_decay)
        if args.scheduler:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                'min',
                factor=0.1,
                patience=5,
                threshold=0.0001,
                min_lr=0,
                eps=1e-08)
        else:
            scheduler = None

    best_val_acc = 0
    for epoch in range(epoch_offset, epoch_offset + args.n_epochs):
        logger.write('\nEpoch [%d]:\n' % epoch)
        logger.write(f'Training:\n')
        run_epoch(epoch,
                  model,
                  optimizer,
                  dataset['train_loader'],
                  train_loss_computer,
                  logger,
                  train_csv_logger,
                  args,
                  is_training=True,
                  show_progress=args.show_progress,
                  log_every=args.log_every,
                  scheduler=scheduler)

        logger.write(f'\nValidation:\n')
        val_loss_computer = LossComputer(criterion,
                                         is_robust=args.robust,
                                         dataset=dataset['val_data'],
                                         step_size=args.robust_step_size,
                                         alpha=args.alpha)
        run_epoch(epoch,
                  model,
                  optimizer,
                  dataset['val_loader'],
                  val_loss_computer,
                  logger,
                  val_csv_logger,
                  args,
                  is_training=False)

        # Test set; don't print to avoid peeking
        # if dataset['test_data'] is not None:
        #     test_loss_computer = LossComputer(
        #         criterion,
        #         is_robust=args.robust,
        #         dataset=dataset['test_data'],
        #         step_size=args.robust_step_size,
        #         alpha=args.alpha)
        #     run_epoch(
        #         epoch, model, optimizer,
        #         dataset['test_loader'],
        #         test_loss_computer,
        #         None, test_csv_logger, args,
        #         is_training=False)

        # Inspect learning rates
        if (epoch + 1) % 1 == 0:
            for param_group in optimizer.param_groups:
                curr_lr = param_group['lr']
                logger.write('Current lr: %f\n' % curr_lr)

        if args.scheduler and args.model != 'bert':
            if args.robust:
                val_loss, _ = val_loss_computer.compute_robust_loss_greedy(
                    val_loss_computer.avg_group_loss,
                    val_loss_computer.avg_group_loss)
            else:
                val_loss = val_loss_computer.avg_actual_loss
            scheduler.step(
                val_loss)  #scheduler step to update lr at the end of epoch

        if epoch % args.save_step == 0:
            torch.save(model, os.path.join(args.log_dir,
                                           '%d_model.pth' % epoch))

        if args.save_last:
            torch.save(model, os.path.join(args.log_dir, 'last_model.pth'))

        if args.save_best:
            if args.robust or args.reweight_groups:
                curr_val_acc = min(val_loss_computer.avg_group_acc)
            else:
                curr_val_acc = val_loss_computer.avg_acc
            logger.write(f'Current validation accuracy: {curr_val_acc}\n')
            if curr_val_acc > best_val_acc:
                best_val_acc = curr_val_acc
                torch.save(model, os.path.join(args.log_dir, 'best_model.pth'))
                logger.write(f'Best model saved at epoch {epoch}\n')

        if args.automatic_adjustment:
            gen_gap = val_loss_computer.avg_group_loss - train_loss_computer.exp_avg_loss
            adjustments = gen_gap * torch.sqrt(
                train_loss_computer.group_counts)
            train_loss_computer.adj = adjustments
            logger.write('Adjustments updated\n')
            for group_idx in range(train_loss_computer.n_groups):
                logger.write(
                    f'  {train_loss_computer.get_group_name(group_idx)}:\t'
                    f'adj = {train_loss_computer.adj[group_idx]:.3f}\n')
        logger.write('\n')
Пример #3
0
def train(
    model,
    criterion,
    dataset,
    logger,
    train_csv_logger,
    val_csv_logger,
    test_csv_logger,
    args,
    epoch_offset,
    csv_name=None,
    wandb=None,
):
    model = model.to(device)

    # process generalization adjustment stuff
    adjustments = [float(c) for c in args.generalization_adjustment.split(",")]
    assert len(adjustments) in (1, dataset["train_data"].n_groups)
    if len(adjustments) == 1:
        adjustments = np.array(adjustments * dataset["train_data"].n_groups)
    else:
        adjustments = np.array(adjustments)

    train_loss_computer = LossComputer(
        criterion,
        loss_type=args.loss_type,
        dataset=dataset["train_data"],
        alpha=args.alpha,
        gamma=args.gamma,
        adj=adjustments,
        step_size=args.robust_step_size,
        normalize_loss=args.use_normalized_loss,
        btl=args.btl,
        min_var_weight=args.minimum_variational_weight,
        joint_dro_alpha=args.joint_dro_alpha,
    )

    # BERT uses its own scheduler and optimizer
    if (args.model.startswith("bert") and args.use_bert_params):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                args.weight_decay,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.lr,
                          eps=args.adam_epsilon)
        t_total = len(dataset["train_loader"]) * args.n_epochs
        print(f"\nt_total is {t_total}\n")
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)
    else:
        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            momentum=0.9,
            weight_decay=args.weight_decay,
        )
        if args.scheduler:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                "min",
                factor=0.1,
                patience=5,
                threshold=0.0001,
                min_lr=0,
                eps=1e-08,
            )
        else:
            scheduler = None

    best_val_acc = 0
    for epoch in range(epoch_offset, epoch_offset + args.n_epochs):
        logger.write("\nEpoch [%d]:\n" % epoch)
        logger.write(f"Training:\n")
        run_epoch(
            epoch,
            model,
            optimizer,
            dataset["train_loader"],
            train_loss_computer,
            logger,
            train_csv_logger,
            args,
            is_training=True,
            csv_name=csv_name,
            show_progress=args.show_progress,
            log_every=args.log_every,
            scheduler=scheduler,
            wandb_group="train",
            wandb=wandb,
        )

        logger.write(f"\nValidation:\n")
        val_loss_computer = LossComputer(
            criterion,
            loss_type=args.loss_type,
            dataset=dataset["val_data"],
            alpha=args.alpha,
            gamma=args.gamma,
            adj=adjustments,
            step_size=args.robust_step_size,
            normalize_loss=args.use_normalized_loss,
            btl=args.btl,
            min_var_weight=args.minimum_variational_weight,
            joint_dro_alpha=args.joint_dro_alpha,
        )
        run_epoch(
            epoch,
            model,
            optimizer,
            dataset["val_loader"],
            val_loss_computer,
            logger,
            val_csv_logger,
            args,
            is_training=False,
            csv_name=csv_name,
            wandb_group="val",
            wandb=wandb,
        )

        # Test set; don't print to avoid peeking
        if dataset["test_data"] is not None:
            test_loss_computer = LossComputer(
                criterion,
                loss_type=args.loss_type,
                dataset=dataset["test_data"],
                step_size=args.robust_step_size,
                alpha=args.alpha,
                gamma=args.gamma,
                adj=adjustments,
                normalize_loss=args.use_normalized_loss,
                btl=args.btl,
                min_var_weight=args.minimum_variational_weight,
                joint_dro_alpha=args.joint_dro_alpha,
            )
            run_epoch(
                epoch,
                model,
                optimizer,
                dataset["test_loader"],
                test_loss_computer,
                None,
                test_csv_logger,
                args,
                is_training=False,
                csv_name=csv_name,
                wandb_group="val",
                wandb=wandb,
            )

        # Inspect learning rates
        if (epoch + 1) % 1 == 0:
            for param_group in optimizer.param_groups:
                curr_lr = param_group["lr"]
                logger.write("Current lr: %f\n" % curr_lr)

        if args.scheduler and args.model != "bert":
            if args.loss_type == "group_dro":
                val_loss, _ = val_loss_computer.compute_robust_loss_greedy(
                    val_loss_computer.avg_group_loss,
                    val_loss_computer.avg_group_loss)
            else:
                val_loss = val_loss_computer.avg_actual_loss
            scheduler.step(
                val_loss)  # scheduler step to update lr at the end of epoch

        if epoch % args.save_step == 0:
            torch.save(model, os.path.join(args.log_dir,
                                           "%d_model.pth" % epoch))

        if args.save_last:
            torch.save(model, os.path.join(args.log_dir, "last_model.pth"))

        if args.save_best:
            if args.loss_type == "group_dro" or args.reweight_groups:
                curr_val_acc = min(val_loss_computer.avg_group_acc)
            else:
                curr_val_acc = val_loss_computer.avg_acc
            logger.write(f"Current validation accuracy: {curr_val_acc}\n")
            if curr_val_acc > best_val_acc:
                best_val_acc = curr_val_acc
                torch.save(model, os.path.join(args.log_dir, "best_model.pth"))
                logger.write(f"Best model saved at epoch {epoch}\n")

        if args.automatic_adjustment:
            gen_gap = val_loss_computer.avg_group_loss - train_loss_computer.exp_avg_loss
            adjustments = gen_gap * torch.sqrt(
                train_loss_computer.group_counts)
            train_loss_computer.adj = adjustments
            logger.write("Adjustments updated\n")
            for group_idx in range(train_loss_computer.n_groups):
                logger.write(
                    f"  {train_loss_computer.get_group_name(group_idx)}:\t"
                    f"adj = {train_loss_computer.adj[group_idx]:.3f}\n")
        logger.write("\n")