def train(model_name, model, trainloader, testloader, device, opt, nb_epochs, lr=0.001): history_loss = [] history_acc = [] criterion = nn.CrossEntropyLoss() print("Using optimizer: ", opt) #TODO adjust optimizer hyperparameters if opt == 'sgd': optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) elif opt == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr) elif opt == 'lbfgs': optimizer = LBFGSNew(model.parameters(), history_size=7, max_iter=2, line_search_fn=True, batch_mode=True) #optimizer = optim.LBFGS(model.parameters()) else: raise NotImplementedError for epoch in range(nb_epochs): # Train for each epoch model.train() running_loss = 0.0 for batch_idx, data in enumerate(trainloader, 0): inputs, labels = data[0].to(device), data[1].to(device) if opt == 'lbfgs': # Def Closure def closure(): if torch.is_grad_enabled(): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) if loss.requires_grad: loss.backward() return loss optimizer.step(closure) outputs = model(inputs) loss = criterion(outputs, labels) else: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() #if batch_idx % 100 == 99: # print every 100 mini-batches # print('[{}, {}] loss: {}'.format(epoch + 1, i + 1, running_loss / 100)) # running_loss = 0.0 # Test for each epoch epoch_loss = running_loss / (batch_idx + 1) epoch_acc = test(model, testloader, device) print("Epoch {} train loss: {}, test acc: {}".format( epoch + 1, epoch_loss, epoch_acc)) history_loss.append(epoch_loss) history_acc.append(epoch_acc) print('Finished Training') with open('history_loss_mnist' + '_' + model_name + '_' + opt + '.json', 'w') as f: json.dump(history_loss, f) with open('history_acc_mnist' + '_' + model_name + '_' + opt + '.json', 'w') as f: json.dump(history_acc, f)
net.train() # initialize for training (BN,dropout) start_time=time.time() use_lbfgs=True # train network for epoch in range(20): running_loss=0.0 for i,data in enumerate(trainloader,0): # get the inputs inputs,labels=data # wrap them in variable inputs,labels=Variable(inputs).to(mydevice),Variable(labels).to(mydevice) if not use_lbfgs: # zero gradients optimizer.zero_grad() # forward+backward optimize outputs=net(inputs) loss=criterion(outputs,labels) loss.backward() optimizer.step() else: if not wide_resnet: layer1=torch.cat([x.view(-1) for x in net.layer1.parameters()]) layer2=torch.cat([x.view(-1) for x in net.layer2.parameters()]) layer3=torch.cat([x.view(-1) for x in net.layer3.parameters()]) layer4=torch.cat([x.view(-1) for x in net.layer4.parameters()]) def closure(): if torch.is_grad_enabled(): optimizer.zero_grad()