def test(): net = MobileNet(amc=True) from compute_flops import print_model_param_nums, print_model_param_flops #x = torch.randn(1,3,224,224) #y = net(x) #print(y.size()) print_model_param_nums(net) print_model_param_flops(net)
def test(epoch): global best_acc net.eval() test_loss = 0 correct = 0 total = 0 n_params = sum(p.numel() for p in net.parameters())/10**6 print(f'Total params: {n_params:2f}M') print_model_param_flops(net, 32) with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) # Save checkpoint. acc = 100.*correct/total if acc > best_acc: print('Saving..') state = { 'net': net.state_dict(), 'acc': acc, 'epoch': epoch, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/ckpt.pth') best_acc = acc
def print_statistics(self): num_params = [] num_flops = [] print("\n===== Metrics for grouped model ==========================\n") for group_id, model in zip(self.group_info, self.model_list): n_params = sum(p.numel() for p in model.parameters()) / 10**6 num_params.append(n_params) print(f'Grouped model for Class {group_id} ' f'Total params: {n_params:2f}M') num_flops.append(print_model_param_flops(model, 32)) print( f"Average number of flops: {sum(num_flops) / len(num_flops) / 10**9 :3f} G" ) print( f"Average number of param: {sum(num_params) / len(num_params)} M")
model = torch.nn.parallel.DistributedDataParallel( model, device_ids=args.gpu_ids) if args.swa == True: swa_model = torch.nn.parallel.DistributedDataParallel( swa_model, device_ids=args.gpu_ids) else: model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) if args.cuda: model.cuda() if len(args.gpu_ids) > 1: # model = torch.nn.DataParallel(model, device_ids=args.gpu_ids) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=args.gpu_ids) if args.dataset == 'imagenet': pruned_flops = print_model_param_flops(model, 224) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) def save_checkpoint(state, is_best, epoch, filepath, is_swa): if is_swa: torch.save(state, os.path.join(filepath, 'swa.pth.tar')) else: if epoch == 'init': filepath = os.path.join(filepath, 'init.pth.tar') torch.save(state, filepath) elif 'EB' in str(epoch):
output = model(data) test_loss += F.cross_entropy( output, target, size_average=False).data # sum up batch loss pred = output.data.max( 1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().numpy().sum() test_loss /= len(test_loader.dataset) #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( # test_loss, correct, len(test_loader.dataset), # 100. * correct / len(test_loader.dataset))) return correct / float(len(test_loader.dataset)) acc = test(model) total_params = print_model_param_nums(model.cpu()) total_flops = print_model_param_flops(model.cpu(), 32) results = { 'load': args.load, 'dataset': args.dataset, 'model_name': args.model_name, 'arch': 'mobilenetv1', 'acc': acc, 'cfg': model.cfg, 'total_params': total_params, 'total_flops': total_flops, } print(results)
# define loss function (criterion) and optimizer num_classes = 1000 # Data loading code train_loader, val_loader = \ get_data_loader(args.data, train_batch_size=args.batch_size, test_batch_size=args.test_batch_size, workers=args.workers) ## loading pretrained model ## assert args.load assert os.path.isfile(args.load) print("=> loading checkpoint '{}'".format(args.load)) checkpoint = torch.load(args.load) model = mbnet(cfg=checkpoint['cfg']) total_flops = print_model_param_flops(model, 224, multiply_adds=False) print(total_flops) if args.use_cuda: model.cuda() selected_model_keys = [k for k in model.state_dict().keys() if not (k.endswith('.y') or k.endswith('.v') or k.startswith('net_params') or k.startswith('y_params') or k.startswith('v_params'))] saved_model_keys = checkpoint['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict() if len(selected_model_keys) == len(saved_model_keys): for k0, k1 in zip(selected_model_keys, saved_model_keys): new_state_dict[k0] = checkpoint['state_dict'][k1] model_dict = model.state_dict()
metavar='PATH', help='path to the model (default: none)') parser.add_argument('--save', default='', type=str, metavar='PATH', help='path to save pruned model (default: none)') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() if not os.path.exists(args.save): os.makedirs(args.save) model = resnet(depth=args.depth, dataset=args.dataset) total_flops = print_model_param_flops(model, input_res=32) if args.cuda: model.cuda() if args.model: if os.path.isfile(args.model): print("=> loading checkpoint '{}'".format(args.model)) checkpoint = torch.load(args.model) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format( args.model, checkpoint['epoch'], best_prec1)) else: print("=> no checkpoint found at '{}'".format(args.resume))
if args.model: if os.path.isfile(args.model): print("=> loading checkpoint '{}'".format(args.model)) checkpoint = fix_robustness_ckpt(torch.load(args.model)) # args.start_epoch = checkpoint['epoch'] # best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint, strict=False) # print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" # .format(args.model, checkpoint['epoch'], best_prec1)) else: print("=> no checkpoint found at '{}'".format(args.resume)) exit() if args.dataset == 'imagenet': print('original model param: ', print_model_param_nums(model)) print('original model flops: ', print_model_param_flops(model, 224, True)) else: print('original model param: ', print_model_param_nums(model)) print('original model flops: ', print_model_param_flops(model, 32, True)) if args.cuda: model.cuda() total = 0 for m in model.modules(): if isinstance(m, nn.BatchNorm2d): total += m.weight.data.shape[0] bn = torch.zeros(total) index = 0
model = models.__dict__[args.arch](pretrained=False, cfg=cfg_input) if args.cuda: model.cuda() if len(args.gpu_ids) > 1: model = torch.nn.DataParallel(model, device_ids=args.gpu_ids) # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.gpu_ids, find_unused_parameters=True) else: model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) if args.cuda: model.cuda() if len(args.gpu_ids) > 1: model = torch.nn.DataParallel(model, device_ids=args.gpu_ids) # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.gpu_ids, find_unused_parameters=True) if args.dataset == 'imagenet': pruned_flops = print_model_param_flops(model.cpu(), 224) model.cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) def save_checkpoint(state, is_best, epoch, filepath): if epoch == 'init': filepath = os.path.join(filepath, 'init.pth.tar') torch.save(state, filepath) elif 'EB' in str(epoch): filepath = os.path.join(filepath, epoch+'.pth.tar') torch.save(state, filepath) else: filename = os.path.join(filepath, 'ckpt'+str(epoch)+'.pth.tar') torch.save(state, filename)
def main(): global args, best_prec1, device args = parser.parse_args() batch_size = args.batch_size * max(1, args.num_gpus) args.lr = args.lr * (batch_size / 256.) print(batch_size, args.lr, args.num_gpus) num_classes = 1000 num_training_samples = 1281167 args.num_batches_per_epoch = num_training_samples // batch_size assert os.path.isfile(args.load) and args.load.endswith(".pth.tar") args.save = os.path.dirname(args.load) training_mode = 'retrain' if args.retrain else 'finetune' args.save = os.path.join(args.save, training_mode) if not os.path.exists(args.save): os.makedirs(args.save) args.model_save_path = os.path.join( args.save, "epochs_{}_{}".format(args.epochs, os.path.basename(args.load))) args.distributed = args.world_size > 1 ########################################################## ## create file handler which logs even debug messages #import logging #log = logging.getLogger() #log.setLevel(logging.INFO) #ch = logging.StreamHandler() #fh = logging.FileHandler(args.logging_file_path) #formatter = logging.Formatter('%(asctime)s - %(message)s') #ch.setFormatter(formatter) #fh.setFormatter(formatter) #log.addHandler(fh) #log.addHandler(ch) ########################################################## if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # Use CUDA args.use_cuda = torch.cuda.is_available() and not args.no_cuda # Random seed random.seed(0) torch.manual_seed(0) if args.use_cuda: torch.cuda.manual_seed_all(0) device = 'cuda' cudnn.benchmark = True else: device = 'cpu' if args.evaluate == 1: device = 'cuda:0' assert os.path.isfile(args.load) print("=> loading checkpoint '{}'".format(args.load)) checkpoint = torch.load(args.load) model = mobilenetv2(cfg=checkpoint['cfg']) cfg = model.cfg total_params = print_model_param_nums(model.cpu()) total_flops = print_model_param_flops(model.cpu(), 224, multiply_adds=False) print(total_params, total_flops) if not args.distributed: model = torch.nn.DataParallel(model).to(device) else: model.to(device) model = torch.nn.parallel.DistributedDataParallel(model) ##### finetune ##### if not args.retrain: model.load_state_dict(checkpoint['state_dict']) # define loss function (criterion) and optimizer if args.label_smoothing: criterion = CrossEntropyLabelSmooth(num_classes).to(device) else: criterion = nn.CrossEntropyLoss().to(device) ### all parameter #### no_wd_params, wd_params = [], [] for name, param in model.named_parameters(): if param.requires_grad: if ".bn" in name or '.bias' in name: no_wd_params.append(param) else: wd_params.append(param) no_wd_params = nn.ParameterList(no_wd_params) wd_params = nn.ParameterList(wd_params) optimizer = torch.optim.SGD([ { 'params': no_wd_params, 'weight_decay': 0. }, { 'params': wd_params, 'weight_decay': args.weight_decay }, ], args.lr, momentum=args.momentum) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.model_save_path): print("=> loading checkpoint '{}'".format(args.model_save_path)) checkpoint = torch.load(args.model_save_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.model_save_path, checkpoint['epoch'])) else: pass # Data loading code train_loader, val_loader = \ get_data_loader(args.data, train_batch_size=batch_size, test_batch_size=32, workers=args.workers) if args.evaluate: validate(val_loader, model, criterion) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) #adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'cfg': cfg, #'m': args.m, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, args.model_save_path) print(' + Number of params: %.3fM' % (total_params / 1e6)) print(' + Number of FLOPs: %.3fG' % (total_flops / 1e9))
if idx0.size == 1: idx0 = np.resize(idx0, (1, )) if idx1.size == 1: idx1 = np.resize(idx1, (1, )) w1 = mm0.weight._data[0][:, idx0.tolist(), :, :] w1 = w1[idx1.tolist(), :, :, :] params[mm1.weight.name] = w1 elif isinstance(mm0, nn.Dense): if layer_id_in_cfg == len(cfg_mask): idx0 = np.squeeze( np.argwhere(np.asarray(cfg_mask[-1].asnumpy()))) if idx0.size == 1: idx0 = np.resize(idx0, (1, )) params[mm1.weight.name] = mm0.weight._data[0][:, idx0] params[mm1.bias.name] = mm0.bias._data[0] layer_id_in_cfg += 1 continue params[mm1.weight.name] = mm0.weight._data[0] params[mm1.bias.name] = mm0.bias._data[0] #print(params) pruned_model = '%s/%s-%s-pruned.params' % (args.save, args.dataset, model_name) mxnet.ndarray.save(pruned_model, params) newmodel.collect_params().load(pruned_model, ctx=context) acc = test(newmodel) num_parameters, flops = print_model_param_flops(newmodel, input_res=32) print('\nTest-set accuracy after pruning: ', acc)
criterion = nn.CrossEntropyLoss().cuda() # Data loading code train_loader, val_loader = \ get_data_loader(args.data, train_batch_size=args.batch_size, test_batch_size=16, workers=args.workers) ## loading pretrained model ## assert args.load assert os.path.isfile(args.load) print("=> loading checkpoint '{}'".format(args.load)) checkpoint = torch.load(args.load) model = mbnet(cfg=checkpoint['cfg']) total_params = print_model_param_nums(model) total_flops = print_model_param_flops(model, 224, multiply_adds=False) print(total_params, total_flops) if args.use_cuda: model.cuda() selected_model_keys = [k for k in model.state_dict().keys() if not (k.endswith('.y') or k.endswith('.v') or k.startswith('net_params') or k.startswith('y_params') or k.startswith('v_params'))] saved_model_keys = checkpoint['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict() if len(selected_model_keys) == len(saved_model_keys): for k0, k1 in zip(selected_model_keys, saved_model_keys): new_state_dict[k0] = checkpoint['state_dict'][k1] model_dict = model.state_dict()
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): n = m.weight.size(1) m.weight.data.normal_(0, 0.01) m.bias.data.zero_() def sp_mbnetv2(**kwargs): """ Constructs a MobileNet V2 model """ return SpMobileNetV2(**kwargs) if __name__ == '__main__': net = sp_mbnetv2() x = Variable(torch.FloatTensor(2, 3, 224, 224)) y = net(x) print(y.data.shape) print_cfg(net.cfg) from compute_flops import print_model_param_nums, print_model_param_flops total_flops = print_model_param_flops(net.cpu(), 224, multiply_adds=False)
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) if idx0.size == 1: idx0 = np.resize(idx0, (1,)) if idx1.size == 1: idx1 = np.resize(idx1, (1,)) w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() w1 = w1[idx1.tolist(), :, :, :].clone() m1.weight.data = w1.clone() elif isinstance(m0, nn.Linear): idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) if idx0.size == 1: idx0 = np.resize(idx0, (1,)) m1.weight.data = m0.weight.data[:, idx0].clone() m1.bias.data = m0.bias.data.clone() flop_ramained = compute_flops.print_model_param_flops(model=newmodel.cpu(), input_res=32, multiply_adds=False) torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save,str(int(prune_ratio*100))+ 'pruned.pth.tar')) # print(newmodel) # model = newmodel # test(model) def test(model): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} if args.dataset == 'cifar10': test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data/dataset/cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) elif args.dataset == 'cifar100':
'cfg': cfg, 'state_dict': newmodel.state_dict() }, os.path.join(args.save, 'pruned.pth.tar')) # print(newmodel) model = newmodel print("after pruning") acc = test(model) # Calculate Flops and Params origin_num_parameters = sum( [param.nelement() for param in origin_model.parameters()]) num_parameters = sum([param.nelement() for param in newmodel.parameters()]) param_reduction_percent = ( (origin_num_parameters - num_parameters) / origin_num_parameters) * 100 origin_flops = print_model_param_flops(origin_model.cpu(), input_res=32) / 1e9 new_flops = print_model_param_flops(newmodel.cpu(), input_res=32) / 1e9 flops_reduction_percent = ((origin_flops - new_flops) / origin_flops) * 100 with open(os.path.join(args.save, "prune.txt"), "w") as fp: fp.write("Number of parameters Before: \n" + str(origin_num_parameters) + "\n" + "\n") fp.write("Number of parameters: \n" + str(num_parameters) + "\n" + "\n") fp.write("% of reduced parameters: \n" + str(param_reduction_percent) + "\n" + "\n" + "\n") fp.write("Number of Flops Before: \n" + str(origin_flops) + "G" + "\n" + "\n") fp.write("Number of Flops: \n" + str(new_flops) + "G" + "\n" + "\n") fp.write("% of reduced Flops: \n" + str(flops_reduction_percent) + "\n" + "\n")
model = mwr.Model(num_classes, input_size=image_size, cfg=checkpoint['cfg']) model_ref = mwr.Model(num_classes, input_size=image_size, cfg=checkpoint['cfg']) # model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg']) # model_ref = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg']) model_ref.load_state_dict(checkpoint['state_dict']) for m0, m1 in zip(model.modules(), model_ref.modules()): if isinstance(m0, models.channel_selection): m0.indexes.data = m1.indexes.data.clone() # model_base = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) model_base = mwr.Model(num_classes, input_size=image_size) base_flops = print_model_param_flops(model_base, 32) pruned_flops = print_model_param_flops(model, 32) args.epochs = int(160 * (base_flops / pruned_flops)) if args.cuda: model.cuda() # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) if args.resume: if os.path.isfile(args.resume):
if isinstance(m, SpMbBlock) or (k == 2 and isinstance(m, SpConvBlock)): m.reset_yv_() log.info('acc before splitting') test(model) for epoch in range(1, 1 + args.epochs): if epoch % 2 == 0: for param_group in optimizer_v.param_groups: param_group['lr'] *= 0.2 min_eig_vals, min_eig_vecs = train(epoch) #break ######################################## ##### select neurons ###### ######################################## print_model_param_flops(model.cpu(), 32) model.to(device) total = 0 for m in min_eig_vals: total += len(m) cfg_grow = [] cfg_mask = [] block_weigths_norm = [] if args.energy or args.params: ## flops ## cfg = model.cfg params_inc_per_neuron, flops_inc_per_neuron = [], []
def main(): global best_prec1, log batch_size = args.batch_size * max(1, args.num_gpus) args.lr = args.lr * (batch_size // 256) print(batch_size, args.lr, args.num_gpus) num_classes = 1000 num_training_samples = 1281167 args.num_batches_per_epoch = num_training_samples // batch_size assert args.exp_name args.save = os.path.join(args.save, args.exp_name) if not os.path.exists(args.save): os.makedirs(args.save) hyper_str = "run_{}_lr_{}_decay_{}_b_{}_gpu_{}".format(args.epochs, args.lr, \ args.lr_mode, batch_size, args.num_gpus) ## bn-based pruning base model ## if args.sr: hyper_str = "{}_sr_grow_{}_s_{}".format(hyper_str, args.m, args.s) ## using amc configuration ## elif args.amc: hyper_str = "{}_amc".format(hyper_str) elif args.sp: hyper_str = "{}_sp_base_{}".format(hyper_str, args.sp_cfg) else: hyper_str = "{}_grow_{}".format(hyper_str, args.m) args.model_save_path = \ os.path.join(args.save, 'mbv1_{}.pth.tar'.format(hyper_str)) #args.logging_file_path = \ # os.path.join(args.save, 'mbv1_{}.log'.format(hyper_str)) #print(args.model_save_path, args.logging_file_path) ########################################################## ## create file handler which logs even debug messages #import logging #log = logging.getLogger() #log.setLevel(logging.INFO) #ch = logging.StreamHandler() #fh = logging.FileHandler(args.logging_file_path) #formatter = logging.Formatter('%(asctime)s - %(message)s') #ch.setFormatter(formatter) #fh.setFormatter(formatter) #log.addHandler(fh) #log.addHandler(ch) ######################################################### args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # Use CUDA use_cuda = torch.cuda.is_available() args.use_cuda = use_cuda # Random seed random.seed(0) torch.manual_seed(0) if use_cuda: torch.cuda.manual_seed_all(0) device = 'cuda' cudnn.benchmark = True else: device = 'cpu' if args.evaluate == 1: device = 'cuda:0' if args.sp: model = mbnet(default=args.sp_cfg) else: #model = mobilenetv1(amc=args.amc, m=args.m) model = mbnet(amc=args.amc, m=args.m) print(model.cfg) cfg = model.cfg total_params = print_model_param_nums(model.cpu()) total_flops = print_model_param_flops(model.cpu(), 224, multiply_adds=False) print(total_params, total_flops) if not args.distributed: model = torch.nn.DataParallel(model).cuda() else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) # define loss function (criterion) and optimizer if args.label_smoothing: criterion = CrossEntropyLabelSmooth(num_classes).cuda() else: criterion = nn.CrossEntropyLoss().cuda() ### all parameter #### no_wd_params, wd_params = [], [] for name, param in model.named_parameters(): if param.requires_grad: if ".bn" in name or '.bias' in name: no_wd_params.append(param) else: wd_params.append(param) no_wd_params = nn.ParameterList(no_wd_params) wd_params = nn.ParameterList(wd_params) optimizer = torch.optim.SGD([ { 'params': no_wd_params, 'weight_decay': 0. }, { 'params': wd_params, 'weight_decay': args.weight_decay }, ], args.lr, momentum=args.momentum, nesterov=True) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.model_save_path): print("=> loading checkpoint '{}'".format(args.model_save_path)) checkpoint = torch.load(args.model_save_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.model_save_path, checkpoint['epoch'])) else: pass #print("=> no checkpoint found at '{}'".format(args.model_save_path)) # Data loading code train_loader, val_loader = \ get_data_loader(args.data, train_batch_size=batch_size, test_batch_size=32, workers=args.workers) if args.evaluate: validate(val_loader, model, criterion) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) #adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'cfg': cfg, 'sr': args.sr, 'amc': args.amc, 's': args.s, 'args': args, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, args.model_save_path)
])), batch_size=args.test_batch_size, shuffle=True, **kwargs) model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) if args.scratch: checkpoint = torch.load(args.scratch) model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg']) model_ref = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) flops_std = print_model_param_flops(model_ref, 32) flops_small = print_model_param_flops(model, 32) args.epochs = int(160 * (flops_std / flops_small)) if args.cuda: model.cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume)
checkpoint = torch.load(args.scratch) if args.dataset == 'imagenet': model = models.__dict__[args.arch](pretrained=False, cfg=checkpoint['cfg']) model_ref = models.__dict__[args.arch](pretrained=False, cfg=checkpoint['cfg']) model_ref.load_state_dict(checkpoint['state_dict']) else: model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg']) model_ref = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg']) model_ref.load_state_dict(checkpoint['state_dict']) for m0, m1 in zip(model.modules(), model_ref.modules()): if isinstance(m0, models.channel_selection): m0.indexes.data = m1.indexes.data.clone() if args.dataset == 'imagenet': model_base = model base_flops = print_model_param_flops(model_base, 224) pruned_flops = print_model_param_flops(model, 224) else: pass # model_base = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) # base_flops = print_model_param_flops(model_base, 32) # pruned_flops = print_model_param_flops(model, 32) # args.epochs = int(160 * (base_flops / pruned_flops)) if args.cuda: model.cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if os.path.isfile(args.resume):