Beispiel #1
0
def main():
    args = get_args()
    wandb.init()
    wandb.config.update(args)

    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.deterministic = True
    torch.backends.cudnn.benchmark = False

    loaded_model = False

    [train_loader, valid_loader, model,
     optimizer] = initialize(args, loaded_model)
    scaler = torch.cuda.amp.GradScaler()

    wandb.watch(model)
    best_acc = 0
    run_avg = RunningAverage()

    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.001, max_lr=0.1, cycle_momentum=False)

    for epoch in range(1, args.epochs_number + 1):
        run_avg.reset_train()
        run_avg.reset_val()

        train(args, model, train_loader, epoch, optimizer, scaler, run_avg)
        val_acc = evaluation(args, model, valid_loader, epoch, run_avg)

        # scheduler.step()
        if best_acc < val_acc:
            best_acc = val_acc
            save_checkpoint(model, optimizer, args, epoch)
def test_running_average():
    train_losses = [1, 0.5, 0.3]
    train_accuracies = [0.3, 0.5, 0.1]
    running_average = RunningAverage()
    for i in range(len(train_losses)):
        running_average.update_train_loss_avg(train_losses[i], 1)
        running_average.update_train_acc_avg(train_accuracies[i], 1)

    assert running_average.train_loss_run_avg == 0.6
    assert running_average.train_acc_run_avg == 0.3

    running_average.update_train_loss_avg(0.2, 1)
    running_average.update_train_acc_avg(0.1, 1)

    assert running_average.train_loss_run_avg == 0.5
    assert running_average.train_acc_run_avg == 0.25

    val_losses = [1, 0.7, 1.3]
    val_accuracies = [0.3, 0.4, 8.3]

    for i in range(len(val_losses)):
        running_average.update_val_loss_avg(val_losses[i], 1)
        running_average.update_val_acc_avg(val_accuracies[i], 1)

    assert running_average.val_loss_run_avg == 1
    assert running_average.val_acc_run_avg == 3

    running_average.update_val_loss_avg(3, 1)
    running_average.update_val_acc_avg(7, 1)

    assert running_average.val_loss_run_avg == 1.5
    assert running_average.val_acc_run_avg == 4

    running_average.reset_train()
    running_average.reset_val()

    assert running_average.sum_train_loss == 0
    assert running_average.sum_train_acc == 0
    assert running_average.train_loss_counter == 0
    assert running_average.train_acc_counter == 0
    assert running_average.train_loss_run_avg == 0
    assert running_average.train_acc_run_avg == 0

    assert running_average.sum_val_loss == 0
    assert running_average.sum_val_acc == 0
    assert running_average.val_loss_counter == 0
    assert running_average.val_acc_counter == 0
    assert running_average.val_loss_run_avg == 0
    assert running_average.val_acc_run_avg == 0