epoch + 1, p, loss.item())) scheduler.step() model.eval() val_loss, val_acc = test(model, val_loader, device) print("__________________________________________") print("\nValidation Acc: ", val_acc) print("Validation Loss: ", val_loss) print("\nTraining Loss: ", current_loss / count) print("Learning Rate: ", scheduler.get_lr()[0]) print("__________________________________________") model.train() torch.save(model.state_dict(), models[model_id] + "_state_dict.pt") torch.save(model, models[model_id] + "_model.pt") # After training model.eval() test_loader = get_data_loader(train=False, batch_size=test_size, split='test', model=models[model_id]) _, train_acc = test(model, train_loader, device) print("Final Train Accuracy: ", train_acc) _, test_acc = test(model, test_loader, device) print("Final Accuracy: ", test_acc)
def main(): """ This code implements the ADMM based training of a CNN. """ #model = LeNet5() model = VGG(n_class=10) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') Path = 'saved_model/pre_train_models/cifar10_vgg_acc_0.943' # Path to the baseline model model.load_state_dict(torch.load(Path)) model.to(device) #data_transforms = transforms.Compose([transforms.CenterCrop(32),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]) train_data = datasets.CIFAR10('data/', train=True, download=False, transform=transforms.Compose([ transforms.Pad(4), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])) #train_data = datasets.MNIST(root='data/',download=False,train=True,transform=data_transforms) """ N_train = len(train_data) val_split = 0.1 N_val = int(val_split*N_train) train_data,val_data = torch.utils.data.random_split(train_data,(N_train-N_val,N_val)) """ ## Test data test_data = datasets.CIFAR10('data/', train=False, download=False, transform=transforms.Compose([ transforms.Pad(4), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])) #test_data = datasets.MNIST(root='data/',download=False,train=False,transform=data_transforms) batch_size = 128 num_epochs = 50 log_step = 100 loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) #optimizer = torch.optim.SGD(model.parameters(), lr =5e-4,momentum =0.9, weight_decay = 5e-4 ) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.1) ####### ADMM Training ############## ## Parameters fc_prune = False # True if the fully connected layers are also pruned prune_type = 'filter' # Type of structural pruning at the convolutional layers # Number of non zero filters at each convolutional layer l = { 'conv1': 32, 'conv2': 64, 'conv3': 128, 'conv4': 128, 'conv5': 256, 'conv6': 256, 'conv7': 256, 'conv8': 256 } # ADMM parameters rho_val = 1.5e-3 num_admm_steps = 10 Z = {} U = {} rho = {} best_accuracy = 0 all_acc = False ## Initialization of the variable Z and dual variable U for name_net in model.named_modules(): name, net = name_net if isinstance(net, nn.Conv2d): Z[name] = net.weight.clone().detach().requires_grad_(False) Z[name] = Projection_structured(Z[name], l[name], prune_type) U[name] = torch.zeros_like(net.weight, requires_grad=False) rho[name] = rho_val elif fc_prune and isinstance(net, nn.Linear): Z[name] = net.weight.clone().detach().requires_grad_(False) l_unst = int(len(net.weight.data.reshape(-1, )) * prune_ratio) Z[name], _ = Projection_unstructured(Z[name], l_unst) U[name] = torch.zeros_like(net.weight, requires_grad=False) ## ADMM loop for i in range(num_admm_steps): print('ADMM step number {}'.format(i)) # First train the VGG model train_model_admm(model, train_data, batch_size, loss_fn, optimizer, scheduler, num_epochs, log_step, Z, U, rho, fc_prune, device) # Update the variable Z for name_net in model.named_modules(): name, net = name_net if isinstance(net, nn.Conv2d): Z[name] = Projection_structured(net.weight.detach() + U[name], l[name], prune_type) elif fc_prune and isinstance(net, nn.Linear): l_unst = int(len(net.weight.data.reshape(-1, )) * prune_ratio) Z[name], _ = Projection_unstructured( net.weight.detach() + U[name], l_unst) # Updating the dual variable U for name_net in model.named_modules(): name, net = name_net if isinstance(net, nn.Conv2d): U[name] = U[name] + net.weight.detach() - Z[name] elif fc_prune and isinstance(net, nn.Linear): U[name] = U[name] + net.weight.detach() - Z[name] ## Check the test accuracy model.eval() test_accuracy = eval_accuracy_data(test_data, model, batch_size, device) print('Test accuracy is', test_accuracy) if test_accuracy > best_accuracy: print( 'Saving model with test accuracy {:.3f}'.format(test_accuracy)) torch.save( model.state_dict(), 'saved_model/admm_model/cifar10_vgg_acc_{:.3f}'.format( test_accuracy)) if all_acc: print('Removing model with test accuracy {:.3f}'.format( best_accuracy)) os.remove( 'saved_model/admm_model/cifar10_vgg_acc_{:.3f}'.format( best_accuracy)) best_accuracy = test_accuracy all_acc = True