def get_model(arch_type): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if arch_type == "fc1": model = fc1.fc1().to(device) elif arch_type == "lenet5": model = LeNet5.LeNet5().to(device) elif arch_type == "alexnet": model = AlexNet.AlexNet().to(device) elif arch_type == "vgg16": model = vgg.vgg16().to(device) elif arch_type == "resnet18": model = resnet.resnet18().to(device) elif arch_type == "densenet121": model = densenet.densenet121().to(device) # If you want to add extra model paste here else: print("\nWrong Model choice\n") exit() return model
def main(args, ITE=0): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") reinit = True if args.prune_type == "reinit" else False if args.save_dir: utils.checkdir( f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/" ) utils.checkdir( f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/" ) utils.checkdir( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/" ) else: utils.checkdir( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/") utils.checkdir( f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/") utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/") # Data Loader transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) if args.dataset == "mnist": traindataset = datasets.MNIST('../data', train=True, download=True, transform=transform) testdataset = datasets.MNIST('../data', train=False, transform=transform) from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet elif args.dataset == "cifar10": transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) traindataset = datasets.CIFAR10('../data', train=True, download=True, transform=transform_train) testdataset = datasets.CIFAR10('../data', train=False, transform=transform_test) from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet elif args.dataset == "fashionmnist": traindataset = datasets.FashionMNIST('../data', train=True, download=True, transform=transform) testdataset = datasets.FashionMNIST('../data', train=False, transform=transform) from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet elif args.dataset == "cifar100": traindataset = datasets.CIFAR100('../data', train=True, download=True, transform=transform) testdataset = datasets.CIFAR100('../data', train=False, transform=transform) from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet # If you want to add extra datasets paste here else: print("\nWrong Dataset choice \n") exit() if args.dataset == "cifar10": #trainsampler = torch.utils.data.RandomSampler(traindataset, replacement=True, num_samples=45000) # 45K train dataset #train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False, sampler=trainsampler) train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=True, num_workers=4) else: train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) #train_loader = cycle(train_loader) test_loader = torch.utils.data.DataLoader(testdataset, batch_size=args.batch_size, shuffle=False, num_workers=4) # Importing Network Architecture #Initalize hessian dataloader, default batch_num 1 for inputs, labels in train_loader: hessian_dataloader = (inputs, labels) break global model if args.arch_type == "fc1": model = fc1.fc1().to(device) elif args.arch_type == "lenet5": model = LeNet5.LeNet5().to(device) elif args.arch_type == "alexnet": model = AlexNet.AlexNet().to(device) elif args.arch_type == "vgg16": model = vgg.vgg16().to(device) elif args.arch_type == "resnet18": model = resnet.resnet18().to(device) elif args.arch_type == "densenet121": model = densenet.densenet121().to(device) # If you want to add extra model paste here else: print("\nWrong Model choice\n") exit() model = nn.DataParallel(model) # Weight Initialization model.apply(weight_init) # Copying and Saving Initial State initial_state_dict = copy.deepcopy(model.state_dict()) if args.save_dir: torch.save( model.state_dict(), f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/initial_state_dict_{args.prune_type}.pth" ) else: torch.save( model.state_dict(), f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth" ) # global total_params total_params = 0 # Layer Looper for name, param in model.named_parameters(): print(name, param.size()) total_params += param.numel() # Making Initial Mask make_mask(model, total_params) # Optimizer and Loss optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4) # warm up schedule; scheduler_warmup is chained with schduler_steplr scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0, 15], gamma=0.1, last_epoch=-1) if args.warmup: scheduler_warmup = GradualWarmupScheduler( optimizer, multiplier=1, total_epoch=50, after_scheduler=scheduler_steplr) # 20K=(idx)56, 35K=70 criterion = nn.CrossEntropyLoss( ) # Default was F.nll_loss; why test, train different? # Pruning # NOTE First Pruning Iteration is of No Compression bestacc = 0.0 best_accuracy = 0 ITERATION = args.prune_iterations comp = np.zeros(ITERATION, float) bestacc = np.zeros(ITERATION, float) step = 0 all_loss = np.zeros(args.end_iter, float) all_accuracy = np.zeros(args.end_iter, float) for _ite in range(args.start_iter, ITERATION): if not _ite == 0: prune_by_percentile(args.prune_percent, resample=resample, reinit=reinit, total_params=total_params, hessian_aware=args.hessian, criterion=criterion, dataloader=hessian_dataloader, cuda=torch.cuda.is_available()) if reinit: model.apply(weight_init) #if args.arch_type == "fc1": # model = fc1.fc1().to(device) #elif args.arch_type == "lenet5": # model = LeNet5.LeNet5().to(device) #elif args.arch_type == "alexnet": # model = AlexNet.AlexNet().to(device) #elif args.arch_type == "vgg16": # model = vgg.vgg16().to(device) #elif args.arch_type == "resnet18": # model = resnet.resnet18().to(device) #elif args.arch_type == "densenet121": # model = densenet.densenet121().to(device) #else: # print("\nWrong Model choice\n") # exit() step = 0 for name, param in model.named_parameters(): if 'weight' in name: param_frac = param.numel() / total_params if param_frac > 0.01: weight_dev = param.device param.data = torch.from_numpy( param.data.cpu().numpy() * mask[step]).to(weight_dev) step = step + 1 step = 0 else: original_initialization(mask, initial_state_dict, total_params) # optimizer = torch.optim.SGD([{'params': model.parameters(), 'initial_lr': 0.03}], lr=args.lr, momentum=0.9, weight_decay=1e-4) # scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0, 14], gamma=0.1, last_epoch=-1) # scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=56, after_scheduler=scheduler_steplr) # 20K=(idx)56, 35K=70 print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---") # Optimizer and Loss optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4) # warm up schedule; scheduler_warmup is chained with schduler_steplr scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[0, 15], gamma=0.1, last_epoch=-1) if args.warmup: scheduler_warmup = GradualWarmupScheduler( optimizer, multiplier=1, total_epoch=50, after_scheduler=scheduler_steplr) # 20K=(idx)56, 35K=70 # Print the table of Nonzeros in each layer comp1 = utils.print_nonzeros(model) comp[_ite] = comp1 pbar = tqdm(range(args.end_iter)) # process bar for iter_ in pbar: # Frequency for Testing if iter_ % args.valid_freq == 0: accuracy = test(model, test_loader, criterion) # Save Weights for each _ite if accuracy > best_accuracy: best_accuracy = accuracy if args.save_dir: torch.save( model.state_dict(), f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/{_ite}_model_{args.prune_type}.pth" ) else: # torch.save(model,f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth") torch.save( model.state_dict(), f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth" ) # Training loss = train(model, train_loader, optimizer, criterion, total_params) all_loss[iter_] = loss all_accuracy[iter_] = accuracy # warm up if args.warmup: scheduler_warmup.step() _lr = optimizer.param_groups[0]['lr'] # Save the model during training if args.save_freq > 0 and iter_ % args.save_freq == 0: if args.save_dir: torch.save( model.state_dict(), f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/{_ite}_model_{args.prune_type}_epoch{iter_}.pth" ) else: torch.save( model.state_dict(), f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}_epoch{iter_}.pth" ) # Frequency for Printing Accuracy and Loss if iter_ % args.print_freq == 0: pbar.set_description( f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}% Learning Rate: {_lr:.6f}%' ) writer.add_scalar('Accuracy/test', best_accuracy, comp1) bestacc[_ite] = best_accuracy # Plotting Loss (Training), Accuracy (Testing), Iteration Curve #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations. #NOTE Normalized the accuracy to [0,100] for ease of plotting. plt.plot(np.arange(1, (args.end_iter) + 1), 100 * (all_loss - np.min(all_loss)) / np.ptp(all_loss).astype(float), c="blue", label="Loss") plt.plot(np.arange(1, (args.end_iter) + 1), all_accuracy, c="red", label="Accuracy") plt.title( f"Loss Vs Accuracy Vs Iterations ({args.dataset},{args.arch_type})" ) plt.xlabel("Iterations") plt.ylabel("Loss and Accuracy") plt.legend() plt.grid(color="gray") if args.save_dir: plt.savefig( f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_LossVsAccuracy_{comp1}.png", dpi=1200) else: plt.savefig( f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png", dpi=1200) plt.close() # Dump Plot values if args.save_dir: all_loss.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_all_loss_{comp1}.dat" ) all_accuracy.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_all_accuracy_{comp1}.dat" ) else: all_loss.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat" ) all_accuracy.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat" ) # Dumping mask if args.save_dir: with open( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_mask_{comp1}.pkl", 'wb') as fp: pickle.dump(mask, fp) else: with open( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl", 'wb') as fp: pickle.dump(mask, fp) # Making variables into 0 best_accuracy = 0 all_loss = np.zeros(args.end_iter, float) all_accuracy = np.zeros(args.end_iter, float) # Dumping Values for Plotting if args.save_dir: comp.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_compression.dat" ) bestacc.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_bestaccuracy.dat" ) else: comp.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat" ) bestacc.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat" ) # Plotting a = np.arange(args.prune_iterations) plt.plot(a, bestacc, c="blue", label="Winning tickets") plt.title( f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{args.arch_type})" ) plt.xlabel("Unpruned Weights Percentage") plt.ylabel("test accuracy") plt.xticks(a, comp, rotation="vertical") plt.ylim(0, 100) plt.legend() plt.grid(color="gray") if args.save_dir: plt.savefig( f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_AccuracyVsWeights.png", dpi=1200) else: plt.savefig( f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png", dpi=1200) plt.close()
def main(args, ITE=0): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") reinit = True if args.prune_type == "reinit" else False # Data Loader transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) if args.dataset == "mnist": traindataset = datasets.MNIST('../data', train=True, download=True, transform=transform) testdataset = datasets.MNIST('../data', train=False, transform=transform) from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet elif args.dataset == "cifar10": traindataset = datasets.CIFAR10('../data', train=True, download=True, transform=transform) testdataset = datasets.CIFAR10('../data', train=False, transform=transform) from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet elif args.dataset == "fashionmnist": traindataset = datasets.FashionMNIST('../data', train=True, download=True, transform=transform) testdataset = datasets.FashionMNIST('../data', train=False, transform=transform) from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet elif args.dataset == "cifar100": traindataset = datasets.CIFAR100('../data', train=True, download=True, transform=transform) testdataset = datasets.CIFAR100('../data', train=False, transform=transform) from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet # If you want to add extra datasets paste here else: print("\nWrong Dataset choice \n") exit() train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) #train_loader = cycle(train_loader) test_loader = torch.utils.data.DataLoader(testdataset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True) # Importing Network Architecture global model if args.arch_type == "fc1": model = fc1.fc1().to(device) elif args.arch_type == "lenet5": model = LeNet5.LeNet5().to(device) elif args.arch_type == "alexnet": model = AlexNet.AlexNet().to(device) elif args.arch_type == "vgg16": model = vgg.vgg16().to(device) elif args.arch_type == "resnet18": model = resnet.resnet18().to(device) elif args.arch_type == "densenet121": model = densenet.densenet121().to(device) # If you want to add extra model paste here else: print("\nWrong Model choice\n") exit() # Weight Initialization model.apply(weight_init) # Copying and Saving Initial State initial_state_dict = copy.deepcopy(model.state_dict()) utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/") torch.save( model, f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar" ) # Making Initial Mask make_mask(model) # Optimizer and Loss optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4) criterion = nn.CrossEntropyLoss() # Default was F.nll_loss # Layer Looper for name, param in model.named_parameters(): print(name, param.size()) # Pruning # NOTE First Pruning Iteration is of No Compression bestacc = 0.0 best_accuracy = 0 ITERATION = args.prune_iterations comp = np.zeros(ITERATION, float) bestacc = np.zeros(ITERATION, float) step = 0 all_loss = np.zeros(args.end_iter, float) all_accuracy = np.zeros(args.end_iter, float) for _ite in range(args.start_iter, ITERATION): if not _ite == 0: prune_by_percentile(args.prune_percent, resample=resample, reinit=reinit) if reinit: model.apply(weight_init) #if args.arch_type == "fc1": # model = fc1.fc1().to(device) #elif args.arch_type == "lenet5": # model = LeNet5.LeNet5().to(device) #elif args.arch_type == "alexnet": # model = AlexNet.AlexNet().to(device) #elif args.arch_type == "vgg16": # model = vgg.vgg16().to(device) #elif args.arch_type == "resnet18": # model = resnet.resnet18().to(device) #elif args.arch_type == "densenet121": # model = densenet.densenet121().to(device) #else: # print("\nWrong Model choice\n") # exit() step = 0 for name, param in model.named_parameters(): if 'weight' in name: weight_dev = param.device param.data = torch.from_numpy( param.data.cpu().numpy() * mask[step]).to(weight_dev) step = step + 1 step = 0 else: original_initialization(mask, initial_state_dict) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---") # Print the table of Nonzeros in each layer comp1 = utils.print_nonzeros(model) comp[_ite] = comp1 pbar = tqdm(range(args.end_iter)) for iter_ in pbar: # Frequency for Testing if iter_ % args.valid_freq == 0: accuracy = test(model, test_loader, criterion) # Save Weights if accuracy > best_accuracy: best_accuracy = accuracy utils.checkdir( f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/" ) torch.save( model, f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth.tar" ) # Training if _ite == 0: loss = train(model, train_loader, optimizer, criterion) #needed to be completed #teacher_model = ... else: loss = train_with_distill(model, train_loader, optimizer, teacher_model) # Frequency for Printing Accuracy and Loss if iter_ % args.print_freq == 0: pbar.set_description( f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}%' ) writer.add_scalar('Accuracy/test', best_accuracy, comp1) bestacc[_ite] = best_accuracy # Plotting Loss (Training), Accuracy (Testing), Iteration Curve #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations. #NOTE Normalized the accuracy to [0,100] for ease of plotting. plt.plot(np.arange(1, (args.end_iter) + 1), 100 * (all_loss - np.min(all_loss)) / np.ptp(all_loss).astype(float), c="blue", label="Loss") plt.plot(np.arange(1, (args.end_iter) + 1), all_accuracy, c="red", label="Accuracy") plt.title( f"Loss Vs Accuracy Vs Iterations ({args.dataset},{args.arch_type})" ) plt.xlabel("Iterations") plt.ylabel("Loss and Accuracy") plt.legend() plt.grid(color="gray") utils.checkdir( f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/") plt.savefig( f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png", dpi=1200) plt.close() # Dump Plot values utils.checkdir( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/") all_loss.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat" ) all_accuracy.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat" ) # Dumping mask utils.checkdir( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/") with open( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl", 'wb') as fp: pickle.dump(mask, fp) # Making variables into 0 best_accuracy = 0 all_loss = np.zeros(args.end_iter, float) all_accuracy = np.zeros(args.end_iter, float) # Dumping Values for Plotting utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/") comp.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat" ) bestacc.dump( f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat" ) # Plotting a = np.arange(args.prune_iterations) plt.plot(a, bestacc, c="blue", label="Winning tickets") plt.title( f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{args.arch_type})" ) plt.xlabel("Unpruned Weights Percentage") plt.ylabel("test accuracy") plt.xticks(a, comp, rotation="vertical") plt.ylim(0, 100) plt.legend() plt.grid(color="gray") utils.checkdir(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/") plt.savefig( f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png", dpi=1200) plt.close()
def main(args, ITE=0): import pandas as pd pd.set_option('display.width', 400) pd.set_option('display.max_columns', 10) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") reinit = True if args.prune_type == "reinit" else False layerwise = True if args.prune_type == "layerwise" else False # Data Loader transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) transform_cifar10 = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == "mnist": traindataset = datasets.MNIST('../data', train=True, download=True, transform=transform) testdataset = datasets.MNIST('../data', train=False, transform=transform) from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet elif args.dataset == "cifar10": traindataset = datasets.CIFAR10('../data', train=True, download=True, transform=transform_cifar10) testdataset = datasets.CIFAR10('../data', train=False, transform=transform_cifar10) from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet, minivgg elif args.dataset == "fashionmnist": traindataset = datasets.FashionMNIST('../data', train=True, download=True, transform=transform) testdataset = datasets.FashionMNIST('../data', train=False, transform=transform) from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet elif args.dataset == "cifar100": traindataset = datasets.CIFAR100('../data', train=True, download=True, transform=transform) testdataset = datasets.CIFAR100('../data', train=False, transform=transform) from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet # If you want to add extra datasets paste here else: print("\nWrong Dataset choice \n") exit() # obtain training indices that will be used for validation if args.early_stopping: print(' Splitting Validation sets ') trainset_size = int((1 - args.valid_size) * len(traindataset)) valset_size = len(traindataset) - trainset_size trainset, valset = torch.utils.data.random_split( traindataset, [trainset_size, valset_size]) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) valid_loader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) else: print(' Eww, no validation set? ') train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False) # train_loader = cycle(train_loader) test_loader = torch.utils.data.DataLoader(testdataset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True) # Importing Network Architecture global model if args.arch_type == "fc1": model = fc1.fc1().to(device) elif args.arch_type == "lenet5": model = LeNet5.LeNet5().to(device) elif args.arch_type == "alexnet": model = AlexNet.AlexNet().to(device) elif args.arch_type == "vgg16": model = vgg.vgg16().to(device) elif args.arch_type == "resnet18": model = resnet.resnet18().to(device) elif args.arch_type == "densenet121": model = densenet.densenet121().to(device) # If you want to add extra model paste here elif args.arch_type == "conv2": model = minivgg.conv2().to(device) elif args.arch_type == "conv4": model = minivgg.conv4().to(device) elif args.arch_type == "conv6": model = minivgg.conv6().to(device) else: print("\nWrong Model choice\n") exit() # Weight Initialization. Warning! This drops test acc, so i'm examining this function. model.apply(weight_init) # get time for file path import datetime now = datetime.datetime.now() now_ = now.strftime("%02m%02d%02H%02M_") # Copying and Saving Initial State print(' saving initial model... ') initial_state_dict = copy.deepcopy(model.state_dict()) utils.checkdir( f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/") torch.save( model, f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar" ) print( " initial model saved in ", f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar" ) # Making Initial Mask make_mask(model) # Optimizer and Loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) criterion = nn.CrossEntropyLoss() # Default was F.nll_loss # Layer Looper for name, param in model.named_parameters(): print(name, param.size()) # Pruning # NOTE First Pruning Iteration is of No Compression bestacc = 0.0 best_accuracy = 0 ITERATION = args.prune_iterations comp = np.zeros(ITERATION, float) bestacc = np.zeros(ITERATION, float) step = 0 all_loss = np.zeros(args.end_iter, float) all_vloss = np.zeros(args.end_iter, float) all_accuracy = np.zeros(args.end_iter, float) for _ite in range(args.start_iter, ITERATION): # Early stopping parameter for each pruning iteration early_stopping = EarlyStopping( patience=99, verbose=True) ######### we don't stop, party all night if not _ite == 0: prune_by_percentile(args.prune_percent, args.fc_prune_percent, resample=resample, reinit=reinit, layerwise=layerwise, if_split=args.split_conv_and_fc) if reinit: model.apply(weight_init) #if args.arch_type == "fc1": # model = fc1.fc1().to(device) #elif args.arch_type == "lenet5": # model = LeNet5.LeNet5().to(device) #elif args.arch_type == "alexnet": # model = AlexNet.AlexNet().to(device) #elif args.arch_type == "vgg16": # model = vgg.vgg16().to(device) #elif args.arch_type == "resnet18": # model = resnet.resnet18().to(device) #elif args.arch_type == "densenet121": # model = densenet.densenet121().to(device) #else: # print("\nWrong Model choice\n") # exit() step = 0 for name, param in model.named_parameters(): if 'weight' in name: weight_dev = param.device param.data = torch.from_numpy( param.data.cpu().numpy() * mask[step]).to(weight_dev) step = step + 1 step = 0 else: original_initialization(mask, initial_state_dict) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) time.sleep(0.25) print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---") # Print the table of Nonzeros in each layer comp1 = utils.print_nonzeros(model) comp[_ite] = comp1 # pbar = range(args.end_iter) pbar = tqdm(range(args.end_iter)) stop_flag = False for iter_ in pbar: # Frequency for Testing if iter_ % args.valid_freq == 0: accuracy = test(model, test_loader, criterion) # Save Weights if accuracy > best_accuracy: best_accuracy = accuracy # We don't save model per test-acc, will use validation-acc! # utils.checkdir(f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/") # torch.save(model,f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth.tar") # Training loss = train(model, train_loader, optimizer, criterion) all_loss[iter_] = loss all_accuracy[iter_] = accuracy # Validating valid_loss, loss_v = validate(model, valid_loader, optimizer, criterion) all_vloss[iter_] = valid_loss #loss_v # early stopping checkpoint_path = f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/" save_path = f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth.tar" # msg = early_stopping(valid_loss, model, checkpoint_path, save_path) early_stopping(valid_loss, model, checkpoint_path, save_path) # Frequency for Printing Accuracy and Loss if iter_ % args.print_freq == 0: time.sleep(0.25) pbar.set_description( # f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}% \t' + msg) f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} V-Loss: {valid_loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}%' ) if iter_ % 5 == 4: print('') if early_stopping.early_stop and not stop_flag: print("Early stopping") stop_flag = True # break writer.add_scalar('Accuracy/test', best_accuracy, comp1) bestacc[_ite] = best_accuracy # Plotting Loss (Training), Accuracy (Testing), Iteration Curve #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations. #NOTE Normalized the accuracy to [0,100] for ease of plotting. plt.plot(np.arange(1, (args.end_iter) + 1), 100 * (all_loss - np.min(all_loss)) / np.ptp(all_loss).astype(float), c="blue", label="Train loss") plt.plot(np.arange(1, (args.end_iter) + 1), 100 * (all_vloss - np.min(all_vloss)) / np.ptp(all_vloss).astype(float), c="green", label="Valid loss") plt.plot(np.arange(1, (args.end_iter) + 1), all_accuracy, c="red", label="Accuracy") plt.title( f"Loss Vs Accuracy Vs Iterations ({args.dataset},{now_}{args.arch_type})" ) plt.xlabel("Iterations") plt.ylabel("Loss and Accuracy") plt.legend() plt.grid(color="gray") utils.checkdir( f"{os.getcwd()}/plots/lt/{now_}{args.arch_type}/{args.dataset}/") plt.savefig( f"{os.getcwd()}/plots/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png", dpi=300) plt.close() # Dump Plot values utils.checkdir( f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/") all_loss.dump( f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat" ) all_accuracy.dump( f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat" ) # Dumping mask utils.checkdir( f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/") with open( f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl", 'wb') as fp: pickle.dump(mask, fp) # Making variables into 0 best_accuracy = 0 all_loss = np.zeros(args.end_iter, float) all_accuracy = np.zeros(args.end_iter, float) # Dumping Values for Plotting utils.checkdir( f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/") comp.dump( f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat" ) bestacc.dump( f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat" ) # Plotting a = np.arange(args.prune_iterations) plt.plot(a, bestacc, c="blue", label="Winning tickets") plt.title( f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{now_}{args.arch_type})" ) plt.xlabel("Unpruned Weights Percentage") plt.ylabel("test accuracy") plt.xticks(a, comp, rotation="vertical") plt.ylim(0, 100) plt.legend() plt.grid(color="gray") utils.checkdir( f"{os.getcwd()}/plots/lt/{now_}{args.arch_type}/{args.dataset}/") plt.savefig( f"{os.getcwd()}/plots/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png", dpi=300) plt.close() print('Training ended~~~')