def run(net):
    # Create dataloaders
    trainloader, valloader = prepare_data()

    net = net.to(device)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=hps['lr'],
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=0.0001)
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='max',
                                  factor=0.5,
                                  patience=10,
                                  verbose=True)
    criterion = nn.CrossEntropyLoss()

    best_acc_v = 0

    print("Training", hps['name'], "on", device)
    for epoch in range(hps['n_epochs']):
        acc_tr, loss_tr = train(net, trainloader, criterion, optimizer)
        logger.loss_train.append(loss_tr)
        logger.acc_train.append(acc_tr)

        acc_v, loss_v = evaluate(net, valloader, criterion)
        logger.loss_val.append(loss_v)
        logger.acc_val.append(acc_v)

        # Update learning rate if plateau
        scheduler.step(acc_v)

        # Save logs regularly
        if (epoch + 1) % 5 == 0:
            logger.save(hps)

        # Save the best network and print results
        if acc_v > best_acc_v:
            save(net, hps)
            best_acc_v = acc_v

            print('Epoch %2d' % (epoch + 1),
                  'Train Accuracy: %2.2f %%' % acc_tr,
                  'Val Accuracy: %2.2f %%' % acc_v,
                  'Network Saved',
                  sep='\t\t')

        else:
            print('Epoch %2d' % (epoch + 1),
                  'Train Accuracy: %2.2f %%' % acc_tr,
                  'Val Accuracy: %2.2f %%' % acc_v,
                  sep='\t\t')
示例#2
0
                                args.batch_accumulation),
        )

        avg_loss, iteration = train_loop(
            model,
            train_loader,
            optimizer,
            criterion,
            scheduler,
            args,
            iteration,
        )

        avg_val_loss, score, val_preds = evaluate(args,
                                                  model,
                                                  valid_loader,
                                                  criterion,
                                                  val_shape=len(valid_set))

        print("Epoch {}/{}: \t loss={:.4f} \t val_loss={:.4f} \t score={:.6f}".
              format(epoch + 1, args.epochs, avg_loss, avg_val_loss, score))
        with open(log_file, "a") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([epoch + 1, avg_loss, avg_val_loss, score])

        torch.save(
            model.state_dict(),
            os.path.join(fold_checkpoints,
                         "model_on_epoch_{}.pth".format(epoch)),
        )
        val_preds_df = val_fold_df.copy()[["qa_id"] + args.target_columns]