def train(train_dataset, val_dataset, configs): train_loader = torch.utils.data.DataLoader( train_dataset, batch_size = configs["batch_size"], shuffle = True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size = configs["batch_size"], shuffle = False ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = AlexNet().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(params = model.parameters(), lr = configs["lr"]) for epoch in range(configs["epochs"]): model.train() running_loss = 0.0 correct = 0 for i, (inputs, labels) in tqdm(enumerate(train_loader)): inputs, labels = inputs.to(device), labels.squeeze().to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() running_loss += loss.item() print("[%d] loss: %.4f" % (epoch + 1, running_loss / train_dataset.__len__())) model.eval() correct = 0 with torch.no_grad(): for i, (inputs, labels) in tqdm(enumerate(val_loader)): inputs, labels = inputs.to(device), labels.squeeze().to(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() print("Accuracy of the network on the %d test images: %.4f %%" % (val_dataset.__len__(), 100. * correct / val_dataset.__len__())) torch.save(model.state_dict(), "/opt/output/model.pt")
def train(data_train, data_val, num_classes, num_epoch, milestones): model = AlexNet(num_classes, pretrain=False) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.0001) lr_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1) since = time.time() best_acc = 0 best = 0 for epoch in range(num_epoch): print('Epoch {}/{}'.format(epoch + 1, num_epoch)) print('-' * 10) # Iterate over data. running_loss = 0.0 running_corrects = 0 model.train() with torch.set_grad_enabled(True): for i, (inputs, labels) in enumerate(data_train): inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0) print("\rIteration: {}/{}, Loss: {}.".format(i + 1, len(data_train), loss.item()), end="") sys.stdout.flush() avg_loss = running_loss / len(data_train) t_acc = running_corrects.double() / len(data_train) running_loss = 0.0 running_corrects = 0 model.eval() with torch.set_grad_enabled(False): for i, (inputs, labels) in enumerate(data_val): inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) running_loss += loss.item() running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0) val_loss = running_loss / len(data_val) val_acc = running_corrects.double() / len(data_val) print() print('Train Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, t_acc)) print('Val Loss: {:.4f} Acc: {:.4f}'.format(val_loss, val_acc)) print('lr rate: {:.6f}'.format(optimizer.param_groups[0]['lr'])) print() if val_acc > best_acc: best_acc = val_acc best = epoch + 1 lr_scheduler.step() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best Validation Accuracy: {}, Epoch: {}'.format(best_acc, best)) return model