def profiling(model, use_cuda): """profiling on either gpu or cpu""" print('Start model profiling, use_cuda: {}.'.format(use_cuda)) if getattr(FLAGS, 'autoslim', False): flops, params = model_profiling(model, FLAGS.image_size, FLAGS.image_size, use_cuda=use_cuda, verbose=getattr( FLAGS, 'profiling_verbose', False)) elif getattr(FLAGS, 'slimmable_training', False): for width_mult in sorted(FLAGS.width_mult_list, reverse=True): model.apply(lambda m: setattr(m, 'width_mult', width_mult)) print('Model profiling with width mult {}x:'.format(width_mult)) flops, params = model_profiling(model, FLAGS.image_size, FLAGS.image_size, use_cuda=use_cuda, verbose=getattr( FLAGS, 'profiling_verbose', False)) else: flops, params = model_profiling(model, FLAGS.image_size, FLAGS.image_size, use_cuda=use_cuda, verbose=getattr( FLAGS, 'profiling_verbose', True)) return flops, params
def slimming(loader, model, criterion): """network slimming by slimmable network""" model.eval() bn_calibration_init(model) model.apply(lambda m: setattr(m, 'width_mult', 1.0)) if getattr(FLAGS, 'distributed', False): layers = get_conv_layers(model.module) else: raise NotImplementedError print('Totally {} layers to slim.'.format(len(layers))) error = np.zeros(len(layers)) # get data if getattr(FLAGS, 'distributed', False): loader.sampler.set_epoch(0) input, target = next(iter(loader)) input = input.cuda() target = target.cuda() # start to slim print('Start to slim...') flops = 10e10 FLAGS.autoslim_target_flops = sorted(FLAGS.autoslim_target_flops) autoslim_target_flop = FLAGS.autoslim_target_flops.pop() while True: flops, params = model_profiling(model, FLAGS.image_size, FLAGS.image_size, verbose=getattr( FLAGS, 'profiling_verbose', False)) if flops < autoslim_target_flop: if len(FLAGS.autoslim_target_flops) == 0: break else: print('Find autoslim net at flops {}'.format( autoslim_target_flop)) autoslim_target_flop = FLAGS.autoslim_target_flops.pop() for i in range(len(layers)): torch.cuda.empty_cache() error[i] = 0. outc = layers[i].out_channels - layers[i].divisor if outc <= 0 or outc > layers[i].out_channels_max: error[i] += 1. continue layers[i].out_channels -= layers[i].divisor loss, error_batch = forward_loss(model, criterion, input, target, None, return_acc=True) error[i] += error_batch layers[i].out_channels += layers[i].divisor best_index = np.argmin(error) print(*[f'{element:.4f}' for element in error]) layers[best_index].out_channels -= layers[best_index].divisor print('Adjust layer {} for {} to {}, error: {}.'.format( best_index, -layers[best_index].divisor, layers[best_index].out_channels, error[best_index])) return
def get_model(): """get model""" model_lib = importlib.import_module(FLAGS.model) model = model_lib.Model(FLAGS.num_classes, input_size=FLAGS.image_size) inputs = Variable(torch.randn((2, 3, 32, 32))) output = model(inputs) print(output) print(output.size()) #input('pause') if getattr(FLAGS, 'distributed', False): gpu_id = init_dist() if getattr(FLAGS, 'distributed_all_reduce', False): # seems faster model_wrapper = AllReduceDistributedDataParallel(model.cuda()) else: model_wrapper = torch.nn.parallel.DistributedDataParallel( model.cuda(), [gpu_id], gpu_id) else: model_wrapper = torch.nn.DataParallel(model).cuda() return model, model_wrapper
def train_val_test(): """train and val""" torch.backends.cudnn.benchmark = True # seed set_random_seed() # for universally slimmable networks only if getattr(FLAGS, 'universally_slimmable_training', False): if getattr(FLAGS, 'test_only', False): if getattr(FLAGS, 'width_mult_list_test', None) is not None: FLAGS.test_only = False # skip training and goto BN calibration FLAGS.skip_training = True else: FLAGS.width_mult_list = FLAGS.width_mult_range # model model, model_wrapper = get_model() if getattr(FLAGS, 'label_smoothing', 0): criterion = CrossEntropyLossSmooth(reduction='none') else: criterion = torch.nn.CrossEntropyLoss(reduction='none') if getattr(FLAGS, 'inplace_distill', False): soft_criterion = CrossEntropyLossSoft(reduction='none') else: soft_criterion = None # check pretrained if getattr(FLAGS, 'pretrained', False): checkpoint = torch.load( FLAGS.pretrained, map_location=lambda storage, loc: storage) # update keys from external models if type(checkpoint) == dict and 'model' in checkpoint: checkpoint = checkpoint['model'] if getattr(FLAGS, 'pretrained_model_remap_keys', False): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] print('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_wrapper.load_state_dict(checkpoint) print('Loaded model {}.'.format(FLAGS.pretrained)) optimizer = get_optimizer(model_wrapper) # check resume training if os.path.exists(os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt')): checkpoint = torch.load( os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt'), map_location=lambda storage, loc: storage) model_wrapper.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) last_epoch = checkpoint['last_epoch'] lr_scheduler = get_lr_scheduler(optimizer) lr_scheduler.last_epoch = last_epoch best_val = checkpoint['best_val'] train_meters, val_meters = checkpoint['meters'] print('Loaded checkpoint {} at epoch {}.'.format( FLAGS.log_dir, last_epoch)) else: lr_scheduler = get_lr_scheduler(optimizer) last_epoch = lr_scheduler.last_epoch best_val = 1. train_meters = get_meters('train') val_meters = get_meters('val') # if start from scratch, print model and do profiling print(model_wrapper) if getattr(FLAGS, 'profiling', False): if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) if getattr(FLAGS, 'profiling_only', False): return # data train_transforms, val_transforms, test_transforms = data_transforms() train_set, val_set, test_set = dataset( train_transforms, val_transforms, test_transforms) train_loader, val_loader, test_loader = data_loader( train_set, val_set, test_set) # autoslim only if getattr(FLAGS, 'autoslim', False): with torch.no_grad(): slimming(train_loader, model_wrapper, criterion) return if getattr(FLAGS, 'test_only', False) and (test_loader is not None): print('Start testing.') test_meters = get_meters('test') with torch.no_grad(): if getattr(FLAGS, 'slimmable_training', False): for width_mult in sorted(FLAGS.width_mult_list, reverse=True): model_wrapper.apply( lambda m: setattr(m, 'width_mult', width_mult)) run_one_epoch( last_epoch, test_loader, model_wrapper, criterion, optimizer, test_meters, phase='test') else: run_one_epoch( last_epoch, test_loader, model_wrapper, criterion, optimizer, test_meters, phase='test') return if getattr(FLAGS, 'nonuniform_diff_seed', False): set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank()) print('Start training.') for epoch in range(last_epoch+1, FLAGS.num_epochs): if getattr(FLAGS, 'skip_training', False): print('Skip training at epoch: {}'.format(epoch)) break lr_scheduler.step() # train results = run_one_epoch( epoch, train_loader, model_wrapper, criterion, optimizer, train_meters, phase='train', soft_criterion=soft_criterion) # val if val_meters is not None: val_meters['best_val'].cache(best_val) with torch.no_grad(): results = run_one_epoch( epoch, val_loader, model_wrapper, criterion, optimizer, val_meters, phase='val') if is_master() and results['top1_error'] < best_val: best_val = results['top1_error'] torch.save( { 'model': model_wrapper.state_dict(), }, os.path.join(FLAGS.log_dir, 'best_model.pt')) print('New best validation top1 error: {:.3f}'.format(best_val)) # save latest checkpoint if is_master(): torch.save( { 'model': model_wrapper.state_dict(), 'optimizer': optimizer.state_dict(), 'last_epoch': epoch, 'best_val': best_val, 'meters': (train_meters, val_meters), }, os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt')) if getattr(FLAGS, 'calibrate_bn', False): if getattr(FLAGS, 'universally_slimmable_training', False): # need to rebuild model according to width_mult_list_test width_mult_list = FLAGS.width_mult_range.copy() for width in FLAGS.width_mult_list_test: if width not in FLAGS.width_mult_list: width_mult_list.append(width) FLAGS.width_mult_list = width_mult_list new_model, new_model_wrapper = get_model() profiling(new_model, use_cuda=True) new_model_wrapper.load_state_dict( model_wrapper.state_dict(), strict=False) model_wrapper = new_model_wrapper cal_meters = get_meters('cal') print('Start calibration.') results = run_one_epoch( -1, train_loader, model_wrapper, criterion, optimizer, cal_meters, phase='cal') print('Start validation after calibration.') with torch.no_grad(): results = run_one_epoch( -1, val_loader, model_wrapper, criterion, optimizer, cal_meters, phase='val') if is_master(): torch.save( { 'model': model_wrapper.state_dict(), }, os.path.join(FLAGS.log_dir, 'best_model_bn_calibrated.pt')) return
def run_one_epoch( epoch, loader, model, criterion, optimizer, meters, phase='train', soft_criterion=None): """run one epoch for train/val/test/cal""" t_start = time.time() assert phase in ['train', 'val', 'test', 'cal'], 'Invalid phase.' train = phase == 'train' if train: model.train() else: model.eval() if phase == 'cal': model.apply(bn_calibration_init) # change learning rate in each iteration if getattr(FLAGS, 'universally_slimmable_training', False): max_width = FLAGS.width_mult_range[1] min_width = FLAGS.width_mult_range[0] elif getattr(FLAGS, 'slimmable_training', False): max_width = max(FLAGS.width_mult_list) min_width = min(FLAGS.width_mult_list) if getattr(FLAGS, 'distributed', False): loader.sampler.set_epoch(epoch) for batch_idx, (input, target) in enumerate(loader): if phase == 'cal': if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1): break target = target.cuda(non_blocking=True) if train: # change learning rate if necessary lr_schedule_per_iteration(optimizer, epoch, batch_idx) optimizer.zero_grad() if getattr(FLAGS, 'slimmable_training', False): if getattr(FLAGS, 'universally_slimmable_training', False): # universally slimmable model (us-nets) widths_train = [] for _ in range(getattr(FLAGS, 'num_sample_training', 2)-2): widths_train.append( random.uniform(min_width, max_width)) widths_train = [max_width, min_width] + widths_train for width_mult in widths_train: # the sandwich rule if width_mult in [max_width, min_width]: model.apply( lambda m: setattr(m, 'width_mult', width_mult)) elif getattr(FLAGS, 'nonuniform', False): model.apply(lambda m: setattr( m, 'width_mult', lambda: random.uniform(min_width, max_width))) else: model.apply(lambda m: setattr( m, 'width_mult', width_mult)) # always track largest model and smallest model if is_master() and width_mult in [ max_width, min_width]: meter = meters[str(width_mult)] else: meter = None # inplace distillation if width_mult == max_width: loss, soft_target = forward_loss( model, criterion, input, target, meter, return_soft_target=True) else: if getattr(FLAGS, 'inplace_distill', False): loss = forward_loss( model, criterion, input, target, meter, soft_target=soft_target.detach(), soft_criterion=soft_criterion) else: loss = forward_loss( model, criterion, input, target, meter) loss.backward() else: # slimmable model (s-nets) for width_mult in sorted( FLAGS.width_mult_list, reverse=True): model.apply( lambda m: setattr(m, 'width_mult', width_mult)) if is_master(): meter = meters[str(width_mult)] else: meter = None if width_mult == max_width: loss, soft_target = forward_loss( model, criterion, input, target, meter, return_soft_target=True) else: if getattr(FLAGS, 'inplace_distill', False): loss = forward_loss( model, criterion, input, target, meter, soft_target=soft_target.detach(), soft_criterion=soft_criterion) else: loss = forward_loss( model, criterion, input, target, meter) loss.backward() else: loss = forward_loss( model, criterion, input, target, meters) loss.backward() if (getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False)): allreduce_grads(model) optimizer.step() if is_master() and getattr(FLAGS, 'slimmable_training', False): for width_mult in sorted(FLAGS.width_mult_list, reverse=True): meter = meters[str(width_mult)] meter['lr'].cache(optimizer.param_groups[0]['lr']) elif is_master(): meters['lr'].cache(optimizer.param_groups[0]['lr']) else: pass else: if getattr(FLAGS, 'slimmable_training', False): for width_mult in sorted(FLAGS.width_mult_list, reverse=True): model.apply( lambda m: setattr(m, 'width_mult', width_mult)) if is_master(): meter = meters[str(width_mult)] else: meter = None forward_loss(model, criterion, input, target, meter) else: forward_loss(model, criterion, input, target, meters) if is_master() and getattr(FLAGS, 'slimmable_training', False): for width_mult in sorted(FLAGS.width_mult_list, reverse=True): results = flush_scalar_meters(meters[str(width_mult)]) print('{:.1f}s\t{}\t{}\t{}/{}: '.format( time.time() - t_start, phase, str(width_mult), epoch, FLAGS.num_epochs) + ', '.join( '{}: {:.3f}'.format(k, v) for k, v in results.items())) elif is_master(): results = flush_scalar_meters(meters) print( '{:.1f}s\t{}\t{}/{}: '.format( time.time() - t_start, phase, epoch, FLAGS.num_epochs) + ', '.join('{}: {:.3f}'.format(k, v) for k, v in results.items())) else: results = None return results