def run(p_seed=0, p_epochs=150, p_kernel_size=5, p_logdir="temp"): # random number generator seed ------------------------------------------------# SEED = p_seed torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.manual_seed(SEED) torch.cuda.manual_seed_all(SEED) np.random.seed(SEED) # kernel size of model --------------------------------------------------------# KERNEL_SIZE = p_kernel_size # number of epochs ------------------------------------------------------------# NUM_EPOCHS = p_epochs # file names ------------------------------------------------------------------# if not os.path.exists("../logs/%s" % p_logdir): os.makedirs("../logs/%s" % p_logdir) OUTPUT_FILE = str("../logs/%s/log%03d.out" % (p_logdir, SEED)) MODEL_FILE = str("../logs/%s/model%03d.pth" % (p_logdir, SEED)) # enable GPU usage ------------------------------------------------------------# use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") if use_cuda == False: print("WARNING: CPU will be used for training.") exit(0) # data augmentation methods ---------------------------------------------------# transform = transforms.Compose([ RandomRotation(20, seed=SEED), transforms.RandomAffine(0, translate=(0.2, 0.2)), ]) # data loader -----------------------------------------------------------------# train_dataset = MnistDataset(training=True, transform=transform) test_dataset = MnistDataset(training=False, transform=None) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=120, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False) # model selection -------------------------------------------------------------# if (KERNEL_SIZE == 3): model = ModelM3().to(device) elif (KERNEL_SIZE == 5): model = ModelM5().to(device) elif (KERNEL_SIZE == 7): model = ModelM7().to(device) summary(model, (1, 28, 28)) # hyperparameter selection ----------------------------------------------------# ema = EMA(model, decay=0.999) optimizer = optim.Adam(model.parameters(), lr=0.001) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98) # delete result file ----------------------------------------------------------# f = open(OUTPUT_FILE, 'w') f.close() # global variables ------------------------------------------------------------# g_step = 0 max_correct = 0 # training and evaluation loop ------------------------------------------------# for epoch in range(NUM_EPOCHS): #--------------------------------------------------------------------------# # train process # #--------------------------------------------------------------------------# model.train() train_loss = 0 train_corr = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) train_pred = output.argmax(dim=1, keepdim=True) train_corr += train_pred.eq( target.view_as(train_pred)).sum().item() train_loss += F.nll_loss(output, target, reduction='sum').item() loss.backward() optimizer.step() g_step += 1 ema(model, g_step) if batch_idx % 100 == 0: print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\tLoss: {:.6f}'. format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) train_loss /= len(train_loader.dataset) train_accuracy = 100 * train_corr / len(train_loader.dataset) #--------------------------------------------------------------------------# # test process # #--------------------------------------------------------------------------# model.eval() ema.assign(model) test_loss = 0 correct = 0 total_pred = np.zeros(0) total_target = np.zeros(0) with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() pred = output.argmax(dim=1, keepdim=True) total_pred = np.append(total_pred, pred.cpu().numpy()) total_target = np.append(total_target, target.cpu().numpy()) correct += pred.eq(target.view_as(pred)).sum().item() if (max_correct < correct): torch.save(model.state_dict(), MODEL_FILE) max_correct = correct print("Best accuracy! correct images: %5d" % correct) ema.resume(model) #--------------------------------------------------------------------------# # output # #--------------------------------------------------------------------------# test_loss /= len(test_loader.dataset) test_accuracy = 100 * correct / len(test_loader.dataset) best_test_accuracy = 100 * max_correct / len(test_loader.dataset) print( '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%) (best: {:.2f}%)\n' .format(test_loss, correct, len(test_loader.dataset), test_accuracy, best_test_accuracy)) f = open(OUTPUT_FILE, 'a') f.write(" %3d %12.6f %9.3f %12.6f %9.3f %9.3f\n" % (epoch, train_loss, train_accuracy, test_loss, test_accuracy, best_test_accuracy)) f.close() #--------------------------------------------------------------------------# # update learning rate scheduler # #--------------------------------------------------------------------------# lr_scheduler.step()
'learning_rate', optimizer.param_groups[0]['lr'], train_idx) if train_idx % args.vis_idx == args.vis_idx-1: writer.add_scalar('training/total_loss', loss.item(), train_idx) writer.add_scalar('training/sup_loss', loss_sup.item(), train_idx) if args.mod == 'semisup': writer.add_scalar('training/unsup_loss', loss_unsup, train_idx) print('[%d] loss: %.3f loss_sup: %.3f loss_unsup: %.3f' % ( train_idx, running_loss[0] / 100, running_loss[1] / 100, running_loss[2] / 100)) else: print('[%d] loss: %.3f loss_sup: %.3f' % (train_idx, running_loss[0] / 100, running_loss[1] / 100)) running_loss = [0.0, 0.0, 0.0] # eval model if train_idx % args.eval_idx == args.eval_idx-1: ema.assign(net) curr_val = eval_model(net, validloader, writer, train_idx) ema.resume(net) # save model if curr_val > best_val: torch.save(net.state_dict(), args.model_path) # impose infinite loop if train_idx % trainloader_sup_iter.__len__() == trainloader_sup_iter.__len__()-1: trainloader_sup_iter = iter(trainloader_sup) if args.mod == 'semisup': trainloader_unsup_iter = iter(trainloader_unsup) print('Finished Training')