def main(): manager = Manager.init() models = [["model", MobileNetV2(**manager.args.model)]] manager.init_model(models) args = manager.args criterion = Criterion() optimizer, scheduler = Optimizer(models, args.optim).init() args.cuda = args.cuda and torch.cuda.is_available() if args.cuda: for item in models: item[1].cuda() criterion.cuda() dataloader = DataLoader(args.dataloader, args.cuda) summary = manager.init_summary() trainer = Trainer(models, criterion, optimizer, scheduler, dataloader, summary, args.cuda) for epoch in range(args.runtime.start_epoch, args.runtime.num_epochs + args.runtime.start_epoch): try: print("epoch {}...".format(epoch)) trainer.train(epoch) manager.save_checkpoint(models, epoch) if (epoch + 1) % args.runtime.test_every == 0: trainer.validate() except KeyboardInterrupt: print("Training had been Interrupted\n") break trainer.test()
def main(args): # training_set, test_set = compile_dataset_files(args.data) w300_dataset = W300(args.data, train=True) train_dataset = w300_dataset.create_dataset(args) # Tensorboard if (os.path.exists(tb_train_dir)): shutil.rmtree(tb_train_dir) logdir = os.path.join(tb_train_dir, datetime.now().strftime("%Y%m%d-%H%M%S")) file_writer = tf.summary.create_file_writer(logdir) model = MobileNetV2(input_shape=(256, 256, 3), k=w300_dataset.expression_number()) optimizer = tf.keras.optimizers.Adam() loss_fn = tf.keras.losses.mean_squared_error() epoch_num = 0 with file_writer.as_default(): for step, data in enumerate(train_dataset): with tf.GradientTape() as tape: img, exp_params, vertex = data epoch_step = step % len(w300_dataset) if epoch_step == 0: epoch_num += 1 print("starting epoch {}".format(epoch_num)) prediction = model(img, training=True) loss = loss_fn(prediction, exp_params) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) if step % args.vis_freq == 0: print("epoch {} step {}".format(epoch_num, epoch_step)) tf.summary.image("Training Image", img, max_outputs=1, step=step) summary = mesh_summary.mesh('GT Expression Mesh', vertices=vertex[0:], faces=w300_dataset.faces, step=step) file_writer.flush()
prune_fname=vgg_60_fname, descent_idx=14) latency = compute_latency(model) params = compute_params_(model) param_60.append(params) latency_60.append(latency) elif args.model == "mobilenetv2": # CIFAR100 only filters = [[32], [16], [24, 24], [32, 32, 32], [64, 64, 64, 64], [96, 96, 96], [160, 160, 160], [320], [1280]] ratios = np.arange(0.25, 2.1, 0.25) # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2] for ratio in ratios: # uniform new_config = MobileNetV2.prepare_filters(MobileNetV2, filters, ratio=ratio, neuralscale=False, num_classes=num_classes) model = MobileNetV2(filters=new_config, num_classes=num_classes, dataset=args.dataset) latency = compute_latency(model) params = compute_params_(model) param_uni.append(params) latency_uni.append(latency) ## efficient mobilenetv2_0_fname = "mobilenetv2_0_eff_c100" mobilenetv2_2_fname = "mobilenetv2_2_eff_c100" mobilenetv2_5_fname = "mobilenetv2_5_eff_c100" mobilenetv2_10_fname = "mobilenetv2_10_eff_c100" mobilenetv2_30_fname = "mobilenetv2_30_eff_c100"
elif args.model == "mobilenetv2": original_filters = [[32],[16],[24,24],[32,32,32],[64,64,64,64],[96,96,96],[160,160,160],[320],[1280]] for idx, ratio in enumerate(ratios): if plot_growth: total_params = [] print("Ratio: {}".format(ratio)) for iteration in range(15): if args.model == "vgg": filters = VGG.prepare_filters(VGG, original_filters, ratio=ratio, neuralscale=True, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=iteration) filters = [cfg for cfg in list(filter(lambda a: a != 'M', filters))] elif args.model == "resnet18": filters = PreActResNet.prepare_filters(PreActResNet, original_filters, ratio=ratio, neuralscale=True, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=iteration) filters = [cfg for cfg in sum(filters,[])] elif args.model == "mobilenetv2": filters = MobileNetV2.prepare_filters(MobileNetV2, original_filters, ratio=ratio, neuralscale=True, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=iteration) filters = [cfg for cfg in sum(filters,[])][:-2] if plot_growth: total_params.append(filters) if plot_growth: # ax = sns.heatmap(np.array(total_params).T, linewidth=0.005, xticklabels=2, yticklabels=2, ax=axs[idx]) ax = sns.heatmap(np.array(total_params).T, linewidth=0.005, xticklabels=2, ax=axs[int(idx/2)][idx%2]) ax.set_title("Ratio={}".format(ratio)) # if idx!=0: # ax.tick_params(axis='y', which='both', width=0) if idx==1 or idx==3: ax.tick_params(axis='y', which='both', width=0, length=0) if idx==0 or idx==1: ax.tick_params(axis='x', which='both', width=0, length=0)
def main(): # Training settings parser = argparse.ArgumentParser( description= 'Search for optimal configuration using architecture descent') parser.add_argument('--batch-size', type=int, default=256, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=300, metavar='N', help='number of epochs to train (default: 40)') parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)') parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay (default: 5e-4)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument( '--dataset', default="CIFAR10", type=str, help='dataset for experiment, choice: CIFAR10, CIFAR100', choices=["CIFAR10", "CIFAR100"]) parser.add_argument( '--model', default="vgg", type=str, help='model selection, choices: vgg, mobilenetv2, resnet18', choices=["vgg", "mobilenetv2", "resnet18"]) parser.add_argument('--save', default='model', help='model and prune file') parser.add_argument('--no_prune', dest="prune_on", action='store_false', default=True, help='Turn off pruning') parser.add_argument( '--warmup', type=int, default=10, help= 'number of warm-up or fine-tuning epochs before pruning (default: 10)') parser.add_argument( '--morph', dest="morph", action='store_true', default=False, help='Prunes only 50 percent of neurons, for comparison with MorphNet') args = parser.parse_args() ################## ## Data loading ## ################## kwargs = {'num_workers': 1, 'pin_memory': True} if (args.dataset == "CIFAR10"): print("Using Cifar10 Dataset") normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, **kwargs) elif args.dataset == "CIFAR100": print("Using Cifar100 Dataset") normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, **kwargs) elif args.dataset == "Imagenet": print("Using Imagenet Dataset") traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_sampler = None kwargs = {'num_workers': 16} train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, pin_memory=True, **kwargs) test_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, pin_memory=True, **kwargs) else: print("Dataset does not exist! [CIFAR10, CIFAR100]") exit() if args.dataset == "CIFAR10": num_classes = 10 elif args.dataset == "CIFAR100": num_classes = 100 param_list = [] if args.morph: total_iter = 1 else: total_iter = 15 for iteration in range(total_iter): print("Iteration: {}".format(iteration)) args.lr = 0.1 ########### ## Model ## ########### print("Setting Up Model...") if args.model == "vgg": model = vgg11(ratio=1, neuralscale=True, iteration=iteration, num_classes=num_classes, search=True) elif args.model == "resnet18": model = PreActResNet18(ratio=1, neuralscale=True, iteration=iteration, num_classes=num_classes, search=True, dataset=args.dataset) elif args.model == "mobilenetv2": model = MobileNetV2(ratio=1, neuralscale=True, iteration=iteration, num_classes=num_classes, search=True, dataset=args.dataset) else: print(args.model, "model not supported") exit() print("{} set up.".format(args.model)) # for model saving model_path = "saved_models" if not os.path.exists(model_path): os.makedirs(model_path) log_save_folder = "%s/%s" % (model_path, args.model) if not os.path.exists(log_save_folder): os.makedirs(log_save_folder) model_save_path = "%s/%s" % (log_save_folder, args.save) + "_checkpoint.t7" model_state_dict = model.state_dict() if args.save: print("Model will be saved to {}".format(model_save_path)) save_checkpoint({'state_dict': model_state_dict}, False, filename=model_save_path) else: print("Save path not defined. Model will not be saved.") # Assume cuda is available and uses single GPU model.cuda() cudnn.benchmark = True # define objective criterion = nn.CrossEntropyLoss() ###################### ## Set up pruning ## ###################### # remove updates from gate layers, because we want them to be 0 or 1 constantly parameters_for_update = [] parameters_for_update_named = [] for name, m in model.named_parameters(): if "gate" not in name: parameters_for_update.append(m) parameters_for_update_named.append((name, m)) else: print("skipping parameter", name, "shape:", m.shape) total_size_params = sum( [np.prod(par.shape) for par in parameters_for_update]) print("Total number of parameters, w/o usage of bn consts: ", total_size_params) optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.prune_on: pruning_parameters_list = prepare_pruning_list(model) print("Total pruning layers:", len(pruning_parameters_list)) if args.morph: prune_neurons = 0.5 else: prune_neurons = 0.95 pruning_engine = pruner(pruning_parameters_list, iteration=iteration, prune_fname=args.save, classes=num_classes, model=args.model, prune_neurons=prune_neurons) ############### ## Training ## ############### for epoch in range(1, args.epochs + 1): print("Epoch: {}".format(epoch)) adjust_learning_rate(args, optimizer, epoch, search=True, warmup=args.warmup) # train model if args.prune_on: train_acc, train_loss = train(args, model, train_loader, optimizer, epoch, criterion, pruning_engine=pruning_engine, num_classes=num_classes) if train_acc == -1 and train_loss == -1: break else: train_acc, train_loss = train(args, model, train_loader, optimizer, epoch, criterion, num_classes=num_classes) # ========================= # store searched parameters # ========================= alpha, beta = fit_params( iteration=iteration, prune_fname=args.save, classes=num_classes, model=args.model) # search for params of each layer param_list.append([alpha, beta]) # ======================= # save scaling parameters # ======================= pickle_save = { "param": param_list, } pickle_out = open("prune_record/" + args.save + ".pk", "wb") pickle.dump(pickle_save, pickle_out) pickle_out.close()
def main(): # Training settings parser = argparse.ArgumentParser(description='Efficient Filter Scaling of Convolutional Neural Network') parser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=300, metavar='N', help='number of epochs to train (default: 40)') parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)') parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay (default: 5e-4)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--dataset', default="CIFAR10", type=str, help='dataset for experiment, choice: CIFAR10, CIFAR100', choices= ["CIFAR10", "CIFAR100"]) parser.add_argument('--model', default="resnet18", type=str, help='model selection, choices: vgg, mobilenetv2, resnet18', choices=["vgg", "mobilenetv2", "resnet18"]) parser.add_argument('--save', default='model', help='model file') parser.add_argument('--prune_fname', default='filename', help='prune save file') parser.add_argument('--descent_idx', type=int, default=14, help='Iteration for Architecture Descent') parser.add_argument('--morph', dest="morph", action='store_true', default=False, help='Prunes only 50 percent of neurons, for comparison with MorphNet') parser.add_argument('--uniform', dest="uniform", action='store_true', default=False, help='Use uniform scaling instead of NeuralScale') args = parser.parse_args() ################## ## Data loading ## ################## kwargs = {'num_workers': 1, 'pin_memory': True} if(args.dataset == "CIFAR10"): print("Using Cifar10 Dataset") normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, **kwargs) elif args.dataset == "CIFAR100": print("Using Cifar100 Dataset") normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, **kwargs) else: print("Dataset does not exist! [CIFAR10, CIFAR100]") exit() if args.dataset=='CIFAR10': num_classes = 10 elif args.dataset=='CIFAR100': num_classes = 100 ratios = np.arange(0.25,2.1,0.25) # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2] pruned_filters = None neuralscale = True # turn NeuralScale on by default if args.uniform: neuralscale = False if args.morph: neuralscale = False if args.model == "vgg": if args.dataset == "CIFAR10": pruned_filters = [64, 128, 249, 253, 268, 175, 87, 152] # VGG C10 elif args.dataset == "CIFAR100": pruned_filters = [63, 125, 204, 215, 234, 174, 120, 241] # VGG C100 else: print("{} not supported for {}".format(args.dataset, args.model)) exit() elif args.model == "resnet18": if args.dataset == "CIFAR100": pruned_filters = [48, 46, 40, 41, 54, 91, 75, 73, 95, 157, 149, 149, 156, 232, 216, 140, 190] # resnet18 c100 elif args.dataset == "tinyimagenet": print("Please use imagenet_ratio_swipe.py instead.") exit() else: print("{} not supported for {}".format(args.dataset, args.model)) exit() elif args.model == "mobilenetv2": if args.dataset == "CIFAR100": pruned_filters = [28, 16, 24, 21, 30, 31, 26, 56, 50, 49, 46, 83, 70, 58, 120, 101, 68, 134, 397] # mobilenetv2 c100 elif args.dataset == "tinyimagenet": print("Please use imagenet_ratio_swipe.py instead.") exit() else: print("{} not supported for {}".format(args.dataset, args.model)) exit() for ratio in ratios: print("Current ratio: {}".format(ratio)) ########### ## Model ## ########### print("Setting Up Model...") if args.model == "vgg": model = vgg11(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters) elif args.model == "resnet18": model = PreActResNet18(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters) elif args.model == "mobilenetv2": model = MobileNetV2(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters) else: print(args.model, "model not supported") exit() print("{} set up.".format(args.model)) # for model saving model_path = "saved_models" if not os.path.exists(model_path): os.makedirs(model_path) log_save_folder = "%s/%s"%(model_path, args.model) if not os.path.exists(log_save_folder): os.makedirs(log_save_folder) model_save_path = "%s/%s"%(log_save_folder, args.save) + "_checkpoint.t7" model_state_dict = model.state_dict() if args.save: print("Model will be saved to {}".format(model_save_path)) save_checkpoint({ 'state_dict': model_state_dict }, False, filename = model_save_path) else: print("Save path not defined. Model will not be saved.") # Assume cuda is available and uses single GPU model.cuda() cudnn.benchmark = True # define objective criterion = nn.CrossEntropyLoss() ###################### ## Set up pruning ## ###################### # remove updates from gate layers, because we want them to be 0 or 1 constantly parameters_for_update = [] parameters_for_update_named = [] for name, m in model.named_parameters(): if "gate" not in name: parameters_for_update.append(m) parameters_for_update_named.append((name, m)) else: print("skipping parameter", name, "shape:", m.shape) total_size_params = sum([np.prod(par.shape) for par in parameters_for_update]) print("Total number of parameters, w/o usage of bn consts: ", total_size_params) optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) ############### ## Training ## ############### best_test_acc = 0 train_acc_plt = [] train_loss_plt = [] test_acc_plt = [] test_loss_plt = [] epoch_plt = [] for epoch in range(1, args.epochs + 1): adjust_learning_rate(args, optimizer, epoch) print("Epoch: {}".format(epoch)) # train model train_acc, train_loss = train(args, model, train_loader, optimizer, epoch, criterion) # evaluate on validation set test_acc, test_loss = validate(args, test_loader, model, criterion, epoch, optimizer=optimizer) # remember best prec@1 and save checkpoint is_best = test_acc > best_test_acc best_test_acc = max(test_acc, best_test_acc) model_state_dict = model.state_dict() if args.save: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model_state_dict, 'best_prec1': test_acc, }, is_best, filename=model_save_path) train_acc_plt.append(train_acc) train_loss_plt.append(train_loss) test_acc_plt.append(test_acc) test_loss_plt.append(test_loss) epoch_plt.append(epoch) pickle_save = { "ratio": ratio, "train_acc": train_acc_plt, "train_loss": train_loss_plt, "test_acc": test_acc_plt, "test_loss": test_loss_plt, } plot_path = "saved_plots" if not os.path.exists(plot_path): os.makedirs(plot_path) log_save_folder = "%s/%s"%(plot_path, args.model) if not os.path.exists(log_save_folder): os.makedirs(log_save_folder) pickle_out = open("%s/%s_%s.pk"%(log_save_folder, args.save, int(ratio*100)),"wb") pickle.dump(pickle_save, pickle_out) pickle_out.close()
def load_model(model_file): model = MobileNetV2() state_dict = torch.load(model_file) model.load_state_dict(state_dict) model.to('cpu') return model
def main(): # Training settings parser = argparse.ArgumentParser(description='Efficient Filter Scaling of Convolutional Neural Network') parser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=150, metavar='N', help='number of epochs to train (default: 40)') parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)') parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay (default: 5e-4)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--dataset', default="tinyimagenet", type=str, help='dataset for experiment, choice: tinyimagenet', choices= ["tinyimagenet"]) parser.add_argument('--data', metavar='DIR', default='/DATA/tiny-imagenet-200', help='path to imagenet dataset') parser.add_argument('--model', default="resnet18", type=str, help='model selection, choices: vgg, mobilenetv2, resnet18', choices=["mobilenetv2", "resnet18"]) parser.add_argument('--save', default='model', help='model file') parser.add_argument('--prune_fname', default='filename', help='prune save file') parser.add_argument('--descent_idx', type=int, default=14, help='Iteration for Architecture Descent') parser.add_argument('--morph', dest="morph", action='store_true', default=False, help='Prunes only 50 percent of neurons, for comparison with MorphNet') parser.add_argument('--uniform', dest="uniform", action='store_true', default=False, help='Use uniform scaling instead of NeuralScale') args = parser.parse_args() ################## ## Data loading ## ################## kwargs = {'num_workers': 1, 'pin_memory': True} if args.dataset == "tinyimagenet": print("Using tiny-Imagenet Dataset") traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'test') normalize = transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]) train_dataset = torchvision.datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomCrop(64, padding=4), transforms.RandomRotation(20), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_sampler = None kwargs = {'num_workers': 16} train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, pin_memory=True, **kwargs) test_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder(valdir, transforms.Compose([ transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, pin_memory=True, **kwargs) else: print("Dataset does not exist! [Imagenet]") exit() if args.dataset=='tinyimagenet': num_classes = 200 else: print("Only tinyimagenet") exit() ratios = [0.25, 0.5, 0.75, 1.0] pruned_filters = None neuralscale = True # turn on NeuralScale by default if args.uniform: neuralscale = False if args.morph: neuralscale = False if args.model == "resnet18": pruned_filters = [82,90,78,80,96,180,104,96,194,312,182,178,376,546,562,454,294] # resnet18 tinyimagenet elif args.mode == "mobilenetv2": pruned_filters = [28, 16, 24, 24, 32, 32, 30, 64, 59, 50, 41, 96, 73, 48, 160, 69, 47, 155, 360] # mobilenetv2 tinyimagenet else: print("{} not supported.".format(args.model)) exit() for ratio in ratios: print("Current ratio: {}".format(ratio)) ########### ## Model ## ########### print("Setting Up Model...") if args.model == "resnet18": model = PreActResNet18(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters) elif args.model == "mobilenetv2": model = MobileNetV2(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters) else: print(args.model, "model not supported [resnet18 mobilenetv2] only") exit() print("{} set up.".format(args.model)) # for model saving model_path = "saved_models" if not os.path.exists(model_path): os.makedirs(model_path) log_save_folder = "%s/%s"%(model_path, args.model) if not os.path.exists(log_save_folder): os.makedirs(log_save_folder) model_save_path = "%s/%s"%(log_save_folder, args.save) + "_checkpoint.t7" model_state_dict = model.state_dict() if args.save: print("Model will be saved to {}".format(model_save_path)) save_checkpoint({ 'state_dict': model_state_dict }, False, filename = model_save_path) else: print("Save path not defined. Model will not be saved.") # Assume cuda is available and uses single GPU model.cuda() cudnn.benchmark = True # define objective criterion = nn.CrossEntropyLoss() ###################### ## Set up pruning ## ###################### # remove updates from gate layers, because we want them to be 0 or 1 constantly parameters_for_update = [] parameters_for_update_named = [] for name, m in model.named_parameters(): if "gate" not in name: parameters_for_update.append(m) parameters_for_update_named.append((name, m)) else: print("skipping parameter", name, "shape:", m.shape) total_size_params = sum([np.prod(par.shape) for par in parameters_for_update]) print("Total number of parameters, w/o usage of bn consts: ", total_size_params) optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) ############### ## Training ## ############### best_test_acc = 0 train_acc_plt = [] train_loss_plt = [] test_acc_plt = [] test_loss_plt = [] epoch_plt = [] for epoch in range(1, args.epochs + 1): adjust_learning_rate_imagenet(args, optimizer, epoch, search=False) print("Epoch: {}".format(epoch)) # train model train_acc, train_loss = train(args, model, train_loader, optimizer, epoch, criterion) # evaluate on validation set test_acc, test_loss = validate(args, test_loader, model, criterion, epoch, optimizer=optimizer) # remember best prec@1 and save checkpoint is_best = test_acc > best_test_acc best_test_acc = max(test_acc, best_test_acc) model_state_dict = model.state_dict() if args.save: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model_state_dict, 'best_prec1': test_acc, }, is_best, filename=model_save_path) train_acc_plt.append(train_acc) train_loss_plt.append(train_loss) test_acc_plt.append(test_acc) test_loss_plt.append(test_loss) epoch_plt.append(epoch) pickle_save = { "ratio": ratio, "train_acc": train_acc_plt, "train_loss": train_loss_plt, "test_acc": test_acc_plt, "test_loss": test_loss_plt, } plot_path = "saved_plots" if not os.path.exists(plot_path): os.makedirs(plot_path) log_save_folder = "%s/%s"%(plot_path, args.model) if not os.path.exists(log_save_folder): os.makedirs(log_save_folder) pickle_out = open("%s/%s_%s.pk"%(log_save_folder, args.save, int(ratio*100)),"wb") pickle.dump(pickle_save, pickle_out) pickle_out.close()
model = PreActResNet18(ratio=ratio, neuralscale=True, num_classes=num_classes, prune_fname=fname, descent_idx=14) latency = compute_latency(model) params = compute_params_(model) param_high.append(params) latency_high.append(latency) elif args.model == "mobilenetv2": filters = [[32],[16],[24,24],[32,32,32],[64,64,64,64],[96,96,96],[160,160,160],[320],[1280]] if args.dataset == "CIFAR100": filters_prune = [28, 16, 24, 21, 30, 31, 26, 56, 50, 49, 46, 83, 70, 58, 120, 101, 68, 134, 397] ratios = np.arange(0.25,2.1,0.25) # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2] elif args.dataset == "tinyimagenet": filters_prune = [28, 16, 24, 24, 32, 32, 30, 64, 59, 50, 41, 96, 73, 48, 160, 69, 47, 155, 360] # mobilenetv2 tinyimagenet ratios = [0.25,0.5,0.75,1.0] for ratio in ratios: # convcut new_config = MobileNetV2.prepare_filters(MobileNetV2, filters, ratio=ratio, neuralscale=False, num_classes=num_classes) model = MobileNetV2(filters=new_config, num_classes=num_classes, dataset=args.dataset, convcut=True) latency = compute_latency(model) params = compute_params_(model) param_convcut.append(params) latency_convcut.append(latency) # uniform new_config = MobileNetV2.prepare_filters(MobileNetV2, filters, ratio=ratio, neuralscale=False, num_classes=num_classes) model = MobileNetV2(filters=new_config, num_classes=num_classes, dataset=args.dataset) latency = compute_latency(model) params = compute_params_(model) param_uni.append(params) latency_uni.append(latency) # pruned new_config = MobileNetV2.prepare_filters(MobileNetV2, filters, ratio=ratio, neuralscale=False, num_classes=num_classes, pruned_filters=filters_prune) model = MobileNetV2(filters=new_config, num_classes=num_classes, dataset=args.dataset)
def main(): # Training settings parser = argparse.ArgumentParser( description='Fine-tune on pruned architecture') parser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=300, metavar='N', help='number of epochs to train (default: 40)') parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.1)') parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay (default: 5e-4)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument( '--dataset', default="CIFAR10", type=str, help='dataset for experiment, choice: CIFAR10, CIFAR100, tinyimagenet', choices=["CIFAR10", "CIFAR100", "tinyimagenet"]) parser.add_argument('--data', metavar='DIR', default='/DATA/tiny-imagenet-200', help='path to imagenet dataset') parser.add_argument( '--model', default="resnet18", type=str, help='model selection, choices: vgg, mobilenetv2, resnet18', choices=["vgg", "mobilenetv2", "resnet18"]) parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--save', default='model', help='model file') parser.add_argument('--prune_fname', default='filename', help='prune save file') parser.add_argument('--descent_idx', type=int, default=14, help='Iteration for Architecture Descent') parser.add_argument('--method', type=int, default=0, help='sets pruning method') args = parser.parse_args() ################## ## Data loading ## ################## kwargs = {'num_workers': 1, 'pin_memory': True} if (args.dataset == "CIFAR10"): print("Using Cifar10 Dataset") normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, **kwargs) elif args.dataset == "CIFAR100": print("Using Cifar100 Dataset") normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, **kwargs) elif args.dataset == "tinyimagenet": print("Using tiny-Imagenet Dataset") traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'test') normalize = transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]) train_dataset = torchvision.datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomCrop(64, padding=4), transforms.RandomRotation(20), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_sampler = None kwargs = {'num_workers': 16} train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, pin_memory=True, **kwargs) test_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder( valdir, transforms.Compose([ transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, pin_memory=True, **kwargs) else: print("Dataset does not exist! [CIFAR10, MNIST, tinyimagenet]") exit() if args.dataset == 'CIFAR10': num_classes = 10 args.epochs = 40 elif args.dataset == 'CIFAR100': num_classes = 100 args.epochs = 40 elif args.dataset == 'tinyimagenet': num_classes = 200 args.epochs = 20 # ratios = [0.25, 0.75] ratios = [0.75] pruned_filters = None for ratio in ratios: print("Current ratio: {}".format(ratio)) ########### ## Model ## ########### print("Setting Up Model...") if args.model == "vgg": model = vgg11(ratio=1, neuralscale=False, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters, search=True) elif args.model == "resnet18": model = PreActResNet18(ratio=1, neuralscale=False, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters, search=True) elif args.model == "mobilenetv2": model = MobileNetV2(ratio=1, neuralscale=False, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters, search=True) else: print(args.model, "model not supported") exit() print("{} set up.".format(args.model)) # optionally resume from a checkpoint if args.resume: model_path = "saved_models" if not os.path.exists(model_path): os.makedirs(model_path) log_save_folder = "%s/%s" % (model_path, args.model) model_resume_path = "%s/%s" % (log_save_folder, args.resume) + "_best_model.t7" if os.path.isfile(model_resume_path): print("=> loading checkpoint '{}'".format(model_resume_path)) checkpoint = torch.load(model_resume_path) best_acc1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( model_resume_path, checkpoint['epoch'])) else: print( "=> no checkpoint found at '{}'".format(model_resume_path)) # Assume cuda is available and uses single GPU model.cuda() cudnn.benchmark = True # define objective criterion = nn.CrossEntropyLoss() ###################### ## Set up pruning ## ###################### # remove updates from gate layers, because we want them to be 0 or 1 constantly parameters_for_update = [] parameters_for_update_named = [] for name, m in model.named_parameters(): if "gate" not in name: parameters_for_update.append(m) parameters_for_update_named.append((name, m)) else: print("skipping parameter", name, "shape:", m.shape) total_size_params = sum( [np.prod(par.shape) for par in parameters_for_update]) print("Total number of parameters, w/o usage of bn consts: ", total_size_params) optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) pruner = compare_pruner(model, method=args.method, size=ratio) pruner.prune_neurons(optimizer) ############### ## Training ## ############### best_test_acc = 0 train_acc_plt = [] train_loss_plt = [] test_acc_plt = [] test_loss_plt = [] epoch_plt = [] for epoch in range(1, args.epochs + 1): print("Epoch: {}".format(epoch)) # train model train_acc, train_loss = train(args, model, train_loader, optimizer, epoch, criterion) # evaluate on validation set test_acc, test_loss = validate(args, test_loader, model, criterion, epoch, optimizer=optimizer) # remember best prec@1 and save checkpoint is_best = test_acc > best_test_acc best_test_acc = max(test_acc, best_test_acc) print(best_test_acc) train_acc_plt.append(train_acc) train_loss_plt.append(train_loss) test_acc_plt.append(test_acc) test_loss_plt.append(test_loss) epoch_plt.append(epoch) pickle_save = { "ratio": ratio, "train_acc": train_acc_plt, "train_loss": train_loss_plt, "test_acc": test_acc_plt, "test_loss": test_loss_plt, } plot_path = "saved_plots" if not os.path.exists(plot_path): os.makedirs(plot_path) log_save_folder = "%s/%s" % (plot_path, args.model) if not os.path.exists(log_save_folder): os.makedirs(log_save_folder) pickle_out = open( "%s/%s_%s.pk" % (log_save_folder, args.save, int(ratio * 100)), "wb") pickle.dump(pickle_save, pickle_out) pickle_out.close()
descent_idx=args.descent_idx, pruned_filters=pruned_filters) elif args.model == "resnet18": model = PreActResNet18(ratio=ratio, efficient_scale=False, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters, search=True) elif args.model == "mobilenetv2": model = MobileNetV2(ratio=ratio, efficient_scale=False, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters, search=True) else: print(args.model, "model not supported") exit() print("{} set up.".format(args.model)) # for model saving model_path = "saved_models" if not os.path.exists(model_path): os.makedirs(model_path) log_save_folder = "%s/%s" % (model_path, args.model) if not os.path.exists(log_save_folder):
new_config = PreActResNet.prepare_filters(PreActResNet, filters, ratio=0.75, neuralscale=False, num_classes=num_classes) model = PreActResNet18(filters=new_config, num_classes=num_classes, dataset=args.dataset) latency = compute_latency(model) params = compute_params_(model) elif args.model == "mobilenetv2": filters = [[32], [16], [24, 24], [32, 32, 32], [64, 64, 64, 64], [96, 96, 96], [160, 160, 160], [320], [1280]] new_config = MobileNetV2.prepare_filters(MobileNetV2, filters, ratio=0.75, neuralscale=False, num_classes=num_classes) model = MobileNetV2(filters=new_config, num_classes=num_classes, dataset=args.dataset) latency = compute_latency(model) params = compute_params_(model) ratio = 0.75 uni_test_loss = [] uni_test_acc = [] uni_test_loss_tmp = [] uni_test_acc_tmp = []
def main(): # Training settings parser = argparse.ArgumentParser( description= 'Train a base network under ratio=1 (default configuration) for pruning' ) parser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=128, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=300, metavar='N', help='number of epochs to train (default: 40)') parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)') # parser.add_argument('--lr', type=float, default=0.01, metavar='LR', # help='learning rate (default: 0.01)') parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay (default: 5e-4)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--lr-decay-every', type=int, default=100, help='learning rate decay by 10 every X epochs') parser.add_argument('--lr-decay-scalar', type=float, default=0.1, help='--') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--dataset', default="CIFAR10", type=str, help='dataset for experiment, choice: MNIST, CIFAR10', choices=["MNIST", "CIFAR10", "CIFAR100", "Imagenet", "tinyimagenet"]) parser.add_argument('--data', metavar='DIR', default='/DATA/tiny-imagenet-200', help='path to tinyimagenet dataset') parser.add_argument( '--model', default="resnet18", type=str, help='model selection, choices: vgg, mobilenetv2, resnet18', choices=["vgg", "mobilenetv2", "resnet18", "mobilenet"]) parser.add_argument('--r', dest="resume", action='store_true', default=False, help='Resume from checkpoint') parser.add_argument('--save', default='model', help='model file') parser.add_argument('--prune_fname', default='filename', help='prune save file') parser.add_argument('--descent_idx', type=int, default=14, help='Iteration for Architecture Descent') parser.add_argument('--s', type=float, default=0.0001, help='scale sparse rate (default: 0.0001)') args = parser.parse_args() ################## ## Data loading ## ################## kwargs = {'num_workers': 1, 'pin_memory': True} if (args.dataset == "CIFAR10"): print("Using Cifar10 Dataset") normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, **kwargs) elif args.dataset == "CIFAR100": print("Using Cifar100 Dataset") normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, **kwargs) elif args.dataset == "Imagenet": print("Using Imagenet Dataset") traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_sampler = None kwargs = {'num_workers': 16} train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, pin_memory=True, **kwargs) if args.use_test_as_train: train_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=(train_sampler is None), **kwargs) test_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, pin_memory=True, **kwargs) elif args.dataset == "tinyimagenet": print("Using tiny-Imagenet Dataset") traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'test') normalize = transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]) train_dataset = torchvision.datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomCrop(64, padding=4), transforms.RandomRotation(20), # transforms.RandomResizedCrop(64), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_sampler = None kwargs = {'num_workers': 16} train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, pin_memory=True, **kwargs) test_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder( valdir, transforms.Compose([ transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, pin_memory=True, **kwargs) else: print("Dataset does not exist! [CIFAR10, MNIST, tinyimagenet]") exit() if args.dataset == 'CIFAR10': num_classes = 10 elif args.dataset == 'CIFAR100': num_classes = 100 elif args.dataset == 'tinyimagenet': num_classes = 200 ratio = 1 pruned_filters = None ########### ## Model ## ########### print("Setting Up Model...") if args.model == "vgg": model = vgg11(ratio=ratio, neuralscale=False, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters) elif args.model == "resnet18": model = PreActResNet18(ratio=ratio, neuralscale=False, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters, search=True) elif args.model == "mobilenetv2": model = MobileNetV2(ratio=ratio, neuralscale=False, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters, search=True) else: print(args.model, "model not supported") exit() print("{} set up.".format(args.model)) # for model saving model_path = "saved_models" if not os.path.exists(model_path): os.makedirs(model_path) log_save_folder = "%s/%s" % (model_path, args.model) if not os.path.exists(log_save_folder): os.makedirs(log_save_folder) model_save_path = "%s/%s" % (log_save_folder, args.save) + "_checkpoint.t7" model_state_dict = model.state_dict() if args.save: print("Model will be saved to {}".format(model_save_path)) save_checkpoint({'state_dict': model_state_dict}, False, filename=model_save_path) else: print("Save path not defined. Model will not be saved.") # Assume cuda is available and uses single GPU model.cuda() cudnn.benchmark = True # define objective criterion = nn.CrossEntropyLoss() ###################### ## Set up pruning ## ###################### # remove updates from gate layers, because we want them to be 0 or 1 constantly parameters_for_update = [] parameters_for_update_named = [] for name, m in model.named_parameters(): if "gate" not in name: parameters_for_update.append(m) parameters_for_update_named.append((name, m)) else: print("skipping parameter", name, "shape:", m.shape) total_size_params = sum( [np.prod(par.shape) for par in parameters_for_update]) print("Total number of parameters, w/o usage of bn consts: ", total_size_params) optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) ############### ## Training ## ############### best_test_acc = 0 train_acc_plt = [] train_loss_plt = [] test_acc_plt = [] test_loss_plt = [] epoch_plt = [] if num_classes == 200: # tinyimagenet args.epochs = 150 else: args.epochs = 300 for epoch in range(1, args.epochs + 1): if num_classes == 200: # tinyimagenet adjust_learning_rate_imagenet(args, optimizer, epoch, search=False) else: adjust_learning_rate(args, optimizer, epoch) print("Epoch: {}".format(epoch)) # train model train_acc, train_loss = train(args, model, train_loader, optimizer, epoch, criterion) # evaluate on validation set test_acc, test_loss = validate(args, test_loader, model, criterion, epoch, optimizer=optimizer) # remember best prec@1 and save checkpoint is_best = test_acc > best_test_acc best_test_acc = max(test_acc, best_test_acc) model_state_dict = model.state_dict() if args.save: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model_state_dict, 'best_prec1': test_acc, }, is_best, filename=model_save_path) train_acc_plt.append(train_acc) train_loss_plt.append(train_loss) test_acc_plt.append(test_acc) test_loss_plt.append(test_loss) epoch_plt.append(epoch) pickle_save = { "ratio": ratio, "train_acc": train_acc_plt, "train_loss": train_loss_plt, "test_acc": test_acc_plt, "test_loss": test_loss_plt, } plot_path = "saved_plots" if not os.path.exists(plot_path): os.makedirs(plot_path) log_save_folder = "%s/%s" % (plot_path, args.model) if not os.path.exists(log_save_folder): os.makedirs(log_save_folder) pickle_out = open( "%s/%s_%s.pk" % (log_save_folder, args.save, int(ratio * 100)), "wb") pickle.dump(pickle_save, pickle_out) pickle_out.close()