def calculate_flops(current_model): if args.expand: if args.arch == "resnet56": model_ref = models.resnet_expand.resnet56(num_classes=num_classes) else: raise NotImplementedError() else: if re.match("vgg.+", args.arch): model_ref = models.__dict__[args.arch](num_classes=num_classes) else: raise NotImplementedError() current_flops = count_model_param_flops(current_model.cpu(), 32) ref_flops = count_model_param_flops(model_ref.cpu(), 32) flops_ratio = current_flops / ref_flops print("FLOPs remains {}".format(flops_ratio))
def main(): global args, best_prec1 args = parser.parse_args() print(args) args.distributed = args.world_size > 1 if not os.path.exists(args.save): os.makedirs(args.save) if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) ################################################################################# if args.model == 'resnet-2x': model = models.resnet_2x() model_ref = models.resnet50_official() if args.model == 'vgg-5x': model = models.vgg_5x() model_ref = models.vgg_official() flops_std = count_model_param_flops(model_ref, 224) flops_small = count_model_param_flops(model, 224) ratio = flops_std / flops_small if ratio >= 2: args.epochs = 180 step_size = 60 else: args.epochs = int(90 * ratio) step_size = int(args.epochs / 3) ################################################################################# if not args.distributed: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, step_size) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, args.save)
def main(): # Init logger if not os.path.isdir(args.save_path): os.makedirs(args.save_path) log = open( os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w') print_log('save path : {}'.format(args.save_path), log) state = {k: v for k, v in args._get_kwargs()} print_log(state, log) print_log("Random Seed: {}".format(args.manualSeed), log) print_log("python version : {}".format(sys.version.replace('\n', ' ')), log) print_log("torch version : {}".format(torch.__version__), log) print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log) print_log("Compress Rate: {}".format(args.rate), log) print_log("Layer Begin: {}".format(args.layer_begin), log) print_log("Layer End: {}".format(args.layer_end), log) print_log("Layer Inter: {}".format(args.layer_inter), log) print_log("Epoch prune: {}".format(args.epoch_prune), log) # Init dataset if not os.path.isdir(args.data_path): os.makedirs(args.data_path) if args.dataset == 'cifar10': mean = [x / 255 for x in [125.3, 123.0, 113.9]] std = [x / 255 for x in [63.0, 62.1, 66.7]] elif args.dataset == 'cifar100': mean = [x / 255 for x in [129.3, 124.1, 112.4]] std = [x / 255 for x in [68.2, 65.4, 70.4]] else: assert False, "Unknow dataset : {}".format(args.dataset) train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std) ]) test_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)]) if args.dataset == 'cifar10': train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True) test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'cifar100': train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True) test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True) num_classes = 100 elif args.dataset == 'svhn': train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True) test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'stl10': train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True) test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'imagenet': assert False, 'Do not finish imagenet code' else: assert False, 'Do not support dataset : {}'.format(args.dataset) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) print_log("=> creating model '{}'".format(args.arch), log) # Init model, criterion, and optimizer net = models.__dict__[args.arch](num_classes) net_ref = models.__dict__[args.arch](num_classes) print_log("=> network :\n {}".format(net), log) net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) net_ref = torch.nn.DataParallel(net_ref, device_ids=list(range(args.ngpu))) # define loss function (criterion) and optimizer criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) if args.use_cuda: net.cuda() criterion.cuda() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print_log("=> loading checkpoint '{}'".format(args.resume), log) checkpoint = torch.load(args.resume) net_ref = checkpoint['state_dict'] print_log( "=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch']), log) else: print_log("=> no checkpoint found at '{}'".format(args.resume), log) else: print_log( "=> do not use any checkpoint for {} model".format(args.arch), log) flops_std = count_model_param_flops(net, 32) flops_small = count_model_param_flops(net_ref, 32) ratio = flops_std / flops_small args.epochs = int(400 * ratio) print("Total epochs %d" % args.epochs) schedule = args.schedule args.schedule = [ 1, int(schedule[1] * ratio), int(schedule[2] * ratio), int(schedule[3] * ratio) ] print(args.schedule) recorder = RecorderMeter(args.epochs) ################################################################################################################### for m, m_ref in zip(net.modules(), net_ref.modules()): if isinstance(m, nn.Conv2d): weight_copy = m_ref.weight.data.abs().clone() mask = weight_copy.gt(0).float().cuda() n = mask.sum() / float(m.in_channels) m.weight.data.normal_(0, math.sqrt(2. / n)) m.weight.data.mul_(mask) ################################################################################################################### if args.evaluate: time1 = time.time() validate(test_loader, net, criterion, log) time2 = time.time() print('function took %0.3f ms' % ((time2 - time1) * 1000.0)) return m = Mask(net) m.init_length() comp_rate = args.rate print("-" * 10 + "one epoch begin" + "-" * 10) print("the compression rate now is %f" % comp_rate) val_acc_1, val_los_1 = validate(test_loader, net, criterion, log) print(" accu before is: %.3f %%" % val_acc_1) if args.use_cuda: net = net.cuda() val_acc_2, val_los_2 = validate(test_loader, net, criterion, log) print(" accu after is: %s %%" % val_acc_2) # Main loop start_time = time.time() epoch_time = AverageMeter() for epoch in range(args.start_epoch, args.epochs): current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule) need_hour, need_mins, need_secs = convert_secs2time( epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format( need_hour, need_mins, need_secs) print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \ + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log) num_parameters = get_conv_zero_param(net) print_log('Zero parameters: {}'.format(num_parameters), log) num_parameters = sum([param.nelement() for param in net.parameters()]) print_log('Parameters: {}'.format(num_parameters), log) # train for one epoch train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log) # evaluate on validation set val_acc_1, val_los_1 = validate(test_loader, net, criterion, log) is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2) save_checkpoint( { 'arch': args.arch, 'state_dict': net.state_dict(), 'recorder': recorder, 'optimizer': optimizer.state_dict(), }, is_best, args.save_path, 'checkpoint.pth.tar') # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() log.close()
mask = bn3_masks[i] assert mask[1].shape[0] == m.expand_layer.idx.shape[0] m.expand_layer.idx = np.argwhere( mask[1].clone().cpu().numpy()).squeeze() torch.save( { 'cfg': cfg, 'state_dict': newmodel.state_dict(), "bn3_masks": bn3_masks }, os.path.join(args.save, '{}.pth.tar'.format(output_name))) # print(newmodel) model = newmodel flops = count_model_param_flops(model.cuda(), 224) print("FLOPs after pruning: {}".format(flops)) summary = pruning_summary_resnet50(model, False) print(summary) # evaluate model test(model, args) with open(savepath, "a") as fp: fp.write("FLOPs after pruning: {} \n".format(flops)) fp.write("\n\n\n") fp.write("************MODEL SUMMARY************") fp.write(summary) fp.write("*************************************")
continue mask = bn2_masks[i] assert mask[1].shape[0] == m.expand_layer.idx.shape[0] m.expand_layer.idx = np.argwhere( mask[1].clone().cpu().numpy()).squeeze().reshape(-1) torch.save( { 'cfg': cfg, 'state_dict': newmodel.state_dict(), "bn3_masks": bn2_masks }, os.path.join(args.save, '{}.pth.tar'.format(output_name))) model.enable_aux_fc = False newmodel.enable_aux_fc = False flops_ref = count_model_param_flops(model.cpu(), 32) model = newmodel flops = count_model_param_flops(model.cpu(), 32) summary = pruning_summary_resnet56(model, num_classes=num_classes) print(summary) pruned_acc = test(model, test_loader) print("=> Pruned completed. Test acc: {}".format(load_acc)) with open(savepath, "a") as fp: fp.write("FLOPs before pruning: {} \n".format(flops_ref)) fp.write("FLOPs after pruning: {} \n".format(flops)) fp.write("\n\n\n") fp.write("************MODEL SUMMARY************") fp.write(summary)
def main(): global args, best_prec1 args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if not os.path.exists(args.save): os.maskdit(args.save) args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) model_ref = models.__dict__[args.arch](pretrained=True) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() model_ref = models.__dict__[args.arch]() ###################################################################################################### flops_std = count_model_param_flops(model) flops_small = count_model_param_flops(model_ref) args.epochs = int(90 * flops_std / flops_small) step_size = int(args.epochs / 3) print("Scratch-B training total epochs %d" % args.epochs) ###################################################################################################### if args.gpu is not None: model = model.cuda(args.gpu) model_ref = model_ref.cuda(args.gpu) elif args.distributed: model.cuda() model_ref.cuda() model = torch.nn.parallel.DistributedDataParallel(model) model_ref = torch.nn.parallel.DistributedDataParallel(model_ref) else: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model_ref.features = torch.nn.DataParallel(model_ref.features) model.cuda() model_ref.cuda() else: model = torch.nn.DataParallel(model).cuda() model_ref = torch.nn.DataParallel(model_ref).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) model_ref.load_state_dict(checkpoint['state_dict']) else: print("=> no checkpoint found at '{}'".format(args.resume)) # set some weights to zero, according to model_ref --------------------------------- for m, m_ref in zip(model.modules(), model_ref.modules()): if isinstance(m, nn.Conv2d): weight_copy = m_ref.weight.data.abs().clone() mask = weight_copy.gt(0).float().cuda() n = mask.sum() / float(m.in_channels) m.weight.data.normal_(0, math.sqrt(2. / n)) m.weight.data.mul_(mask) # ---------------------------------------------------------------------------------- cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, step_size) ##################################################################################################### num_parameters = get_conv_zero_param(model) print('Zero parameters: {}'.format(num_parameters)) num_parameters = sum( [param.nelement() for param in model.parameters()]) print('Parameters: {}'.format(num_parameters)) ##################################################################################################### # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, checkpoint=args.save) return
cfg_mask = [] for k, m in enumerate(model.modules()): if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): weight_copy = m.weight.data.abs().clone() mask = weight_copy.gt(thre) mask = mask.float().cuda() pruned = pruned + mask.shape[0] - torch.sum(mask) m.weight.data.mul_(mask) m.bias.data.mul_(mask) cfg.append(int(torch.sum(mask))) cfg_mask.append(mask.clone()) print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'. format(k, mask.shape[0], int(torch.sum(mask)))) elif isinstance(m, nn.MaxPool2d): cfg.append('M') compute_flops.count_model_param_flops(model=None, input_res=224, multiply_adds=False) torch.save({'cfg': cfg, 'state_dict': model.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) pruned_ratio = pruned/total print('Pre-processing Successful!') def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred))
def main(): # Init logger if not os.path.isdir(args.save_path): os.makedirs(args.save_path) if args.resume: if not os.path.isdir(args.resume): os.makedirs(args.resume) log = open(os.path.join(args.save_path, '{}.txt'.format(args.description)), 'w') print_log('save path : {}'.format(args.save_path), log) state = {k: v for k, v in args._get_kwargs()} print_log(state, log) print_log("Random Seed: {}".format(args.manualSeed), log) print_log("use cuda: {}".format(args.use_cuda), log) print_log("python version : {}".format(sys.version.replace('\n', ' ')), log) print_log("torch version : {}".format(torch.__version__), log) print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log) print_log("Compress Rate: {}".format(args.rate), log) print_log("Epoch prune: {}".format(args.epoch_prune), log) print_log("description: {}".format(args.description), log) # Init data loader if args.dataset=='cifar10': train_loader=dataset.cifar10DataLoader(True,args.batch_size,True,args.workers) test_loader=dataset.cifar10DataLoader(False,args.batch_size,False,args.workers) num_classes=10 elif args.dataset=='cifar100': train_loader=dataset.cifar100DataLoader(True,args.batch_size,True,args.workers) test_loader=dataset.cifar100DataLoader(False,args.batch_size,False,args.workers) num_classes=100 elif args.dataset=='imagenet': assert False,'Do not finish imagenet code' else: assert False,'Do not support dataset : {}'.format(args.dataset) # Init model if args.arch=='cifarvgg16': net=models.vgg16_cifar(True,num_classes) elif args.arch=='resnet32': net=models.resnet32(num_classes) elif args.arch=='resnet56': net=models.resnet56(num_classes) elif args.arch=='resnet110': net=models.resnet110(num_classes) else: assert False,'Not finished' print_log("=> network:\n {}".format(net),log) net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) # define loss function (criterion) and optimizer criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) if args.use_cuda: net.cuda() criterion.cuda() recorder = RecorderMeter(args.epochs) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume+'checkpoint.pth.tar'): print_log("=> loading checkpoint '{}'".format(args.resume+'checkpoint.pth.tar'), log) checkpoint = torch.load(args.resume+'checkpoint.pth.tar') recorder = checkpoint['recorder'] args.start_epoch = checkpoint['epoch'] if args.use_state_dict: net.load_state_dict(checkpoint['state_dict']) else: net = checkpoint['state_dict'] optimizer.load_state_dict(checkpoint['optimizer']) print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), log) if args.evaluate: time1=time.time() validate(test_loader,net,criterion,args.use_cuda,log) time2=time.time() print('validate function took %0.3f ms' % ((time2 - time1) * 1000.0)) return else: print_log("=> no checkpoint found at '{}'".format(args.resume), log) else: print_log("=> not use any checkpoint for {} model".format(args.description), log) if args.original_train: original_train.args.arch=args.arch original_train.args.dataset=args.dataset original_train.main() return comp_rate=args.rate m=mask.Mask(net,args.use_cuda) print("-" * 10 + "one epoch begin" + "-" * 10) print("the compression rate now is %f" % comp_rate) val_acc_1, val_los_1 = validate(test_loader, net, criterion, args.use_cuda,log) print(" accu before is: %.3f %%" % val_acc_1) m.model=net print('before pruning') m.init_mask(comp_rate,args.last_index) m.do_mask() print('after pruning') m.print_weights_zero() net=m.model#update net if args.use_cuda: net=net.cuda() val_acc_2, val_los_2 = validate(test_loader, net, criterion, args.use_cuda,log) print(" accu after is: %.3f %%" % val_acc_2) # start_time=time.time() epoch_time=AverageMeter() for epoch in range(args.start_epoch,args.epochs): current_learning_rate=adjust_learning_rate(args.learning_rate,optimizer,epoch,args.gammas,args.schedule) need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) print_log( '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \ + ' [Best : Accuracy={:.2f}]'.format(recorder.max_accuracy(False)), log) train_acc,train_los=train(train_loader,net,criterion,optimizer,epoch,args.use_cuda,log) validate(test_loader, net, criterion,args.use_cuda, log) if (epoch % args.epoch_prune == 0 or epoch == args.epochs - 1): m.model=net print('before pruning') m.print_weights_zero() m.init_mask(comp_rate,args.last_index) m.do_mask() print('after pruning') m.print_weights_zero() net=m.model if args.use_cuda: net=net.cuda() val_acc_2, val_los_2 = validate(test_loader, net, criterion,args.use_cuda,log) is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2) if args.resume: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': net, 'recorder': recorder, 'optimizer': optimizer.state_dict(), }, is_best, args.resume, 'checkpoint.pth.tar') print('save ckpt done') epoch_time.update(time.time()-start_time) start_time=time.time() torch.save(net,args.model_save) # torch.save(net,args.save_path) flops.print_model_param_nums(net) flops.count_model_param_flops(net,32,False) log.close()
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) if idx0.size == 1: idx0 = np.resize(idx0, (1, )) if idx1.size == 1: idx1 = np.resize(idx1, (1, )) w1 = m0.weight.data[:, idx0.tolist()].clone() if layer_id_in_cfg != len(cfg_mask): w1 = w1[idx1.tolist(), :].clone() bias1 = m0.bias.data[idx1.tolist()].clone() else: bias1 = m0.bias.data.clone() assert m1.weight.data.shape == w1.shape assert m1.bias.data.shape == bias1.shape m1.weight.data = w1.clone() m1.bias.data = bias1.clone() torch.save({ 'cfg': cfg, 'state_dict': newmodel.state_dict() }, os.path.join(args.save, 'pruned.pth.tar')) print(newmodel) pruned_acc = test(newmodel) print("Accuracy after pruning: {}".format(pruned_acc)) # calculate FLOPs base_flops = count_model_param_flops(model, 32) pruned_flops = count_model_param_flops(newmodel, 32) flops_ratio = pruned_flops / base_flops print("Pruning FLOPs: {}".format(flops_ratio))
class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count print("Starting evaluating...") # acc = test() print("Skip evaluation. Aborted.") print("Computing FLOPs...") print("cfg: ", cfg) # calculate FLOPs flops = count_model_param_flops(new_model.cuda(), 224) flops_unpruned = count_model_param_flops(model.cuda(), 224) print("FLOPs after pruning: {}".format(flops)) print("FLOPs Unpruned: {}".format(flops_unpruned))