Esempio n. 1
0
def train(train_dataset, val_dataset, configs):

    train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size = configs["batch_size"],
            shuffle = True
    )

    val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size = configs["batch_size"],
            shuffle = False
    )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = AlexNet().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params = model.parameters(), lr = configs["lr"])

    for epoch in range(configs["epochs"]):

        model.train()
        running_loss = 0.0
        correct = 0

        for i, (inputs, labels) in tqdm(enumerate(train_loader)):

            inputs, labels = inputs.to(device), labels.squeeze().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            running_loss += loss.item()

        print("[%d] loss: %.4f" %
                  (epoch + 1, running_loss / train_dataset.__len__()))

        model.eval()
        correct = 0

        with torch.no_grad():

            for i, (inputs, labels) in tqdm(enumerate(val_loader)):

                inputs, labels = inputs.to(device), labels.squeeze().to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels).sum().item()

        print("Accuracy of the network on the %d test images: %.4f %%" %
                (val_dataset.__len__(), 100. * correct / val_dataset.__len__()))

    torch.save(model.state_dict(), "/opt/output/model.pt")
def train(data_train, data_val, num_classes, num_epoch, milestones):
    model = AlexNet(num_classes, pretrain=False)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    lr_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    since = time.time()
    best_acc = 0
    best = 0
    for epoch in range(num_epoch):
        print('Epoch {}/{}'.format(epoch + 1, num_epoch))
        print('-' * 10)


        # Iterate over data.
        running_loss = 0.0
        running_corrects = 0
        model.train()
        with torch.set_grad_enabled(True):
            for i, (inputs, labels) in enumerate(data_train):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0)
                print("\rIteration: {}/{}, Loss: {}.".format(i + 1, len(data_train), loss.item()), end="")

                sys.stdout.flush()

        avg_loss = running_loss / len(data_train)
        t_acc = running_corrects.double() / len(data_train)

        running_loss = 0.0
        running_corrects = 0
        model.eval()
        with torch.set_grad_enabled(False):
            for i, (inputs, labels) in enumerate(data_val):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0)

        val_loss = running_loss / len(data_val)
        val_acc = running_corrects.double() / len(data_val)

        print()
        print('Train Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, t_acc))
        print('Val Loss: {:.4f} Acc: {:.4f}'.format(val_loss, val_acc))
        print('lr rate: {:.6f}'.format(optimizer.param_groups[0]['lr']))
        print()

        if val_acc > best_acc:
            best_acc = val_acc
            best = epoch + 1

        lr_scheduler.step()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Validation Accuracy: {}, Epoch: {}'.format(best_acc, best))

    return model