def yolo_config(config, args): # if (args.prune_threshold > 0.005): # print("WARNING: Prune threshold seems too large.") # if input("Input y if you are sure you want to continue.") != 'y': return device = 'cpu' if args.no_cuda else 'cuda:0' model = config['model'](config['config_path'], device=device) wrapper = YoloWrapper(device, model) lr0 = 0.001 # lr0 = args.lr optimizer = config['optimizer'](filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=args.momentum) writer = SummaryWriter() print("Loading dataloaders..") train_dataloader = LoadImagesAndLabels(config['datasets']['train'], batch_size=args.batch_size, img_size=config['image_size']) val_dataloader = LoadImagesAndLabels(config['datasets']['test'], batch_size=args.batch_size, img_size=config['image_size']) if (args.pretrained_weights): model.load_state_dict( torch.load(args.pretrained_weights, map_location=torch.device(device))) else: wrapper.train(train_dataloader, val_dataloader, args.epochs, optimizer, lr0) torch.save(model.state_dict(), "YOLOv3-gate-prepruned.pt") with torch.no_grad(): pre_prune_mAP, _, _ = wrapper.test(val_dataloader, img_size=config['image_size'], batch_size=args.batch_size) prune_perc = 0. if args.start_at_prune_rate is None else args.start_at_prune_rate prune_iter = 0 curr_mAP = pre_prune_mAP if args.tensorboard: writer.add_scalar('prune/accuracy', curr_mAP, prune_iter) writer.add_scalar('prune/percentage', prune_perc, prune_iter) for name, param in wrapper.model.named_parameters(): if 'bn' not in name: writer.add_histogram(f'prune/preprune/{name}', param, prune_iter) thresh_reached, _ = reached_threshold(args.prune_threshold, curr_mAP, pre_prune_mAP) while (not thresh_reached): prune_iter += 1 prune_perc += 5. masks = weight_prune(model, prune_perc) model.set_mask(masks) print( f"Just pruned with prune_perc={prune_perc}, now has {prune_rate(model, verbose=False)}% zeros" ) if not args.no_retrain: print(f"Retraining at prune percentage {prune_perc}..") curr_mAP, best_weights = wrapper.train(train_dataloader, val_dataloader, 3, optimizer, lr0) print("Loading best weights from training epochs..") model.load_state_dict(best_weights) print( f"Just finished training with prune_perc={prune_perc}, now has {prune_rate(model, verbose=False)}% zeros" ) else: with torch.no_grad(): curr_mAP, _, _ = wrapper.test(val_dataloader, img_size=config['image_size'], batch_size=args.batch_size) if args.tensorboard: writer.add_scalar('prune/accuracy', curr_mAP, prune_iter) writer.add_scalar('prune/percentage', prune_perc, prune_iter) thresh_reached, diff = reached_threshold(args.prune_threshold, curr_mAP, pre_prune_mAP) print(f"mAP achieved: {curr_mAP}") print(f"Change in mAP: {diff}") prune_perc = prune_rate(model) if (args.save_model): #torch.save(model.state_dict(), f'{config["name"]}-pruned-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt') #torch.save(model.state_dict(), "YOLOv3-prune-perc-" + str(prune_perc) + ".pt") torch.save(model.state_dict(), "YOLOv3-gate-pruned-modelcompression.pt") if args.tensorboard: for name, param in wrapper.model.named_parameters(): if 'weight' in name: writer.add_histogram(f'prune/postprune/{name}', param, prune_iter + 1) print(f"Pruned model: {config['name']}") print(f"Pre-pruning mAP: {pre_prune_mAP}") print(f"Post-pruning mAP: {curr_mAP}") print(f"Percentage of zeroes: {prune_perc}") return wrapper
def frcnn_config(config, args): classes = ( '__background__', # always index 0 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') model = config['model']( classes # model_path = args.pretrained_weights ) model.create_architecture() wrapper = FasterRCNNWrapper('cpu' if args.no_cuda else 'cuda:0', model) if args.tensorboard: writer = SummaryWriter() if args.pretrained_weights: print("Loading weights ", args.pretrained_weights) state_dict = torch.load(args.pretrained_weights, map_location=torch.device('cuda:0')) if 'model' in state_dict.keys(): state_dict = state_dict['model'] model.load_state_dict(state_dict) else: wrapper.train(args.batch_size, args.lr, args.epochs) pre_prune_mAP = wrapper.test() # pre_prune_mAP = 0.6772 prune_perc = 0. if args.start_at_prune_rate is None else args.start_at_prune_rate prune_iter = 0 curr_mAP = pre_prune_mAP if args.tensorboard: writer.add_scalar('prune/accuracy', curr_mAP, prune_iter) writer.add_scalar('prune/percentage', prune_perc, prune_iter) for name, param in wrapper.model.named_parameters(): if 'bn' not in name: writer.add_histogram(f'prune/preprune/{name}', param, prune_iter) thresh_reached, _ = reached_threshold(args.prune_threshold, curr_mAP, pre_prune_mAP) while not thresh_reached: prune_iter += 1 prune_perc += 5. masks = weight_prune(model, prune_perc) model.set_mask(masks) if not args.no_retrain: print(f"Retraining at prune percentage {prune_perc}..") curr_mAP, best_weights = wrapper.train(args.batch_size, args.lr, args.epochs) print("Loading best weights from epoch at mAP ", curr_mAP) model.load_state_dict(best_weights) else: with torch.no_grad(): curr_mAP = wrapper.test() if args.tensorboard: writer.add_scalar('prune/accuracy', curr_mAP, prune_iter) writer.add_scalar('prune/percentage', prune_perc, prune_iter) thresh_reached, diff = reached_threshold(args.prune_threshold, curr_mAP, pre_prune_mAP) print(f"mAP achieved: {curr_mAP}") print(f"Change in mAP: {curr_mAP - pre_prune_mAP}") prune_perc = prune_rate(model) if (args.save_model): torch.save( model.state_dict(), f'{config["name"]}-pruned-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt' ) if args.tensorboard: for name, param in wrapper.model.named_parameters(): if 'weight' in name: writer.add_histogram(f'prune/postprune/{name}', param, prune_iter + 1) print(f"Pruned model: {config['name']}") print(f"Pre-pruning mAP: {pre_prune_mAP}") print(f"Post-pruning mAP: {curr_mAP}") print(f"Percentage of zeroes: {prune_perc}") return wrapper
def classifier_config(config, args): model = config['model']() device = 'cpu' if args.no_cuda else 'cuda:0' if args.tensorboard: writer = SummaryWriter() train_data = test_data = config['dataset']('./data', train=True, download=True, transform=transforms.Compose( config['transforms'])) test_data = config['dataset']('./data', train=False, download=True, transform=transforms.Compose( config['transforms'])) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=1) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=True, num_workers=1) optimizer = config['optimizer'](model.parameters(), lr=args.lr, momentum=args.momentum) wrapper = Classifier(model, device, train_loader, test_loader) if (args.pretrained_weights): print("Loading pretrained weights..") model.load_state_dict( torch.load(args.pretrained_weights, map_location=torch.device(device))) else: wrapper.train(args.log_interval, optimizer, args.epochs, config['loss_fn']) pre_prune_accuracy = wrapper.test(config['loss_fn']) prune_perc = 0. if args.start_at_prune_rate is None else args.start_at_prune_rate prune_iter = 0 curr_accuracy = pre_prune_accuracy if args.tensorboard: writer.add_scalar('prune/accuracy', curr_accuracy, prune_iter) writer.add_scalar('prune/percentage', prune_perc, prune_iter) for name, param in wrapper.model.named_parameters(): if 'bn' not in name: writer.add_histogram(f'prune/preprune/{name}', param, prune_iter) thresh_reached, _ = reached_threshold(args.prune_threshold, curr_accuracy, pre_prune_accuracy) while not thresh_reached: print(f"Testing at prune percentage {prune_perc}..") curr_accuracy = wrapper.test(config["loss_fn"]) prune_iter += 1 prune_perc += 5. # masks = weight_prune(model, prune_perc) masks = weight_prune(model, prune_perc, layerwise_thresh=True) model.set_mask(masks) if not args.no_retrain: print(f"Retraining at prune percentage {prune_perc}..") curr_accuracy, best_weights = wrapper.train( args.log_interval, optimizer, args.epochs, config['loss_fn']) print("Loading best weights from training epochs..") model.load_state_dict(best_weights) else: with torch.no_grad(): curr_accuracy = wrapper.test(config['loss_fn']) if args.tensorboard: writer.add_scalar('prune/accuracy', curr_accuracy, prune_iter) writer.add_scalar('prune/percentage', prune_perc, prune_iter) thresh_reached, diff = reached_threshold(args.prune_threshold, curr_accuracy, pre_prune_accuracy) print(f"Accuracy achieved: {curr_accuracy}") print(f"Change in accuracy: {diff}") prune_perc = prune_rate(model) if (args.save_model): torch.save( model.state_dict(), f'./models/{config["name"]}-pruned-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt' ) if args.tensorboard: for name, param in wrapper.model.named_parameters(): if 'weight' in name: writer.add_histogram(f'prune/postprune/{name}', param, prune_iter + 1) print(f"Pruned model: {config['name']}") print(f"Pre-pruning accuracy: {pre_prune_accuracy}") print(f"Post-pruning accuracy: {curr_accuracy}") print(f"Percentage of zeroes: {prune_perc}") return wrapper
def main(): global args, best_prec1 args = parser.parse_args() pruning = False chkpoint = False args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) # model = models.__dict__[args.arch](pretrained=True) model = alexnet(pretrained=True) else: print("=> creating model '{}'".format(args.arch)) # model = models.__dict__[args.arch]() model = alexnet(pretrained=False) if not args.distributed: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) params = { k: v for k, v in checkpoint['state_dict'].items() if 'mask' not in k } mask_params = { k: v for k, v in checkpoint['state_dict'].items() if 'mask' in k } args.start_epoch = checkpoint['epoch'] # saved_iter = checkpoint['iter'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(params) model.set_masks(list(mask_params.values())) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) prune_rate(model) chkpoint = True else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.prune > 0 and not chkpoint: # prune print("=> pruning...") masks = weight_prune(model, args.prune) model.set_masks(masks) pruning = True cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data, 'ilsvrc12_train_lmdb_224_pytorch') valdir = os.path.join(args.data, 'ilsvrc12_val_lmdb_224_pytorch') # traindir = os.path.join(args.data, 'ILSVRC2012_img_train') # valdir = os.path.join(args.data, 'ILSVRC2012_img_val_sorted') # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) # train_dataset = datasets.ImageFolder( # traindir, # transforms.Compose([ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize, # ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None # train_loader = torch.utils.data.DataLoader( # train_dataset, batch_size=args.batch_size, shuffle=( # train_sampler is None), # num_workers=args.workers, pin_memory=True, sampler=train_sampler) train_loader = Loader('train', traindir, batch_size=args.batch_size, num_workers=args.workers, cuda=True) val_loader = Loader('val', valdir, batch_size=args.batch_size, num_workers=args.workers, cuda=True) # val_loader = torch.utils.data.DataLoader( # datasets.ImageFolder(valdir, transforms.Compose([ # transforms.Resize(256), # transforms.CenterCrop(224), # transforms.ToTensor(), # normalize, # ])), # batch_size=args.batch_size, shuffle=False, # num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion) return if pruning and not chkpoint: # Prune weights validation print("--- {}% parameters pruned ---".format(args.prune)) validate(val_loader, model, criterion) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), 'iter': 0, }, is_best, path=args.logfolder) print("--- After retraining ---") prune_rate(model) torch.save(model.state_dict(), os.path.join(args.logfolder, 'alexnet_pruned.pkl'))
transform=transforms.ToTensor()) loader_test = torch.utils.data.DataLoader(test_dataset, batch_size=param['test_batch_size'], shuffle=True) # Load the pretrained model net = MLP() net.load_state_dict(torch.load('models/mlp_pretrained.pkl')) if torch.cuda.is_available(): print('CUDA ensabled.') net.cuda() print("--- Pretrained network loaded ---") test(net, loader_test) # prune the weights masks = weight_prune(net, param['pruning_perc']) net.set_masks(masks) print("--- {}% parameters pruned ---".format(param['pruning_perc'])) test(net, loader_test) # Retraining criterion = nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'], weight_decay=param['weight_decay']) train(net, criterion, optimizer, param, loader_train) # Check accuracy and nonzeros weights in each layer print("--- After retraining ---") test(net, loader_test)
enc = encoder(ip) enc = enc + torch.randn_like(enc, device=device) / scal op = decoder(enc) errs[i] = error_rate(op, labels) plt.semilogy(xx, errs + 1 / 10**hp.e_prec, label='All weights') loss_fn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(net.parameters(), lr=hp.lr) print("--- Pretrained network loaded ---") test() # prune the weights masks = weight_prune(net, hp.pp) i = 0 for part in net: # part in [encoder,decoder] for p in part[::2]: # conveniently skips biases p.set_mask(masks[i]) i += 1 print("--- {}% parameters pruned ---".format(hp.pp)) test() if hp.plot: for i, snr in enumerate(snrs): print(i) scal = np.sqrt(snr * 2 * hp.k / hp.n) labels, ip = generate_input(amt=10**hp.e_prec) enc = encoder(ip)